my_project.train

Train and evaluate the Fashion-MNIST model.

This script can be run from the command line to:

  1. Initialize the DataModule and model with default or specified parameters.
  2. Run the training process for a fixed number of epochs.
  3. Evaluate the trained model on the test set.
  4. Save evaluation figures (like confusion matrix) in reports/figures/.

Generated Artifacts

When you run this script, the following directories and files may be created:

  • data/:

    • Contains the downloaded Fashion-MNIST dataset.
  • models/lightning_logs/:

    • Stores logs and checkpoints from PyTorch Lightning during training.
  • reports/figures/:

    • Contains output visualizations from the evaluation step, such as confusion_matrix.png, per_class_accuracy.png, etc.

Examples

Run training from the command line:

>>> python -m my_project.train
 1"""
 2Train and evaluate the Fashion-MNIST model.
 3
 4This script can be run from the command line to:
 51. Initialize the DataModule and model with default or specified parameters.
 62. Run the training process for a fixed number of epochs.
 73. Evaluate the trained model on the test set.
 84. Save evaluation figures (like confusion matrix) in `reports/figures/`.
 9
10Generated Artifacts
11-------------------
12When you run this script, the following directories and files may be created:
13
14- `data/`:
15  - Contains the downloaded Fashion-MNIST dataset.
16
17- `models/lightning_logs/`:
18  - Stores logs and checkpoints from PyTorch Lightning during training.
19
20- `reports/figures/`:
21  - Contains output visualizations from the evaluation step, such as `confusion_matrix.png`, `per_class_accuracy.png`, etc.
22
23Examples
24--------
25Run training from the command line:
26
27>>> python -m my_project.train
28"""
29
30from my_project.dataset import FashionMNISTDataModule
31from my_project.model import Net
32from my_project.plots import evaluate_and_plot
33from my_project.config import BATCH_SIZE, NUM_WORKERS, MAX_EPOCHS
34import pytorch_lightning as pl
35
36
37
38def main():
39    """
40    Train and evaluate the Fashion-MNIST model.
41
42    This script:
43    1. Initializes the DataModule and model.
44    2. Runs training for a fixed number of epochs.
45    3. Evaluates on the test set.
46    4. Saves evaluation figures in `reports/figures/`.
47
48    Returns
49    -------
50    The main training and evaluation function.
51    """
52
53    data_module = FashionMNISTDataModule(
54        data_dir="data/",
55        batch_size=BATCH_SIZE,
56        num_workers=NUM_WORKERS,
57        # The num_workers parameter is now accepted by FashionMNISTDataModule
58    )
59
60    net = Net(num_filters=32, hidden_size=64)
61
62    trainer = pl.Trainer(
63        max_epochs=MAX_EPOCHS,
64        accelerator="auto",
65        devices="auto",
66        default_root_dir="models/lightning_logs",
67    )
68
69    trainer.fit(net, datamodule=data_module)
70    trainer.test(net, datamodule=data_module)
71
72    artifacts = evaluate_and_plot(net, data_module, out_dir="reports/figures")
73    print(f"Test accuracy: {artifacts['test_accuracy']:.4f}")
74    print("Saved figures:")
75    for k, v in artifacts.items():
76        if k != "test_accuracy":
77            print(f"  - {k}: {v}")
78
79
80if __name__ == "__main__":
81    main()
def main():
39def main():
40    """
41    Train and evaluate the Fashion-MNIST model.
42
43    This script:
44    1. Initializes the DataModule and model.
45    2. Runs training for a fixed number of epochs.
46    3. Evaluates on the test set.
47    4. Saves evaluation figures in `reports/figures/`.
48
49    Returns
50    -------
51    The main training and evaluation function.
52    """
53
54    data_module = FashionMNISTDataModule(
55        data_dir="data/",
56        batch_size=BATCH_SIZE,
57        num_workers=NUM_WORKERS,
58        # The num_workers parameter is now accepted by FashionMNISTDataModule
59    )
60
61    net = Net(num_filters=32, hidden_size=64)
62
63    trainer = pl.Trainer(
64        max_epochs=MAX_EPOCHS,
65        accelerator="auto",
66        devices="auto",
67        default_root_dir="models/lightning_logs",
68    )
69
70    trainer.fit(net, datamodule=data_module)
71    trainer.test(net, datamodule=data_module)
72
73    artifacts = evaluate_and_plot(net, data_module, out_dir="reports/figures")
74    print(f"Test accuracy: {artifacts['test_accuracy']:.4f}")
75    print("Saved figures:")
76    for k, v in artifacts.items():
77        if k != "test_accuracy":
78            print(f"  - {k}: {v}")

Train and evaluate the Fashion-MNIST model.

This script:

  1. Initializes the DataModule and model.
  2. Runs training for a fixed number of epochs.
  3. Evaluates on the test set.
  4. Saves evaluation figures in reports/figures/.

Returns

The main training and evaluation function.