In our toy problem, the images we want to register, A,BA, B, are diffeomorphisms.

To register them, we want to find Ξ\Xi such that AΞ=BA \circ \Xi = B.

So, Ξ=A1B\Xi = A^{-1} \circ B.

A neural network that can compute A1BA^{-1}\circ B will have all the equivariances of Ξ\Xi.

Transformers naturally implement function inversion

The Key, Query, Value attention mechanism naturally lends itself to implementing function inversion in a matter appropriate for representing Ξ\Xi. As a first demo, we will invert the function y=x2y=x^2

x = torch.arange(0, 1, 1/100)
x = x[None, None, :]
y = x**2

First, we process our function inputs and outputs into feature vectors. These representations are chosen so that repr(u)repr(v)1repr(u) \cdot repr(v) \simeq 1 when uvu \simeq v.

scale_weight = (torch.randn(100) * 59)[:, None, None]

scale = torch.nn.Conv1d(1, 100, 1, bias=True)

with torch.no_grad():
    scale.weight[:] = scale_weight
ft_x = torch.sin(scale(x))
ft_y = torch.sin(scale(y))

Then, do an attention, with function outputs as Keys, function inputs as Values, and the values that we want to pass to the inverted function as Queries.

attention = torch.nn.functional.softmax((ft_x.permute(0, 2, 1) @ ft_y), dim=2)
output = attention @ x.permute(0, 2, 1)

Voila, the graph of x\sqrt{x}.

Registering two images using a neural network

We define some (1-D) images to register:

Image A is A:[0,1][0,1],xcos(π2x)A: [0, 1] \rightarrow[0, 1], x \mapsto \cos(\frac{\pi}{2} x)

Image B is B:[0,1][0,1],xx+0.07sin(3πx)B: [0, 1] \rightarrow[0, 1], x \mapsto x + 0.07 \sin(3\pi x)

A = torch.cos(.5 * torch.pi * x)
B = x + .07 * torch.sin(3 * torch.pi * x)

plt.plot(A[0, 0])
plt.plot(B[0, 0])

We know analytically that A and B are registered by

Ξ[A,B]=A1B=2πcos1(x+0.07sin(3πx))\Xi[A, B] = A^{-1} \circ B = \frac{2}{\pi} \cos^{-1}(x + 0.07 \sin(3\pi x))

We create a neural network that inverts A and applies the result to B, and verify that our neural network correctly implements Ξ\Xi on this specific example.

class AttentionRegistration(torch.nn.Module):
    def __init__(self):
        self.x = torch.arange(0, 1, 1/100)
        self.x = self.x[None, None, :]
        scale_weight = (torch.randn(100) * 59)[:, None, None]

        self.scale = torch.nn.Conv1d(1, 100, 1, bias=True)

        with torch.no_grad():
            self.scale.weight[:] = scale_weight
    def featurize(values):
        values = self.scale(values)
        return torch.sin(values)
    def forward(A, B):
        ft_A = self.featurize(A)
        ft_B = self.featurize(B)
        attention = torch.nn.functional.softmax((ft_B.permute(0, 2, 1) @ ft_A), dim=2)
        output = attention @ x.permute(0, 2, 1)
        return output
ar = AttentionRegistration()

XiAB = ar(A, B)


plt.plot((torch.arccos((x + .07 * torch.sin(3 * torch.pi * x))) * 2 / torch.pi)[0, 0])

Our neural network produces the map which we proved registers A to B.

Next steps: apply this architecture to medical images

