Source code for fastcpd.visualization

"""Visualization utilities for change point detection.

This module provides plotting functions to visualize:
- Change point detection results
- Metric comparisons across algorithms
- Multi-annotator scenarios
- Dataset characteristics
"""

import numpy as np
from typing import Union, List, Dict, Optional, Any
import warnings

try:
    import matplotlib.pyplot as plt
    import matplotlib.patches as mpatches
    from matplotlib.gridspec import GridSpec
    HAS_MATPLOTLIB = True
except ImportError:
    HAS_MATPLOTLIB = False
    warnings.warn(
        "matplotlib not installed. Install with: pip install matplotlib",
        ImportWarning
    )


def _check_matplotlib():
    """Check if matplotlib is available."""
    if not HAS_MATPLOTLIB:
        raise ImportError(
            "matplotlib is required for visualization. "
            "Install with: pip install matplotlib"
        )


[docs] def plot_detection( data: Union[np.ndarray, List], true_cps: Union[List, np.ndarray], pred_cps: Union[List, np.ndarray], metric_result: Optional[Dict] = None, title: str = "Change Point Detection", figsize: tuple = (14, 8), show_legend: bool = True, save_path: Optional[str] = None ) -> tuple: """Plot change point detection results. Args: data: Time series data (1D or 2D array) true_cps: True change point indices pred_cps: Predicted change point indices metric_result: Optional dict from metrics.evaluate_all() title: Plot title figsize: Figure size (width, height) show_legend: Whether to show legend save_path: Optional path to save figure Returns: Tuple of (fig, axes) Examples: >>> from fastcpd.visualization import plot_detection >>> from fastcpd.datasets import make_mean_change >>> from fastcpd import fastcpd >>> from fastcpd.metrics import evaluate_all >>> >>> # Generate data >>> data_dict = make_mean_change(n_samples=500, n_changepoints=3) >>> result = fastcpd(data_dict['data'], family='mean') >>> >>> # Evaluate and plot >>> metrics = evaluate_all(data_dict['changepoints'], result.cp_set.tolist(), ... n_samples=500, margin=10) >>> plot_detection(data_dict['data'], data_dict['changepoints'], ... result.cp_set.tolist(), metrics) """ _check_matplotlib() data = np.asarray(data) true_cps = np.asarray(true_cps) pred_cps = np.asarray(pred_cps) # Create figure with optional metrics panel if metric_result is not None: fig = plt.figure(figsize=figsize) gs = GridSpec(2, 1, height_ratios=[3, 1], hspace=0.3) ax_main = fig.add_subplot(gs[0]) ax_metrics = fig.add_subplot(gs[1]) else: fig, ax_main = plt.subplots(figsize=figsize) ax_metrics = None # Plot data if data.ndim == 1: ax_main.plot(data, 'k-', linewidth=0.8, alpha=0.7, label='Data') else: # Plot first 3 dimensions for i in range(min(3, data.shape[1])): ax_main.plot(data[:, i], linewidth=0.8, alpha=0.7, label=f'Dim {i+1}') # Plot true change points for i, cp in enumerate(true_cps): label = 'True CP' if i == 0 else None ax_main.axvline(x=cp, color='green', linestyle='--', linewidth=2, alpha=0.8, label=label) # Plot predicted change points for i, cp in enumerate(pred_cps): label = 'Predicted CP' if i == 0 else None ax_main.axvline(x=cp, color='red', linestyle=':', linewidth=2, alpha=0.8, label=label) ax_main.set_xlabel('Time', fontsize=12) ax_main.set_ylabel('Value', fontsize=12) ax_main.set_title(title, fontsize=14, fontweight='bold') ax_main.grid(True, alpha=0.3) if show_legend: ax_main.legend(loc='upper right', fontsize=10) # Add metrics panel if provided if metric_result is not None and ax_metrics is not None: ax_metrics.axis('off') # Extract key metrics pm = metric_result.get('point_metrics', {}) dm = metric_result.get('distance_metrics', {}) sm = metric_result.get('segmentation_metrics', {}) metrics_text = ( f"Precision: {pm.get('precision', 0):.3f} | " f"Recall: {pm.get('recall', 0):.3f} | " f"F1: {pm.get('f1_score', 0):.3f} | " f"Hausdorff: {dm.get('hausdorff', np.nan):.1f} | " f"ARI: {sm.get('adjusted_rand_index', 0):.3f}\n" f"TP: {pm.get('true_positives', 0)} | " f"FP: {pm.get('false_positives', 0)} | " f"FN: {pm.get('false_negatives', 0)}" ) ax_metrics.text(0.5, 0.5, metrics_text, transform=ax_metrics.transAxes, fontsize=11, ha='center', va='center', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.3)) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') return fig, (ax_main, ax_metrics) if ax_metrics else (ax_main,)
[docs] def plot_metric_comparison( results_dict: Dict[str, Dict], metrics: List[str] = None, figsize: tuple = (12, 6), save_path: Optional[str] = None ) -> tuple: """Compare metrics across multiple algorithms. Args: results_dict: Dict mapping algorithm names to metric results {'PELT': metrics_dict, 'SeGD': metrics_dict, ...} metrics: List of metrics to compare (default: precision, recall, f1) figsize: Figure size save_path: Optional path to save figure Returns: Tuple of (fig, ax) Examples: >>> from fastcpd.metrics import evaluate_all >>> from fastcpd.visualization import plot_metric_comparison >>> >>> # Assume you have multiple algorithm results >>> pelt_metrics = evaluate_all(true_cps, pelt_cps, n_samples=500, margin=10) >>> segd_metrics = evaluate_all(true_cps, segd_cps, n_samples=500, margin=10) >>> >>> results = {'PELT': pelt_metrics, 'SeGD': segd_metrics} >>> plot_metric_comparison(results) """ _check_matplotlib() if metrics is None: metrics = ['precision', 'recall', 'f1_score'] fig, ax = plt.subplots(figsize=figsize) algorithms = list(results_dict.keys()) n_algorithms = len(algorithms) n_metrics = len(metrics) # Prepare data metric_values = {metric: [] for metric in metrics} for algo in algorithms: result = results_dict[algo] pm = result.get('point_metrics', {}) for metric in metrics: value = pm.get(metric, 0) metric_values[metric].append(value) # Plot grouped bar chart x = np.arange(n_algorithms) width = 0.8 / n_metrics for i, metric in enumerate(metrics): offset = (i - n_metrics/2 + 0.5) * width ax.bar(x + offset, metric_values[metric], width, label=metric.replace('_', ' ').title()) ax.set_xlabel('Algorithm', fontsize=12, fontweight='bold') ax.set_ylabel('Score', fontsize=12, fontweight='bold') ax.set_title('Algorithm Comparison', fontsize=14, fontweight='bold') ax.set_xticks(x) ax.set_xticklabels(algorithms) ax.legend(loc='upper right', fontsize=10) ax.grid(True, alpha=0.3, axis='y') ax.set_ylim([0, 1.1]) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') return fig, ax
[docs] def plot_annotators( data: Union[np.ndarray, List], annotators_list: List[List], pred_cps: Union[List, np.ndarray], title: str = "Multi-Annotator Change Point Detection", figsize: tuple = (14, 6), save_path: Optional[str] = None ) -> tuple: """Visualize multi-annotator scenario. Args: data: Time series data annotators_list: List of lists, each containing one annotator's CPs pred_cps: Predicted change points title: Plot title figsize: Figure size save_path: Optional path to save figure Returns: Tuple of (fig, ax) Examples: >>> from fastcpd.datasets import add_annotation_noise >>> from fastcpd.visualization import plot_annotators >>> >>> true_cps = [100, 200, 300] >>> annotators = add_annotation_noise(true_cps, n_annotators=5, seed=42) >>> pred_cps = [98, 205, 295] >>> >>> plot_annotators(data, annotators, pred_cps) """ _check_matplotlib() data = np.asarray(data) pred_cps = np.asarray(pred_cps) fig, ax = plt.subplots(figsize=figsize) # Plot data if data.ndim == 1: ax.plot(data, 'k-', linewidth=0.8, alpha=0.5, label='Data') else: ax.plot(data[:, 0], 'k-', linewidth=0.8, alpha=0.5, label='Data') # Plot each annotator's CPs with different colors colors = plt.cm.Set3(np.linspace(0, 1, len(annotators_list))) for i, (annotator_cps, color) in enumerate(zip(annotators_list, colors)): for j, cp in enumerate(annotator_cps): label = f'Annotator {i+1}' if j == 0 else None ax.axvline(x=cp, color=color, linestyle='--', linewidth=1.5, alpha=0.6, label=label) # Plot predictions for i, cp in enumerate(pred_cps): label = 'Algorithm' if i == 0 else None ax.axvline(x=cp, color='red', linestyle='-', linewidth=2.5, alpha=0.8, label=label) ax.set_xlabel('Time', fontsize=12) ax.set_ylabel('Value', fontsize=12) ax.set_title(title, fontsize=14, fontweight='bold') ax.grid(True, alpha=0.3) ax.legend(loc='upper right', fontsize=9, ncol=2) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') return fig, ax
[docs] def plot_dataset_characteristics( data_dict: Dict[str, Any], figsize: tuple = (14, 10), save_path: Optional[str] = None ) -> tuple: """Visualize dataset characteristics and metadata. Args: data_dict: Dictionary from datasets module (e.g., make_mean_change) figsize: Figure size save_path: Optional path to save figure Returns: Tuple of (fig, axes) Examples: >>> from fastcpd.datasets import make_mean_change >>> from fastcpd.visualization import plot_dataset_characteristics >>> >>> data_dict = make_mean_change(n_samples=500, n_changepoints=3, seed=42) >>> plot_dataset_characteristics(data_dict) """ _check_matplotlib() fig = plt.figure(figsize=figsize) gs = GridSpec(3, 2, hspace=0.3, wspace=0.3) data = np.asarray(data_dict['data']) changepoints = data_dict['changepoints'] metadata = data_dict.get('metadata', {}) # 1. Data plot with CPs ax1 = fig.add_subplot(gs[0, :]) if data.ndim == 1: ax1.plot(data, 'b-', linewidth=0.8, alpha=0.7) else: for i in range(min(3, data.shape[1])): ax1.plot(data[:, i], linewidth=0.8, alpha=0.7, label=f'Dim {i+1}') for cp in changepoints: ax1.axvline(x=cp, color='red', linestyle='--', linewidth=2, alpha=0.7) ax1.set_title('Data with True Change Points', fontsize=12, fontweight='bold') ax1.set_xlabel('Time') ax1.set_ylabel('Value') ax1.grid(True, alpha=0.3) if data.ndim > 1: ax1.legend() # 2. Histogram ax2 = fig.add_subplot(gs[1, 0]) ax2.hist(data.flatten(), bins=50, alpha=0.7, edgecolor='black') ax2.set_title('Data Distribution', fontsize=11, fontweight='bold') ax2.set_xlabel('Value') ax2.set_ylabel('Frequency') ax2.grid(True, alpha=0.3, axis='y') # 3. Segment lengths ax3 = fig.add_subplot(gs[1, 1]) segment_lengths = metadata.get('segment_lengths', []) if segment_lengths: ax3.bar(range(len(segment_lengths)), segment_lengths, alpha=0.7, edgecolor='black') ax3.set_title('Segment Lengths', fontsize=11, fontweight='bold') ax3.set_xlabel('Segment') ax3.set_ylabel('Length') ax3.grid(True, alpha=0.3, axis='y') # 4. Metadata text ax4 = fig.add_subplot(gs[2, :]) ax4.axis('off') metadata_text = "Dataset Metadata:\n" metadata_text += f" n_samples: {len(data) if data.ndim == 1 else data.shape[0]}\n" metadata_text += f" n_changepoints: {len(changepoints)}\n" # Add specific metadata if 'snr_db' in metadata: metadata_text += f" SNR: {metadata['snr_db']:.2f} dB\n" if 'difficulty' in metadata: metadata_text += f" Difficulty: {metadata['difficulty']:.3f}\n" if 'change_type' in metadata: metadata_text += f" Change type: {metadata['change_type']}\n" if 'r_squared_per_segment' in metadata: r2_vals = metadata['r_squared_per_segment'] metadata_text += f" R² per segment: {[f'{r:.3f}' for r in r2_vals]}\n" if 'family' in metadata: metadata_text += f" Family: {metadata['family']}\n" ax4.text(0.5, 0.5, metadata_text, transform=ax4.transAxes, fontsize=11, ha='center', va='center', family='monospace', bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.3)) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') return fig, (ax1, ax2, ax3, ax4)
[docs] def plot_roc_curve( true_cps: Union[List, np.ndarray], pred_cps_list: List[Union[List, np.ndarray]], labels: List[str], n_samples: int, figsize: tuple = (8, 8), save_path: Optional[str] = None ) -> tuple: """Plot ROC-like curve for change point detection. For each threshold (margin), computes precision and recall. Args: true_cps: True change points pred_cps_list: List of predicted CP arrays (one per algorithm) labels: Algorithm labels n_samples: Total number of samples figsize: Figure size save_path: Optional path to save figure Returns: Tuple of (fig, ax) """ _check_matplotlib() from fastcpd.metrics import precision_recall fig, ax = plt.subplots(figsize=figsize) margins = [1, 2, 5, 10, 20, 50, 100] for pred_cps, label in zip(pred_cps_list, labels): precisions = [] recalls = [] for margin in margins: result = precision_recall(true_cps, pred_cps, margin=margin) precisions.append(result['precision']) recalls.append(result['recall']) ax.plot(recalls, precisions, 'o-', linewidth=2, label=label) ax.set_xlabel('Recall', fontsize=12, fontweight='bold') ax.set_ylabel('Precision', fontsize=12, fontweight='bold') ax.set_title('Precision-Recall Curve (varying margin)', fontsize=14, fontweight='bold') ax.legend(loc='lower left', fontsize=10) ax.grid(True, alpha=0.3) ax.set_xlim([0, 1.05]) ax.set_ylim([0, 1.05]) # Add diagonal reference line ax.plot([0, 1], [0, 1], 'k--', alpha=0.3, linewidth=1) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') return fig, ax