import sys
import os
sys.path.append(os.path.abspath("../.."))
import numpy as np
import numpy.ma as ma
from scipy.fftpack import dctn, idctn
import matplotlib.pyplot as plt
from time import time
from BFOT.BFOT import BFM
from skimage.transform import resize
from skimage.filters import unsharp_mask
plt.rcParams['figure.figsize'] = (13, 8)
plt.rcParams['image.cmap'] = 'viridis'
# Function definitions
def initialize_kernel(n1, n2):
xx, yy = np.meshgrid(np.linspace(0, np.pi, n1, False), np.linspace(0, np.pi, n2, False))
kernel = 2 * n1 * n1 * (1 - np.cos(xx)) + 2 * n2 * n2 * (1 - np.cos(yy))
kernel[0, 0] = 1 # to avoid dividing by zero
return kernel
def dct2(a):
return dctn(a, norm='ortho')
def idct2(a):
return idctn(a, norm='ortho')
def update_potential(phi, rho, nu, kernel, sigma):
n1, n2 = nu.shape
# rho -= nu.flatten()
rho -= nu
# workspace = dct2(rho.reshape(n2, n1)) / kernel
workspace = dct2(rho) / kernel
workspace[0, 0] = 0
# workspace = idct2(workspace).flatten()
workspace = idct2(workspace)
phi += sigma * workspace
h1 = np.sum(workspace * rho) / (n1 * n2)
return h1
def compute_w2(phi, psi, mu, nu, x, y):
n1, n2 = mu.shape
return np.sum(0.5 * (x * x + y * y) * (mu + nu) - nu * phi.reshape((n1, n2)) - mu * psi.reshape((n1, n2)) + (m_mu - m_nu) ** 2) / (n1 * n2)
scaleDown = 0.95
scaleUp = 1 / scaleDown
upper = 0.75
lower = 0.25
def stepsize_update(sigma, value, oldValue, gradSq):
diff = value - oldValue
if diff > gradSq * sigma * upper:
return sigma * scaleUp
elif diff < gradSq * sigma * lower:
return sigma * scaleDown
return sigma
def compute_ot(phi, psi, bf, sigma):
kernel = initialize_kernel(n1, n2)
# rho = np.copy(mu.flatten())
rho = np.copy(mu)
oldValue = compute_w2(phi, psi, mu, nu, x, y)
for k in range(numIters + 1):
gradSq = update_potential(phi, rho, nu, kernel, sigma)
bf.ctransform(psi, phi)
bf.ctransform(phi, psi)
value = compute_w2(phi, psi, mu, nu, x, y)
sigma = stepsize_update(sigma, value, oldValue, gradSq)
oldValue = value
# bf.pushforward(rho, phi, nu.flatten())
bf.pushforward(rho, phi, nu)
gradSq = update_potential(psi, rho, mu, kernel, sigma)
bf.ctransform(phi, psi)
bf.ctransform(psi, phi)
# bf.pushforward(rho, psi, mu.flatten())
bf.pushforward(rho, psi, mu)
value = compute_w2(phi, psi, mu, nu, x, y)
sigma = stepsize_update(sigma, value, oldValue, gradSq)
oldValue = value
if k % 5 == 0:
print(f'iter {k:4d}, W2 value: {value:.6e}, H1 err: {gradSq:.2e}')
def process_image(image_array, output_size, sharpen=False, sharpen_radius=1.0, sharpen_amount=1.5):
"""
Processes an image to a specified resolution and optionally applies sharpening.
Parameters:
image_array (np.ndarray): Input image as a NumPy array.
output_size (tuple): Desired output size as (height, width).
sharpen (bool): Whether to apply sharpening to the resized image. Default is False.
sharpen_radius (float): Radius for the unsharp mask filter. Default is 1.0.
sharpen_amount (float): Amount for the unsharp mask filter. Default is 1.5.
Returns:
np.ndarray: Processed image array.
"""
# Ensure input is a NumPy array
if not isinstance(image_array, np.ndarray):
raise ValueError("The input image must be a NumPy array.")
# Resize the image
resized_image = resize(image_array, output_size, anti_aliasing=True)
# Apply sharpening if requested
if sharpen:
processed_image = unsharp_mask(resized_image, radius=sharpen_radius, amount=sharpen_amount)
else:
processed_image = resized_image
return processed_image
def process_shrink_image(image_array, output_size, scale_factor, sharpen=False, sharpen_radius=1.0, sharpen_amount=1.5):
"""
Shrinks an image by the given scale factor and embeds it in the center of a blank image
of size 'output_size'. The blank space is filled with zeros.
"""
# Calculate new dimensions based on the scale factor
new_height = int(output_size[0] * scale_factor)
new_width = int(output_size[1] * scale_factor)
# Resize the image to the new (smaller) dimensions
resized_image = resize(image_array, (new_height, new_width), anti_aliasing=True)
if sharpen:
resized_image = unsharp_mask(resized_image, radius=sharpen_radius, amount=sharpen_amount)
# Create a blank image (all zeros) with the full output dimensions
final_image = np.zeros(output_size, dtype=resized_image.dtype)
# Compute offsets so that the resized image is centered
offset_h = (output_size[0] - new_height) // 2
offset_w = (output_size[1] - new_width) // 2
# Place the resized image in the center of the blank canvas
final_image[offset_h:offset_h + new_height, offset_w:offset_w + new_width] = resized_image
return final_image
#-----------------#
# Read from data, black is large pixel
mu = 1-plt.imread('../../images/star.png')[:, :, 2]
nu = 1-plt.imread('../../images/fivestar.png')[:, :, 2]
# Masses
m_mu = 2
m_nu = 3
# Grid of size n1 x n2
n1, n2 = 256, 256
output_dim = (n1, n2)
# Resize the images
if m_mu == m_nu:
mu = process_image(mu,output_dim,sharpen=True)
nu = process_image(nu,output_dim,sharpen=True)
elif m_mu > m_nu:
scale_factor = np.sqrt(m_nu / m_mu)
mu = process_shrink_image(mu, output_dim, scale_factor, sharpen=True)
nu = process_image(nu,output_dim,sharpen=True)
else:
scale_factor = np.sqrt(m_mu / m_nu)
mu = process_image(mu,output_dim,sharpen=True)
nu = process_shrink_image(nu, output_dim, scale_factor, sharpen=True)
mu *= n1 * n2 / np.sum(mu)
nu *= n1 * n2 / np.sum(nu)
x, y = np.meshgrid(np.linspace(0.5 / n1, 1 - 0.5 / n1, n1), np.linspace(0.5 / n2, 1 - 0.5 / n2, n2))
# phi = 0.5 * (x * x + y * y).flatten()
# psi = 0.5 * (x * x + y * y).flatten()
phi = 0.5 * (x * x + y * y)
psi = 0.5 * (x * x + y * y)
#-----------------#
vmin = min(mu.min(), nu.min())
vmax = max(mu.max(), nu.max())
fig, ax = plt.subplots(1, 2)
im0 = ax[0].imshow(mu, origin='lower', extent=(0, 1, 0, 1), vmin=vmin, vmax=vmax)
ax[0].set_title("$\\mu$")
im1 = ax[1].imshow(nu, origin='lower', extent=(0, 1, 0, 1), vmin=vmin, vmax=vmax)
ax[1].set_title("$\\nu$")
cbar = fig.colorbar(im0, ax=ax.ravel().tolist(), orientation='vertical', label='Pixel Intensity')
numIters = 40
sigma = 4 / np.maximum(mu.max(), nu.max())
tic = time()
bf = BFM(n1, n2, mu)
compute_ot(phi, psi, bf, sigma)
toc = time()
print(f'\nElapsed time: {toc - tic:.2f}s')
# my, mx = ma.masked_array(np.gradient(psi.reshape((n2, n1)) - 0.5 * (x * x + y * y), 1 / n2, 1 / n1), mask=((mu == 0), (mu == 0)))
# fig, ax = plt.subplots()
# ax.contourf(x, y, mu + nu)
# ax.set_aspect('equal')
# skip = (slice(None, None, n1 // 50), slice(None, None, n2 // 50))
# ax.quiver(x[skip], y[skip], mx[skip], my[skip], color='yellow', angles='xy', scale_units='xy', scale=1)
# fig, ax = plt.subplots(1, 2)
# ax[0].imshow((x + mx).reshape((n2, n1)), origin='lower', extent=(0, 1, 0, 1), cmap='plasma')
# x_masked = ma.masked_array(x, mask=(nu == 0))
# ax[1].imshow(x_masked, origin='lower', extent=(0, 1, 0, 1), cmap='plasma')
def plot_interpolation(mu, nu, phi, psi, bf, x, y, n_fig=6):
"""
Plots the discrete geodesic interpolation between mu and nu
using the given potentials phi and psi.
Parameters
----------
mu : 2D array (n2, n1)
Source measure.
nu : 2D array (n2, n1)
Target measure.
phi, psi : 2D arrays (n2, n1)
Kantorovich potentials.
bf : BFM object
Instance of the Semi-discrete solver that has a pushforward method.
x, y : 2D arrays (n2, n1)
Grid coordinates.
n_fig : int
Number of frames to plot (including t=0 and t=1).
"""
fig, ax = plt.subplots(1, n_fig, figsize=(20, 8))
for axi in ax:
axi.axis('off')
# Use the same max for color scaling
vmax = max(mu.max(), nu.max())
# Show the initial and final measures
im0 = ax[0].imshow(mu, origin='lower', extent=(0, 1, 0, 1), vmax=vmax)
ax[0].set_title("$t=0$")
ax[n_fig - 1].imshow(nu, origin='lower', extent=(0, 1, 0, 1), vmax=vmax)
ax[n_fig - 1].set_title("$t=1$")
# Allocate arrays for intermediate pushforwards
interpolate = np.zeros_like(mu)
rho_fwd = np.zeros_like(mu)
rho_bwd = np.zeros_like(mu)
# "Reference" potential = 0.5*(x² + y²)
phi_0 = 0.5 * (x**2 + y**2)
# Generate intermediate frames
for i in range(1, n_fig - 1):
t = i / (n_fig - 1)
# Blend potentials in 2D form
# psi_t interpolates from phi_0 to psi
# phi_t interpolates from phi_0 to phi
psi_t = (1 - t) * phi_0 + t * psi
phi_t = t * phi_0 + (1 - t) * phi
# Pushforward mu by psi_t
bf.pushforward(rho_fwd, psi_t, mu)
# Pushforward nu by phi_t
bf.pushforward(rho_bwd, phi_t, nu)
# Convex combination of the two pushforwards
interpolate = (1 - t) * rho_fwd + t * rho_bwd
# Plot the interpolation
ax[i].imshow(interpolate, origin='lower', extent=(0, 1, 0, 1), vmax=vmax)
ax[i].set_title(f"$t={t:.2f}$")
# Adjust the main subplot layout to leave room for a dedicated colorbar axis
plt.subplots_adjust(right=0.85) # Shrink subplots to make room on the right
# Create a new axis for the colorbar (position: [left, bottom, width, height])
cbar_ax = fig.add_axes([0.88, 0.15, 0.03, 0.7])
fig.colorbar(im0, cax=cbar_ax, label='Intensity')
plt.show()
plot_interpolation(mu, nu, phi, psi, bf, x, y)
plt.show()