Source code for chemotools.plotting._feature_selection

"""
The :mod:`chemotools.plotting._feature_selection` module implements the FeatureSelectionPlot class.
"""

from typing import Any, Optional

import numpy as np
from matplotlib.axes import Axes
from matplotlib.figure import Figure

from chemotools.plotting._spectra import SpectraPlot
from chemotools.plotting._utils import validate_data


[docs] class FeatureSelectionPlot(SpectraPlot): """Plot class for visualizing feature selection on spectral data. This class extends SpectraPlot to highlight excluded features using colored vertical spans. Parameters ---------- x : np.ndarray X-axis data (e.g., wavelengths, wavenumbers). y : np.ndarray Y-axis data (e.g., spectra intensities). support : np.ndarray Boolean mask indicating selected features (True means selected). Must have same length as x. selection_color : str, optional Color to use for highlighting excluded features. Default is "red". selection_alpha : float, optional Transparency of the selection highlight. Default is 0.2. **kwargs : Any Additional arguments passed to SpectraPlot.__init__. """ def __init__( self, x: np.ndarray, y: np.ndarray, support: np.ndarray, *, selection_color: str = "red", selection_alpha: float = 0.2, **kwargs: Any, ): super().__init__(x, y, **kwargs) self.support = validate_data( support, name="support", ensure_2d=False, numeric=False ).astype(bool) if len(self.support) != len(self.x): raise ValueError( f"Support mask length ({len(self.support)}) must match x length ({len(self.x)})" ) self.selection_color = selection_color self.selection_alpha = selection_alpha
[docs] def show( self, figsize=(12, 4), title=None, xlabel="X-axis", ylabel="Y-axis", xlim=None, ylim=None, **kwargs, ) -> Figure: """Show the spectra plot with given figure size and labels. The excluded features are highlighted with red bars. Parameters ---------- figsize : tuple, optional Figure size as (width, height) in inches. Default is (12, 4). title : str, optional Title for the plot. If None, a default title is generated. xlabel : str, optional X-axis label. Default is "X-axis". ylabel : str, optional Y-axis label. Default is "Y-axis". xlim : tuple, optional X-axis limits as (xmin, xmax). Default is None (auto). ylim : tuple, optional Y-axis limits as (ymin, ymax). Default is None (auto). **kwargs : Any Additional keyword arguments passed to the plot function. Returns ------- Figure The matplotlib Figure object containing the plot. """ return super().show( figsize=figsize, title=title, xlabel=xlabel, ylabel=ylabel, xlim=xlim, ylim=ylim, **kwargs, )
[docs] def render( self, ax: Optional[Axes] = None, *, xlabel: Optional[str] = None, ylabel: Optional[str] = None, xlim: Optional[tuple[float, float]] = None, ylim: Optional[tuple[float, float]] = None, **kwargs: Any, ) -> tuple[Figure, Axes]: """Render the plot on the given axes or create new ones. Parameters ---------- ax : Axes, optional Matplotlib axes to plot on. If None, creates new figure and axes. xlabel : str, optional X-axis label. Default is "X-axis". ylabel : str, optional Y-axis label. Default is "Y-axis". xlim : tuple[float, float], optional X-axis limits as (xmin, xmax). When set without ylim, the y-axis automatically scales to fit the data within the x-range. ylim : tuple[float, float], optional Y-axis limits as (ymin, ymax). When provided, disables automatic y-scaling. **kwargs : Any Additional keyword arguments passed to the plot function. Returns ------- fig : Figure The matplotlib Figure object. ax : Axes The matplotlib Axes object with the rendered plot. """ fig, ax = super().render( ax=ax, xlabel=xlabel, ylabel=ylabel, xlim=xlim, ylim=ylim, **kwargs, ) return fig, ax
def _render_plot(self, ax: Axes, **kwargs: Any) -> None: """Render the plot with feature selection highlights. Parameters ---------- ax : Axes Matplotlib axes to plot on. **kwargs : Any Additional keyword arguments passed to the plot function. """ # 1. Render the standard spectrum plot super()._render_plot(ax, **kwargs) # 2. Overlay the selection regions (highlight excluded features) regions = self._get_continuous_regions(~self.support) # Add label only once for the legend label_added = False for start_idx, end_idx in regions: # Get x coordinates for the span # We extend slightly to cover the full point width if needed, # but for now using exact point coordinates is standard x_start = self.x[start_idx] x_end = self.x[end_idx] # Handle case where start > end (e.g. wavenumbers in descending order) if x_start > x_end: x_start, x_end = x_end, x_start label = "Excluded Features" if not label_added else None ax.axvspan( x_start, x_end, color=self.selection_color, alpha=self.selection_alpha, label=label, zorder=-1, # Put behind spectra ) label_added = True def _get_continuous_regions(self, mask: np.ndarray) -> list[tuple[int, int]]: """Convert boolean mask to list of (start, end) indices.""" # Pad with False to detect edges padded = np.concatenate(([False], mask, [False])) # Find where values change # diff = 1 means False -> True (Start) # diff = -1 means True -> False (End) changes = np.diff(padded.astype(int)) starts = np.where(changes == 1)[0] ends = np.where(changes == -1)[0] - 1 return list(zip(starts, ends))