"""Defines IOConfig for Model Engines."""
from __future__ import annotations
from dataclasses import dataclass, field, replace
from typing import TYPE_CHECKING
import numpy as np
if TYPE_CHECKING: # pragma: no cover
from tiatoolbox.type_hints import Resolution, Units
[docs]
@dataclass
class ModelIOConfigABC:
"""Defines a data class for holding a deep learning model's I/O information.
Enforcing such that following attributes must always be defined by
the subclass.
Args:
input_resolutions (list(dict)):
Resolution of each input head of model inference, must be in
the same order as `target model.forward()`.
patch_input_shape (:class:`numpy.ndarray`, list(int), tuple(int, int)):
Shape of the largest input in (height, width).
stride_shape (:class:`numpy.ndarray`, list(int), tuple(int)):
Stride in (x, y) direction for patch extraction.
output_resolutions (list(dict)):
Resolution of each output head from model inference, must be
in the same order as target model.infer_batch().
Attributes:
input_resolutions (list(dict)):
Resolution of each input head of model inference, must be in
the same order as `target model.forward()`.
patch_input_shape (:class:`numpy.ndarray`, list(int), tuple(int, int)):
Shape of the largest input in (height, width).
stride_shape (:class:`numpy.ndarray`, list(int), tuple(int, int)):
Stride in (x, y) direction for patch extraction.
output_resolutions (list(dict)):
Resolution of each output head from model inference, must be
in the same order as target model.infer_batch().
highest_input_resolution (dict):
Highest resolution to process the image based on input and
output resolutions. This helps to read the image at the optimal
resolution and improves performance.
Examples:
>>> # Defining io for a base network and converting to baseline.
>>> ioconfig = ModelIOConfigABC(
... input_resolutions=[{"units": "mpp", "resolution": 0.5}],
... output_resolutions=[{"units": "mpp", "resolution": 1.0}],
... patch_input_shape=(224, 224),
... stride_shape=(224, 224),
... )
>>> ioconfig = ioconfig.to_baseline()
"""
input_resolutions: list[dict]
patch_input_shape: list[int] | np.ndarray | tuple[int, int]
stride_shape: list[int] | np.ndarray | tuple[int, int] = None
output_resolutions: list[dict] = field(default_factory=list)
def __post_init__(self: ModelIOConfigABC) -> None:
"""Perform post initialization tasks."""
if self.stride_shape is None:
self.stride_shape = self.patch_input_shape
self.resolution_unit = self.input_resolutions[0]["units"]
self.highest_input_resolution = self.input_resolutions[0]["resolution"]
if self.resolution_unit == "mpp":
self.highest_input_resolution = min(
self.input_resolutions,
key=lambda x: x["resolution"],
)
else:
self.highest_input_resolution = max(
self.input_resolutions,
key=lambda x: x["resolution"],
)
self._validate()
def _validate(self: ModelIOConfigABC) -> None:
"""Validate the data format."""
resolutions = self.input_resolutions + self.output_resolutions
units = {v["units"] for v in resolutions}
if len(units) != 1:
msg = (
f"Multiple resolution units found: `{units}`. "
f"Mixing resolution units is not allowed."
)
raise ValueError(
msg,
)
if units.pop() not in [
"power",
"baseline",
"mpp",
]:
msg = f"Invalid resolution units `{units}`."
raise ValueError(msg)
[docs]
@staticmethod
def scale_to_highest(
resolutions: list[dict[Units, Resolution]], units: Units
) -> np.array:
"""Get the scaling factor from input resolutions.
This will convert resolutions to a scaling factor with respect to
the highest resolution found in the input resolutions list. If a model
requires images at multiple resolutions. This helps to read the image a
single resolution. The image will be read at the highest required resolution
and will be scaled for low resolution requirements using interpolation.
Args:
resolutions (list(dict(Units, Resolution))):
A list of resolutions where one is defined as
`{'resolution': value, 'unit': value}`
units (Units):
Resolution units.
Returns:
:class:`numpy.ndarray`:
A 1D array of scaling factors having the same length as
`resolutions`.
Examples:
>>> # Defining io for a base network and converting to baseline.
>>> ioconfig = ModelIOConfigABC(
... input_resolutions=[
... {"units": "mpp", "resolution": 0.25},
... {"units": "mpp", "resolution": 0.5},
... ],
... output_resolutions=[{"units": "mpp", "resolution": 1.0}],
... patch_input_shape=(224, 224),
... stride_shape=(224, 224),
... )
>>> ioconfig = ioconfig.scale_to_highest()
... array([1. , 0.5]) # output
>>>
>>> # Defining io for a base network and converting to baseline.
>>> ioconfig = ModelIOConfigABC(
... input_resolutions=[
... {"units": "mpp", "resolution": 0.5},
... {"units": "mpp", "resolution": 0.25},
... ],
... output_resolutions=[{"units": "mpp", "resolution": 1.0}],
... patch_input_shape=(224, 224),
... stride_shape=(224, 224),
... )
>>> ioconfig = ioconfig.scale_to_highest()
... array([0.5 , 1.]) # output
"""
old_vals = [v["resolution"] for v in resolutions]
if units not in {"baseline", "mpp", "power"}:
msg = (
f"Unknown units `{units}`. "
f"Units should be one of 'baseline', 'mpp' or 'power'."
)
raise ValueError(
msg,
)
if units == "baseline":
return old_vals
if units == "mpp":
return np.min(old_vals) / np.array(old_vals)
return np.array(old_vals) / np.max(old_vals)
[docs]
def to_baseline(self: ModelIOConfigABC) -> ModelIOConfigABC:
"""Returns a new config object converted to baseline form.
This will return a new :class:`ModelIOConfigABC` where
resolutions have been converted to baseline format with the
highest possible resolution found in both input and output as
reference.
"""
resolutions = self.input_resolutions + self.output_resolutions
save_resolution = getattr(self, "save_resolution", None)
if save_resolution is not None:
resolutions.append(save_resolution)
scale_factors = self.scale_to_highest(resolutions, self.resolution_unit)
num_input_resolutions = len(self.input_resolutions)
end_idx = num_input_resolutions
input_resolutions = [
{"units": "baseline", "resolution": v} for v in scale_factors[:end_idx]
]
num_input_resolutions = len(self.input_resolutions)
num_output_resolutions = len(self.output_resolutions)
end_idx = num_input_resolutions + num_output_resolutions
output_resolutions = [
{"units": "baseline", "resolution": v}
for v in scale_factors[num_input_resolutions:end_idx]
]
return replace(
self,
input_resolutions=input_resolutions,
output_resolutions=output_resolutions,
)
[docs]
@dataclass
class IOSegmentorConfig(ModelIOConfigABC):
"""Contains semantic segmentor input and output information.
Args:
input_resolutions (list(dict)):
Resolution of each input head of model inference, must be in
the same order as `target model.forward()`.
patch_input_shape (:class:`numpy.ndarray`, list(int), tuple(int, int)):
Shape of the largest input in (height, width).
stride_shape (:class:`numpy.ndarray`, list(int), tuple(int)):
Stride in (x, y) direction for patch extraction.
output_resolutions (list(dict)):
Resolution of each output head from model inference, must be
in the same order as target model.infer_batch().
patch_output_shape (:class:`numpy.ndarray`, list(int)):
Shape of the largest output in (height, width).
save_resolution (dict):
Resolution to save all output.
tile_shape (tuple(int, int)):
Tile shape to process the WSI.
Attributes:
input_resolutions (list(dict)):
Resolution of each input head of model inference, must be in
the same order as `target model.forward()`.
patch_input_shape (:class:`numpy.ndarray`, list(int), tuple(int, int)):
Shape of the largest input in (height, width).
stride_shape (:class:`numpy.ndarray`, list(int), tuple(int, int)):
Stride in (x, y) direction for patch extraction.
output_resolutions (list(dict)):
Resolution of each output head from model inference, must be
in the same order as target model.infer_batch().
patch_output_shape (:class:`numpy.ndarray`, list(int)):
Shape of the largest output in (height, width).
save_resolution (dict):
Resolution to save all output.
highest_input_resolution (dict):
Highest resolution to process the image based on input and
output resolutions. This helps to read the image at the optimal
resolution and improves performance.
tile_shape (tuple(int, int)):
Tile shape to process the WSI.
margin (int):
Tile margin to accumulate the output.
Examples:
>>> # Defining io for a network having 1 input and 1 output at the
>>> # same resolution
>>> ioconfig = IOSegmentorConfig(
... input_resolutions=[{"units": "baseline", "resolution": 1.0}],
... output_resolutions=[{"units": "baseline", "resolution": 1.0}],
... patch_input_shape=(2048, 2048),
... patch_output_shape=(1024, 1024),
... stride_shape=(512, 512),
... )
...
>>> # Defining io for a network having 3 input and 2 output
>>> # at the same resolution, the output is then merged at a
>>> # different resolution.
>>> ioconfig = IOSegmentorConfig(
... input_resolutions=[
... {"units": "mpp", "resolution": 0.25},
... {"units": "mpp", "resolution": 0.50},
... {"units": "mpp", "resolution": 0.75},
... ],
... output_resolutions=[
... {"units": "mpp", "resolution": 0.25},
... {"units": "mpp", "resolution": 0.50},
... ],
... patch_input_shape=(2048, 2048),
... patch_output_shape=(1024, 1024),
... stride_shape=(512, 512),
... save_resolution={"units": "mpp", "resolution": 4.0},
... )
"""
patch_output_shape: list[int] | np.ndarray | tuple[int, int] = None
save_resolution: dict = None
tile_shape: tuple[int, int] | None = None
margin: int | None = None
[docs]
def to_baseline(self: IOSegmentorConfig) -> IOSegmentorConfig:
"""Returns a new config object converted to baseline form.
This will return a new :class:`IOSegmentorConfig` where
resolutions have been converted to baseline format with the
highest possible resolution found in both input and output as
reference.
"""
new_config = super().to_baseline()
resolutions = self.input_resolutions + self.output_resolutions
if self.save_resolution is not None:
resolutions.append(self.save_resolution)
scale_factors = self.scale_to_highest(resolutions, self.resolution_unit)
save_resolution = None
if self.save_resolution is not None:
save_resolution = {"units": "baseline", "resolution": scale_factors[-1]}
return replace(
self,
input_resolutions=new_config.input_resolutions,
output_resolutions=new_config.output_resolutions,
save_resolution=save_resolution,
)
[docs]
class IOPatchPredictorConfig(ModelIOConfigABC):
"""Contains patch predictor input and output information.
Args:
input_resolutions (list(dict)):
Resolution of each input head of model inference, must be in
the same order as `target model.forward()`.
patch_input_shape (:class:`numpy.ndarray`, list(int), tuple(int)):
Shape of the largest input in (height, width).
stride_shape (:class:`numpy.ndarray`, list(int), tuple(int)):
Stride in (x, y) direction for patch extraction.
output_resolutions (list(dict)):
Resolution of each output head from model inference, must be
in the same order as target model.infer_batch().
Attributes:
input_resolutions (list(dict)):
Resolution of each input head of model inference, must be in
the same order as `target model.forward()`.
patch_input_shape (:class:`numpy.ndarray`, list(int), tuple(int)):
Shape of the largest input in (height, width).
stride_shape (:class:`numpy.ndarray`, list(int), tuple(int)):
Stride in (x, y) direction for patch extraction.
output_resolutions (list(dict)):
Resolution of each output head from model inference, must be
in the same order as target model.infer_batch().
highest_input_resolution (dict):
Highest resolution to process the image based on input and
output resolutions. This helps to read the image at the optimal
resolution and improves performance.
Examples:
>>> # Defining io for a patch predictor network
>>> ioconfig = IOPatchPredictorConfig(
... input_resolutions=[{"units": "mpp", "resolution": 0.5}],
... output_resolutions=[{"units": "mpp", "resolution": 0.5}],
... patch_input_shape=(224, 224),
... stride_shape=(224, 224),
... )
"""
[docs]
@dataclass
class IOInstanceSegmentorConfig(IOSegmentorConfig):
"""Contains instance segmentor input and output information.
Args:
input_resolutions (list(dict)):
Resolution of each input head of model inference, must be in
the same order as `target model.forward()`.
patch_input_shape (:class:`numpy.ndarray`, list(int), tuple(int, int)):
Shape of the largest input in (height, width).
stride_shape (:class:`numpy.ndarray`, list(int), tuple(int, int)):
Stride in (x, y) direction for patch extraction.
output_resolutions (list(dict)):
Resolution of each output head from model inference, must be
in the same order as target model.infer_batch().
patch_output_shape (:class:`numpy.ndarray`, list(int)):
Shape of the largest output in (height, width).
save_resolution (dict):
Resolution to save all output.
margin (int):
Tile margin to accumulate the output.
tile_shape (tuple(int, int)):
Tile shape to process the WSI.
Attributes:
input_resolutions (list(dict)):
Resolution of each input head of model inference, must be in
the same order as `target model.forward()`.
patch_input_shape (:class:`numpy.ndarray`, list(int), tuple(int, int)):
Shape of the largest input in (height, width).
stride_shape (:class:`numpy.ndarray`, list(int), tuple(int)):
Stride in (x, y) direction for patch extraction.
output_resolutions (list(dict)):
Resolution of each output head from model inference, must be
in the same order as target model.infer_batch().
patch_output_shape (:class:`numpy.ndarray`, list(int)):
Shape of the largest output in (height, width).
save_resolution (dict):
Resolution to save all output.
highest_input_resolution (dict):
Highest resolution to process the image based on input and
output resolutions. This helps to read the image at the optimal
resolution and improves performance.
margin (int):
Tile margin to accumulate the output.
tile_shape (tuple(int, int)):
Tile shape to process the WSI.
Examples:
>>> # Defining io for a network having 1 input and 1 output at the
>>> # same resolution
>>> ioconfig = IOInstanceSegmentorConfig(
... input_resolutions=[{"units": "baseline", "resolution": 1.0}],
... output_resolutions=[{"units": "baseline", "resolution": 1.0}],
... patch_input_shape=(2048, 2048),
... patch_output_shape=(1024, 1024),
... stride_shape=(512, 512),
... margin=128,
... tile_shape=(1024, 1024),
... )
>>> # Defining io for a network having 3 input and 2 output
>>> # at the same resolution, the output is then merged at a
>>> # different resolution.
>>> ioconfig = IOInstanceSegmentorConfig(
... input_resolutions=[
... {"units": "mpp", "resolution": 0.25},
... {"units": "mpp", "resolution": 0.50},
... {"units": "mpp", "resolution": 0.75},
... ],
... output_resolutions=[
... {"units": "mpp", "resolution": 0.25},
... {"units": "mpp", "resolution": 0.50},
... ],
... patch_input_shape=(2048, 2048),
... patch_output_shape=(1024, 1024),
... stride_shape=(512, 512),
... save_resolution={"units": "mpp", "resolution": 4.0},
... margin=128,
... tile_shape=(1024, 1024),
... )
"""
[docs]
def to_baseline(self: IOInstanceSegmentorConfig) -> IOInstanceSegmentorConfig:
"""Returns a new config object converted to baseline form.
This will return a new :class:`IOSegmentorConfig` where
resolutions have been converted to baseline format with the
highest possible resolution found in both input and output as
reference.
"""
return super().to_baseline()