Source code for brainlit.algorithms.connect_fragments.viterbrain

import numpy as np
import math
import zarr
from joblib import Parallel, delayed
from tqdm import tqdm
from brainlit.viz.swc2voxel import Bresenham3D
from brainlit.preprocessing import image_process
import networkx as nx
from typing import List, Tuple, Callable
from pathlib import Path
import pickle
import os
import copy
import itertools
from scipy.spatial import cKDTree


def _curv_dist(
    res: List[float],
    pt1: List[int],
    orientation1: List[int],
    pt2: List[int],
    orientation2: List[int],
):
    """Compute components of transition cost between two fragment states

    Args:
        res (list of floats): resolution of image
        pt1 (list of ints): first coordinate
        orientation1 (list of ints): orientation at first coordinate
        pt2 (list of ints): second coordinate
        orientation2 (list of ints): orientation at second coordinate

    Raises:
        ValueError: if an orientation is not unit length
        ValueError: if distance or curvature cost is nan

    Returns:
        [float]: cost of transition
    """
    dif = np.multiply(np.subtract(pt2, pt1), res)

    dist = np.linalg.norm(dif)

    if dist > 15:
        return np.inf, np.inf

    if (
        dist == 0
        or not math.isclose(np.linalg.norm(orientation1), 1, abs_tol=1e-5)
        or not math.isclose(np.linalg.norm(orientation2), 1, abs_tol=1e-5)
    ):
        raise ValueError(
            f"pt1: {pt1} pt2: {pt2} dist: {dist}, o1: {orientation1} o2: {orientation2}"
        )

    k1_sq = 1 - np.dot(dif, orientation1) / dist
    k2_sq = 1 - np.dot(dif, orientation2) / dist

    k_cost = np.mean([k1_sq, k2_sq])

    if np.isnan(dist) or np.isnan(k_cost):
        raise ValueError(f"NAN cost: distance - {dist}, curv - {k_cost}")

    # if combined  average angle is tighter than 45 deg or either is tighter than 30 deg
    if 1 - k1_sq < -0.87 or 1 - k2_sq < -0.87:
        return np.inf, np.inf
    else:
        return dist, k_cost


def _dist_simple(
    res: List[float],
    pt0: List[int],
    pt1: List[int],
    orientation1: List[int],
    pt2: List[int],
    orientation2: List[int],
) -> float:
    dif = np.multiply(np.subtract(pt2, pt1), res)
    dist2 = np.linalg.norm(dif)
    if dist2 > 20:
        return np.inf

    dif = np.multiply(np.subtract(pt1, pt0), res)
    dist1 = np.linalg.norm(dif)

    return dist1 + dist2**2


def _compute_dist_cost(pair, res, coef_dist=10, coef_curv=1000):
    state1_data = pair[0]
    state1 = state1_data[0]
    state1_dict = state1_data[1]

    state2_data = pair[1]
    state2 = state2_data[0]
    state2_dict = state2_data[1]

    if state1_dict["fragment"] == state2_dict["fragment"]:
        return (state1, state2, np.inf)
    elif state1_dict["type"] == "fragment" and state2_dict["type"] == "fragment":
        pt0 = state1_dict["point1"]
        pt1 = state1_dict["point2"]
        orientation1 = state1_dict["orientation2"]
        pt2 = state2_dict["point1"]
        orientation2 = state2_dict["orientation1"]

        cost = _dist_simple(
            res=res,
            pt0=pt0,
            pt1=pt1,
            orientation1=orientation1,
            pt2=pt2,
            orientation2=orientation2,
        )

        # dist, k_cost = _curv_dist(
        #     res=res,
        #     pt1=pt1,
        #     orientation1=orientation1,
        #     pt2=pt2,
        #     orientation2=orientation2,
        # )
        # cost = coef_dist*(dist**2) + coef_curv*k_cost
        return (state1, state2, cost)
    else:
        raise ValueError("no two fragments?")


def _line_int_coord(loc1: List[int], loc2: List[int], tiered_path: str):
    image_tiered = zarr.open(tiered_path, mode="r")
    corner1 = [np.amin([loc1[i], loc2[i]]) for i in range(len(loc1))]
    corner2 = [np.amax([loc1[i], loc2[i]]) for i in range(len(loc1))]

    image_tiered_cutout = image_tiered[
        corner1[0] : corner2[0] + 1,
        corner1[1] : corner2[1] + 1,
        corner1[2] : corner2[2] + 1,
    ]

    loc1 = [int(loc1[i]) - corner1[i] for i in range(len(loc1))]
    loc2 = [int(loc2[i]) - corner1[i] for i in range(len(loc1))]

    xlist, ylist, zlist = Bresenham3D(
        int(loc1[0]),
        int(loc1[1]),
        int(loc1[2]),
        int(loc2[0]),
        int(loc2[1]),
        int(loc2[2]),
    )
    # exclude first and last points because they are included in the component intensity sum
    xlist = xlist[1:-1]
    ylist = ylist[1:-1]
    zlist = zlist[1:-1]

    sum = np.sum(image_tiered_cutout[xlist, ylist, zlist])

    return sum


def _line_int_zero(loc1: List[int], loc2: List[int], tiered_path: str):
    return 0


def _compute_int_cost(pair, tiered_path):
    state1_data = pair[0]
    state1 = state1_data[0]
    state1_dict = state1_data[1]

    state2_data = pair[1]
    state2 = state2_data[0]
    state2_dict = state2_data[1]

    if state1_dict["fragment"] == state2_dict["fragment"]:
        return (state1, state2, np.inf)
    elif state1_dict["type"] == "fragment" and state2_dict["type"] == "fragment":
        int_cost = _line_int_zero(
            state1_dict["point2"], state2_dict["point1"], tiered_path=tiered_path
        )  # _line_int_coord("") + state2_data["image_cost"]
        return (state1, state2, int_cost)

    else:
        raise ValueError("No two fragments?")


[docs]class ViterBrain: def __init__( self, G: nx.Graph, tiered_path: str, fragment_path: str, resolution: List[float], coef_curv: float, coef_dist: float, coef_int: float, parallel: int = 1, ) -> None: """Initialize ViterBrain object Args: G (nx.Graph): networkx graph representation of states tiered_path (str): path to tiered image fragment_path (str): path to fragments image resolution (list of floats): resolution of images coef_curv (float): curvature coefficient coef_dist (float): distance coefficient coef_int (float): image likelihood coefficient parallel (int, optional): Number of threads to use for parallelization. Defaults to 1. """ self.nxGraph = G self.num_states = G.number_of_nodes() self.tiered_path = tiered_path self.fragment_path = fragment_path self.resolution = resolution self.coef_curv = coef_curv self.coef_dist = coef_dist self.coef_int = coef_int self.parallel = parallel soma_fragment2coords = {} for node in G.nodes: if G.nodes[node]["type"] == "soma": soma_fragment2coords[G.nodes[node]["fragment"]] = G.nodes[node][ "soma_coords" ] self.soma_fragment2coords = soma_fragment2coords comp_to_states = {} for node in G.nodes: frag = G.nodes[node]["fragment"] if frag in comp_to_states.keys(): prev = comp_to_states[frag] states = prev + [node] comp_to_states[frag] = states else: comp_to_states[frag] = [node] self.comp_to_states = comp_to_states
[docs] def frag_soma_dist( self, point: List[float], orientation: List[float], soma_lbl: int, verbose: bool = False, ) -> Tuple[float, List]: """Compute cost of transition from fragment state to soma state Args: point (list of floats): coordinate on fragment orientation (list of floats): orientation at fragment soma_lbl (int): label of soma component verbose (bool, optional): Prints cost values. Defaults to False. Raises: ValueError: if either distance or curvature cost is nan ValueError: if the computed closest soma coordinate is not on the soma Returns: [float]: cost of transition [list of floats]: closest soma coordinate """ coords = self.soma_fragment2coords[soma_lbl] image_fragment = zarr.open_array(self.fragment_path, mode="r") difs = np.multiply(np.subtract(coords, point), self.resolution) dists = np.linalg.norm(difs, axis=1) argmin = np.argmin(dists) dif = difs[argmin, :] dist = dists[argmin] dot = np.dot(dif, orientation) / ( np.linalg.norm(dif) * np.linalg.norm(orientation) ) k_cost = 1 - dot if np.isnan(k_cost) or np.isnan(dist): raise ValueError(f"NAN cost: distance - {dist}, curv - {k_cost}") if dist > 15: cost = np.inf else: cost = k_cost * self.coef_curv + self.coef_dist * (dist**2) nonline_point = coords[argmin, :] if ( image_fragment[ nonline_point[0], nonline_point[1], nonline_point[2], ] != soma_lbl ): raise ValueError("Soma point is not on soma") if verbose: print( f"Distance: {dist}, Curv penalty: {k_cost}, Total cost: {cost}, connection point: {nonline_point}" ) return cost, nonline_point
[docs] def compute_all_costs_dist(self) -> None: """Splits up transition computation tasks then assembles them into networkx graph""" parallel = self.parallel G = self.nxGraph data = [] for state in range(self.num_states): if G.nodes[state]["type"] == "fragment": data.append(np.multiply(G.nodes[state]["point2"], self.resolution)) elif G.nodes[state]["type"] == "soma": print( f"Warning: Component of type soma is encountered which will not be connected to the graph" ) data = np.stack(data, axis=0) kdt1 = cKDTree(data) data = [] for state in range(self.num_states): data.append(np.multiply(G.nodes[state]["point1"], self.resolution)) data = np.stack(data, axis=0) kdt2 = cKDTree(data) results = kdt1.query_ball_tree(kdt2, r=15) pairs = [] for state1, nbrs in enumerate(tqdm(results, desc="constructing pairs")): state1_data = (state1, G.nodes[state1]) for state2 in nbrs: state2_data = (state2, G.nodes[state2]) pairs.append((state1_data, state2_data)) print(f"{len(pairs)} for {self.num_states} states") chunk_size = 100000 for start in tqdm(range(0, len(pairs), chunk_size), desc="pair chunks"): pairs_chunk = itertools.islice(pairs, start, start + chunk_size) cost_data = Parallel(n_jobs=parallel)( # , backend="threading")( delayed(_compute_dist_cost)(pair, self.resolution) for pair in tqdm( pairs_chunk, desc="pair", leave=False, total=chunk_size ) ) for cost in tqdm(cost_data, desc="adding edges"): if np.isfinite(cost[-1]): G.add_edge(cost[0], cost[1], dist_cost=cost[-1]) print(f"{len(G.edges)} edges")
[docs] def compute_all_costs_int(self) -> None: """Splits up transition computation tasks then assembles them into networkx graph""" parallel = self.parallel G = self.nxGraph pairs = [] for e in G.edges: state1_data = (e[0], G.nodes[e[0]]) state2_data = (e[1], G.nodes[e[1]]) pairs.append((state1_data, state2_data)) chunk_size = 100000 for start in tqdm(range(0, len(pairs), chunk_size), desc="pair chunks"): pairs_chunk = itertools.islice(pairs, start, start + chunk_size) cost_data = Parallel(n_jobs=parallel)( # , backend="threading")( delayed(_compute_int_cost)(pair, self.tiered_path) for pair in tqdm( pairs_chunk, desc="pair", leave=False, total=chunk_size ) ) for cost in tqdm(cost_data, desc="adding edges"): if np.isfinite(cost[-1]): G.edges[cost[0], cost[1]]["int_cost"] = cost[-1] G.edges[cost[0], cost[1]]["total_cost"] = ( G.edges[cost[0], cost[1]]["dist_cost"] + G.edges[cost[0], cost[1]]["int_cost"] )
[docs] def shortest_path(self, coord1: List[int], coord2: List[int]) -> List[List[int]]: """Compute coordinate path from one coordinate to another. Args: coord1 (list): voxel coordinate of start point coord2 (list): voxel coordinate of end point Raises: ValueError: if state sequence contains a soma state that is not at the end Returns: list: list of voxel coordinates of path """ fragments = zarr.open_array(self.fragment_path, mode="r") # Compute labels of coordinates labels = [] for coord in [coord1, coord2]: local_labels, new_coord = get_valid_bbox(fragments, coord, radius=20) label = image_process.label_points( local_labels, [new_coord], res=self.resolution, )[1][0] labels.append(label) # find shortest path for all state combinations states1 = self.comp_to_states[labels[0]] states2 = self.comp_to_states[labels[1]] min_cost = -1 for state1 in states1: for state2 in states2: try: cost = nx.shortest_path_length( self.nxGraph, state1, state2, weight="total_cost" ) except nx.NetworkXNoPath: continue if cost < min_cost or min_cost == -1: min_cost = cost states = nx.shortest_path( self.nxGraph, state1, state2, weight="total_cost" ) if min_cost == -1: raise nx.NetworkXNoPath(f"No path found between {coord1} and {coord2}") else: coords = [coord1] coords.append(list(self.nxGraph.nodes[states[0]]["point2"])) for i, state in enumerate(states[1:]): if self.nxGraph.nodes[state]["type"] == "fragment": coords.append(list(self.nxGraph.nodes[state]["point1"])) coords.append(list(self.nxGraph.nodes[state]["point2"])) elif self.nxGraph.nodes[state]["type"] == "soma": coords.append(list(self.nxGraph.nodes[states[i]]["soma_pt"])) if i != len(states) - 2: raise ValueError("Soma state is not last state") coords.append(coord2) return coords
def explain_viterbrain(vb, c1, c2): # assume c1,c2 fall on a fragment path_coords = vb.shortest_path(c1, c2) comp_to_states = vb.comp_to_states z_frags = zarr.open_array(vb.fragment_path) states1 = comp_to_states[z_frags[c1[0], c1[1], c1[2]]] states2 = comp_to_states[z_frags[c2[0], c2[1], c2[2]]] min_cost = -1 for state1 in states1: for state2 in states2: try: cost = nx.shortest_path_length( vb.nxGraph, state1, state2, weight="total_cost" ) except nx.NetworkXNoPath: continue if cost < min_cost or min_cost == -1: min_cost = cost states = nx.shortest_path( vb.nxGraph, state1, state2, weight="total_cost" ) print(f"{len(states)} states") print(f"{len(path_coords)} coordinates") coord_idx = 0 for coord_idx, c in enumerate(path_coords[:-1]): state_idx = coord_idx // 2 state = states[state_idx] if coord_idx > 0: prev_c = path_coords[coord_idx - 1] if z_frags[c[0], c[1], c[2]] != z_frags[prev_c[0], prev_c[1], prev_c[2]]: e = vb.nxGraph.edges[states[state_idx - 1], state] print(f"Transition: {states[state_idx-1]}->{state}: {e}") print(f"{coord_idx}: {c} f{z_frags[c[0],c[1],c[2]]} s{state}") def get_valid_bbox(array, coord, radius): x1 = np.amax([coord[0] - radius, 0]) y1 = np.amax([coord[1] - radius, 0]) z1 = np.amax([coord[2] - radius, 0]) x2 = np.amin([coord[0] + radius, array.shape[0]]) y2 = np.amin([coord[1] + radius, array.shape[1]]) z2 = np.amin([coord[2] + radius, array.shape[2]]) subvol = np.array(np.squeeze(array[x1:x2, y1:y2, z1:z2])) return subvol, [coord[0] - x1, coord[1] - y1, coord[2] - z1]