Source code for tiatoolbox.models.models_abc

"""Define Abstract Base Class for Models defined in tiatoolbox."""

from __future__ import annotations

import os
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any

import torch
import torch._dynamo
import torch.distributed as dist
from torch import nn
from torch.nn.parallel import DistributedDataParallel

from tiatoolbox.models.architecture.utils import is_torch_compile_compatible

torch._dynamo.config.suppress_errors = True  # skipcq: PYL-W0212  # noqa: SLF001

if TYPE_CHECKING:  # pragma: no cover
    from collections.abc import Callable
    from pathlib import Path

    import numpy as np


[docs] def load_torch_model(model: nn.Module, weights: str | Path) -> nn.Module: """Helper function to load a torch model. Args: model (torch.nn.Module): A torch model. weights (str or Path): Path to pretrained weights. Returns: torch.nn.Module: Torch model with pretrained weights loaded on CPU. """ # ! assume to be saved in single GPU mode # always load on to the CPU saved_state_dict = torch.load(weights, map_location="cpu") model.load_state_dict(saved_state_dict, strict=True) return model
[docs] def model_to(model: torch.nn.Module, device: str = "cpu") -> torch.nn.Module: """Transfers model to cpu/gpu. Args: model (torch.nn.Module): PyTorch defined model. device (str): Transfers model to the specified device. Default is "cpu". Returns: torch.nn.Module: The model after being moved to cpu/gpu. """ torch_device = torch.device(device) # Use DDP if multiple GPUs and not on CPU if ( device == "cuda" and torch.cuda.device_count() > 1 and is_torch_compile_compatible() ): # pragma: no cover # This assumes a single-process DDP setup for inference model = model.to(torch_device) os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12355" dist.init_process_group(backend="nccl", rank=0, world_size=1) model = DistributedDataParallel(model, device_ids=[torch_device.index]) elif device != "cpu": # DataParallel work only for cuda model = torch.nn.DataParallel(model) model = model.to(torch_device) else: model = model.to(torch_device) return model
[docs] class ModelABC(ABC, torch.nn.Module): """Abstract base class for models used in tiatoolbox.""" def __init__(self: ModelABC) -> None: """Initialize Abstract class ModelABC.""" super().__init__() self._postproc = self.postproc self._preproc = self.preproc self.class_dict = None
[docs] @abstractmethod # This is generic abc, else pylint will complain def forward( self: ModelABC, *args: tuple[Any, ...], **kwargs: dict ) -> None | torch.Tensor: """Torch method, this contains logic for using layers defined in init.""" ... # pragma: no cover
[docs] @staticmethod @abstractmethod def infer_batch( model: nn.Module, batch_data: np.ndarray | torch.Tensor, *, device: str ) -> np.ndarray | tuple[np.ndarray, ...] | dict: """Run inference on an input batch. Contains logic for forward operation as well as I/O aggregation. Args: model (nn.Module): PyTorch defined model. batch_data (np.ndarray | torch.Tensor): A batch of data generated by `torch.utils.data.DataLoader`. device (str): Transfers model to the specified device. Returns: np.ndarray: The inference results as a numpy array. dict: Returns a dictionary of predictions and other expected outputs depending on the network architecture. """ ... # pragma: no cover
[docs] @staticmethod def preproc(image: np.ndarray) -> np.ndarray: """Define the pre-processing of this class of model.""" return image
[docs] @staticmethod def postproc(image: np.ndarray) -> np.ndarray: """Define the post-processing of this class of model.""" return image
@property def preproc_func(self: ModelABC) -> Callable: """Return the current pre-processing function of this instance.""" return self._preproc @preproc_func.setter def preproc_func(self: ModelABC, func: Callable) -> None: """Set the pre-processing function for this instance. If `func=None`, the method will default to `self.preproc`. Otherwise, `func` is expected to be callable. Examples: >>> # expected usage >>> # model is a subclass object of this ModelABC >>> # `func` is a user defined function >>> model = ModelABC() >>> model.preproc_func = func >>> transformed_img = model.preproc_func(image=np.ndarray) """ if func is not None and not callable(func): msg = f"{func} is not callable!" raise ValueError(msg) if func is None: self._preproc = self.preproc else: self._preproc = func @property def postproc_func(self: ModelABC) -> Callable: """Return the current post-processing function of this instance.""" return self._postproc @postproc_func.setter def postproc_func(self: ModelABC, func: Callable) -> None: """Set the pre-processing function for this instance of model. If `func=None`, the method will default to `self.postproc`. Otherwise, `func` is expected to be callable and behave as follows: Examples: >>> # expected usage >>> # model is a subclass object of this ModelABC >>> # `func` is a user defined function >>> model = ModelABC() >>> model.postproc_func = func >>> transformed_img = model.postproc_func(image=np.ndarray) """ if func is not None and not callable(func): msg = f"{func} is not callable!" raise ValueError(msg) if func is None: self._postproc = self.postproc else: self._postproc = func
[docs] def to( # type: ignore[override] self: ModelABC, device: str = "cpu", dtype: torch.dtype | None = None, *, non_blocking: bool = False, ) -> ModelABC | torch.nn.DataParallel[ModelABC]: """Transfers model to cpu/gpu. Args: self (ModelABC): PyTorch defined model. device (str): Transfers model to the specified device. Default is "cpu". dtype (:class:`torch.dtype`): the desired floating point or complex dtype of the parameters and buffers in this module. non_blocking (bool): When set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. Returns: torch.nn.Module | torch.nn.DataParallel: The model after being moved to cpu/gpu. """ torch_device = torch.device(device) model = super().to(torch_device, dtype=dtype, non_blocking=non_blocking) # If target device istorch.cuda and more # than one GPU is available, use DataParallel if torch_device.type == "cuda" and torch.cuda.device_count() > 1: return torch.nn.DataParallel(model) # pragma: no cover return model
[docs] def load_weights_from_file(self: ModelABC, weights: str | Path) -> torch.nn.Module: """Helper function to load a torch model. Args: self (ModelABC): A torch model as :class:`ModelABC`. weights (str or Path): Path to pretrained weights. Returns: torch.nn.Module: Torch model with pretrained weights loaded on CPU. """ # ! assume to be saved in single GPU mode # always load on to the CPU saved_state_dict = torch.load(weights, map_location="cpu") return super().load_state_dict(saved_state_dict, strict=True)