merge_multitask_horizontal¶
- merge_multitask_horizontal(canvas, count, output_locs_y, canvas_np, output_locs, change_indices)[source]¶
Merge horizontally a run of patch outputs into per-head row blocks.
This helper performs row-wise stitching of patch predictions for multitask heads. It consumes the leftmost segment of
canvas_np(per head) up to each index inchange_indices—which mark where the dataloader advanced to a new row of output patches—and merges that segment into a horizontally concatenated row block for each head. The merged blocks and their per-pixel hit counts are appended tocanvasandcount(as Dask arrays with chunking equal to the merged row height), while the consumed portion is removed fromcanvas_np. The function also updates and returnsoutput_locs(with the consumed locations removed) and accumulates the vertical extents of each merged row inoutput_locs_y_.- For each row segment:
The function determines the row’s horizontal span from
output_locs(min x0, max x1).For each head, it calls
merge_batch_to_canvasto place the segment’s patch outputs into a contiguous row block and an aligned count map.The row block and count map are wrapped as Dask arrays and appended to the running lists in
canvasandcount(one list per head).The segment is removed from
canvas_npandoutput_locs; the segment’s vertical bounds(y0, y1)are appended tooutput_locs_y_.
- Parameters:
canvas (list[da.Array] | list[None]) – Accumulated per-head row blocks (probability/logit sums) as Dask arrays. Each entry grows along the first axis with each merged row. Pass
Nonefor each head on the first call.count (list[da.Array] | list[None]) – Accumulated per-head row count maps, aligned with
canvas. PassNonefor each head on the first call.output_locs_y (np.ndarray) – Accumulated vertical extents of already-merged rows. Each appended element is
[y0, y1]corresponding to the merged row’s span. PassNoneon the first call; it will be initialized internally via concatenation.canvas_np (list[np.ndarray]) – In-memory patch outputs awaiting merge, one list entry per head. Each head’s entry is a NumPy array of stacked patch outputs for the current unmerged part of the row, with shape
(N_seg, H, W, C)for the segment being merged.output_locs (np.ndarray) – Output placement boxes for the awaiting patches in
canvas_np, shaped(N_pending, 4)as[x0, y0, x1, y1]. The function consumes from the front up to eachchange_indicesboundary and returns the remaining tail.change_indices (np.ndarray | list[int]) – Sorted indices (relative to the current
output_locs) where a row change occurs. Each index marks the end of a contiguous row segment to be merged in this call.
- Returns:
canvas: Updated list of per-head Dask arrays containing concatenated row blocks (values are sums; normalization happens later).count: Updated list of per-head Dask arrays containing concatenated row hit counts for normalization.canvas_np: Updated in-memory per-head arrays with consumed segment removed.output_locs: Updated placement boxes with the consumed segment removed.output_locs_y_: Updated array of accumulated vertical row extents, with the new row’s[y0, y1]appended.
- Return type:
tuple[list[da.Array], list[da.Array], list[np.ndarray], np.ndarray, np.ndarray]
Notes
The merged row block shape per head is
(row_height, row_width, C), where:row_heightis the head’s patch output height,row_widthismax(x1) - min(x0)for the row,Cis the number of channels for that head.
merge_batch_to_canvashandles placement and accumulation of overlapping patch outputs and produces a matching count map.Normalization (division by counts) is not performed here; it is done later during vertical merging to form the final probability maps.
Dask chunking is set to the full row height to facilitate subsequent vertical concatenation and overlap handling.