Shortcuts

Source code for torch_uncertainty.baselines.depth.bts

from typing import Literal

from torch import nn

from torch_uncertainty.models.depth.bts import bts_resnet50, bts_resnet101
from torch_uncertainty.routines import PixelRegressionRoutine


[docs]class BTSBaseline(PixelRegressionRoutine): single = ["std"] versions = { "std": [ bts_resnet50, bts_resnet101, ] } archs = [50, 101] def __init__( self, loss: nn.Module, version: Literal["std"], arch: int, max_depth: float, dist_family: str | None = None, num_estimators: int = 1, pretrained_backbone: bool = True, ) -> None: params = { "dist_family": dist_family, "max_depth": max_depth, "pretrained_backbone": pretrained_backbone, } format_batch_fn = nn.Identity() if version not in self.versions: raise ValueError(f"Unknown version {version}") model = self.versions[version][self.archs.index(arch)](**params) super().__init__( output_dim=1, model=model, loss=loss, num_estimators=num_estimators, format_batch_fn=format_batch_fn, dist_family=dist_family, ) self.save_hyperparameters(ignore=["loss"])