my_project.dataset
Defines the PyTorch Lightning DataModule for the Fashion-MNIST dataset.
This module is responsible for:
- Downloading the Fashion-MNIST dataset.
- Applying necessary transformations (e.g., normalization).
- Splitting the data into training, validation, and test sets.
- Creating and providing DataLoaders for each set.
1""" 2Defines the PyTorch Lightning DataModule for the Fashion-MNIST dataset. 3 4This module is responsible for: 5- Downloading the Fashion-MNIST dataset. 6- Applying necessary transformations (e.g., normalization). 7- Splitting the data into training, validation, and test sets. 8- Creating and providing DataLoaders for each set. 9""" 10 11import pytorch_lightning as pl 12from torch.utils.data import DataLoader, random_split 13from torchvision import datasets, transforms 14 15 16class FashionMNISTDataModule(pl.LightningDataModule): 17 """ 18 PyTorch Lightning DataModule for the Fashion-MNIST dataset. 19 It handles the downloading, splitting, and loading of the data. 20 """ 21 22 def __init__(self, data_dir: str = "data/", batch_size: int = 128, val_split: float = 0.2, num_workers: int = 4): 23 """ 24 Args: 25 data_dir (str): Directory where the data will be downloaded/stored. 26 batch_size (int): The batch size for the data loaders. 27 val_split (float): The fraction of the training data to use for validation. 28 num_workers (int): Number of subprocesses to use for data loading. 29 """ 30 super().__init__() 31 self.data_dir = data_dir 32 self.batch_size = batch_size 33 self.num_workers = num_workers # Store num_workers 34 self.val_split = val_split 35 self.transform = transforms.Compose( 36 [ 37 transforms.ToTensor(), 38 transforms.Normalize((0.5,), (0.5,)), # Normalize to [-1, 1] 39 ] 40 ) 41 self.train_dataset = None 42 self.val_dataset = None 43 self.test_dataset = None 44 self.train_val_dataset = None # For data exploration tab 45 46 def prepare_data(self): 47 """ 48 Downloads the Fashion-MNIST dataset if it's not already present. 49 This method is called only on a single GPU/process. 50 """ 51 datasets.FashionMNIST(self.data_dir, train=True, download=True) 52 datasets.FashionMNIST(self.data_dir, train=False, download=True) 53 54 def setup(self, stage: str = None): 55 """ 56 Assigns train/val/test datasets for dataloaders. 57 This method is called on every GPU. 58 """ 59 # Assign train/val datasets for use in dataloaders 60 if stage == "fit" or stage is None: 61 self.train_val_dataset = datasets.FashionMNIST(self.data_dir, train=True, transform=self.transform) 62 n_samples = len(self.train_val_dataset) 63 n_val = int(self.val_split * n_samples) 64 n_train = n_samples - n_val 65 self.train_dataset, self.val_dataset = random_split(self.train_val_dataset, [n_train, n_val]) 66 67 # Assign test dataset for use in dataloader(s) 68 if stage == "test" or stage is None: 69 self.test_dataset = datasets.FashionMNIST(self.data_dir, train=False, transform=self.transform) 70 71 def train_dataloader(self): 72 return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers) 73 74 def val_dataloader(self): 75 return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers) 76 77 def test_dataloader(self): 78 return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
17class FashionMNISTDataModule(pl.LightningDataModule): 18 """ 19 PyTorch Lightning DataModule for the Fashion-MNIST dataset. 20 It handles the downloading, splitting, and loading of the data. 21 """ 22 23 def __init__(self, data_dir: str = "data/", batch_size: int = 128, val_split: float = 0.2, num_workers: int = 4): 24 """ 25 Args: 26 data_dir (str): Directory where the data will be downloaded/stored. 27 batch_size (int): The batch size for the data loaders. 28 val_split (float): The fraction of the training data to use for validation. 29 num_workers (int): Number of subprocesses to use for data loading. 30 """ 31 super().__init__() 32 self.data_dir = data_dir 33 self.batch_size = batch_size 34 self.num_workers = num_workers # Store num_workers 35 self.val_split = val_split 36 self.transform = transforms.Compose( 37 [ 38 transforms.ToTensor(), 39 transforms.Normalize((0.5,), (0.5,)), # Normalize to [-1, 1] 40 ] 41 ) 42 self.train_dataset = None 43 self.val_dataset = None 44 self.test_dataset = None 45 self.train_val_dataset = None # For data exploration tab 46 47 def prepare_data(self): 48 """ 49 Downloads the Fashion-MNIST dataset if it's not already present. 50 This method is called only on a single GPU/process. 51 """ 52 datasets.FashionMNIST(self.data_dir, train=True, download=True) 53 datasets.FashionMNIST(self.data_dir, train=False, download=True) 54 55 def setup(self, stage: str = None): 56 """ 57 Assigns train/val/test datasets for dataloaders. 58 This method is called on every GPU. 59 """ 60 # Assign train/val datasets for use in dataloaders 61 if stage == "fit" or stage is None: 62 self.train_val_dataset = datasets.FashionMNIST(self.data_dir, train=True, transform=self.transform) 63 n_samples = len(self.train_val_dataset) 64 n_val = int(self.val_split * n_samples) 65 n_train = n_samples - n_val 66 self.train_dataset, self.val_dataset = random_split(self.train_val_dataset, [n_train, n_val]) 67 68 # Assign test dataset for use in dataloader(s) 69 if stage == "test" or stage is None: 70 self.test_dataset = datasets.FashionMNIST(self.data_dir, train=False, transform=self.transform) 71 72 def train_dataloader(self): 73 return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers) 74 75 def val_dataloader(self): 76 return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers) 77 78 def test_dataloader(self): 79 return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
PyTorch Lightning DataModule for the Fashion-MNIST dataset. It handles the downloading, splitting, and loading of the data.
23 def __init__(self, data_dir: str = "data/", batch_size: int = 128, val_split: float = 0.2, num_workers: int = 4): 24 """ 25 Args: 26 data_dir (str): Directory where the data will be downloaded/stored. 27 batch_size (int): The batch size for the data loaders. 28 val_split (float): The fraction of the training data to use for validation. 29 num_workers (int): Number of subprocesses to use for data loading. 30 """ 31 super().__init__() 32 self.data_dir = data_dir 33 self.batch_size = batch_size 34 self.num_workers = num_workers # Store num_workers 35 self.val_split = val_split 36 self.transform = transforms.Compose( 37 [ 38 transforms.ToTensor(), 39 transforms.Normalize((0.5,), (0.5,)), # Normalize to [-1, 1] 40 ] 41 ) 42 self.train_dataset = None 43 self.val_dataset = None 44 self.test_dataset = None 45 self.train_val_dataset = None # For data exploration tab
Args: data_dir (str): Directory where the data will be downloaded/stored. batch_size (int): The batch size for the data loaders. val_split (float): The fraction of the training data to use for validation. num_workers (int): Number of subprocesses to use for data loading.
47 def prepare_data(self): 48 """ 49 Downloads the Fashion-MNIST dataset if it's not already present. 50 This method is called only on a single GPU/process. 51 """ 52 datasets.FashionMNIST(self.data_dir, train=True, download=True) 53 datasets.FashionMNIST(self.data_dir, train=False, download=True)
Downloads the Fashion-MNIST dataset if it's not already present. This method is called only on a single GPU/process.
55 def setup(self, stage: str = None): 56 """ 57 Assigns train/val/test datasets for dataloaders. 58 This method is called on every GPU. 59 """ 60 # Assign train/val datasets for use in dataloaders 61 if stage == "fit" or stage is None: 62 self.train_val_dataset = datasets.FashionMNIST(self.data_dir, train=True, transform=self.transform) 63 n_samples = len(self.train_val_dataset) 64 n_val = int(self.val_split * n_samples) 65 n_train = n_samples - n_val 66 self.train_dataset, self.val_dataset = random_split(self.train_val_dataset, [n_train, n_val]) 67 68 # Assign test dataset for use in dataloader(s) 69 if stage == "test" or stage is None: 70 self.test_dataset = datasets.FashionMNIST(self.data_dir, train=False, transform=self.transform)
Assigns train/val/test datasets for dataloaders. This method is called on every GPU.
72 def train_dataloader(self): 73 return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
An iterable or collection of iterables specifying training samples.
For more information about multiple dataloaders, see this :ref:section <multiple-dataloaders>.
The dataloader you return will not be reloaded unless you set
:paramref:~pytorch_lightning.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs to
a positive integer.
For data processing use the following pattern:
- download in `prepare_data()`
- process and split in `setup()`
However, the above are only necessary for distributed processing.
do not assign state in prepare_data
~pytorch_lightning.trainer.trainer.Trainer.fit()prepare_data()setup()
Note: Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
75 def val_dataloader(self): 76 return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
An iterable or collection of iterables specifying validation samples.
For more information about multiple dataloaders, see this :ref:section <multiple-dataloaders>.
The dataloader you return will not be reloaded unless you set
:paramref:~pytorch_lightning.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs to
a positive integer.
It's recommended that all data downloads and preparation happen in prepare_data().
~pytorch_lightning.trainer.trainer.Trainer.fit()~pytorch_lightning.trainer.trainer.Trainer.validate()prepare_data()setup()
Note: Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
Note:
If you don't need a validation dataset and a validation_step(), you don't need to
implement this method.
78 def test_dataloader(self): 79 return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
An iterable or collection of iterables specifying test samples.
For more information about multiple dataloaders, see this :ref:section <multiple-dataloaders>.
For data processing use the following pattern:
- download in `prepare_data()`
- process and split in `setup()`
However, the above are only necessary for distributed processing.
do not assign state in prepare_data
~pytorch_lightning.trainer.trainer.Trainer.test()prepare_data()setup()
Note: Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
Note:
If you don't need a test dataset and a test_step(), you don't need to implement
this method.