Source code for chemotools.plotting._predicted_vs_actual

"""Predicted vs Actual plot for regression model evaluation."""

from typing import Literal, Optional, Any, Tuple
import numpy as np
from matplotlib.figure import Figure
from matplotlib.axes import Axes

from chemotools.plotting._base import BasePlot, ColoringMixin
from chemotools.plotting._utils import validate_data, scatter_with_colormap


[docs] class PredictedVsActualPlot(BasePlot, ColoringMixin): """Scatter plot of predicted vs actual values to assess regression fit. This class creates scatter plots comparing predicted values against actual (true) values with an ideal prediction line (y=x). Useful for visually assessing model accuracy and detecting systematic errors or bias. Parameters ---------- y_true : np.ndarray True (actual) y values with shape (n_samples,) for univariate or (n_samples, n_targets) for multivariate regression. y_pred : np.ndarray Predicted y values with same shape as y_true. target_index : int, optional For multivariate predictions, which target to plot (default: 0). Ignored if y_true/y_pred are 1D. color_by : np.ndarray, optional Values for coloring samples. Can be either: - Continuous (numeric): shows colorbar - Categorical (strings/classes): shows legend with discrete colors label : str, optional Legend label for this dataset (default: None). color : str, optional Color for all points when color_by is None (default: auto-assigned). colormap : str, optional Colormap name. Colorblind-friendly defaults: - "tab10" for categorical data - "viridis" for continuous data marker : str, optional Marker style for scatter points (default: "o"). Examples: "o", "s", "^", "v", "D". add_ideal_line : bool, optional Whether to add diagonal y=x line showing ideal predictions (default: True). color_mode : {"continuous", "categorical"}, optional Explicitly specify coloring mode. If None (default), automatically detects based on dtype and unique values of color_by. colorbar_label : str, optional Label for the colorbar when using continuous coloring. Default is "Value". Only applies when color_by is continuous. Raises ------ ValueError If y_true and y_pred have mismatched shapes. Examples -------- **Basic predicted vs actual plot:** >>> plot = PredictedVsActualPlot(y_true, y_pred) >>> fig = plot.show(title="Model Performance") **With categorical coloring (e.g., by batch):** >>> batches = np.array(['A', 'B', 'A', 'B', ...]) >>> plot = PredictedVsActualPlot(y_true, y_pred, color_by=batches) >>> fig = plot.show(title="Predictions by Batch") **Multiple models compared:** >>> fig, ax = plt.subplots() >>> PredictedVsActualPlot(y_true, y_pred_model1, label="Model 1", color="blue").render(ax) >>> PredictedVsActualPlot(y_true, y_pred_model2, label="Model 2", color="red").render(ax) >>> ax.legend() >>> plt.show() **Multivariate regression:** >>> fig, axes = plt.subplots(1, 3, figsize=(15, 5)) >>> for i in range(3): ... PredictedVsActualPlot(y_true, y_pred, target_index=i).render(axes[i]) ... axes[i].set_title(f"Target {i+1}") >>> plt.tight_layout() >>> plt.show() """ def __init__( self, y_true: np.ndarray, y_pred: np.ndarray, *, target_index: int = 0, color_by: Optional[np.ndarray] = None, label: Optional[str] = None, color: Optional[str] = None, colormap: Optional[str] = None, marker: str = "o", add_ideal_line: bool = True, color_mode: Optional[Literal["continuous", "categorical"]] = None, colorbar_label: str = "Value", ): # Validate and store input data self.y_true = validate_data(y_true, name="y_true", ensure_2d=False) self.y_pred = validate_data(y_pred, name="y_pred", ensure_2d=False) self.target_index = target_index self.label = label self.color = color self.marker = marker self.add_ideal_line = add_ideal_line self.colorbar_label = colorbar_label # Validate shapes match if self.y_true.shape != self.y_pred.shape: raise ValueError( f"y_true and y_pred must have same shape. " f"Got y_true: {self.y_true.shape}, y_pred: {self.y_pred.shape}" ) # Validate color_by if provided if color_by is not None: color_by = validate_data( color_by, name="color_by", ensure_2d=False, numeric=False ) # Extract the specific target if multivariate if self.y_true.ndim == 2: if target_index >= self.y_true.shape[1]: raise ValueError( f"target_index {target_index} is out of bounds for " f"y_true with {self.y_true.shape[1]} targets" ) self.y_true_1d = self.y_true[:, target_index] self.y_pred_1d = self.y_pred[:, target_index] else: self.y_true_1d = self.y_true self.y_pred_1d = self.y_pred # Initialize coloring self._init_coloring( color_by, colormap, color_mode=color_mode, colorbar_label=self.colorbar_label, ) def _get_default_labels(self) -> dict[str, str]: if self.y_true.ndim == 2: default_title = f"Predicted vs Actual (Target {self.target_index + 1})" else: default_title = "Predicted vs Actual" return { "xlabel": "Actual", "ylabel": "Predicted", "title": default_title, }
[docs] def show( self, *, figsize: Optional[Tuple[float, float]] = None, title: Optional[str] = None, xlabel: Optional[str] = None, ylabel: Optional[str] = None, xlim: Optional[Tuple[float, float]] = None, ylim: Optional[Tuple[float, float]] = None, **kwargs: Any, ) -> Figure: """Create and return a complete figure with the predicted vs axtual plot. This method handles figure creation and then delegates to `render()`. Parameters ---------- figsize : tuple[float, float], optional Figure size in inches (width, height). title : str, optional Figure title. xlabel : str, optional Custom x-axis label. If None, uses existing label or default. ylabel : str, optional Custom y-axis label. If None, uses existing label or default. xlim : tuple[float, float], optional X-axis limits as (xmin, xmax). ylim : tuple[float, float], optional Y-axis limits as (ymin, ymax). **kwargs : Any Additional keyword arguments passed to the render() method. Returns ------- Figure The matplotlib Figure object containing the plot. """ return super().show( figsize=figsize, title=title, xlabel=xlabel, ylabel=ylabel, xlim=xlim, ylim=ylim, **kwargs, )
[docs] def render( self, ax: Optional[Axes] = None, *, xlabel: Optional[str] = None, ylabel: Optional[str] = None, xlim: Optional[tuple[float, float]] = None, ylim: Optional[tuple[float, float]] = None, **kwargs: Any, ) -> tuple[Figure, Axes]: """Render the plot on existing or new axes. Parameters ---------- ax : Axes, optional Matplotlib axes to render on. If None, creates new figure/axes. xlabel : str, optional Custom x-axis label. If None, uses existing label or default. ylabel : str, optional Custom y-axis label. If None, uses existing label or default. xlim : tuple[float, float], optional X-axis limits (min, max). ylim : tuple[float, float], optional Y-axis limits (min, max). **kwargs : Any Additional keyword arguments passed to scatter plot. Returns ------- tuple[Figure, Axes] The Figure and Axes objects containing the plot. """ # Remove figsize from kwargs as it is not used in render kwargs.pop("figsize", None) fig, ax = super().render( ax=ax, xlabel=xlabel, ylabel=ylabel, xlim=xlim, ylim=ylim, **kwargs, ) # Add colorbar for continuous data self._add_colorbar_if_needed(ax) # Add legend if categorical or if label is provided if self.is_categorical or self.label: # Only add legend if there are labeled artists handles, labels = ax.get_legend_handles_labels() if handles: ax.legend() return fig, ax
def _render_plot(self, ax: Axes, **kwargs: Any) -> None: """Internal method to render the plot on given axes.""" alpha = kwargs.pop("alpha", 0.7) s = kwargs.pop("s", 50) edgecolors = kwargs.pop("edgecolors", "none") # Determine colors scatter_with_colormap( ax, self.y_true_1d, self.y_pred_1d, color_by=self.color_by, is_categorical=self.is_categorical, colormap=self.colormap, color=self.color if self.color is not None else "steelblue", label=self.label, alpha=alpha, s=s, marker=self.marker, edgecolors=edgecolors, **kwargs, ) # Add ideal prediction line (y=x) if self.add_ideal_line: lims = [ min(self.y_true_1d.min(), self.y_pred_1d.min()), max(self.y_true_1d.max(), self.y_pred_1d.max()), ] ax.plot( lims, lims, "k--", alpha=0.5, zorder=0, label="Ideal", linewidth=1.5 )