In [1]:
import icon_registration as icon
import icon_registration.data
import icon_registration.networks as networks
from icon_registration.config import device
import numpy as np
import torch
import torchvision.utils
import matplotlib.pyplot as plt
In [2]:
ds, _ = icon_registration.data.get_dataset_mnist(split="train", number=2)
sample_batch = next(iter(ds))[0]
plt.imshow(torchvision.utils.make_grid(sample_batch[:12], nrow=4)[0])
Out[2]:
<matplotlib.image.AxesImage at 0x7fc8f0482e30>
In [3]:
import icon_registration.carl as carl
class Equivariantize(torch.nn.Module):
def __init__(self, net):
super().__init__()
self.net = net
def forward(self, a):
i = self.net(a)
i = i + self.net(a.flip(dims=(2,))).flip(dims=(2,))
i = i + self.net(a.flip(dims=(3,))).flip(dims=(3,))
return i / 2
unet = Equivariantize(carl.NoDownsampleNet(dimension=2))
ar = carl.AttentionRegistration(unet, dimension=2)
ts = icon.network_wrappers.DownsampleNet(icon.FunctionFromVectorField(ar), 2)
net = icon.losses.DiffusionRegularizedNet(ts, icon.LNCC(sigma=4), lmbda=10)
net.assign_identity_map(sample_batch.shape)
net = carl.augmentify(net)
In [4]:
net.assign_identity_map(sample_batch.shape)
In [8]:
net.train()
net.to(device)
optim = torch.optim.Adam(net.parameters(), lr=0.001)
curves = icon.train_datasets(net, optim, ds, ds, epochs=10)
plt.close()
plt.plot(np.array(curves)[:, :3])
100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [11:39<00:00, 69.98s/it]
Out[8]:
[<matplotlib.lines.Line2D at 0x7fc8981831f0>, <matplotlib.lines.Line2D at 0x7fc898183340>, <matplotlib.lines.Line2D at 0x7fc898183490>]
In [9]:
plt.close()
def show(tensor):
plt.imshow(torchvision.utils.make_grid(tensor[:6], nrow=3)[0].cpu().detach())
plt.xticks([])
plt.yticks([])
image_A = next(iter(ds))[0].to(device)
image_B = next(iter(ds))[0].to(device)
net(image_A, image_B)
plt.subplot(2, 2, 1)
show(image_A)
plt.subplot(2, 2, 2)
show(image_B)
plt.subplot(2, 2, 3)
show(net.warped_image_A)
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[0].cpu().detach())
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[1].cpu().detach())
plt.subplot(2, 2, 4)
show(net.warped_image_A - image_B)
plt.tight_layout()
In [10]:
ds, _ = icon_registration.data.get_dataset_mnist(split="train", number=5)
curves = icon.train_datasets(net, optim, ds, ds, epochs=1)
plt.close()
plt.plot(np.array(curves)[:, :3])
ds, _ = icon_registration.data.get_dataset_mnist(split="train", number=8)
curves = icon.train_datasets(net, optim, ds, ds, epochs=1)
plt.close()
plt.plot(np.array(curves)[:, :3])
100%|█████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:56<00:00, 56.13s/it] 100%|█████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:59<00:00, 59.09s/it]
Out[10]:
[<matplotlib.lines.Line2D at 0x7fc8880e3d60>, <matplotlib.lines.Line2D at 0x7fc8880e3eb0>, <matplotlib.lines.Line2D at 0x7fc8880e3d30>]
In [11]:
ds, _ = icon_registration.data.get_dataset_mnist(split="train", number=6)
image_A = next(iter(ds))[0].to(device)[:1]
image_B = next(iter(ds))[0].to(device)[:1]
net(image_A, image_B)
plt.subplot(2, 2, 1)
show(image_A)
plt.subplot(2, 2, 2)
show(image_B)
plt.subplot(2, 2, 3)
show(net.warped_image_A)
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[0].cpu().detach())
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[1].cpu().detach())
plt.subplot(2, 2, 4)
show(net.warped_image_A - image_B)
plt.tight_layout()
plt.show()
ds, _ = icon_registration.data.get_dataset_mnist(split="train", number=1)
image_A = next(iter(ds))[0].to(device)[:12]
image_B = next(iter(ds))[0].to(device)[:12]
net(image_A, image_B)
plt.subplot(2, 2, 1)
show(image_A)
plt.subplot(2, 2, 2)
show(image_B)
plt.subplot(2, 2, 3)
show(net.warped_image_A)
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[0].cpu().detach())
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[1].cpu().detach())
plt.subplot(2, 2, 4)
show(net.warped_image_A - image_B)
plt.tight_layout()
In [ ]:
In [ ]: