In [2]:
import icon_registration as icon
import icon_registration.data
import icon_registration.networks as networks
from icon_registration.config import device
import numpy as np
import torch
import torchvision.utils
import matplotlib.pyplot as plt
In [3]:
ds, _ = icon_registration.data.get_dataset_mnist(split="train", number=2)
sample_batch = next(iter(ds))[0]
plt.imshow(torchvision.utils.make_grid(sample_batch[:12], nrow=4)[0])
Out[3]:
<matplotlib.image.AxesImage at 0x7fabc477fc70>
In [4]:
# model.py
import icon_registration.constricon as constricon
input_shape = [1, 1, 128, 128, 128]
def make_network():
net = constricon.FirstTransform(
constricon.TwoStepInverseConsistent(
constricon.ConsistentFromMatrix(
networks.ConvolutionalMatrixNet(dimension=2)
),
constricon.TwoStepInverseConsistent(
constricon.ConsistentFromMatrix(
networks.ConvolutionalMatrixNet(dimension=2)
),
constricon.TwoStepInverseConsistent(
constricon.ConsistentFromMatrix(
networks.ConvolutionalMatrixNet(dimension=2)
),
constricon.ConsistentFromMatrix(
networks.ConvolutionalMatrixNet(dimension=2)
),
),
),
)
)
net = icon.losses.BendingEnergyNet(net, icon.LNCC(5), lmbda=.03)
net.assign_identity_map(input_shape)
return net
net = make_network()
In [5]:
net.assign_identity_map(sample_batch.shape)
In [6]:
net.train()
net.to(device)
optim = torch.optim.Adam(net.parameters(), lr=0.001)
curves = icon.train_datasets(net, optim, ds, ds, epochs=5)
plt.close()
plt.plot(np.array(curves)[:, :3])
100%|█████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:34<00:00, 6.83s/it]
Out[6]:
[<matplotlib.lines.Line2D at 0x7fabc0c87c70>, <matplotlib.lines.Line2D at 0x7fabc0c87dc0>, <matplotlib.lines.Line2D at 0x7fabc0c87f10>]
In [7]:
plt.close()
def show(tensor):
plt.imshow(torchvision.utils.make_grid(tensor[:6], nrow=3)[0].cpu().detach())
plt.xticks([])
plt.yticks([])
image_A = next(iter(ds))[0].to(device)
image_B = next(iter(ds))[0].to(device)
net(image_A, image_B)
plt.subplot(2, 2, 1)
show(image_A)
plt.subplot(2, 2, 2)
show(image_B)
plt.subplot(2, 2, 3)
show(net.warped_image_A)
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[0].cpu().detach())
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[1].cpu().detach())
plt.subplot(2, 2, 4)
show(net.warped_image_A - image_B)
plt.tight_layout()
In [8]:
ds, _ = icon_registration.data.get_dataset_mnist(split="train", number=5)
curves = icon.train_datasets(net, optim, ds, ds, epochs=5)
plt.close()
plt.plot(np.array(curves)[:, :3])
ds, _ = icon_registration.data.get_dataset_mnist(split="train", number=8)
curves = icon.train_datasets(net, optim, ds, ds, epochs=5)
plt.close()
plt.plot(np.array(curves)[:, :3])
100%|█████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:30<00:00, 6.07s/it] 100%|█████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:32<00:00, 6.51s/it]
Out[8]:
[<matplotlib.lines.Line2D at 0x7fabc1e275b0>, <matplotlib.lines.Line2D at 0x7fabc1e27700>, <matplotlib.lines.Line2D at 0x7fabc1e27850>]
In [9]:
ds, _ = icon_registration.data.get_dataset_mnist(split="train", number=6)
image_A = next(iter(ds))[0].to(device)
image_B = next(iter(ds))[0].to(device)
net(image_A, image_B)
plt.subplot(2, 2, 1)
show(image_A)
plt.subplot(2, 2, 2)
show(image_B)
plt.subplot(2, 2, 3)
show(net.warped_image_A)
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[0].cpu().detach())
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[1].cpu().detach())
plt.subplot(2, 2, 4)
show(net.warped_image_A - image_B)
plt.tight_layout()
plt.show()
ds, _ = icon_registration.data.get_dataset_mnist(split="train", number=1)
image_A = next(iter(ds))[0].to(device)
image_B = next(iter(ds))[0].to(device)
net(image_A, image_B)
plt.subplot(2, 2, 1)
show(image_A)
plt.subplot(2, 2, 2)
show(image_B)
plt.subplot(2, 2, 3)
show(net.warped_image_A)
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[0].cpu().detach())
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[1].cpu().detach())
plt.subplot(2, 2, 4)
show(net.warped_image_A - image_B)
plt.tight_layout()
In [ ]: