Add mass WOP#

import sys
import os
sys.path.append(os.path.abspath("../.."))

import numpy as np
import numpy.ma as ma
from scipy.fftpack import dctn, idctn
import matplotlib.pyplot as plt
from time import time
from BFOT.BFOT import BFM
from skimage.transform import resize
from skimage.filters import unsharp_mask

plt.rcParams['figure.figsize'] = (13, 8)
plt.rcParams['image.cmap'] = 'viridis'

# Function definitions

def initialize_kernel(n1, n2):
    xx, yy = np.meshgrid(np.linspace(0, np.pi, n1, False), np.linspace(0, np.pi, n2, False))
    kernel = 2 * n1 * n1 * (1 - np.cos(xx)) + 2 * n2 * n2 * (1 - np.cos(yy))
    kernel[0, 0] = 1  # to avoid dividing by zero
    return kernel

def dct2(a):
    return dctn(a, norm='ortho')

def idct2(a):
    return idctn(a, norm='ortho')

def update_potential(phi, rho, nu, kernel, sigma):
    n1, n2 = nu.shape
    # rho -= nu.flatten()
    rho -= nu
    # workspace = dct2(rho.reshape(n2, n1)) / kernel
    workspace = dct2(rho) / kernel
    workspace[0, 0] = 0
    # workspace = idct2(workspace).flatten()
    workspace = idct2(workspace)
    phi += sigma * workspace
    h1 = np.sum(workspace * rho) / (n1 * n2)
    return h1

def compute_w2(phi, psi, mu, nu, x, y):
    n1, n2 = mu.shape
    return np.sum(0.5 * (x * x + y * y) * (mu + nu) - nu * phi.reshape((n1, n2)) - mu * psi.reshape((n1, n2)) + (m_mu - m_nu) ** 2) / (n1 * n2)

scaleDown = 0.95
scaleUp = 1 / scaleDown
upper = 0.75
lower = 0.25

def stepsize_update(sigma, value, oldValue, gradSq):
    diff = value - oldValue
    if diff > gradSq * sigma * upper:
        return sigma * scaleUp
    elif diff < gradSq * sigma * lower:
        return sigma * scaleDown
    return sigma

def compute_ot(phi, psi, bf, sigma):
    kernel = initialize_kernel(n1, n2)
    # rho = np.copy(mu.flatten())
    rho = np.copy(mu)
    oldValue = compute_w2(phi, psi, mu, nu, x, y)
    for k in range(numIters + 1):
        gradSq = update_potential(phi, rho, nu, kernel, sigma)
        bf.ctransform(psi, phi)
        bf.ctransform(phi, psi)
        value = compute_w2(phi, psi, mu, nu, x, y)
        sigma = stepsize_update(sigma, value, oldValue, gradSq)
        oldValue = value
        # bf.pushforward(rho, phi, nu.flatten())
        bf.pushforward(rho, phi, nu)
        gradSq = update_potential(psi, rho, mu, kernel, sigma)
        bf.ctransform(phi, psi)
        bf.ctransform(psi, phi)
        # bf.pushforward(rho, psi, mu.flatten())
        bf.pushforward(rho, psi, mu)
        value = compute_w2(phi, psi, mu, nu, x, y)
        sigma = stepsize_update(sigma, value, oldValue, gradSq)
        oldValue = value
        if k % 5 == 0:
            print(f'iter {k:4d},   W2 value: {value:.6e},   H1 err: {gradSq:.2e}')

def process_image(image_array, output_size, sharpen=False, sharpen_radius=1.0, sharpen_amount=1.5):
    """
    Processes an image to a specified resolution and optionally applies sharpening.

    Parameters:
        image_array (np.ndarray): Input image as a NumPy array.
        output_size (tuple): Desired output size as (height, width).
        sharpen (bool): Whether to apply sharpening to the resized image. Default is False.
        sharpen_radius (float): Radius for the unsharp mask filter. Default is 1.0.
        sharpen_amount (float): Amount for the unsharp mask filter. Default is 1.5.

    Returns:
        np.ndarray: Processed image array.
    """
    # Ensure input is a NumPy array
    if not isinstance(image_array, np.ndarray):
        raise ValueError("The input image must be a NumPy array.")

    # Resize the image
    resized_image = resize(image_array, output_size, anti_aliasing=True)

    # Apply sharpening if requested
    if sharpen:
        processed_image = unsharp_mask(resized_image, radius=sharpen_radius, amount=sharpen_amount)
    else:
        processed_image = resized_image

    return processed_image

def process_shrink_image(image_array, output_size, scale_factor, sharpen=False, sharpen_radius=1.0, sharpen_amount=1.5):
    """
    Shrinks an image by the given scale factor and embeds it in the center of a blank image
    of size 'output_size'. The blank space is filled with zeros.
    """
    # Calculate new dimensions based on the scale factor
    new_height = int(output_size[0] * scale_factor)
    new_width  = int(output_size[1] * scale_factor)
    
    # Resize the image to the new (smaller) dimensions
    resized_image = resize(image_array, (new_height, new_width), anti_aliasing=True)
    if sharpen:
        resized_image = unsharp_mask(resized_image, radius=sharpen_radius, amount=sharpen_amount)
    
    # Create a blank image (all zeros) with the full output dimensions
    final_image = np.zeros(output_size, dtype=resized_image.dtype)
    
    # Compute offsets so that the resized image is centered
    offset_h = (output_size[0] - new_height) // 2
    offset_w = (output_size[1] - new_width) // 2
    
    # Place the resized image in the center of the blank canvas
    final_image[offset_h:offset_h + new_height, offset_w:offset_w + new_width] = resized_image

    return final_image

#-----------------#
# Read from data, black is large pixel
mu = 1-plt.imread('../../images/star.png')[:, :, 2]
nu = 1-plt.imread('../../images/fivestar.png')[:, :, 2]

# Masses
m_mu = 2
m_nu = 3

# Grid of size n1 x n2
n1, n2 = 256, 256
output_dim = (n1, n2)

# Resize the images
if m_mu == m_nu:
    mu = process_image(mu,output_dim,sharpen=True)
    nu = process_image(nu,output_dim,sharpen=True)
elif m_mu > m_nu:
    scale_factor = np.sqrt(m_nu / m_mu)
    mu = process_shrink_image(mu, output_dim, scale_factor, sharpen=True)
    nu = process_image(nu,output_dim,sharpen=True)
else:
    scale_factor = np.sqrt(m_mu / m_nu)
    mu = process_image(mu,output_dim,sharpen=True)
    nu = process_shrink_image(nu, output_dim, scale_factor, sharpen=True)

mu *= n1 * n2 / np.sum(mu)
nu *= n1 * n2 / np.sum(nu)



x, y = np.meshgrid(np.linspace(0.5 / n1, 1 - 0.5 / n1, n1), np.linspace(0.5 / n2, 1 - 0.5 / n2, n2))
# phi = 0.5 * (x * x + y * y).flatten()
# psi = 0.5 * (x * x + y * y).flatten()
phi = 0.5 * (x * x + y * y)
psi = 0.5 * (x * x + y * y)

#-----------------#
vmin = min(mu.min(), nu.min())
vmax = max(mu.max(), nu.max())

fig, ax = plt.subplots(1, 2)
im0 = ax[0].imshow(mu, origin='lower', extent=(0, 1, 0, 1), vmin=vmin, vmax=vmax)
ax[0].set_title("$\\mu$")
im1 = ax[1].imshow(nu, origin='lower', extent=(0, 1, 0, 1), vmin=vmin, vmax=vmax)
ax[1].set_title("$\\nu$")

cbar = fig.colorbar(im0, ax=ax.ravel().tolist(), orientation='vertical', label='Pixel Intensity')

numIters = 40
sigma = 4 / np.maximum(mu.max(), nu.max())

tic = time()
bf = BFM(n1, n2, mu)
compute_ot(phi, psi, bf, sigma)
toc = time()
print(f'\nElapsed time: {toc - tic:.2f}s')

# my, mx = ma.masked_array(np.gradient(psi.reshape((n2, n1)) - 0.5 * (x * x + y * y), 1 / n2, 1 / n1), mask=((mu == 0), (mu == 0)))
# fig, ax = plt.subplots()
# ax.contourf(x, y, mu + nu)
# ax.set_aspect('equal')
# skip = (slice(None, None, n1 // 50), slice(None, None, n2 // 50))
# ax.quiver(x[skip], y[skip], mx[skip], my[skip], color='yellow', angles='xy', scale_units='xy', scale=1)

# fig, ax = plt.subplots(1, 2)
# ax[0].imshow((x + mx).reshape((n2, n1)), origin='lower', extent=(0, 1, 0, 1), cmap='plasma')
# x_masked = ma.masked_array(x, mask=(nu == 0))
# ax[1].imshow(x_masked, origin='lower', extent=(0, 1, 0, 1), cmap='plasma')

def plot_interpolation(mu, nu, phi, psi, bf, x, y, n_fig=6):
    """
    Plots the discrete geodesic interpolation between mu and nu
    using the given potentials phi and psi.
    
    Parameters
    ----------
    mu : 2D array (n2, n1)
        Source measure.
    nu : 2D array (n2, n1)
        Target measure.
    phi, psi : 2D arrays (n2, n1)
        Kantorovich potentials.
    bf : BFM object
        Instance of the Semi-discrete solver that has a pushforward method.
    x, y : 2D arrays (n2, n1)
        Grid coordinates.
    n_fig : int
        Number of frames to plot (including t=0 and t=1).
    """
    fig, ax = plt.subplots(1, n_fig, figsize=(20, 8))
    for axi in ax:
        axi.axis('off')

    # Use the same max for color scaling
    vmax = max(mu.max(), nu.max())

    # Show the initial and final measures
    im0 = ax[0].imshow(mu, origin='lower', extent=(0, 1, 0, 1), vmax=vmax)
    ax[0].set_title("$t=0$")

    ax[n_fig - 1].imshow(nu, origin='lower', extent=(0, 1, 0, 1), vmax=vmax)
    ax[n_fig - 1].set_title("$t=1$")

    # Allocate arrays for intermediate pushforwards
    interpolate = np.zeros_like(mu)
    rho_fwd = np.zeros_like(mu)
    rho_bwd = np.zeros_like(mu)

    # "Reference" potential = 0.5*(x² + y²)
    phi_0 = 0.5 * (x**2 + y**2)

    # Generate intermediate frames
    for i in range(1, n_fig - 1):
        t = i / (n_fig - 1)

        # Blend potentials in 2D form
        # psi_t interpolates from phi_0 to psi
        # phi_t interpolates from phi_0 to phi
        psi_t = (1 - t) * phi_0 + t * psi
        phi_t = t * phi_0 + (1 - t) * phi

        # Pushforward mu by psi_t
        bf.pushforward(rho_fwd, psi_t, mu)

        # Pushforward nu by phi_t
        bf.pushforward(rho_bwd, phi_t, nu)

        # Convex combination of the two pushforwards
        interpolate = (1 - t) * rho_fwd + t * rho_bwd

        # Plot the interpolation
        ax[i].imshow(interpolate, origin='lower', extent=(0, 1, 0, 1), vmax=vmax)
        ax[i].set_title(f"$t={t:.2f}$")

    # Adjust the main subplot layout to leave room for a dedicated colorbar axis
    plt.subplots_adjust(right=0.85)  # Shrink subplots to make room on the right

    # Create a new axis for the colorbar (position: [left, bottom, width, height])
    cbar_ax = fig.add_axes([0.88, 0.15, 0.03, 0.7])
    fig.colorbar(im0, cax=cbar_ax, label='Intensity')

    plt.show()

plot_interpolation(mu, nu, phi, psi, bf, x, y)
plt.show()
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[1], line 10
      8 import matplotlib.pyplot as plt
      9 from time import time
---> 10 from BFOT.BFOT import BFM
     11 from skimage.transform import resize
     12 from skimage.filters import unsharp_mask

File ~/work/BFOT/BFOT/BFOT/BFOT.py:2
      1 import numpy as np
----> 2 import numba
      3 from numba import prange
      4 from scipy.signal import convolve2d

ModuleNotFoundError: No module named 'numba'