In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
In [2]:
import icon_registration
In [3]:
!ntop
Sun Jan 14 09:14:13 2024 +---------------------------------------------------------------------------------------+ | NVIDIA-SMI 545.23.08 Driver Version: 545.23.08 CUDA Version: 12.3 | |-----------------------------------------+----------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+======================+======================| | 0 Quadro RTX 6000 On | 00000000:17:00.0 Off | Off | | 33% 30C P8 8W / 260W | 3148MiB / 24576MiB | 0% Default | | | | N/A | +-----------------------------------------+----------------------+----------------------+ | 1 NVIDIA GeForce GTX 1050 On | 00000000:B3:00.0 Off | N/A | | 30% 26C P8 N/A / 75W | 471MiB / 2048MiB | 0% Default | | | | N/A | +-----------------------------------------+----------------------+----------------------+ +---------------------------------------------------------------------------------------+ | GPU PID USER GPU MEM %CPU %MEM TIME COMMAND | | 0 466061 root 4MiB 0.0 0.1 59 days /usr/lib/xorg/Xorg v | | 0 3356427 tgreer 3140MiB 0.1 3.1 47 days /playpen-raid1/tgree | | 1 466061 root 192MiB 0.0 0.1 59 days /usr/lib/xorg/Xorg v | | 1 466294 tgreer 120MiB 0.1 0.7 59 days /usr/bin/gnome-shell | | 1 466302 tgreer 1MiB 0.1 0.0 59 days /opt/teamviewer/tv_b | | 1 1403059 tgreer 150MiB 4.2 1.9 5 days /usr/lib/firefox/fir | +---------------------------------------------------------------------------------------+
In [4]:
import icon_registration as icon
import icon_registration.data
import icon_registration.networks as networks
from icon_registration.config import device
import numpy as np
import torch
import torchvision.utils
import matplotlib.pyplot as plt
In [5]:
ds1, ds2 = icon_registration.data.get_dataset_retina(include_boundary=False, scale=.8, fixed_vertical_offset=200)
sample_batch = next(iter(ds2))[0]
plt.imshow(torchvision.utils.make_grid(sample_batch[:12], nrow=4)[0])
/playpen-raid1/tgreer/flash_attention/venv/lib/python3.8/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.8.14) is available. It's recommended that you update to the latest version using `pip install -U deeplake`. warnings.warn(
Out[5]:
<matplotlib.image.AxesImage at 0x7f657ca089a0>
In [6]:
icon_registration.data.get_dataset_retina.__globals__['__file__']
Out[6]:
'/playpen-raid1/tgreer/flash_attention/venv/lib/python3.8/site-packages/icon_registration/data.py'
In [34]:
import alias_free_unet
import importlib
import icon_registration.network_wrappers as network_wrappers
importlib.reload(alias_free_unet)
unet = alias_free_unet.GenericUNet(
input_channels=1,
output_channels=64,
num_layers=3,
channels=[[None, 16, 32, 64], [16, 32, 64]],
init_zero=False,
regis_scale=False)
import unet as unet_
unet = unet_.GenericUNet(
input_channels=1,
output_channels=12,
num_layers=3,
channels=[[None, 64, 64, 64], [64, 64, 64]],
init_zero=False,
regis_scale=False)
unet = alias_free_unet.NoDownsampleNet(output_dim=12)
unet.cuda()
image_A = sample_batch.cuda()
ufeatures = unet(image_A)
In [35]:
from icon_registration.losses import ICONLoss, flips
class GradientICONLayerwiseRegularizer(network_wrappers.RegistrationModule):
def __init__(self, network, lmbda):
super().__init__()
self.regis_net = network
self.lmbda = lmbda
def compute_gradient_icon_loss(self, phi_AB, phi_BA):
Iepsilon = (
self.identity_map
+ torch.randn(*self.identity_map.shape).to(self.identity_map.device)
* 1
/ self.identity_map.shape[-1]
)
# compute squared Frobenius of Jacobian of icon error
direction_losses = []
approximate_Iepsilon = phi_AB(phi_BA(Iepsilon))
inverse_consistency_error = Iepsilon - approximate_Iepsilon
delta = 0.001
if len(self.identity_map.shape) == 4:
dx = torch.Tensor([[[[delta]], [[0.0]]]]).to(self.identity_map.device)
dy = torch.Tensor([[[[0.0]], [[delta]]]]).to(self.identity_map.device)
direction_vectors = (dx, dy)
elif len(self.identity_map.shape) == 5:
dx = torch.Tensor([[[[[delta]]], [[[0.0]]], [[[0.0]]]]]).to(
self.identity_map.device
)
dy = torch.Tensor([[[[[0.0]]], [[[delta]]], [[[0.0]]]]]).to(
self.identity_map.device
)
dz = torch.Tensor([[[[0.0]]], [[[0.0]]], [[[delta]]]]).to(
self.identity_map.device
)
direction_vectors = (dx, dy, dz)
elif len(self.identity_map.shape) == 3:
dx = torch.Tensor([[[delta]]]).to(self.identity_map.device)
direction_vectors = (dx,)
for d in direction_vectors:
approximate_Iepsilon_d = phi_AB(phi_BA(Iepsilon + d))
inverse_consistency_error_d = Iepsilon + d - approximate_Iepsilon_d
grad_d_icon_error = (
inverse_consistency_error - inverse_consistency_error_d
) / delta
direction_losses.append(torch.mean(grad_d_icon_error**2))
inverse_consistency_loss = sum(direction_losses)
return inverse_consistency_loss
def forward(self, image_A, image_B):
assert self.identity_map.shape[2:] == image_A.shape[2:]
assert self.identity_map.shape[2:] == image_B.shape[2:]
# Tag used elsewhere for optimization.
# Must be set at beginning of forward b/c not preserved by .cuda() etc
self.identity_map.isIdentity = True
self.phi_AB = self.regis_net(image_A, image_B)
self.phi_BA = self.regis_net(image_B, image_A)
inverse_consistency_loss = self.compute_gradient_icon_loss(
self.phi_AB, self.phi_BA
)
return self.phi_AB, self.lmbda * inverse_consistency_loss + torch.mean(10000 * (self.phi_AB(self.identity_map) - self.identity_map)**2)
class DiffusionLayerwiseRegularizer(network_wrappers.RegistrationModule):
def __init__(self, network, lmbda):
super().__init__()
self.regis_net = network
self.lmbda = lmbda
def compute_diffusion_loss(self, phi_AB_vectorfield):
phi_AB_vectorfield = self.identity_map - phi_AB_vectorfield
if len(self.identity_map.shape) == 3:
bending_energy = torch.mean((
- phi_AB_vectorfield[:, :, 1:]
+ phi_AB_vectorfield[:, :, 1:-1]
)**2)
elif len(self.identity_map.shape) == 4:
bending_energy = torch.mean((
- phi_AB_vectorfield[:, :, 1:]
+ phi_AB_vectorfield[:, :, :-1]
)**2) + torch.mean((
- phi_AB_vectorfield[:, :, :, 1:]
+ phi_AB_vectorfield[:, :, :, :-1]
)**2)
elif len(self.identity_map.shape) == 5:
bending_energy = torch.mean((
- phi_AB_vectorfield[:, :, 1:]
+ phi_AB_vectorfield[:, :, :-1]
)**2) + torch.mean((
- phi_AB_vectorfield[:, :, :, 1:]
+ phi_AB_vectorfield[:, :, :, :-1]
)**2) + torch.mean((
- phi_AB_vectorfield[:, :, :, :, 1:]
+ phi_AB_vectorfield[:, :, :, :, :-1]
)**2)
return bending_energy * self.identity_map.shape[2] **2
def forward(self, image_A, image_B):
assert self.identity_map.shape[2:] == image_A.shape[2:]
assert self.identity_map.shape[2:] == image_B.shape[2:]
# Tag used elsewhere for optimization.
# Must be set at beginning of forward b/c not preserved by .cuda() etc
self.identity_map.isIdentity = True
self.phi_AB = self.regis_net(image_A, image_B)
diffusion_loss = self.compute_diffusion_loss(self.phi_AB(self.identity_map))
return self.phi_AB, self.lmbda * diffusion_loss
class TwoStepLayerwiseRegularizer(network_wrappers.RegistrationModule):
def __init__(self, phi, psi):
super().__init__()
self.phi = phi
self.psi = psi
def forward(self, image_A, image_B):
phi_AB , loss1 = self.phi(image_A, image_B)
a_circ_phi_AB = self.as_function(image_A)(phi_AB(self.identity_map))
psi_AB, loss2 = self.psi(a_circ_phi_AB, image_B)
return (lambda coords: phi_AB(psi_AB(coords))), loss1 + loss2
class CollectLayerwiseRegularizer(network_wrappers.RegistrationModule):
def __init__(self, network, similarity):
super().__init__()
self.regis_net = network
self.similarity = similarity
def compute_similarity_measure(self, phi_AB_vectorfield, image_A, image_B):
if getattr(self.similarity, "isInterpolated", False):
# tag images during warping so that the similarity measure
# can use information about whether a sample is interpolated
# or extrapolated
inbounds_tag = torch.zeros([image_A.shape[0]] + [1] + list(image_A.shape[2:]), device=image_A.device)
if len(self.input_shape) - 2 == 3:
inbounds_tag[:, :, 1:-1, 1:-1, 1:-1] = 1.0
elif len(self.input_shape) - 2 == 2:
inbounds_tag[:, :, 1:-1, 1:-1] = 1.0
else:
inbounds_tag[:, :, 1:-1] = 1.0
else:
inbounds_tag = None
self.warped_image_A = self.as_function(
torch.cat([image_A, inbounds_tag], axis=1) if inbounds_tag is not None else image_A
)(phi_AB_vectorfield)
similarity_loss = self.similarity(
self.warped_image_A, image_B
)
return similarity_loss
def forward(self, image_A, image_B) -> ICONLoss:
assert self.identity_map.shape[2:] == image_A.shape[2:]
assert self.identity_map.shape[2:] == image_B.shape[2:]
# Tag used elsewhere for optimization.
# Must be set at beginning of forward b/c not preserved by .cuda() etc
self.identity_map.isIdentity = True
self.phi_AB, regularity_loss = self.regis_net(image_A, image_B)
self.phi_AB_vectorfield = self.phi_AB(self.identity_map)
similarity_loss = 2 * self.compute_similarity_measure(
self.phi_AB_vectorfield, image_A, image_B
)
all_loss = regularity_loss + similarity_loss
transform_magnitude = torch.mean(
(self.identity_map - self.phi_AB_vectorfield) ** 2
)
return ICONLoss(
all_loss,
regularity_loss,
similarity_loss,
transform_magnitude,
flips(self.phi_AB_vectorfield),
)
def prepare_for_viz(self, image_A, image_B):
self.phi_AB = self.regis_net(image_A, image_B)
self.phi_AB_vectorfield = self.phi_AB(self.identity_map)
self.phi_BA = self.regis_net(image_B, image_A)
self.phi_BA_vectorfield = self.phi_BA(self.identity_map)
self.warped_image_A = self.as_function(image_A)(self.phi_AB_vectorfield)
self.warped_image_B = self.as_function(image_B)(self.phi_BA_vectorfield)
In [36]:
from icon_registration.mermaidlite import identity_map_multiN
def make_im(input_shape):
input_shape = np.array(input_shape)
input_shape[0] = 1
spacing = 1.0 / (input_shape[2::] - 1)
_id = identity_map_multiN(input_shape, spacing)
return _id
def pad_im(im, n):
new_shape = np.array(im.shape)
old_shape = np.array(im.shape)
new_shape[2:] += 2 * n
new_im = np.array(make_im(new_shape))
if len(new_shape) == 4:
def expand(t):
return t[None, 2:, None, None]
else:
def expand(t):
return t[None, 2:, None, None, None]
new_im *= expand((new_shape - 1 )) / expand((old_shape - 1))
new_im -= n / expand((old_shape - 1))
new_im = torch.tensor(new_im)
return new_im
class AttentionRegistration(icon_registration.RegistrationModule):
def __init__(self, net):
super().__init__()
self.net = net
self.dim = 12
self.blur_kernel = torch.nn.Conv2d(2, 2, 5, padding="same", bias=False, groups=2)
self.padding = 9
def crop(self, x):
padding = self.padding
return x[:, :, padding:-padding, padding:-padding]
def featurize(self, values, recrop=True):
padding = self.padding
x = torch.nn.functional.pad(values, [padding,padding,padding,padding])
x = self.net(x)
x = 4 * x / (.001 + torch.sqrt(torch.sum(x**2, dim=1, keepdims=True)))
if recrop:
x = self.crop(x)
return x
def torch_attention(self, ft_A, ft_B):
ft_A = ft_A.reshape(-1, 1, self.dim, (self.identity_map.shape[-1] + 2 * self.padding) * (self.identity_map.shape[-2] + 2 * self.padding))
ft_B = ft_B.reshape(-1, 1, self.dim, self.identity_map.shape[-1] * self.identity_map.shape[-2])
ft_A = ft_A.permute([0, 1, 3, 2]).contiguous()
ft_B = ft_B.permute([0, 1, 3, 2]).contiguous()
im = pad_im(self.identity_map, self.padding).to(ft_A.device)
x = im.reshape(-1, 1, 2, ft_A.shape[2]).permute(0, 1, 3, 2)
# Requirements for memory efficient attention kernel:
# ft_A, ft_B, x all four dimensional [batch, head, len, feature]
# feature dimension all divisible by four
# tensors all contiguous
# batch dimension must match (cannot broadcast)
x = torch.cat([x, x], axis=-1)
x = x.expand(10, -1, -1, -1)
with torch.backends.cuda.sdp_kernel(enable_math=False):
output = torch.nn.functional.scaled_dot_product_attention(ft_B, ft_A, x, scale=1)
output = output[:, :, :, 2:]
output = output.permute(0, 1, 3, 2)
output = output.reshape(-1, 2, self.identity_map.shape[2], self.identity_map.shape[3])
return output
def forward(self, A, B):
ft_A = self.featurize(A, recrop=False)
ft_B = self.featurize(B)
output = self.torch_attention(ft_A, ft_B)
output = output - self.identity_map
#output = self.blur_kernel(output)
return output
ar = AttentionRegistration(unet)
ar.cuda()
0
Out[36]:
0
In [ ]:
In [37]:
inner_net = icon.network_wrappers.DownsampleRegistration(
icon.network_wrappers.DownsampleRegistration(
icon.FunctionFromVectorField(ar), 2), 2)
#inner_net = icon.FunctionFromVectorField(ar)
inner_net.assign_identity_map(sample_batch.shape)
inner_net.cuda()
0
Out[37]:
0
In [38]:
ts = icon.TwoStepRegistration(
icon.FunctionFromVectorField(icon.networks.tallUNet2(dimension=2)),
icon.FunctionFromVectorField(icon.networks.tallUNet2(dimension=2)))
ts = TwoStepLayerwiseRegularizer(
DiffusionLayerwiseRegularizer(inner_net, 1.5),
GradientICONLayerwiseRegularizer(ts, 1.5))
In [39]:
net = CollectLayerwiseRegularizer(ts, icon.LNCC(sigma=4))
#net = icon.losses.GradientICON(ts, icon.LNCC(sigma=4), lmbda=1.5)
#net = icon.losses.BendingEnergyNet(ts, icon.LNCC(sigma=4), lmbda=6e-4)
net.assign_identity_map(sample_batch.shape)
net.cuda()
0
Out[39]:
0
In [40]:
def show(tensor):
plt.imshow(torchvision.utils.make_grid(tensor[:6], nrow=3)[0].cpu().detach())
plt.xticks([])
plt.yticks([])
image_A = next(iter(ds1))[0].to(device)
image_B = next(iter(ds2))[0].to(device)
net(image_A, image_B)
plt.subplot(2, 2, 1)
show(image_A)
plt.subplot(2, 2, 2)
show(image_B)
plt.subplot(2, 2, 3)
show(net.warped_image_A)
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[0].cpu().detach())
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[1].cpu().detach())
plt.subplot(2, 2, 4)
show(net.warped_image_A - image_B)
plt.tight_layout()
In [41]:
show(net.phi_AB_vectorfield[:, 1])
plt.colorbar()
Out[41]:
<matplotlib.colorbar.Colorbar at 0x7f667b11abe0>
In [42]:
8
Out[42]:
8
In [43]:
net.train()
net.to(device)
optim = torch.optim.Adam(net.parameters(), lr=0.0003)
curves = icon.train_datasets(net, optim, ds1, ds2, epochs=5)
plt.close()
plt.plot(np.array(curves)[:, :3])
100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:20<00:00, 4.15s/it]
Out[43]:
[<matplotlib.lines.Line2D at 0x7f667b068280>, <matplotlib.lines.Line2D at 0x7f667b05a100>, <matplotlib.lines.Line2D at 0x7f667b05a130>]
In [44]:
curves = icon.train_datasets(net, optim, ds1, ds2, epochs=45)
plt.close()
plt.plot(np.array(curves)[:, :3])
100%|██████████████████████████████████████████████████████████████████████████████████| 45/45 [03:05<00:00, 4.13s/it]
Out[44]:
[<matplotlib.lines.Line2D at 0x7f6573e22dc0>, <matplotlib.lines.Line2D at 0x7f6573e22e20>, <matplotlib.lines.Line2D at 0x7f6573e22e50>]
In [45]:
curves = icon.train_datasets(net, optim, ds1, ds2, epochs=45)
plt.close()
plt.plot(np.array(curves)[:, :3])
100%|██████████████████████████████████████████████████████████████████████████████████| 45/45 [03:06<00:00, 4.14s/it]
Out[45]:
[<matplotlib.lines.Line2D at 0x7f6573e52100>, <matplotlib.lines.Line2D at 0x7f6573e52160>, <matplotlib.lines.Line2D at 0x7f6573e52190>]
In [46]:
plt.figure(figsize=(20, 20))
def show(tensor):
plt.imshow(torchvision.utils.make_grid(tensor[:6], nrow=3)[0].cpu().detach())
plt.xticks([])
plt.yticks([])
image_A = next(iter(ds1))[0].to(device)
image_B = next(iter(ds2))[0].to(device)
net(image_A, image_B)
plt.subplot(2, 2, 1)
show(image_A)
plt.subplot(2, 2, 2)
show(image_B)
plt.subplot(2, 2, 3)
show(net.warped_image_A)
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[0].cpu().detach())
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[1].cpu().detach())
plt.subplot(2, 2, 4)
show(net.warped_image_A - image_B)
plt.tight_layout()
In [47]:
firststep_net = icon.losses.DiffusionRegularizedNet(inner_net, icon.LNCC(sigma=4), lmbda=1.5)
firststep_net.assign_identity_map(sample_batch.shape)
firststep_net.cuda()
plt.figure(figsize=(20, 20))
def show(tensor):
plt.imshow(torchvision.utils.make_grid(tensor[:6], nrow=3)[0].cpu().detach())
plt.xticks([])
plt.yticks([])
image_A = next(iter(ds1))[0].to(device)
image_B = next(iter(ds2))[0].to(device)
firststep_net(image_A, image_B)
plt.subplot(2, 2, 1)
show(image_A)
plt.subplot(2, 2, 2)
show(image_B)
plt.subplot(2, 2, 3)
show(firststep_net.warped_image_A)
plt.contour(torchvision.utils.make_grid(firststep_net.phi_AB_vectorfield[:6], nrow=3)[0].cpu().detach())
plt.contour(torchvision.utils.make_grid(firststep_net.phi_AB_vectorfield[:6], nrow=3)[1].cpu().detach())
plt.subplot(2, 2, 4)
show(firststep_net.warped_image_A - image_B)
plt.tight_layout()
In [ ]:
In [ ]:
In [ ]:
import math
math.sqrt(1332.)
In [ ]:
37 * 36
In [ ]:
features = unet(torch.nn.functional.avg_pool2d(image_A, 2).cuda())
In [ ]:
for i in range(14):
show(features[:, i:])
plt.show()
In [ ]:
In [ ]:
sample_batch.shape
In [ ]:
In [ ]:
In [ ]: