Source code for chemotools.inspector._preprocessing_inspector

"""Preprocessing Inspector for pipeline step visualization."""

from __future__ import annotations

import warnings
from typing import (
    TYPE_CHECKING,
    Dict,
    List,
    Literal,
    Optional,
    Sequence,
    Tuple,
    Union,
)

import numpy as np
from sklearn.base import is_classifier, is_regressor
from sklearn.cross_decomposition._pls import _PLS
from sklearn.decomposition._base import _BasePCA
from sklearn.pipeline import Pipeline
from sklearn.utils import check_array
from sklearn.utils.validation import check_is_fitted

from .core.base import InspectorDataset, _DataHoldingBase
from .core.spectra import SpectraMixin
from .core.summaries import PreprocessingSummary
from .core.utils import (
    get_xlabel_for_features,
    normalize_datasets,
    prepare_color_values,
)
from .helpers._preprocessing import create_preprocessing_step_plot

if TYPE_CHECKING:
    from matplotlib.figure import Figure


def _is_model_step(step: object) -> bool:
    """Return ``True`` if *step* is a model/estimator that should be excluded.

    Steps excluded from preprocessing visualization:
    - PCA and other decomposition models (``_BasePCA``)
    - PLS and related cross-decomposition models (``_PLS``)
    - All estimators from ``sklearn.decomposition`` (NMF, FastICA, etc.)
    - All estimators from ``sklearn.cross_decomposition``
    - Classifiers (``is_classifier``)
    - Regressors (``is_regressor``)
    - The string ``"passthrough"`` (no-op placeholder)
    """
    if isinstance(step, str):
        # Handles "passthrough" and any future string sentinels
        return True

    if isinstance(step, (_BasePCA, _PLS)):
        return True

    if is_classifier(step) or is_regressor(step):
        return True

    # Catch decomposition models not derived from _BasePCA (NMF, FastICA, …)
    module = type(step).__module__ or ""
    if module.startswith("sklearn.decomposition") or module.startswith(
        "sklearn.cross_decomposition"
    ):
        return True

    return False


[docs] class PreprocessingInspector(SpectraMixin, _DataHoldingBase): """Inspector for visualizing the effects of each preprocessing step in a pipeline. The ``PreprocessingInspector`` takes a **fitted** scikit-learn :class:`~sklearn.pipeline.Pipeline` together with the datasets that were used for training (and, optionally, testing/validation). It walks through the pipeline steps, applies each preprocessing transformer cumulatively, and generates one plot per step so that users can visually inspect how each transformation modifies their data. Steps that are *model* estimators — such as PCA, PLS, classifiers, or regressors — are automatically detected and **excluded** from the visualization, because they do not represent a preprocessing transformation. The class also inherits :class:`SpectraMixin`, which provides the :meth:`inspect_spectra` method for a quick *raw vs. fully preprocessed* comparison. Parameters ---------- pipeline : Pipeline A **fitted** scikit-learn ``Pipeline``. All steps must already be fitted (i.e. ``pipeline.fit(X)`` has been called). X_train : array-like of shape (n_samples, n_features) Training feature matrix (required). y_train : array-like of shape (n_samples,), optional Training target values. Used for colouring plots when ``color_by='y'``. X_test : array-like of shape (n_samples, n_features), optional Test feature matrix. y_test : array-like of shape (n_samples,), optional Test target values. X_val : array-like of shape (n_samples, n_features), optional Validation feature matrix. y_val : array-like of shape (n_samples,), optional Validation target values. x_axis : array-like of shape (n_features,), optional Feature names or axis values (e.g. wavenumbers). If ``None``, integer indices are used. Attributes ---------- pipeline : Pipeline The original fitted pipeline. model : Pipeline Alias for ``pipeline`` (consistent with ``PCAInspector`` / ``PLSRegressionInspector``). preprocessing_steps : list of tuple ``(name, transformer)`` pairs for every step that will be visualised (model steps are excluded). datasets_ : dict of str to InspectorDataset Dictionary of loaded datasets keyed by ``'train'``, ``'test'``, ``'val'``. n_features_in_ : int Number of input features. Raises ------ TypeError If *pipeline* is not a :class:`~sklearn.pipeline.Pipeline`. RuntimeError If the pipeline has not been fitted. ValueError If ``X_train`` has inconsistent shape with other datasets. Examples -------- >>> from sklearn.pipeline import make_pipeline >>> from sklearn.preprocessing import StandardScaler, MinMaxScaler >>> from sklearn.decomposition import PCA >>> from chemotools.inspector import PreprocessingInspector >>> >>> pipe = make_pipeline(StandardScaler(), MinMaxScaler(), PCA(n_components=3)) >>> pipe.fit(X_train) >>> >>> inspector = PreprocessingInspector(pipe, X_train, y_train) >>> figures = inspector.inspect() # one plot per preprocessing step >>> figures = inspector.inspect_spectra() # raw vs. fully preprocessed """ def __init__( self, pipeline: Pipeline, X_train: np.ndarray, y_train: Optional[np.ndarray] = None, X_test: Optional[np.ndarray] = None, y_test: Optional[np.ndarray] = None, X_val: Optional[np.ndarray] = None, y_val: Optional[np.ndarray] = None, x_axis: Optional[np.ndarray] = None, ) -> None: self._validate_pipeline(pipeline) self._pipeline = pipeline # --- Identify preprocessing vs. model steps ---------------------------- self._preprocessing_steps: List[Tuple[str, object]] = [ (name, step) for name, step in pipeline.steps if not _is_model_step(step) ] if not self._preprocessing_steps: raise ValueError( "The pipeline does not contain any preprocessing steps to " "visualise. All steps were identified as model/estimator steps." ) # Warn about nested pipelines / ColumnTransformer — not fully supported from sklearn.compose import ColumnTransformer for name, step in self._preprocessing_steps: if isinstance(step, (Pipeline, ColumnTransformer)): warnings.warn( f"Step '{name}' is a {type(step).__name__}, which is treated " f"as a single opaque step. Individual sub-steps inside it " f"will not be visualised separately.", UserWarning, stacklevel=2, ) # --- Validate and build datasets ---------------------------------------- X_train = check_array( X_train, dtype="numeric", ensure_2d=True, ensure_all_finite=True, input_name="X_train", ) datasets: Dict[str, InspectorDataset] = { "train": InspectorDataset( X=X_train, y=self._validate_y(y_train, X_train.shape[0], "y_train"), ), } self._add_optional_dataset(datasets, "test", X_test, y_test, X_train.shape[1]) self._add_optional_dataset(datasets, "val", X_val, y_val, X_train.shape[1]) # --- Initialise data-holding base ---------------------------------------- super().__init__( datasets=datasets, n_features_in=X_train.shape[1], feature_names=x_axis, ) # ------------------------------------------------------------------ # Static / private helpers for __init__ # ------------------------------------------------------------------ @staticmethod def _validate_pipeline(pipeline: object) -> None: """Validate that *pipeline* is a fitted sklearn Pipeline.""" if not isinstance(pipeline, Pipeline): raise TypeError( f"Expected a sklearn Pipeline, got {type(pipeline).__name__}." ) try: check_is_fitted(pipeline) except Exception as exc: raise RuntimeError( "The pipeline must be fitted before passing it to " "PreprocessingInspector." ) from exc @staticmethod def _add_optional_dataset( datasets: Dict[str, InspectorDataset], name: str, X: Optional[np.ndarray], y: Optional[np.ndarray], expected_features: int, ) -> None: """Validate and store an optional (test / val) dataset.""" if X is None: return X = check_array( X, dtype="numeric", ensure_2d=True, ensure_all_finite=True, input_name=f"X_{name}", ) if X.shape[1] != expected_features: raise ValueError( f"X_{name} must have the same number of features as X_train. " f"Got {X.shape[1]} vs {expected_features}." ) datasets[name] = InspectorDataset( X=X, y=PreprocessingInspector._validate_y(y, X.shape[0], f"y_{name}"), ) # ------------------------------------------------------------------ # Static helpers # ------------------------------------------------------------------ @staticmethod def _validate_y( y: Optional[np.ndarray], expected_n: int, name: str ) -> Optional[np.ndarray]: """Validate and normalise a target array.""" if y is None: return None arr = check_array( y, dtype=None, ensure_2d=False, ensure_all_finite=True, input_name=name, ) if arr.ndim == 2 and arr.shape[1] == 1: arr = arr.ravel() if arr.shape[0] != expected_n: raise ValueError( f"{name} must have {expected_n} samples, got {arr.shape[0]}." ) return arr # ------------------------------------------------------------------ # SpectraMixin protocol implementation # ------------------------------------------------------------------ @property def transformer(self) -> Optional[Pipeline]: """Return a pipeline containing only the preprocessing steps. This is used by :class:`SpectraMixin` to generate *raw vs. preprocessed* comparison plots via :meth:`inspect_spectra`. """ if not self._preprocessing_steps: return None return Pipeline(list(self._preprocessing_steps)) def _get_preprocessed_data(self, dataset: str) -> np.ndarray: """Return fully preprocessed X (through all preprocessing steps).""" if dataset in self._preprocessed_cache: return self._preprocessed_cache[dataset] X = self._get_dataset(dataset).X transformer = self.transformer if transformer is None: result = X else: result = transformer.transform(X) self._preprocessed_cache[dataset] = result return result def _get_preprocessed_x_axis(self) -> np.ndarray: """Return x-axis values after full preprocessing. If preprocessing changes the number of features (e.g. feature selection), the returned array will reflect the new dimensionality. """ X_prep = self._get_preprocessed_data("train") return self._x_axis_for_n_features(X_prep.shape[1]) # ------------------------------------------------------------------ # Core API # ------------------------------------------------------------------ @property def pipeline(self) -> Pipeline: """Return the original pipeline.""" return self._pipeline @property def model(self) -> Pipeline: """Return the original pipeline. Alias for :attr:`pipeline`, provided for consistency with ``PCAInspector`` and ``PLSRegressionInspector``. """ return self._pipeline @property def preprocessing_steps(self) -> List[Tuple[str, object]]: """Return the list of ``(name, transformer)`` preprocessing steps.""" return list(self._preprocessing_steps) # ------------------------------------------------------------------ # Representation # ------------------------------------------------------------------ def __repr__(self) -> str: # noqa: D105 datasets = ", ".join( f"{name}({ds.n_samples})" for name, ds in self.datasets_.items() ) return ( f"PreprocessingInspector(" f"steps={len(self._preprocessing_steps)}, " f"features={self.n_features_in_}, " f"datasets=[{datasets}])" ) # ------------------------------------------------------------------ # inspect() — step-by-step preprocessing visualization # ------------------------------------------------------------------
[docs] def inspect( self, dataset: Union[str, Sequence[str]] = "train", color_by: Optional[Union[str, Dict[str, np.ndarray]]] = "y", xlim: Optional[Tuple[float, float]] = None, figsize: Tuple[float, float] = (12, 5), color_mode: Optional[Literal["continuous", "categorical"]] = None, ) -> Dict[str, "Figure"]: """Generate one plot per preprocessing step showing cumulative effects. For a pipeline with steps ``[A, B, C, PCA]`` (where PCA is excluded), this method produces: 1. **Raw** – the original input data 2. **After A** – ``A.transform(X)`` 3. **After A + B** – ``B.transform(A.transform(X))`` 4. **After A + B + C** – ``C.transform(B.transform(A.transform(X)))`` Parameters ---------- dataset : str or sequence of str, default='train' Dataset(s) to visualise. When a sequence is given, all datasets are overlaid on the same axes, coloured by dataset name. color_by : str or dict, default='y' Colouring specification (single-dataset mode only): - ``'y'``: colour by target values (if available) - ``'sample_index'``: colour by sample index - dict mapping dataset names to colour arrays Ignored when multiple datasets are provided (colours by dataset instead). xlim : tuple of float, optional X-axis limits for zooming into a spectral region. figsize : tuple of float, default=(12, 5) Figure size ``(width, height)`` in inches for each subplot. color_mode : {``'continuous'``, ``'categorical'``}, optional Override automatic colour-mode detection. Returns ------- figures : dict of str to Figure Dictionary mapping step names to matplotlib ``Figure`` objects. Keys follow the pattern ``'raw'``, ``'step_1_<name>'``, ``'step_2_<name>'``, etc. Examples -------- >>> inspector = PreprocessingInspector(pipeline, X_train, y_train) >>> figures = inspector.inspect() >>> figures['raw'].savefig('raw_spectra.png') >>> figures['step_1_standardscaler'].savefig('after_scaling.png') """ self.close_figures() datasets = normalize_datasets(dataset) is_multi = len(datasets) > 1 xlabel = get_xlabel_for_features(self.feature_names is not None) figures: Dict[str, "Figure"] = {} # --- Raw data plot ----------------------------------------------------- if is_multi: figures["raw"] = self._plot_multi_dataset_step( datasets, step_data=None, title="Raw Spectra", xlabel=xlabel, xlim=xlim, figsize=figsize, color_mode=color_mode, ) else: ds_name = datasets[0] ds = self._get_dataset(ds_name) color_values = prepare_color_values(color_by, ds_name, ds.y, ds.X.shape[0]) figures["raw"] = create_preprocessing_step_plot( X=ds.X, x_axis=self._x_axis, title=f"Raw Spectra ({ds_name.capitalize()})", xlabel=xlabel, color_values=color_values, xlim=xlim, figsize=figsize, color_mode=color_mode, ) # --- One plot per preprocessing step (iterative: O(N) transforms) ----- # Keep track of the cumulative transformed X for each dataset so that # each step only applies its own transform on the previous output, # avoiding the O(N²) cost of rebuilding sub-pipelines. cumulative: Dict[str, np.ndarray] = { ds_name: self._get_dataset(ds_name).X for ds_name in datasets } for step_idx, (step_name, _step_transformer) in enumerate( self._preprocessing_steps, start=1 ): # Apply only this step's transform to the previous cumulative output for ds_name in datasets: cumulative[ds_name] = _step_transformer.transform(cumulative[ds_name]) # type: ignore[union-attr] # Cumulative step label for the title/key latest_step_type = type(_step_transformer).__name__ fig_key = f"step_{step_idx}_{step_name}" title = f"Step {step_idx}: after {latest_step_type}" if is_multi: step_x_axis = self._resolve_step_x_axis(cumulative[datasets[0]]) figures[fig_key] = self._plot_multi_dataset_step( datasets, step_data=dict(cumulative), title=title, xlabel=xlabel, xlim=xlim, figsize=figsize, color_mode=color_mode, step_x_axis=step_x_axis, ) else: ds_name = datasets[0] ds = self._get_dataset(ds_name) step_x_axis = self._resolve_step_x_axis(cumulative[ds_name]) color_values = prepare_color_values( color_by, ds_name, ds.y, ds.X.shape[0] ) figures[fig_key] = create_preprocessing_step_plot( X=cumulative[ds_name], x_axis=step_x_axis, title=f"{title} ({ds_name.capitalize()})", xlabel=xlabel, color_values=color_values, xlim=xlim, figsize=figsize, color_mode=color_mode, ) return self._track_figures(figures)
# ------------------------------------------------------------------ # Summary # ------------------------------------------------------------------
[docs] def summary(self) -> PreprocessingSummary: """Return a summary of the pipeline and preprocessing steps. Returns ------- summary : PreprocessingSummary Typed summary dataclass. Printing the returned object produces a human-readable table (via ``__repr__``). """ steps_info = [ { "step": i, "name": name, "type": type(transformer).__name__, } for i, (name, transformer) in enumerate(self._preprocessing_steps, start=1) ] excluded = [ { "name": name, "type": type(transformer).__name__, } for name, transformer in self._pipeline.steps if _is_model_step(transformer) ] return PreprocessingSummary( pipeline_type=type(self._pipeline).__name__, total_steps=len(self._pipeline.steps), n_preprocessing_steps=len(self._preprocessing_steps), n_excluded_steps=len(excluded), n_features=self.n_features_in_, n_samples=self.n_samples, steps=steps_info, excluded=excluded, )
# ------------------------------------------------------------------ # Private helpers # ------------------------------------------------------------------ def _resolve_step_x_axis(self, X_step: np.ndarray) -> np.ndarray: """Return the appropriate x-axis for a transformed array.""" return self._x_axis_for_n_features(X_step.shape[1]) def _x_axis_for_n_features(self, n_features_out: int) -> np.ndarray: """Return the appropriate x-axis for a given output dimensionality. If the number of features is unchanged and feature names were provided, use them. Otherwise fall back to integer indices. """ if n_features_out == self.n_features_in_ and self.feature_names is not None: return self.feature_names.copy() if n_features_out == self.n_features_in_: return self._x_axis.copy() return np.arange(n_features_out) def _plot_multi_dataset_step( self, datasets: List[str], step_data: Optional[Dict[str, np.ndarray]], title: str, xlabel: str, xlim: Optional[Tuple[float, float]], figsize: Tuple[float, float], color_mode: Optional[Literal["continuous", "categorical"]], step_x_axis: Optional[np.ndarray] = None, ) -> "Figure": """Create a single figure with multiple datasets overlaid.""" import matplotlib.pyplot as plt from chemotools.plotting import SpectraPlot from chemotools.plotting._styles import DATASET_COLORS fig, ax = plt.subplots(figsize=figsize) for ds_name in datasets: if step_data is not None: X = step_data[ds_name] x_ax = step_x_axis if step_x_axis is not None else np.arange(X.shape[1]) else: X = self._get_dataset(ds_name).X x_ax = self._x_axis color = DATASET_COLORS.get(ds_name, "black") labels: List[Optional[str]] = [ds_name.capitalize()] + [None] * ( X.shape[0] - 1 ) plot = SpectraPlot(x=x_ax, y=X, labels=labels, color_mode=color_mode) plot.render(ax=ax, color=color, alpha=0.6, linewidth=1) ax.set_title(title, fontsize=14, fontweight="bold") ax.set_xlabel(xlabel, fontsize=12) ax.set_ylabel("Intensity", fontsize=12) ax.grid(alpha=0.3) if xlim: ax.set_xlim(xlim) ax.legend() return fig