.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_tutorials/Segmentation/tutorial_muad_packed.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_tutorials_Segmentation_tutorial_muad_packed.py: Packed ensembles Segmentation Tutorial using Muad Dataset ========================================================= This tutorial demonstrates how to train a segmentation model on the MUAD dataset using TorchUncertainty. MUAD is a synthetic dataset designed for evaluating autonomous driving under diverse uncertainties. It includes **10,413 images** across training, validation, and test sets, featuring adverse weather, lighting conditions, and out-of-distribution (OOD) objects. The dataset supports tasks like semantic segmentation, depth estimation, and object detection. For details and access, visit the `MUAD Website `_. 1. Loading the utilities ~~~~~~~~~~~~~~~~~~~~~~~~ First, we load the following utilities from TorchUncertainty: - the TUTrainer which mostly handles the link with the hardware (accelerators, precision, etc) - the segmentation training & evaluation routine from torch_uncertainty.routines - the datamodule handling dataloaders: MUADDataModule from torch_uncertainty.datamodules .. GENERATED FROM PYTHON SOURCE LINES 25-40 .. code-block:: Python import matplotlib.pyplot as plt import torch import torchvision.transforms.v2.functional as F from torch import optim from torch.optim import lr_scheduler from torchvision import tv_tensors from torchvision.transforms import v2 from torchvision.utils import draw_segmentation_masks from torch_uncertainty import TUTrainer from torch_uncertainty.datamodules.segmentation import MUADDataModule from torch_uncertainty.models.segmentation.unet import packed_small_unet from torch_uncertainty.routines import SegmentationRoutine from torch_uncertainty.transforms import RepeatTarget .. GENERATED FROM PYTHON SOURCE LINES 41-43 2. Initializing the DataModule ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 43-89 .. code-block:: Python muad_mean = MUADDataModule.mean muad_std = MUADDataModule.std train_transform = v2.Compose( [ v2.Resize(size=(256, 512), antialias=True), v2.RandomHorizontalFlip(), v2.ToDtype( dtype={ tv_tensors.Image: torch.float32, tv_tensors.Mask: torch.int64, "others": None, }, scale=True, ), v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) test_transform = v2.Compose( [ v2.Resize(size=(256, 512), antialias=True), v2.ToDtype( dtype={ tv_tensors.Image: torch.float32, tv_tensors.Mask: torch.int64, "others": None, }, scale=True, ), v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ] ) # datamodule providing the dataloaders to the trainer datamodule = MUADDataModule( root="./data", batch_size=10, version="small", train_transform=train_transform, test_transform=test_transform, num_workers=4, ) datamodule.prepare_data() datamodule.setup("fit") .. GENERATED FROM PYTHON SOURCE LINES 90-91 Visualize a validation input sample (and RGB image) .. GENERATED FROM PYTHON SOURCE LINES 91-105 .. code-block:: Python # Undo normalization on the image and convert to uint8. img, tgt = datamodule.train[0] t_muad_mean = torch.tensor(muad_mean, device=img.device) t_muad_std = torch.tensor(muad_std, device=img.device) img = img * t_muad_std[:, None, None] + t_muad_mean[:, None, None] img = F.to_dtype(img, torch.uint8, scale=True) img_pil = F.to_pil_image(img) plt.figure(figsize=(6, 6)) plt.imshow(img_pil) plt.axis("off") plt.show() .. image-sg:: /auto_tutorials/Segmentation/images/sphx_glr_tutorial_muad_packed_001.png :alt: tutorial muad packed :srcset: /auto_tutorials/Segmentation/images/sphx_glr_tutorial_muad_packed_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 106-107 Visualize the same image above but segmented. .. GENERATED FROM PYTHON SOURCE LINES 107-120 .. code-block:: Python tmp_tgt = tgt.masked_fill(tgt == 255, 21) tgt_masks = tmp_tgt == torch.arange(22, device=tgt.device)[:, None, None] img_segmented = draw_segmentation_masks( img, tgt_masks, alpha=1, colors=datamodule.train.color_palette ) img_pil = F.to_pil_image(img_segmented) plt.figure(figsize=(6, 6)) plt.imshow(img_pil) plt.axis("off") plt.show() .. image-sg:: /auto_tutorials/Segmentation/images/sphx_glr_tutorial_muad_packed_002.png :alt: tutorial muad packed :srcset: /auto_tutorials/Segmentation/images/sphx_glr_tutorial_muad_packed_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 121-124 3. Instantiating the Model ~~~~~~~~~~~~~~~~~~~~~~~~~~ We create the model easily using the blueprint from torch_uncertainty.models. .. GENERATED FROM PYTHON SOURCE LINES 124-135 .. code-block:: Python model = packed_small_unet( in_channels=datamodule.num_channels, num_classes=datamodule.num_classes, alpha=2, num_estimators=4, gamma=1, bilinear=True, ) .. GENERATED FROM PYTHON SOURCE LINES 136-138 4. Compute class weights to mitigate class inbalance ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 138-184 .. code-block:: Python def enet_weighing(dataloader, num_classes, c=1.02): """Computes class weights as described in the ENet paper. w_class = 1 / (ln(c + p_class)), where c is usually 1.02 and p_class is the propensity score of that class: propensity_score = freq_class / total_pixels. References: https://arxiv.org/abs/1606.02147 Args: dataloader (``data.Dataloader``): A data loader to iterate over the dataset. num_classes (``int``): The number of classes. c (``int``, optional): AN additional hyper-parameter which restricts the interval of values for the weights. Default: 1.02. ignore_indexes (``list``, optional): A list of indexes to ignore when computing the weights. Default to `None`. """ class_count = 0 total = 0 for _, label in dataloader: label = label.cpu() # Flatten label flat_label = label.flatten() flat_label = flat_label[flat_label != 255] flat_label = flat_label[flat_label < num_classes] # Sum up the number of pixels of each class and the total pixel # counts for each label class_count += torch.bincount(flat_label, minlength=num_classes) total += flat_label.size(0) # Compute propensity score and then the weights for each class propensity_score = class_count / total return 1 / (torch.log(c + propensity_score)) class_weights = enet_weighing(datamodule.val_dataloader(), datamodule.num_classes) print(class_weights) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([ 4.3817, 19.7927, 3.3011, 48.8031, 36.2141, 33.0049, 47.5130, 48.8560, 12.4401, 48.0600, 14.4807, 30.8762, 4.7467, 19.3913, 50.4984]) .. GENERATED FROM PYTHON SOURCE LINES 185-186 Let's define the training parameters. .. GENERATED FROM PYTHON SOURCE LINES 186-195 .. code-block:: Python BATCH_SIZE = 10 LEARNING_RATE = 1e-3 WEIGHT_DECAY = 2e-4 LR_DECAY_EPOCHS = 20 LR_DECAY = 0.1 NB_EPOCHS = 1 NUM_ESTIMATORS = 4 .. GENERATED FROM PYTHON SOURCE LINES 196-198 5. The Loss, the Routine, and the Trainer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 198-223 .. code-block:: Python # We build the optimizer optimizer = optim.Adam( model.parameters(), lr=LEARNING_RATE * NUM_ESTIMATORS, weight_decay=WEIGHT_DECAY ) # Learning rate decay scheduler lr_updater = lr_scheduler.StepLR(optimizer, LR_DECAY_EPOCHS, LR_DECAY) packed_routine = SegmentationRoutine( model=model, num_classes=datamodule.num_classes, loss=torch.nn.CrossEntropyLoss(weight=class_weights), format_batch_fn=RepeatTarget(NUM_ESTIMATORS), # Repeat the target 4 times for the ensemble optim_recipe={"optimizer": optimizer, "lr_scheduler": lr_updater}, ) trainer = TUTrainer( accelerator="gpu", devices=1, max_epochs=NB_EPOCHS, enable_progress_bar=False, precision="16-mixed", ) .. GENERATED FROM PYTHON SOURCE LINES 224-226 6. Training the model ~~~~~~~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 226-227 .. code-block:: Python trainer.fit(model=packed_routine, datamodule=datamodule) .. GENERATED FROM PYTHON SOURCE LINES 228-230 7. Testing the model ~~~~~~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 230-232 .. code-block:: Python results = trainer.test(datamodule=datamodule, ckpt_path="best") .. rst-class:: sphx-glr-script-out .. code-block:: none ┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Test metric ┃ Segmentation ┃ ┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ Brier │ 0.67491 │ │ NLL │ 1.61145 │ │ mAcc │ 23.960% │ │ mIoU │ 12.944% │ │ pixAcc │ 52.616% │ └──────────────┴───────────────────────────┘ ┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Test metric ┃ Calibration ┃ ┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ ECE │ 18.491% │ │ aECE │ 18.658% │ └──────────────┴───────────────────────────┘ ┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Test metric ┃ Selective Classification ┃ ┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ AUGRC │ 19.938% │ │ AURC │ 33.967% │ │ Cov@5Risk │ nan% │ │ Risk@80Cov │ 43.517% │ └──────────────┴───────────────────────────┘ .. GENERATED FROM PYTHON SOURCE LINES 233-236 8. Uncertainty evaluations with MCP ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Here we will just use as confidence score the Maximum class probability (MCP) .. GENERATED FROM PYTHON SOURCE LINES 236-288 .. code-block:: Python img, target = datamodule.test[0] batch_img = img.unsqueeze(0) batch_target = target.unsqueeze(0) model.eval() with torch.no_grad(): # Forward propagation outputs = model(batch_img) outputs_proba = outputs.softmax(dim=1) # average the outputs over the estimators outputs_proba = outputs_proba.mean(dim=0) # remove the batch dimension outputs_proba = outputs_proba.squeeze(0) confidence, pred = outputs_proba.max(0) # Undo normalization on the image and convert to uint8. mean = torch.tensor([0.485, 0.456, 0.406], device=img.device) std = torch.tensor([0.229, 0.224, 0.225], device=img.device) img = img * std[:, None, None] + mean[:, None, None] img = F.to_dtype(img, torch.uint8, scale=True) tmp_target = target.masked_fill(target == 255, 21) target_masks = tmp_target == torch.arange(22, device=target.device)[:, None, None] img_segmented = draw_segmentation_masks( img, target_masks, alpha=1, colors=datamodule.test.color_palette ) pred_masks = pred == torch.arange(22, device=pred.device)[:, None, None] pred_img = draw_segmentation_masks(img, pred_masks, alpha=1, colors=datamodule.test.color_palette) if confidence.ndim == 2: confidence = confidence.unsqueeze(0) img = F.to_pil_image(F.resize(img, 1024)) img_segmented = F.to_pil_image(F.resize(img_segmented, 1024)) pred_img = F.to_pil_image(F.resize(pred_img, 1024)) confidence_img = F.to_pil_image(F.resize(confidence, 1024)) fig, axs = plt.subplots(1, 4, figsize=(25, 7)) images = [img, img_segmented, pred_img, confidence_img] for ax, im in zip(axs, images, strict=False): ax.imshow(im) ax.axis("off") plt.subplots_adjust(left=0.01, right=0.99, top=0.99, bottom=0.01, wspace=0.05) plt.show() .. image-sg:: /auto_tutorials/Segmentation/images/sphx_glr_tutorial_muad_packed_003.png :alt: tutorial muad packed :srcset: /auto_tutorials/Segmentation/images/sphx_glr_tutorial_muad_packed_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 38.966 seconds) .. _sphx_glr_download_auto_tutorials_Segmentation_tutorial_muad_packed.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: tutorial_muad_packed.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: tutorial_muad_packed.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: tutorial_muad_packed.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_