my_project.model

Defines the neural network architecture for the Fashion-MNIST classifier.

This module contains the Net class, a PyTorch LightningModule that encapsulates the model's structure (a simple CNN), the forward pass logic, and the training, validation, and test steps.

  1"""
  2Defines the neural network architecture for the Fashion-MNIST classifier.
  3
  4This module contains the `Net` class, a PyTorch LightningModule that
  5encapsulates the model's structure (a simple CNN), the forward pass logic,
  6and the training, validation, and test steps.
  7"""
  8
  9import torch
 10import torch.nn as nn
 11import torch.optim as optim
 12import torch.nn.functional as F
 13import pytorch_lightning as pl
 14
 15
 16class Net(pl.LightningModule):
 17    """
 18    Simple convolutional neural network for Fashion-MNIST classification.
 19
 20    Architecture
 21    ------------
 22    - Conv2d(1 → 16, kernel_size=3)
 23    - ReLU
 24    - MaxPool2d(kernel_size=2)
 25    - Flatten
 26    - Linear(16*13*13 → 32)  # Note: This will be configurable
 27    - ReLU
 28    - Linear(32 → 10)       # Note: This will be configurable
 29
 30    Loss: CrossEntropyLoss
 31
 32    Examples
 33    --------
 34    >>> model = Net()
 35    >>> x = torch.randn(8, 1, 28, 28)
 36    >>> out = model(x)
 37    >>> out.shape
 38    torch.Size([8, 10])
 39    """
 40
 41    def __init__(
 42        self,
 43        num_filters: int = 16,
 44        hidden_size: int = 32,
 45        lr: float = 1e-3,
 46    ):
 47        super().__init__()
 48        self.save_hyperparameters()  # Save hyperparameters
 49
 50        self.num_filters = num_filters
 51        self.hidden_size = hidden_size
 52        self.lr = lr
 53
 54        self.conv = nn.Conv2d(1, self.num_filters, 3)
 55        self.pool = nn.MaxPool2d(2)
 56        self.flat = nn.Flatten()
 57        # Calculate the flattened size after conv and pool
 58        self.fc1 = nn.Linear(self.num_filters * 13 * 13, self.hidden_size)
 59        self.fc2 = nn.Linear(self.hidden_size, 10)
 60        self.loss_fn = nn.CrossEntropyLoss()
 61
 62    def forward(self, x):
 63        """
 64        Forward pass of the network.
 65
 66        Parameters
 67        ----------
 68        x : torch.Tensor
 69            Input tensor of shape (N, 1, 28, 28).
 70
 71        Returns
 72        -------
 73        torch.Tensor
 74            Output logits of shape (N, 10).
 75        """
 76
 77        x = self.pool(F.relu(self.conv(x)))
 78        x = self.flat(x)
 79        x = F.relu(self.fc1(x))
 80        return self.fc2(x)
 81
 82    def training_step(self, batch, batch_idx):
 83        """
 84        Training step for a single batch.
 85
 86        Parameters
 87        ----------
 88        batch : tuple
 89            A batch of data (images, labels).
 90        batch_idx : int
 91            Index of the batch.
 92
 93        Returns
 94        -------
 95        torch.Tensor
 96            Training loss.
 97        """
 98
 99        xb, yb = batch
100        out = self(xb)
101        loss = self.loss_fn(out, yb)
102        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
103        return loss
104
105    def validation_step(self, batch, batch_idx):
106        """
107        Validation step for a single batch.
108
109        Parameters
110        ----------
111        batch : tuple
112            A batch of data (images, labels).
113        batch_idx : int
114            Index of the batch.
115
116        Returns
117        -------
118        None
119        """
120
121        xb, yb = batch
122        out = self(xb)
123        preds = out.argmax(1)
124        acc = (preds == yb).float().mean()
125        self.log("val_acc", acc, prog_bar=True)
126
127    def test_step(self, batch, batch_idx):
128        """
129        Test step for a single batch.
130
131        Parameters
132        ----------
133        batch : tuple
134            A batch of data (images, labels).
135        batch_idx : int
136            Index of the batch.
137
138        Returns
139        -------
140        None
141        """
142
143        xb, yb = batch
144        out = self(xb)
145        preds = out.argmax(1)
146        acc = (preds == yb).float().mean()
147        self.log("test_acc", acc, prog_bar=True, on_epoch=True)
148
149    def configure_optimizers(self):
150        """
151        Define optimizer for training.
152
153        Returns
154        -------
155        torch.optim.Optimizer
156            Adam optimizer with default parameters.
157        """
158
159        return optim.Adam(self.parameters(), lr=self.lr)
class Net(pytorch_lightning.core.module.LightningModule):
 17class Net(pl.LightningModule):
 18    """
 19    Simple convolutional neural network for Fashion-MNIST classification.
 20
 21    Architecture
 22    ------------
 23    - Conv2d(1 → 16, kernel_size=3)
 24    - ReLU
 25    - MaxPool2d(kernel_size=2)
 26    - Flatten
 27    - Linear(16*13*13 → 32)  # Note: This will be configurable
 28    - ReLU
 29    - Linear(32 → 10)       # Note: This will be configurable
 30
 31    Loss: CrossEntropyLoss
 32
 33    Examples
 34    --------
 35    >>> model = Net()
 36    >>> x = torch.randn(8, 1, 28, 28)
 37    >>> out = model(x)
 38    >>> out.shape
 39    torch.Size([8, 10])
 40    """
 41
 42    def __init__(
 43        self,
 44        num_filters: int = 16,
 45        hidden_size: int = 32,
 46        lr: float = 1e-3,
 47    ):
 48        super().__init__()
 49        self.save_hyperparameters()  # Save hyperparameters
 50
 51        self.num_filters = num_filters
 52        self.hidden_size = hidden_size
 53        self.lr = lr
 54
 55        self.conv = nn.Conv2d(1, self.num_filters, 3)
 56        self.pool = nn.MaxPool2d(2)
 57        self.flat = nn.Flatten()
 58        # Calculate the flattened size after conv and pool
 59        self.fc1 = nn.Linear(self.num_filters * 13 * 13, self.hidden_size)
 60        self.fc2 = nn.Linear(self.hidden_size, 10)
 61        self.loss_fn = nn.CrossEntropyLoss()
 62
 63    def forward(self, x):
 64        """
 65        Forward pass of the network.
 66
 67        Parameters
 68        ----------
 69        x : torch.Tensor
 70            Input tensor of shape (N, 1, 28, 28).
 71
 72        Returns
 73        -------
 74        torch.Tensor
 75            Output logits of shape (N, 10).
 76        """
 77
 78        x = self.pool(F.relu(self.conv(x)))
 79        x = self.flat(x)
 80        x = F.relu(self.fc1(x))
 81        return self.fc2(x)
 82
 83    def training_step(self, batch, batch_idx):
 84        """
 85        Training step for a single batch.
 86
 87        Parameters
 88        ----------
 89        batch : tuple
 90            A batch of data (images, labels).
 91        batch_idx : int
 92            Index of the batch.
 93
 94        Returns
 95        -------
 96        torch.Tensor
 97            Training loss.
 98        """
 99
100        xb, yb = batch
101        out = self(xb)
102        loss = self.loss_fn(out, yb)
103        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
104        return loss
105
106    def validation_step(self, batch, batch_idx):
107        """
108        Validation step for a single batch.
109
110        Parameters
111        ----------
112        batch : tuple
113            A batch of data (images, labels).
114        batch_idx : int
115            Index of the batch.
116
117        Returns
118        -------
119        None
120        """
121
122        xb, yb = batch
123        out = self(xb)
124        preds = out.argmax(1)
125        acc = (preds == yb).float().mean()
126        self.log("val_acc", acc, prog_bar=True)
127
128    def test_step(self, batch, batch_idx):
129        """
130        Test step for a single batch.
131
132        Parameters
133        ----------
134        batch : tuple
135            A batch of data (images, labels).
136        batch_idx : int
137            Index of the batch.
138
139        Returns
140        -------
141        None
142        """
143
144        xb, yb = batch
145        out = self(xb)
146        preds = out.argmax(1)
147        acc = (preds == yb).float().mean()
148        self.log("test_acc", acc, prog_bar=True, on_epoch=True)
149
150    def configure_optimizers(self):
151        """
152        Define optimizer for training.
153
154        Returns
155        -------
156        torch.optim.Optimizer
157            Adam optimizer with default parameters.
158        """
159
160        return optim.Adam(self.parameters(), lr=self.lr)

Simple convolutional neural network for Fashion-MNIST classification.

Architecture

  • Conv2d(1 → 16, kernel_size=3)
  • ReLU
  • MaxPool2d(kernel_size=2)
  • Flatten
  • Linear(161313 → 32) # Note: This will be configurable
  • ReLU
  • Linear(32 → 10) # Note: This will be configurable

Loss: CrossEntropyLoss

Examples

>>> model = Net()
>>> x = torch.randn(8, 1, 28, 28)
>>> out = model(x)
>>> out.shape
torch.Size([8, 10])
Net(num_filters: int = 16, hidden_size: int = 32, lr: float = 0.001)
42    def __init__(
43        self,
44        num_filters: int = 16,
45        hidden_size: int = 32,
46        lr: float = 1e-3,
47    ):
48        super().__init__()
49        self.save_hyperparameters()  # Save hyperparameters
50
51        self.num_filters = num_filters
52        self.hidden_size = hidden_size
53        self.lr = lr
54
55        self.conv = nn.Conv2d(1, self.num_filters, 3)
56        self.pool = nn.MaxPool2d(2)
57        self.flat = nn.Flatten()
58        # Calculate the flattened size after conv and pool
59        self.fc1 = nn.Linear(self.num_filters * 13 * 13, self.hidden_size)
60        self.fc2 = nn.Linear(self.hidden_size, 10)
61        self.loss_fn = nn.CrossEntropyLoss()

Attributes: prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data. allow_zero_length_dataloader_with_multiple_devices: If True, dataloader with zero length within local rank is allowed. Default value is False.

num_filters
hidden_size
lr
conv
pool
flat
fc1
fc2
loss_fn
def forward(self, x):
63    def forward(self, x):
64        """
65        Forward pass of the network.
66
67        Parameters
68        ----------
69        x : torch.Tensor
70            Input tensor of shape (N, 1, 28, 28).
71
72        Returns
73        -------
74        torch.Tensor
75            Output logits of shape (N, 10).
76        """
77
78        x = self.pool(F.relu(self.conv(x)))
79        x = self.flat(x)
80        x = F.relu(self.fc1(x))
81        return self.fc2(x)

Forward pass of the network.

Parameters

x : torch.Tensor Input tensor of shape (N, 1, 28, 28).

Returns

torch.Tensor Output logits of shape (N, 10).

def training_step(self, batch, batch_idx):
 83    def training_step(self, batch, batch_idx):
 84        """
 85        Training step for a single batch.
 86
 87        Parameters
 88        ----------
 89        batch : tuple
 90            A batch of data (images, labels).
 91        batch_idx : int
 92            Index of the batch.
 93
 94        Returns
 95        -------
 96        torch.Tensor
 97            Training loss.
 98        """
 99
100        xb, yb = batch
101        out = self(xb)
102        loss = self.loss_fn(out, yb)
103        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
104        return loss

Training step for a single batch.

Parameters

batch : tuple A batch of data (images, labels). batch_idx : int Index of the batch.

Returns

torch.Tensor Training loss.

def validation_step(self, batch, batch_idx):
106    def validation_step(self, batch, batch_idx):
107        """
108        Validation step for a single batch.
109
110        Parameters
111        ----------
112        batch : tuple
113            A batch of data (images, labels).
114        batch_idx : int
115            Index of the batch.
116
117        Returns
118        -------
119        None
120        """
121
122        xb, yb = batch
123        out = self(xb)
124        preds = out.argmax(1)
125        acc = (preds == yb).float().mean()
126        self.log("val_acc", acc, prog_bar=True)

Validation step for a single batch.

Parameters

batch : tuple A batch of data (images, labels). batch_idx : int Index of the batch.

Returns

None

def test_step(self, batch, batch_idx):
128    def test_step(self, batch, batch_idx):
129        """
130        Test step for a single batch.
131
132        Parameters
133        ----------
134        batch : tuple
135            A batch of data (images, labels).
136        batch_idx : int
137            Index of the batch.
138
139        Returns
140        -------
141        None
142        """
143
144        xb, yb = batch
145        out = self(xb)
146        preds = out.argmax(1)
147        acc = (preds == yb).float().mean()
148        self.log("test_acc", acc, prog_bar=True, on_epoch=True)

Test step for a single batch.

Parameters

batch : tuple A batch of data (images, labels). batch_idx : int Index of the batch.

Returns

None

def configure_optimizers(self):
150    def configure_optimizers(self):
151        """
152        Define optimizer for training.
153
154        Returns
155        -------
156        torch.optim.Optimizer
157            Adam optimizer with default parameters.
158        """
159
160        return optim.Adam(self.parameters(), lr=self.lr)

Define optimizer for training.

Returns

torch.optim.Optimizer Adam optimizer with default parameters.