Source code for tiatoolbox.models.dataset.classification
"""Define classes and methods for classification datasets."""
from __future__ import annotations
from typing import TYPE_CHECKING
from torchvision import transforms
if TYPE_CHECKING: # pragma: no cover
import numpy as np
import torch
from PIL import Image
class _TorchPreprocCaller:
"""Wrapper for applying PyTorch transforms.
Args:
preprocs (list):
List of torchvision transforms for preprocessing the image.
The transforms will be applied in the order that they are
given in the list. For more information, visit the following
link: https://pytorch.org/vision/stable/transforms.html.
"""
def __init__(self: _TorchPreprocCaller, preprocs: list) -> None:
self.func = transforms.Compose(preprocs)
def __call__(self: _TorchPreprocCaller, img: np.ndarray | Image) -> torch.Tensor:
tensor: torch.Tensor = self.func(img)
return tensor.permute((1, 2, 0))
[docs]
def predefined_preproc_func(dataset_name: str) -> _TorchPreprocCaller:
"""Get the preprocessing information used for the pretrained model.
Args:
dataset_name (str):
Dataset name used to determine what preprocessing was used.
Returns:
_TorchPreprocCaller:
Preprocessing function for transforming the input data.
"""
preproc_dict = {
"kather100k": [
transforms.ToTensor(),
],
"pcam": [
transforms.ToTensor(),
],
}
if dataset_name not in preproc_dict:
msg = f"Predefined preprocessing for dataset `{dataset_name}` does not exist."
raise ValueError(
msg,
)
preprocs = preproc_dict[dataset_name]
return _TorchPreprocCaller(preprocs)