Source code for chemotools.plotting._y_residuals

"""Y residuals plot for regression diagnostics and homoscedasticity analysis."""

from typing import Optional, Any, Literal, Tuple
import numpy as np
from matplotlib.figure import Figure
from matplotlib.axes import Axes

from chemotools.plotting._base import BasePlot, ColoringMixin
from chemotools.plotting._utils import (
    annotate_points,
    validate_data,
    scatter_with_colormap,
)


[docs] class YResidualsPlot(BasePlot, ColoringMixin): """Plot of residuals to assess homoscedasticity and model fit quality. This class creates scatter plots of Y residuals (observed - predicted) versus sample index or a given vector (e.g., predicted values, experimental conditions). Useful for detecting heteroscedasticity, patterns in residuals, and model issues. Parameters ---------- residuals : np.ndarray Residual values with shape (n_samples,) for univariate or (n_samples, n_targets) for multivariate regression. Residuals should be calculated as (y_true - y_pred). x_values : np.ndarray, optional Values for the x-axis. If None, uses sample indices (0, 1, 2, ...). Common choices: predicted values, experimental conditions, time points. Shape should be (n_samples,) or broadcastable to residuals shape. target_index : int, optional For multivariate residuals, which target to plot (default: 0). Ignored if residuals is 1D. color_by : np.ndarray, optional Values for coloring samples. Can be either: - Continuous (numeric): shows colorbar - Categorical (strings/classes): shows legend with discrete colors annotations : list[str], optional Labels for annotating individual points. label : str, optional Legend label for this dataset (default: "Residuals"). color : str, optional Color for all points when color_by is None (default: "steelblue"). colormap : str, optional Colormap name. Colorblind-friendly defaults: - "tab10" for categorical data - "viridis" for continuous data add_zero_line : bool, optional Whether to add a horizontal line at y=0 (default: True). add_confidence_band : bool or float, optional Whether to add confidence bands (±n*std) around zero. - If True: uses ±2*std (95% for normal distribution) - If float: uses ±value*std - If False or None: no bands (default: None) color_mode : {"continuous", "categorical"}, optional Explicitly specify coloring mode. If None (default), automatically detects based on dtype and unique values of color_by. colorbar_label : str, optional Label for the colorbar when using continuous coloring. Default is "Value". Only applies when color_by is continuous. Raises ------ ValueError If residuals have invalid shapes or x_values shape mismatch. Examples -------- **Simple residuals plot vs sample index:** >>> residuals = y_true - y_pred >>> plot = YResidualsPlot(residuals) >>> fig = plot.show(title="Residuals vs Sample Index") **Residuals vs predicted values (check for heteroscedasticity):** >>> plot = YResidualsPlot(residuals, x_values=y_pred) >>> fig = plot.show( ... title="Residuals vs Predicted", ... xlabel="Predicted Values", ... ylabel="Residuals" ... ) **With confidence bands:** >>> plot = YResidualsPlot( ... residuals, ... x_values=y_pred, ... add_confidence_band=2.0 # ±2 standard deviations ... ) >>> fig = plot.show(title="Residuals with 95% Confidence Band") **Multiple datasets composed together:** >>> fig, ax = plt.subplots() >>> YResidualsPlot(train_residuals, label="Train", color="blue").render(ax) >>> YResidualsPlot(test_residuals, label="Test", color="red").render(ax) >>> ax.legend() >>> plt.show() """ def __init__( self, residuals: np.ndarray, *, x_values: Optional[np.ndarray] = None, target_index: int = 0, color_by: Optional[np.ndarray] = None, annotations: Optional[list[str]] = None, label: str = "Residuals", color: Optional[str] = None, colormap: Optional[str] = None, add_zero_line: bool = True, add_confidence_band: Optional[bool | float] = None, color_mode: Optional[Literal["continuous", "categorical"]] = None, colorbar_label: str = "Value", ): self.residuals = validate_data(residuals, name="residuals", ensure_2d=False) self.x_values: Optional[np.ndarray] if x_values is not None: self.x_values = validate_data(x_values, name="x_values", ensure_2d=False) else: self.x_values = None self.target_index = target_index self.annotations = annotations self.label = label self.color = color self.add_zero_line = add_zero_line self.add_confidence_band = add_confidence_band self.x_axis: np.ndarray self.x_label: str self._validate_residuals() self._init_xy_data() if color_by is not None: color_by = validate_data( color_by, name="color_by", ensure_2d=False, numeric=False ) # Initialize coloring self._init_coloring( color_by, colormap, color_mode=color_mode, colorbar_label=colorbar_label ) def _validate_residuals(self) -> None: """Validate residuals shape and target index.""" if self.residuals.ndim == 1: self.residuals_1d = self.residuals elif self.residuals.ndim == 2: n_targets = self.residuals.shape[1] if self.target_index < 0 or self.target_index >= n_targets: raise ValueError( f"Invalid target_index {self.target_index}. " f"Residuals have {n_targets} targets." ) self.residuals_1d = self.residuals[:, self.target_index] def _init_xy_data(self) -> None: """Initialize x/y data for plotting.""" n_samples = self.residuals_1d.shape[0] if self.x_values is None: self.x_axis = np.arange(n_samples) self.x_label = "Sample Index" else: if self.x_values.shape[0] != n_samples: raise ValueError( f"x_values length ({self.x_values.shape[0]}) must match " f"residuals length ({n_samples})" ) self.x_axis = self.x_values self.x_label = "X Values" def _get_default_labels(self) -> dict[str, str]: if self.residuals.ndim == 2: default_title = f"Residuals Plot - Target {self.target_index + 1}" else: default_title = "Residuals Plot" return { "xlabel": self.x_label, "ylabel": "Residuals", "title": default_title, }
[docs] def show( self, *, figsize: Optional[Tuple[float, float]] = None, title: Optional[str] = None, xlabel: Optional[str] = None, ylabel: Optional[str] = None, xlim: Optional[Tuple[float, float]] = None, ylim: Optional[Tuple[float, float]] = None, **kwargs: Any, ) -> Figure: """Create and return a complete figure with the y-residuals plot. This method handles figure creation and then delegates to `render()`. Parameters ---------- figsize : tuple[float, float], optional Figure size in inches (width, height). title : str, optional Figure title. xlabel : str, optional Custom x-axis label. If None, uses existing label or default. ylabel : str, optional Custom y-axis label. If None, uses existing label or default. xlim : tuple[float, float], optional X-axis limits as (xmin, xmax). ylim : tuple[float, float], optional Y-axis limits as (ymin, ymax). **kwargs : Any Additional keyword arguments passed to the render() method. Returns ------- Figure The matplotlib Figure object containing the plot. """ return super().show( figsize=figsize, title=title, xlabel=xlabel, ylabel=ylabel, xlim=xlim, ylim=ylim, **kwargs, )
[docs] def render( self, ax: Optional[Axes] = None, *, xlabel: Optional[str] = None, ylabel: Optional[str] = None, xlim: Optional[tuple[float, float]] = None, ylim: Optional[tuple[float, float]] = None, **kwargs: Any, ) -> tuple[Figure, Axes]: """Render the plot on existing or new axes. Parameters ---------- ax : Axes, optional Matplotlib axes to render on. If None, creates new figure/axes. xlabel : str, optional Custom x-axis label. If None, uses existing label or default. ylabel : str, optional Custom y-axis label. If None, uses existing label or default. xlim : tuple[float, float], optional X-axis limits (min, max). ylim : tuple[float, float], optional Y-axis limits (min, max). **kwargs : Any Additional keyword arguments passed to scatter plot. Returns ------- tuple[Figure, Axes] The Figure and Axes objects containing the plot. """ fig, ax = super().render( ax=ax, xlabel=xlabel, ylabel=ylabel, xlim=xlim, ylim=ylim, **kwargs, ) # Add colorbar for continuous data self._add_colorbar_if_needed(ax) # Add legend if categorical or if label is provided if self.is_categorical or self.label: # Only add legend if there are labeled artists handles, labels = ax.get_legend_handles_labels() if handles: ax.legend() return fig, ax
def _render_plot(self, ax: Axes, **kwargs: Any) -> None: """Internal method to render the plot on given axes.""" alpha = kwargs.pop("alpha", 0.6) s = kwargs.pop("s", 50) edgecolors = kwargs.pop("edgecolors", "black") linewidths = kwargs.pop("linewidths", 0.5) scatter_with_colormap( ax, self.x_axis, self.residuals_1d, color_by=self.color_by, is_categorical=self.is_categorical, colormap=self.colormap, color=self.color if self.color is not None else "steelblue", label=self.label, alpha=alpha, s=s, edgecolors=edgecolors, linewidths=linewidths, **kwargs, ) # Add zero reference line if self.add_zero_line: ax.axhline(y=0, color="black", linestyle="-", linewidth=1.5, alpha=0.7) # Add confidence bands if requested if self.add_confidence_band is not None: std = np.std(self.residuals_1d) if isinstance(self.add_confidence_band, bool): n_std = 2.0 # Default to ±2σ (95% for normal) else: n_std = float(self.add_confidence_band) ax.axhline( y=n_std * std, color="red", linestyle="--", linewidth=1.5, alpha=0.5, label=f{n_std:.1f}σ", ) ax.axhline( y=-n_std * std, color="red", linestyle="--", linewidth=1.5, alpha=0.5, ) ax.fill_between( [ax.get_xlim()[0], ax.get_xlim()[1]], -n_std * std, n_std * std, color="red", alpha=0.1, ) # Add annotations if provided if self.annotations: annotate_points(ax, self.x_axis, self.residuals_1d, self.annotations)