prepare_full_batch¶
- prepare_full_batch(batch_output, batch_locs, full_output_locs, output_locs, canvas_np=None, save_path='temp_fullbatch', memory_threshold=80, *, is_last)[source]¶
Prepare full-sized output and count arrays for a batch of patch predictions.
This function aligns patch-level predictions with global output locations when a mask (e.g., auto_get_mask) is applied. It initializes full-sized arrays and fills them using matched indices. If the batch is the last in the sequence, it pads the arrays to cover remaining locations.
- Parameters:
batch_output (np.ndarray) – Patch-level model predictions of shape (N, H, W, C).
batch_locs (np.ndarray) – Output locations corresponding to batch_output.
full_output_locs (np.ndarray) – Remaining global output locations to be matched.
output_locs (np.ndarray) – Accumulated output location array across batches.
canvas_np (np.ndarray | zarr.Array | None) – Accumulated canvas array from previous batches. Used to check total memory footprint when deciding numpy vs zarr.
save_path (Path | str) – Path to a directory; a unique temp subfolder will be created within it to store the temporary full-batch zarr for this batch.
memory_threshold (int) – Memory usage threshold (in percentage) to trigger caching behavior.
is_last (bool) – Flag indicating whether this is the final batch.
- Returns:
full_batch_output: Full-sized output array with predictions placed.
full_output_locs: Updated remaining global output locations.
output_locs: Updated accumulated output locations.
- Return type:
tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]