"""
The :mod:`chemotools.scale._point_scaler` module implements a Point Scaler transformer.
"""
# Authors: Pau Cabaneros
# License: MIT
from numbers import Integral
from typing import Optional
import numpy as np
from sklearn.base import BaseEstimator, OneToOneFeatureMixin, TransformerMixin
from sklearn.utils._param_validation import Interval
from sklearn.utils.validation import check_is_fitted, validate_data
from chemotools._axis_mixin import XAxisMixin
from chemotools._deprecation import (
DEPRECATED_PARAMETER,
deprecated_parameter_constraint,
)
[docs]
class PointScaler(XAxisMixin, TransformerMixin, OneToOneFeatureMixin, BaseEstimator):
"""
A transformer that scales the input data by the intensity value at a given point.
The point can be specified by an index or by a wavenumber.
Parameters
----------
point : int, optional, default=0
The point to scale the data by. It can be an index or an x-axis value.
x_axis : array-like, optional, default=None
The x-axis values of the input data. If not provided, the indices will be used
instead. Default is None. If provided, the values must be in ascending order.
wavenumbers : array-like, optional
Deprecated alias for ``x_axis``.
Attributes
----------
n_features_in_ : int
The number of features in the input data.
point_index_ : int
The index of the point to scale the data by. It is 0
if the wavenumbers are not provided.
Examples
--------
>>> from chemotools.datasets import load_fermentation_train
>>> from chemotools.scale import PointScaler
>>> # Load sample data
>>> X, _ = load_fermentation_train()
>>> # Initialize PointScaler with point index
>>> scaler = PointScaler(point=10)
PointScaler(point=10, wavenumbers=None)
>>> # Fit and transform the data
>>> X_scaled = scaler.fit_transform(X)
"""
_parameter_constraints: dict = {
"point": [Interval(Integral, 0, None, closed="left")],
"x_axis": ["array-like", None],
"wavenumbers": ["array-like", None, deprecated_parameter_constraint()],
}
def __init__(
self,
point: int = 0,
x_axis: Optional[np.ndarray] = None,
wavenumbers=DEPRECATED_PARAMETER,
):
self.point = point
self.x_axis = x_axis
self.wavenumbers = wavenumbers
[docs]
def fit(self, X: np.ndarray, y=None) -> "PointScaler":
"""
Fit the transformer to the input data.
Parameters
----------
X : np.ndarray of shape (n_samples, n_features)
The input data to fit the transformer to.
y : None
Ignored to align with API.
Returns
-------
self : PointScaler
The fitted transformer.
"""
# Validate the input parameters
self._validate_params()
# Check that X is a 2D array and has only finite values
X = validate_data(
self, X, y="no_validation", ensure_2d=True, reset=True, dtype=np.float64
)
axis_values = self._resolve_x_axis(self.x_axis, self.wavenumbers)
# Set the point index
if axis_values is None:
self.point_index_ = self.point
else:
self.point_index_ = self._find_index(self.point, axis_values)
return self