# Source code for brainlit.map_neurons.diffeo_gen

import numpy as np
import torch
import matplotlib.pyplot as plt

def interp(x, I, phii, **kwargs):
"""Interpolate a function given original and transformed coordinates.

Args:
x (list): Original pixel locations of image.
I (torch.tensor): Function to be interpolated.
phii (torch.tensor): Transformed coordinates.

Raises:
Exception: Function should be four or five dimensional.

Returns:
torch.tensor: Interpolated function.
"""
# note I want components of phi to be at the end
# components of I should be at the beginning
# this does make composing two transformations a bit weird
# start by scaling to -1,1
phii_ = torch.clone(phii)
for i in range(3):
phii_[..., i] -= x[i][0]
phii_[..., i] /= x[i][-1] - x[i][0]
phii_ *= 2.0
phii_ -= 1.0

# we need xyz at the end, and in the order xyz (not zyx)
# check if I'm using batches
if I.ndim == 4:
add_batch = True
elif I.ndim == 5:
add_batch = False
else:
raise Exception("Image should be 4 or 5 dim")

if add_batch:
Iin = I[None]
phii_in = phii_[None]
else:
Iin = I
phii_in = phii_
output = torch.nn.functional.grid_sample(
Iin,
torch.flip(phii_in, (-1,)),
align_corners=True,
padding_mode="border",
**kwargs
)
# remove batch dimension
if add_batch:
output = output[0]
return output

def expR(xv, v0, K, n=10, visualize=False, return_forward=True):
"""Riemannian exponential

Args:
xv (list of arrays): Location of pixels in v.
v0 (array): velocity at time 0. Recall shape is rowxcolxsicex3.
K (array): kernel in fft domain
n (int, optional): number of timesteps. Defaults to 10.
visualize (bool, optional): Whether to plot the output. Defaults to False.
return_forward (bool, optional): Direction of exponential. Defaults to True.

Returns:
torch.tensor: Generated diffeomorphism.
"""

use_batch = v0.ndim == 5
if not use_batch:
permute0 = (-1, -4, -3, -2)
permute1 = (-3, -2, -1, -4)
else:
permute0 = (0, -1, -4, -3, -2)
permute1 = (0, -3, -2, -1, -4)

# initialize p at time 0
p0 = torch.fft.ifftn(
torch.fft.fftn(v0, dim=(-2, -3, -4)) / K[..., None], dim=(-2, -3, -4)
).real
# initialize phii at time 0
XV = torch.stack(
torch.meshgrid([torch.as_tensor(x) for x in xv], indexing="ij"), -1
)
# initialize dv
dv = [x[1].item() - x[0].item() for x in xv]

phii = XV.clone()
if use_batch:
phii = phii[None].repeat(v0.shape[0], 1, 1, 1, 1)
XV = XV[None].repeat(v0.shape[0], 1, 1, 1, 1)

if visualize and not use_batch:
fig, ax = plt.subplots(1, 3)

# we take n timesteps

if return_forward:
vsave = []
for t in range(n):
# we need to calculate p at time t
# first we just deform it
p = interp(xv, p0.permute(*permute0), phii).permute(*permute1)
# then we need the jacobian
Dphii = torch.stack(torch.gradient(phii, dim=(-4, -3, -2), spacing=dv), -1)
# and the determinant (over the last two axes)
detDphii = torch.linalg.det(Dphii)
# then we will multiply

p = (Dphii.transpose(-1, -2) @ p[..., None])[..., 0] * detDphii[..., None]
# now we calculate v
v = torch.fft.ifftn(
torch.fft.fftn(p, dim=(-2, -3, -4)) * K[..., None], dim=(-2, -3, -4)
).real
if return_forward:
vsave.append(v)
# now we update phii
Xs = XV - v / n
phii = interp(xv, (phii - XV).permute(*permute0), Xs).permute(*permute1) + Xs

if visualize and not use_batch:
pshow = np.array(p[p0.shape[1] // 2, :, :, :])
pshow -= np.min(pshow, axis=(0, 1, 2))
pshow /= np.max(pshow, axis=(0, 1, 2))
ax[0].cla()
ax[0].imshow(pshow)

vshow = np.array(v[p0.shape[1] // 2, :, :, :])
vshow -= np.min(vshow, axis=(0, 1, 2))
vshow /= np.max(vshow, axis=(0, 1, 2))
ax[1].cla()
ax[1].imshow(vshow)

fig.canvas.draw()
if not return_forward:
return phii
else:
phi = XV.clone()
for v in reversed(vsave):
Xs = XV + v / n
phi = interp(xv, (phii - XV).permute(*permute0), Xs).permute(*permute1) + Xs
return phi

[docs]def diffeo_gen_ara(sigma):
"""Return random diffeomorphism generated by sampling Gaussian noise then passing through Riemannian exponential.

Args:
sigma (float): standard deviation of noise in microns.

Returns:
List: list of sampled points in spatial domain.
np.array: range of diffeomorphism at sampled points.
"""
# a domain for sampling your velocity and deformatoin
dv = np.array([100.0, 100.0, 100.0])  # units are every 100 microns
nv = np.array([132, 80, 114])
xv = [np.arange(n) * d - (n - 1) * d / 2 for n, d in zip(nv, dv)]

XV = torch.stack(
torch.meshgrid([torch.as_tensor(x) for x in xv], indexing="ij"), -1
)

# a frequency domain
fv = [np.arange(n) / n / d for n, d in zip(nv, dv)]
FV = np.stack(np.meshgrid(*fv, indexing="ij"), -1)
a = 100.0
p = 2.0
LL = (
1.0 - 2.0 * a**2 * np.sum(((np.cos(2.0 * np.pi * FV * dv) - 1)) / dv**2, -1)
) ** (2 * p)
K = 1.0 / LL

# lets make a new p which is really simple for testing
# sample white noise
Lm = np.random.randn(*FV.shape) * sigma

# smooth it with sqrt(K) (here I smoothed with K to be a bit smoother)
v = np.fft.ifftn(
np.fft.fftn(Lm, axes=(0, 1, 2)) * K[..., None], axes=(0, 1, 2)
).real

# shoot it with remannian exponential
phii = expR([torch.tensor(x) for x in xv], torch.tensor(v), K, n=10)
phii = phii.detach().cpu().numpy()

return xv, phii