Source code for torch_uncertainty.baselines.depth.bts
from typing import Literal
from torch import nn
from torch_uncertainty.models.depth.bts import bts_resnet
from torch_uncertainty.routines import PixelRegressionRoutine
[docs]
class BTSBaseline(PixelRegressionRoutine):
archs = [50, 101]
def __init__(
self,
loss: nn.Module,
version: Literal["std"],
arch: int,
max_depth: float,
dist_family: str | None = None,
pretrained_backbone: bool = True,
) -> None:
params = {
"arch": arch,
"dist_family": dist_family,
"max_depth": max_depth,
"pretrained_backbone": pretrained_backbone,
}
format_batch_fn = nn.Identity()
if version not in self.versions:
raise ValueError(f"Unknown version {version}")
model = bts_resnet(**params)
super().__init__(
model=model,
output_dim=1,
loss=loss,
format_batch_fn=format_batch_fn,
dist_family=dist_family,
)
self.save_hyperparameters(ignore=["loss"])