my_project.app
Gradio application for interactive Fashion-MNIST model training and data exploration.
This script launches a web-based interface with two main tabs:
- Data Exploration: Allows users to visualize the Fashion-MNIST dataset. Users can switch between training and test sets, view class distributions, see a class similarity dendrogram, and browse a gallery of sample images with filtering options.
- Train & Evaluate: Provides an interface to train a model with adjustable hyperparameters (e.g., batch size, learning rate, epochs). After training, it displays evaluation results, including test accuracy, learning curves, a confusion matrix, and other performance plots.
1""" 2Gradio application for interactive Fashion-MNIST model training and data exploration. 3 4This script launches a web-based interface with two main tabs: 51. **Data Exploration**: Allows users to visualize the Fashion-MNIST dataset. 6 Users can switch between training and test sets, view class distributions, 7 see a class similarity dendrogram, and browse a gallery of sample images 8 with filtering options. 9 102. **Train & Evaluate**: Provides an interface to train a model with adjustable 11 hyperparameters (e.g., batch size, learning rate, epochs). After training, 12 it displays evaluation results, including test accuracy, learning curves, 13 a confusion matrix, and other performance plots. 14""" 15import gradio as gr 16import pytorch_lightning as pl 17from pytorch_lightning.loggers import CSVLogger 18import torch 19import pandas as pd 20import tempfile 21import os 22 23from my_project.dataset import FashionMNISTDataModule 24from my_project.model import Net 25from my_project.plots import ( 26 evaluate_and_plot, 27 FASHION_CLASSES, 28 plot_class_distribution, 29 get_sample_images_for_gallery, 30 plot_learning_curves_from_df, 31 plot_class_correlation_dendrogram, 32) 33 34# Set matmul precision for Tensor Cores 35if torch.cuda.is_available(): 36 torch.set_float32_matmul_precision('high') 37 38 39def get_df_from_dataset(dataset): 40 """Converts a torchvision dataset to a pandas DataFrame.""" 41 images = dataset.data.view(len(dataset), -1).numpy() 42 labels = dataset.targets.numpy() 43 df = pd.DataFrame(images) 44 df.columns = [f"pixel{i}" for i in range(images.shape[1])] 45 df["label"] = labels 46 return df 47 48 49# --- Load data for exploration using the DataModule --- 50explore_datamodule = FashionMNISTDataModule(data_dir="data/", batch_size=128) 51explore_datamodule.prepare_data() 52explore_datamodule.setup() 53train_df = get_df_from_dataset(explore_datamodule.train_val_dataset) 54test_df = get_df_from_dataset(explore_datamodule.test_dataset) 55 56 57def update_data_exploration(dataset_choice, class_filter): 58 """ 59 Updates the components in the Data Exploration tab based on user selection. 60 """ 61 df = train_df if dataset_choice == "Train" else test_df 62 63 # 1. Update the class distribution plot 64 dist_plot_fig = plot_class_distribution(df, FASHION_CLASSES) 65 66 # 2. Update the dendrogram 67 dendrogram_fig = plot_class_correlation_dendrogram(df, FASHION_CLASSES) 68 69 # 2. Update the gallery 70 df_to_sample = df 71 if class_filter != "All": 72 class_index = FASHION_CLASSES.index(class_filter) 73 df_to_sample = df[df["label"] == class_index] 74 75 gallery_images = get_sample_images_for_gallery( 76 df_to_sample, FASHION_CLASSES, n_samples=15 77 ) 78 79 # 3. Update statistics text 80 stats_md = f""" 81 ### Dataset Statistics 82 - **Selected Set:** {dataset_choice} 83 - **Total Samples:** {len(df)} 84 - **Number of Classes:** {len(FASHION_CLASSES)} 85 - **Image Size:** 28x28 pixels (grayscale) 86 """ 87 88 return stats_md, dist_plot_fig, dendrogram_fig, gallery_images 89 90 91 92def train_and_evaluate( 93 batch_size: int, 94 max_epochs: int, 95 lr: float, 96 num_filters: int, 97 hidden_size: int, 98 progress=gr.Progress(track_tqdm=True), 99): 100 """ 101 A function to train and evaluate the model with given hyperparameters. 102 This will be connected to the Gradio interface. 103 """ 104 progress(0, desc="Initializing DataModule...") 105 # The data will be downloaded to the 'data/' directory if not present 106 datamodule = FashionMNISTDataModule(data_dir="data/", batch_size=int(batch_size)) 107 108 progress(0.1, desc="Initializing Model...") 109 model = Net(num_filters=int(num_filters), hidden_size=int(hidden_size), lr=lr) 110 111 # Use a simple callback to update progress 112 class GradioProgressCallback(pl.Callback): 113 def on_train_epoch_end(self, trainer, pl_module): 114 progress( 115 trainer.current_epoch / trainer.max_epochs, 116 desc=f"Epoch {trainer.current_epoch+1}/{trainer.max_epochs}", 117 ) 118 119 # Use a temporary directory for logs 120 with tempfile.TemporaryDirectory() as tmpdir: 121 logger = CSVLogger(save_dir=tmpdir, name="gradio_logs") 122 trainer = pl.Trainer( 123 max_epochs=int(max_epochs), 124 accelerator="gpu", # Explicitly use GPU 125 devices="auto", 126 logger=logger, 127 callbacks=[GradioProgressCallback()], 128 enable_checkpointing=False, 129 ) 130 131 progress(0.2, desc="Starting Training...") 132 trainer.fit(model, datamodule=datamodule) 133 134 progress(0.85, desc="Generating Learning Curves...") 135 metrics_path = os.path.join(logger.log_dir, "metrics.csv") 136 train_loss_fig, val_acc_fig = None, None 137 if os.path.exists(metrics_path): 138 metrics_df = pd.read_csv(metrics_path) 139 train_loss_fig, val_acc_fig = plot_learning_curves_from_df(metrics_df) 140 141 progress(0.9, desc="Evaluating on Test Set...") 142 # Manually call setup for the 'test' stage to ensure test_dataset is initialized. 143 datamodule.setup(stage="test") 144 145 # The evaluate_and_plot function already returns the paths to the plots 146 artifacts = evaluate_and_plot( 147 model, datamodule, out_dir="reports/figures/gradio" 148 ) 149 150 return ( 151 f"{artifacts['test_accuracy']:.4f}", 152 # Learning curves 153 train_loss_fig, 154 val_acc_fig, 155 # Evaluation plots 156 artifacts["confusion_matrix"], 157 artifacts["per_class_accuracy"], 158 artifacts["misclassified_grid"], 159 artifacts["calibration_curve"], 160 ) 161 162 163# We define the Gradio interface using Blocks for more control. 164with gr.Blocks( 165 css=""" 166.gradio-container .gallery-item { max-height: 100px !important; min-height: 100px !important; } 167""" 168) as demo: 169 gr.Markdown("# Fashion-MNIST Interactive Demo") 170 171 with gr.Tab("Data Exploration"): 172 gr.Markdown("## Interactive Exploration of the Fashion-MNIST Dataset") 173 gr.Markdown( 174 "Use the controls to switch between the training and test datasets, or to filter images by class." 175 ) 176 177 with gr.Row(variant="panel"): 178 with gr.Column(scale=1, min_width=250): 179 gr.Markdown("### Controls") 180 dataset_selector = gr.Radio( 181 ["Train", "Test"], value="Train", label="Select Dataset" 182 ) 183 stats_md_box = gr.Markdown() 184 185 gr.Markdown("### Image Gallery") 186 class_selector = gr.Dropdown( 187 ["All"] + FASHION_CLASSES, value="All", label="Filter by Class" 188 ) 189 refresh_button = gr.Button("Refresh Images") 190 191 with gr.Column(scale=3): 192 with gr.Tabs(): 193 with gr.TabItem("Class Distribution"): 194 dist_plot = gr.Plot() 195 with gr.TabItem("Class Similarity"): 196 dendrogram_plot = gr.Plot() 197 gr.Markdown("### Image Samples") 198 gallery = gr.Gallery( 199 label="Random samples from the dataset", 200 columns=5, 201 object_fit="contain", 202 height="auto", 203 show_label=True, 204 ) 205 206 # --- Event Listeners for Interactivity --- 207 # Combine selectors to a single update function for efficiency 208 controls = [dataset_selector, class_selector] 209 outputs = [stats_md_box, dist_plot, dendrogram_plot, gallery] 210 211 for control in controls: 212 control.change(fn=update_data_exploration, inputs=controls, outputs=outputs) 213 refresh_button.click( 214 fn=update_data_exploration, inputs=controls, outputs=outputs 215 ) 216 # Initial load for the data exploration tab 217 demo.load(fn=update_data_exploration, inputs=[gr.State("Train"), gr.State("All")], outputs=outputs) 218 219 with gr.Tab("Train & Evaluate"): 220 with gr.Row(): 221 with gr.Column(scale=1): 222 gr.Markdown("## Hyperparameters") 223 batch_size_slider = gr.Slider( 224 32, 512, value=128, step=32, label="Batch Size" 225 ) 226 epochs_slider = gr.Slider(1, 20, value=5, step=1, label="Max Epochs") 227 lr_slider = gr.Slider( 228 1e-5, 1e-2, value=1e-3, label="Learning Rate", step=1e-5 229 ) 230 num_filters_slider = gr.Slider( 231 8, 64, value=32, step=8, label="Conv Filters" 232 ) 233 hidden_size_slider = gr.Slider( 234 32, 256, value=64, step=32, label="Hidden Layer Size" 235 ) 236 train_button = gr.Button("Start Training", variant="primary") 237 238 with gr.Column(scale=3): 239 gr.Markdown("## Evaluation Results") 240 test_acc_box = gr.Textbox(label="Test Accuracy") 241 with gr.Tabs() as eval_tabs: 242 with gr.TabItem("Learning Curves"): 243 with gr.Row(): 244 train_loss_plot = gr.Plot(label="Training Loss") 245 val_acc_plot = gr.Plot(label="Validation Accuracy") 246 with gr.TabItem("Confusion & Calibration"): 247 with gr.Row(): 248 cm_plot = gr.Image(label="Confusion Matrix") 249 calibration_plot = gr.Image(label="Calibration Curve") 250 with gr.TabItem("Per-Class Accuracy"): 251 pca_plot = gr.Image(label="Per-Class Accuracy") 252 with gr.TabItem("Misclassified Samples"): 253 misclassified_plot = gr.Image(label="Misclassified Samples") 254 255 train_button.click( 256 fn=train_and_evaluate, 257 inputs=[ 258 batch_size_slider, 259 epochs_slider, 260 lr_slider, 261 num_filters_slider, 262 hidden_size_slider, 263 ], 264 outputs=[ 265 test_acc_box, 266 train_loss_plot, 267 val_acc_plot, 268 cm_plot, 269 pca_plot, 270 misclassified_plot, 271 calibration_plot, 272 ], 273 ) 274 275def main(): 276 """Launch Gradio interface.""" 277 demo.launch() 278 279if __name__ == "__main__": 280 main()
def
get_df_from_dataset(dataset):
40def get_df_from_dataset(dataset): 41 """Converts a torchvision dataset to a pandas DataFrame.""" 42 images = dataset.data.view(len(dataset), -1).numpy() 43 labels = dataset.targets.numpy() 44 df = pd.DataFrame(images) 45 df.columns = [f"pixel{i}" for i in range(images.shape[1])] 46 df["label"] = labels 47 return df
Converts a torchvision dataset to a pandas DataFrame.
explore_datamodule =
<my_project.dataset.FashionMNISTDataModule object>
train_df =
pixel0 pixel1 pixel2 pixel3 pixel4 pixel5 pixel6 ... pixel778 pixel779 pixel780 pixel781 pixel782 pixel783 label
0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 9
1 0 0 0 0 0 1 0 ... 0 0 0 0 0 0 0
2 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0
3 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 3
4 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
59995 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 5
59996 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 1
59997 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 3
59998 0 0 0 0 0 0 0 ... 0 1 0 0 0 0 0
59999 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 5
[60000 rows x 785 columns]
test_df =
pixel0 pixel1 pixel2 pixel3 pixel4 pixel5 pixel6 ... pixel778 pixel779 pixel780 pixel781 pixel782 pixel783 label
0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 9
1 0 0 0 0 0 0 0 ... 174 189 67 0 0 0 2
2 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 1
3 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 1
4 0 0 0 2 0 1 1 ... 0 0 0 0 0 0 6
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
9995 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 9
9996 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 1
9997 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 8
9998 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 1
9999 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 5
[10000 rows x 785 columns]
def
update_data_exploration(dataset_choice, class_filter):
58def update_data_exploration(dataset_choice, class_filter): 59 """ 60 Updates the components in the Data Exploration tab based on user selection. 61 """ 62 df = train_df if dataset_choice == "Train" else test_df 63 64 # 1. Update the class distribution plot 65 dist_plot_fig = plot_class_distribution(df, FASHION_CLASSES) 66 67 # 2. Update the dendrogram 68 dendrogram_fig = plot_class_correlation_dendrogram(df, FASHION_CLASSES) 69 70 # 2. Update the gallery 71 df_to_sample = df 72 if class_filter != "All": 73 class_index = FASHION_CLASSES.index(class_filter) 74 df_to_sample = df[df["label"] == class_index] 75 76 gallery_images = get_sample_images_for_gallery( 77 df_to_sample, FASHION_CLASSES, n_samples=15 78 ) 79 80 # 3. Update statistics text 81 stats_md = f""" 82 ### Dataset Statistics 83 - **Selected Set:** {dataset_choice} 84 - **Total Samples:** {len(df)} 85 - **Number of Classes:** {len(FASHION_CLASSES)} 86 - **Image Size:** 28x28 pixels (grayscale) 87 """ 88 89 return stats_md, dist_plot_fig, dendrogram_fig, gallery_images
Updates the components in the Data Exploration tab based on user selection.
def
train_and_evaluate( batch_size: int, max_epochs: int, lr: float, num_filters: int, hidden_size: int, progress=<gradio.helpers.Progress object>):
93def train_and_evaluate( 94 batch_size: int, 95 max_epochs: int, 96 lr: float, 97 num_filters: int, 98 hidden_size: int, 99 progress=gr.Progress(track_tqdm=True), 100): 101 """ 102 A function to train and evaluate the model with given hyperparameters. 103 This will be connected to the Gradio interface. 104 """ 105 progress(0, desc="Initializing DataModule...") 106 # The data will be downloaded to the 'data/' directory if not present 107 datamodule = FashionMNISTDataModule(data_dir="data/", batch_size=int(batch_size)) 108 109 progress(0.1, desc="Initializing Model...") 110 model = Net(num_filters=int(num_filters), hidden_size=int(hidden_size), lr=lr) 111 112 # Use a simple callback to update progress 113 class GradioProgressCallback(pl.Callback): 114 def on_train_epoch_end(self, trainer, pl_module): 115 progress( 116 trainer.current_epoch / trainer.max_epochs, 117 desc=f"Epoch {trainer.current_epoch+1}/{trainer.max_epochs}", 118 ) 119 120 # Use a temporary directory for logs 121 with tempfile.TemporaryDirectory() as tmpdir: 122 logger = CSVLogger(save_dir=tmpdir, name="gradio_logs") 123 trainer = pl.Trainer( 124 max_epochs=int(max_epochs), 125 accelerator="gpu", # Explicitly use GPU 126 devices="auto", 127 logger=logger, 128 callbacks=[GradioProgressCallback()], 129 enable_checkpointing=False, 130 ) 131 132 progress(0.2, desc="Starting Training...") 133 trainer.fit(model, datamodule=datamodule) 134 135 progress(0.85, desc="Generating Learning Curves...") 136 metrics_path = os.path.join(logger.log_dir, "metrics.csv") 137 train_loss_fig, val_acc_fig = None, None 138 if os.path.exists(metrics_path): 139 metrics_df = pd.read_csv(metrics_path) 140 train_loss_fig, val_acc_fig = plot_learning_curves_from_df(metrics_df) 141 142 progress(0.9, desc="Evaluating on Test Set...") 143 # Manually call setup for the 'test' stage to ensure test_dataset is initialized. 144 datamodule.setup(stage="test") 145 146 # The evaluate_and_plot function already returns the paths to the plots 147 artifacts = evaluate_and_plot( 148 model, datamodule, out_dir="reports/figures/gradio" 149 ) 150 151 return ( 152 f"{artifacts['test_accuracy']:.4f}", 153 # Learning curves 154 train_loss_fig, 155 val_acc_fig, 156 # Evaluation plots 157 artifacts["confusion_matrix"], 158 artifacts["per_class_accuracy"], 159 artifacts["misclassified_grid"], 160 artifacts["calibration_curve"], 161 )
A function to train and evaluate the model with given hyperparameters. This will be connected to the Gradio interface.
def
main():
Launch Gradio interface.