Source code for tiatoolbox.utils.postproc_defs
"""Module to provide postprocessing classes."""
from __future__ import annotations
import colorsys
import warnings
import numpy as np
[docs]
class MultichannelToRGB:
"""Class to convert multi-channel images to RGB images."""
def __init__(
self: MultichannelToRGB,
color_dict: dict[str, tuple[float, float, float]] | None = None,
) -> None:
"""Initialize the MultichannelToRGB converter.
Args:
color_dict: Dict of channel names with RGB colors for each channel. If not
provided, a set of distinct colors will be auto-generated.
"""
self.colors: np.ndarray | None = None
self.is_validated: bool = False
self.channels: list[int] | None = None
self.enhance: float = 1.0
self.color_dict = color_dict
[docs]
def validate(self: MultichannelToRGB, n: int) -> None:
"""Validate the input color_dict on first read from image.
Checks that n is either equal to the number of colors provided, or is
one less. In the latter case it is assumed that the last channel is background
autofluorescence and is not in the tiff and we will drop it from
the color_dict with a warning.
Args:
n (int): Number of channels
"""
if self.colors is None:
msg = "Colors must be initialized before validation."
raise ValueError(msg)
n_colors = len(self.colors)
if self.channels is None:
self.channels = list(range(n_colors))
if n_colors == n:
self.is_validated = True
return
if n_colors - 1 == n:
self.colors = self.colors[:n]
self.channels = [c for c in self.channels if c < n]
self.is_validated = True
msg = """Number of channels in image is one less than number of channels in
dict. Assuming last channel is background autofluorescence and ignoring
it. If this is not the case please provide a manual color_dict."""
warnings.warn(
msg,
stacklevel=2,
)
return
msg = f"Number of colors: {n_colors} does not match channels in image: {n}."
raise ValueError(msg)
[docs]
def generate_colors(self: MultichannelToRGB, n_channels: int) -> np.ndarray:
"""Generate a set of visually distinct colors.
Args:
n_channels (int): Number of channels/colors to generate
Returns:
np.ndarray: Array of RGB colors
"""
self.color_dict = {
f"channel_{i}": colorsys.hsv_to_rgb(i / n_channels, 1, 1)
for i in range(n_channels)
}
return np.array(list(self.color_dict.values()), dtype=np.float32)
[docs]
def __call__(self: MultichannelToRGB, image: np.ndarray) -> np.ndarray:
"""Convert a multi-channel image to an RGB image.
Args:
image (np.ndarray): Input image of shape (H, W, N)
Returns:
np.ndarray: RGB image of shape (H, W, 3)
"""
n = image.shape[2]
if n < 5: # noqa: PLR2004
# assume already rgb(a) so just return image
return image
colors = self.colors
if colors is None:
colors = self.generate_colors(n)
if not self.is_validated:
self.validate(n)
if image.dtype == np.uint16:
image = (image / 256).astype(np.uint8)
# Convert to RGB image
rgb_image = (
np.einsum(
"hwn,nc->hwc",
image[:, :, self.channels],
colors[self.channels, :],
optimize=True,
)
* self.enhance
)
# Clip to ensure in valid range and return
return np.clip(rgb_image, 0, 255).astype(np.uint8)
def __setattr__(
self: MultichannelToRGB,
name: str,
value: dict[str, tuple[float, float, float]] | None,
) -> None:
"""Ensure that colors is updated if color_dict is updated."""
if name == "color_dict" and value is not None:
self.colors = np.array(list(value.values()), dtype=np.float32)
if getattr(self, "channels", None) is None:
self.channels = list(range(len(value)))
super().__setattr__(name, value)