Source code for megafish.register

import numpy as np
import pandas as pd
import zarr
import xarray as xr
from dask.diagnostics import ProgressBar
import dask.array as da

from skimage.feature import SIFT, match_descriptors
from skimage.measure import ransac

from .config import USE_GPU, show_resource

if USE_GPU:
    import cupy as cp
    from cucim.skimage.registration import phase_cross_correlation
    from cucim.skimage.transform import ProjectiveTransform, AffineTransform, \
        warp
    from cupyx.scipy.ndimage import shift as nd_shift
else:
    from skimage.registration import phase_cross_correlation
    from skimage.transform import ProjectiveTransform, AffineTransform, warp
    from scipy.ndimage import shift as nd_shift


def _get_yx_dims(arr):
    shape = arr.shape
    keep_dims = len(shape) - 2
    slices = [0] * keep_dims + [slice(None), slice(None)]
    return arr[tuple(slices)]


[docs] def shift_cycle_cYXyx( zarr_path, group, sift_kwargs=None, match_kwargs=None, ransac_kwargs=None, subfooter="", footer="_shift_cycle"): """ Calculates and stores cycle shifts for aligning image tiles based on phase correlation and feature matching. If the SIFT detector parameters are not provided, only phase correlation is used to calculate shifts. Args: zarr_path (str): Path to the Zarr file containing image data. group (str): Group name in the Zarr file where image data is stored. sift_kwargs (dict, optional): Parameters for the SIFT detector; defaults to None. match_kwargs (dict, optional): Parameters for matching descriptors; defaults to None. ransac_kwargs (dict, optional): Parameters for RANSAC transformation; defaults to None. subfooter (str, optional): String to append before the footer in the output CSV filename; defaults to an empty string. footer (str, optional): String appended to the output CSV filename; defaults to "_shift_cycle". Returns: None: The function saves shift information as a CSV file. """ def _shift_cycle( mov_tiles, ref_tiles, sift_kwargs=None, match_kwargs=None, ransac_kwargs=None): """ Calculates the shift matrix for a single cycle using phase correlation and SIFT feature matching. Args: mov_tiles (ndarray): Moving image tiles (to be aligned). ref_tiles (ndarray): Reference image tiles (alignment target). sift_kwargs (dict, optional): Parameters for the SIFT detector. match_kwargs (dict, optional): Parameters for matching descriptors. ransac_kwargs (dict, optional): Parameters for RANSAC transformation. Returns: ndarray: A shift matrix for the current cycle. """ # Extract image data and replace NaNs with zeros ref_img = _get_yx_dims(ref_tiles) ref_img = np.nan_to_num(ref_img) mov_img = _get_yx_dims(mov_tiles) mov_img = np.nan_to_num(mov_img) # Move data to GPU if available if USE_GPU: ref_img = cp.asarray(ref_img) mov_img = cp.asarray(mov_img) # Normalize images for phase correlation ref_img = (ref_img - ref_img.min()) / (ref_img.max() - ref_img.min()) mov_img = (mov_img - mov_img.min()) / (mov_img.max() - mov_img.min()) # Replace NaNs after normalization if USE_GPU: ref_img = cp.nan_to_num(ref_img) mov_img = cp.nan_to_num(mov_img) else: ref_img = np.nan_to_num(ref_img) mov_img = np.nan_to_num(mov_img) # Apply phase correlation to estimate shift shift, error, diffphase = phase_cross_correlation( ref_img, mov_img, normalization=None) # (y, x) H_shift = AffineTransform(translation=(-shift[1], -shift[0])) # Initialize the shift matrix with the same shape keep_dims = len(mov_tiles.shape) - 2 shift_matrix_shape = (1,) * keep_dims + (2, 9) shift_matrix = np.zeros(shift_matrix_shape, dtype=np.float32) append_slices = tuple([0] * (keep_dims + 1) + [slice(None)]) # If SIFT parameters are not provided, return the basic shift matrix if sift_kwargs is None: if USE_GPU: shift_matrix[append_slices] = H_shift.params.flatten().get() else: shift_matrix[append_slices] = H_shift.params.flatten() return shift_matrix # Shift moving image based on phase correlation result if USE_GPU: mov_img = nd_shift(mov_img, shift).get() ref_img = ref_img.get() else: mov_img = nd_shift(mov_img, shift) # Detect keypoints using SIFT detector_extractor_ref = SIFT(**sift_kwargs) try: detector_extractor_ref.detect_and_extract(ref_img) except RuntimeError: keypoints_ref = np.zeros((0, 2), dtype=np.float32) else: keypoints_ref = detector_extractor_ref.keypoints # (x, y) detector_extractor_mov = SIFT(**sift_kwargs) try: detector_extractor_mov.detect_and_extract(mov_img) except RuntimeError: keypoints_mov = np.zeros((0, 2), dtype=np.float32) else: keypoints_mov = detector_extractor_mov.keypoints # (x, y) # Match descriptors and compute transformation if len(keypoints_ref) == 0 or len(keypoints_mov) == 0: H = AffineTransform(translation=(0, 0)) match_keys_ref = np.zeros((0, 2), dtype=np.float32) match_keys_mov = np.zeros((0, 2), dtype=np.float32) inliers = np.zeros(0, dtype=bool) else: matches = match_descriptors( keypoints_ref, keypoints_mov, **match_kwargs) match_keys_ref = keypoints_ref[matches[:, 0]] # (x, y) match_keys_mov = keypoints_mov[matches[:, 1]] # (x, y) # Apply RANSAC if sufficient matches are found if len(match_keys_ref) < 4: H = AffineTransform(translation=(0, 0)) else: H, inliers = ransac( (np.flip(match_keys_mov, axis=-1), np.flip(match_keys_ref, axis=-1)), ProjectiveTransform, **ransac_kwargs) # Compute the final transformation and update the shift matrix if USE_GPU: H_inv = cp.linalg.inv(cp.asarray(H.params)) H = ProjectiveTransform(H_shift.params @ H_inv) shift_matrix[append_slices] = H.params.flatten().get() else: H_inv = np.linalg.inv(H.params) H = ProjectiveTransform(H_shift.params @ H_inv) shift_matrix[append_slices] = H.params.flatten() return shift_matrix # Open Zarr file and get data array root = zarr.open(zarr_path) zr = root[group + "/0"]["data"] da_zr = da.from_zarr(zr) n_cycle, n_tile_y, n_tile_x, n_y, n_x = da_zr.shape # Reference tiles for alignment ref_tiles = da_zr[0, :, :, :, :] # Map shifts across cycles shifts = da.map_blocks( _shift_cycle, da_zr, ref_tiles, sift_kwargs, match_kwargs, ransac_kwargs, dtype=np.float32, chunks=(1, 1, 1, 2, 9)) print("Calculating cycle shifts: " + group + show_resource()) with ProgressBar(): shift_matrix = shifts.compute() # Reshape shift matrix for saving as CSV n_rows = n_cycle * n_tile_y * n_tile_x n_cols = 9 * 2 n_indices = 3 shift_matrix_reshape = shift_matrix.reshape(n_rows, n_cols)[:, :9] index_matrix = np.indices( (n_cycle, n_tile_y, n_tile_x)).reshape(n_indices, n_rows).T shift_cols = ["shift_" + str(i) for i in range(9)] index_shift_matrix = np.concatenate( (index_matrix, shift_matrix_reshape), axis=1) shifts_df = pd.DataFrame( index_shift_matrix, columns=["cycle", "tile_y", "tile_x"] + shift_cols) # Save the shift matrix as a CSV file shifts_df.to_csv(zarr_path.replace( ".zarr", subfooter + footer + ".csv"), index=False)
[docs] def shift_tile_cYXyx( zarr_path, group_mov, group_stitched, max_shift=100, sift_kwargs=None, match_kwargs=None, ransac_kwargs=None, subfooter="", footer="_shift_tile"): """ Calculates and stores tile shifts for aligning image tiles based on phase correlation and feature matching. If the SIFT detector parameters are not provided, only phase correlation is used to calculate shifts. Args: zarr_path (str): Path to the Zarr file containing image data. group_mov (str): Group name in the Zarr file for the moving image data. group_stitched (str): Group name in the Zarr file for the reference stitched image data. max_shift (int, optional): Maximum allowed shift (in pixels); defaults to 100. sift_kwargs (dict, optional): Parameters for the SIFT detector; defaults to None. match_kwargs (dict, optional): Parameters for matching descriptors; defaults to None. ransac_kwargs (dict, optional): Parameters for RANSAC transformation; defaults to None. subfooter (str, optional): String to append before the footer in the output CSV filename; defaults to an empty string. footer (str, optional): String appended to the output CSV filename; defaults to "_shift_tile". Returns: None: The function saves the shift matrix as a CSV file without returning any value. """ def _shift_tile(img, max_shift, sift_kwargs, match_kwargs, ransac_kwargs): """ Calculates the shift matrix for a single tile using phase correlation and SIFT feature matching. Args: img (xarray.DataArray): Image tiles with dimensions for reference and moving images. max_shift (int): Maximum allowed shift (in pixels). sift_kwargs (dict, optional): Parameters for the SIFT detector. match_kwargs (dict, optional): Parameters for matching descriptors. ransac_kwargs (dict, optional): Parameters for RANSAC transformation. Returns: xarray.DataArray: A shift matrix for the current tile. """ # Extract and preprocess reference and moving images ref_tiles = img.sel(refmov="ref").values ref_img = _get_yx_dims(ref_tiles) ref_img = np.nan_to_num(ref_img) mov_tiles = img.sel(refmov="mov").values mov_img = _get_yx_dims(mov_tiles) mov_img = np.nan_to_num(mov_img) # Move to GPU if available if USE_GPU: ref_img = cp.asarray(ref_img) mov_img = cp.asarray(mov_img) # Normalize images for phase correlation ref_img = (ref_img - ref_img.min()) / (ref_img.max() - ref_img.min()) mov_img = (mov_img - mov_img.min()) / (mov_img.max() - mov_img.min()) # Replace NaNs after normalization if USE_GPU: ref_img = cp.nan_to_num(ref_img) mov_img = cp.nan_to_num(mov_img) else: ref_img = np.nan_to_num(ref_img) mov_img = np.nan_to_num(mov_img) # Apply phase correlation to estimate shift shift, error, diffphase = phase_cross_correlation( ref_img, mov_img, normalization=None) # (y, x) H_shift = AffineTransform(translation=(-shift[1], -shift[0])) # Initialize the shift matrix shift_matrix = np.zeros((1, 1, 9), dtype=np.float32) if sift_kwargs is None: if USE_GPU: shift_matrix[0, 0, :] = H_shift.params.flatten().get() else: shift_matrix[0, 0, :] = H_shift.params.flatten() else: # Shift moving image based on phase correlation result if USE_GPU: mov_img = nd_shift(mov_img, shift).get() ref_img = ref_img.get() else: mov_img = nd_shift(mov_img, shift) # Detect keypoints using SIFT detector_extractor_ref = SIFT(**sift_kwargs) try: detector_extractor_ref.detect_and_extract(ref_img) except RuntimeError: keypoints_ref = np.zeros((0, 2), dtype=np.float32) else: keypoints_ref = detector_extractor_ref.keypoints # (x, y) detector_extractor_mov = SIFT(**sift_kwargs) try: detector_extractor_mov.detect_and_extract(mov_img) except RuntimeError: keypoints_mov = np.zeros((0, 2), dtype=np.float32) else: keypoints_mov = detector_extractor_mov.keypoints # (x, y) # Match keypoints and estimate transformation if len(keypoints_ref) == 0 or len(keypoints_mov) == 0: H = AffineTransform(translation=(0, 0)) match_keys_ref = np.zeros((0, 2), dtype=np.float32) match_keys_mov = np.zeros((0, 2), dtype=np.float32) inliers = np.zeros(0, dtype=bool) else: matches = match_descriptors( keypoints_ref, keypoints_mov, **match_kwargs) match_keys_ref = keypoints_ref[matches[:, 0]] # (x, y) match_keys_mov = keypoints_mov[matches[:, 1]] # (x, y) if len(match_keys_ref) < 4: H = AffineTransform(translation=(0, 0)) else: H, inliers = ransac( (np.flip(match_keys_mov, axis=-1), np.flip(match_keys_ref, axis=-1)), ProjectiveTransform, **ransac_kwargs) # Apply the transformation and enforce max shift constraint if USE_GPU: H_inv = cp.linalg.inv(cp.asarray(H.params)) H = ProjectiveTransform(H_shift.params @ H_inv) if cp.linalg.norm(H.params[0:2, 2]) > max_shift: H = AffineTransform(translation=(0, 0)) shift_matrix[0, 0, :] = H.params.flatten().get() else: H_inv = np.linalg.inv(H.params) H = ProjectiveTransform(H_shift.params @ H_inv) if np.linalg.norm(H.params[0:2, 2]) > max_shift: H = AffineTransform(translation=(0, 0)) shift_matrix[0, 0, :] = H.params.flatten() # Return the shift matrix as an xarray DataArray res = xr.DataArray( shift_matrix, dims=["tile_y", "tile_x", "shift"], coords={ "tile_y": img.coords["tile_y"], "tile_x": img.coords["tile_x"], "shift": np.arange(9)}) return res # Open Zarr groups for reference and moving images group_ref = group_stitched + "/0" root = xr.open_zarr(zarr_path, group=group_ref) xar_ref = root["data"] xar_ref = xar_ref.expand_dims(dim={"refmov": ["ref"]}, axis=[0]) group_mov = group_mov + "/0" root = xr.open_zarr(zarr_path, group=group_mov) xar_mov = root["data"] xar_mov = xar_mov.isel(cycle=0) # TODO n_tile_y, n_tile_x, n_y, n_x = xar_mov.shape xar_mov = xar_mov.expand_dims(dim={"refmov": ["mov"]}, axis=[0]) # Concatenate reference and moving images xar_in = xr.concat([xar_ref, xar_mov], dim="refmov") xar_in = xar_in.chunk({ "refmov": 2, "tile_y": 1, "tile_x": 1, "y": n_y, "x": n_x}) # Set up a template for the output shifts new_dims = ["tile_y", "tile_x", "shift"] new_coords = { "tile_y": np.arange(n_tile_y), "tile_x": np.arange(n_tile_x), "shift": np.arange(9)} template = xr.DataArray( da.empty((n_tile_y, n_tile_x, 9), dtype=np.float32, chunks=(1, 1, 9)), dims=new_dims, coords=new_coords) # Map shifts across tiles res = xar_in.map_blocks( _shift_tile, kwargs={ "max_shift": max_shift, "sift_kwargs": sift_kwargs, "match_kwargs": match_kwargs, "ransac_kwargs": ransac_kwargs}, template=template) print("Calculating tile shifts: " + group_mov + show_resource()) with ProgressBar(): res = res.compute() # Reshape and save the shift matrix as a CSV n_tile_y, n_tile_x, n_shift = res.shape n_rows = n_tile_y * n_tile_x n_cols = 9 shift_matrix = res.values.reshape(n_rows, n_cols) index_matrix = np.indices((n_tile_y, n_tile_x)).reshape( 2, n_rows).T sift_cols = ["shift_" + str(i) for i in range(9)] index_shift_matrix = np.concatenate( (index_matrix, shift_matrix), axis=1) shifts_df = pd.DataFrame( index_shift_matrix, columns=["tile_y", "tile_x"] + sift_cols) shifts_df.to_csv(zarr_path.replace( ".zarr", subfooter + footer + ".csv"), index=False)
[docs] def dummy_shift_tile(zarr_path, shift_cycle_footer, subfooter="", footer="_shift_tile"): """ Creates a dummy tile shifts CSV file with identity transformation values. Args: zarr_path (str): Path to the Zarr file associated with the image data. shift_cycle_footer (str): Footer of the shift cycle CSV file to read and modify. subfooter (str, optional): String to append before the footer in the output CSV filename; defaults to an empty string. footer (str, optional): String appended to the output CSV filename; defaults to "_shift_tile". Returns: None: The function saves a dummy shifts CSV file. """ print("Creating dummy tile shifts csv") # Read the shift cycle CSV file shift_cycle_path = zarr_path.replace(".zarr", shift_cycle_footer + ".csv") df = pd.read_csv(shift_cycle_path) # Remove the 'cycle' column df = df[df["cycle"] == 0] df = df.drop(columns=["cycle"]) # Set all shift values to create an identity transformation df["shift_0"] = 1 df["shift_1"] = 0 df["shift_2"] = 0 df["shift_3"] = 0 df["shift_4"] = 1 df["shift_5"] = 0 df["shift_6"] = 0 df["shift_7"] = 0 df["shift_8"] = 1 # Save the modified DataFrame as a new CSV file df.to_csv(zarr_path.replace( ".zarr", subfooter + footer + ".csv"), index=False)
[docs] def merge_shift_cYXyx( zarr_path, group, subfooter="", cycle_footer="_shift_cycle", tile_footer="_shift_tile", footer="_shift_tile_cycle"): """ Merges cycle and tile shift transformations and saves the combined shifts as a CSV file. Args: zarr_path (str): Path to the Zarr file associated with the image data. group (str): Group name in the Zarr file where image data is stored. subfooter (str, optional): String to append before the footer in the output CSV filename; defaults to an empty string. cycle_footer (str, optional): Footer of the cycle shift CSV file; defaults to "_shift_cycle". tile_footer (str, optional): Footer of the tile shift CSV file; defaults to "_shift_tile". footer (str, optional): String appended to the output CSV filename; defaults to "_shift_tile_cycle". Returns: None: The function saves the merged shifts as a CSV file. """ # Paths for the shift files shift_tile_path = zarr_path.replace( ".zarr", subfooter + tile_footer + ".csv") shift_cycle_path = zarr_path.replace( ".zarr", subfooter + cycle_footer + ".csv") # Load cycle and tile shift data shifts_tile_df = pd.read_csv(shift_tile_path) shifts_cycle_df = pd.read_csv(shift_cycle_path) # Load image data from the Zarr file ds = xr.open_zarr(zarr_path, group=group + "/0") xar = ds["data"] n_cycle, n_tile_y, n_tile_x, n_y, n_x = xar.shape shift_cols = ["shift_" + str(i) for i in range(9)] cycles = [] tiles_y = [] tiles_x = [] shifts = [] # Iterate through all cycles, tiles in y, and tiles in x dimensions for cycle in range(n_cycle): for tile_y in range(n_tile_y): for tile_x in range(n_tile_x): # Filter the tile and cycle shift data for the current cycle, tile_y, and tile_x shift_tile = shifts_tile_df[ (shifts_tile_df["tile_y"] == tile_y) & (shifts_tile_df["tile_x"] == tile_x) ] shift_cycle = shifts_cycle_df[ (shifts_cycle_df["cycle"] == cycle) & (shifts_cycle_df["tile_y"] == tile_y) & (shifts_cycle_df["tile_x"] == tile_x) ] # Extract the transformation matrices from the filtered data shift_tile = shift_tile[shift_cols].values shift_cycle = shift_cycle[shift_cols].values # Reshape matrices into 3x3 format H_tile = shift_tile.reshape(3, 3) H_cycle = shift_cycle.reshape(3, 3) # Apply the combined transformation (cycle shift followed by tile shift) H = H_cycle @ H_tile # Store the combined transformation and indices shifts.append(H.flatten()) cycles.append(cycle) tiles_y.append(tile_y) tiles_x.append(tile_x) # Combine indices and shifts into a DataFrame index_array = np.array([cycles, tiles_y, tiles_x]).T shifts = np.array(shifts) index_shifts = np.concatenate((index_array, shifts), axis=1) shifts_df = pd.DataFrame( index_shifts, columns=["cycle", "tile_y", "tile_x"] + shift_cols) # Save the merged shift data as a CSV file shifts_df.to_csv(zarr_path.replace( ".zarr", subfooter + footer + ".csv"), index=False)
[docs] def get_edges(n_cycle, n_tile_y, n_tile_x, df_shift, n_y_stitched, n_x_stitched, n_y, n_x, margin=500): """ Calculates the edges of tiles after applying transformations, taking into account tile positions and transformations. Args: n_cycle (int): Number of cycles. n_tile_y (int): Number of tiles along the y-axis. n_tile_x (int): Number of tiles along the x-axis. df_shift (pandas.DataFrame): DataFrame containing shift transformation matrices for each cycle and tile. n_y_stitched (int): Height of each stitched tile. n_x_stitched (int): Width of each stitched tile. n_y (int): Height of each individual tile. n_x (int): Width of each individual tile. margin (int, optional): Margin added to the edges for buffer space; defaults to 500. Returns: pandas.DataFrame: A DataFrame containing the min and max x and y coordinates of each tile after transformation, including the cycle, tile_y, and tile_x. """ edges = [] # Iterate through all cycles, tiles in y, and tiles in x dimensions for cycle in range(n_cycle): for tile_y in range(n_tile_y): for tile_x in range(n_tile_x): # Define the original tile edges (corners) edge = ((0, 0), (n_x, 0), (0, n_y), (n_x, n_y)) # Calculate the offset based on tile position offset = (tile_x * n_x_stitched, tile_y * n_y_stitched) # Apply the offset to each corner edge_offset = [ (edge[0][0] + offset[0], edge[0][1] + offset[1]), (edge[1][0] + offset[0], edge[1][1] + offset[1]), (edge[2][0] + offset[0], edge[2][1] + offset[1]), (edge[3][0] + offset[0], edge[3][1] + offset[1])] # Get the transformation matrix H for the current tile shift = df_shift[ (df_shift["cycle"] == cycle) & (df_shift["tile_y"] == tile_y) & (df_shift["tile_x"] == tile_x)] shift_cols = ["shift_" + str(i) for i in range(9)] H_mat = shift[shift_cols].values[0].reshape(3, 3) H_inv = np.linalg.inv(H_mat) if USE_GPU: H_inv = cp.asarray(H_inv) # Apply the inverse transformation to the edge coordinates H = ProjectiveTransform(matrix=H_inv) if USE_GPU: edge_offset = cp.array(edge_offset).T edge_offset = cp.vstack((edge_offset, cp.ones( edge_offset.shape[1]))) # make (x, y, 1) else: edge_offset = np.array(edge_offset).T edge_offset = np.vstack((edge_offset, np.ones( edge_offset.shape[1]))) # make (x, y, 1) # Apply the transformation matrix edge_offset = H.params @ edge_offset # apply H # remove the homogeneous coordinate edge_offset = edge_offset[:2, :].T # Calculate the bounding box with margin max_x = edge_offset[:, 0].max() + margin min_x = edge_offset[:, 0].min() - margin max_y = edge_offset[:, 1].max() + margin min_y = edge_offset[:, 1].min() - margin # Append the edge information edges.append( [cycle, tile_y, tile_x, min_y, max_y, min_x, max_x]) # Create a DataFrame to store all edges edges = pd.DataFrame( edges, columns=["cycle", "tile_y", "tile_x", "min_y", "max_y", "min_x", "max_x"]) return edges
[docs] def create_chunk_dataframe(shape, chunk_size): """ Creates a DataFrame representing the coordinates of chunks within a grid based on the specified shape and chunk size. Args: shape (tuple of int): The dimensions of the grid (height, width). chunk_size (tuple of int): The size of each chunk (chunk_height, chunk_width). Returns: pandas.DataFrame: A DataFrame containing the chunk indices and coordinates, with columns: 'chunk_y', 'chunk_x', 'upper_y', 'lower_y', 'left_x', 'right_x'. """ def _normalize_chunks(chunks, shape): """ Calculates the number and size of chunks needed to cover the entire grid. Args: chunks (tuple of int): The size of each chunk (chunk_height, chunk_width). shape (tuple of int): The dimensions of the grid (height, width). Returns: list of list of int: A list where each element is a list containing the sizes of each chunk dimension. """ num_chunks = [(shape[i] + chunks[i] - 1) // chunks[i] for i in range(len(shape))] chunk_sizes = [ [chunks[i]] * (num_chunks[i] - 1) + [ shape[i] - chunks[i] * (num_chunks[i] - 1)] for i in range(len(shape))] return chunk_sizes def _get_chunk_coordinates(shape, chunk_size): """ Generates the coordinates of each chunk within the grid. Args: shape (tuple of int): The dimensions of the grid (height, width). chunk_size (tuple of int): The size of each chunk (chunk_height, chunk_width). Yields: tuple: Chunk indices and coordinates (chunk_y, chunk_x, upper_y, lower_y, left_x, right_x). """ chunk_dims = _normalize_chunks(chunk_size, shape) for y, _ in enumerate(chunk_dims[0]): for x, _ in enumerate(chunk_dims[1]): yield y, x, sum(chunk_dims[0][:y]), \ sum(chunk_dims[0][:y + 1]), sum(chunk_dims[1][:x]), \ sum(chunk_dims[1][:x + 1]) data = list(_get_chunk_coordinates(shape, chunk_size)) return pd.DataFrame(data, columns=[ 'chunk_y', 'chunk_x', 'upper_y', 'lower_y', 'left_x', 'right_x'])
[docs] def get_overlap(chunk_sel, tiles_df, cycle): """ Finds the overlapping tiles within a specified chunk for a given cycle. Args: chunk_sel (dict): A dictionary containing the coordinates of the chunk with keys: 'left_x', 'right_x', 'upper_y', 'lower_y'. tiles_df (pandas.DataFrame): A DataFrame containing tile information with columns: 'cycle', 'min_x', 'max_x', 'min_y', 'max_y'. cycle (int): The cycle number for which to find overlapping tiles. Returns: pandas.DataFrame: A DataFrame of tiles that overlap with the specified chunk in the given cycle. """ # Filter tiles for the specified cycle tiles_df_cycle = tiles_df[ (tiles_df["cycle"] == cycle)] # Determine if any part of the tile's width overlaps with the chunk's width right_in = (chunk_sel["left_x"] <= tiles_df_cycle["max_x"]) & ( tiles_df_cycle["max_x"] <= chunk_sel["right_x"]) left_in = (chunk_sel["left_x"] <= tiles_df_cycle["min_x"]) & ( tiles_df_cycle["min_x"] <= chunk_sel["right_x"]) width_in = (tiles_df_cycle["min_x"] <= chunk_sel["left_x"]) & ( chunk_sel["right_x"] <= tiles_df_cycle["max_x"]) or_width_in = right_in | left_in | width_in # Determine if any part of the tile's height overlaps with the chunk's height upper_in = (chunk_sel["upper_y"] <= tiles_df_cycle["max_y"]) & ( tiles_df_cycle["max_y"] <= chunk_sel["lower_y"]) lower_in = (chunk_sel["upper_y"] <= tiles_df_cycle["min_y"]) & ( tiles_df_cycle["min_y"] <= chunk_sel["lower_y"]) height_in = (tiles_df_cycle["min_y"] <= chunk_sel["upper_y"]) & ( chunk_sel["lower_y"] <= tiles_df_cycle["max_y"]) or_height_in = upper_in | lower_in | height_in # Return tiles that overlap in both width and height return tiles_df_cycle[or_width_in & or_height_in]
def _register_chunk(input_img, zarr_path, group_name, df_chunk, df_tile, df_H, n_y, n_x, chunk_size, block_info=None): """ Registers and stitches image chunks based on transformation matrices and tile offsets. Args: input_img (numpy.ndarray): The input image data. zarr_path (str): Path to the Zarr file containing the image data. group_name (str): Name of the group in the Zarr file where data is stored. df_chunk (pandas.DataFrame): DataFrame containing chunk information (coordinates and dimensions). df_tile (pandas.DataFrame): DataFrame containing tile information for each cycle. df_H (pandas.DataFrame): DataFrame containing the transformation matrices for each tile. n_y (int): Height of each individual tile. n_x (int): Width of each individual tile. chunk_size (tuple of int): The size of each chunk (chunk_height, chunk_width). block_info (dict, optional): Information about the current block being processed in Dask. Returns: numpy.ndarray: The registered image chunk after stitching and applying transformations. """ # Extract cycle and chunk coordinates from block_info cycle = block_info[0]["chunk-location"][0] chunk_y = block_info[0]["chunk-location"][1] chunk_x = block_info[0]["chunk-location"][2] # Load image data from Zarr dar_img = da.from_zarr(zarr_path, component=group_name + "/0/data") # Select the current chunk based on chunk coordinates chunk_sel = df_chunk[ (df_chunk["chunk_y"] == chunk_y) & (df_chunk["chunk_x"] == chunk_x)].iloc[0] # Find overlapping tiles for the current chunk and merge with transformation matrices overlap = get_overlap(chunk_sel, df_tile, cycle) overlap = overlap.merge(df_H, on=["cycle", "tile_y", "tile_x"]) # Calculate offsets for each overlapping tile overlap["offset_y"] = overlap["tile_y"] * n_y - chunk_sel["upper_y"] overlap["offset_x"] = overlap["tile_x"] * n_x - chunk_sel["left_x"] # Initialize an empty array for the registered tile image tile_img = cp.zeros(chunk_size) if USE_GPU else np.zeros(chunk_size) # Iterate through each overlapping tile and apply transformations for i, row in overlap.iterrows(): shift_cols = ["shift_" + str(i) for i in range(9)] H_mat = row[shift_cols].values.reshape(3, 3).astype(np.float32) if USE_GPU: H_mat = cp.array(H_mat) offset = cp.array([-row["offset_x"], -row["offset_y"]]) H_offset = cp.eye(3) else: offset = np.array([-row["offset_x"], -row["offset_y"]]) H_offset = np.eye(3) # Apply the offset to the transformation matrix H_offset[:2, 2] = offset H_mat = H_mat @ H_offset H = AffineTransform(matrix=H_mat) # Extract the image data for the current tile tile_img_add = dar_img[cycle, int(row["tile_y"]), int(row["tile_x"])] if USE_GPU: tile_img_add = cp.asarray(tile_img_add.compute()) else: tile_img_add = tile_img_add.compute() # Apply the transformation and warp the image tile_img_add = warp(tile_img_add, H, output_shape=chunk_size, preserve_range=True, order=0) # Take the maximum pixel values to merge the tile image tile_img = np.maximum(tile_img, tile_img_add) # Adjust the image dimensions based on chunk boundaries upper_y = chunk_sel["upper_y"] lower_y = chunk_sel["lower_y"] left_x = chunk_sel["left_x"] right_x = chunk_sel["right_x"] if lower_y - upper_y != chunk_size[0] or \ right_x - left_x != chunk_size[1]: tile_img = tile_img[:lower_y - upper_y, :right_x - left_x] # Prepare the output image array out_img = np.zeros((1, chunk_size[0], chunk_size[1])) if USE_GPU: tile_img = tile_img.get() out_img[0, :tile_img.shape[0], :tile_img.shape[1]] = tile_img return out_img
[docs] def registration_cYXyx(zarr_path, group_tile, group_ref, chunk_size, subfooter="", shift_footer="_shift_tile_cycle", footer="_reg"): """ Registers and stitches image tiles based on transformation matrices, creating a registered dataset in Zarr format. Args: zarr_path (str): Path to the Zarr file containing the image data. group_tile (str): Group name in the Zarr file for the tiles to be registered. group_ref (str): Group name in the Zarr file for the reference stitched images. chunk_size (tuple of int): Size of each chunk (chunk_height, chunk_width). subfooter (str, optional): String to append before the shift footer in the output filename; defaults to an empty string. shift_footer (str, optional): Footer of the shift CSV file; defaults to "_shift_tile_cycle". footer (str, optional): String appended to the output Zarr group name; defaults to "_reg". Returns: None: The function saves the registered and stitched images to a new Zarr group. """ # Load transformation matrices from the shift file shift_path = zarr_path.replace( ".zarr", subfooter + shift_footer + ".csv") df_H = pd.read_csv(shift_path) # Load the image data from the Zarr file dar_img = da.from_zarr(zarr_path, component=group_tile + "/0/data") n_cycle, n_tile_y, n_tile_x, n_y, n_x = dar_img.shape # Load the reference stitched images from the Zarr file dar_stitched = da.from_zarr(zarr_path, component=group_ref + "/0/data") n_tile_stiched_y, n_tile_stiched_x, n_y_stitched, n_x_stitched = \ dar_stitched.shape # Calculate tile edges based on transformations df_tile = get_edges(n_cycle, n_tile_y, n_tile_x, df_H, n_y_stitched, n_x_stitched, n_y, n_x) # Create a DataFrame representing chunk coordinates shape = (n_y_stitched * n_tile_stiched_y, n_x_stitched * n_tile_stiched_x) df_chunk = create_chunk_dataframe(shape, chunk_size) # Determine the number of chunks in y and x directions n_chunk_y = df_chunk["chunk_y"].max() + 1 n_chunk_x = df_chunk["chunk_x"].max() + 1 chunk_w = n_chunk_x * chunk_size[1] chunk_h = n_chunk_y * chunk_size[0] # Create an empty array for storing the registered image chunks dar_chunk = da.zeros((n_cycle, chunk_h, chunk_w), dtype=dar_img.dtype, chunks=(1, chunk_size[0], chunk_size[1])) # Register the chunks by applying the transformations using _register_chunk function dar_res = da.map_blocks( _register_chunk, dar_chunk, zarr_path, group_tile, df_chunk, df_tile, df_H, n_y, n_x, chunk_size, dtype=dar_img.dtype, chunks=(1, chunk_size[0], chunk_size[1])) print("Registering: " + group_tile + show_resource()) with ProgressBar(): # Define the dimensions and coordinates for the output DataArray dims = ["cycle", "y", "x"] coords = { "cycle": range(n_cycle), "y": range(chunk_h), "x": range(chunk_w)} chunks = {"cycle": 1, "y": chunk_size[0], "x": chunk_size[1]} # Create the output dataset, chunk it, and save it to Zarr out = xr.DataArray(dar_res, dims=dims, coords=coords) out = out.to_dataset(name="data") out = out.chunk(chunks=chunks) out.to_zarr(zarr_path, group=group_tile + footer + "/0", mode="w")
[docs] def registration_cYXyx_noref(zarr_path, group_tile, stitched_shape, chunk_size, subfooter="", shift_footer="_shift_tile_cycle", footer="_reg"): """ Registers and stitches image tiles based on transformation matrices, creating a registered dataset in Zarr format. This version does not use a reference stitched image group, but instead takes the stitched shape directly as input. Differences from `registration_cYXyx`: - This function does not rely on a reference stitched image group (`group_ref`). - The stitched shape (`stitched_shape`) is provided directly as an argument, specifying the dimensions of the stitched images (number of tiles in y and x directions and the dimensions of each tile). Args: zarr_path (str): Path to the Zarr file containing the image data. group_tile (str): Group name in the Zarr file for the tiles to be registered. stitched_shape (tuple of int): Shape of the stitched image (n_tile_stitched_y, n_tile_stitched_x, n_y_stitched, n_x_stitched). chunk_size (tuple of int): Size of each chunk (chunk_height, chunk_width). subfooter (str, optional): String to append before the shift footer in the output filename; defaults to an empty string. shift_footer (str, optional): Footer of the shift CSV file; defaults to "_shift_tile_cycle". footer (str, optional): String appended to the output Zarr group name; defaults to "_reg". Returns: None: The function saves the registered and stitched images to a new Zarr group. """ # Load transformation matrices from the shift file print("Registering: " + group_tile + show_resource()) shift_path = zarr_path.replace( ".zarr", subfooter + shift_footer + ".csv") df_H = pd.read_csv(shift_path) # Load the image data from the Zarr file dar_img = da.from_zarr(zarr_path, component=group_tile + "/0/data") n_cycle, n_tile_y, n_tile_x, n_y, n_x = dar_img.shape # Unpack the stitched shape provided as input n_tile_stiched_y, n_tile_stiched_x, n_y_stitched, n_x_stitched = \ stitched_shape # Calculate tile edges based on transformations df_tile = get_edges(n_cycle, n_tile_y, n_tile_x, df_H, n_y_stitched, n_x_stitched, n_y, n_x) # Create a DataFrame representing chunk coordinates shape = (n_y_stitched * n_tile_stiched_y, n_x_stitched * n_tile_stiched_x) df_chunk = create_chunk_dataframe(shape, chunk_size) # Determine the number of chunks in y and x directions n_chunk_y = df_chunk["chunk_y"].max() + 1 n_chunk_x = df_chunk["chunk_x"].max() + 1 chunk_w = n_chunk_x * chunk_size[1] chunk_h = n_chunk_y * chunk_size[0] # Create an empty array for storing the registered image chunks dar_chunk = da.zeros((n_cycle, chunk_h, chunk_w), dtype=dar_img.dtype, chunks=(1, chunk_size[0], chunk_size[1])) # Register the chunks by applying the transformations using _register_chunk function dar_res = da.map_blocks( _register_chunk, dar_chunk, zarr_path, group_tile, df_chunk, df_tile, df_H, n_y, n_x, chunk_size, dtype=dar_img.dtype, chunks=(1, chunk_size[0], chunk_size[1])) with ProgressBar(): # Define the dimensions and coordinates for the output DataArray dims = ["cycle", "y", "x"] coords = { "cycle": range(n_cycle), "y": range(chunk_h), "x": range(chunk_w)} chunks = {"cycle": 1, "y": chunk_size[0], "x": chunk_size[1]} # Create the output dataset, chunk it, and save it to Zarr out = xr.DataArray(dar_res, dims=dims, coords=coords) out = out.to_dataset(name="data") out = out.chunk(chunks=chunks) out.to_zarr(zarr_path, group=group_tile + footer + "/0", mode="w")