Source code for chemotools.plotting._qq_plot

"""Q-Q plot for assessing normality of residuals."""

from typing import Optional, Any, Tuple
import numpy as np
from matplotlib.figure import Figure
from matplotlib.axes import Axes
from scipy import stats

from chemotools.plotting._base import BasePlot
from chemotools.plotting._utils import (
    annotate_points,
    validate_data,
)


[docs] class QQPlot(BasePlot): """Quantile-Quantile plot to assess if residuals follow a normal distribution. This class creates Q-Q plots comparing the quantiles of residuals against theoretical quantiles from a normal distribution. Points falling on the diagonal line indicate normality, while deviations suggest non-normality. Parameters ---------- residuals : np.ndarray Residual values with shape (n_samples,) for univariate or (n_samples, n_targets) for multivariate regression. target_index : int, optional For multivariate residuals, which target to plot (default: 0). Ignored if residuals is 1D. annotations : list[str], optional Labels for annotating individual points (e.g., outliers). label : str, optional Legend label for this dataset (default: "Residuals"). color : str, optional Color for the points (default: "#008BFB"). add_reference_line : bool, optional Whether to add the diagonal reference line (default: True). add_confidence_band : bool or float, optional Whether to add confidence bands around the reference line. - If True: uses 95% confidence band - If float: uses specified confidence level (0 < level < 1) - If False or None: no bands (default: None) Raises ------ ValueError If residuals have invalid shapes. Examples -------- **Basic Q-Q plot:** >>> residuals = y_true - y_pred >>> plot = QQPlot(residuals) >>> fig = plot.show(title="Q-Q Plot of Residuals") **With confidence bands:** >>> plot = QQPlot(residuals, add_confidence_band=0.95) >>> fig = plot.show(title="Q-Q Plot with 95% Confidence Band") **Multiple datasets compared:** >>> fig, axes = plt.subplots(1, 2, figsize=(12, 5)) >>> QQPlot(train_residuals, label="Train").render(axes[0]) >>> QQPlot(test_residuals, label="Test").render(axes[1]) >>> plt.show() **Multivariate regression - plot specific target:** >>> residuals = y_true - y_pred # shape (n_samples, n_targets) >>> plot = QQPlot(residuals, target_index=1) >>> fig = plot.show(title="Q-Q Plot for Target 2") **With outlier annotations:** >>> outlier_indices = [5, 23, 47] >>> annotations = [f"S{i}" if i in outlier_indices else "" for i in range(len(residuals))] >>> plot = QQPlot(residuals, annotations=annotations) >>> fig = plot.show(title="Q-Q Plot with Outliers") Notes ----- The Q-Q plot compares: - X-axis: Theoretical quantiles from standard normal distribution N(0,1) - Y-axis: Sample quantiles (standardized residuals) Points should fall approximately on the diagonal line y=x if residuals are normally distributed. Common patterns: - S-curve: Heavy or light tails - Points above line: Right skew - Points below line: Left skew """ def __init__( self, residuals: np.ndarray, *, target_index: int = 0, annotations: Optional[list[str]] = None, label: str = "Residuals", color: str = "#008BFB", add_reference_line: bool = True, add_confidence_band: Optional[bool | float] = None, ): self.residuals = validate_data(residuals, name="residuals", ensure_2d=False) self.target_index = target_index self.annotations = annotations self.label = label self.color = color self.add_reference_line = add_reference_line self.add_confidence_band = add_confidence_band # Validate inputs self._validate_residuals() # Extract the specific target's residuals if multivariate if self.residuals.ndim == 2: if target_index >= self.residuals.shape[1]: raise ValueError( f"target_index {target_index} is out of bounds for " f"residuals with {self.residuals.shape[1]} targets" ) self.residuals_1d = self.residuals[:, target_index] elif self.residuals.ndim == 1: self.residuals_1d = self.residuals # Calculate Q-Q plot data self._calculate_qq_data() def _validate_residuals(self) -> None: """Validate residuals array.""" if self.residuals.size < 3: raise ValueError("Need at least 3 residuals for Q-Q plot") def _calculate_qq_data(self) -> None: """Calculate theoretical and sample quantiles for Q-Q plot.""" # Use scipy.stats.probplot to get the Q-Q plot data # probplot returns ((theoretical_quantiles, ordered_values), (slope, intercept, r)) ( (self.theoretical_quantiles, self.sample_quantiles), ( self.slope, self.intercept, self.r_value, ), ) = stats.probplot(self.residuals_1d, dist="norm") def _get_default_labels(self) -> dict[str, str]: if self.residuals.ndim == 2: title = f"Q-Q Plot for Target {self.target_index + 1}" else: title = "Q-Q Plot" return { "xlabel": "Theoretical Quantiles", "ylabel": "Sample Quantiles", "title": title, }
[docs] def show( self, *, figsize: Optional[Tuple[float, float]] = None, title: Optional[str] = None, xlabel: Optional[str] = None, ylabel: Optional[str] = None, xlim: Optional[Tuple[float, float]] = None, ylim: Optional[Tuple[float, float]] = None, **kwargs: Any, ) -> Figure: """Create and return a complete figure with the QQ plot. This method handles figure creation and then delegates to `render()`. Parameters ---------- figsize : tuple[float, float], optional Figure size in inches (width, height). title : str, optional Figure title. xlabel : str, optional Custom x-axis label. If None, uses existing label or default. ylabel : str, optional Custom y-axis label. If None, uses existing label or default. xlim : tuple[float, float], optional X-axis limits as (xmin, xmax). ylim : tuple[float, float], optional Y-axis limits as (ymin, ymax). **kwargs : Any Additional keyword arguments passed to the render() method. Returns ------- Figure The matplotlib Figure object containing the plot. """ return super().show( figsize=figsize, title=title, xlabel=xlabel, ylabel=ylabel, xlim=xlim, ylim=ylim, **kwargs, )
[docs] def render( self, ax: Optional[Axes] = None, *, xlabel: Optional[str] = None, ylabel: Optional[str] = None, xlim: Optional[tuple[float, float]] = None, ylim: Optional[tuple[float, float]] = None, **kwargs: Any, ) -> tuple[Figure, Axes]: """Render the plot on existing or new axes. Parameters ---------- ax : Axes, optional Matplotlib axes to render on. If None, creates new figure/axes. xlabel : str, optional Custom x-axis label. If None, uses existing label or default. ylabel : str, optional Custom y-axis label. If None, uses existing label or default. xlim : tuple[float, float], optional X-axis limits (min, max). ylim : tuple[float, float], optional Y-axis limits (min, max). **kwargs : Any Additional keyword arguments passed to scatter plot. Returns ------- tuple[Figure, Axes] The Figure and Axes objects containing the plot. """ fig, ax = super().render( ax=ax, xlabel=xlabel, ylabel=ylabel, xlim=xlim, ylim=ylim, **kwargs, ) return fig, ax
def _render_plot(self, ax: Axes, **kwargs: Any) -> None: """Internal method to render the plot on given axes.""" # Create scatter plot of theoretical vs sample quantiles scatter_kwargs = { "alpha": kwargs.get("alpha", 0.7), "s": kwargs.get("s", 50), "edgecolors": kwargs.get("edgecolors", "black"), "linewidths": kwargs.get("linewidths", 0.5), "label": self.label, } ax.scatter( self.theoretical_quantiles, self.sample_quantiles, c=self.color, **scatter_kwargs, ) # Add reference line (diagonal) if self.add_reference_line: # Calculate the line based on the fit line_x = np.array( [self.theoretical_quantiles.min(), self.theoretical_quantiles.max()] ) line_y = self.slope * line_x + self.intercept ax.plot( line_x, line_y, "r-", linewidth=2, alpha=0.8, label=f"Reference Line (R²={self.r_value**2:.3f})", ) # Add confidence bands if requested if self.add_confidence_band is not None: if isinstance(self.add_confidence_band, bool): confidence_level = 0.95 else: confidence_level = float(self.add_confidence_band) # Calculate confidence bands using standard error n = len(self.residuals_1d) se = np.std(self.residuals_1d) * np.sqrt( (1 / n) + (self.theoretical_quantiles**2) / np.sum(self.theoretical_quantiles**2) ) # Critical value for the confidence level z_crit = stats.norm.ppf((1 + confidence_level) / 2) upper_band = ( self.slope * self.theoretical_quantiles + self.intercept + z_crit * se ) lower_band = ( self.slope * self.theoretical_quantiles + self.intercept - z_crit * se ) ax.fill_between( self.theoretical_quantiles, lower_band, upper_band, color="red", alpha=0.2, label=f"{confidence_level * 100:.0f}% Confidence Band", ) # Add annotations if provided if self.annotations: annotate_points( ax, self.theoretical_quantiles, self.sample_quantiles, self.annotations ) # Enforce equal scaling so the reference line is visually meaningful ax.set_aspect("equal", adjustable="box") # When available make the axes box square to avoid tiny drawing areas set_box_aspect = getattr(ax, "set_box_aspect", None) if callable(set_box_aspect): set_box_aspect(1)