my_project.train
Train and evaluate the Fashion-MNIST model.
This script can be run from the command line to:
- Initialize the DataModule and model with default or specified parameters.
- Run the training process for a fixed number of epochs.
- Evaluate the trained model on the test set.
- Save evaluation figures (like confusion matrix) in
reports/figures/.
Generated Artifacts
When you run this script, the following directories and files may be created:
data/:- Contains the downloaded Fashion-MNIST dataset.
models/lightning_logs/:- Stores logs and checkpoints from PyTorch Lightning during training.
reports/figures/:- Contains output visualizations from the evaluation step, such as
confusion_matrix.png,per_class_accuracy.png, etc.
- Contains output visualizations from the evaluation step, such as
Examples
Run training from the command line:
>>> python -m my_project.train
1""" 2Train and evaluate the Fashion-MNIST model. 3 4This script can be run from the command line to: 51. Initialize the DataModule and model with default or specified parameters. 62. Run the training process for a fixed number of epochs. 73. Evaluate the trained model on the test set. 84. Save evaluation figures (like confusion matrix) in `reports/figures/`. 9 10Generated Artifacts 11------------------- 12When you run this script, the following directories and files may be created: 13 14- `data/`: 15 - Contains the downloaded Fashion-MNIST dataset. 16 17- `models/lightning_logs/`: 18 - Stores logs and checkpoints from PyTorch Lightning during training. 19 20- `reports/figures/`: 21 - Contains output visualizations from the evaluation step, such as `confusion_matrix.png`, `per_class_accuracy.png`, etc. 22 23Examples 24-------- 25Run training from the command line: 26 27>>> python -m my_project.train 28""" 29 30from my_project.dataset import FashionMNISTDataModule 31from my_project.model import Net 32from my_project.plots import evaluate_and_plot 33from my_project.config import BATCH_SIZE, NUM_WORKERS, MAX_EPOCHS 34import pytorch_lightning as pl 35 36 37 38def main(): 39 """ 40 Train and evaluate the Fashion-MNIST model. 41 42 This script: 43 1. Initializes the DataModule and model. 44 2. Runs training for a fixed number of epochs. 45 3. Evaluates on the test set. 46 4. Saves evaluation figures in `reports/figures/`. 47 48 Returns 49 ------- 50 The main training and evaluation function. 51 """ 52 53 data_module = FashionMNISTDataModule( 54 data_dir="data/", 55 batch_size=BATCH_SIZE, 56 num_workers=NUM_WORKERS, 57 # The num_workers parameter is now accepted by FashionMNISTDataModule 58 ) 59 60 net = Net(num_filters=32, hidden_size=64) 61 62 trainer = pl.Trainer( 63 max_epochs=MAX_EPOCHS, 64 accelerator="auto", 65 devices="auto", 66 default_root_dir="models/lightning_logs", 67 ) 68 69 trainer.fit(net, datamodule=data_module) 70 trainer.test(net, datamodule=data_module) 71 72 artifacts = evaluate_and_plot(net, data_module, out_dir="reports/figures") 73 print(f"Test accuracy: {artifacts['test_accuracy']:.4f}") 74 print("Saved figures:") 75 for k, v in artifacts.items(): 76 if k != "test_accuracy": 77 print(f" - {k}: {v}") 78 79 80if __name__ == "__main__": 81 main()
def
main():
39def main(): 40 """ 41 Train and evaluate the Fashion-MNIST model. 42 43 This script: 44 1. Initializes the DataModule and model. 45 2. Runs training for a fixed number of epochs. 46 3. Evaluates on the test set. 47 4. Saves evaluation figures in `reports/figures/`. 48 49 Returns 50 ------- 51 The main training and evaluation function. 52 """ 53 54 data_module = FashionMNISTDataModule( 55 data_dir="data/", 56 batch_size=BATCH_SIZE, 57 num_workers=NUM_WORKERS, 58 # The num_workers parameter is now accepted by FashionMNISTDataModule 59 ) 60 61 net = Net(num_filters=32, hidden_size=64) 62 63 trainer = pl.Trainer( 64 max_epochs=MAX_EPOCHS, 65 accelerator="auto", 66 devices="auto", 67 default_root_dir="models/lightning_logs", 68 ) 69 70 trainer.fit(net, datamodule=data_module) 71 trainer.test(net, datamodule=data_module) 72 73 artifacts = evaluate_and_plot(net, data_module, out_dir="reports/figures") 74 print(f"Test accuracy: {artifacts['test_accuracy']:.4f}") 75 print("Saved figures:") 76 for k, v in artifacts.items(): 77 if k != "test_accuracy": 78 print(f" - {k}: {v}")
Train and evaluate the Fashion-MNIST model.
This script:
- Initializes the DataModule and model.
- Runs training for a fixed number of epochs.
- Evaluates on the test set.
- Saves evaluation figures in
reports/figures/.
Returns
The main training and evaluation function.