my_project.model
Defines the neural network architecture for the Fashion-MNIST classifier.
This module contains the Net class, a PyTorch LightningModule that
encapsulates the model's structure (a simple CNN), the forward pass logic,
and the training, validation, and test steps.
1""" 2Defines the neural network architecture for the Fashion-MNIST classifier. 3 4This module contains the `Net` class, a PyTorch LightningModule that 5encapsulates the model's structure (a simple CNN), the forward pass logic, 6and the training, validation, and test steps. 7""" 8 9import torch 10import torch.nn as nn 11import torch.optim as optim 12import torch.nn.functional as F 13import pytorch_lightning as pl 14 15 16class Net(pl.LightningModule): 17 """ 18 Simple convolutional neural network for Fashion-MNIST classification. 19 20 Architecture 21 ------------ 22 - Conv2d(1 → 16, kernel_size=3) 23 - ReLU 24 - MaxPool2d(kernel_size=2) 25 - Flatten 26 - Linear(16*13*13 → 32) # Note: This will be configurable 27 - ReLU 28 - Linear(32 → 10) # Note: This will be configurable 29 30 Loss: CrossEntropyLoss 31 32 Examples 33 -------- 34 >>> model = Net() 35 >>> x = torch.randn(8, 1, 28, 28) 36 >>> out = model(x) 37 >>> out.shape 38 torch.Size([8, 10]) 39 """ 40 41 def __init__( 42 self, 43 num_filters: int = 16, 44 hidden_size: int = 32, 45 lr: float = 1e-3, 46 ): 47 super().__init__() 48 self.save_hyperparameters() # Save hyperparameters 49 50 self.num_filters = num_filters 51 self.hidden_size = hidden_size 52 self.lr = lr 53 54 self.conv = nn.Conv2d(1, self.num_filters, 3) 55 self.pool = nn.MaxPool2d(2) 56 self.flat = nn.Flatten() 57 # Calculate the flattened size after conv and pool 58 self.fc1 = nn.Linear(self.num_filters * 13 * 13, self.hidden_size) 59 self.fc2 = nn.Linear(self.hidden_size, 10) 60 self.loss_fn = nn.CrossEntropyLoss() 61 62 def forward(self, x): 63 """ 64 Forward pass of the network. 65 66 Parameters 67 ---------- 68 x : torch.Tensor 69 Input tensor of shape (N, 1, 28, 28). 70 71 Returns 72 ------- 73 torch.Tensor 74 Output logits of shape (N, 10). 75 """ 76 77 x = self.pool(F.relu(self.conv(x))) 78 x = self.flat(x) 79 x = F.relu(self.fc1(x)) 80 return self.fc2(x) 81 82 def training_step(self, batch, batch_idx): 83 """ 84 Training step for a single batch. 85 86 Parameters 87 ---------- 88 batch : tuple 89 A batch of data (images, labels). 90 batch_idx : int 91 Index of the batch. 92 93 Returns 94 ------- 95 torch.Tensor 96 Training loss. 97 """ 98 99 xb, yb = batch 100 out = self(xb) 101 loss = self.loss_fn(out, yb) 102 self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True) 103 return loss 104 105 def validation_step(self, batch, batch_idx): 106 """ 107 Validation step for a single batch. 108 109 Parameters 110 ---------- 111 batch : tuple 112 A batch of data (images, labels). 113 batch_idx : int 114 Index of the batch. 115 116 Returns 117 ------- 118 None 119 """ 120 121 xb, yb = batch 122 out = self(xb) 123 preds = out.argmax(1) 124 acc = (preds == yb).float().mean() 125 self.log("val_acc", acc, prog_bar=True) 126 127 def test_step(self, batch, batch_idx): 128 """ 129 Test step for a single batch. 130 131 Parameters 132 ---------- 133 batch : tuple 134 A batch of data (images, labels). 135 batch_idx : int 136 Index of the batch. 137 138 Returns 139 ------- 140 None 141 """ 142 143 xb, yb = batch 144 out = self(xb) 145 preds = out.argmax(1) 146 acc = (preds == yb).float().mean() 147 self.log("test_acc", acc, prog_bar=True, on_epoch=True) 148 149 def configure_optimizers(self): 150 """ 151 Define optimizer for training. 152 153 Returns 154 ------- 155 torch.optim.Optimizer 156 Adam optimizer with default parameters. 157 """ 158 159 return optim.Adam(self.parameters(), lr=self.lr)
17class Net(pl.LightningModule): 18 """ 19 Simple convolutional neural network for Fashion-MNIST classification. 20 21 Architecture 22 ------------ 23 - Conv2d(1 → 16, kernel_size=3) 24 - ReLU 25 - MaxPool2d(kernel_size=2) 26 - Flatten 27 - Linear(16*13*13 → 32) # Note: This will be configurable 28 - ReLU 29 - Linear(32 → 10) # Note: This will be configurable 30 31 Loss: CrossEntropyLoss 32 33 Examples 34 -------- 35 >>> model = Net() 36 >>> x = torch.randn(8, 1, 28, 28) 37 >>> out = model(x) 38 >>> out.shape 39 torch.Size([8, 10]) 40 """ 41 42 def __init__( 43 self, 44 num_filters: int = 16, 45 hidden_size: int = 32, 46 lr: float = 1e-3, 47 ): 48 super().__init__() 49 self.save_hyperparameters() # Save hyperparameters 50 51 self.num_filters = num_filters 52 self.hidden_size = hidden_size 53 self.lr = lr 54 55 self.conv = nn.Conv2d(1, self.num_filters, 3) 56 self.pool = nn.MaxPool2d(2) 57 self.flat = nn.Flatten() 58 # Calculate the flattened size after conv and pool 59 self.fc1 = nn.Linear(self.num_filters * 13 * 13, self.hidden_size) 60 self.fc2 = nn.Linear(self.hidden_size, 10) 61 self.loss_fn = nn.CrossEntropyLoss() 62 63 def forward(self, x): 64 """ 65 Forward pass of the network. 66 67 Parameters 68 ---------- 69 x : torch.Tensor 70 Input tensor of shape (N, 1, 28, 28). 71 72 Returns 73 ------- 74 torch.Tensor 75 Output logits of shape (N, 10). 76 """ 77 78 x = self.pool(F.relu(self.conv(x))) 79 x = self.flat(x) 80 x = F.relu(self.fc1(x)) 81 return self.fc2(x) 82 83 def training_step(self, batch, batch_idx): 84 """ 85 Training step for a single batch. 86 87 Parameters 88 ---------- 89 batch : tuple 90 A batch of data (images, labels). 91 batch_idx : int 92 Index of the batch. 93 94 Returns 95 ------- 96 torch.Tensor 97 Training loss. 98 """ 99 100 xb, yb = batch 101 out = self(xb) 102 loss = self.loss_fn(out, yb) 103 self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True) 104 return loss 105 106 def validation_step(self, batch, batch_idx): 107 """ 108 Validation step for a single batch. 109 110 Parameters 111 ---------- 112 batch : tuple 113 A batch of data (images, labels). 114 batch_idx : int 115 Index of the batch. 116 117 Returns 118 ------- 119 None 120 """ 121 122 xb, yb = batch 123 out = self(xb) 124 preds = out.argmax(1) 125 acc = (preds == yb).float().mean() 126 self.log("val_acc", acc, prog_bar=True) 127 128 def test_step(self, batch, batch_idx): 129 """ 130 Test step for a single batch. 131 132 Parameters 133 ---------- 134 batch : tuple 135 A batch of data (images, labels). 136 batch_idx : int 137 Index of the batch. 138 139 Returns 140 ------- 141 None 142 """ 143 144 xb, yb = batch 145 out = self(xb) 146 preds = out.argmax(1) 147 acc = (preds == yb).float().mean() 148 self.log("test_acc", acc, prog_bar=True, on_epoch=True) 149 150 def configure_optimizers(self): 151 """ 152 Define optimizer for training. 153 154 Returns 155 ------- 156 torch.optim.Optimizer 157 Adam optimizer with default parameters. 158 """ 159 160 return optim.Adam(self.parameters(), lr=self.lr)
Simple convolutional neural network for Fashion-MNIST classification.
Architecture
- Conv2d(1 → 16, kernel_size=3)
- ReLU
- MaxPool2d(kernel_size=2)
- Flatten
- Linear(161313 → 32) # Note: This will be configurable
- ReLU
- Linear(32 → 10) # Note: This will be configurable
Loss: CrossEntropyLoss
Examples
>>> model = Net()
>>> x = torch.randn(8, 1, 28, 28)
>>> out = model(x)
>>> out.shape
torch.Size([8, 10])
42 def __init__( 43 self, 44 num_filters: int = 16, 45 hidden_size: int = 32, 46 lr: float = 1e-3, 47 ): 48 super().__init__() 49 self.save_hyperparameters() # Save hyperparameters 50 51 self.num_filters = num_filters 52 self.hidden_size = hidden_size 53 self.lr = lr 54 55 self.conv = nn.Conv2d(1, self.num_filters, 3) 56 self.pool = nn.MaxPool2d(2) 57 self.flat = nn.Flatten() 58 # Calculate the flattened size after conv and pool 59 self.fc1 = nn.Linear(self.num_filters * 13 * 13, self.hidden_size) 60 self.fc2 = nn.Linear(self.hidden_size, 10) 61 self.loss_fn = nn.CrossEntropyLoss()
Attributes: prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data. allow_zero_length_dataloader_with_multiple_devices: If True, dataloader with zero length within local rank is allowed. Default value is False.
63 def forward(self, x): 64 """ 65 Forward pass of the network. 66 67 Parameters 68 ---------- 69 x : torch.Tensor 70 Input tensor of shape (N, 1, 28, 28). 71 72 Returns 73 ------- 74 torch.Tensor 75 Output logits of shape (N, 10). 76 """ 77 78 x = self.pool(F.relu(self.conv(x))) 79 x = self.flat(x) 80 x = F.relu(self.fc1(x)) 81 return self.fc2(x)
Forward pass of the network.
Parameters
x : torch.Tensor Input tensor of shape (N, 1, 28, 28).
Returns
torch.Tensor Output logits of shape (N, 10).
83 def training_step(self, batch, batch_idx): 84 """ 85 Training step for a single batch. 86 87 Parameters 88 ---------- 89 batch : tuple 90 A batch of data (images, labels). 91 batch_idx : int 92 Index of the batch. 93 94 Returns 95 ------- 96 torch.Tensor 97 Training loss. 98 """ 99 100 xb, yb = batch 101 out = self(xb) 102 loss = self.loss_fn(out, yb) 103 self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True) 104 return loss
Training step for a single batch.
Parameters
batch : tuple A batch of data (images, labels). batch_idx : int Index of the batch.
Returns
torch.Tensor Training loss.
106 def validation_step(self, batch, batch_idx): 107 """ 108 Validation step for a single batch. 109 110 Parameters 111 ---------- 112 batch : tuple 113 A batch of data (images, labels). 114 batch_idx : int 115 Index of the batch. 116 117 Returns 118 ------- 119 None 120 """ 121 122 xb, yb = batch 123 out = self(xb) 124 preds = out.argmax(1) 125 acc = (preds == yb).float().mean() 126 self.log("val_acc", acc, prog_bar=True)
Validation step for a single batch.
Parameters
batch : tuple A batch of data (images, labels). batch_idx : int Index of the batch.
Returns
None
128 def test_step(self, batch, batch_idx): 129 """ 130 Test step for a single batch. 131 132 Parameters 133 ---------- 134 batch : tuple 135 A batch of data (images, labels). 136 batch_idx : int 137 Index of the batch. 138 139 Returns 140 ------- 141 None 142 """ 143 144 xb, yb = batch 145 out = self(xb) 146 preds = out.argmax(1) 147 acc = (preds == yb).float().mean() 148 self.log("test_acc", acc, prog_bar=True, on_epoch=True)
Test step for a single batch.
Parameters
batch : tuple A batch of data (images, labels). batch_idx : int Index of the batch.
Returns
None
150 def configure_optimizers(self): 151 """ 152 Define optimizer for training. 153 154 Returns 155 ------- 156 torch.optim.Optimizer 157 Adam optimizer with default parameters. 158 """ 159 160 return optim.Adam(self.parameters(), lr=self.lr)
Define optimizer for training.
Returns
torch.optim.Optimizer Adam optimizer with default parameters.