my_project.plots
Utilities to evaluate a trained model and generate performance visualizations.
- evaluate_and_plot(...) : runs eval on a dataloader and saves:
- confusion_matrix.png
- per_class_accuracy.png
- misclassified_grid.png
- plot_curves_from_csvlogger(...) : plots curves from Lightning's CSVLogger
- train_loss_[step|epoch].png
- val_acc_epoch.png (if logged during validation) Outputs are saved under reports/figures/ by default.
This module does NOT modify your model architecture.
1""" 2Utilities to evaluate a trained model and generate performance visualizations. 3- evaluate_and_plot(...) : runs eval on a dataloader and saves: 4 * confusion_matrix.png 5 * per_class_accuracy.png 6 * misclassified_grid.png 7- plot_curves_from_csvlogger(...) : plots curves from Lightning's CSVLogger 8 * train_loss_[step|epoch].png 9 * val_acc_epoch.png (if logged during validation) 10Outputs are saved under reports/figures/ by default. 11 12This module does NOT modify your model architecture. 13""" 14 15import os 16import math 17from typing import Tuple, List, Optional 18 19import torch 20import torch.nn.functional as F 21import matplotlib.pyplot as plt 22import numpy as np 23import pandas as pd 24from sklearn.metrics import confusion_matrix 25from sklearn.calibration import calibration_curve 26from scipy.cluster import hierarchy as sch 27from scipy.spatial.distance import pdist 28 29 30# Fashion-MNIST class names (0..9) 31FASHION_CLASSES = [ 32 "T-shirt/top", 33 "Trouser", 34 "Pullover", 35 "Dress", 36 "Coat", 37 "Sandal", 38 "Shirt", 39 "Sneaker", 40 "Bag", 41 "Ankle boot", 42] 43 44 45def _ensure_dir(path: str) -> None: 46 os.makedirs(path, exist_ok=True) 47 48 49def _to_numpy(x: torch.Tensor) -> np.ndarray: 50 return x.detach().cpu().numpy() 51 52 53def _get_device(module: torch.nn.Module) -> torch.device: 54 return next(module.parameters()).device 55 56 57@torch.no_grad() 58def evaluate_model( 59 model: torch.nn.Module, 60 dataloader: torch.utils.data.DataLoader, 61) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 62 """ 63 Run the model over a dataloader and return predictions. 64 65 Parameters 66 ---------- 67 model : torch.nn.Module 68 Trained model to evaluate. 69 dataloader : torch.utils.data.DataLoader 70 DataLoader for evaluation (test/validation). 71 72 Returns 73 ------- 74 tuple 75 (y_true, y_pred, y_prob_max): 76 - y_true (ndarray[int]): Ground-truth labels. 77 - y_pred (ndarray[int]): Predicted labels. 78 - y_prob_max (ndarray[float]): Max softmax confidence per sample. 79 """ 80 81 model.eval() 82 device = _get_device(model) 83 84 all_true: List[int] = [] 85 all_pred: List[int] = [] 86 all_conf: List[float] = [] 87 88 for xb, yb in dataloader: 89 xb = xb.to(device) 90 yb = yb.to(device) 91 92 logits = model(xb) 93 probs = F.softmax(logits, dim=1) 94 conf, pred = probs.max(dim=1) 95 96 all_true.extend(_to_numpy(yb)) 97 all_pred.extend(_to_numpy(pred)) 98 all_conf.extend(_to_numpy(conf)) 99 100 return ( 101 np.asarray(all_true, dtype=np.int64), 102 np.asarray(all_pred, dtype=np.int64), 103 np.asarray(all_conf, dtype=np.float32), 104 ) 105 106 107def plot_confusion_matrix( 108 y_true: np.ndarray, 109 y_pred: np.ndarray, 110 class_names: List[str], 111 out_path: str, 112 normalize: bool = True, 113) -> None: 114 """ 115 Plot and save a confusion matrix. 116 117 Parameters 118 ---------- 119 y_true : np.ndarray 120 Ground-truth labels. 121 y_pred : np.ndarray 122 Predicted labels. 123 class_names : list of str 124 Class names corresponding to label indices. 125 out_path : str 126 File path to save the plot. 127 normalize : bool, optional (default=True) 128 If True, normalize counts to percentages. 129 """ 130 131 cm = confusion_matrix(y_true, y_pred, labels=list(range(len(class_names)))) 132 if normalize: 133 with np.errstate(all="ignore"): 134 cm = cm.astype(np.float64) / cm.sum(axis=1, keepdims=True) 135 cm = np.nan_to_num(cm) 136 137 plt.figure(figsize=(6, 5)) 138 plt.imshow(cm, interpolation="nearest", cmap="coolwarm") 139 plt.title("Confusion Matrix " if normalize else "Confusion Matrix") 140 plt.colorbar() 141 tick_marks = np.arange(len(class_names)) 142 plt.xticks(tick_marks, class_names, rotation=45, ha="right") 143 plt.yticks(tick_marks, class_names) 144 145 # Annotate cells 146 thresh = cm.max() / 2.0 if cm.size > 0 else 0.5 147 for i in range(cm.shape[0]): 148 for j in range(cm.shape[1]): 149 val = cm[i, j] 150 txt = f"{val:.2f}" if normalize else f"{int(val)}" 151 plt.text( 152 j, 153 i, 154 txt, 155 horizontalalignment="center", 156 verticalalignment="center", 157 fontsize=8, 158 color="white" if cm[i, j] > thresh else "black", 159 ) 160 161 plt.ylabel("True label") 162 plt.xlabel("Predicted label") 163 plt.tight_layout() 164 plt.savefig(out_path, dpi=160) 165 plt.close() 166 167 168def plot_per_class_accuracy( 169 y_true: np.ndarray, 170 y_pred: np.ndarray, 171 class_names: List[str], 172 out_path: str, 173) -> None: 174 """ 175 Plot and save per-class accuracy. 176 177 Parameters 178 ---------- 179 y_true : np.ndarray 180 Ground-truth labels. 181 y_pred : np.ndarray 182 Predicted labels. 183 class_names : list of str 184 Class names corresponding to label indices. 185 out_path : str 186 File path to save the plot. 187 """ 188 189 num_classes = len(class_names) 190 correct = np.zeros(num_classes, dtype=np.int64) 191 total = np.zeros(num_classes, dtype=np.int64) 192 193 for t, p in zip(y_true, y_pred): 194 total[t] += 1 195 if t == p: 196 correct[t] += 1 197 198 acc = np.divide(correct, np.maximum(total, 1), where=total > 0) 199 200 plt.figure(figsize=(9, 4)) 201 plt.bar(np.arange(num_classes), acc) 202 plt.ylim(0, 1) 203 plt.xticks(np.arange(num_classes), class_names, rotation=45, ha="right") 204 plt.ylabel("Accuracy") 205 plt.title("Per-class Accuracy") 206 plt.tight_layout() 207 plt.savefig(out_path, dpi=160) 208 plt.close() 209 210 211def _denormalize_img(img: torch.Tensor) -> np.ndarray: 212 """ 213 Convert a normalized tensor image back to NumPy format. 214 215 Parameters 216 ---------- 217 img : torch.Tensor 218 Image tensor (CHW or HW). 219 220 Returns 221 ------- 222 np.ndarray 223 Image array suitable for plotting. 224 """ 225 226 if img.dim() == 3 and img.shape[0] == 1: 227 img_np = img.squeeze(0).detach().cpu().numpy() 228 return img_np 229 # If already HxW 230 return img.detach().cpu().numpy() 231 232 233@torch.no_grad() 234def plot_misclassified_grid( 235 model: torch.nn.Module, 236 dataloader: torch.utils.data.DataLoader, 237 class_names: List[str], 238 out_path: str, 239 max_examples: int = 16, 240) -> None: 241 """ 242 Plot a grid of misclassified samples. 243 244 Parameters 245 ---------- 246 model : torch.nn.Module 247 Trained model to evaluate. 248 dataloader : torch.utils.data.DataLoader 249 DataLoader for evaluation (test/validation). 250 class_names : list of str 251 Class names corresponding to label indices. 252 out_path : str 253 File path to save the plot. 254 max_examples : int, optional (default=16) 255 Maximum number of misclassified samples to show. 256 """ 257 model.eval() 258 device = _get_device(model) 259 260 images = [] 261 labels_true = [] 262 labels_pred = [] 263 264 # Collect up to max_examples misclassified samples 265 for xb, yb in dataloader: 266 xb = xb.to(device) 267 yb = yb.to(device) 268 logits = model(xb) 269 pred = logits.argmax(dim=1) 270 271 mism = pred != yb 272 if mism.any(): 273 mis_idx = torch.nonzero(mism).flatten() 274 for idx in mis_idx: 275 images.append(xb[idx].cpu()) 276 labels_true.append(int(yb[idx].cpu())) 277 labels_pred.append(int(pred[idx].cpu())) 278 if len(images) >= max_examples: 279 break 280 if len(images) >= max_examples: 281 break 282 283 if len(images) == 0: 284 # Nothing to plot 285 fig = plt.figure(figsize=(6, 2)) 286 plt.text(0.5, 0.5, "No misclassified samples found.", ha="center", va="center") 287 plt.axis("off") 288 plt.tight_layout() 289 fig.savefig(out_path, dpi=160) 290 plt.close() 291 return 292 293 # Determine grid size 294 cols = int(math.ceil(math.sqrt(len(images)))) 295 rows = int(math.ceil(len(images) / cols)) 296 297 plt.figure(figsize=(cols * 1.5, rows * 1.5)) 298 for i, img in enumerate(images): 299 ax = plt.subplot(rows, cols, i + 1) 300 ax.imshow(_denormalize_img(img), cmap="gray") 301 true_name = class_names[labels_true[i]] 302 pred_name = class_names[labels_pred[i]] 303 ax.set_title(f"{true_name} → {pred_name}", fontsize=8) 304 ax.axis("off") 305 306 plt.suptitle("Misclassified Samples", y=0.98) 307 plt.tight_layout() 308 plt.savefig(out_path, dpi=160) 309 plt.close() 310 311 312def evaluate_and_plot( 313 model: torch.nn.Module, 314 datamodule, 315 out_dir: str = "reports/figures", 316) -> dict: 317 318 _ensure_dir(out_dir) 319 320 # Use the datamodule's test dataloader (Lightning-style) 321 test_loader = datamodule.test_dataloader() 322 323 # 1) Predictions + basic accuracy 324 y_true, y_pred, y_conf = evaluate_model(model, test_loader) 325 test_acc = float((y_true == y_pred).mean()) 326 327 # 2) Confusion matrix (normalized) 328 cm_path = os.path.join(out_dir, "confusion_matrix.png") 329 plot_confusion_matrix(y_true, y_pred, FASHION_CLASSES, cm_path, normalize=True) 330 331 # 3) Per-class accuracy bars 332 pca_path = os.path.join(out_dir, "per_class_accuracy.png") 333 plot_per_class_accuracy(y_true, y_pred, FASHION_CLASSES, pca_path) 334 335 # 4) Misclassified samples grid 336 mis_path = os.path.join(out_dir, "misclassified_grid.png") 337 plot_misclassified_grid( 338 model, test_loader, FASHION_CLASSES, mis_path, max_examples=6 339 ) 340 341 # 5) Calibration (reliability) curve 342 calib_path = os.path.join(out_dir, "calibration_curve.png") 343 y_correct = (y_true == y_pred).astype(int) 344 plot_calibration_curve(y_correct, y_conf, calib_path, n_bins=10) 345 346 return { 347 "test_accuracy": test_acc, 348 "confusion_matrix": cm_path, 349 "per_class_accuracy": pca_path, 350 "misclassified_grid": mis_path, 351 "calibration_curve": calib_path, 352 } 353 354 355def plot_calibration_curve(y_true, y_prob, out_path, n_bins: int = 10): 356 357 # Calibration curve 358 prob_true, prob_pred = calibration_curve( 359 y_true, y_prob, n_bins=n_bins, strategy="uniform" 360 ) 361 362 plt.figure(figsize=(6, 5)) 363 plt.plot(prob_pred, prob_true, marker="o", label="Model") 364 plt.plot([0, 1], [0, 1], linestyle="--", color="gray", label="Perfectly calibrated") 365 plt.xlabel("Predicted probability") 366 plt.ylabel("Observed frequency") 367 plt.title("Calibration Curve") 368 plt.legend() 369 plt.tight_layout() 370 plt.savefig(out_path, dpi=160) 371 plt.close() 372 373 374def plot_curves_from_csvlogger( 375 csv_log_dir: str, 376 out_dir: str = "reports/figures", 377 train_loss_keys: Optional[List[str]] = None, 378 val_acc_keys: Optional[List[str]] = None, 379) -> List[str]: 380 """ 381 If you use Lightning's CSVLogger, this parses the `metrics.csv` file and 382 produces line plots for: 383 - train loss (step/epoch) 384 - val accuracy (epoch) 385 386 Usage in train.py: 387 from pytorch_lightning.loggers import CSVLogger 388 logger = CSVLogger("logs", name="fashion") 389 trainer = pl.Trainer(..., logger=logger) 390 391 Then call: 392 plot_curves_from_csvlogger(logger.log_dir) 393 394 Args: 395 csv_log_dir: e.g., "logs/fashion/version_0" 396 train_loss_keys: possible metric column names for train loss 397 val_acc_keys: possible metric column names for val acc 398 399 Returns: 400 list of saved figure paths 401 """ 402 _ensure_dir(out_dir) 403 metrics_path = os.path.join(csv_log_dir, "metrics.csv") 404 if not os.path.exists(metrics_path): 405 return [] 406 407 df = pd.read_csv(metrics_path) 408 409 # Reasonable defaults; adjust if you log different names 410 if train_loss_keys is None: 411 # Lightning often stores as "train_loss_step" and/or "train_loss_epoch" 412 train_loss_keys = ["train_loss_step", "train_loss_epoch", "loss", "train_loss"] 413 if val_acc_keys is None: 414 # You already log 'val_acc' in your model 415 val_acc_keys = ["val_acc", "val_acc_epoch"] 416 417 saved = [] 418 419 # Train loss vs step/epoch 420 for key in train_loss_keys: 421 if key in df.columns: 422 # Prefer 'step' index if present, else 'epoch' 423 if "step" in df.columns and not df["step"].isna().all(): 424 x = df["step"] 425 x_label = "Step" 426 else: 427 x = df["epoch"] if "epoch" in df.columns else np.arange(len(df)) 428 x_label = "Epoch" 429 430 plt.figure(figsize=(6, 4)) 431 plt.plot(x, df[key]) 432 plt.xlabel(x_label) 433 plt.ylabel("Train Loss") 434 plt.title(f"{key} over {x_label.lower()}") 435 plt.tight_layout() 436 out_path = os.path.join(out_dir, f"{key}.png") 437 plt.savefig(out_path, dpi=160) 438 plt.close() 439 saved.append(out_path) 440 # Only plot first that exists (avoid duplicates) 441 break 442 443 # Validation accuracy vs epoch 444 for key in val_acc_keys: 445 if key in df.columns: 446 x = df["epoch"] if "epoch" in df.columns else np.arange(len(df)) 447 plt.figure(figsize=(6, 4)) 448 plt.plot(x, df[key]) 449 plt.xlabel("Epoch") 450 plt.ylabel("Validation Accuracy") 451 plt.title(f"{key} over epochs") 452 plt.tight_layout() 453 out_path = os.path.join(out_dir, f"{key}.png") 454 plt.savefig(out_path, dpi=160) 455 plt.close() 456 saved.append(out_path) 457 break 458 459 return saved 460 461 462def plot_learning_curves_from_df( 463 df: pd.DataFrame, 464) -> Tuple[Optional[plt.Figure], Optional[plt.Figure]]: 465 """ 466 Parses a metrics DataFrame and produces line plots for train loss and val accuracy. 467 468 Args: 469 df (pd.DataFrame): DataFrame with metrics (e.g., from CSVLogger). 470 471 Returns: 472 A tuple of (train_loss_figure, val_acc_figure). Figures can be None. 473 """ 474 train_loss_fig, val_acc_fig = None, None 475 476 # Train loss vs step/epoch 477 train_loss_keys = ["train_loss_step", "train_loss_epoch", "loss", "train_loss"] 478 val_acc_keys = ["val_acc", "val_acc_epoch"] 479 480 for key in train_loss_keys: 481 if key in df.columns and not df[key].isna().all(): 482 series = df[key].dropna() 483 # Prefer 'step' index if present, else 'epoch' 484 if "step" in df.columns and not df["step"].isna().all(): 485 x = df.loc[series.index, "step"] 486 x_label = "Step" 487 else: 488 x = df.loc[series.index, "epoch"] 489 x_label = "Epoch" 490 491 train_loss_fig = plt.figure(figsize=(6, 4)) 492 plt.plot(x, series) 493 plt.xlabel(x_label) 494 plt.ylabel("Train Loss") 495 plt.title(f"Training Loss vs. {x_label}") 496 plt.tight_layout() 497 # Only plot first that exists 498 break 499 500 # Validation accuracy vs epoch 501 for key in val_acc_keys: 502 if key in df.columns and not df[key].isna().all(): 503 # Filter to get only the steps where validation was run 504 series = df[key].dropna() 505 x = df.loc[series.index, "epoch"] 506 val_acc_fig = plt.figure(figsize=(6, 4)) 507 plt.plot(x, series, marker='o', linestyle='-') 508 plt.xlabel("Epoch") 509 plt.ylabel("Validation Accuracy") 510 plt.title("Validation Accuracy vs. Epoch") 511 plt.ylim(0, 1) 512 plt.tight_layout() 513 # Only plot first that exists 514 break 515 516 # Close figures if they were not created to avoid empty plots 517 if train_loss_fig is None: plt.close() 518 if val_acc_fig is None: plt.close() 519 520 return train_loss_fig, val_acc_fig 521 522 523def plot_class_distribution( 524 df: pd.DataFrame, class_names: List[str], out_path: Optional[str] = None 525) -> plt.Figure: 526 """ 527 Plots the distribution of classes in a dataset. 528 529 Args: 530 df (pd.DataFrame): DataFrame with a 'label' column. 531 class_names (List[str]): List of class names. 532 out_path (Optional[str], optional): If provided, saves the plot. Defaults to None. 533 534 Returns: 535 plt.Figure: The matplotlib figure object. 536 """ 537 class_counts = df["label"].value_counts().sort_index() 538 total_samples = class_counts.sum() 539 540 fig = plt.figure(figsize=(9, 4)) 541 bars = plt.bar(class_counts.index, class_counts.values) 542 plt.xticks(np.arange(len(class_names)), class_names, rotation=45, ha="right") 543 plt.ylabel("Number of Samples") 544 plt.title("Dataset Class Distribution") 545 546 # Add count and percentage labels 547 for bar in bars: 548 height = bar.get_height() 549 plt.text(bar.get_x() + bar.get_width()/2., height, f'{height}\n({height/total_samples:.1%})', ha='center', va='bottom', fontsize=8) 550 551 plt.tight_layout() 552 553 if out_path: 554 _ensure_dir(os.path.dirname(out_path)) 555 plt.savefig(out_path, dpi=160) 556 plt.close() 557 558 return fig 559 560 561def get_sample_images_for_gallery( 562 df: pd.DataFrame, class_names: List[str], n_samples: int = 20 563) -> List[Tuple[np.ndarray, str]]: 564 """ 565 Gets a list of random sample images and their labels for a Gradio Gallery. 566 567 Args: 568 df (pd.DataFrame): DataFrame with image data. 569 class_names (List[str]): List of class names. 570 n_samples (int, optional): Number of samples to retrieve. Defaults to 20. 571 572 Returns: 573 List[Tuple[np.ndarray, str]]: A list of tuples, each containing an image and its label. 574 """ 575 samples = [] 576 if len(df) > n_samples: 577 df_sample = df.sample(n=n_samples) 578 else: 579 df_sample = df 580 581 for _, row in df_sample.iterrows(): 582 label = class_names[int(row["label"])] 583 image_data = row.iloc[1:].values.astype(np.uint8) 584 image = image_data.reshape(28, 28) 585 # The label is now just plain text 586 samples.append((image, label)) 587 return samples 588 589 590def plot_class_correlation_dendrogram( 591 df: pd.DataFrame, class_names: List[str], out_path: Optional[str] = None 592) -> plt.Figure: 593 """ 594 Calculates and plots a dendrogram showing the similarity between the 595 average image of each class. 596 597 Args: 598 df (pd.DataFrame): DataFrame with image data and labels. 599 class_names (List[str]): List of class names. 600 out_path (Optional[str], optional): If provided, saves the plot. Defaults to None. 601 602 Returns: 603 plt.Figure: The matplotlib figure object. 604 """ 605 mean_images = [] 606 for i in range(len(class_names)): 607 # Filter for the class, get pixel data, and calculate the mean image 608 mean_img = df[df["label"] == i].iloc[:, 1:].mean().values 609 mean_images.append(mean_img) 610 611 # Convert list of mean images to a matrix 612 mean_images_matrix = np.vstack(mean_images) 613 614 # Perform hierarchical clustering 615 linked = sch.linkage(mean_images_matrix, method="ward") 616 617 fig = plt.figure(figsize=(9, 4)) 618 sch.dendrogram(linked, orientation="top", labels=class_names, leaf_rotation=90) 619 plt.ylabel("Euclidean Distance (between mean images)") 620 plt.title("Class Similarity Dendrogram") 621 plt.tight_layout() 622 623 if out_path: 624 _ensure_dir(os.path.dirname(out_path)) 625 plt.savefig(out_path, dpi=160) 626 plt.close() 627 628 return fig
58@torch.no_grad() 59def evaluate_model( 60 model: torch.nn.Module, 61 dataloader: torch.utils.data.DataLoader, 62) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 63 """ 64 Run the model over a dataloader and return predictions. 65 66 Parameters 67 ---------- 68 model : torch.nn.Module 69 Trained model to evaluate. 70 dataloader : torch.utils.data.DataLoader 71 DataLoader for evaluation (test/validation). 72 73 Returns 74 ------- 75 tuple 76 (y_true, y_pred, y_prob_max): 77 - y_true (ndarray[int]): Ground-truth labels. 78 - y_pred (ndarray[int]): Predicted labels. 79 - y_prob_max (ndarray[float]): Max softmax confidence per sample. 80 """ 81 82 model.eval() 83 device = _get_device(model) 84 85 all_true: List[int] = [] 86 all_pred: List[int] = [] 87 all_conf: List[float] = [] 88 89 for xb, yb in dataloader: 90 xb = xb.to(device) 91 yb = yb.to(device) 92 93 logits = model(xb) 94 probs = F.softmax(logits, dim=1) 95 conf, pred = probs.max(dim=1) 96 97 all_true.extend(_to_numpy(yb)) 98 all_pred.extend(_to_numpy(pred)) 99 all_conf.extend(_to_numpy(conf)) 100 101 return ( 102 np.asarray(all_true, dtype=np.int64), 103 np.asarray(all_pred, dtype=np.int64), 104 np.asarray(all_conf, dtype=np.float32), 105 )
Run the model over a dataloader and return predictions.
Parameters
model : torch.nn.Module Trained model to evaluate. dataloader : torch.utils.data.DataLoader DataLoader for evaluation (test/validation).
Returns
tuple (y_true, y_pred, y_prob_max): - y_true (ndarray[int]): Ground-truth labels. - y_pred (ndarray[int]): Predicted labels. - y_prob_max (ndarray[float]): Max softmax confidence per sample.
108def plot_confusion_matrix( 109 y_true: np.ndarray, 110 y_pred: np.ndarray, 111 class_names: List[str], 112 out_path: str, 113 normalize: bool = True, 114) -> None: 115 """ 116 Plot and save a confusion matrix. 117 118 Parameters 119 ---------- 120 y_true : np.ndarray 121 Ground-truth labels. 122 y_pred : np.ndarray 123 Predicted labels. 124 class_names : list of str 125 Class names corresponding to label indices. 126 out_path : str 127 File path to save the plot. 128 normalize : bool, optional (default=True) 129 If True, normalize counts to percentages. 130 """ 131 132 cm = confusion_matrix(y_true, y_pred, labels=list(range(len(class_names)))) 133 if normalize: 134 with np.errstate(all="ignore"): 135 cm = cm.astype(np.float64) / cm.sum(axis=1, keepdims=True) 136 cm = np.nan_to_num(cm) 137 138 plt.figure(figsize=(6, 5)) 139 plt.imshow(cm, interpolation="nearest", cmap="coolwarm") 140 plt.title("Confusion Matrix " if normalize else "Confusion Matrix") 141 plt.colorbar() 142 tick_marks = np.arange(len(class_names)) 143 plt.xticks(tick_marks, class_names, rotation=45, ha="right") 144 plt.yticks(tick_marks, class_names) 145 146 # Annotate cells 147 thresh = cm.max() / 2.0 if cm.size > 0 else 0.5 148 for i in range(cm.shape[0]): 149 for j in range(cm.shape[1]): 150 val = cm[i, j] 151 txt = f"{val:.2f}" if normalize else f"{int(val)}" 152 plt.text( 153 j, 154 i, 155 txt, 156 horizontalalignment="center", 157 verticalalignment="center", 158 fontsize=8, 159 color="white" if cm[i, j] > thresh else "black", 160 ) 161 162 plt.ylabel("True label") 163 plt.xlabel("Predicted label") 164 plt.tight_layout() 165 plt.savefig(out_path, dpi=160) 166 plt.close()
Plot and save a confusion matrix.
Parameters
y_true : np.ndarray Ground-truth labels. y_pred : np.ndarray Predicted labels. class_names : list of str Class names corresponding to label indices. out_path : str File path to save the plot. normalize : bool, optional (default=True) If True, normalize counts to percentages.
169def plot_per_class_accuracy( 170 y_true: np.ndarray, 171 y_pred: np.ndarray, 172 class_names: List[str], 173 out_path: str, 174) -> None: 175 """ 176 Plot and save per-class accuracy. 177 178 Parameters 179 ---------- 180 y_true : np.ndarray 181 Ground-truth labels. 182 y_pred : np.ndarray 183 Predicted labels. 184 class_names : list of str 185 Class names corresponding to label indices. 186 out_path : str 187 File path to save the plot. 188 """ 189 190 num_classes = len(class_names) 191 correct = np.zeros(num_classes, dtype=np.int64) 192 total = np.zeros(num_classes, dtype=np.int64) 193 194 for t, p in zip(y_true, y_pred): 195 total[t] += 1 196 if t == p: 197 correct[t] += 1 198 199 acc = np.divide(correct, np.maximum(total, 1), where=total > 0) 200 201 plt.figure(figsize=(9, 4)) 202 plt.bar(np.arange(num_classes), acc) 203 plt.ylim(0, 1) 204 plt.xticks(np.arange(num_classes), class_names, rotation=45, ha="right") 205 plt.ylabel("Accuracy") 206 plt.title("Per-class Accuracy") 207 plt.tight_layout() 208 plt.savefig(out_path, dpi=160) 209 plt.close()
Plot and save per-class accuracy.
Parameters
y_true : np.ndarray Ground-truth labels. y_pred : np.ndarray Predicted labels. class_names : list of str Class names corresponding to label indices. out_path : str File path to save the plot.
234@torch.no_grad() 235def plot_misclassified_grid( 236 model: torch.nn.Module, 237 dataloader: torch.utils.data.DataLoader, 238 class_names: List[str], 239 out_path: str, 240 max_examples: int = 16, 241) -> None: 242 """ 243 Plot a grid of misclassified samples. 244 245 Parameters 246 ---------- 247 model : torch.nn.Module 248 Trained model to evaluate. 249 dataloader : torch.utils.data.DataLoader 250 DataLoader for evaluation (test/validation). 251 class_names : list of str 252 Class names corresponding to label indices. 253 out_path : str 254 File path to save the plot. 255 max_examples : int, optional (default=16) 256 Maximum number of misclassified samples to show. 257 """ 258 model.eval() 259 device = _get_device(model) 260 261 images = [] 262 labels_true = [] 263 labels_pred = [] 264 265 # Collect up to max_examples misclassified samples 266 for xb, yb in dataloader: 267 xb = xb.to(device) 268 yb = yb.to(device) 269 logits = model(xb) 270 pred = logits.argmax(dim=1) 271 272 mism = pred != yb 273 if mism.any(): 274 mis_idx = torch.nonzero(mism).flatten() 275 for idx in mis_idx: 276 images.append(xb[idx].cpu()) 277 labels_true.append(int(yb[idx].cpu())) 278 labels_pred.append(int(pred[idx].cpu())) 279 if len(images) >= max_examples: 280 break 281 if len(images) >= max_examples: 282 break 283 284 if len(images) == 0: 285 # Nothing to plot 286 fig = plt.figure(figsize=(6, 2)) 287 plt.text(0.5, 0.5, "No misclassified samples found.", ha="center", va="center") 288 plt.axis("off") 289 plt.tight_layout() 290 fig.savefig(out_path, dpi=160) 291 plt.close() 292 return 293 294 # Determine grid size 295 cols = int(math.ceil(math.sqrt(len(images)))) 296 rows = int(math.ceil(len(images) / cols)) 297 298 plt.figure(figsize=(cols * 1.5, rows * 1.5)) 299 for i, img in enumerate(images): 300 ax = plt.subplot(rows, cols, i + 1) 301 ax.imshow(_denormalize_img(img), cmap="gray") 302 true_name = class_names[labels_true[i]] 303 pred_name = class_names[labels_pred[i]] 304 ax.set_title(f"{true_name} → {pred_name}", fontsize=8) 305 ax.axis("off") 306 307 plt.suptitle("Misclassified Samples", y=0.98) 308 plt.tight_layout() 309 plt.savefig(out_path, dpi=160) 310 plt.close()
Plot a grid of misclassified samples.
Parameters
model : torch.nn.Module Trained model to evaluate. dataloader : torch.utils.data.DataLoader DataLoader for evaluation (test/validation). class_names : list of str Class names corresponding to label indices. out_path : str File path to save the plot. max_examples : int, optional (default=16) Maximum number of misclassified samples to show.
313def evaluate_and_plot( 314 model: torch.nn.Module, 315 datamodule, 316 out_dir: str = "reports/figures", 317) -> dict: 318 319 _ensure_dir(out_dir) 320 321 # Use the datamodule's test dataloader (Lightning-style) 322 test_loader = datamodule.test_dataloader() 323 324 # 1) Predictions + basic accuracy 325 y_true, y_pred, y_conf = evaluate_model(model, test_loader) 326 test_acc = float((y_true == y_pred).mean()) 327 328 # 2) Confusion matrix (normalized) 329 cm_path = os.path.join(out_dir, "confusion_matrix.png") 330 plot_confusion_matrix(y_true, y_pred, FASHION_CLASSES, cm_path, normalize=True) 331 332 # 3) Per-class accuracy bars 333 pca_path = os.path.join(out_dir, "per_class_accuracy.png") 334 plot_per_class_accuracy(y_true, y_pred, FASHION_CLASSES, pca_path) 335 336 # 4) Misclassified samples grid 337 mis_path = os.path.join(out_dir, "misclassified_grid.png") 338 plot_misclassified_grid( 339 model, test_loader, FASHION_CLASSES, mis_path, max_examples=6 340 ) 341 342 # 5) Calibration (reliability) curve 343 calib_path = os.path.join(out_dir, "calibration_curve.png") 344 y_correct = (y_true == y_pred).astype(int) 345 plot_calibration_curve(y_correct, y_conf, calib_path, n_bins=10) 346 347 return { 348 "test_accuracy": test_acc, 349 "confusion_matrix": cm_path, 350 "per_class_accuracy": pca_path, 351 "misclassified_grid": mis_path, 352 "calibration_curve": calib_path, 353 }
356def plot_calibration_curve(y_true, y_prob, out_path, n_bins: int = 10): 357 358 # Calibration curve 359 prob_true, prob_pred = calibration_curve( 360 y_true, y_prob, n_bins=n_bins, strategy="uniform" 361 ) 362 363 plt.figure(figsize=(6, 5)) 364 plt.plot(prob_pred, prob_true, marker="o", label="Model") 365 plt.plot([0, 1], [0, 1], linestyle="--", color="gray", label="Perfectly calibrated") 366 plt.xlabel("Predicted probability") 367 plt.ylabel("Observed frequency") 368 plt.title("Calibration Curve") 369 plt.legend() 370 plt.tight_layout() 371 plt.savefig(out_path, dpi=160) 372 plt.close()
375def plot_curves_from_csvlogger( 376 csv_log_dir: str, 377 out_dir: str = "reports/figures", 378 train_loss_keys: Optional[List[str]] = None, 379 val_acc_keys: Optional[List[str]] = None, 380) -> List[str]: 381 """ 382 If you use Lightning's CSVLogger, this parses the `metrics.csv` file and 383 produces line plots for: 384 - train loss (step/epoch) 385 - val accuracy (epoch) 386 387 Usage in train.py: 388 from pytorch_lightning.loggers import CSVLogger 389 logger = CSVLogger("logs", name="fashion") 390 trainer = pl.Trainer(..., logger=logger) 391 392 Then call: 393 plot_curves_from_csvlogger(logger.log_dir) 394 395 Args: 396 csv_log_dir: e.g., "logs/fashion/version_0" 397 train_loss_keys: possible metric column names for train loss 398 val_acc_keys: possible metric column names for val acc 399 400 Returns: 401 list of saved figure paths 402 """ 403 _ensure_dir(out_dir) 404 metrics_path = os.path.join(csv_log_dir, "metrics.csv") 405 if not os.path.exists(metrics_path): 406 return [] 407 408 df = pd.read_csv(metrics_path) 409 410 # Reasonable defaults; adjust if you log different names 411 if train_loss_keys is None: 412 # Lightning often stores as "train_loss_step" and/or "train_loss_epoch" 413 train_loss_keys = ["train_loss_step", "train_loss_epoch", "loss", "train_loss"] 414 if val_acc_keys is None: 415 # You already log 'val_acc' in your model 416 val_acc_keys = ["val_acc", "val_acc_epoch"] 417 418 saved = [] 419 420 # Train loss vs step/epoch 421 for key in train_loss_keys: 422 if key in df.columns: 423 # Prefer 'step' index if present, else 'epoch' 424 if "step" in df.columns and not df["step"].isna().all(): 425 x = df["step"] 426 x_label = "Step" 427 else: 428 x = df["epoch"] if "epoch" in df.columns else np.arange(len(df)) 429 x_label = "Epoch" 430 431 plt.figure(figsize=(6, 4)) 432 plt.plot(x, df[key]) 433 plt.xlabel(x_label) 434 plt.ylabel("Train Loss") 435 plt.title(f"{key} over {x_label.lower()}") 436 plt.tight_layout() 437 out_path = os.path.join(out_dir, f"{key}.png") 438 plt.savefig(out_path, dpi=160) 439 plt.close() 440 saved.append(out_path) 441 # Only plot first that exists (avoid duplicates) 442 break 443 444 # Validation accuracy vs epoch 445 for key in val_acc_keys: 446 if key in df.columns: 447 x = df["epoch"] if "epoch" in df.columns else np.arange(len(df)) 448 plt.figure(figsize=(6, 4)) 449 plt.plot(x, df[key]) 450 plt.xlabel("Epoch") 451 plt.ylabel("Validation Accuracy") 452 plt.title(f"{key} over epochs") 453 plt.tight_layout() 454 out_path = os.path.join(out_dir, f"{key}.png") 455 plt.savefig(out_path, dpi=160) 456 plt.close() 457 saved.append(out_path) 458 break 459 460 return saved
If you use Lightning's CSVLogger, this parses the metrics.csv file and
produces line plots for:
- train loss (step/epoch)
- val accuracy (epoch)
Usage in train.py: from pytorch_lightning.loggers import CSVLogger logger = CSVLogger("logs", name="fashion") trainer = pl.Trainer(..., logger=logger)
Then call: plot_curves_from_csvlogger(logger.log_dir)
Args: csv_log_dir: e.g., "logs/fashion/version_0" train_loss_keys: possible metric column names for train loss val_acc_keys: possible metric column names for val acc
Returns: list of saved figure paths
463def plot_learning_curves_from_df( 464 df: pd.DataFrame, 465) -> Tuple[Optional[plt.Figure], Optional[plt.Figure]]: 466 """ 467 Parses a metrics DataFrame and produces line plots for train loss and val accuracy. 468 469 Args: 470 df (pd.DataFrame): DataFrame with metrics (e.g., from CSVLogger). 471 472 Returns: 473 A tuple of (train_loss_figure, val_acc_figure). Figures can be None. 474 """ 475 train_loss_fig, val_acc_fig = None, None 476 477 # Train loss vs step/epoch 478 train_loss_keys = ["train_loss_step", "train_loss_epoch", "loss", "train_loss"] 479 val_acc_keys = ["val_acc", "val_acc_epoch"] 480 481 for key in train_loss_keys: 482 if key in df.columns and not df[key].isna().all(): 483 series = df[key].dropna() 484 # Prefer 'step' index if present, else 'epoch' 485 if "step" in df.columns and not df["step"].isna().all(): 486 x = df.loc[series.index, "step"] 487 x_label = "Step" 488 else: 489 x = df.loc[series.index, "epoch"] 490 x_label = "Epoch" 491 492 train_loss_fig = plt.figure(figsize=(6, 4)) 493 plt.plot(x, series) 494 plt.xlabel(x_label) 495 plt.ylabel("Train Loss") 496 plt.title(f"Training Loss vs. {x_label}") 497 plt.tight_layout() 498 # Only plot first that exists 499 break 500 501 # Validation accuracy vs epoch 502 for key in val_acc_keys: 503 if key in df.columns and not df[key].isna().all(): 504 # Filter to get only the steps where validation was run 505 series = df[key].dropna() 506 x = df.loc[series.index, "epoch"] 507 val_acc_fig = plt.figure(figsize=(6, 4)) 508 plt.plot(x, series, marker='o', linestyle='-') 509 plt.xlabel("Epoch") 510 plt.ylabel("Validation Accuracy") 511 plt.title("Validation Accuracy vs. Epoch") 512 plt.ylim(0, 1) 513 plt.tight_layout() 514 # Only plot first that exists 515 break 516 517 # Close figures if they were not created to avoid empty plots 518 if train_loss_fig is None: plt.close() 519 if val_acc_fig is None: plt.close() 520 521 return train_loss_fig, val_acc_fig
Parses a metrics DataFrame and produces line plots for train loss and val accuracy.
Args: df (pd.DataFrame): DataFrame with metrics (e.g., from CSVLogger).
Returns: A tuple of (train_loss_figure, val_acc_figure). Figures can be None.
524def plot_class_distribution( 525 df: pd.DataFrame, class_names: List[str], out_path: Optional[str] = None 526) -> plt.Figure: 527 """ 528 Plots the distribution of classes in a dataset. 529 530 Args: 531 df (pd.DataFrame): DataFrame with a 'label' column. 532 class_names (List[str]): List of class names. 533 out_path (Optional[str], optional): If provided, saves the plot. Defaults to None. 534 535 Returns: 536 plt.Figure: The matplotlib figure object. 537 """ 538 class_counts = df["label"].value_counts().sort_index() 539 total_samples = class_counts.sum() 540 541 fig = plt.figure(figsize=(9, 4)) 542 bars = plt.bar(class_counts.index, class_counts.values) 543 plt.xticks(np.arange(len(class_names)), class_names, rotation=45, ha="right") 544 plt.ylabel("Number of Samples") 545 plt.title("Dataset Class Distribution") 546 547 # Add count and percentage labels 548 for bar in bars: 549 height = bar.get_height() 550 plt.text(bar.get_x() + bar.get_width()/2., height, f'{height}\n({height/total_samples:.1%})', ha='center', va='bottom', fontsize=8) 551 552 plt.tight_layout() 553 554 if out_path: 555 _ensure_dir(os.path.dirname(out_path)) 556 plt.savefig(out_path, dpi=160) 557 plt.close() 558 559 return fig
Plots the distribution of classes in a dataset.
Args: df (pd.DataFrame): DataFrame with a 'label' column. class_names (List[str]): List of class names. out_path (Optional[str], optional): If provided, saves the plot. Defaults to None.
Returns: plt.Figure: The matplotlib figure object.
562def get_sample_images_for_gallery( 563 df: pd.DataFrame, class_names: List[str], n_samples: int = 20 564) -> List[Tuple[np.ndarray, str]]: 565 """ 566 Gets a list of random sample images and their labels for a Gradio Gallery. 567 568 Args: 569 df (pd.DataFrame): DataFrame with image data. 570 class_names (List[str]): List of class names. 571 n_samples (int, optional): Number of samples to retrieve. Defaults to 20. 572 573 Returns: 574 List[Tuple[np.ndarray, str]]: A list of tuples, each containing an image and its label. 575 """ 576 samples = [] 577 if len(df) > n_samples: 578 df_sample = df.sample(n=n_samples) 579 else: 580 df_sample = df 581 582 for _, row in df_sample.iterrows(): 583 label = class_names[int(row["label"])] 584 image_data = row.iloc[1:].values.astype(np.uint8) 585 image = image_data.reshape(28, 28) 586 # The label is now just plain text 587 samples.append((image, label)) 588 return samples
Gets a list of random sample images and their labels for a Gradio Gallery.
Args: df (pd.DataFrame): DataFrame with image data. class_names (List[str]): List of class names. n_samples (int, optional): Number of samples to retrieve. Defaults to 20.
Returns: List[Tuple[np.ndarray, str]]: A list of tuples, each containing an image and its label.
591def plot_class_correlation_dendrogram( 592 df: pd.DataFrame, class_names: List[str], out_path: Optional[str] = None 593) -> plt.Figure: 594 """ 595 Calculates and plots a dendrogram showing the similarity between the 596 average image of each class. 597 598 Args: 599 df (pd.DataFrame): DataFrame with image data and labels. 600 class_names (List[str]): List of class names. 601 out_path (Optional[str], optional): If provided, saves the plot. Defaults to None. 602 603 Returns: 604 plt.Figure: The matplotlib figure object. 605 """ 606 mean_images = [] 607 for i in range(len(class_names)): 608 # Filter for the class, get pixel data, and calculate the mean image 609 mean_img = df[df["label"] == i].iloc[:, 1:].mean().values 610 mean_images.append(mean_img) 611 612 # Convert list of mean images to a matrix 613 mean_images_matrix = np.vstack(mean_images) 614 615 # Perform hierarchical clustering 616 linked = sch.linkage(mean_images_matrix, method="ward") 617 618 fig = plt.figure(figsize=(9, 4)) 619 sch.dendrogram(linked, orientation="top", labels=class_names, leaf_rotation=90) 620 plt.ylabel("Euclidean Distance (between mean images)") 621 plt.title("Class Similarity Dendrogram") 622 plt.tight_layout() 623 624 if out_path: 625 _ensure_dir(os.path.dirname(out_path)) 626 plt.savefig(out_path, dpi=160) 627 plt.close() 628 629 return fig
Calculates and plots a dendrogram showing the similarity between the average image of each class.
Args: df (pd.DataFrame): DataFrame with image data and labels. class_names (List[str]): List of class names. out_path (Optional[str], optional): If provided, saves the plot. Defaults to None.
Returns: plt.Figure: The matplotlib figure object.