In [2]:
import icon_registration as icon
import icon_registration.data
import icon_registration.networks as networks
from icon_registration.config import device
import numpy as np
import torch
import torchvision.utils
import matplotlib.pyplot as plt
In [3]:
ds, _ = icon_registration.data.get_dataset_mnist(split="train", number=2)
sample_batch = next(iter(ds))[0]
plt.imshow(torchvision.utils.make_grid(sample_batch[:12], nrow=4)[0])
Out[3]:
<matplotlib.image.AxesImage at 0x7f0451808d90>
In [9]:
import icon_registration.carl as carl
unet = carl.NoDownsampleNet(dimension=2)
ar = carl.AttentionRegistration(unet, dimension=2)
ts = icon.FunctionFromVectorField(ar)
inner_net = icon.FunctionFromVectorField(networks.tallUNet2(dimension=2))
for _ in range(2):
inner_net = icon.TwoStepRegistration(
icon.DownsampleRegistration(inner_net, dimension=2),
icon.FunctionFromVectorField(networks.tallUNet2(dimension=2))
)
inner_net = icon.TwoStepRegistration(ts, inner_net)
net = icon.losses.DiffusionRegularizedNet(inner_net, icon.LNCC(sigma=4), lmbda=.5)
In [10]:
net.assign_identity_map(sample_batch.shape)
In [11]:
net.train()
net.to(device)
optim = torch.optim.Adam(net.parameters(), lr=0.001)
curves = icon.train_datasets(net, optim, ds, ds, epochs=5)
plt.close()
plt.plot(np.array(curves)[:, :3])
/usr/lib/python3.10/contextlib.py:103: FutureWarning: `torch.backends.cuda.sdp_kernel()` is deprecated. In the future, this context manager will be removed. Please see `torch.nn.attention.sdpa_kernel()` for the new context manager, with updated signature. self.gen = func(*args, **kwds) 100%|█████████████████████████████████████████████████████████████████████████████████████| 5/5 [04:40<00:00, 56.15s/it]
Out[11]:
[<matplotlib.lines.Line2D at 0x7f04485a6860>, <matplotlib.lines.Line2D at 0x7f04485a69b0>, <matplotlib.lines.Line2D at 0x7f04485a6b00>]
In [12]:
plt.close()
def show(tensor):
plt.imshow(torchvision.utils.make_grid(tensor[:6], nrow=3)[0].cpu().detach())
plt.xticks([])
plt.yticks([])
image_A = next(iter(ds))[0].to(device)
image_B = next(iter(ds))[0].to(device)
net(image_A, image_B)
plt.subplot(2, 2, 1)
show(image_A)
plt.subplot(2, 2, 2)
show(image_B)
plt.subplot(2, 2, 3)
show(net.warped_image_A)
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[0].cpu().detach())
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[1].cpu().detach())
plt.subplot(2, 2, 4)
show(net.warped_image_A - image_B)
plt.tight_layout()
In [15]:
ds, _ = icon_registration.data.get_dataset_mnist(split="train", number=5)
curves = icon.train_datasets(net, optim, ds, ds, epochs=1)
plt.close()
plt.plot(np.array(curves)[:, :3])
ds, _ = icon_registration.data.get_dataset_mnist(split="train", number=8)
curves = icon.train_datasets(net, optim, ds, ds, epochs=1)
plt.close()
plt.plot(np.array(curves)[:, :3])
0%| | 0/1 [00:00<?, ?it/s]
--------------------------------------------------------------------------- OutOfMemoryError Traceback (most recent call last) Cell In[15], line 2 1 ds, _ = icon_registration.data.get_dataset_mnist(split="train", number=5) ----> 2 curves = icon.train_datasets(net, optim, ds, ds, epochs=1) 3 plt.close() 4 plt.plot(np.array(curves)[:, :3]) File ~/projects/affine_generalization/venv/lib/python3.10/site-packages/icon_registration/train.py:112, in train_datasets(net, optimizer, d1, d2, epochs) 109 image_B = B[0].to(icon_registration.config.device) 110 optimizer.zero_grad() --> 112 loss_object = net(image_A, image_B) 114 loss_object.all_loss.backward() 115 optimizer.step() File ~/projects/affine_generalization/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs) 1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1738 else: -> 1739 return self._call_impl(*args, **kwargs) File ~/projects/affine_generalization/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs) 1745 # If we don't have any hooks, we want to skip the rest of the logic in 1746 # this function, and just call forward. 1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1748 or _global_backward_pre_hooks or _global_backward_hooks 1749 or _global_forward_hooks or _global_forward_pre_hooks): -> 1750 return forward_call(*args, **kwargs) 1752 result = None 1753 called_always_called_hooks = set() File ~/projects/affine_generalization/venv/lib/python3.10/site-packages/icon_registration/losses.py:459, in BendingEnergyNet.forward(self, image_A, image_B) 455 # Tag used elsewhere for optimization. 456 # Must be set at beginning of forward b/c not preserved by .cuda() etc 457 self.identity_map.isIdentity = True --> 459 self.phi_AB = self.regis_net(image_A, image_B) 460 self.phi_AB_vectorfield = self.phi_AB(self.identity_map) 462 similarity_loss = 2 * self.compute_similarity_measure( 463 self.phi_AB_vectorfield, image_A, image_B 464 ) File ~/projects/affine_generalization/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs) 1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1738 else: -> 1739 return self._call_impl(*args, **kwargs) File ~/projects/affine_generalization/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs) 1745 # If we don't have any hooks, we want to skip the rest of the logic in 1746 # this function, and just call forward. 1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1748 or _global_backward_pre_hooks or _global_backward_hooks 1749 or _global_forward_hooks or _global_forward_pre_hooks): -> 1750 return forward_call(*args, **kwargs) 1752 result = None 1753 called_always_called_hooks = set() File ~/projects/affine_generalization/venv/lib/python3.10/site-packages/icon_registration/network_wrappers.py:212, in TwoStepRegistration.forward(self, image_A, image_B) 206 def forward(self, image_A, image_B): 207 208 # Tag for shortcutting hack. Must be set at the beginning of 209 # forward because it is not preserved by .to(config.device) 210 self.identity_map.isIdentity = True --> 212 phi = self.netPhi(image_A, image_B) 213 psi = self.netPsi( 214 self.as_function(image_A)(phi(self.identity_map)), 215 image_B 216 ) 217 return lambda tensor_of_coordinates: phi(psi(tensor_of_coordinates)) File ~/projects/affine_generalization/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs) 1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1738 else: -> 1739 return self._call_impl(*args, **kwargs) File ~/projects/affine_generalization/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs) 1745 # If we don't have any hooks, we want to skip the rest of the logic in 1746 # this function, and just call forward. 1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1748 or _global_backward_pre_hooks or _global_backward_hooks 1749 or _global_forward_hooks or _global_forward_pre_hooks): -> 1750 return forward_call(*args, **kwargs) 1752 result = None 1753 called_always_called_hooks = set() File ~/projects/affine_generalization/venv/lib/python3.10/site-packages/icon_registration/network_wrappers.py:112, in FunctionFromVectorField.forward(self, image_A, image_B) 111 def forward(self, image_A, image_B): --> 112 tensor_of_displacements = self.net(image_A, image_B) 113 displacement_field = self.as_function(tensor_of_displacements) 115 def transform(coordinates): File ~/projects/affine_generalization/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs) 1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1738 else: -> 1739 return self._call_impl(*args, **kwargs) File ~/projects/affine_generalization/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs) 1745 # If we don't have any hooks, we want to skip the rest of the logic in 1746 # this function, and just call forward. 1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1748 or _global_backward_pre_hooks or _global_backward_hooks 1749 or _global_forward_hooks or _global_forward_pre_hooks): -> 1750 return forward_call(*args, **kwargs) 1752 result = None 1753 called_always_called_hooks = set() File ~/projects/affine_generalization/venv/lib/python3.10/site-packages/icon_registration/carl.py:397, in AttentionRegistration.forward(self, A, B) 395 def forward(self, A, B): 396 ft_A = self.featurize(A, recrop=False) --> 397 ft_B = self.featurize(B) 398 output = self.torch_attention(ft_A, ft_B) 399 output = output - self.identity_map File ~/projects/affine_generalization/venv/lib/python3.10/site-packages/icon_registration/carl.py:324, in AttentionRegistration.featurize(self, values, recrop) 322 x = torch.nn.functional.pad(values, [padding, padding, padding, padding]) 323 x = self.net(x) --> 324 x = 4 * x / (0.001 + torch.sqrt(torch.sum(x**2, dim=1, keepdims=True))) 325 if recrop: 326 x = self.crop(x) File ~/projects/affine_generalization/venv/lib/python3.10/site-packages/torch/_tensor.py:39, in _handle_torch_function_and_wrap_type_error_to_not_implemented.<locals>.wrapped(*args, **kwargs) 37 if has_torch_function(args): 38 return handle_torch_function(wrapped, args, *args, **kwargs) ---> 39 return f(*args, **kwargs) 40 except TypeError: 41 return NotImplemented OutOfMemoryError: CUDA out of memory. Tried to allocate 134.00 MiB. GPU 0 has a total capacity of 7.92 GiB of which 17.06 MiB is free. Including non-PyTorch memory, this process has 6.94 GiB memory in use. Of the allocated memory 5.98 GiB is allocated by PyTorch, and 815.89 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
In [14]:
ds, _ = icon_registration.data.get_dataset_mnist(split="train", number=6)
image_A = next(iter(ds))[0].to(device)[:12]
image_B = next(iter(ds))[0].to(device)[:12]
net(image_A, image_B)
plt.subplot(2, 2, 1)
show(image_A)
plt.subplot(2, 2, 2)
show(image_B)
plt.subplot(2, 2, 3)
show(net.warped_image_A)
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[0].cpu().detach())
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[1].cpu().detach())
plt.subplot(2, 2, 4)
show(net.warped_image_A - image_B)
plt.tight_layout()
plt.show()
ds, _ = icon_registration.data.get_dataset_mnist(split="train", number=1)
image_A = next(iter(ds))[0].to(device)[:12]
image_B = next(iter(ds))[0].to(device)[:12]
net(image_A, image_B)
plt.subplot(2, 2, 1)
show(image_A)
plt.subplot(2, 2, 2)
show(image_B)
plt.subplot(2, 2, 3)
show(net.warped_image_A)
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[0].cpu().detach())
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[1].cpu().detach())
plt.subplot(2, 2, 4)
show(net.warped_image_A - image_B)
plt.tight_layout()
In [ ]:
In [ ]: