2d mass without mass pushforward#

import sys
import os
sys.path.append(os.path.abspath("../.."))

import numpy as np
import numpy.ma as ma
from scipy.fftpack import dctn, idctn
import matplotlib.pyplot as plt
from time import time
from BFOT.BFOT import BFM
import cProfile
import pstats
import math
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[1], line 10
      8 import matplotlib.pyplot as plt
      9 from time import time
---> 10 from BFOT.BFOT import BFM
     11 import cProfile
     12 import pstats

File ~/work/BFOT/BFOT/BFOT/BFOT.py:2
      1 import numpy as np
----> 2 import numba
      3 from numba import prange
      4 from scipy.signal import convolve2d

ModuleNotFoundError: No module named 'numba'
plt.rcParams['figure.figsize'] = (13, 8)
plt.rcParams['image.cmap'] = 'viridis'

# Function definitions
def initialize_kernel(n1, n2):
    xx, yy = np.meshgrid(np.linspace(0, np.pi, n1, False), np.linspace(0, np.pi, n2, False))
    kernel = 2 * n1 * n1 * (1 - np.cos(xx)) + 2 * n2 * n2 * (1 - np.cos(yy))
    kernel[0, 0] = 1  # to avoid dividing by zero
    return kernel

def dct2(a):
    #discrete cosine transform
    return dctn(a, norm='ortho')

def idct2(a):
    #inverse discrete cosine transform
    return idctn(a, norm='ortho')

def update_potential(phi, rho, nu, kernel, sigma):
    n1, n2 = nu.shape
    rho -= nu
    workspace = dct2(rho) / kernel
    workspace[0, 0] = 0
    workspace = idct2(workspace)
    phi += sigma * workspace
    h1 = np.sum(workspace * rho) / (n1 * n2)
    return h1
    
def compute_w2(phi, psi, mu, nu, x, y):
    n1, n2 = mu.shape
    return np.sum(0.5 * (x * x + y * y) * (mu + nu) - nu * phi.reshape((n1, n2)) - mu * psi.reshape((n1, n2))) / (n1 * n2)
    
def compute_w2_wop(phi, psi, mu, nu, x, y, m_mu, m_nu):
    n1, n2 = mu.shape
    return np.sum(0.5 * (x * x + y * y) * (mu + nu) - nu * phi.reshape((n1, n2)) - mu * psi.reshape((n1, n2)) + (m_mu - m_nu) ** 2) / (n1 * n2)

def stepsize_update(sigma, value, oldValue, gradSq):
    scaleDown = 0.95
    scaleUp = 1 / scaleDown
    upper = 0.75
    lower = 0.25
    diff = value - oldValue
    if diff > gradSq * sigma * upper:
        return sigma * scaleUp
    elif diff < gradSq * sigma * lower:
        return sigma * scaleDown
    return sigma
# use the pudhforward function in the c file 

import matplotlib.pyplot as plt
import numpy as np

def compute_ot(phi, psi, bf, sigma):
    kernel = initialize_kernel(n2, n1)
    rho = np.copy(mu)
    oldValue = compute_w2(phi, psi, mu, nu, x, y)
    
    for k in range(numIters + 1):
        # Perform gradient and potential updates
        gradSq = update_potential(phi, rho, nu, kernel, sigma)
        bf.ctransform(psi, phi)
        bf.ctransform(phi, psi)
        value = compute_w2(phi, psi, mu, nu, x, y)
        sigma = stepsize_update(sigma, value, oldValue, gradSq)
        oldValue = value
        bf.pushforward(rho, phi, nu)
        # Update potential and push forward for psi
        gradSq = update_potential(psi, rho, mu, kernel, sigma)
        bf.ctransform(phi, psi)
        bf.ctransform(psi, phi)
        bf.pushforward(rho, psi, mu)

        # Logging the values
        if k % 10 == 0:
            print(f'iter {k:4d},   W2 value: {value:.6e},   H1 err: {gradSq:.2e}')
        #if k % 10 == 0:
            #print(bf.xMap, bf.yMap)

        # Update stepsize and oldValue for next iteration
        sigma = stepsize_update(sigma, value, oldValue, gradSq)
        oldValue = value

def compute_ot_wop(phi, psi, bf, sigma):
    kernel = initialize_kernel(n2, n1)
    rho = np.copy(mu)
    oldValue = compute_w2_wop(phi, psi, mu, nu, x, y, m_mu, m_nu)
    
    for k in range(numIters + 1):
        # Perform gradient and potential updates
        gradSq = update_potential(phi, rho, nu, kernel, sigma)
        bf.ctransform(psi, phi)
        bf.ctransform(phi, psi)
        value = compute_w2_wop(phi, psi, mu, nu, x, y, m_mu, m_nu)
        sigma = stepsize_update(sigma, value, oldValue, gradSq)
        oldValue = value
        bf.pushforward(rho, phi, nu)
        # Update potential and push forward for psi
        gradSq = update_potential(psi, rho, mu, kernel, sigma)
        bf.ctransform(phi, psi)
        bf.ctransform(psi, phi)
        bf.pushforward(rho, psi, mu)

        # Logging the values
        if k % 10 == 0:
            print(f'iter {k:4d},   W2 value: {value:.6e},   H1 err: {gradSq:.2e}')
        #if k % 10 == 0:
            #print(bf.xMap, bf.yMap)

        # Update stepsize and oldValue for next iteration
        sigma = stepsize_update(sigma, value, oldValue, gradSq)
        oldValue = value
# functions for calculating the pushforward by mass 
import numpy as np
from scipy.signal import convolve2d

def sampling_pushforward_jit(rho, xMap, yMap, mu, n1, n2, totalMass):
    """
    Compute the pushforward measure using a sampling-based method.
    
    Parameters:
        rho (np.ndarray): Array of shape (n1, n2) to store the resulting density.
        xMap (np.ndarray): The computed x-coordinate mapping (shape must be (n1+1, n2+1)).
        yMap (np.ndarray): The computed y-coordinate mapping (shape must be (n1+1, n2+1)).
        mu (np.ndarray): Source measure (density) of shape (n1, n2).
        n1 (int): Number of grid points in the x-direction.
        n2 (int): Number of grid points in the y-direction.
        totalMass (float): The total mass to normalize the output density.
    """
    # 1) Clear out rho
    pcount = n1 * n2
    for i in range(n1):
        for j in range(n2):
            rho[i, j] = 0.0

    # 2) Loop over each cell and subdivide
    for i in range(n1):
        for j in range(n2):
            mass = mu[i, j]
            if mass > 0.0:
                # Compute stretches along x and y
                xStretch0 = abs(xMap[i,     j+1] - xMap[i,     j])
                xStretch1 = abs(xMap[i + 1, j+1] - xMap[i + 1, j])
                yStretch0 = abs(yMap[i + 1, j]   - yMap[i,     j])
                yStretch1 = abs(yMap[i + 1, j+1] - yMap[i,     j+1])
                
                xStretch = max(xStretch0, xStretch1)
                yStretch = max(yStretch0, yStretch1)
                
                xSamples = max(int(n1 * xStretch), 1)
                ySamples = max(int(n2 * yStretch), 1)
                
                factor = 1.0 / (xSamples * ySamples)

                # 3) Sample within the cell
                for l in range(ySamples):
                    b = (l + 0.5) / ySamples
                    for k in range(xSamples):
                        a = (k + 0.5) / xSamples
                        
                        # Bilinear interpolation for xMap and yMap
                        xPoint = (
                            (1 - b)*(1 - a)*xMap[i,     j] +
                            (1 - b)*a      *xMap[i,     j+1] +
                            b      *(1 - a)*xMap[i + 1, j] +
                            a*b            *xMap[i + 1, j+1]
                        )
                        yPoint = (
                            (1 - b)*(1 - a)*yMap[i,     j] +
                            (1 - b)*a      *yMap[i,     j+1] +
                            b      *(1 - a)*yMap[i + 1, j] +
                            a*b            *yMap[i + 1, j+1]
                        )

                        # Convert continuous coordinates to discrete indices
                        X = xPoint * n1 - 0.5
                        Y = yPoint * n2 - 0.5

                        xIndex = int(X)
                        yIndex = int(Y)
                        xFrac = X - xIndex
                        yFrac = Y - yIndex
                        xOther = xIndex + 1
                        yOther = yIndex + 1

                        # Clamp indices
                        xIndex = max(0, min(xIndex, n1-1))
                        xOther = max(0, min(xOther, n1-1))
                        yIndex = max(0, min(yIndex, n2-1))
                        yOther = max(0, min(yOther, n2-1))

                        # 4) Distribute mass using bilinear weights
                        w00 = (1.0 - xFrac)*(1.0 - yFrac)
                        w01 = (1.0 - xFrac)*yFrac
                        w10 = xFrac*(1.0 - yFrac)
                        w11 = xFrac*yFrac

                        rho[yIndex, xIndex] += w00 * mass * factor
                        rho[yOther, xIndex] += w01 * mass * factor
                        rho[yIndex, xOther] += w10 * mass * factor
                        rho[yOther, xOther] += w11 * mass * factor

    # 5) Normalize so that the total measure equals totalMass
    sum_rho = 0.0
    for i in range(n1):
        for j in range(n2):
            sum_rho += rho[i, j]
    sum_rho /= pcount  # average density
    if sum_rho > 0.0:
        scale = totalMass / sum_rho
        for i in range(n1):
            for j in range(n2):
                rho[i, j] *= scale

def cal_pushforward_map(dual, n1, n2):
    """
    Compute the pushforward map (xMap and yMap) from the dual variable.
    
    Parameters:
        dual (np.ndarray): The dual variable (often denoted phi).
        n1 (int): Number of grid points in the x-direction.
        n2 (int): Number of grid points in the y-direction.
    
    Returns:
        xMap (np.ndarray), yMap (np.ndarray): The computed mapping arrays.
    """
    kernel = np.array([[0.25, 0.25],
                       [0.25, 0.25]])
    dual_padded = np.pad(dual, ((1, 1), (1, 1)), mode="edge")
    interpolate_function = convolve2d(dual_padded, kernel, mode="valid")
    interpolate_function_padded = np.pad(interpolate_function, ((1, 1), (1, 1)), mode="edge")
    xMap = 0.5 * n2 * (interpolate_function_padded[1:-1, 2:] - interpolate_function_padded[1:-1, :-2])
    yMap = 0.5 * n1 * (interpolate_function_padded[2:, 1:-1] - interpolate_function_padded[:-2, 1:-1])
    return xMap, yMap

def pushforward_mass(phi, nu, n1, n2):
    """
    Perform the pushforward of the measure 'nu' using the dual variable 'phi',
    and store the resulting measure in 'rho'.
    
    Parameters:
        rho (np.ndarray): An (n1 x n2) array to store the pushforward measure.
        phi (np.ndarray): The dual variable used to compute the pushforward map.
        nu (np.ndarray): The source measure (density) on the grid.
        n1 (int): Number of grid points in the x-direction.
        n2 (int): Number of grid points in the y-direction.
        totalMass (float): The total mass for normalization.
    
    Returns:
        np.ndarray: The updated pushforward measure stored in 'rho'.
    """
    rho = np.zeros((n1, n2))
    totalMass = np.sum(nu) / (n1 * n2)
    # Compute the pushforward maps based on phi
    xMap, yMap = cal_pushforward_map(phi, n1, n2)
    
    # Use the sampling pushforward method to update rho
    sampling_pushforward_jit(rho, xMap, yMap, nu, n1, n2, totalMass)
    
    return rho
    
def plot_interpolation_wop(mu, nu, phi, psi, bf, x, y, n1, n2, m_mu, m_nu, n_fig=6):
    """
    Plots the discrete geodesic interpolation between mu and nu
    using the given potentials phi and psi.
    
    Parameters
    ----------
    mu : 2D array (n2, n1)
        Source measure.
    nu : 2D array (n2, n1)
        Target measure.
    phi, psi : 2D arrays (n2, n1)
        Kantorovich potentials.
    bf : BFM object
        Instance of the Semi-discrete solver that has a pushforward method.
    x, y : 2D arrays (n2, n1)
        Grid coordinates.
    n_fig : int
        Number of frames to plot (including t=0 and t=1).
    """
    fig, ax = plt.subplots(1, n_fig, figsize=(20, 8))
    for axi in ax:
        axi.axis('off')

    # Use the same max for color scaling
    vmax = mu.max()

    # Allocate arrays for intermediate pushforwards
    interpolate = np.zeros_like(mu)
    rho_fwd = np.zeros_like(mu)
    rho_bwd = np.zeros_like(mu)

    # "Reference" potential = 0.5*(x² + y²)
    phi_0 = 0.5 * (x**2 + y**2)
    
    # Generate intermediate frames
    for i in range(0, n_fig):
        t_star = i / (n_fig - 1)
        t = t_star / ((m_mu / m_nu) + t_star * (1 - m_mu / m_nu))

        # Blend potentials in 2D form
        # psi_t interpolates from phi_0 to psi
        # phi_t interpolates from phi_0 to phi
        psi_t = (1 - t) * phi_0 + t * psi
        phi_t = t * phi_0 + (1 - t) * phi

        m_t = (1 - t) * m_mu + t * m_nu

        # Pushforward mu by psi_t
        bf.pushforward(rho_fwd, psi_t, mu)

        # Pushforward nu by phi_t
        bf.pushforward(rho_bwd, phi_t, nu)

        # reverse the oushforwards by mass
        # phi_mu_re = 0.5 * ((1/m_t) * (x * x) + (1/m_t) * y * y)
        # phi_nu_re = 0.5 * ((m_mu) * (x * x) + (m_mu) * y * y)
        # rho_fwd = pushforward_mass(phi_mu_re, rho_fwd, n1, n2)
        # rho_bwd = pushforward_mass(phi_nu_re, rho_bwd, n1, n2)

        # Convex combination of the two pushforwards
        interpolate = (1-t) * rho_fwd + t * rho_bwd

        # Plot the interpolation
        ax[i].imshow(interpolate, origin='lower', extent=(0, 1, 0, 1), vmax=vmax)
        ax[i].set_title(f"$t={t:.2f}$")
    plt.show()

def plot_interpolation(mu, nu, phi, psi, bf, x, y, n_fig=6):
    """
    Plots the discrete geodesic interpolation between mu and nu
    using the given potentials phi and psi.
    
    Parameters
    ----------
    mu : 2D array (n2, n1)
        Source measure.
    nu : 2D array (n2, n1)
        Target measure.
    phi, psi : 2D arrays (n2, n1)
        Kantorovich potentials.
    bf : BFM object
        Instance of the Semi-discrete solver that has a pushforward method.
    x, y : 2D arrays (n2, n1)
        Grid coordinates.
    n_fig : int
        Number of frames to plot (including t=0 and t=1).
    """
    fig, ax = plt.subplots(1, n_fig, figsize=(20, 8))
    for axi in ax:
        axi.axis('off')

    # Use the same max for color scaling
    vmax = mu.max()

    # Show the initial and final measures
    ax[0].imshow(mu, origin='lower', extent=(0, 1, 0, 1), vmax=vmax)
    ax[0].set_title("$t=0$")

    ax[n_fig - 1].imshow(nu, origin='lower', extent=(0, 1, 0, 1), vmax=vmax)
    ax[n_fig - 1].set_title("$t=1$")

    # Allocate arrays for intermediate pushforwards
    interpolate = np.zeros_like(mu)
    rho_fwd = np.zeros_like(mu)
    rho_bwd = np.zeros_like(mu)

    # "Reference" potential = 0.5*(x² + y²)
    phi_0 = 0.5 * (x**2 + y**2)

    # Generate intermediate frames
    for i in range(1, n_fig - 1):
        t = i / (n_fig - 1)

        # Blend potentials in 2D form
        # psi_t interpolates from phi_0 to psi
        # phi_t interpolates from phi_0 to phi
        psi_t = (1 - t) * phi_0 + t * psi
        phi_t = t * phi_0 + (1 - t) * phi

        # Pushforward mu by psi_t
        bf.pushforward(rho_fwd, psi_t, mu)

        # Pushforward nu by phi_t
        bf.pushforward(rho_bwd, phi_t, nu)

        # Convex combination of the two pushforwards
        interpolate = (1 - t) * rho_fwd + t * rho_bwd

        # Plot the interpolation
        ax[i].imshow(interpolate, origin='lower', extent=(0, 1, 0, 1), vmax=vmax)
        ax[i].set_title(f"$t={t:.2f}$")

    plt.show()
n1 = 128
n2 = 128
mu = np.zeros((n1,n2))
nu = np.zeros((n1,n2))

for i in range(n1):
    for j in range(n2):
        mu[i,j] = np.cos(np.pi*(128/n1)*(i/n1+0.5)) + 1
        nu[i,j] = np.cos(np.pi*(64/n1)*(i/n1+0.5))*np.cos(np.pi*(64/n1)*(j/n2+0.5)) + 1

fig, ax = plt.subplots(1, 2)
ax[0].imshow(mu, origin='lower', extent=(0, 1, 0, 1))
ax[0].set_title("$\\mu$")
ax[1].imshow(nu, origin='lower', extent=(0, 1, 0, 1))
ax[1].set_title("$\\nu$")

m_mu = np.sum(mu)
m_nu = np.sum(nu)
print(f'total mass of mu = {m_mu}, total mass of nu = {m_nu}')


mu = mu/m_nu
nu = nu/m_nu
fig, ax = plt.subplots(1, 2)
ax[0].imshow(mu, origin='lower', extent=(0, 1, 0, 1))
ax[0].set_title("$\\mu$")
ax[1].imshow(nu, origin='lower', extent=(0, 1, 0, 1))
ax[1].set_title("$\\nu$")
m_mu = np.sum(mu)
m_nu = np.sum(nu)
print(f'total mass of mu = {m_mu}, total mass of nu = {m_nu}')


x, y = np.meshgrid(np.linspace(0.5 / n1, 1 - 0.5 / n1, n1), np.linspace(0.5 / n2, 1 - 0.5 / n2, n2))
phi = 0.5 * (x * x + y * y)
psi = 0.5 * (x * x + y * y)
phi_mu = 0.5 * (m_mu * (x * x) + m_mu * (y * y))
phi_nu = 0.5 * (x * x + y * y)
numIters = 30
sigma = 4 / np.maximum(mu.max(), nu.max())

# normalize the two distributions 
mu *= 1 / np.sum(mu)
nu *= 1 / np.sum(nu)
# visualize the normalized mu distribution
fig, ax = plt.subplots(1, 2)
ax[0].imshow(mu, origin='lower', extent=(0, 1, 0, 1))
ax[0].set_title("$\\bar{\\mu}$")
ax[1].imshow(nu, origin='lower', extent=(0, 1, 0, 1))
ax[1].set_title("$\\bar{\\nu}$")

# visualize the pushed nu distribution
fig, ax = plt.subplots(1, 2)
ax[0].imshow(mu, origin='lower', extent=(0, 1, 0, 1))
ax[0].set_title("$T_{m_{\\mu}\#}\\mu$")
ax[1].imshow(nu, origin='lower', extent=(0, 1, 0, 1))
ax[1].set_title("$T_{m_{\\nu}\#}\\nu$")


bf = BFM(n1, n2, mu)
compute_ot_wop(phi, psi, bf, sigma)    
plot_interpolation_wop(mu, nu, phi, psi, bf, x, y, n1, n2, m_mu, m_nu)
plot_interpolation(mu, nu, phi, psi, bf, x, y)
    
total mass of mu = 5954.14525356209, total mass of nu = 16384.5
total mass of mu = 0.36340109576502727, total mass of nu = 1.0
iter    0,   W2 value: 4.052583e-01,   H1 err: 4.89e-11
iter   10,   W2 value: 4.052587e-01,   H1 err: 8.54e-12
iter   20,   W2 value: 4.052588e-01,   H1 err: 1.34e-12
iter   30,   W2 value: 4.052588e-01,   H1 err: 9.09e-13
../_images/1283f933c6fee83f6ec2a245a24f3fd7845f35d08ff0002e69c1e5e371048dc6.png ../_images/1283f933c6fee83f6ec2a245a24f3fd7845f35d08ff0002e69c1e5e371048dc6.png ../_images/071aaebfc01cd5d5384056cd79c826656e95af1088bb1b6efbc837a9ece5746a.png ../_images/5b9e64de9959d8e575f418a1ff826739c6051f86798cac379a065406044fbf3b.png ../_images/987adbc13e6f3d9895ea025d5cff72bcaa0f5d42ef0b2533433172c0af7df210.png ../_images/8117b7312a6825ccff77a9ee9edf786feb179d0d6ebb96cc2e9c1b0aab03f329.png