Skip to content

yixuan/regot-python

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

125 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

RegOT-Pythonregot

RegOT is a collection of state-of-the-art solvers for regularized optimal transport (OT) problems, implemented in efficient C++ code. This repository is the Python interface to RegOT.

πŸ“ Formulation

RegOT mainly solves two types of regularized OT problems: the entropic-regularized OT (EROT) and the quadratically regularized OT (QROT).

EROT, also known as the Sinkhorn-type OT, considers the following optimization problem:

$$\begin{align*} \min_{T\in\mathbb{R}^{n\times m}}\quad & \langle T,M\rangle-\eta h(T),\\\ \text{subject to}\quad & T\mathbf{1}_{m}=a,T^{T}\mathbf{1}_{n}=b,T\ge0, \end{align*}$$

where $a\in\mathbb{R}^n$ and $b\in\mathbb{R}^m$ are two given probability vectors with $a_i>0$, $b_j>0$, $\sum_{i=1}^n a_i=\sum_{j=1}^m b_j=1$, and $M\in\mathbb{R}^{n\times m}$ is a given cost matrix. The function $h(T)=\sum_{i=1}^{n}\sum_{j=1}^{m}T_{ij}(1-\log T_{ij})$ is the entropy term, and $\eta>0$ is a regularization parameter.

QROT, also known as the Euclidean-regularized OT, is concerned with the problem

$$\begin{align*} \min_{T\in\mathbb{R}^{n\times m}}\quad & \langle T,M\rangle+(\gamma/2) \Vert T \Vert_F^2,\\\ \text{subject to}\quad & T\mathbf{1}_{m}=a,T^{T}\mathbf{1}_{n}=b,T\ge0. \end{align*}$$

πŸ”§ Solvers

Currently RegOT contains the following solvers for EROT (methods marked with 🌟 are developed by our group!):

  • sinkhorn_bcd: the block coordinate descent (BCD) algorithm, equivalent to the well-known Sinkhorn algorithm.
  • sinkhorn_apdagd: the adaptive primal-dual accelerate gradient descent (APDAGD) algorithm (link to paper).
  • sinkhorn_lbfgs_dual: the L-BFGS algorithm applied to the dual problem of EROT.
  • sinkhorn_newton: Newton's method applied to the dual problem of EROT.
  • 🌟sinkhorn_sparse_newton: Newton-type method using sparsified Hessian matrix, as described in our SPLR paper.
  • 🌟sinkhorn_ssns: the safe and sparse Newton method for Sinkhorn-type OT (SSNS, link to paper).
  • 🌟sinkhorn_splr: the sparse-plus-low-rank quasi-Newton method for the dual problem of EROT (SPLR, link to paper).

The following solvers are available for the QROT problem:

  • qrot_bcd: the BCD algorithm.
  • qrot_gd: the line search gradient descent algorithm applied to the dual problem of QROT.
  • qrot_apdagd: the APDAGD algorithm (link to paper).
  • qrot_pdaam: the primal-dual accelerated alternating minimization (PDAAM) algorithm (link to paper).
  • qrot_lbfgs_dual: the L-BFGS algorithm applied to the dual problem of QROT.
  • qrot_lbfgs_semi_dual: the L-BFGS algorithm applied to the semi-dual problem of QROT (link to paper).
  • qrot_assn: the adaptive semi-smooth Newton (ASSN) method applied to the dual problem of QROT (link to paper).
  • qrot_grssn: the globalized and regularized semi-smooth Newton (GRSSN) method applied to the dual problem of QROT (link to paper).

All the solvers above return an object containing fields:

  • niter: number of iterations used.
  • dual: final dual variables.
  • plan: computed transport plan.
  • obj_vals: history of dual objective function values.
  • mar_errs: history of marginal errors.
  • run_times: cumulative runtimes of iterations in milliseconds.

A specialized primal-dual interior-point solver is also available for the QROT problem:

  • qrot_pdip: the primal-dual interior-point method. Pass inner_solver="cg" for a CG-based inner solver, or inner_solver="fp" for the fixed-point method inner solver. Default is inner_solver="cg".
  • pdip_cg / pdip_fp: aliases for qrot_pdip with inner_solver="cg" and "fp", respectively.
  • For cg: default tol=1e-8 for normalized primal/dual gaps and mu. When cg_stop_gap_mu_only is false (default), marginal-error stopping uses cg_mar_tol (default 1e-10), independent of tol. Pass cg_mar_tol in kwargs to override. Set cg_stop_gap_mu_only=True to require only gaps + mu (like the FP solver).
  • For fp: default tol=1e-8. By default, stopping uses only normalized primal/dual gaps and mu (fp_stop_gap_mu_only=True). Pass fp_stop_gap_mu_only=False if you also want to stop when marginal error mar_err falls below tol.

πŸ’½ Installation

Using pip

You can simply install RegOT using the pip command:

pip install regot

Building from source

A C++ compiler is needed to build RegOT from source. Enter the source directory and run

pip install . -r requirements.txt

Developer / Profiling builds

An optional environment variable REGOT_PDIP_DEV can be set during installation (e.g., REGOT_PDIP_DEV=1 pip install .) to enable additional profiling functions. See src/pdip_dev_flags.h for details.

πŸ“— Example

The code below shows minimal examples computing EROT and QROT transport plans given $a$, $b$, $M$, and $\eta$.

import numpy as np
from scipy.stats import expon, norm
import regot
import matplotlib.pyplot as plt

# OT between two discretized distributions
# One is exponential, the other is mixture normal
def example(n=100, m=80):
    x1 = np.linspace(0.0, 5.0, num=n)
    x2 = np.linspace(0.0, 5.0, num=m)
    distr1 = expon(scale=1.0)
    distr2 = norm(loc=1.0, scale=0.2)
    distr3 = norm(loc=3.0, scale=0.5)
    a = distr1.pdf(x1)
    a = a / np.sum(a)
    b = 0.2 * distr2.pdf(x2) + 0.8 * distr3.pdf(x2)
    b = b / np.sum(b)
    M = np.square(x1.reshape(n, 1) - x2.reshape(1, m))
    return M, a, b

# Source and target distribution vectors `a` and `b`
# Cost matrix `M`
# Regularization parameter `reg`
np.random.seed(123)
M, a, b = example(n=100, m=80)
reg = 0.1

# EROT transport plans
# Algorithm: block coordinate descent (the Sinkhorn algorithm)
res1 = regot.sinkhorn_bcd(
    M, a, b, reg, tol=1e-6, max_iter=1000, verbose=1)

# Algorithm: SSNS
reg = 0.01
res2 = regot.sinkhorn_ssns(
    M, a, b, reg, tol=1e-6, max_iter=1000, verbose=0)

# QROT transport plans
res3 = regot.qrot_pdip(
    M, a, b, reg=0.1, tol=1e-8, max_iter=1000, inner_solver="cg")
res4 = regot.qrot_pdip(
    M, a, b, reg=0.01, tol=1e-8, max_iter=1000, inner_solver="fp")

We can retrieve the computed transport plans and visualize them using heatmaps:

def vis_plan(T, title=""):
    fig = plt.figure(figsize=(8, 8))
    plt.imshow(T, interpolation="nearest")
    plt.title(title, fontsize=20)
    plt.show()

vis_plan(res1.plan, title="EROT (BCD), reg=0.1")
vis_plan(res2.plan, title="EROT (SSNS), reg=0.01")
vis_plan(res3.plan, title="QROT (PDIP-CG), reg=0.1")
vis_plan(res4.plan, title="QROT (PDIP-FP), reg=0.01")

🌟 Fun fact: The logo sticker of RegOT also uses the package itself to compute the transport pattern between point clouds. You can use this code to reproduce the image.

RegOT sticker

πŸ“ƒ Bibliography

Please consider to cite our work if you find our algorithms or software useful in your research and applications.

@inproceedings{tang2024safe,
  title={Safe and sparse Newton method for entropic-regularized optimal transport},
  author={Tang, Zihao and Qiu, Yixuan},
  booktitle={Advances in Neural Information Processing Systems},
  volume={37},
  pages={129914--129943},
  year={2024}
}

@inproceedings{wang2025sparse,
  title={The Sparse-Plus-Low-Rank quasi-Newton method for entropic-regularized optimal transport},
  author={Wang, Chenrui and Qiu, Yixuan},
  booktitle={Forty-second International Conference on Machine Learning},
  year={2025}
}

About

A collection of state-of-the-art solvers for regularized optimal transport (OT) problems, implemented in efficient C++ code.

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors