my_project.dataset

Defines the PyTorch Lightning DataModule for the Fashion-MNIST dataset.

This module is responsible for:

  • Downloading the Fashion-MNIST dataset.
  • Applying necessary transformations (e.g., normalization).
  • Splitting the data into training, validation, and test sets.
  • Creating and providing DataLoaders for each set.
 1"""
 2Defines the PyTorch Lightning DataModule for the Fashion-MNIST dataset.
 3
 4This module is responsible for:
 5- Downloading the Fashion-MNIST dataset.
 6- Applying necessary transformations (e.g., normalization).
 7- Splitting the data into training, validation, and test sets.
 8- Creating and providing DataLoaders for each set.
 9"""
10
11import pytorch_lightning as pl
12from torch.utils.data import DataLoader, random_split
13from torchvision import datasets, transforms
14
15
16class FashionMNISTDataModule(pl.LightningDataModule):
17    """
18    PyTorch Lightning DataModule for the Fashion-MNIST dataset.
19    It handles the downloading, splitting, and loading of the data.
20    """
21
22    def __init__(self, data_dir: str = "data/", batch_size: int = 128, val_split: float = 0.2, num_workers: int = 4):
23        """
24        Args:
25            data_dir (str): Directory where the data will be downloaded/stored.
26            batch_size (int): The batch size for the data loaders.
27            val_split (float): The fraction of the training data to use for validation.
28            num_workers (int): Number of subprocesses to use for data loading.
29        """
30        super().__init__()
31        self.data_dir = data_dir
32        self.batch_size = batch_size
33        self.num_workers = num_workers # Store num_workers
34        self.val_split = val_split
35        self.transform = transforms.Compose(
36            [
37                transforms.ToTensor(),
38                transforms.Normalize((0.5,), (0.5,)),  # Normalize to [-1, 1]
39            ]
40        )
41        self.train_dataset = None
42        self.val_dataset = None
43        self.test_dataset = None
44        self.train_val_dataset = None # For data exploration tab
45
46    def prepare_data(self):
47        """
48        Downloads the Fashion-MNIST dataset if it's not already present.
49        This method is called only on a single GPU/process.
50        """
51        datasets.FashionMNIST(self.data_dir, train=True, download=True)
52        datasets.FashionMNIST(self.data_dir, train=False, download=True)
53
54    def setup(self, stage: str = None):
55        """
56        Assigns train/val/test datasets for dataloaders.
57        This method is called on every GPU.
58        """
59        # Assign train/val datasets for use in dataloaders
60        if stage == "fit" or stage is None:
61            self.train_val_dataset = datasets.FashionMNIST(self.data_dir, train=True, transform=self.transform)
62            n_samples = len(self.train_val_dataset)
63            n_val = int(self.val_split * n_samples)
64            n_train = n_samples - n_val
65            self.train_dataset, self.val_dataset = random_split(self.train_val_dataset, [n_train, n_val])
66
67        # Assign test dataset for use in dataloader(s)
68        if stage == "test" or stage is None:
69            self.test_dataset = datasets.FashionMNIST(self.data_dir, train=False, transform=self.transform)
70
71    def train_dataloader(self):
72        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
73
74    def val_dataloader(self):
75        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
76
77    def test_dataloader(self):
78        return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
class FashionMNISTDataModule(pytorch_lightning.core.datamodule.LightningDataModule):
17class FashionMNISTDataModule(pl.LightningDataModule):
18    """
19    PyTorch Lightning DataModule for the Fashion-MNIST dataset.
20    It handles the downloading, splitting, and loading of the data.
21    """
22
23    def __init__(self, data_dir: str = "data/", batch_size: int = 128, val_split: float = 0.2, num_workers: int = 4):
24        """
25        Args:
26            data_dir (str): Directory where the data will be downloaded/stored.
27            batch_size (int): The batch size for the data loaders.
28            val_split (float): The fraction of the training data to use for validation.
29            num_workers (int): Number of subprocesses to use for data loading.
30        """
31        super().__init__()
32        self.data_dir = data_dir
33        self.batch_size = batch_size
34        self.num_workers = num_workers # Store num_workers
35        self.val_split = val_split
36        self.transform = transforms.Compose(
37            [
38                transforms.ToTensor(),
39                transforms.Normalize((0.5,), (0.5,)),  # Normalize to [-1, 1]
40            ]
41        )
42        self.train_dataset = None
43        self.val_dataset = None
44        self.test_dataset = None
45        self.train_val_dataset = None # For data exploration tab
46
47    def prepare_data(self):
48        """
49        Downloads the Fashion-MNIST dataset if it's not already present.
50        This method is called only on a single GPU/process.
51        """
52        datasets.FashionMNIST(self.data_dir, train=True, download=True)
53        datasets.FashionMNIST(self.data_dir, train=False, download=True)
54
55    def setup(self, stage: str = None):
56        """
57        Assigns train/val/test datasets for dataloaders.
58        This method is called on every GPU.
59        """
60        # Assign train/val datasets for use in dataloaders
61        if stage == "fit" or stage is None:
62            self.train_val_dataset = datasets.FashionMNIST(self.data_dir, train=True, transform=self.transform)
63            n_samples = len(self.train_val_dataset)
64            n_val = int(self.val_split * n_samples)
65            n_train = n_samples - n_val
66            self.train_dataset, self.val_dataset = random_split(self.train_val_dataset, [n_train, n_val])
67
68        # Assign test dataset for use in dataloader(s)
69        if stage == "test" or stage is None:
70            self.test_dataset = datasets.FashionMNIST(self.data_dir, train=False, transform=self.transform)
71
72    def train_dataloader(self):
73        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
74
75    def val_dataloader(self):
76        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
77
78    def test_dataloader(self):
79        return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers)

PyTorch Lightning DataModule for the Fashion-MNIST dataset. It handles the downloading, splitting, and loading of the data.

FashionMNISTDataModule( data_dir: str = 'data/', batch_size: int = 128, val_split: float = 0.2, num_workers: int = 4)
23    def __init__(self, data_dir: str = "data/", batch_size: int = 128, val_split: float = 0.2, num_workers: int = 4):
24        """
25        Args:
26            data_dir (str): Directory where the data will be downloaded/stored.
27            batch_size (int): The batch size for the data loaders.
28            val_split (float): The fraction of the training data to use for validation.
29            num_workers (int): Number of subprocesses to use for data loading.
30        """
31        super().__init__()
32        self.data_dir = data_dir
33        self.batch_size = batch_size
34        self.num_workers = num_workers # Store num_workers
35        self.val_split = val_split
36        self.transform = transforms.Compose(
37            [
38                transforms.ToTensor(),
39                transforms.Normalize((0.5,), (0.5,)),  # Normalize to [-1, 1]
40            ]
41        )
42        self.train_dataset = None
43        self.val_dataset = None
44        self.test_dataset = None
45        self.train_val_dataset = None # For data exploration tab

Args: data_dir (str): Directory where the data will be downloaded/stored. batch_size (int): The batch size for the data loaders. val_split (float): The fraction of the training data to use for validation. num_workers (int): Number of subprocesses to use for data loading.

data_dir
batch_size
num_workers
val_split
transform
train_dataset
val_dataset
test_dataset
train_val_dataset
def prepare_data(self):
47    def prepare_data(self):
48        """
49        Downloads the Fashion-MNIST dataset if it's not already present.
50        This method is called only on a single GPU/process.
51        """
52        datasets.FashionMNIST(self.data_dir, train=True, download=True)
53        datasets.FashionMNIST(self.data_dir, train=False, download=True)

Downloads the Fashion-MNIST dataset if it's not already present. This method is called only on a single GPU/process.

def setup(self, stage: str = None):
55    def setup(self, stage: str = None):
56        """
57        Assigns train/val/test datasets for dataloaders.
58        This method is called on every GPU.
59        """
60        # Assign train/val datasets for use in dataloaders
61        if stage == "fit" or stage is None:
62            self.train_val_dataset = datasets.FashionMNIST(self.data_dir, train=True, transform=self.transform)
63            n_samples = len(self.train_val_dataset)
64            n_val = int(self.val_split * n_samples)
65            n_train = n_samples - n_val
66            self.train_dataset, self.val_dataset = random_split(self.train_val_dataset, [n_train, n_val])
67
68        # Assign test dataset for use in dataloader(s)
69        if stage == "test" or stage is None:
70            self.test_dataset = datasets.FashionMNIST(self.data_dir, train=False, transform=self.transform)

Assigns train/val/test datasets for dataloaders. This method is called on every GPU.

def train_dataloader(self):
72    def train_dataloader(self):
73        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)

An iterable or collection of iterables specifying training samples.

For more information about multiple dataloaders, see this :ref:section <multiple-dataloaders>.

The dataloader you return will not be reloaded unless you set :paramref:~pytorch_lightning.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs to a positive integer.

For data processing use the following pattern:

- download in `prepare_data()`
- process and split in `setup()`

However, the above are only necessary for distributed processing.

do not assign state in prepare_data

Note: Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

def val_dataloader(self):
75    def val_dataloader(self):
76        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)

An iterable or collection of iterables specifying validation samples.

For more information about multiple dataloaders, see this :ref:section <multiple-dataloaders>.

The dataloader you return will not be reloaded unless you set :paramref:~pytorch_lightning.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs to a positive integer.

It's recommended that all data downloads and preparation happen in prepare_data().

  • ~pytorch_lightning.trainer.trainer.Trainer.fit()
  • ~pytorch_lightning.trainer.trainer.Trainer.validate()
  • prepare_data()
  • setup()

Note: Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Note: If you don't need a validation dataset and a validation_step(), you don't need to implement this method.

def test_dataloader(self):
78    def test_dataloader(self):
79        return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers)

An iterable or collection of iterables specifying test samples.

For more information about multiple dataloaders, see this :ref:section <multiple-dataloaders>.

For data processing use the following pattern:

- download in `prepare_data()`
- process and split in `setup()`

However, the above are only necessary for distributed processing.

do not assign state in prepare_data

Note: Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Note: If you don't need a test dataset and a test_step(), you don't need to implement this method.