In [1]:
import icon_registration as icon
import icon_registration.data
import icon_registration.networks as networks
from icon_registration.config import device
import numpy as np
import torch
import torchvision.utils
import matplotlib.pyplot as plt
In [2]:
ds, _ = icon_registration.data.get_dataset_mnist(split="train", number=2)
sample_batch = next(iter(ds))[0]
plt.imshow(torchvision.utils.make_grid(sample_batch[:12], nrow=4)[0])
Out[2]:
<matplotlib.image.AxesImage at 0x7f19de2434f0>
In [3]:
inner_net = icon.FunctionFromVectorField(networks.tallUNet2(dimension=2))
for _ in range(3):
inner_net = icon.TwoStepRegistration(
icon.DownsampleRegistration(inner_net, dimension=2),
icon.FunctionFromVectorField(networks.tallUNet2(dimension=2))
)
net = icon.GradientICON(inner_net, icon.LNCC(sigma=4), lmbda=.5)
In [4]:
net.assign_identity_map(sample_batch.shape)
In [5]:
net.train()
net.to(device)
optim = torch.optim.Adam(net.parameters(), lr=0.001)
curves = icon.train_datasets(net, optim, ds, ds, epochs=5)
plt.close()
plt.plot(np.array(curves)[:, :3])
100%|█████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:56<00:00, 11.25s/it]
Out[5]:
[<matplotlib.lines.Line2D at 0x7f19dc1d7460>, <matplotlib.lines.Line2D at 0x7f19dc1d75b0>, <matplotlib.lines.Line2D at 0x7f19dc1d7700>]
In [6]:
plt.close()
def show(tensor):
plt.imshow(torchvision.utils.make_grid(tensor[:6], nrow=3)[0].cpu().detach())
plt.xticks([])
plt.yticks([])
image_A = next(iter(ds))[0].to(device)
image_B = next(iter(ds))[0].to(device)
net(image_A, image_B)
plt.subplot(2, 2, 1)
show(image_A)
plt.subplot(2, 2, 2)
show(image_B)
plt.subplot(2, 2, 3)
show(net.warped_image_A)
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[0].cpu().detach())
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[1].cpu().detach())
plt.subplot(2, 2, 4)
show(net.warped_image_A - image_B)
plt.tight_layout()
In [7]:
ds, _ = icon_registration.data.get_dataset_mnist(split="train", number=5)
curves = icon.train_datasets(net, optim, ds, ds, epochs=5)
plt.close()
plt.plot(np.array(curves)[:, :3])
ds, _ = icon_registration.data.get_dataset_mnist(split="train", number=8)
curves = icon.train_datasets(net, optim, ds, ds, epochs=5)
plt.close()
plt.plot(np.array(curves)[:, :3])
100%|█████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:49<00:00, 9.91s/it] 100%|█████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:55<00:00, 11.09s/it]
Out[7]:
[<matplotlib.lines.Line2D at 0x7f19db7df8b0>, <matplotlib.lines.Line2D at 0x7f19db7dfa00>, <matplotlib.lines.Line2D at 0x7f19db7dfb50>]
In [8]:
ds, _ = icon_registration.data.get_dataset_mnist(split="train", number=6)
image_A = next(iter(ds))[0].to(device)
image_B = next(iter(ds))[0].to(device)
net(image_A, image_B)
plt.subplot(2, 2, 1)
show(image_A)
plt.subplot(2, 2, 2)
show(image_B)
plt.subplot(2, 2, 3)
show(net.warped_image_A)
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[0].cpu().detach())
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[1].cpu().detach())
plt.subplot(2, 2, 4)
show(net.warped_image_A - image_B)
plt.tight_layout()
plt.show()
ds, _ = icon_registration.data.get_dataset_mnist(split="train", number=1)
image_A = next(iter(ds))[0].to(device)
image_B = next(iter(ds))[0].to(device)
net(image_A, image_B)
plt.subplot(2, 2, 1)
show(image_A)
plt.subplot(2, 2, 2)
show(image_B)
plt.subplot(2, 2, 3)
show(net.warped_image_A)
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[0].cpu().detach())
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[1].cpu().detach())
plt.subplot(2, 2, 4)
show(net.warped_image_A - image_B)
plt.tight_layout()
In [ ]:
In [ ]: