my_project.plots

Utilities to evaluate a trained model and generate performance visualizations.

  • evaluate_and_plot(...) : runs eval on a dataloader and saves:
    • confusion_matrix.png
    • per_class_accuracy.png
    • misclassified_grid.png
  • plot_curves_from_csvlogger(...) : plots curves from Lightning's CSVLogger
    • train_loss_[step|epoch].png
    • val_acc_epoch.png (if logged during validation) Outputs are saved under reports/figures/ by default.

This module does NOT modify your model architecture.

  1"""
  2Utilities to evaluate a trained model and generate performance visualizations.
  3- evaluate_and_plot(...) : runs eval on a dataloader and saves:
  4    * confusion_matrix.png
  5    * per_class_accuracy.png
  6    * misclassified_grid.png
  7- plot_curves_from_csvlogger(...) : plots curves from Lightning's CSVLogger
  8    * train_loss_[step|epoch].png
  9    * val_acc_epoch.png (if logged during validation)
 10Outputs are saved under reports/figures/ by default.
 11
 12This module does NOT modify your model architecture.
 13"""
 14
 15import os
 16import math
 17from typing import Tuple, List, Optional
 18
 19import torch
 20import torch.nn.functional as F
 21import matplotlib.pyplot as plt
 22import numpy as np
 23import pandas as pd
 24from sklearn.metrics import confusion_matrix
 25from sklearn.calibration import calibration_curve
 26from scipy.cluster import hierarchy as sch
 27from scipy.spatial.distance import pdist
 28
 29
 30# Fashion-MNIST class names (0..9)
 31FASHION_CLASSES = [
 32    "T-shirt/top",
 33    "Trouser",
 34    "Pullover",
 35    "Dress",
 36    "Coat",
 37    "Sandal",
 38    "Shirt",
 39    "Sneaker",
 40    "Bag",
 41    "Ankle boot",
 42]
 43
 44
 45def _ensure_dir(path: str) -> None:
 46    os.makedirs(path, exist_ok=True)
 47
 48
 49def _to_numpy(x: torch.Tensor) -> np.ndarray:
 50    return x.detach().cpu().numpy()
 51
 52
 53def _get_device(module: torch.nn.Module) -> torch.device:
 54    return next(module.parameters()).device
 55
 56
 57@torch.no_grad()
 58def evaluate_model(
 59    model: torch.nn.Module,
 60    dataloader: torch.utils.data.DataLoader,
 61) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
 62    """
 63    Run the model over a dataloader and return predictions.
 64
 65    Parameters
 66    ----------
 67    model : torch.nn.Module
 68        Trained model to evaluate.
 69    dataloader : torch.utils.data.DataLoader
 70        DataLoader for evaluation (test/validation).
 71
 72    Returns
 73    -------
 74    tuple
 75        (y_true, y_pred, y_prob_max):
 76        - y_true (ndarray[int]): Ground-truth labels.
 77        - y_pred (ndarray[int]): Predicted labels.
 78        - y_prob_max (ndarray[float]): Max softmax confidence per sample.
 79    """
 80
 81    model.eval()
 82    device = _get_device(model)
 83
 84    all_true: List[int] = []
 85    all_pred: List[int] = []
 86    all_conf: List[float] = []
 87
 88    for xb, yb in dataloader:
 89        xb = xb.to(device)
 90        yb = yb.to(device)
 91
 92        logits = model(xb)
 93        probs = F.softmax(logits, dim=1)
 94        conf, pred = probs.max(dim=1)
 95
 96        all_true.extend(_to_numpy(yb))
 97        all_pred.extend(_to_numpy(pred))
 98        all_conf.extend(_to_numpy(conf))
 99
100    return (
101        np.asarray(all_true, dtype=np.int64),
102        np.asarray(all_pred, dtype=np.int64),
103        np.asarray(all_conf, dtype=np.float32),
104    )
105
106
107def plot_confusion_matrix(
108    y_true: np.ndarray,
109    y_pred: np.ndarray,
110    class_names: List[str],
111    out_path: str,
112    normalize: bool = True,
113) -> None:
114    """
115    Plot and save a confusion matrix.
116
117    Parameters
118    ----------
119    y_true : np.ndarray
120        Ground-truth labels.
121    y_pred : np.ndarray
122        Predicted labels.
123    class_names : list of str
124        Class names corresponding to label indices.
125    out_path : str
126        File path to save the plot.
127    normalize : bool, optional (default=True)
128        If True, normalize counts to percentages.
129    """
130
131    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(class_names))))
132    if normalize:
133        with np.errstate(all="ignore"):
134            cm = cm.astype(np.float64) / cm.sum(axis=1, keepdims=True)
135        cm = np.nan_to_num(cm)
136
137    plt.figure(figsize=(6, 5))
138    plt.imshow(cm, interpolation="nearest", cmap="coolwarm")
139    plt.title("Confusion Matrix " if normalize else "Confusion Matrix")
140    plt.colorbar()
141    tick_marks = np.arange(len(class_names))
142    plt.xticks(tick_marks, class_names, rotation=45, ha="right")
143    plt.yticks(tick_marks, class_names)
144
145    # Annotate cells
146    thresh = cm.max() / 2.0 if cm.size > 0 else 0.5
147    for i in range(cm.shape[0]):
148        for j in range(cm.shape[1]):
149            val = cm[i, j]
150            txt = f"{val:.2f}" if normalize else f"{int(val)}"
151            plt.text(
152                j,
153                i,
154                txt,
155                horizontalalignment="center",
156                verticalalignment="center",
157                fontsize=8,
158                color="white" if cm[i, j] > thresh else "black",
159            )
160
161    plt.ylabel("True label")
162    plt.xlabel("Predicted label")
163    plt.tight_layout()
164    plt.savefig(out_path, dpi=160)
165    plt.close()
166
167
168def plot_per_class_accuracy(
169    y_true: np.ndarray,
170    y_pred: np.ndarray,
171    class_names: List[str],
172    out_path: str,
173) -> None:
174    """
175    Plot and save per-class accuracy.
176
177    Parameters
178    ----------
179    y_true : np.ndarray
180        Ground-truth labels.
181    y_pred : np.ndarray
182        Predicted labels.
183    class_names : list of str
184        Class names corresponding to label indices.
185    out_path : str
186        File path to save the plot.
187    """
188
189    num_classes = len(class_names)
190    correct = np.zeros(num_classes, dtype=np.int64)
191    total = np.zeros(num_classes, dtype=np.int64)
192
193    for t, p in zip(y_true, y_pred):
194        total[t] += 1
195        if t == p:
196            correct[t] += 1
197
198    acc = np.divide(correct, np.maximum(total, 1), where=total > 0)
199
200    plt.figure(figsize=(9, 4))
201    plt.bar(np.arange(num_classes), acc)
202    plt.ylim(0, 1)
203    plt.xticks(np.arange(num_classes), class_names, rotation=45, ha="right")
204    plt.ylabel("Accuracy")
205    plt.title("Per-class Accuracy")
206    plt.tight_layout()
207    plt.savefig(out_path, dpi=160)
208    plt.close()
209
210
211def _denormalize_img(img: torch.Tensor) -> np.ndarray:
212    """
213    Convert a normalized tensor image back to NumPy format.
214
215    Parameters
216    ----------
217    img : torch.Tensor
218        Image tensor (CHW or HW).
219
220    Returns
221    -------
222    np.ndarray
223        Image array suitable for plotting.
224    """
225
226    if img.dim() == 3 and img.shape[0] == 1:
227        img_np = img.squeeze(0).detach().cpu().numpy()
228        return img_np
229    # If already HxW
230    return img.detach().cpu().numpy()
231
232
233@torch.no_grad()
234def plot_misclassified_grid(
235    model: torch.nn.Module,
236    dataloader: torch.utils.data.DataLoader,
237    class_names: List[str],
238    out_path: str,
239    max_examples: int = 16,
240) -> None:
241    """
242    Plot a grid of misclassified samples.
243
244    Parameters
245    ----------
246    model : torch.nn.Module
247        Trained model to evaluate.
248    dataloader : torch.utils.data.DataLoader
249        DataLoader for evaluation (test/validation).
250    class_names : list of str
251        Class names corresponding to label indices.
252    out_path : str
253        File path to save the plot.
254    max_examples : int, optional (default=16)
255        Maximum number of misclassified samples to show.
256    """
257    model.eval()
258    device = _get_device(model)
259
260    images = []
261    labels_true = []
262    labels_pred = []
263
264    # Collect up to max_examples misclassified samples
265    for xb, yb in dataloader:
266        xb = xb.to(device)
267        yb = yb.to(device)
268        logits = model(xb)
269        pred = logits.argmax(dim=1)
270
271        mism = pred != yb
272        if mism.any():
273            mis_idx = torch.nonzero(mism).flatten()
274            for idx in mis_idx:
275                images.append(xb[idx].cpu())
276                labels_true.append(int(yb[idx].cpu()))
277                labels_pred.append(int(pred[idx].cpu()))
278                if len(images) >= max_examples:
279                    break
280        if len(images) >= max_examples:
281            break
282
283    if len(images) == 0:
284        # Nothing to plot
285        fig = plt.figure(figsize=(6, 2))
286        plt.text(0.5, 0.5, "No misclassified samples found.", ha="center", va="center")
287        plt.axis("off")
288        plt.tight_layout()
289        fig.savefig(out_path, dpi=160)
290        plt.close()
291        return
292
293    # Determine grid size
294    cols = int(math.ceil(math.sqrt(len(images))))
295    rows = int(math.ceil(len(images) / cols))
296
297    plt.figure(figsize=(cols * 1.5, rows * 1.5))
298    for i, img in enumerate(images):
299        ax = plt.subplot(rows, cols, i + 1)
300        ax.imshow(_denormalize_img(img), cmap="gray")
301        true_name = class_names[labels_true[i]]
302        pred_name = class_names[labels_pred[i]]
303        ax.set_title(f"{true_name} → {pred_name}", fontsize=8)
304        ax.axis("off")
305
306    plt.suptitle("Misclassified Samples", y=0.98)
307    plt.tight_layout()
308    plt.savefig(out_path, dpi=160)
309    plt.close()
310
311
312def evaluate_and_plot(
313    model: torch.nn.Module,
314    datamodule,
315    out_dir: str = "reports/figures",
316) -> dict:
317
318    _ensure_dir(out_dir)
319
320    # Use the datamodule's test dataloader (Lightning-style)
321    test_loader = datamodule.test_dataloader()
322
323    # 1) Predictions + basic accuracy
324    y_true, y_pred, y_conf = evaluate_model(model, test_loader)
325    test_acc = float((y_true == y_pred).mean())
326
327    # 2) Confusion matrix (normalized)
328    cm_path = os.path.join(out_dir, "confusion_matrix.png")
329    plot_confusion_matrix(y_true, y_pred, FASHION_CLASSES, cm_path, normalize=True)
330
331    # 3) Per-class accuracy bars
332    pca_path = os.path.join(out_dir, "per_class_accuracy.png")
333    plot_per_class_accuracy(y_true, y_pred, FASHION_CLASSES, pca_path)
334
335    # 4) Misclassified samples grid
336    mis_path = os.path.join(out_dir, "misclassified_grid.png")
337    plot_misclassified_grid(
338        model, test_loader, FASHION_CLASSES, mis_path, max_examples=6
339    )
340
341    # 5) Calibration (reliability) curve
342    calib_path = os.path.join(out_dir, "calibration_curve.png")
343    y_correct = (y_true == y_pred).astype(int)
344    plot_calibration_curve(y_correct, y_conf, calib_path, n_bins=10)
345
346    return {
347        "test_accuracy": test_acc,
348        "confusion_matrix": cm_path,
349        "per_class_accuracy": pca_path,
350        "misclassified_grid": mis_path,
351        "calibration_curve": calib_path,
352    }
353
354
355def plot_calibration_curve(y_true, y_prob, out_path, n_bins: int = 10):
356
357    # Calibration curve
358    prob_true, prob_pred = calibration_curve(
359        y_true, y_prob, n_bins=n_bins, strategy="uniform"
360    )
361
362    plt.figure(figsize=(6, 5))
363    plt.plot(prob_pred, prob_true, marker="o", label="Model")
364    plt.plot([0, 1], [0, 1], linestyle="--", color="gray", label="Perfectly calibrated")
365    plt.xlabel("Predicted probability")
366    plt.ylabel("Observed frequency")
367    plt.title("Calibration Curve")
368    plt.legend()
369    plt.tight_layout()
370    plt.savefig(out_path, dpi=160)
371    plt.close()
372
373
374def plot_curves_from_csvlogger(
375    csv_log_dir: str,
376    out_dir: str = "reports/figures",
377    train_loss_keys: Optional[List[str]] = None,
378    val_acc_keys: Optional[List[str]] = None,
379) -> List[str]:
380    """
381    If you use Lightning's CSVLogger, this parses the `metrics.csv` file and
382    produces line plots for:
383      - train loss (step/epoch)
384      - val accuracy (epoch)
385
386    Usage in train.py:
387        from pytorch_lightning.loggers import CSVLogger
388        logger = CSVLogger("logs", name="fashion")
389        trainer = pl.Trainer(..., logger=logger)
390
391    Then call:
392        plot_curves_from_csvlogger(logger.log_dir)
393
394    Args:
395        csv_log_dir: e.g., "logs/fashion/version_0"
396        train_loss_keys: possible metric column names for train loss
397        val_acc_keys: possible metric column names for val acc
398
399    Returns:
400        list of saved figure paths
401    """
402    _ensure_dir(out_dir)
403    metrics_path = os.path.join(csv_log_dir, "metrics.csv")
404    if not os.path.exists(metrics_path):
405        return []
406
407    df = pd.read_csv(metrics_path)
408
409    # Reasonable defaults; adjust if you log different names
410    if train_loss_keys is None:
411        # Lightning often stores as "train_loss_step" and/or "train_loss_epoch"
412        train_loss_keys = ["train_loss_step", "train_loss_epoch", "loss", "train_loss"]
413    if val_acc_keys is None:
414        # You already log 'val_acc' in your model
415        val_acc_keys = ["val_acc", "val_acc_epoch"]
416
417    saved = []
418
419    # Train loss vs step/epoch
420    for key in train_loss_keys:
421        if key in df.columns:
422            # Prefer 'step' index if present, else 'epoch'
423            if "step" in df.columns and not df["step"].isna().all():
424                x = df["step"]
425                x_label = "Step"
426            else:
427                x = df["epoch"] if "epoch" in df.columns else np.arange(len(df))
428                x_label = "Epoch"
429
430            plt.figure(figsize=(6, 4))
431            plt.plot(x, df[key])
432            plt.xlabel(x_label)
433            plt.ylabel("Train Loss")
434            plt.title(f"{key} over {x_label.lower()}")
435            plt.tight_layout()
436            out_path = os.path.join(out_dir, f"{key}.png")
437            plt.savefig(out_path, dpi=160)
438            plt.close()
439            saved.append(out_path)
440            # Only plot first that exists (avoid duplicates)
441            break
442
443    # Validation accuracy vs epoch
444    for key in val_acc_keys:
445        if key in df.columns:
446            x = df["epoch"] if "epoch" in df.columns else np.arange(len(df))
447            plt.figure(figsize=(6, 4))
448            plt.plot(x, df[key])
449            plt.xlabel("Epoch")
450            plt.ylabel("Validation Accuracy")
451            plt.title(f"{key} over epochs")
452            plt.tight_layout()
453            out_path = os.path.join(out_dir, f"{key}.png")
454            plt.savefig(out_path, dpi=160)
455            plt.close()
456            saved.append(out_path)
457            break
458
459    return saved
460
461
462def plot_learning_curves_from_df(
463    df: pd.DataFrame,
464) -> Tuple[Optional[plt.Figure], Optional[plt.Figure]]:
465    """
466    Parses a metrics DataFrame and produces line plots for train loss and val accuracy.
467
468    Args:
469        df (pd.DataFrame): DataFrame with metrics (e.g., from CSVLogger).
470
471    Returns:
472        A tuple of (train_loss_figure, val_acc_figure). Figures can be None.
473    """
474    train_loss_fig, val_acc_fig = None, None
475
476    # Train loss vs step/epoch
477    train_loss_keys = ["train_loss_step", "train_loss_epoch", "loss", "train_loss"]
478    val_acc_keys = ["val_acc", "val_acc_epoch"]
479
480    for key in train_loss_keys:
481        if key in df.columns and not df[key].isna().all():
482            series = df[key].dropna()
483            # Prefer 'step' index if present, else 'epoch'
484            if "step" in df.columns and not df["step"].isna().all():
485                x = df.loc[series.index, "step"]
486                x_label = "Step"
487            else:
488                x = df.loc[series.index, "epoch"]
489                x_label = "Epoch"
490
491            train_loss_fig = plt.figure(figsize=(6, 4))
492            plt.plot(x, series)
493            plt.xlabel(x_label)
494            plt.ylabel("Train Loss")
495            plt.title(f"Training Loss vs. {x_label}")
496            plt.tight_layout()
497            # Only plot first that exists
498            break
499
500    # Validation accuracy vs epoch
501    for key in val_acc_keys:
502        if key in df.columns and not df[key].isna().all():
503            # Filter to get only the steps where validation was run
504            series = df[key].dropna()
505            x = df.loc[series.index, "epoch"]
506            val_acc_fig = plt.figure(figsize=(6, 4))
507            plt.plot(x, series, marker='o', linestyle='-')
508            plt.xlabel("Epoch")
509            plt.ylabel("Validation Accuracy")
510            plt.title("Validation Accuracy vs. Epoch")
511            plt.ylim(0, 1)
512            plt.tight_layout()
513            # Only plot first that exists
514            break
515
516    # Close figures if they were not created to avoid empty plots
517    if train_loss_fig is None: plt.close()
518    if val_acc_fig is None: plt.close()
519
520    return train_loss_fig, val_acc_fig
521
522
523def plot_class_distribution(
524    df: pd.DataFrame, class_names: List[str], out_path: Optional[str] = None
525) -> plt.Figure:
526    """
527    Plots the distribution of classes in a dataset.
528
529    Args:
530        df (pd.DataFrame): DataFrame with a 'label' column.
531        class_names (List[str]): List of class names.
532        out_path (Optional[str], optional): If provided, saves the plot. Defaults to None.
533
534    Returns:
535        plt.Figure: The matplotlib figure object.
536    """
537    class_counts = df["label"].value_counts().sort_index()
538    total_samples = class_counts.sum()
539
540    fig = plt.figure(figsize=(9, 4))
541    bars = plt.bar(class_counts.index, class_counts.values)
542    plt.xticks(np.arange(len(class_names)), class_names, rotation=45, ha="right")
543    plt.ylabel("Number of Samples")
544    plt.title("Dataset Class Distribution")
545
546    # Add count and percentage labels
547    for bar in bars:
548        height = bar.get_height()
549        plt.text(bar.get_x() + bar.get_width()/2., height, f'{height}\n({height/total_samples:.1%})', ha='center', va='bottom', fontsize=8)
550
551    plt.tight_layout()
552
553    if out_path:
554        _ensure_dir(os.path.dirname(out_path))
555        plt.savefig(out_path, dpi=160)
556        plt.close()
557
558    return fig
559
560
561def get_sample_images_for_gallery(
562    df: pd.DataFrame, class_names: List[str], n_samples: int = 20
563) -> List[Tuple[np.ndarray, str]]:
564    """
565    Gets a list of random sample images and their labels for a Gradio Gallery.
566
567    Args:
568        df (pd.DataFrame): DataFrame with image data.
569        class_names (List[str]): List of class names.
570        n_samples (int, optional): Number of samples to retrieve. Defaults to 20.
571
572    Returns:
573        List[Tuple[np.ndarray, str]]: A list of tuples, each containing an image and its label.
574    """
575    samples = []
576    if len(df) > n_samples:
577        df_sample = df.sample(n=n_samples)
578    else:
579        df_sample = df
580
581    for _, row in df_sample.iterrows():
582        label = class_names[int(row["label"])]
583        image_data = row.iloc[1:].values.astype(np.uint8)
584        image = image_data.reshape(28, 28)
585        # The label is now just plain text
586        samples.append((image, label))
587    return samples
588
589
590def plot_class_correlation_dendrogram(
591    df: pd.DataFrame, class_names: List[str], out_path: Optional[str] = None
592) -> plt.Figure:
593    """
594    Calculates and plots a dendrogram showing the similarity between the
595    average image of each class.
596
597    Args:
598        df (pd.DataFrame): DataFrame with image data and labels.
599        class_names (List[str]): List of class names.
600        out_path (Optional[str], optional): If provided, saves the plot. Defaults to None.
601
602    Returns:
603        plt.Figure: The matplotlib figure object.
604    """
605    mean_images = []
606    for i in range(len(class_names)):
607        # Filter for the class, get pixel data, and calculate the mean image
608        mean_img = df[df["label"] == i].iloc[:, 1:].mean().values
609        mean_images.append(mean_img)
610
611    # Convert list of mean images to a matrix
612    mean_images_matrix = np.vstack(mean_images)
613
614    # Perform hierarchical clustering
615    linked = sch.linkage(mean_images_matrix, method="ward")
616
617    fig = plt.figure(figsize=(9, 4))
618    sch.dendrogram(linked, orientation="top", labels=class_names, leaf_rotation=90)
619    plt.ylabel("Euclidean Distance (between mean images)")
620    plt.title("Class Similarity Dendrogram")
621    plt.tight_layout()
622
623    if out_path:
624        _ensure_dir(os.path.dirname(out_path))
625        plt.savefig(out_path, dpi=160)
626        plt.close()
627
628    return fig
FASHION_CLASSES = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
@torch.no_grad()
def evaluate_model( model: torch.nn.modules.module.Module, dataloader: torch.utils.data.dataloader.DataLoader) -> Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]:
 58@torch.no_grad()
 59def evaluate_model(
 60    model: torch.nn.Module,
 61    dataloader: torch.utils.data.DataLoader,
 62) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
 63    """
 64    Run the model over a dataloader and return predictions.
 65
 66    Parameters
 67    ----------
 68    model : torch.nn.Module
 69        Trained model to evaluate.
 70    dataloader : torch.utils.data.DataLoader
 71        DataLoader for evaluation (test/validation).
 72
 73    Returns
 74    -------
 75    tuple
 76        (y_true, y_pred, y_prob_max):
 77        - y_true (ndarray[int]): Ground-truth labels.
 78        - y_pred (ndarray[int]): Predicted labels.
 79        - y_prob_max (ndarray[float]): Max softmax confidence per sample.
 80    """
 81
 82    model.eval()
 83    device = _get_device(model)
 84
 85    all_true: List[int] = []
 86    all_pred: List[int] = []
 87    all_conf: List[float] = []
 88
 89    for xb, yb in dataloader:
 90        xb = xb.to(device)
 91        yb = yb.to(device)
 92
 93        logits = model(xb)
 94        probs = F.softmax(logits, dim=1)
 95        conf, pred = probs.max(dim=1)
 96
 97        all_true.extend(_to_numpy(yb))
 98        all_pred.extend(_to_numpy(pred))
 99        all_conf.extend(_to_numpy(conf))
100
101    return (
102        np.asarray(all_true, dtype=np.int64),
103        np.asarray(all_pred, dtype=np.int64),
104        np.asarray(all_conf, dtype=np.float32),
105    )

Run the model over a dataloader and return predictions.

Parameters

model : torch.nn.Module Trained model to evaluate. dataloader : torch.utils.data.DataLoader DataLoader for evaluation (test/validation).

Returns

tuple (y_true, y_pred, y_prob_max): - y_true (ndarray[int]): Ground-truth labels. - y_pred (ndarray[int]): Predicted labels. - y_prob_max (ndarray[float]): Max softmax confidence per sample.

def plot_confusion_matrix( y_true: numpy.ndarray, y_pred: numpy.ndarray, class_names: List[str], out_path: str, normalize: bool = True) -> None:
108def plot_confusion_matrix(
109    y_true: np.ndarray,
110    y_pred: np.ndarray,
111    class_names: List[str],
112    out_path: str,
113    normalize: bool = True,
114) -> None:
115    """
116    Plot and save a confusion matrix.
117
118    Parameters
119    ----------
120    y_true : np.ndarray
121        Ground-truth labels.
122    y_pred : np.ndarray
123        Predicted labels.
124    class_names : list of str
125        Class names corresponding to label indices.
126    out_path : str
127        File path to save the plot.
128    normalize : bool, optional (default=True)
129        If True, normalize counts to percentages.
130    """
131
132    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(class_names))))
133    if normalize:
134        with np.errstate(all="ignore"):
135            cm = cm.astype(np.float64) / cm.sum(axis=1, keepdims=True)
136        cm = np.nan_to_num(cm)
137
138    plt.figure(figsize=(6, 5))
139    plt.imshow(cm, interpolation="nearest", cmap="coolwarm")
140    plt.title("Confusion Matrix " if normalize else "Confusion Matrix")
141    plt.colorbar()
142    tick_marks = np.arange(len(class_names))
143    plt.xticks(tick_marks, class_names, rotation=45, ha="right")
144    plt.yticks(tick_marks, class_names)
145
146    # Annotate cells
147    thresh = cm.max() / 2.0 if cm.size > 0 else 0.5
148    for i in range(cm.shape[0]):
149        for j in range(cm.shape[1]):
150            val = cm[i, j]
151            txt = f"{val:.2f}" if normalize else f"{int(val)}"
152            plt.text(
153                j,
154                i,
155                txt,
156                horizontalalignment="center",
157                verticalalignment="center",
158                fontsize=8,
159                color="white" if cm[i, j] > thresh else "black",
160            )
161
162    plt.ylabel("True label")
163    plt.xlabel("Predicted label")
164    plt.tight_layout()
165    plt.savefig(out_path, dpi=160)
166    plt.close()

Plot and save a confusion matrix.

Parameters

y_true : np.ndarray Ground-truth labels. y_pred : np.ndarray Predicted labels. class_names : list of str Class names corresponding to label indices. out_path : str File path to save the plot. normalize : bool, optional (default=True) If True, normalize counts to percentages.

def plot_per_class_accuracy( y_true: numpy.ndarray, y_pred: numpy.ndarray, class_names: List[str], out_path: str) -> None:
169def plot_per_class_accuracy(
170    y_true: np.ndarray,
171    y_pred: np.ndarray,
172    class_names: List[str],
173    out_path: str,
174) -> None:
175    """
176    Plot and save per-class accuracy.
177
178    Parameters
179    ----------
180    y_true : np.ndarray
181        Ground-truth labels.
182    y_pred : np.ndarray
183        Predicted labels.
184    class_names : list of str
185        Class names corresponding to label indices.
186    out_path : str
187        File path to save the plot.
188    """
189
190    num_classes = len(class_names)
191    correct = np.zeros(num_classes, dtype=np.int64)
192    total = np.zeros(num_classes, dtype=np.int64)
193
194    for t, p in zip(y_true, y_pred):
195        total[t] += 1
196        if t == p:
197            correct[t] += 1
198
199    acc = np.divide(correct, np.maximum(total, 1), where=total > 0)
200
201    plt.figure(figsize=(9, 4))
202    plt.bar(np.arange(num_classes), acc)
203    plt.ylim(0, 1)
204    plt.xticks(np.arange(num_classes), class_names, rotation=45, ha="right")
205    plt.ylabel("Accuracy")
206    plt.title("Per-class Accuracy")
207    plt.tight_layout()
208    plt.savefig(out_path, dpi=160)
209    plt.close()

Plot and save per-class accuracy.

Parameters

y_true : np.ndarray Ground-truth labels. y_pred : np.ndarray Predicted labels. class_names : list of str Class names corresponding to label indices. out_path : str File path to save the plot.

@torch.no_grad()
def plot_misclassified_grid( model: torch.nn.modules.module.Module, dataloader: torch.utils.data.dataloader.DataLoader, class_names: List[str], out_path: str, max_examples: int = 16) -> None:
234@torch.no_grad()
235def plot_misclassified_grid(
236    model: torch.nn.Module,
237    dataloader: torch.utils.data.DataLoader,
238    class_names: List[str],
239    out_path: str,
240    max_examples: int = 16,
241) -> None:
242    """
243    Plot a grid of misclassified samples.
244
245    Parameters
246    ----------
247    model : torch.nn.Module
248        Trained model to evaluate.
249    dataloader : torch.utils.data.DataLoader
250        DataLoader for evaluation (test/validation).
251    class_names : list of str
252        Class names corresponding to label indices.
253    out_path : str
254        File path to save the plot.
255    max_examples : int, optional (default=16)
256        Maximum number of misclassified samples to show.
257    """
258    model.eval()
259    device = _get_device(model)
260
261    images = []
262    labels_true = []
263    labels_pred = []
264
265    # Collect up to max_examples misclassified samples
266    for xb, yb in dataloader:
267        xb = xb.to(device)
268        yb = yb.to(device)
269        logits = model(xb)
270        pred = logits.argmax(dim=1)
271
272        mism = pred != yb
273        if mism.any():
274            mis_idx = torch.nonzero(mism).flatten()
275            for idx in mis_idx:
276                images.append(xb[idx].cpu())
277                labels_true.append(int(yb[idx].cpu()))
278                labels_pred.append(int(pred[idx].cpu()))
279                if len(images) >= max_examples:
280                    break
281        if len(images) >= max_examples:
282            break
283
284    if len(images) == 0:
285        # Nothing to plot
286        fig = plt.figure(figsize=(6, 2))
287        plt.text(0.5, 0.5, "No misclassified samples found.", ha="center", va="center")
288        plt.axis("off")
289        plt.tight_layout()
290        fig.savefig(out_path, dpi=160)
291        plt.close()
292        return
293
294    # Determine grid size
295    cols = int(math.ceil(math.sqrt(len(images))))
296    rows = int(math.ceil(len(images) / cols))
297
298    plt.figure(figsize=(cols * 1.5, rows * 1.5))
299    for i, img in enumerate(images):
300        ax = plt.subplot(rows, cols, i + 1)
301        ax.imshow(_denormalize_img(img), cmap="gray")
302        true_name = class_names[labels_true[i]]
303        pred_name = class_names[labels_pred[i]]
304        ax.set_title(f"{true_name} → {pred_name}", fontsize=8)
305        ax.axis("off")
306
307    plt.suptitle("Misclassified Samples", y=0.98)
308    plt.tight_layout()
309    plt.savefig(out_path, dpi=160)
310    plt.close()

Plot a grid of misclassified samples.

Parameters

model : torch.nn.Module Trained model to evaluate. dataloader : torch.utils.data.DataLoader DataLoader for evaluation (test/validation). class_names : list of str Class names corresponding to label indices. out_path : str File path to save the plot. max_examples : int, optional (default=16) Maximum number of misclassified samples to show.

def evaluate_and_plot( model: torch.nn.modules.module.Module, datamodule, out_dir: str = 'reports/figures') -> dict:
313def evaluate_and_plot(
314    model: torch.nn.Module,
315    datamodule,
316    out_dir: str = "reports/figures",
317) -> dict:
318
319    _ensure_dir(out_dir)
320
321    # Use the datamodule's test dataloader (Lightning-style)
322    test_loader = datamodule.test_dataloader()
323
324    # 1) Predictions + basic accuracy
325    y_true, y_pred, y_conf = evaluate_model(model, test_loader)
326    test_acc = float((y_true == y_pred).mean())
327
328    # 2) Confusion matrix (normalized)
329    cm_path = os.path.join(out_dir, "confusion_matrix.png")
330    plot_confusion_matrix(y_true, y_pred, FASHION_CLASSES, cm_path, normalize=True)
331
332    # 3) Per-class accuracy bars
333    pca_path = os.path.join(out_dir, "per_class_accuracy.png")
334    plot_per_class_accuracy(y_true, y_pred, FASHION_CLASSES, pca_path)
335
336    # 4) Misclassified samples grid
337    mis_path = os.path.join(out_dir, "misclassified_grid.png")
338    plot_misclassified_grid(
339        model, test_loader, FASHION_CLASSES, mis_path, max_examples=6
340    )
341
342    # 5) Calibration (reliability) curve
343    calib_path = os.path.join(out_dir, "calibration_curve.png")
344    y_correct = (y_true == y_pred).astype(int)
345    plot_calibration_curve(y_correct, y_conf, calib_path, n_bins=10)
346
347    return {
348        "test_accuracy": test_acc,
349        "confusion_matrix": cm_path,
350        "per_class_accuracy": pca_path,
351        "misclassified_grid": mis_path,
352        "calibration_curve": calib_path,
353    }
def plot_calibration_curve(y_true, y_prob, out_path, n_bins: int = 10):
356def plot_calibration_curve(y_true, y_prob, out_path, n_bins: int = 10):
357
358    # Calibration curve
359    prob_true, prob_pred = calibration_curve(
360        y_true, y_prob, n_bins=n_bins, strategy="uniform"
361    )
362
363    plt.figure(figsize=(6, 5))
364    plt.plot(prob_pred, prob_true, marker="o", label="Model")
365    plt.plot([0, 1], [0, 1], linestyle="--", color="gray", label="Perfectly calibrated")
366    plt.xlabel("Predicted probability")
367    plt.ylabel("Observed frequency")
368    plt.title("Calibration Curve")
369    plt.legend()
370    plt.tight_layout()
371    plt.savefig(out_path, dpi=160)
372    plt.close()
def plot_curves_from_csvlogger( csv_log_dir: str, out_dir: str = 'reports/figures', train_loss_keys: Optional[List[str]] = None, val_acc_keys: Optional[List[str]] = None) -> List[str]:
375def plot_curves_from_csvlogger(
376    csv_log_dir: str,
377    out_dir: str = "reports/figures",
378    train_loss_keys: Optional[List[str]] = None,
379    val_acc_keys: Optional[List[str]] = None,
380) -> List[str]:
381    """
382    If you use Lightning's CSVLogger, this parses the `metrics.csv` file and
383    produces line plots for:
384      - train loss (step/epoch)
385      - val accuracy (epoch)
386
387    Usage in train.py:
388        from pytorch_lightning.loggers import CSVLogger
389        logger = CSVLogger("logs", name="fashion")
390        trainer = pl.Trainer(..., logger=logger)
391
392    Then call:
393        plot_curves_from_csvlogger(logger.log_dir)
394
395    Args:
396        csv_log_dir: e.g., "logs/fashion/version_0"
397        train_loss_keys: possible metric column names for train loss
398        val_acc_keys: possible metric column names for val acc
399
400    Returns:
401        list of saved figure paths
402    """
403    _ensure_dir(out_dir)
404    metrics_path = os.path.join(csv_log_dir, "metrics.csv")
405    if not os.path.exists(metrics_path):
406        return []
407
408    df = pd.read_csv(metrics_path)
409
410    # Reasonable defaults; adjust if you log different names
411    if train_loss_keys is None:
412        # Lightning often stores as "train_loss_step" and/or "train_loss_epoch"
413        train_loss_keys = ["train_loss_step", "train_loss_epoch", "loss", "train_loss"]
414    if val_acc_keys is None:
415        # You already log 'val_acc' in your model
416        val_acc_keys = ["val_acc", "val_acc_epoch"]
417
418    saved = []
419
420    # Train loss vs step/epoch
421    for key in train_loss_keys:
422        if key in df.columns:
423            # Prefer 'step' index if present, else 'epoch'
424            if "step" in df.columns and not df["step"].isna().all():
425                x = df["step"]
426                x_label = "Step"
427            else:
428                x = df["epoch"] if "epoch" in df.columns else np.arange(len(df))
429                x_label = "Epoch"
430
431            plt.figure(figsize=(6, 4))
432            plt.plot(x, df[key])
433            plt.xlabel(x_label)
434            plt.ylabel("Train Loss")
435            plt.title(f"{key} over {x_label.lower()}")
436            plt.tight_layout()
437            out_path = os.path.join(out_dir, f"{key}.png")
438            plt.savefig(out_path, dpi=160)
439            plt.close()
440            saved.append(out_path)
441            # Only plot first that exists (avoid duplicates)
442            break
443
444    # Validation accuracy vs epoch
445    for key in val_acc_keys:
446        if key in df.columns:
447            x = df["epoch"] if "epoch" in df.columns else np.arange(len(df))
448            plt.figure(figsize=(6, 4))
449            plt.plot(x, df[key])
450            plt.xlabel("Epoch")
451            plt.ylabel("Validation Accuracy")
452            plt.title(f"{key} over epochs")
453            plt.tight_layout()
454            out_path = os.path.join(out_dir, f"{key}.png")
455            plt.savefig(out_path, dpi=160)
456            plt.close()
457            saved.append(out_path)
458            break
459
460    return saved

If you use Lightning's CSVLogger, this parses the metrics.csv file and produces line plots for:

  • train loss (step/epoch)
  • val accuracy (epoch)

Usage in train.py: from pytorch_lightning.loggers import CSVLogger logger = CSVLogger("logs", name="fashion") trainer = pl.Trainer(..., logger=logger)

Then call: plot_curves_from_csvlogger(logger.log_dir)

Args: csv_log_dir: e.g., "logs/fashion/version_0" train_loss_keys: possible metric column names for train loss val_acc_keys: possible metric column names for val acc

Returns: list of saved figure paths

def plot_learning_curves_from_df( df: pandas.core.frame.DataFrame) -> Tuple[Optional[matplotlib.figure.Figure], Optional[matplotlib.figure.Figure]]:
463def plot_learning_curves_from_df(
464    df: pd.DataFrame,
465) -> Tuple[Optional[plt.Figure], Optional[plt.Figure]]:
466    """
467    Parses a metrics DataFrame and produces line plots for train loss and val accuracy.
468
469    Args:
470        df (pd.DataFrame): DataFrame with metrics (e.g., from CSVLogger).
471
472    Returns:
473        A tuple of (train_loss_figure, val_acc_figure). Figures can be None.
474    """
475    train_loss_fig, val_acc_fig = None, None
476
477    # Train loss vs step/epoch
478    train_loss_keys = ["train_loss_step", "train_loss_epoch", "loss", "train_loss"]
479    val_acc_keys = ["val_acc", "val_acc_epoch"]
480
481    for key in train_loss_keys:
482        if key in df.columns and not df[key].isna().all():
483            series = df[key].dropna()
484            # Prefer 'step' index if present, else 'epoch'
485            if "step" in df.columns and not df["step"].isna().all():
486                x = df.loc[series.index, "step"]
487                x_label = "Step"
488            else:
489                x = df.loc[series.index, "epoch"]
490                x_label = "Epoch"
491
492            train_loss_fig = plt.figure(figsize=(6, 4))
493            plt.plot(x, series)
494            plt.xlabel(x_label)
495            plt.ylabel("Train Loss")
496            plt.title(f"Training Loss vs. {x_label}")
497            plt.tight_layout()
498            # Only plot first that exists
499            break
500
501    # Validation accuracy vs epoch
502    for key in val_acc_keys:
503        if key in df.columns and not df[key].isna().all():
504            # Filter to get only the steps where validation was run
505            series = df[key].dropna()
506            x = df.loc[series.index, "epoch"]
507            val_acc_fig = plt.figure(figsize=(6, 4))
508            plt.plot(x, series, marker='o', linestyle='-')
509            plt.xlabel("Epoch")
510            plt.ylabel("Validation Accuracy")
511            plt.title("Validation Accuracy vs. Epoch")
512            plt.ylim(0, 1)
513            plt.tight_layout()
514            # Only plot first that exists
515            break
516
517    # Close figures if they were not created to avoid empty plots
518    if train_loss_fig is None: plt.close()
519    if val_acc_fig is None: plt.close()
520
521    return train_loss_fig, val_acc_fig

Parses a metrics DataFrame and produces line plots for train loss and val accuracy.

Args: df (pd.DataFrame): DataFrame with metrics (e.g., from CSVLogger).

Returns: A tuple of (train_loss_figure, val_acc_figure). Figures can be None.

def plot_class_distribution( df: pandas.core.frame.DataFrame, class_names: List[str], out_path: Optional[str] = None) -> matplotlib.figure.Figure:
524def plot_class_distribution(
525    df: pd.DataFrame, class_names: List[str], out_path: Optional[str] = None
526) -> plt.Figure:
527    """
528    Plots the distribution of classes in a dataset.
529
530    Args:
531        df (pd.DataFrame): DataFrame with a 'label' column.
532        class_names (List[str]): List of class names.
533        out_path (Optional[str], optional): If provided, saves the plot. Defaults to None.
534
535    Returns:
536        plt.Figure: The matplotlib figure object.
537    """
538    class_counts = df["label"].value_counts().sort_index()
539    total_samples = class_counts.sum()
540
541    fig = plt.figure(figsize=(9, 4))
542    bars = plt.bar(class_counts.index, class_counts.values)
543    plt.xticks(np.arange(len(class_names)), class_names, rotation=45, ha="right")
544    plt.ylabel("Number of Samples")
545    plt.title("Dataset Class Distribution")
546
547    # Add count and percentage labels
548    for bar in bars:
549        height = bar.get_height()
550        plt.text(bar.get_x() + bar.get_width()/2., height, f'{height}\n({height/total_samples:.1%})', ha='center', va='bottom', fontsize=8)
551
552    plt.tight_layout()
553
554    if out_path:
555        _ensure_dir(os.path.dirname(out_path))
556        plt.savefig(out_path, dpi=160)
557        plt.close()
558
559    return fig

Plots the distribution of classes in a dataset.

Args: df (pd.DataFrame): DataFrame with a 'label' column. class_names (List[str]): List of class names. out_path (Optional[str], optional): If provided, saves the plot. Defaults to None.

Returns: plt.Figure: The matplotlib figure object.

def plot_class_correlation_dendrogram( df: pandas.core.frame.DataFrame, class_names: List[str], out_path: Optional[str] = None) -> matplotlib.figure.Figure:
591def plot_class_correlation_dendrogram(
592    df: pd.DataFrame, class_names: List[str], out_path: Optional[str] = None
593) -> plt.Figure:
594    """
595    Calculates and plots a dendrogram showing the similarity between the
596    average image of each class.
597
598    Args:
599        df (pd.DataFrame): DataFrame with image data and labels.
600        class_names (List[str]): List of class names.
601        out_path (Optional[str], optional): If provided, saves the plot. Defaults to None.
602
603    Returns:
604        plt.Figure: The matplotlib figure object.
605    """
606    mean_images = []
607    for i in range(len(class_names)):
608        # Filter for the class, get pixel data, and calculate the mean image
609        mean_img = df[df["label"] == i].iloc[:, 1:].mean().values
610        mean_images.append(mean_img)
611
612    # Convert list of mean images to a matrix
613    mean_images_matrix = np.vstack(mean_images)
614
615    # Perform hierarchical clustering
616    linked = sch.linkage(mean_images_matrix, method="ward")
617
618    fig = plt.figure(figsize=(9, 4))
619    sch.dendrogram(linked, orientation="top", labels=class_names, leaf_rotation=90)
620    plt.ylabel("Euclidean Distance (between mean images)")
621    plt.title("Class Similarity Dendrogram")
622    plt.tight_layout()
623
624    if out_path:
625        _ensure_dir(os.path.dirname(out_path))
626        plt.savefig(out_path, dpi=160)
627        plt.close()
628
629    return fig

Calculates and plots a dendrogram showing the similarity between the average image of each class.

Args: df (pd.DataFrame): DataFrame with image data and labels. class_names (List[str]): List of class names. out_path (Optional[str], optional): If provided, saves the plot. Defaults to None.

Returns: plt.Figure: The matplotlib figure object.