import numpy as np
import math
import zarr
from joblib import Parallel, delayed
from tqdm import tqdm
from brainlit.viz.swc2voxel import Bresenham3D
from brainlit.preprocessing import image_process
import networkx as nx
from typing import List, Tuple, Callable
[docs]class ViterBrain:
def __init__(
self,
G: nx.Graph,
tiered_path: str,
fragment_path: str,
resolution: List[float],
coef_curv: float,
coef_dist: float,
coef_int: float,
parallel: int = 1,
) -> None:
"""Initialize ViterBrain object
Args:
G (nx.Graph): networkx graph representation of states
tiered_path (str): path to tiered image
fragment_path (str): path to fragments image
resolution (list of floats): resolution of images
coef_curv (float): curvature coefficient
coef_dist (float): distance coefficient
coef_int (float): image likelihood coefficient
parallel (int, optional): Number of threads to use for parallelization. Defaults to 1.
"""
self.nxGraph = G
self.num_states = G.number_of_nodes()
self.tiered_path = tiered_path
self.fragment_path = fragment_path
self.resolution = resolution
self.coef_curv = coef_curv
self.coef_dist = coef_dist
self.coef_int = coef_int
self.parallel = parallel
soma_fragment2coords = {}
for node in G.nodes:
if G.nodes[node]["type"] == "soma":
soma_fragment2coords[G.nodes[node]["fragment"]] = G.nodes[node][
"soma_coords"
]
self.soma_fragment2coords = soma_fragment2coords
comp_to_states = {}
for node in G.nodes:
frag = G.nodes[node]["fragment"]
if frag in comp_to_states.keys():
prev = comp_to_states[frag]
states = prev + [node]
comp_to_states[frag] = states
else:
comp_to_states[frag] = [node]
self.comp_to_states = comp_to_states
[docs] def frag_frag_dist(
self,
pt1: List[float],
orientation1: List[float],
pt2: List[float],
orientation2: List[float],
verbose: bool = False,
) -> float:
"""Compute cost of transition between two fragment states
Args:
pt1 (list of floats): first coordinate
orientation1 (list of floats): orientation at first coordinate
pt2 (list of floats): second coordinate
orientation2 (list of floats): orientation at second coordinate
verbose (bool, optional): Print transition cost information. Defaults to False.
Raises:
ValueError: if an orientation is not unit length
ValueError: if distance or curvature cost is nan
Returns:
[float]: cost of transition
"""
res = self.resolution
dif = np.multiply(np.subtract(pt2, pt1), res)
dist = np.linalg.norm(dif)
if (
dist == 0
or not math.isclose(np.linalg.norm(orientation1), 1, abs_tol=1e-5)
or not math.isclose(np.linalg.norm(orientation2), 1, abs_tol=1e-5)
):
raise ValueError(
f"pt1: {pt1} pt2: {pt2} dist: {dist}, o1: {orientation1} o2: {orientation2}"
)
if dist > 15:
return np.inf
k1_sq = 1 - np.dot(dif, orientation1) / dist
k2_sq = 1 - np.dot(dif, orientation2) / dist
k_cost = np.mean([k1_sq, k2_sq])
if np.isnan(dist) or np.isnan(k_cost):
raise ValueError(f"NAN cost: distance - {dist}, curv - {k_cost}")
# if combined average angle is tighter than 45 deg or either is tighter than 30 deg
if 1 - k1_sq < -0.87 or 1 - k2_sq < -0.87:
return np.inf
cost = k_cost * self.coef_curv + self.coef_dist * (dist**2)
if verbose:
print(
f"Distance: {dist}, Curv penalty: {k_cost} (dots {1-k1_sq}, {1-k2_sq}, from dif-{dif}), Total cost: {cost}"
)
return cost
[docs] def frag_soma_dist(
self,
point: List[float],
orientation: List[float],
soma_lbl: int,
verbose: bool = False,
) -> Tuple[float, List]:
"""Compute cost of transition from fragment state to soma state
Args:
point (list of floats): coordinate on fragment
orientation (list of floats): orientation at fragment
soma_lbl (int): label of soma component
verbose (bool, optional): Prints cost values. Defaults to False.
Raises:
ValueError: if either distance or curvature cost is nan
ValueError: if the computed closest soma coordinate is not on the soma
Returns:
[float]: cost of transition
[list of floats]: closest soma coordinate
"""
coords = self.soma_fragment2coords[soma_lbl]
image_fragment = zarr.open(self.fragment_path, mode="r")
difs = np.multiply(np.subtract(coords, point), self.resolution)
dists = np.linalg.norm(difs, axis=1)
argmin = np.argmin(dists)
dif = difs[argmin, :]
dist = dists[argmin]
dot = np.dot(dif, orientation) / (
np.linalg.norm(dif) * np.linalg.norm(orientation)
)
k_cost = 1 - dot
if np.isnan(k_cost) or np.isnan(dist):
raise ValueError(f"NAN cost: distance - {dist}, curv - {k_cost}")
if dist > 15:
cost = np.inf
else:
cost = k_cost * self.coef_curv + self.coef_dist * (dist**2)
nonline_point = coords[argmin, :]
if (
image_fragment[
nonline_point[0],
nonline_point[1],
nonline_point[2],
]
!= soma_lbl
):
raise ValueError("Soma point is not on soma")
if verbose:
print(
f"Distance: {dist}, Curv penalty: {k_cost}, Total cost: {cost}, connection point: {nonline_point}"
)
return cost, nonline_point
def _compute_out_costs_dist(
self, states: List[int], frag_frag_func: Callable, frag_soma_func: Callable
) -> List[tuple]:
"""Compute outgoing distance costs for specified list of states.
Args:
states (list of ints): list of states from which to compute transition costs.
frag_frag_func (function): function that computes transition cost between fragments
frag_soma_func (function): function that computes transition cost between fragments
Raises:
ValueError: if cannot compute transition cost between two states
Returns:
[list]: list of transition costs
"""
num_states = self.num_states
G = self.nxGraph
results = []
for state1 in tqdm(states, desc="computing state costs (geometry)"):
for state2 in range(num_states):
soma_pt = None
if G.nodes[state1]["fragment"] == G.nodes[state2]["fragment"]:
continue
elif G.nodes[state1]["type"] == "soma":
continue
elif (
G.nodes[state1]["type"] == "fragment"
and G.nodes[state2]["type"] == "fragment"
):
try:
dist_cost = frag_frag_func(
G.nodes[state1]["point2"],
G.nodes[state1]["orientation2"],
G.nodes[state2]["point1"],
G.nodes[state2]["orientation1"],
)
except:
raise ValueError(
f"Cant compute cost between fragments: state1: {state1}, state2: {state2}, node1: {G.nodes[state1]}, node2 = {G.nodes[state2]}"
)
elif (
G.nodes[state1]["type"] == "fragment"
and G.nodes[state2]["type"] == "soma"
):
dist_cost, soma_pt = frag_soma_func(
G.nodes[state1]["point2"],
G.nodes[state1]["orientation2"],
G.nodes[state2]["fragment"],
)
if np.isfinite(dist_cost):
results.append((state1, state2, dist_cost, soma_pt))
return results
[docs] def compute_all_costs_dist(
self, frag_frag_func: Callable, frag_soma_func: Callable
) -> None:
"""Splits up transition computation tasks then assembles them into networkx graph
Args:
frag_frag_func (function): function that computes transition cost between fragments
frag_soma_func (function): function that computes transition cost between fragments
"""
parallel = self.parallel
G = self.nxGraph
state_sets = np.array_split(np.arange(self.num_states), parallel)
results_tuple = Parallel(n_jobs=parallel)(
delayed(self._compute_out_costs_dist)(
states, frag_frag_func, frag_soma_func
)
for states in state_sets
)
results = [item for result in results_tuple for item in result]
for result in results:
state1, state2, dist_cost, soma_pt = result
if dist_cost != np.inf:
G.add_edge(state1, state2, dist_cost=dist_cost)
if soma_pt is not None:
G.nodes[state1]["soma_pt"] = soma_pt
def _line_int(self, loc1: List[int], loc2: List[int]) -> float:
"""Compute line integral of image likelihood costs between two coordinates
Args:
loc1 (list of ints): first coordinate
loc2 (list of ints): second coordinate
Returns:
[float]: sum of image likelihood costs
"""
image_tiered = zarr.open(self.tiered_path, mode="r")
corner1 = [np.amin([loc1[i], loc2[i]]) for i in range(len(loc1))]
corner2 = [np.amax([loc1[i], loc2[i]]) for i in range(len(loc1))]
image_tiered_cutout = image_tiered[
corner1[0] : corner2[0] + 1,
corner1[1] : corner2[1] + 1,
corner1[2] : corner2[2] + 1,
]
loc1 = [int(loc1[i]) - corner1[i] for i in range(len(loc1))]
loc2 = [int(loc2[i]) - corner1[i] for i in range(len(loc1))]
xlist, ylist, zlist = Bresenham3D(
int(loc1[0]),
int(loc1[1]),
int(loc1[2]),
int(loc2[0]),
int(loc2[1]),
int(loc2[2]),
)
# exclude first and last points because they are included in the component intensity sum
xlist = xlist[1:-1]
ylist = ylist[1:-1]
zlist = zlist[1:-1]
sum = np.sum(image_tiered_cutout[xlist, ylist, zlist])
return sum
def _compute_out_int_costs(self, states: List[int]) -> List[tuple]:
"""Compute pairwise image likelihood costs.
Args:
states (list of ints): list of states
Raises:
ValueError: Cases did not catch the particular state type pair
Returns:
[list]: list of transition costs values
"""
num_states = self.num_states
G = self.nxGraph
results = []
for state1 in tqdm(states, desc="Computing state costs (intensity)"):
for state2 in range(num_states):
if G.nodes[state1]["fragment"] == G.nodes[state2][
"fragment"
] or not G.has_edge(state1, state2):
continue
elif G.nodes[state1]["type"] == "soma":
continue
elif (
G.nodes[state1]["type"] == "fragment"
and G.nodes[state2]["type"] == "fragment"
):
line_int_cost = self._line_int(
G.nodes[state1]["point2"], G.nodes[state2]["point1"]
)
int_cost = line_int_cost + G.nodes[state2]["image_cost"]
results.append((state1, state2, int_cost))
elif (
G.nodes[state1]["type"] == "fragment"
and G.nodes[state2]["type"] == "soma"
):
line_int_cost = self._line_int(
G.nodes[state1]["point2"], G.nodes[state1]["soma_pt"]
)
results.append((state1, state2, line_int_cost))
else:
raise ValueError("No cases caught int")
return results
[docs] def compute_all_costs_int(self) -> None:
"""Splits up transition computation tasks then assembles them into networkx graph"""
parallel = self.parallel
G = self.nxGraph
state_sets = np.array_split(np.arange(self.num_states), parallel)
results_tuple = Parallel(n_jobs=parallel)(
delayed(self._compute_out_int_costs)(states) for states in state_sets
)
results = [item for result in results_tuple for item in result]
for result in results:
state1, state2, int_cost = result
if int_cost != np.inf:
G.edges[state1, state2]["int_cost"] = int_cost
G.edges[state1, state2]["total_cost"] = (
G.edges[state1, state2]["int_cost"]
+ G.edges[state1, state2]["dist_cost"]
)
[docs] def shortest_path(self, coord1: List[int], coord2: List[int]) -> List[List[int]]:
"""Compute coordinate path from one coordinate to another.
Args:
coord1 (list): voxel coordinate of start point
coord2 (list): voxel coordinate of end point
Raises:
ValueError: if state sequence contains a soma state that is not at the end
Returns:
list: list of voxel coordinates of path
"""
fragments = zarr.open(self.fragment_path, mode="r")
# Compute labels of coordinates
labels = []
radius = 20
for coord in [coord1, coord2]:
local_labels = fragments[
np.amax([coord[0] - radius, 0]) : coord[0] + radius,
np.amax([coord[1] - radius, 0]) : coord[1] + radius,
np.amax([coord[2] - radius, 0]) : coord[2] + radius,
]
label = image_process.label_points(
local_labels, [[radius, radius, radius]], res=self.resolution
)[1][0]
labels.append(label)
# find shortest path for all state combinations
states1 = self.comp_to_states[labels[0]]
states2 = self.comp_to_states[labels[1]]
min_cost = -1
for state1 in states1:
for state2 in states2:
try:
cost = nx.shortest_path_length(
self.nxGraph, state1, state2, weight="total_cost"
)
except nx.NetworkXNoPath:
continue
if cost < min_cost or min_cost == -1:
min_cost = cost
states = nx.shortest_path(
self.nxGraph, state1, state2, weight="total_cost"
)
# create coordinate list
coords = [coord1]
coords.append(self.nxGraph.nodes[states[0]]["point2"])
for i, state in enumerate(states[1:]):
if self.nxGraph.nodes[state]["type"] == "fragment":
coords.append(self.nxGraph.nodes[state]["point1"])
coords.append(self.nxGraph.nodes[state]["point2"])
elif self.nxGraph.nodes[state]["type"] == "soma":
coords.append(self.nxGraph.nodes[states[i]]["soma_pt"])
if i != len(states) - 2:
raise ValueError("Soma state is not last state")
coords.append(coord2)
return coords