merge_multitask_horizontal¶

merge_multitask_horizontal(canvas, count, output_locs_y, canvas_np, output_locs, change_indices)[source]¶

Merge horizontally a run of patch outputs into per-head row blocks.

This helper performs row-wise stitching of patch predictions for multitask heads. It consumes the leftmost segment of canvas_np (per head) up to each index in change_indices—which mark where the dataloader advanced to a new row of output patches—and merges that segment into a horizontally concatenated row block for each head. The merged blocks and their per-pixel hit counts are appended to canvas and count (as Dask arrays with chunking equal to the merged row height), while the consumed portion is removed from canvas_np. The function also updates and returns output_locs (with the consumed locations removed) and accumulates the vertical extents of each merged row in output_locs_y_.

For each row segment:
  1. The function determines the row’s horizontal span from output_locs (min x0, max x1).

  2. For each head, it calls merge_batch_to_canvas to place the segment’s patch outputs into a contiguous row block and an aligned count map.

  3. The row block and count map are wrapped as Dask arrays and appended to the running lists in canvas and count (one list per head).

  4. The segment is removed from canvas_np and output_locs; the segment’s vertical bounds (y0, y1) are appended to output_locs_y_.

Parameters:
  • canvas (list[da.Array] | list[None]) – Accumulated per-head row blocks (probability/logit sums) as Dask arrays. Each entry grows along the first axis with each merged row. Pass None for each head on the first call.

  • count (list[da.Array] | list[None]) – Accumulated per-head row count maps, aligned with canvas. Pass None for each head on the first call.

  • output_locs_y (np.ndarray) – Accumulated vertical extents of already-merged rows. Each appended element is [y0, y1] corresponding to the merged row’s span. Pass None on the first call; it will be initialized internally via concatenation.

  • canvas_np (list[np.ndarray]) – In-memory patch outputs awaiting merge, one list entry per head. Each head’s entry is a NumPy array of stacked patch outputs for the current unmerged part of the row, with shape (N_seg, H, W, C) for the segment being merged.

  • output_locs (np.ndarray) – Output placement boxes for the awaiting patches in canvas_np, shaped (N_pending, 4) as [x0, y0, x1, y1]. The function consumes from the front up to each change_indices boundary and returns the remaining tail.

  • change_indices (np.ndarray | list[int]) – Sorted indices (relative to the current output_locs) where a row change occurs. Each index marks the end of a contiguous row segment to be merged in this call.

Returns:

  • canvas: Updated list of per-head Dask arrays containing concatenated row blocks (values are sums; normalization happens later).

  • count: Updated list of per-head Dask arrays containing concatenated row hit counts for normalization.

  • canvas_np: Updated in-memory per-head arrays with consumed segment removed.

  • output_locs: Updated placement boxes with the consumed segment removed.

  • output_locs_y_: Updated array of accumulated vertical row extents, with the new row’s [y0, y1] appended.

Return type:

tuple[list[da.Array], list[da.Array], list[np.ndarray], np.ndarray, np.ndarray]

Notes

  • The merged row block shape per head is (row_height, row_width, C), where:

    • row_height is the head’s patch output height,

    • row_width is max(x1) - min(x0) for the row,

    • C is the number of channels for that head.

  • merge_batch_to_canvas handles placement and accumulation of overlapping patch outputs and produces a matching count map.

  • Normalization (division by counts) is not performed here; it is done later during vertical merging to form the final probability maps.

  • Dask chunking is set to the full row height to facilitate subsequent vertical concatenation and overlap handling.