"""Define dataset abstract classes."""
from __future__ import annotations
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING
import cv2
import numpy as np
import torch
from tiatoolbox import logger
from tiatoolbox.tools.patchextraction import PatchExtractor
from tiatoolbox.utils import imread
from tiatoolbox.utils.exceptions import DimensionMismatchError
from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIReader
if TYPE_CHECKING: # pragma: no cover
from collections.abc import Callable, Iterable
from typing import TypeGuard
from tiatoolbox.type_hints import IntPair, Resolution, Units
input_type = list[str | Path | np.ndarray] | np.ndarray
[docs]
class PatchDatasetABC(ABC, torch.utils.data.Dataset):
"""Define abstract base class for patch dataset."""
inputs: input_type
labels: list[int] | np.ndarray
def __init__(
self: PatchDatasetABC,
) -> None:
"""Initialize :class:`PatchDatasetABC`."""
super().__init__()
self._preproc = self.preproc
self.data_is_npy_alike = False
self.inputs = []
self.labels = []
@staticmethod
def _check_shape_integrity(shapes: list | np.ndarray) -> None:
"""Checks the integrity of input shapes.
Args:
shapes (list or np.ndarray):
input shape to check.
Raises:
ValueError: If the shape is not valid.
"""
if any(len(v) != 3 for v in shapes): # noqa: PLR2004
msg = "Each sample must be an array of the form HWC."
raise ValueError(msg)
max_shape = np.max(shapes, axis=0)
if (shapes - max_shape[None]).sum() != 0:
msg = "Images must have the same dimensions."
raise ValueError(msg)
@staticmethod
def _are_paths(inputs: input_type) -> TypeGuard[Iterable[Path]]:
"""TypeGuard to check that input array contains only paths."""
return all(isinstance(v, (Path, str)) for v in inputs)
@staticmethod
def _are_npy_like(inputs: input_type) -> TypeGuard[Iterable[np.ndarray]]:
"""TypeGuard to check that input array contains only np.ndarray."""
return all(isinstance(v, np.ndarray) for v in inputs)
def _check_input_integrity(self: PatchDatasetABC, mode: str) -> None:
"""Check that variables received during init are valid.
These checks include:
- Input is of a singular data type, such as a list of paths.
- If it is list of images, all images are of the same height
and width.
"""
if mode == "patch":
self.data_is_npy_alike = False
msg = (
"Input must be either a list/array of images "
"or a list of valid image paths."
)
# When a list of paths is provided
if self._are_paths(self.inputs):
if any(not Path(v).exists() for v in self.inputs):
# at least one of the paths are invalid
raise ValueError(
msg,
)
# Preload test for sanity check
shapes = [self.load_img(v).shape for v in self.inputs]
self.data_is_npy_alike = False
elif self._are_npy_like(self.inputs):
shapes = [v.shape for v in self.inputs]
self.data_is_npy_alike = True
else:
raise ValueError(msg)
self._check_shape_integrity(shapes)
# If input is a numpy array
if isinstance(self.inputs, np.ndarray):
# Check that input array is numerical
if not np.issubdtype(self.inputs.dtype, np.number):
# ndarray of mixed data types
msg = "Provided input array is non-numerical."
raise ValueError(msg)
self.data_is_npy_alike = True
elif not isinstance(self.inputs, (list, np.ndarray)):
msg = "`inputs` should be a list of patch coordinates."
raise ValueError(msg)
[docs]
@staticmethod
def load_img(path: str | Path) -> np.ndarray:
"""Load an image from a provided path.
Args:
path (str or Path): Path to an image file.
Returns:
:class:`numpy.ndarray`:
Image as a numpy array.
"""
path = Path(path)
if path.suffix not in (".npy", ".jpg", ".jpeg", ".tif", ".tiff", ".png"):
msg = f"Cannot load image data from `{path.suffix}` files."
raise TypeError(msg)
return imread(path, as_uint8=False)
[docs]
@staticmethod
def preproc(image: np.ndarray) -> np.ndarray:
"""Define the pre-processing of this class of loader."""
return image
@property
def preproc_func(self: PatchDatasetABC) -> Callable:
"""Return the current pre-processing function of this instance.
The returned function is expected to behave as follows:
>>> transformed_img = func(img)
"""
return self._preproc
@preproc_func.setter
def preproc_func(self: PatchDatasetABC, func: Callable) -> None:
"""Set the pre-processing function for this instance.
If `func=None`, the method will default to `self.preproc`.
Otherwise, `func` is expected to be callable and behaves as
follows:
>>> transformed_img = func(img)
"""
if func is None:
self._preproc = self.preproc
elif callable(func):
self._preproc = func
else:
msg = f"{func} is not callable!"
raise ValueError(msg)
def __len__(self: PatchDatasetABC) -> int:
"""Return the length of the instance attributes."""
return len(self.inputs)
@abstractmethod
def __getitem__(self: PatchDatasetABC, idx: int) -> None:
"""Get an item from the dataset."""
... # pragma: no cover
[docs]
class WSIPatchDataset(PatchDatasetABC):
"""Define a WSI-level patch dataset.
Attributes:
reader (:class:`.WSIReader`):
A WSI Reader or Virtual Reader for reading pyramidal image
or large tile in pyramidal way.
inputs:
List of coordinates to read from the `reader`, each
coordinate is of the form `[start_x, start_y, end_x,
end_y]`.
patch_input_shape:
A tuple (int, int) or ndarray of shape (2,). Expected size to
read from `reader` at requested `resolution` and `units`.
Expected to be `(height, width)`.
resolution:
See (:class:`.WSIReader`) for details.
units:
See (:class:`.WSIReader`) for details.
preproc_func:
Preprocessing function used to transform the input data. It will
be called on each patch before returning it.
"""
def __init__( # skipcq: PY-R1000
self: WSIPatchDataset,
input_img: str | Path | WSIReader,
mask_path: str | Path | None = None,
patch_input_shape: IntPair = None,
patch_output_shape: IntPair = None,
stride_shape: IntPair = None,
resolution: Resolution = None,
units: Units = None,
min_mask_ratio: float = 0,
preproc_func: Callable | None = None,
*,
auto_get_mask: bool = True,
) -> None:
"""Create a WSI-level patch dataset.
Args:
input_img (str or Path or WSIReader):
Valid path to a whole-slide image class:`WSIReader`.
mask_path (str or Path):
Valid mask image.
patch_input_shape:
A tuple (int, int) or ndarray of shape (2,). Expected
shape to read from `reader` at requested `resolution`
and `units`. Expected to be positive and of (height,
width). Note, this is not at `resolution` coordinate
space.
patch_output_shape:
A tuple (int, int) or ndarray of shape (2,). Expected
output shape from the model at requested `resolution`
and `units`. Expected to be positive and of (height,
width). Note, this is not at `resolution` coordinate
space.
stride_shape:
A tuple (int, int) or ndarray of shape (2,). Expected
stride shape to read at requested `resolution` and
`units`. Expected to be positive and of (height, width).
Note, this is not at level 0.
resolution (Resolution):
Requested resolution corresponding to units. Check
(:class:`WSIReader`) for details.
units (Units):
Units in which `resolution` is defined.
auto_get_mask (bool):
If `True`, then automatically get simple threshold mask using
WSIReader.tissue_mask() function.
min_mask_ratio (float):
Only patches with positive area percentage above this value are
included. Defaults to 0.
preproc_func (Callable):
Preprocessing function used to transform the input data. If
supplied, the function will be called on each patch before
returning it.
Examples:
>>> # A user defined preproc func and expected behavior
>>> preproc_func = lambda img: img/2 # reduce intensity by half
>>> transformed_img = preproc_func(img)
>>> # Create a dataset to get patches from WSI with above
>>> # preprocessing function
>>> ds = WSIPatchDataset(
... input_img='/A/B/C/wsi.svs',
... patch_input_shape=[512, 512],
... stride_shape=[256, 256],
... auto_get_mask=False,
... preproc_func=preproc_func
... )
"""
super().__init__()
valid_path = bool(
isinstance(input_img, (str, Path)) and Path(input_img).is_file()
)
# Is there a generic func for path test in toolbox?
if not valid_path and not isinstance(input_img, WSIReader):
msg = "`input_img` must be a valid file path or a `WSIReader` instance."
raise ValueError(msg)
patch_input_shape = np.array(patch_input_shape)
stride_shape = np.array(stride_shape)
_validate_patch_stride_shape(patch_input_shape, stride_shape)
self.preproc_func = preproc_func
img_path = (
input_img if not isinstance(input_img, WSIReader) else input_img.input_path
)
self.img_path = Path(img_path)
reader = (
input_img
if isinstance(input_img, WSIReader)
else WSIReader.open(self.img_path)
)
# To support multi-threading on Windows
# Helps pickle using Path
self.reader = None if os.name == "nt" else reader
# may decouple into misc ?
# the scaling factor will scale base level to requested read resolution/units
wsi_shape = reader.slide_dimensions(resolution=resolution, units=units)
self.reader_info = reader.info
# use all patches, as long as it overlaps source image
if patch_output_shape is not None:
self.inputs, self.outputs = PatchExtractor.get_coordinates(
image_shape=wsi_shape,
patch_input_shape=patch_input_shape[::-1],
stride_shape=stride_shape[::-1],
patch_output_shape=patch_output_shape,
)
self.full_outputs = self.outputs
else:
self.inputs = PatchExtractor.get_coordinates(
image_shape=wsi_shape,
patch_input_shape=patch_input_shape[::-1],
stride_shape=stride_shape[::-1],
)
mask_reader = self._setup_mask_reader(
mask_path=mask_path,
reader=reader,
auto_get_mask=auto_get_mask,
)
if mask_reader is not None:
selected = PatchExtractor.filter_coordinates(
mask_reader, # must be at the same resolution
self.inputs, # must already be at requested resolution
wsi_shape=wsi_shape,
min_mask_ratio=min_mask_ratio,
)
self.inputs = self.inputs[selected]
if hasattr(self, "outputs"):
self.full_outputs = self.outputs # Full list of outputs
self.outputs = self.outputs[selected]
self._check_inputs()
self.patch_input_shape = patch_input_shape
self.resolution = resolution
self.units = units
# Perform check on the input
self._check_input_integrity(mode="wsi")
def _setup_mask_reader(
self: WSIPatchDataset,
mask_path: str | Path | None,
reader: WSIReader,
*,
auto_get_mask: bool,
) -> VirtualWSIReader | None:
"""Create a mask reader for WSIPatchDataset if requested."""
if mask_path is not None:
mask_path = Path(mask_path)
if not Path.is_file(mask_path):
msg = "`mask_path` must be a valid file path."
raise ValueError(msg)
mask = imread(mask_path) # assume to be gray
mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
mask = np.array(mask > 0, dtype=np.uint8)
mask_reader = VirtualWSIReader(mask)
mask_reader.info = self.reader_info
return mask_reader
if auto_get_mask and mask_path is None:
# if no mask provided and `wsi` mode, generate basic tissue
# mask on the fly
try:
mask_reader = reader.tissue_mask(resolution=1.25, units="power")
except ValueError:
# if power is None, try with mpp
mask_reader = reader.tissue_mask(resolution=6.0, units="mpp")
# ? will this mess up ?
mask_reader.info = self.reader_info
return mask_reader
return None
def _check_inputs(self: WSIPatchDataset) -> None:
"""Check if input length is valid after filtering."""
if len(self.inputs) == 0:
msg = "No patch coordinates remain after filtering."
raise ValueError(msg)
def _get_reader(self: WSIPatchDataset, img_path: str | Path) -> WSIReader:
"""Get a reader for the image."""
# To avoid ruff errors and compatibility with base class.
return self.reader if self.reader else WSIReader.open(img_path)
def __getitem__(self: WSIPatchDataset, idx: int) -> dict:
"""Get an item from the dataset."""
coords = self.inputs[idx]
output_locs = None
if hasattr(self, "outputs"):
output_locs = self.outputs[idx]
# Read image patch from the whole-slide image
self.reader = self._get_reader(self.img_path)
patch = self.reader.read_bounds(
coords,
resolution=self.resolution,
units=self.units,
pad_constant_values=255,
coord_space="resolution",
)
# Apply preprocessing to selected patch
patch = self._preproc(patch)
if output_locs is not None:
return {
"image": patch,
"coords": np.array(coords),
"output_locs": output_locs,
}
return {"image": patch, "coords": np.array(coords)}
[docs]
class PatchDataset(PatchDatasetABC):
"""Define PatchDataset for torch inference.
Define a simple patch dataset, which inherits from the
`torch.utils.data.Dataset` class.
Attributes:
inputs (list or np.ndarray):
Either a list of patches, where each patch is a ndarray or a
list of valid path with its extension be (".jpg", ".jpeg",
".tif", ".tiff", ".png") pointing to an image.
labels (list):
List of labels for sample at the same index in `inputs`.
Default is `None`.
patch_input_shape (tuple):
Size of patches input to the model. Patches are at
requested read resolution, not with respect to level 0,
and must be positive.
Examples:
>>> # A user defined preproc func and expected behavior
>>> preproc_func = lambda img: img/2 # reduce intensity by half
>>> transformed_img = preproc_func(img)
>>> # create a dataset to get patches preprocessed by the above function
>>> ds = PatchDataset(
... inputs=['/A/B/C/img1.png', '/A/B/C/img2.png'],
... labels=["labels1", "labels2"],
... patch_input_shape=(224, 224),
... )
"""
def __init__(
self: PatchDataset,
inputs: np.ndarray | list,
labels: list | None = None,
patch_input_shape: IntPair | None = None,
) -> None:
"""Initialize :class:`PatchDataset`."""
super().__init__()
self.data_is_npy_alike = False
self.inputs = inputs
self.labels = labels
self.patch_input_shape = patch_input_shape
# perform check on the input
self._check_input_integrity(mode="patch")
def __getitem__(self: PatchDataset, idx: int) -> dict:
"""Get an item from the dataset."""
patch = self.inputs[idx]
# Mode 0 is list of paths
if not self.data_is_npy_alike:
patch = self.load_img(patch)
if patch.shape[:-1] != tuple(self.patch_input_shape):
msg = (
f"Patch size is not compatible with the model. "
f"Expected dimensions {tuple(self.patch_input_shape)}, but got "
f"{patch.shape[:-1]}."
)
logger.error(msg=msg)
raise DimensionMismatchError(
expected_dims=tuple(self.patch_input_shape),
actual_dims=patch.shape[:-1],
)
# Apply preprocessing to selected patch
patch = self._preproc(patch)
data = {
"image": patch,
}
if self.labels is not None:
data["label"] = self.labels[idx]
return data
return data
def _validate_patch_stride_shape(
patch_input_shape: np.ndarray, stride_shape: np.ndarray
) -> None:
"""Validate patch and stride shape inputs for semantic segmentation.
Checks that both `patch_input_shape` and `stride_shape` are integer arrays of
length ≤ 2 and contain non-negative values. Raises a ValueError if any
condition fails.
Parameters:
patch_input_shape (np.ndarray):
Shape of the input patch (e.g., height, width).
stride_shape (np.ndarray):
Stride dimensions used for patch extraction.
Raises:
ValueError:
If either input is not a valid integer array of appropriate
shape and values.
"""
if (
not np.issubdtype(patch_input_shape.dtype, np.integer)
or np.size(patch_input_shape) > 2 # noqa: PLR2004
or np.any(patch_input_shape < 0)
):
msg = f"Invalid `patch_input_shape` value {patch_input_shape}."
raise ValueError(msg)
if (
not np.issubdtype(stride_shape.dtype, np.integer)
or np.size(stride_shape) > 2 # noqa: PLR2004
or np.any(stride_shape < 0)
):
msg = f"Invalid `stride_shape` value {stride_shape}."
raise ValueError(msg)