"""
The :mod:`chemotools.adaptation._x_axis_interpolator` allows resampling each X to a
common x_axis grid.
"""
# Author: Pau Cabaneros
# License: MIT
from numbers import Real
import numpy as np
from scipy.interpolate import CubicSpline, PchipInterpolator
from sklearn.base import BaseEstimator, TransformerMixin, _fit_context
from sklearn.utils._param_validation import Interval, StrOptions
from sklearn.utils.validation import check_is_fitted, validate_data
from chemotools._doc_mixin import DocLinkMixin
[ドキュメント]
class XAxisInterpolator(DocLinkMixin, TransformerMixin, BaseEstimator):
"""Interpolate every row of ``X`` onto a shared ``common_x_axis``.
The transformer resamples each row of ``X`` from a sample-specific (or shared)
input grid ``x_axis`` onto a fixed ``common_x_axis`` provided at instantiation
time.
``x_axis`` is consumed as **metadata** so it flows correctly through
``Pipeline``, ``ColumnTransformer``, ``GridSearchCV``, ``cross_validate`` etc.,
once metadata routing is enabled via:
``sklearn.set_config(enable_metadata_routing=True)``.
Parameters
----------
common_x_axis : array-like of shape (n_output_features,)
Strictly increasing target grid. ``transform`` returns an array with
``n_output_features`` columns.
method : str, default="cubic"
The interpolation mode. One of:
``"linear"``, ``"cubic"`` or ``"pchip"``.
left : float, default=np.nan
Value returned for query points below the input grid (passed to
:func:`numpy.interp`).
right : float, default=np.nan
Value returned for query points above the input grid.
Attributes
----------
common_x_axis_ : ndarray of shape (n_output_features,)
Validated copy of ``common_x_axis``.
n_features_in_ : int
Number of input features seen during ``fit``.
feature_names_in_ : ndarray of shape (``n_features_in_``,)
Names of features seen during ``fit`` (only if ``X`` had names).
"""
_parameter_constraints: dict = {
"common_x_axis": ["array-like"],
"method": [StrOptions({"linear", "cubic", "pchip"})],
"left": [Interval(Real, None, None, closed="neither"), None, np.nan.__class__],
"right": [Interval(Real, None, None, closed="neither"), None, np.nan.__class__],
}
def __init__(
self,
common_x_axis: np.ndarray,
method: str = "cubic",
left=np.nan,
right=np.nan,
):
self.common_x_axis = common_x_axis
self.method = method
self.left = left
self.right = right
[ドキュメント]
@_fit_context(prefer_skip_nested_validation=True)
def fit(self, X: np.ndarray, y=None, x_axis=None) -> "XAxisInterpolator":
"""Validate input and the configured ``common_x_axis``.
Parameters
----------
X : array-like of shape (n_samples, n_features)
The input data to fit the transformer to.
y : Ignored
Ignored to align with API.
x_axis : Ignored
Accepted only so that metadata routing through ``fit_transform``
(used by ``Pipeline``) generates a ``set_fit_request`` method.
``fit`` itself does not use it.
Returns
-------
self : object
"""
# Validate the input parameters
self._validate_params()
# Validate the X data
validate_data(self, X, ensure_2d=True, dtype="numeric")
# Validate the common_x_axis
common_x_axis = np.asarray(self.common_x_axis, dtype=float)
if common_x_axis.ndim != 1:
raise ValueError(
f"common_x_axis must be 1D, got shape {common_x_axis.shape}."
)
if common_x_axis.size < 2:
raise ValueError("common_x_axis must contain at least 2 points.")
if not np.all(np.isfinite(common_x_axis)):
raise ValueError("common_x_axis must contain only finite values.")
if not np.all(np.diff(common_x_axis) > 0):
raise ValueError("common_x_axis must be strictly increasing.")
self.common_x_axis_ = common_x_axis
return self
[ドキュメント]
def get_feature_names_out(self, input_features=None):
"""Names are the positions of the common grid."""
check_is_fitted(self)
return np.asarray([f"x_{v:g}" for v in self.common_x_axis_], dtype=object)
def _prepare_x_axis(
self, x_axis: np.ndarray, X_shape: tuple[int, int]
) -> np.ndarray:
"""Validate and normalize x_axis to shape (n_samples, n_features)."""
n_samples, n_features = X_shape
x_axis = np.asarray(x_axis, dtype=float)
if x_axis.ndim == 1:
if x_axis.shape[0] != n_features:
raise ValueError(
f"x_axis has shape {x_axis.shape}, expected ({n_features},) "
f"or ({n_samples}, {n_features})."
)
x_per_row = np.broadcast_to(x_axis, (n_samples, n_features))
elif x_axis.ndim == 2:
if x_axis.shape != X_shape:
raise ValueError(
f"x_axis has shape {x_axis.shape}, expected {X_shape}."
)
x_per_row = x_axis
else:
raise ValueError(f"x_axis must be 1D or 2D, got {x_axis.ndim}D array.")
if not np.all(np.isfinite(x_per_row)):
raise ValueError("x_axis must contain only finite values.")
if not np.all(np.diff(x_per_row, axis=1) > 0):
raise ValueError("x_axis must be strictly increasing along axis=1.")
return x_per_row
@staticmethod
def _resolve_fill_values(
row: np.ndarray, left, right
) -> tuple[float | np.floating, float | np.floating]:
"""Resolve boundary fill values for one row."""
left_value = row[0] if left is None else left
right_value = row[-1] if right is None else right
return left_value, right_value
def _interpolate_row(
self,
*,
xi: np.ndarray,
yi: np.ndarray,
target: np.ndarray,
left_value,
right_value,
) -> np.ndarray:
"""Interpolate one row according to the configured method."""
method = self.method
if method == "linear":
return np.interp(target, xi, yi, left=left_value, right=right_value)
if method == "cubic":
interp_func = CubicSpline(xi, yi, bc_type="not-a-knot", extrapolate=False)
else: # method == "pchip"
interp_func = PchipInterpolator(xi, yi, extrapolate=False)
row = interp_func(target)
row[target < xi[0]] = left_value
row[target > xi[-1]] = right_value
return row