Connecting the Dots with Optimal Transport - Particle Colocalization

In a previous notebook, we have looked at a method of finding out whether two types of molecules are attracted to each other. This problem is common in chemistry or molecular biology, where these connections are studied. We quantify what it means for molecule clouds to be close together, and infer an attraction for images of particularly 'close clouds'.

We found that for really sharp images, a pixel based approach is not satisfactory. But today's microscopes are getting to very high resolutions at even small scales - let's look at a method that utilizes this and treats molecules as points. The method we look at lets us study the attraction between any number of molecule types, not just two.

Subgraph Recovery

Generating Data

First, we need some data to work on, we'll generate it ourselves. We assume that some molecules are connected by a chemical bond. Therefore, they will be close together on an image taken of them. This might mean that their relative position is normally distributed with a small variance - or a small spread. A large spread, conversely, implies a weak bond.

What we now do is traverse a tree structure of chemical bonds that we want to simulate existing. So if molecule $A$, $B$ and $C$ are all attracted to each other, we might find $A$ attached to $B$ and $B$ attached to $C$. Each branch of the tree is a draw out of a two-dimensional normal distribution, which models the random relative observed positioning of the two particles at the edges.

Even if attraction between molecule types exists, however, we might not capture every particle cleanly. In practice, a fluorescent marker substance is attached to the molecules of interest, which might be a point of failure in multiple different ways. In any case, we have to deal with seeing only part of the tree of connected molecules. What we want to simulate is some particles being in a whole connection cluster, like $\{A, B, C\}$, some in a subset of the possible connections, like $\{A, C\}$, and some being all alone - $\{A\}$. This is why we specify a number of pairing clusters to be generated.

The generated points can also be thought of as having a mass, which in this implementation is uniformly equal to $1$ for every one. As you might know, optimal transport is built on measure theory, so formally the transport plans push measures onto other measures. Here, the particle clouds are (probability) measures on $\mathbb R^2$ describing how much its points inside a certain section $A \subset \mathbb R ^2$ weigh. Because of this, the weights should sum to 1.

Let's look at this is action:

In [1]:
import sys
try:
    if "pyodide" in sys.modules:
        import piplite
        await piplite.install('ipywidgets == 7.7')
except:
    print("piplite not found, ignore this if you are not using Jupyterlite")
In [2]:
import scipy
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import assets.ot as ot
from ipywidgets import interactive, FloatSlider, IntSlider, BoundedIntText, FloatLogSlider, Checkbox, widgets
from IPython.display import display

# These are the possible bonds
edges = [(0,1), (1,2), (1,3)]

spread = 0.01

# The number of generated pairing clusters. We want to know exactly which molecules connect in which combination.
species = {(0,1,2,3) : 30, (0,1,2) : 20, (0,2,3) : 20, (2,3) : 10, (0,) : 10, (1,) : 20, (3,) : 10}

def plot_clouds(ax, clouds, mu, scale=2):
    colors = ['midnightblue', 'mediumvioletred', 'lightpink', 'dodgerblue', 'teal', 'deeppink']
    j = 0
    
    for i, c in enumerate(clouds):
        ax.scatter(c[:,0], c[:,1], s=mu[i]*4*scale**2 , c=colors[j], marker='.', alpha=0.75)
        j += 1 

def plot_interactive(spread, show_connections, seed=1):
    mu, dists, X, true_coupling, edg_lines = ot.make_sample2d_s(seed=seed, 
                                                            edges=edges, 
                                                            spread=[spread/10 for _ in edges],
                                                            species=species,
                                                            add_ghost_copies=True)
    fig = plt.figure(figsize=(8,8), dpi=200)
    ax = plt.gca()

    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_axis_off()

    plot_clouds(ax, X, mu, scale=4)
    
    if show_connections:
        lc = mpl.collections.LineCollection(edg_lines, linewidths=1, zorder=0, linestyles=(0,(1,1)), color="hotpink")
        ax.add_collection(lc)

    plt.tight_layout()
    plt.show()


# Creating widgets for interactive visualization
spread_widget = FloatSlider(description='Spread', 
                     min=0.05, max=0.25, 
                     step=0.01, value=0.1,
                     continuous_update=False)

seed_widget = BoundedIntText(description='Seed', 
                             value=1,min=0)
connection_widget = Checkbox(description="Show Tree Structure",
                             value = True)

ui = widgets.VBox([spread_widget, connection_widget, seed_widget])
widget = interactive(plot_interactive, spread=0.1, show_connections=True)
out = widgets.interactive_output(plot_interactive, {'spread': spread_widget, 'show_connections': connection_widget, 'seed': seed_widget})

display(ui, out)

Getting back what's lost

Now, we use an algorithm utilizing optimal transport to guess the dotted trees in the picture above. What's important here is not that the exact trees as above are found, but only that the result reflects how close the generated particles originally were to each other, their colocalization. For a spread small enough however, it will turn out that we often find exactly which particles came from the same tree. In an application on measurements however, is no such original. The algorithm can, in any case, only see the dots and tries to optimally construct weighted graphs that match this input. A high weight on a connection means that the algorithm is sure that there is a connection.

For any set of graphs (or the stochastic couplings on them, to be exact) connecting the points, the solver can judge how good of a solution it is by applying a cost function. This is necessary in order for optimal transport to work, in fact the very word optimal stands for 'minimizing a cost function'. For this, the solver takes the following aspects into account:

  • Dots connected by a branch should not be farther apart than a fixed distance $d_{\text{max}}$. This would be an upper bound for a reasonable molecular interaction distance in practice. The default value is the $0.99$-quantile of a $\chi$-distribution for 2 dimensions, according to which the euclidean norm of normally distributed 2D points are distributed. This describes the length of our brances, so on average, this will cover $99\%$ of the generated connections. The slider below simply multiplies a constant to this default.
  • As few points as possible should remain without a graph. That is, there is a parameter $\lambda$ for unmatched points. The higher, the more enthusiastically points are associated with each other.
  • As little mass as possible should be imaginarily added to the input. We will discuss what this means below.

Further, in order to help the algorithm converge at all, a method called entropic regularization is used. For practical purposes this means that we introduce an indecisiveness parameter $\epsilon$; if $\epsilon$ is high, a point of type $A$ being about equally close to two points of type $B$ will be matched to both of them equally. If $\epsilon$ is low however, it is decided that whichever $B$ is closer, by however narrow a margin, gets connected to $A$ and the other point $B$ will have to look for another cluster to be part of. Note that this also influences the speed of convergence.

Below, you can see code that executes all this. The red connection lines on the left are the generated connections, the ones on the right are where the algorithm guessed there was a connection.

(The following code takes a while to run, so patience is needed to try different parameters.)

In [3]:
# A seed for reproducing the same graph with different parameters. 
# You can rerun this cell to see different configurations.

seed = np.random.randint(2**31)
In [4]:
from scipy.stats import chi

# Here, we fix the properties of the generated graphs
edges = [(0,1), (1,2), (1,3)]

spread = 0.01
# Feel free to change the values in this dictionary or add your own keys to change the generated configurations!
species = {(0,1,2,3) : 20, (0,1,2) : 15, (0,2,3) : 10, (2,3) : 5, (0,) : 5, (1,) : 5, (3,) : 5}

mu, dists, X, true_coupling, edg_lines = ot.make_sample2d_s(seed=seed, 
                                                            edges=edges, 
                                                            spread=[spread for _ in edges],
                                                            species=species,
                                                            add_ghost_copies=True)

def plot_primal(ax, clouds, mu, edges, plans, scale=2, threshold=1e-3):
        plot_clouds(ax, clouds, mu, scale)

        for k in range(len(edges)):
            X1 = clouds[edges[k][0]]
            X2 = clouds[edges[k][1]]
            n1 = plans[k].shape[0]
            n2 = plans[k].shape[1]

            # Drawing trees which have been found by the algorithm
            l = mpl.collections.LineCollection(
                [(X1[i], X2[j]) for i in range(n1) for j in range(n2) if plans[k][i,j] > threshold],
                linewidths=[scale * plans[k][i,j] for i in range(n1) for j in range(n2) if plans[k][i,j] > threshold],
                zorder=-1, color="red", linestyles="-", capstyle="round")

            ax.add_collection(l)

def plot_interactive(d, epsi):
    # These are parameters for the Optimal Transport solver, adjusted by slider values
    d_max = spread * chi.ppf(0.99, 2) * d
    lamp = d_max * 10
    lam = lamp * 10
    costs = [d + lam * (d > d_max) for d in dists]
    sigmas = [np.ones_like(c) for c in costs]
    
    # This runs the solver
    p = ot.partial_multi_L1(mu, edges, costs, sigmas, epsi, lam, lamp, False)
    res = p.run_epsi_scale(min_step_size=1e-4)
    
    # plotting
    fontsize = 16
    fig, axs = plt.subplots(1, 2, figsize=(16,8), dpi=200)

    margin = max(0 - min(0, np.min(X)), max(1, np.max(X)) - 1) + .05
    
    axs[0].set_title("Generated Coupling", fontsize=fontsize)
    axs[0].set_aspect(1)
    axs[0].set_xlim(-margin, 1 + margin)
    axs[0].set_ylim(-margin, 1 + margin)
    plot_primal(axs[0], X, mu, edges, true_coupling, scale=3.5)

    lc = mpl.collections.LineCollection(edg_lines, linewidths=1, zorder=0, linestyles=(0,(1,1)), color="red")
    axs[0].add_collection(lc)
    axs[0].set_xticks([])
    axs[0].set_yticks([])

    axs[1].set_title("Calculated Coupling", fontsize=fontsize)
    axs[1].set_aspect(1)
    axs[1].set_xlim(-margin, 1 + margin)
    axs[1].set_ylim(-margin, 1 + margin)
    plot_primal(axs[1], X, mu, edges, p.get_primal(), scale=3.5)
    axs[1].set_xticks([])
    axs[1].set_yticks([])

    # fig.canvas.draw()
    plt.show()
    
# Creating widgets for interactive visualization
maximal_distance_widget = FloatSlider(description='$d_{max}$', 
                     min=0.5, max=3, 
                     step=0.01, value=1,
                     continuous_update=False)
epsilon_widget = FloatLogSlider(description='$\epsilon$', 
                     min=-8, max=-1,
                     step=0.01, value=5e-5,
                     continuous_update=False)
ui = widgets.VBox([maximal_distance_widget, epsilon_widget])
widget = interactive(plot_interactive, d=1, epsi=1e-5)
out = widgets.interactive_output(plot_interactive, {'d': maximal_distance_widget, 'epsi': epsilon_widget})

display(ui, out)

Evaluation

Next, we can quantitatively assess the colocalization and see how it compares to the original trees. The above recovered graphs may already give a visual notion of how close the points are to each other, but in the end we simply want a statistic of how likely it is that the molecules are attracted to each other - not yet another picture. Here, since we know how many point clusters $\{A\}, \{A,B,C,D\}, ...$ were generated, we can graph that next to the number of guessed clusters to gauge the algorithm's accuracy. Note that when the spread is too large, we can't fault the solver for matching points which weren't generated together, so this is a more approximate way to gauge its performance.

In [5]:
# Here, we fix the properties of the generated graphs
edges = [(0,1), (1,2), (1,3)]
spread = 0.01
# Feel free to change the values in this dictionary or add your own keys to change the generated configurations!
species = {(0,1,2,3) : 15, (0,1,2) : 5, (0,2,3) : 20, (1,2,3) : 10, (2,3) : 10, (0,) : 5, (1,) : 5, (3,) : 5}

mu, dists, X, true_coupling, edg_lines = ot.make_sample2d_s(seed=seed, 
                                                            edges=edges, 
                                                            spread=[spread for _ in edges],
                                                            species=species,
                                                            add_ghost_copies=True)

def plot_interactive(d, epsi):
     # These are parameters for the Optimal Transport solver, adjusted by slider values
    d_max = spread * chi.ppf(0.99, 2) * d
    lamp = d_max * 10
    lam = lamp * 10
    costs = [d + lam * (d > d_max) for d in dists]
    sigmas = [np.ones_like(c) for c in costs]
    
    # This runs the solver
    p = ot.partial_multi_L1(mu, edges, costs, sigmas, epsi, lam, lamp, False)
    res = p.run_epsi_scale(min_step_size=1e-4)
    pi = p.get_primal()
    
    # Counting match numbers
    match_cnts, match_str, deltas = p.get_match_counts()

    fig = plt.figure(figsize=(16,8), dpi=200)

    fontsize=16
    plt.title("Recovery of (partial) Tuple Counts", fontsize=fontsize)

    plt.bar(np.arange(len(deltas)) - 0.2, match_cnts.ravel(), width=0.4, color="purple")
    
    n = len(edges) + 1
    def ctos(c):
        a = np.zeros(n, dtype=int)
        a[list(c)] = 1
        return "".join([(chr(ord('A') + j) if a[j] == 1 else "_") for j in range(n)])

    true_x = list(map(match_str.index, list(map(ctos, species.keys()))))
    true_y = list(species.values())
    true_vals = np.zeros(len(deltas), dtype=int)
    true_vals[true_x] = true_y
    plt.bar(np.array(true_x) + 0.2, true_y, width=0.4, color="deeppink")

    plt.tight_layout()
    plt.xticks(range(len(deltas)), match_str, fontsize=fontsize, rotation=45)
    plt.yticks(fontsize=fontsize)

    for x in range(len(deltas)):
        mc = match_cnts.ravel()[x]
        if mc > 0.01:
            plt.text(x - 0.2, mc + 0.5, str(round(mc, 2)), ha="center")
        if true_vals[x] > 0:
            plt.text(x + 0.2, true_vals[x] + 0.5, str(true_vals[x]), ha="center")

    plt.legend(["Number of calculated Clusters", "Number of generated Clusters"], fontsize=fontsize)
    # fig.canvas.draw()
    plt.show()

widget = interactive(plot_interactive, d=1, epsi=5e-5)
out = widgets.interactive_output(plot_interactive, {'d': maximal_distance_widget, 'epsi': epsilon_widget})

display(ui, out)

So for the default parameters and most seeds, it seems like it's doing quite well. If we were to run the program on experimental data and get this result, we would conclude that there is a strong interaction between molecules $C$ and $D$, or that the interaction $\{A, C, D\}$ is particularly likely.

Phantom Points

In an implementation of a tree based optimal transport approach to this problem, there is one key difficulty we have not touched on yet. The solver needs as a formal constraint that all the vertices from the original tree structure are present. However, as we looked at, we can also include subsets of this. So how was this done?

Now is a good time to introduce Thilo Stier, who wrote a Bachelor's thesis on the topic and the code at the heart of this notebook. (On a side note, don't hesitate to write an email to thilodaniel.stier AT stud DOT uni-goettingen.de if you have questions about this implementation.) His solution was to add phantom particles to the generated point clouds:

A phantom particle is a point with zero mass, not part of the point cloud that originally needs to be connected. If there is a point $A$ and we want to find trees with particles $\{A, B, C\}$, let's say, hovering right above $A$ are phantom points $B$ and $C$. If there aren't any real $B$s or $C$s anywhere near, the algorithm may choose a phantom point as a substitute. But because finding a whole graph is better than just a part of it, doing this has a cost

Conclusion

The field of optimal transport and its numerics is full of methods like the ones we looked at here - being discovered just now. Try a search for 'multi particle colocalization' - what you are looking at in this notebook is already the cutting edge, more or less!

If you want to be part of this research, too, the place this notebook comes from might be just what you're looking for. You can find the University of Göttingen's optimal transport group here and the Collaborative Research Center here. A notebook just like this one about researching Alzheimer's disease with optimal transport is avaiable here. Thank you for reading!


Written by Thilo Stier and Lennart Finke. Published under the MIT license.