In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
In [2]:
import icon_registration
In [3]:
!ntop
Sun Jan 14 09:14:13 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08              Driver Version: 545.23.08    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  Quadro RTX 6000                On  | 00000000:17:00.0 Off |                  Off |
| 33%   30C    P8               8W / 260W |   3148MiB / 24576MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce GTX 1050        On  | 00000000:B3:00.0 Off |                  N/A |
| 30%   26C    P8              N/A /  75W |    471MiB /  2048MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
|  GPU   PID     USER    GPU MEM  %CPU  %MEM      TIME  COMMAND               |
|    0 466061     root       4MiB   0.0   0.1   59 days  /usr/lib/xorg/Xorg v  |
|    0 3356427   tgreer    3140MiB   0.1   3.1   47 days  /playpen-raid1/tgree  |
|    1 466061     root     192MiB   0.0   0.1   59 days  /usr/lib/xorg/Xorg v  |
|    1 466294   tgreer     120MiB   0.1   0.7   59 days  /usr/bin/gnome-shell  |
|    1 466302   tgreer       1MiB   0.1   0.0   59 days  /opt/teamviewer/tv_b  |
|    1 1403059   tgreer     150MiB   4.2   1.9    5 days  /usr/lib/firefox/fir  |
+---------------------------------------------------------------------------------------+

In [4]:
import icon_registration as icon
import icon_registration.data
import icon_registration.networks as networks
from icon_registration.config import device

import numpy as np
import torch
import torchvision.utils
import matplotlib.pyplot as plt
In [5]:
ds1, ds2 = icon_registration.data.get_dataset_retina(include_boundary=False, scale=.8, fixed_vertical_offset=200)

sample_batch = next(iter(ds2))[0]
plt.imshow(torchvision.utils.make_grid(sample_batch[:12], nrow=4)[0])
/playpen-raid1/tgreer/flash_attention/venv/lib/python3.8/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.8.14) is available. It's recommended that you update to the latest version using `pip install -U deeplake`.
  warnings.warn(
Out[5]:
<matplotlib.image.AxesImage at 0x7f657ca089a0>
No description has been provided for this image
In [6]:
icon_registration.data.get_dataset_retina.__globals__['__file__']
Out[6]:
'/playpen-raid1/tgreer/flash_attention/venv/lib/python3.8/site-packages/icon_registration/data.py'
In [34]:
import alias_free_unet
import importlib

import icon_registration.network_wrappers as network_wrappers
importlib.reload(alias_free_unet)
unet = alias_free_unet.GenericUNet(
    input_channels=1, 
    output_channels=64, 
    num_layers=3,
    channels=[[None, 16, 32, 64], [16, 32, 64]],
    init_zero=False, 
    regis_scale=False)
import unet as unet_
unet = unet_.GenericUNet(
    input_channels=1, 
    output_channels=12, 
    num_layers=3,
    channels=[[None, 64, 64, 64], [64, 64, 64]],
    init_zero=False, 
    regis_scale=False)
unet = alias_free_unet.NoDownsampleNet(output_dim=12)
unet.cuda()
image_A = sample_batch.cuda()
ufeatures = unet(image_A)
In [35]:
from icon_registration.losses import ICONLoss, flips

class GradientICONLayerwiseRegularizer(network_wrappers.RegistrationModule):
    def __init__(self, network, lmbda):
        super().__init__()
        self.regis_net = network
        self.lmbda = lmbda

    def compute_gradient_icon_loss(self, phi_AB, phi_BA):
        Iepsilon = (
            self.identity_map
            + torch.randn(*self.identity_map.shape).to(self.identity_map.device)
            * 1
            / self.identity_map.shape[-1]
        )

        # compute squared Frobenius of Jacobian of icon error

        direction_losses = []

        approximate_Iepsilon = phi_AB(phi_BA(Iepsilon))

        inverse_consistency_error = Iepsilon - approximate_Iepsilon

        delta = 0.001

        if len(self.identity_map.shape) == 4:
            dx = torch.Tensor([[[[delta]], [[0.0]]]]).to(self.identity_map.device)
            dy = torch.Tensor([[[[0.0]], [[delta]]]]).to(self.identity_map.device)
            direction_vectors = (dx, dy)

        elif len(self.identity_map.shape) == 5:
            dx = torch.Tensor([[[[[delta]]], [[[0.0]]], [[[0.0]]]]]).to(
                self.identity_map.device
            )
            dy = torch.Tensor([[[[[0.0]]], [[[delta]]], [[[0.0]]]]]).to(
                self.identity_map.device
            )
            dz = torch.Tensor([[[[0.0]]], [[[0.0]]], [[[delta]]]]).to(
                self.identity_map.device
            )
            direction_vectors = (dx, dy, dz)
        elif len(self.identity_map.shape) == 3:
            dx = torch.Tensor([[[delta]]]).to(self.identity_map.device)
            direction_vectors = (dx,)

        for d in direction_vectors:
            approximate_Iepsilon_d = phi_AB(phi_BA(Iepsilon + d))
            inverse_consistency_error_d = Iepsilon + d - approximate_Iepsilon_d
            grad_d_icon_error = (
                inverse_consistency_error - inverse_consistency_error_d
            ) / delta
            direction_losses.append(torch.mean(grad_d_icon_error**2))

        inverse_consistency_loss = sum(direction_losses)

        return inverse_consistency_loss


    def forward(self, image_A, image_B):

        assert self.identity_map.shape[2:] == image_A.shape[2:]
        assert self.identity_map.shape[2:] == image_B.shape[2:]

        # Tag used elsewhere for optimization.
        # Must be set at beginning of forward b/c not preserved by .cuda() etc
        self.identity_map.isIdentity = True

        self.phi_AB = self.regis_net(image_A, image_B)
        self.phi_BA = self.regis_net(image_B, image_A)

        inverse_consistency_loss = self.compute_gradient_icon_loss(
            self.phi_AB, self.phi_BA
        )

        return self.phi_AB, self.lmbda * inverse_consistency_loss  + torch.mean(10000 * (self.phi_AB(self.identity_map) - self.identity_map)**2)

class DiffusionLayerwiseRegularizer(network_wrappers.RegistrationModule):

    def __init__(self, network, lmbda):
        super().__init__()
        self.regis_net = network
        self.lmbda = lmbda
        
    def compute_diffusion_loss(self, phi_AB_vectorfield):
        phi_AB_vectorfield = self.identity_map - phi_AB_vectorfield
        if len(self.identity_map.shape) == 3:
            bending_energy = torch.mean((
                - phi_AB_vectorfield[:, :, 1:]
                + phi_AB_vectorfield[:, :, 1:-1]
            )**2)

        elif len(self.identity_map.shape) == 4:
            bending_energy = torch.mean((
                - phi_AB_vectorfield[:, :, 1:]
                + phi_AB_vectorfield[:, :, :-1]
            )**2) + torch.mean((
                - phi_AB_vectorfield[:, :, :, 1:]
                + phi_AB_vectorfield[:, :, :, :-1]
            )**2)
        elif len(self.identity_map.shape) == 5:
            bending_energy = torch.mean((
                - phi_AB_vectorfield[:, :, 1:]
                + phi_AB_vectorfield[:, :, :-1]
            )**2) + torch.mean((
                - phi_AB_vectorfield[:, :, :, 1:]
                + phi_AB_vectorfield[:, :, :, :-1]
            )**2) + torch.mean((
                - phi_AB_vectorfield[:, :, :, :, 1:]
                + phi_AB_vectorfield[:, :, :, :, :-1]
            )**2)


        return bending_energy * self.identity_map.shape[2] **2
    def forward(self, image_A, image_B):

        assert self.identity_map.shape[2:] == image_A.shape[2:]
        assert self.identity_map.shape[2:] == image_B.shape[2:]

        # Tag used elsewhere for optimization.
        # Must be set at beginning of forward b/c not preserved by .cuda() etc
        self.identity_map.isIdentity = True

        self.phi_AB = self.regis_net(image_A, image_B)

        diffusion_loss = self.compute_diffusion_loss(self.phi_AB(self.identity_map))

        return self.phi_AB, self.lmbda * diffusion_loss

class TwoStepLayerwiseRegularizer(network_wrappers.RegistrationModule):
    def __init__(self, phi, psi):
        super().__init__()
        self.phi = phi
        self.psi = psi
    def forward(self, image_A, image_B):
        phi_AB , loss1 = self.phi(image_A, image_B)
        a_circ_phi_AB = self.as_function(image_A)(phi_AB(self.identity_map))
        psi_AB, loss2 = self.psi(a_circ_phi_AB, image_B)
        return (lambda coords: phi_AB(psi_AB(coords))), loss1 + loss2

class CollectLayerwiseRegularizer(network_wrappers.RegistrationModule):
    def __init__(self, network, similarity):
        super().__init__()

        self.regis_net = network
        self.similarity = similarity

    def compute_similarity_measure(self, phi_AB_vectorfield, image_A, image_B):

        if getattr(self.similarity, "isInterpolated", False):
            # tag images during warping so that the similarity measure
            # can use information about whether a sample is interpolated
            # or extrapolated
            inbounds_tag = torch.zeros([image_A.shape[0]] + [1] + list(image_A.shape[2:]), device=image_A.device)
            if len(self.input_shape) - 2 == 3:
                inbounds_tag[:, :, 1:-1, 1:-1, 1:-1] = 1.0
            elif len(self.input_shape) - 2 == 2:
                inbounds_tag[:, :, 1:-1, 1:-1] = 1.0
            else:
                inbounds_tag[:, :, 1:-1] = 1.0
        else:
            inbounds_tag = None

        self.warped_image_A = self.as_function(
            torch.cat([image_A, inbounds_tag], axis=1) if inbounds_tag is not None else image_A
        )(phi_AB_vectorfield)
        
        similarity_loss = self.similarity(
            self.warped_image_A, image_B
        )
        return similarity_loss

    def forward(self, image_A, image_B) -> ICONLoss:

        assert self.identity_map.shape[2:] == image_A.shape[2:]
        assert self.identity_map.shape[2:] == image_B.shape[2:]

        # Tag used elsewhere for optimization.
        # Must be set at beginning of forward b/c not preserved by .cuda() etc
        self.identity_map.isIdentity = True

        self.phi_AB, regularity_loss = self.regis_net(image_A, image_B)
        self.phi_AB_vectorfield = self.phi_AB(self.identity_map)
        
        similarity_loss = 2 * self.compute_similarity_measure(
            self.phi_AB_vectorfield, image_A, image_B
        )



        all_loss = regularity_loss + similarity_loss

        transform_magnitude = torch.mean(
            (self.identity_map - self.phi_AB_vectorfield) ** 2
        )
        return ICONLoss(
            all_loss,
            regularity_loss,
            similarity_loss,
            transform_magnitude,
            flips(self.phi_AB_vectorfield),
        )

    def prepare_for_viz(self, image_A, image_B):
        self.phi_AB = self.regis_net(image_A, image_B)
        self.phi_AB_vectorfield = self.phi_AB(self.identity_map)
        self.phi_BA = self.regis_net(image_B, image_A)
        self.phi_BA_vectorfield = self.phi_BA(self.identity_map)

        self.warped_image_A = self.as_function(image_A)(self.phi_AB_vectorfield)
        self.warped_image_B = self.as_function(image_B)(self.phi_BA_vectorfield)
In [36]:
from icon_registration.mermaidlite import identity_map_multiN


def make_im(input_shape):

    input_shape = np.array(input_shape)
    input_shape[0] = 1
    spacing = 1.0 / (input_shape[2::] - 1)

    _id = identity_map_multiN(input_shape, spacing)
    return _id

def pad_im(im, n):
    new_shape = np.array(im.shape)
    old_shape = np.array(im.shape)

    new_shape[2:] += 2 * n

    new_im = np.array(make_im(new_shape))

    if len(new_shape) == 4:
        def expand(t):
            return t[None, 2:, None, None]
    else:
        def expand(t):
            return t[None, 2:, None, None, None]

    new_im *= expand((new_shape - 1 )) / expand((old_shape - 1))

    new_im -= n / expand((old_shape - 1))

    new_im = torch.tensor(new_im)

    return new_im

class AttentionRegistration(icon_registration.RegistrationModule):
    def __init__(self, net):
        super().__init__()
        self.net = net
        self.dim = 12
        
        self.blur_kernel = torch.nn.Conv2d(2, 2, 5, padding="same", bias=False, groups=2)
        self.padding = 9

    def crop(self, x):
        padding = self.padding
        return x[:, :, padding:-padding, padding:-padding]
    
    def featurize(self, values, recrop=True):       
        padding = self.padding
        x = torch.nn.functional.pad(values, [padding,padding,padding,padding])
        x = self.net(x)        
        x = 4 * x / (.001 + torch.sqrt(torch.sum(x**2, dim=1, keepdims=True)))        
        if recrop:
            x = self.crop(x)
        return x

    def torch_attention(self, ft_A, ft_B):

        

        ft_A = ft_A.reshape(-1, 1, self.dim, (self.identity_map.shape[-1] + 2 * self.padding) * (self.identity_map.shape[-2] + 2 * self.padding))
        ft_B = ft_B.reshape(-1, 1, self.dim, self.identity_map.shape[-1] * self.identity_map.shape[-2])

        ft_A = ft_A.permute([0, 1, 3, 2]).contiguous()
        ft_B = ft_B.permute([0, 1, 3, 2]).contiguous()

        im = pad_im(self.identity_map, self.padding).to(ft_A.device)
        x = im.reshape(-1, 1, 2, ft_A.shape[2]).permute(0, 1, 3, 2)

        # Requirements for memory efficient attention kernel:
        # ft_A, ft_B, x all four dimensional [batch, head, len, feature]
        # feature dimension all divisible by four
        # tensors all contiguous
        # batch dimension must match (cannot broadcast)
        x = torch.cat([x, x], axis=-1)
        x = x.expand(10, -1, -1, -1)
        with torch.backends.cuda.sdp_kernel(enable_math=False):
            output = torch.nn.functional.scaled_dot_product_attention(ft_B, ft_A, x, scale=1)
        output = output[:, :, :, 2:]
        output = output.permute(0, 1, 3, 2)
        output = output.reshape(-1, 2, self.identity_map.shape[2], self.identity_map.shape[3])    
        return output
        
    def forward(self, A, B):
        ft_A = self.featurize(A, recrop=False)   
        ft_B = self.featurize(B)

        output = self.torch_attention(ft_A, ft_B)    
        output = output  - self.identity_map  
        #output = self.blur_kernel(output)             
        return output
ar = AttentionRegistration(unet)
ar.cuda()
0
Out[36]:
0
In [ ]:
 
In [37]:
inner_net = icon.network_wrappers.DownsampleRegistration(
icon.network_wrappers.DownsampleRegistration(
    icon.FunctionFromVectorField(ar), 2), 2)

#inner_net = icon.FunctionFromVectorField(ar)
inner_net.assign_identity_map(sample_batch.shape)
inner_net.cuda()
0
Out[37]:
0
In [38]:
ts = icon.TwoStepRegistration(
    icon.FunctionFromVectorField(icon.networks.tallUNet2(dimension=2)),
    icon.FunctionFromVectorField(icon.networks.tallUNet2(dimension=2)))

ts = TwoStepLayerwiseRegularizer(
    DiffusionLayerwiseRegularizer(inner_net, 1.5),
    GradientICONLayerwiseRegularizer(ts, 1.5))
    
In [39]:
net = CollectLayerwiseRegularizer(ts, icon.LNCC(sigma=4))
#net = icon.losses.GradientICON(ts, icon.LNCC(sigma=4), lmbda=1.5)

#net = icon.losses.BendingEnergyNet(ts, icon.LNCC(sigma=4), lmbda=6e-4)
net.assign_identity_map(sample_batch.shape)
net.cuda()
0
Out[39]:
0
In [40]:
def show(tensor):
    plt.imshow(torchvision.utils.make_grid(tensor[:6], nrow=3)[0].cpu().detach())
    plt.xticks([])
    plt.yticks([])
image_A = next(iter(ds1))[0].to(device)
image_B = next(iter(ds2))[0].to(device)
net(image_A, image_B)
plt.subplot(2, 2, 1)
show(image_A)
plt.subplot(2, 2, 2)
show(image_B)
plt.subplot(2, 2, 3)
show(net.warped_image_A)
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[0].cpu().detach())
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[1].cpu().detach())
plt.subplot(2, 2, 4)
show(net.warped_image_A - image_B)
plt.tight_layout()
No description has been provided for this image
In [41]:
show(net.phi_AB_vectorfield[:, 1])
plt.colorbar()
Out[41]:
<matplotlib.colorbar.Colorbar at 0x7f667b11abe0>
No description has been provided for this image
In [42]:
8
Out[42]:
8
In [43]:
net.train()
net.to(device)
optim = torch.optim.Adam(net.parameters(), lr=0.0003)
curves = icon.train_datasets(net, optim, ds1, ds2, epochs=5)
plt.close()
plt.plot(np.array(curves)[:, :3])
100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:20<00:00,  4.15s/it]
Out[43]:
[<matplotlib.lines.Line2D at 0x7f667b068280>,
 <matplotlib.lines.Line2D at 0x7f667b05a100>,
 <matplotlib.lines.Line2D at 0x7f667b05a130>]
No description has been provided for this image
In [44]:
curves = icon.train_datasets(net, optim, ds1, ds2, epochs=45)
plt.close()
plt.plot(np.array(curves)[:, :3])
100%|██████████████████████████████████████████████████████████████████████████████████| 45/45 [03:05<00:00,  4.13s/it]
Out[44]:
[<matplotlib.lines.Line2D at 0x7f6573e22dc0>,
 <matplotlib.lines.Line2D at 0x7f6573e22e20>,
 <matplotlib.lines.Line2D at 0x7f6573e22e50>]
No description has been provided for this image
In [45]:
curves = icon.train_datasets(net, optim, ds1, ds2, epochs=45)
plt.close()
plt.plot(np.array(curves)[:, :3])
100%|██████████████████████████████████████████████████████████████████████████████████| 45/45 [03:06<00:00,  4.14s/it]
Out[45]:
[<matplotlib.lines.Line2D at 0x7f6573e52100>,
 <matplotlib.lines.Line2D at 0x7f6573e52160>,
 <matplotlib.lines.Line2D at 0x7f6573e52190>]
No description has been provided for this image
In [46]:
plt.figure(figsize=(20, 20))


def show(tensor):
    plt.imshow(torchvision.utils.make_grid(tensor[:6], nrow=3)[0].cpu().detach())
    plt.xticks([])
    plt.yticks([])
image_A = next(iter(ds1))[0].to(device)
image_B = next(iter(ds2))[0].to(device)
net(image_A, image_B)
plt.subplot(2, 2, 1)
show(image_A)
plt.subplot(2, 2, 2)
show(image_B)
plt.subplot(2, 2, 3)
show(net.warped_image_A)
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[0].cpu().detach())
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[1].cpu().detach())
plt.subplot(2, 2, 4)
show(net.warped_image_A - image_B)
plt.tight_layout()
No description has been provided for this image
In [47]:
firststep_net = icon.losses.DiffusionRegularizedNet(inner_net, icon.LNCC(sigma=4), lmbda=1.5)
firststep_net.assign_identity_map(sample_batch.shape)
firststep_net.cuda()

plt.figure(figsize=(20, 20))


def show(tensor):
    plt.imshow(torchvision.utils.make_grid(tensor[:6], nrow=3)[0].cpu().detach())
    plt.xticks([])
    plt.yticks([])
image_A = next(iter(ds1))[0].to(device)
image_B = next(iter(ds2))[0].to(device)
firststep_net(image_A, image_B)
plt.subplot(2, 2, 1)
show(image_A)
plt.subplot(2, 2, 2)
show(image_B)
plt.subplot(2, 2, 3)
show(firststep_net.warped_image_A)
plt.contour(torchvision.utils.make_grid(firststep_net.phi_AB_vectorfield[:6], nrow=3)[0].cpu().detach())
plt.contour(torchvision.utils.make_grid(firststep_net.phi_AB_vectorfield[:6], nrow=3)[1].cpu().detach())
plt.subplot(2, 2, 4)
show(firststep_net.warped_image_A - image_B)
plt.tight_layout()
No description has been provided for this image
In [ ]:
 
In [ ]:

In [ ]:
import math
math.sqrt(1332.)
In [ ]:
37 * 36
In [ ]:
features = unet(torch.nn.functional.avg_pool2d(image_A, 2).cuda())
In [ ]:
for i in range(14):
    show(features[:, i:])
    plt.show()
In [ ]:
 
In [ ]:
sample_batch.shape
In [ ]:
 
In [ ]:
 
In [ ]: