my_project.app

Gradio application for interactive Fashion-MNIST model training and data exploration.

This script launches a web-based interface with two main tabs:

  1. Data Exploration: Allows users to visualize the Fashion-MNIST dataset. Users can switch between training and test sets, view class distributions, see a class similarity dendrogram, and browse a gallery of sample images with filtering options.
  1. Train & Evaluate: Provides an interface to train a model with adjustable hyperparameters (e.g., batch size, learning rate, epochs). After training, it displays evaluation results, including test accuracy, learning curves, a confusion matrix, and other performance plots.
  1"""
  2Gradio application for interactive Fashion-MNIST model training and data exploration.
  3
  4This script launches a web-based interface with two main tabs:
  51.  **Data Exploration**: Allows users to visualize the Fashion-MNIST dataset.
  6    Users can switch between training and test sets, view class distributions,
  7    see a class similarity dendrogram, and browse a gallery of sample images
  8    with filtering options.
  9
 102.  **Train & Evaluate**: Provides an interface to train a model with adjustable
 11    hyperparameters (e.g., batch size, learning rate, epochs). After training,
 12    it displays evaluation results, including test accuracy, learning curves,
 13    a confusion matrix, and other performance plots.
 14"""
 15import gradio as gr
 16import pytorch_lightning as pl
 17from pytorch_lightning.loggers import CSVLogger
 18import torch
 19import pandas as pd
 20import tempfile
 21import os
 22
 23from my_project.dataset import FashionMNISTDataModule
 24from my_project.model import Net
 25from my_project.plots import (
 26    evaluate_and_plot,
 27    FASHION_CLASSES,
 28    plot_class_distribution,
 29    get_sample_images_for_gallery,
 30    plot_learning_curves_from_df,
 31    plot_class_correlation_dendrogram,
 32)
 33
 34# Set matmul precision for Tensor Cores
 35if torch.cuda.is_available():
 36    torch.set_float32_matmul_precision('high')
 37
 38
 39def get_df_from_dataset(dataset):
 40    """Converts a torchvision dataset to a pandas DataFrame."""
 41    images = dataset.data.view(len(dataset), -1).numpy()
 42    labels = dataset.targets.numpy()
 43    df = pd.DataFrame(images)
 44    df.columns = [f"pixel{i}" for i in range(images.shape[1])]
 45    df["label"] = labels
 46    return df
 47
 48
 49# --- Load data for exploration using the DataModule ---
 50explore_datamodule = FashionMNISTDataModule(data_dir="data/", batch_size=128)
 51explore_datamodule.prepare_data()
 52explore_datamodule.setup()
 53train_df = get_df_from_dataset(explore_datamodule.train_val_dataset)
 54test_df = get_df_from_dataset(explore_datamodule.test_dataset)
 55
 56
 57def update_data_exploration(dataset_choice, class_filter):
 58    """
 59    Updates the components in the Data Exploration tab based on user selection.
 60    """
 61    df = train_df if dataset_choice == "Train" else test_df
 62
 63    # 1. Update the class distribution plot
 64    dist_plot_fig = plot_class_distribution(df, FASHION_CLASSES)
 65
 66    # 2. Update the dendrogram
 67    dendrogram_fig = plot_class_correlation_dendrogram(df, FASHION_CLASSES)
 68
 69    # 2. Update the gallery
 70    df_to_sample = df
 71    if class_filter != "All":
 72        class_index = FASHION_CLASSES.index(class_filter)
 73        df_to_sample = df[df["label"] == class_index]
 74
 75    gallery_images = get_sample_images_for_gallery(
 76        df_to_sample, FASHION_CLASSES, n_samples=15
 77    )
 78
 79    # 3. Update statistics text
 80    stats_md = f"""
 81    ### Dataset Statistics
 82    - **Selected Set:** {dataset_choice}
 83    - **Total Samples:** {len(df)}
 84    - **Number of Classes:** {len(FASHION_CLASSES)}
 85    - **Image Size:** 28x28 pixels (grayscale)
 86    """
 87
 88    return stats_md, dist_plot_fig, dendrogram_fig, gallery_images
 89
 90
 91
 92def train_and_evaluate(
 93    batch_size: int,
 94    max_epochs: int,
 95    lr: float,
 96    num_filters: int,
 97    hidden_size: int,
 98    progress=gr.Progress(track_tqdm=True),
 99):
100    """
101    A function to train and evaluate the model with given hyperparameters.
102    This will be connected to the Gradio interface.
103    """
104    progress(0, desc="Initializing DataModule...")
105    # The data will be downloaded to the 'data/' directory if not present
106    datamodule = FashionMNISTDataModule(data_dir="data/", batch_size=int(batch_size))
107
108    progress(0.1, desc="Initializing Model...")
109    model = Net(num_filters=int(num_filters), hidden_size=int(hidden_size), lr=lr)
110
111    # Use a simple callback to update progress
112    class GradioProgressCallback(pl.Callback):
113        def on_train_epoch_end(self, trainer, pl_module):
114            progress(
115                trainer.current_epoch / trainer.max_epochs,
116                desc=f"Epoch {trainer.current_epoch+1}/{trainer.max_epochs}",
117            )
118
119    # Use a temporary directory for logs
120    with tempfile.TemporaryDirectory() as tmpdir:
121        logger = CSVLogger(save_dir=tmpdir, name="gradio_logs")
122        trainer = pl.Trainer(
123            max_epochs=int(max_epochs),
124            accelerator="gpu",  # Explicitly use GPU
125            devices="auto",
126            logger=logger,
127            callbacks=[GradioProgressCallback()],
128            enable_checkpointing=False,
129        )
130
131        progress(0.2, desc="Starting Training...")
132        trainer.fit(model, datamodule=datamodule)
133
134        progress(0.85, desc="Generating Learning Curves...")
135        metrics_path = os.path.join(logger.log_dir, "metrics.csv")
136        train_loss_fig, val_acc_fig = None, None
137        if os.path.exists(metrics_path):
138            metrics_df = pd.read_csv(metrics_path)
139            train_loss_fig, val_acc_fig = plot_learning_curves_from_df(metrics_df)
140
141        progress(0.9, desc="Evaluating on Test Set...")
142        # Manually call setup for the 'test' stage to ensure test_dataset is initialized.
143        datamodule.setup(stage="test")
144
145        # The evaluate_and_plot function already returns the paths to the plots
146        artifacts = evaluate_and_plot(
147            model, datamodule, out_dir="reports/figures/gradio"
148        )
149
150    return (
151        f"{artifacts['test_accuracy']:.4f}",
152        # Learning curves
153        train_loss_fig,
154        val_acc_fig,
155        # Evaluation plots
156        artifacts["confusion_matrix"],
157        artifacts["per_class_accuracy"],
158        artifacts["misclassified_grid"],
159        artifacts["calibration_curve"],
160    )
161
162
163# We define the Gradio interface using Blocks for more control.
164with gr.Blocks(
165    css="""
166.gradio-container .gallery-item { max-height: 100px !important; min-height: 100px !important; }
167"""
168) as demo:
169    gr.Markdown("# Fashion-MNIST Interactive Demo")
170
171    with gr.Tab("Data Exploration"):
172        gr.Markdown("## Interactive Exploration of the Fashion-MNIST Dataset")
173        gr.Markdown(
174            "Use the controls to switch between the training and test datasets, or to filter images by class."
175        )
176
177        with gr.Row(variant="panel"):
178            with gr.Column(scale=1, min_width=250):
179                gr.Markdown("### Controls")
180                dataset_selector = gr.Radio(
181                    ["Train", "Test"], value="Train", label="Select Dataset"
182                )
183                stats_md_box = gr.Markdown()
184
185                gr.Markdown("### Image Gallery")
186                class_selector = gr.Dropdown(
187                    ["All"] + FASHION_CLASSES, value="All", label="Filter by Class"
188                )
189                refresh_button = gr.Button("Refresh Images")
190
191            with gr.Column(scale=3):
192                with gr.Tabs():
193                    with gr.TabItem("Class Distribution"):
194                        dist_plot = gr.Plot()
195                    with gr.TabItem("Class Similarity"):
196                        dendrogram_plot = gr.Plot()
197                gr.Markdown("### Image Samples")
198                gallery = gr.Gallery(
199                    label="Random samples from the dataset",
200                    columns=5,
201                    object_fit="contain",
202                    height="auto",
203                    show_label=True,
204                )
205
206        # --- Event Listeners for Interactivity ---
207        # Combine selectors to a single update function for efficiency
208        controls = [dataset_selector, class_selector]
209        outputs = [stats_md_box, dist_plot, dendrogram_plot, gallery]
210
211        for control in controls:
212            control.change(fn=update_data_exploration, inputs=controls, outputs=outputs)
213        refresh_button.click(
214            fn=update_data_exploration, inputs=controls, outputs=outputs
215        )
216        # Initial load for the data exploration tab
217        demo.load(fn=update_data_exploration, inputs=[gr.State("Train"), gr.State("All")], outputs=outputs)
218
219    with gr.Tab("Train & Evaluate"):
220        with gr.Row():
221            with gr.Column(scale=1):
222                gr.Markdown("## Hyperparameters")
223                batch_size_slider = gr.Slider(
224                    32, 512, value=128, step=32, label="Batch Size"
225                )
226                epochs_slider = gr.Slider(1, 20, value=5, step=1, label="Max Epochs")
227                lr_slider = gr.Slider(
228                    1e-5, 1e-2, value=1e-3, label="Learning Rate", step=1e-5
229                )
230                num_filters_slider = gr.Slider(
231                    8, 64, value=32, step=8, label="Conv Filters"
232                )
233                hidden_size_slider = gr.Slider(
234                    32, 256, value=64, step=32, label="Hidden Layer Size"
235                )
236                train_button = gr.Button("Start Training", variant="primary")
237
238            with gr.Column(scale=3):
239                gr.Markdown("## Evaluation Results")
240                test_acc_box = gr.Textbox(label="Test Accuracy")
241                with gr.Tabs() as eval_tabs:
242                    with gr.TabItem("Learning Curves"):
243                        with gr.Row():
244                            train_loss_plot = gr.Plot(label="Training Loss")
245                            val_acc_plot = gr.Plot(label="Validation Accuracy")
246                    with gr.TabItem("Confusion & Calibration"):
247                        with gr.Row():
248                            cm_plot = gr.Image(label="Confusion Matrix")
249                            calibration_plot = gr.Image(label="Calibration Curve")
250                    with gr.TabItem("Per-Class Accuracy"):
251                        pca_plot = gr.Image(label="Per-Class Accuracy")
252                    with gr.TabItem("Misclassified Samples"):
253                        misclassified_plot = gr.Image(label="Misclassified Samples")
254
255    train_button.click(
256        fn=train_and_evaluate,
257        inputs=[
258            batch_size_slider,
259            epochs_slider,
260            lr_slider,
261            num_filters_slider,
262            hidden_size_slider,
263        ],
264        outputs=[
265            test_acc_box,
266            train_loss_plot,
267            val_acc_plot,
268            cm_plot,
269            pca_plot,
270            misclassified_plot,
271            calibration_plot,
272        ],
273    )
274
275def main():
276    """Launch Gradio interface."""
277    demo.launch()
278
279if __name__ == "__main__":
280    main()
def get_df_from_dataset(dataset):
40def get_df_from_dataset(dataset):
41    """Converts a torchvision dataset to a pandas DataFrame."""
42    images = dataset.data.view(len(dataset), -1).numpy()
43    labels = dataset.targets.numpy()
44    df = pd.DataFrame(images)
45    df.columns = [f"pixel{i}" for i in range(images.shape[1])]
46    df["label"] = labels
47    return df

Converts a torchvision dataset to a pandas DataFrame.

explore_datamodule = <my_project.dataset.FashionMNISTDataModule object>
train_df = pixel0 pixel1 pixel2 pixel3 pixel4 pixel5 pixel6 ... pixel778 pixel779 pixel780 pixel781 pixel782 pixel783 label 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 9 1 0 0 0 0 0 1 0 ... 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 3 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 3 4 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... 59995 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 5 59996 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 1 59997 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 3 59998 0 0 0 0 0 0 0 ... 0 1 0 0 0 0 0 59999 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 5 [60000 rows x 785 columns]
test_df = pixel0 pixel1 pixel2 pixel3 pixel4 pixel5 pixel6 ... pixel778 pixel779 pixel780 pixel781 pixel782 pixel783 label 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 9 1 0 0 0 0 0 0 0 ... 174 189 67 0 0 0 2 2 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 1 3 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 1 4 0 0 0 2 0 1 1 ... 0 0 0 0 0 0 6 ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... 9995 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 9 9996 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 1 9997 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 8 9998 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 1 9999 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 5 [10000 rows x 785 columns]
def update_data_exploration(dataset_choice, class_filter):
58def update_data_exploration(dataset_choice, class_filter):
59    """
60    Updates the components in the Data Exploration tab based on user selection.
61    """
62    df = train_df if dataset_choice == "Train" else test_df
63
64    # 1. Update the class distribution plot
65    dist_plot_fig = plot_class_distribution(df, FASHION_CLASSES)
66
67    # 2. Update the dendrogram
68    dendrogram_fig = plot_class_correlation_dendrogram(df, FASHION_CLASSES)
69
70    # 2. Update the gallery
71    df_to_sample = df
72    if class_filter != "All":
73        class_index = FASHION_CLASSES.index(class_filter)
74        df_to_sample = df[df["label"] == class_index]
75
76    gallery_images = get_sample_images_for_gallery(
77        df_to_sample, FASHION_CLASSES, n_samples=15
78    )
79
80    # 3. Update statistics text
81    stats_md = f"""
82    ### Dataset Statistics
83    - **Selected Set:** {dataset_choice}
84    - **Total Samples:** {len(df)}
85    - **Number of Classes:** {len(FASHION_CLASSES)}
86    - **Image Size:** 28x28 pixels (grayscale)
87    """
88
89    return stats_md, dist_plot_fig, dendrogram_fig, gallery_images

Updates the components in the Data Exploration tab based on user selection.

def train_and_evaluate( batch_size: int, max_epochs: int, lr: float, num_filters: int, hidden_size: int, progress=<gradio.helpers.Progress object>):
 93def train_and_evaluate(
 94    batch_size: int,
 95    max_epochs: int,
 96    lr: float,
 97    num_filters: int,
 98    hidden_size: int,
 99    progress=gr.Progress(track_tqdm=True),
100):
101    """
102    A function to train and evaluate the model with given hyperparameters.
103    This will be connected to the Gradio interface.
104    """
105    progress(0, desc="Initializing DataModule...")
106    # The data will be downloaded to the 'data/' directory if not present
107    datamodule = FashionMNISTDataModule(data_dir="data/", batch_size=int(batch_size))
108
109    progress(0.1, desc="Initializing Model...")
110    model = Net(num_filters=int(num_filters), hidden_size=int(hidden_size), lr=lr)
111
112    # Use a simple callback to update progress
113    class GradioProgressCallback(pl.Callback):
114        def on_train_epoch_end(self, trainer, pl_module):
115            progress(
116                trainer.current_epoch / trainer.max_epochs,
117                desc=f"Epoch {trainer.current_epoch+1}/{trainer.max_epochs}",
118            )
119
120    # Use a temporary directory for logs
121    with tempfile.TemporaryDirectory() as tmpdir:
122        logger = CSVLogger(save_dir=tmpdir, name="gradio_logs")
123        trainer = pl.Trainer(
124            max_epochs=int(max_epochs),
125            accelerator="gpu",  # Explicitly use GPU
126            devices="auto",
127            logger=logger,
128            callbacks=[GradioProgressCallback()],
129            enable_checkpointing=False,
130        )
131
132        progress(0.2, desc="Starting Training...")
133        trainer.fit(model, datamodule=datamodule)
134
135        progress(0.85, desc="Generating Learning Curves...")
136        metrics_path = os.path.join(logger.log_dir, "metrics.csv")
137        train_loss_fig, val_acc_fig = None, None
138        if os.path.exists(metrics_path):
139            metrics_df = pd.read_csv(metrics_path)
140            train_loss_fig, val_acc_fig = plot_learning_curves_from_df(metrics_df)
141
142        progress(0.9, desc="Evaluating on Test Set...")
143        # Manually call setup for the 'test' stage to ensure test_dataset is initialized.
144        datamodule.setup(stage="test")
145
146        # The evaluate_and_plot function already returns the paths to the plots
147        artifacts = evaluate_and_plot(
148            model, datamodule, out_dir="reports/figures/gradio"
149        )
150
151    return (
152        f"{artifacts['test_accuracy']:.4f}",
153        # Learning curves
154        train_loss_fig,
155        val_acc_fig,
156        # Evaluation plots
157        artifacts["confusion_matrix"],
158        artifacts["per_class_accuracy"],
159        artifacts["misclassified_grid"],
160        artifacts["calibration_curve"],
161    )

A function to train and evaluate the model with given hyperparameters. This will be connected to the Gradio interface.

def main():
276def main():
277    """Launch Gradio interface."""
278    demo.launch()

Launch Gradio interface.