Source code for tiatoolbox.models.models_abc
"""Define Abstract Base Class for Models defined in tiatoolbox."""
from __future__ import annotations
import os
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any
import torch
import torch._dynamo
import torch.distributed as dist
from torch import nn
from torch.nn.parallel import DistributedDataParallel
from tiatoolbox.models.architecture.utils import is_torch_compile_compatible
torch._dynamo.config.suppress_errors = True # skipcq: PYL-W0212 # noqa: SLF001
if TYPE_CHECKING: # pragma: no cover
from collections.abc import Callable
from pathlib import Path
import numpy as np
[docs]
def load_torch_model(model: nn.Module, weights: str | Path) -> nn.Module:
"""Helper function to load a torch model.
Args:
model (torch.nn.Module):
A torch model.
weights (str or Path):
Path to pretrained weights.
Returns:
torch.nn.Module:
Torch model with pretrained weights loaded on CPU.
"""
# ! assume to be saved in single GPU mode
# always load on to the CPU
saved_state_dict = torch.load(weights, map_location="cpu")
model.load_state_dict(saved_state_dict, strict=True)
return model
[docs]
def model_to(model: torch.nn.Module, device: str = "cpu") -> torch.nn.Module:
"""Transfers model to cpu/gpu.
Args:
model (torch.nn.Module):
PyTorch defined model.
device (str):
Transfers model to the specified device. Default is "cpu".
Returns:
torch.nn.Module:
The model after being moved to cpu/gpu.
"""
torch_device = torch.device(device)
# Use DDP if multiple GPUs and not on CPU
if (
device == "cuda"
and torch.cuda.device_count() > 1
and is_torch_compile_compatible()
): # pragma: no cover
# This assumes a single-process DDP setup for inference
model = model.to(torch_device)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
dist.init_process_group(backend="nccl", rank=0, world_size=1)
model = DistributedDataParallel(model, device_ids=[torch_device.index])
elif device != "cpu":
# DataParallel work only for cuda
model = torch.nn.DataParallel(model)
model = model.to(torch_device)
else:
model = model.to(torch_device)
return model
[docs]
class ModelABC(ABC, torch.nn.Module):
"""Abstract base class for models used in tiatoolbox."""
def __init__(self: ModelABC) -> None:
"""Initialize Abstract class ModelABC."""
super().__init__()
self._postproc = self.postproc
self._preproc = self.preproc
self.class_dict = None
[docs]
@abstractmethod
# This is generic abc, else pylint will complain
def forward(
self: ModelABC, *args: tuple[Any, ...], **kwargs: dict
) -> None | torch.Tensor:
"""Torch method, this contains logic for using layers defined in init."""
... # pragma: no cover
[docs]
@staticmethod
@abstractmethod
def infer_batch(
model: nn.Module, batch_data: np.ndarray | torch.Tensor, *, device: str
) -> np.ndarray | tuple[np.ndarray, ...] | dict:
"""Run inference on an input batch.
Contains logic for forward operation as well as I/O aggregation.
Args:
model (nn.Module):
PyTorch defined model.
batch_data (np.ndarray | torch.Tensor):
A batch of data generated by
`torch.utils.data.DataLoader`.
device (str):
Transfers model to the specified device.
Returns:
np.ndarray:
The inference results as a numpy array.
dict:
Returns a dictionary of predictions and other expected outputs
depending on the network architecture.
"""
... # pragma: no cover
[docs]
@staticmethod
def preproc(image: np.ndarray) -> np.ndarray:
"""Define the pre-processing of this class of model."""
return image
[docs]
@staticmethod
def postproc(image: np.ndarray) -> np.ndarray:
"""Define the post-processing of this class of model."""
return image
@property
def preproc_func(self: ModelABC) -> Callable:
"""Return the current pre-processing function of this instance."""
return self._preproc
@preproc_func.setter
def preproc_func(self: ModelABC, func: Callable) -> None:
"""Set the pre-processing function for this instance.
If `func=None`, the method will default to `self.preproc`.
Otherwise, `func` is expected to be callable.
Examples:
>>> # expected usage
>>> # model is a subclass object of this ModelABC
>>> # `func` is a user defined function
>>> model = ModelABC()
>>> model.preproc_func = func
>>> transformed_img = model.preproc_func(image=np.ndarray)
"""
if func is not None and not callable(func):
msg = f"{func} is not callable!"
raise ValueError(msg)
if func is None:
self._preproc = self.preproc
else:
self._preproc = func
@property
def postproc_func(self: ModelABC) -> Callable:
"""Return the current post-processing function of this instance."""
return self._postproc
@postproc_func.setter
def postproc_func(self: ModelABC, func: Callable) -> None:
"""Set the pre-processing function for this instance of model.
If `func=None`, the method will default to `self.postproc`.
Otherwise, `func` is expected to be callable and behave as
follows:
Examples:
>>> # expected usage
>>> # model is a subclass object of this ModelABC
>>> # `func` is a user defined function
>>> model = ModelABC()
>>> model.postproc_func = func
>>> transformed_img = model.postproc_func(image=np.ndarray)
"""
if func is not None and not callable(func):
msg = f"{func} is not callable!"
raise ValueError(msg)
if func is None:
self._postproc = self.postproc
else:
self._postproc = func
[docs]
def to( # type: ignore[override]
self: ModelABC,
device: str = "cpu",
dtype: torch.dtype | None = None,
*,
non_blocking: bool = False,
) -> ModelABC | torch.nn.DataParallel[ModelABC]:
"""Transfers model to cpu/gpu.
Args:
self (ModelABC):
PyTorch defined model.
device (str):
Transfers model to the specified device. Default is "cpu".
dtype (:class:`torch.dtype`): the desired floating point or complex dtype of
the parameters and buffers in this module.
non_blocking (bool): When set, it tries to convert/move asynchronously
with respect to the host if possible, e.g., moving CPU Tensors with
pinned memory to CUDA devices.
Returns:
torch.nn.Module | torch.nn.DataParallel:
The model after being moved to cpu/gpu.
"""
torch_device = torch.device(device)
model = super().to(torch_device, dtype=dtype, non_blocking=non_blocking)
# If target device istorch.cuda and more
# than one GPU is available, use DataParallel
if torch_device.type == "cuda" and torch.cuda.device_count() > 1:
return torch.nn.DataParallel(model) # pragma: no cover
return model
[docs]
def load_weights_from_file(self: ModelABC, weights: str | Path) -> torch.nn.Module:
"""Helper function to load a torch model.
Args:
self (ModelABC):
A torch model as :class:`ModelABC`.
weights (str or Path):
Path to pretrained weights.
Returns:
torch.nn.Module:
Torch model with pretrained weights loaded on CPU.
"""
# ! assume to be saved in single GPU mode
# always load on to the CPU
saved_state_dict = torch.load(weights, map_location="cpu")
return super().load_state_dict(saved_state_dict, strict=True)