prepare_multitask_full_batch¶
- prepare_multitask_full_batch(batch_output, batch_locs, full_output_locs, output_locs, canvas_np=None, save_path='temp_fullbatch', memory_threshold=80, *, is_last)[source]¶
Align patch predictions to the global output index and pad to cover gaps.
This helper prepares a full-sized set of outputs for the current batch by aligning patch-level predictions with the remaining global output locations. It uses the provided full_output_locs (the outstanding locations yet to be filled) to place each patch’s predictions at the correct indices, returning arrays sized to the current span. If this is the final batch (is_last=True), it pads the arrays with zeros to cover any remaining, unmatched output locations and appends those locations to output_locs.
- Concretely:
A lookup is built over full_output_locs so each row in batch_locs maps to a unique index (“match”).
For each head in batch_output, an appropriately sized zero-initialized array is created and the matched batch predictions are placed at the computed indices.
output_locs is extended by the portion of full_output_locs covered in this call; full_output_locs is advanced accordingly.
If is_last=True, the function also appends any remaining locations to output_locs and pads the per-head arrays with zeros so their first dimension matches the updated number of locations.
- Parameters:
batch_output (tuple[np.ndarray]) – Tuple of per-head patch predictions for the current batch. Each element has shape
(N, H, W, C)(head-specific), whereNis the number of patches in the batch.batch_locs (np.ndarray) – Array of output locations (e.g., patch output boxes) corresponding to batch_output. Each row must uniquely identify a location and match rows in full_output_locs.
full_output_locs (np.ndarray) – The remaining global output location array, carrying the canonical order of all locations that should be filled. This is progressively consumed from the front as batches are placed.
output_locs (np.ndarray) – Accumulated output location array across previous batches. This is extended in-place with the portion of full_output_locs filled in this call, and with any remaining tail (zeros padded in outputs) when is_last=True.
canvas_np (tuple[np.ndarray | zarr.Array] | None) – List of accumulated canvas arrays from previous batches. Used to check total memory footprint when deciding numpy vs zarr.
save_path (Path | str) – Path to a directory; a unique temp subfolder will be created within it to store the temporary full-batch zarr for this batch.
memory_threshold (int) – Memory usage threshold (in percentage) to trigger caching behavior.
is_last (bool) – Whether this is the final batch. When True, any locations left in full_output_locs after placing matches are appended to output_locs, and the per-head output arrays are padded with zeros to match the total number of output locations.
- Returns:
full_batch_output (list[np.ndarray]): One array per head containing the aligned outputs for this call. Each has shape
(M, H, W, C), whereMis the number of locations consumed (and possibly padded to include the remaining tail when is_last=True).full_output_locs (np.ndarray): Updated remaining global output locations (the unconsumed tail).
output_locs (np.ndarray): Updated accumulated output locations including those added by this call (and any final tail when is_last=True).
- Return type:
Notes
Ordering is defined by full_output_locs. The number of rows consumed during this call equals
max(match_indices) + 1.Padding on the last batch is performed with zeros of the same dtype as each head’s predictions (uint8 for the padded section in the implementation).
This function is agnostic to the semantic meaning of locations; it only ensures that per-head arrays and the accumulated location index remain consistent across batches.