Skip to content

Training Module

Comprehensive training pipeline with metrics tracking and visualization for food segmentation.

What We Are Tracking

Core Metrics

  • Loss Functions - Training and validation loss per epoch
  • Segmentation Accuracy - Pixel-wise accuracy and mean IoU
  • Learning Progress - Learning rate schedules and training time
  • Model Performance - Validation metrics and best model checkpointing

Experiment Tracking

  • Weights & Biases Integration - Hyperparameters, model architecture, and system metrics
  • Visualization Outputs - Training curves, loss plots, and prediction visualizations

Tracking Architecture

Hybrid approach combining: - Local logging with Rich console output - Weights & Biases for cloud-based experiment tracking - File-based visualization saves - Model checkpoint management

This focuses on essential segmentation metrics while maintaining training pipeline simplicity.

src.segmentation.train

Trainer(lr=None, epochs=None, batch_size=None, base_dir=None, enable_profiler=None, init_wandb=True, prune_amount=0.2)

Initialize the Trainer and do experimental logging with Weights and Biases.

Parameters:

Name Type Description Default
lr float

Learning rate for the optimizer.

None
epochs int

Number of training epochs.

None
batch_size int

Batch size for training.

None
base_dir str

Project's root directory.

None
enable_profiler bool

Whether to enable PyTorch profiler.

None
init_wandb bool

Whether to initialize Weights and Biases.

True
Source code in src/segmentation/train.py
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
def __init__(
    self,
    lr=None,
    epochs=None,
    batch_size=None,
    base_dir=None,
    enable_profiler=None,
    init_wandb=True,
    prune_amount=0.2,
):
    """
    Initialize the Trainer and do experimental logging with Weights and Biases.

    Args:
        lr (float): Learning rate for the optimizer.
        epochs (int): Number of training epochs.
        batch_size (int): Batch size for training.
        base_dir (str): Project's root directory.
        enable_profiler (bool): Whether to enable PyTorch profiler.
        init_wandb (bool): Whether to initialize Weights and Biases.
    """

    super().__init__()

    # model
    self.model = MiniUNet().to("cuda" if torch.cuda.is_available() else "cpu")

    # Check if pruning is enabled
    if prune_amount > 0.0:
        logger.info(f"Applying pruning with amount: {prune_amount}")

        # Apply global unstructured pruning to Conv2d and Linear layers
        parameters_to_prune = []
        for module in self.model.modules():
            if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
                parameters_to_prune.append((module, "weight"))

        if parameters_to_prune:
            prune.global_unstructured(
                parameters_to_prune,
                pruning_method=prune.L1Unstructured,
                amount=prune_amount,
            )
            logger.info("Pruning applied to the model.")
        else:
            logger.warning("No Conv2d or Linear layers found for pruning.")

    # Model Parameters
    self.parameters = sum(
        p.numel() for p in self.model.parameters() if p.requires_grad
    )

    # Training Parameters
    self.epochs = epochs
    self.lr = lr
    self.batch_size = batch_size

    # Directories
    self.base_dir = base_dir
    self.saved_dir = os.path.join(self.base_dir, "saved")
    self.model_path = os.path.join(self.saved_dir, "models", "model.pth")
    self.plots_path = os.path.join(
        self.saved_dir, "reports", "training_metrics.png"
    )
    self.predictions = os.path.join(
        self.saved_dir, "predictions", "predictions.png"
    )
    # Data Loaders
    self.train_loader, self.test_loader = data_loaders(
        base_dir=self.base_dir,
        batch_size=self.batch_size,
    )

    # Loss function
    self.loss = nn.CrossEntropyLoss()
    self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)

    # Training History
    self.train_losses = []
    self.test_losses = []
    self.train_accs = []
    self.test_accs = []
    self.train_ious = []
    self.test_ious = []
    self.best_test_loss = float("inf")

    logger.info(f"Model initialized with {self.parameters} trainable parameters.")

    self.init_wandb = init_wandb
    self.enable_profiler = enable_profiler
    if not self.enable_profiler:
        logger.info("Profiler is disabled. Training will run without profiling.")

    # Initialize Weights and Biases
    if self.init_wandb and wandb.run is None:
        wandb.init(
            project="Food-Segmentation",
            config={
                "epochs": self.epochs,
                "learning_rate": self.lr,
                "batch_size": self.batch_size,
                "base_dir": self.base_dir,
                "trainable_params": self.parameters,
                "model": "MiniUNet",
                "optimizer": "Adam",
                "loss_function": "CrossEntropyLoss",
            },
        )
        wandb.watch(self.model, log="all")
        logger.info("Weights and Biases initialized for tracking.")

calculate_iou(pred_mask, true_mask, num_classes=104)

Implements the Intersection over Union (IoU) metric for segmentation tasks.

Source code in src/segmentation/train.py
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
def calculate_iou(self, pred_mask, true_mask, num_classes=104):
    """Implements the Intersection over Union (IoU) metric for segmentation tasks."""
    ious = []
    pred_mask = pred_mask.view(-1)
    true_mask = true_mask.view(-1)

    for cls in range(num_classes):
        pred_inds = pred_mask == cls
        target_inds = true_mask == cls

        intersection = (pred_inds & target_inds).long().sum().item()
        union = (pred_inds | target_inds).long().sum().item()

        if union == 0:
            continue  # Skip if no pixels for this class

        iou = intersection / union
        ious.append(iou)

    return np.mean(ious) if ious else 0.0

forward(x)

Forward pass through the model.

Source code in src/segmentation/train.py
164
165
166
def forward(self, x):
    """Forward pass through the model."""
    return self.model(x)

remove_pruning()

Remove pruning from the model.

This method iterates through all modules in the model and removes pruning parameters if they exist. It is useful for restoring the original model state after pruning has been applied.

Source code in src/segmentation/train.py
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
def remove_pruning(self):
    """
    Remove pruning from the model.

    This method iterates through all modules in the model and removes
    pruning parameters if they exist. It is useful for restoring the
    original model state after pruning has been applied.
    """

    for module in self.model.modules():
        if hasattr(module, "weight_orig"):
            prune.remove(module, "weight")
        if hasattr(module, "bias_orig"):
            prune.remove(module, "bias")
            logger.info(f"Removed pruning from bias of {module}")

    logger.info("Pruning removed from the model.")

train()

Execute the training loop with optional profiling and Weights and Biases logging.

Pipeline Steps : 1. Initialize the device for training (GPU or CPU). 2. Check if the model path is set. 3. Create profiler if enabled. 4. Loop through the number of epochs: 5. Train the model on the training dataset. 6. Calculate and log training metrics (loss, accuracy, IoU). 7. Validate the model on the test dataset. 8. Save the best model based on test loss. 9. Log metrics to Weights and Biases. 10. Finish Weights and Biases run.

Source code in src/segmentation/train.py
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
def train(self):
    """

    Execute the training loop with optional profiling and Weights and Biases logging.

    Pipeline Steps :
    1. Initialize the device for training (GPU or CPU).
    2. Check if the model path is set.
    3. Create profiler if enabled.
    4. Loop through the number of epochs:
    5. Train the model on the training dataset.
    6. Calculate and log training metrics (loss, accuracy, IoU).
    7. Validate the model on the test dataset.
    8. Save the best model based on test loss.
    9. Log metrics to Weights and Biases.
    10. Finish Weights and Biases run.

    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if not self.model_path:
        logger.warning("Model path is not set. Cannot save the model.")
        return

    os.makedirs("./profiler_logs", exist_ok=True)

    prof = None

    if self.enable_profiler:
        prof = profile(
            activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
            schedule=schedule(wait=1, warmup=1, active=3, repeat=2),
            on_trace_ready=torch.profiler.tensorboard_trace_handler(
                "./profiler_logs"
            ),
            record_shapes=True,
            profile_memory=True,
            with_stack=True,
        )
        prof.start()

    for epoch in range(self.epochs):
        # Training
        self.model.train()

        running_loss = 0.0
        running_accuracy = 0.0
        running_iou = 0.0

        train_bar = tqdm(
            self.train_loader,
            desc=f"Epoch {epoch+1}/{self.epochs} [Train]",
            leave=False,
        )

        for batch_idx, (images, masks) in enumerate(train_bar):
            images = images.to(device)
            masks = masks.to(device).long()

            self.optimizer.zero_grad()
            outputs = self.forward(images)
            loss = self.loss(outputs, masks)
            loss.backward()
            self.optimizer.step()

            # Calculate accuracy for current batch
            pred_classes = torch.argmax(outputs, dim=1)
            train_accuracy = torch.mean((pred_classes == masks).float())
            train_iou = self.calculate_iou(pred_classes, masks)

            # Accumulate loss and accuracy
            running_accuracy += train_accuracy.item() * images.size(0)
            running_loss += loss.item() * images.size(0)
            running_iou += train_iou * images.size(0)

            # Update progress bar
            train_bar.set_postfix(loss=loss.item())

            # Step profiler and limit batches for profiling
            if self.enable_profiler and prof is not None:
                prof.step()
                logger.info(f"Profiler step {batch_idx + 1}")
                if batch_idx >= 20:
                    logger.info("🔥 Profiling completed")
                    prof.stop()
                    break

        # Training metrics
        train_loss = running_loss / len(self.train_loader.dataset)
        train_acc = running_accuracy / len(self.train_loader.dataset)
        train_iou = running_iou / len(self.train_loader.dataset)
        self.train_losses.append(train_loss)
        self.train_accs.append(train_acc)
        self.train_ious.append(train_iou)

        # Your existing validation code...
        self.model.eval()
        running_loss = 0.0
        running_accuracy = 0.0
        running_iou = 0.0

        test_bar = tqdm(
            self.test_loader,
            desc=f"Epoch {epoch+1}/{self.epochs} [Test]",
            leave=False,
        )

        with torch.no_grad():
            for images, masks in test_bar:
                images = images.to(device)
                masks = masks.to(device).long()

                outputs = self.forward(images)
                loss = self.loss(outputs, masks)

                pred_classes = torch.argmax(outputs, dim=1)
                test_accuracy = torch.mean((pred_classes == masks).float())
                test_iou = self.calculate_iou(pred_classes, masks)

                running_accuracy += test_accuracy.item() * images.size(0)
                running_loss += loss.item() * images.size(0)
                running_iou += test_iou * images.size(0)

                test_bar.set_postfix(test_loss=loss.item())

        test_loss = running_loss / len(self.test_loader.dataset)
        test_acc = running_accuracy / len(self.test_loader.dataset)
        test_iou = running_iou / len(self.test_loader.dataset)

        self.test_losses.append(test_loss)
        self.test_accs.append(test_acc)
        self.test_ious.append(test_iou)

        print(f"Epoch [{epoch+1}/{self.epochs}]:")
        print(
            f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Train IoU: {train_iou:.4f}"
        )
        print(
            f"  Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}, Test IoU: {test_iou:.4f}"
        )

        wandb.log(
            {
                "Train Loss": train_loss,
                "Train Accuracy": train_acc,
                "Test Loss": test_loss,
                "Test Accuracy": test_acc,
                "epoch": epoch + 1,
                "Train IoU": train_iou,
                "Test IoU": test_iou,
            }
        )

        if test_loss < self.best_test_loss:
            self.best_test_loss = test_loss
            self.remove_pruning()
            torch.save(
                {
                    "model_state_dict": self.model.state_dict(),
                    "optimizer_state_dict": self.optimizer.state_dict(),
                    "epoch": epoch + 1,
                    "best_test_loss": self.best_test_loss,
                    "final_train_loss": train_loss,
                    "final_test_loss": test_loss,
                    "final_train_acc": train_acc,
                    "final_test_acc": test_acc,
                },
                self.model_path,
            )
            logger.info(f"Best model saved with test_loss: {test_loss:.4f}")

            artifact = wandb.Artifact(
                "model",
                type="model",
                description="Best model based on test loss",
            )
            artifact.add_file(self.model_path)
            wandb.log_artifact(artifact)

        logger.info(f"Training complete. Model saved at {self.model_path}")
        print("Training complete.")

    wandb.finish()

visualize_training_metrics()

Visualize training and testing loss & accuracy from saved model checkpoint Creates two side-by-side graphs: Loss comparison and Accuracy comparison

Parameters:

Name Type Description Default
base_dir

Base directory for saving plots

required
model_path

Path to the saved model checkpoint (.pth file)

required
plots_path

Directory to save the plots

required
Source code in src/segmentation/train.py
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
def visualize_training_metrics(self):
    """
    Visualize training and testing loss & accuracy from saved model checkpoint
    Creates two side-by-side graphs: Loss comparison and Accuracy comparison

    Args:
        base_dir: Base directory for saving plots
        model_path: Path to the saved model checkpoint (.pth file)
        plots_path: Directory to save the plots
    """

    # Extract metrics
    train_losses = self.train_losses
    test_losses = self.test_losses
    train_accs = self.train_accs
    test_accs = self.test_accs
    train_ious = self.train_ious
    test_ious = self.test_ious

    # Check if data exists
    if not train_losses or not test_losses:
        logger.warning("No loss data found in checkpoint!")
        return

    if not train_accs or not test_accs:
        logger.warning("No accuracy data found in checkpoint!")
        return

    if not self.plots_path:
        logger.warning("No plots path provided. Plot not saved.")
        return

    # Create figure with 3 subplots side by side
    _, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 6))

    epochs = range(1, len(train_losses) + 1)

    # Graph 1: Training vs Testing Loss
    ax1.plot(
        epochs,
        train_losses,
        label="Training Loss",
        color="blue",
        marker="o",
        linewidth=2,
    )
    ax1.plot(
        epochs,
        test_losses,
        label="Testing Loss",
        color="red",
        marker="s",
        linewidth=2,
    )
    ax1.set_title("Loss Comparison")
    ax1.set_xlabel("Epochs")
    ax1.set_ylabel("Loss")
    ax1.legend()
    ax1.grid(True)

    # Graph 2: Training vs Testing Accuracy
    ax2.plot(
        epochs,
        train_accs,
        label="Training Accuracy",
        color="green",
        marker="o",
        linewidth=2,
    )
    ax2.plot(
        epochs,
        test_accs,
        label="Testing Accuracy",
        color="orange",
        marker="s",
        linewidth=2,
    )
    ax2.set_title("Accuracy Comparison")
    ax2.set_xlabel("Epochs")
    ax2.set_ylabel("Accuracy")
    ax2.legend()
    ax2.grid(True)

    # Graph 3: Training vs Testing IoU
    ax3.plot(
        epochs,
        train_ious,
        label="Training IoU",
        color="purple",
        marker="o",
        linewidth=2,
    )
    ax3.plot(
        epochs,
        test_ious,
        label="Testing IoU",
        color="brown",
        marker="s",
        linewidth=2,
    )
    ax3.set_title("IoU Comparison")
    ax3.set_xlabel("Epochs")
    ax3.set_ylabel("IoU")
    ax3.legend()
    ax3.grid(True)

    # Adjust layout and save the plot
    plt.tight_layout()
    plt.subplots_adjust(top=0.85)  # Adjust top to make space for the title
    plt.suptitle("Training and Testing Metrics", fontsize=20)
    plt.savefig(self.plots_path, dpi=300)  # Save the plot with high resolution
    logger.info(f"Training metrics plot saved: {self.plots_path}")
    plt.tight_layout()
    plt.show()  # Show the plot in interactive mode
    plt.close()  # Close the plot to free memory