Source code for tiatoolbox.models.architecture
"""Define a set of models to be used within tiatoolbox."""
from __future__ import annotations
from pathlib import Path
from pydoc import locate
from typing import TYPE_CHECKING
import timm
from huggingface_hub import hf_hub_download
from tiatoolbox import rcParam
from tiatoolbox.models.dataset.classification import predefined_preproc_func
from tiatoolbox.models.models_abc import load_torch_model
from .vanilla import CNNBackbone, TimmBackbone, timm_arch_dict, torch_cnn_backbone_dict
if TYPE_CHECKING: # pragma: no cover
import torch
from tiatoolbox.models.engine.io_config import ModelIOConfigABC
__all__ = ["fetch_pretrained_weights", "get_pretrained_model"]
PRETRAINED_INFO = rcParam["pretrained_model_info"]
[docs]
def fetch_pretrained_weights(
model_name: str,
save_path: str | Path | None = None,
*,
overwrite: bool = False,
) -> str:
"""Get the pretrained model information from yml file.
Args:
model_name (str):
Refer to `::py::meth:get_pretrained_model` for all supported
model names.
save_path (str | Path):
Path to the directory in which the pretrained weight will be cached.
overwrite (bool):
Overwrite existing downloaded weights (force downloading).
Returns:
Path:
The local path to the cached pretrained weights after downloading.
"""
if model_name not in PRETRAINED_INFO:
msg = f"Pretrained model `{model_name}` does not exist"
raise ValueError(msg)
info = PRETRAINED_INFO[model_name]
hf_repo_id = info["hf_repo_id"]
file_name = f"{model_name}.pth"
if save_path is None:
local_dir = rcParam["TIATOOLBOX_HOME"] / "models"
else:
local_dir = Path(save_path)
return hf_hub_download(
repo_id=hf_repo_id,
filename=file_name,
local_dir=local_dir,
force_download=overwrite,
)
[docs]
def get_pretrained_model(
pretrained_model: str | None = None,
pretrained_weights: str | Path | None = None,
*,
overwrite: bool = False,
) -> tuple[torch.nn.Module, ModelIOConfigABC | None]:
"""Load a predefined PyTorch model with the appropriate pretrained weights.
Args:
pretrained_model (str):
Name of the existing models support by tiatoolbox for
processing the data. The models currently supported:
- alexnet
- resnet18
- resnet34
- resnet50
- resnet101
- resnext50_32x4d
- resnext101_32x8d
- wide_resnet50_2
- wide_resnet101_2
- densenet121
- densenet161
- densenet169
- densenet201
- mobilenet_v2
- mobilenet_v3_large
- mobilenet_v3_small
- googlenet
Each model has been trained on the Kather100K and PCam
datasets. The format of pretrained_model is
<model_name>-<dataset_name>. For example, to use a resnet18
model trained on Kather100K, use `resnet18-kather100k and to
use an alexnet model trained on PCam, use `alexnet-pcam`.
By default, the corresponding pretrained weights will also be
downloaded. However, you can override with your own set of
weights via the `pretrained_weights` argument. Argument is case-insensitive.
pretrained_weights (str):
Path to the weight of the corresponding `pretrained_model`.
overwrite (bool):
To always overwriting downloaded weights.
Examples:
>>> # get mobilenet pretrained on Kather100K dataset by the TIA team
>>> model = get_pretrained_model(pretrained_model='mobilenet_v2-kather100k')
>>> # get mobilenet defined by TIA team, but loaded with user defined weights
>>> model = get_pretrained_model(
... pretrained_model='mobilenet_v2-kather100k',
... pretrained_weights='/A/B/C/my_weights.tar',
... )
>>> # get resnet34 pretrained on PCam dataset by TIA team
>>> model = get_pretrained_model(pretrained_model='resnet34-pcam')
"""
if not isinstance(pretrained_model, str):
msg = "pretrained_model must be a string."
raise TypeError(msg)
if pretrained_model in torch_cnn_backbone_dict:
return CNNBackbone(pretrained_model), None
if pretrained_model in [*timm_arch_dict, *timm.list_models()]:
return TimmBackbone(pretrained_model, pretrained=True), None
if pretrained_model not in PRETRAINED_INFO:
msg = f"Pretrained model `{pretrained_model}` does not exist."
raise ValueError(msg)
info = PRETRAINED_INFO[pretrained_model]
arch_info = info["architecture"]
model_class_info = arch_info["class"]
model_module_name = str(".".join(model_class_info.split(".")[:-1]))
model_name = str(model_class_info.split(".")[-1])
# Import module containing required model class
arch_module = locate(f"tiatoolbox.models.architecture.{model_module_name}")
# Get model class form module
model_class = getattr(arch_module, model_name)
model = model_class(**arch_info["kwargs"])
# TODO(TBC): Dictionary of dataset specific or transformation? # noqa: FIX002,TD003
if "dataset" in info:
# ! this is a hack currently, need another PR to clean up
# ! associated pre-processing coming from dataset (Kumar, Kather, etc.)
model.preproc_func = predefined_preproc_func(info["dataset"])
if pretrained_weights is None:
pretrained_weights = fetch_pretrained_weights(
pretrained_model,
overwrite=overwrite,
)
model = load_torch_model(model=model, weights=pretrained_weights)
io_info = info["ioconfig"]
io_class_info = io_info["class"]
io_module_name = str(".".join(io_class_info.split(".")[:-1]))
io_class_name = str(io_class_info.split(".")[-1])
engine_module = locate(f"tiatoolbox.models.engine.{io_module_name}")
engine_class = getattr(engine_module, io_class_name)
ioconfig = engine_class(**io_info["kwargs"])
return model, ioconfig