Source code for torch_uncertainty.models.wrappers.deep_ensembles
import copy
from typing import Literal
import torch
from torch import nn
from torch.distributions import Distribution
from torch_uncertainty.utils.distributions import cat_dist
class _DeepEnsembles(nn.Module):
def __init__(
self,
models: list[nn.Module],
) -> None:
"""Create a classification deep ensembles from a list of models."""
super().__init__()
self.core_models = nn.ModuleList(models)
self.num_estimators = len(models)
def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""Return the logits of the ensemble.
Args:
x (Tensor): The input of the model.
Returns:
Tensor: The output of the model with shape :math:`(N \times B, C)`,
where :math:`B` is the batch size, :math:`N` is the number of
estimators, and :math:`C` is the number of classes.
"""
return torch.cat(
[model.forward(x) for model in self.core_models], dim=0
)
class _RegDeepEnsembles(_DeepEnsembles):
def __init__(
self,
probabilistic: bool,
models: list[nn.Module],
) -> None:
"""Create a regression deep ensembles from a list of models."""
super().__init__(models)
self.probabilistic = probabilistic
def forward(self, x: torch.Tensor) -> Distribution:
r"""Return the logits of the ensemble.
Args:
x (Tensor): The input of the model.
Returns:
Distribution:
"""
if self.probabilistic:
return cat_dist(
[model.forward(x) for model in self.core_models], dim=0
)
return super().forward(x)
[docs]def deep_ensembles(
models: list[nn.Module] | nn.Module,
num_estimators: int | None = None,
task: Literal[
"classification", "regression", "segmentation", "pixel_regression"
] = "classification",
probabilistic: bool | None = None,
reset_model_parameters: bool = False,
) -> _DeepEnsembles:
"""Build a Deep Ensembles out of the original models.
Args:
models (list[nn.Module] | nn.Module): The model to be ensembled.
num_estimators (int | None): The number of estimators in the ensemble.
task (Literal["classification", "regression", "segmentation", "pixel_regression"]): The model task.
Defaults to "classification".
probabilistic (bool): Whether the regression model is probabilistic.
reset_model_parameters (bool): Whether to reset the model parameters
when :attr:models is a module or a list of length 1.
Returns:
_DeepEnsembles: The ensembled model.
Raises:
ValueError: If :attr:num_estimators is not specified and :attr:models
is a module (or singleton list).
ValueError: If :attr:num_estimators is less than 2 and :attr:models is
a module (or singleton list).
ValueError: If :attr:num_estimators is defined while :attr:models is
a (non-singleton) list.
References:
Balaji Lakshminarayanan, Alexander Pritzel, and Charles Blundell.
Simple and scalable predictive uncertainty estimation using deep
ensembles. In NeurIPS, 2017.
"""
if isinstance(models, list) and len(models) == 0:
raise ValueError("Models must not be an empty list.")
if (isinstance(models, list) and len(models) == 1) or isinstance(
models, nn.Module
):
if num_estimators is None:
raise ValueError(
"if models is a module, num_estimators must be specified."
)
if num_estimators < 2:
raise ValueError(
f"num_estimators must be at least 2. Got {num_estimators}."
)
if isinstance(models, list):
models = models[0]
models = [copy.deepcopy(models) for _ in range(num_estimators)]
if reset_model_parameters:
for model in models:
for layer in model.children():
if hasattr(layer, "reset_parameters"):
layer.reset_parameters()
elif (
isinstance(models, list)
and len(models) > 1
and num_estimators is not None
):
raise ValueError(
"num_estimators must be None if you provided a non-singleton list."
)
if task in ("classification", "segmentation"):
return _DeepEnsembles(models=models)
if task in ("regression", "pixel_regression"):
if probabilistic is None:
raise ValueError(
"probabilistic must be specified for regression models."
)
return _RegDeepEnsembles(probabilistic=probabilistic, models=models)
raise ValueError(f"Unknown task: {task}.")