Source code for chemotools.smooth._median_filter
"""
The :mod:`chemotools.smooth._median_filter` module implements
the Median Filter (MD) transformation.
"""
# Authors: Pau Cabaneros
# License: MIT
from numbers import Integral
from typing import Literal
import numpy as np
from scipy.ndimage import median_filter
from sklearn.base import BaseEstimator, OneToOneFeatureMixin, TransformerMixin
from sklearn.utils._param_validation import Interval, StrOptions
from sklearn.utils.validation import check_is_fitted, validate_data
from chemotools._deprecation import (
DEPRECATED_PARAMETER,
deprecated_parameter_constraint,
resolve_renamed_parameter,
)
from chemotools._doc_mixin import DocLinkMixin
from chemotools._parallel import apply_rows
[docs]
class MedianFilter(DocLinkMixin, TransformerMixin, OneToOneFeatureMixin, BaseEstimator):
"""
A smoothing transformer that calculates the median filter of the input data.
Parameters
----------
window_length : int, optional, default=3
The size of the window to use for the median filter. Must be odd. Default is 3.
mode : str, optional, default="nearest"
The mode to use for the median filter. Can be "nearest", "constant", "reflect",
"wrap", "mirror", "grid-constant", "grid-mirror" or "grid-wrap".
Default is "nearest".
n_jobs : int, optional, default=1
Number of parallel jobs used to process samples independently.
Uses serial execution when set to 1.
window_size : int, optional
Deprecated alias for ``window_length``.
Attributes
----------
n_features_in_ : int
The number of features in the training data.
Examples
--------
>>> from chemotools.datasets import load_fermentation_train
>>> from chemotools.smooth import MedianFilter
>>> # Load sample data
>>> X, _ = load_fermentation_train()
>>> # Initialize MedianFilter
>>> md = MedianFilter()
MedianFilter()
>>> # Fit and transform the data
>>> X_smoothed = md.fit_transform(X)
"""
_parameter_constraints: dict = {
"window_length": [Interval(Integral, 3, None, closed="left")],
"mode": [
StrOptions(
{
"nearest",
"constant",
"reflect",
"wrap",
"mirror",
"grid-constant",
"grid-mirror",
"grid-wrap",
}
)
],
"window_size": [
Interval(Integral, 3, None, closed="left"),
deprecated_parameter_constraint(),
],
"n_jobs": [
Interval(Integral, None, -1, closed="right"),
Interval(Integral, 1, None, closed="left"),
],
}
def __init__(
self,
window_length: int = 3,
mode: Literal[
"reflect",
"constant",
"nearest",
"mirror",
"wrap",
"grid-constant",
"grid-mirror",
"grid-wrap",
] = "nearest",
n_jobs: int = 1,
window_size=DEPRECATED_PARAMETER,
) -> None:
self.window_length = window_length
self.window_size = window_size
self.mode = mode
self.n_jobs = n_jobs
def __setstate__(self, state: dict) -> None:
"""Restore state while keeping backward compatibility with old pickles."""
super().__setstate__(state)
if "window_length" not in self.__dict__ and "window_size" in self.__dict__:
self.window_length = self.window_size
if "window_size" not in self.__dict__ and "window_length" in self.__dict__:
self.window_size = DEPRECATED_PARAMETER
if "n_jobs" not in self.__dict__:
self.n_jobs = 1
[docs]
def fit(self, X: np.ndarray, y=None) -> "MedianFilter":
"""
Fit the transformer to the input data.
Parameters
----------
X : np.ndarray of shape (n_samples, n_features)
The input data to fit the transformer to.
y : None
Ignored to align with API.
Returns
-------
self : MedianFilter
The fitted transformer.
"""
# Validate the input parameters
self._validate_params()
# Check that X is a 2D array and has only finite values
X = validate_data(
self, X, y="no_validation", ensure_2d=True, reset=True, dtype=np.float64
)
self.window_length_ = resolve_renamed_parameter(
new_name="window_length",
new_value=self.window_length,
new_default=3,
old_name="window_size",
old_value=self.window_size,
)
if self.window_length_ % 2 == 0:
raise ValueError("window_length must be odd")
return self
[docs]
def transform(self, X: np.ndarray, y=None) -> np.ndarray:
"""
Transform the input data by calculating the median filter.
Parameters
----------
X : np.ndarray of shape (n_samples, n_features)
The input data to transform.
y : None
Ignored to align with API.
Returns
-------
X_transformed : np.ndarray of shape (n_samples, n_features)
The transformed data.
"""
# Check that the estimator is fitted
check_is_fitted(self, "n_features_in_")
# Check that X is a 2D array and has only finite values
X_ = validate_data(
self,
X,
y="no_validation",
ensure_2d=True,
copy=True,
reset=False,
dtype=np.float64,
)
X_transformed = apply_rows(X_, n_jobs=self.n_jobs, fn=self._transform_block)
return X_transformed
def _transform_block(self, X_block: np.ndarray) -> np.ndarray:
return median_filter(
X_block, size=self.window_length_, mode=self.mode, axes=(1,)
)