mdot-tnt 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mdot_tnt/__init__.py +57 -0
- mdot_tnt/mdot.py +139 -0
- mdot_tnt/rounding.py +62 -0
- mdot_tnt/truncated_newton.py +283 -0
- mdot_tnt-0.1.0.dist-info/LICENSE +26 -0
- mdot_tnt-0.1.0.dist-info/METADATA +71 -0
- mdot_tnt-0.1.0.dist-info/RECORD +9 -0
- mdot_tnt-0.1.0.dist-info/WHEEL +5 -0
- mdot_tnt-0.1.0.dist-info/top_level.txt +1 -0
mdot_tnt/__init__.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
|
|
2
|
+
import torch as th
|
|
3
|
+
|
|
4
|
+
from mdot_tnt.mdot import mdot
|
|
5
|
+
from mdot_tnt.rounding import round_altschuler, rounded_cost_altschuler
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def solve_OT(r, c, C, gamma_f=4096., drop_tiny=False, return_plan=False, round=True, log=False):
|
|
9
|
+
"""
|
|
10
|
+
Solve the entropic-regularized optimal transport problem. Inputs r, c, C are required to be torch tensors.
|
|
11
|
+
:param r: n-dimensional row marginal.
|
|
12
|
+
:param c: m-dimensional column marginal.
|
|
13
|
+
:param C: n x m cost matrix. Recommended use is to scale the entries to be in [0, 1].
|
|
14
|
+
:param gamma_f: The temperature (inverse of the regularization weight). For many problems, stable up to 2^18.
|
|
15
|
+
Higher values return more accurate solutions but take longer to converge. Use double precision if gamma_f large.
|
|
16
|
+
:param drop_tiny: If either marginal is known to be sparse, set this to True to drop tiny entries for a speedup.
|
|
17
|
+
If return_plan is True, the returned plan will be in the original dimensions.
|
|
18
|
+
:param return_plan: If True, return the optimal transport plan rather than the cost.
|
|
19
|
+
:param round: If True, use the rounding algorithm of Altschuler et al. (2017) to (a) return a feasible plan
|
|
20
|
+
if return_plan is True and (b) the cost of the rounded plan if return_plan is False.
|
|
21
|
+
:param log: If True, additionally return a dictionary containing logs of the optimization process.
|
|
22
|
+
:return: If return_plan is True, return the optimal transport plan as a torch tensor. Otherwise, return the cost.
|
|
23
|
+
"""
|
|
24
|
+
assert all(isinstance(x, th.Tensor) for x in [r, c, C]), "r, c, and C must be torch tensors"
|
|
25
|
+
dtype = r.dtype
|
|
26
|
+
if gamma_f > 2 ** 10:
|
|
27
|
+
r, c, C = r.double(), c.double(), C.double()
|
|
28
|
+
if drop_tiny:
|
|
29
|
+
drop_lessthan = math.log(min(r.size(-1), c.size(-1))) / (gamma_f ** 2)
|
|
30
|
+
(r_, r_keep), (c_, c_keep), C_ = preprocess_marginals(r, c, C, drop_lessthan)
|
|
31
|
+
|
|
32
|
+
u_, v_, gamma_f_, k_total, opt_logs = mdot(r_, c_, C_, gamma_f)
|
|
33
|
+
|
|
34
|
+
u = -th.ones_like(r) * float('inf')
|
|
35
|
+
u[:, r_keep] = u_
|
|
36
|
+
v = -th.ones_like(c) * float('inf')
|
|
37
|
+
v[:, c_keep] = v_
|
|
38
|
+
else:
|
|
39
|
+
u, v, gamma_f_, k_total, opt_logs = mdot(r, c, C, gamma_f)
|
|
40
|
+
|
|
41
|
+
u, v = u.to(dtype), v.to(dtype)
|
|
42
|
+
|
|
43
|
+
if return_plan:
|
|
44
|
+
P = (u.unsqueeze(-1) + v.unsqueeze(-2) - gamma_f_ * C).exp()
|
|
45
|
+
if round:
|
|
46
|
+
P = round_altschuler(P, r, c)
|
|
47
|
+
if log:
|
|
48
|
+
return P, opt_logs
|
|
49
|
+
return P
|
|
50
|
+
else:
|
|
51
|
+
if round:
|
|
52
|
+
cost = rounded_cost_altschuler(u, v, r, c, C, gamma_f_)
|
|
53
|
+
else:
|
|
54
|
+
cost = ((u.unsqueeze(-1) + v.unsqueeze(-2) - gamma_f_ * C).exp() * C).sum()
|
|
55
|
+
if log:
|
|
56
|
+
return cost, opt_logs
|
|
57
|
+
return cost
|
mdot_tnt/mdot.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Code for solving the entropic-regularized optimal transport problem via the MDOT-TruncatedNewton (MDOT-TNT)
|
|
3
|
+
method introduced in the paper "A Truncated Newton Method for Optimal Transport"
|
|
4
|
+
by Mete Kemertas, Amir-massoud Farahmand, Allan D. Jepson (ICLR, 2025).
|
|
5
|
+
URL: https://openreview.net/forum?id=gWrWUaCbMa
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
import math
|
|
10
|
+
import warnings
|
|
11
|
+
|
|
12
|
+
from mdot_tnt.rounding import *
|
|
13
|
+
from mdot_tnt.truncated_newton import TruncatedNewtonProjector
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def preprocess_marginals(r, c, C, eps):
|
|
17
|
+
"""
|
|
18
|
+
This function drops the smallest entries whose cumulative sum equals
|
|
19
|
+
:param r:
|
|
20
|
+
:param c:
|
|
21
|
+
:param C:
|
|
22
|
+
:param eps:
|
|
23
|
+
:return:
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def preprocess_marginal(m, eps):
|
|
27
|
+
m_sorted, m_idx = th.sort(m, dim=-1, descending=False)
|
|
28
|
+
m_cumsum = th.cumsum(m_sorted, dim=-1)
|
|
29
|
+
m_keep = m_idx[m_cumsum > eps]
|
|
30
|
+
m_new = m[:, m_keep]
|
|
31
|
+
mass_removed = 1 - m_new.sum(-1)
|
|
32
|
+
m_new = m_new + mass_removed / m_new.size(-1)
|
|
33
|
+
|
|
34
|
+
return m_new, m_keep
|
|
35
|
+
|
|
36
|
+
r_new, r_keep = preprocess_marginal(r, eps)
|
|
37
|
+
c_new, c_keep = preprocess_marginal(c, eps)
|
|
38
|
+
|
|
39
|
+
C = C[r_keep][:, c_keep]
|
|
40
|
+
|
|
41
|
+
return (r_new, r_keep), (c_new, c_keep), C
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def smooth_marginals(r, c, eps, w_r=0.5, w_c=0.5):
|
|
45
|
+
assert w_r + w_c == 1, "w_r and w_c must sum to 1"
|
|
46
|
+
eps = eps.clamp(max=1.).unsqueeze(-1)
|
|
47
|
+
r_hat = (1 - w_r * eps) * r + w_r * eps * th.ones_like(r) / r.size(-1)
|
|
48
|
+
c_hat = (1 - w_c * eps) * c + w_c * eps * th.ones_like(c) / c.size(-1)
|
|
49
|
+
|
|
50
|
+
return r_hat, c_hat
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def adjust_schedule(q, deltas=None):
|
|
54
|
+
if deltas is None:
|
|
55
|
+
return q
|
|
56
|
+
|
|
57
|
+
deltas = deltas + [1.] # If deltas is empty, we assume that the first iteration was successful
|
|
58
|
+
delta_min = min(deltas)
|
|
59
|
+
|
|
60
|
+
if delta_min < 0.5:
|
|
61
|
+
q = q ** 0.5
|
|
62
|
+
elif delta_min > 0.9:
|
|
63
|
+
q = q ** 2
|
|
64
|
+
|
|
65
|
+
return q
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def mdot(r, c, C, gamma_f, gamma_i=16, p=1.5, q=2.0**(1/3)):
|
|
69
|
+
"""
|
|
70
|
+
Solve the entropic-regularized optimal transport problem using the MDOT method introduced in the paper:
|
|
71
|
+
"Efficient and Accurate Optimal Transport with Mirror Descent and Conjugate Gradients" by Mete Kemertas,
|
|
72
|
+
Allan D. Jepson and Amir-massoud Farahmand. URL: https://arxiv.org/abs/2307.08507
|
|
73
|
+
Here, we use the Truncated Newton method for projection.
|
|
74
|
+
:param r: The first marginal.
|
|
75
|
+
:param c: The second marginal.
|
|
76
|
+
:param C: The cost matrix. Recommended use is to scale the entries to be in [0, 1].
|
|
77
|
+
:param gamma_f: The final temperature (inverse of the regularization weight).
|
|
78
|
+
:param gamma_i: The initial temperature.
|
|
79
|
+
:param p: The exponent for the epsilon function, used to determine the stopping criterion for the dual gradient.
|
|
80
|
+
:param q: The temperature annealing (or mirror descent step size) schedule adjustment factor.
|
|
81
|
+
:return:
|
|
82
|
+
"""
|
|
83
|
+
projector = TruncatedNewtonProjector(device=C.device, dtype=C.dtype)
|
|
84
|
+
|
|
85
|
+
H_r = -(r * (r + 1e-30).log()).sum(-1)
|
|
86
|
+
H_c = -(c * (c + 1e-30).log()).sum(-1)
|
|
87
|
+
H_min = th.min(H_r, H_c)
|
|
88
|
+
eps_fn = lambda g_: H_min / (g_ ** p)
|
|
89
|
+
|
|
90
|
+
logs = {
|
|
91
|
+
"proj_logs": [],
|
|
92
|
+
"eps": [],
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
t = 1
|
|
96
|
+
done = False
|
|
97
|
+
gamma = min(gamma_i, gamma_f)
|
|
98
|
+
gammas = [0., gamma]
|
|
99
|
+
|
|
100
|
+
while not done:
|
|
101
|
+
done = abs(gamma - gamma_f) < 1e-5 # Check if gamma == gamma_f (modulo rounding errors)
|
|
102
|
+
|
|
103
|
+
eps_d = eps_fn(gamma)
|
|
104
|
+
|
|
105
|
+
r_hat, c_hat = smooth_marginals(r, c, eps_d / 2, w_r=0.9, w_c=0.1)
|
|
106
|
+
|
|
107
|
+
if t == 1:
|
|
108
|
+
u_init, v_init = r_hat.log(), c_hat.log()
|
|
109
|
+
u_cur, v_cur = u_init.clone(), v_init.clone()
|
|
110
|
+
|
|
111
|
+
u_prev, v_prev = u_cur.clone(), v_cur.clone()
|
|
112
|
+
u_cur, v_cur, proj_log, success = projector.project(
|
|
113
|
+
gamma * C, r_hat.log(), c_hat.log(), eps_d / 2, u_init, v_init)
|
|
114
|
+
|
|
115
|
+
logs["proj_logs"].append(proj_log)
|
|
116
|
+
if not success:
|
|
117
|
+
warnings.warn("Projection failed. Returning result at the last temperature: {:.4e}".format(1 / gammas[-2]))
|
|
118
|
+
u_cur = u_prev.clone()
|
|
119
|
+
v_cur = v_prev.clone()
|
|
120
|
+
gammas = gammas[:-1]
|
|
121
|
+
break
|
|
122
|
+
|
|
123
|
+
q = adjust_schedule(q, proj_log["deltas"])
|
|
124
|
+
gamma = min(gamma * q, gamma_f)
|
|
125
|
+
|
|
126
|
+
if not done:
|
|
127
|
+
# Generate warm-started initializations for the next iteration.
|
|
128
|
+
u_init = u_cur + (u_cur - u_prev) * (gamma - gammas[-1]) / (gammas[-1] - gammas[-2])
|
|
129
|
+
v_init = v_cur + (v_cur - v_prev) * (gamma - gammas[-1]) / (gammas[-1] - gammas[-2])
|
|
130
|
+
|
|
131
|
+
gammas.append(gamma)
|
|
132
|
+
t += 1
|
|
133
|
+
|
|
134
|
+
k_total = sum([log["n_iter"] for log in logs["proj_logs"]])
|
|
135
|
+
k_total += t - 1
|
|
136
|
+
logs["success"] = success
|
|
137
|
+
logs["gammas"] = gammas
|
|
138
|
+
|
|
139
|
+
return u_cur, v_cur, gammas[-1], k_total, logs
|
mdot_tnt/rounding.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This file contains the implementation of the rounding algorithm proposed by Altschuler et al. (2017) in the paper
|
|
3
|
+
"Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration". The algorithm is used to
|
|
4
|
+
round the transport plan obtained from the Sinkhorn algorithm to a feasible transport plan in the set U(r, c), where r
|
|
5
|
+
and c are the row and column marginals, respectively. The algorithm is used in the mdot.py file to round the transport
|
|
6
|
+
plan and compute the cost of the rounded plan. The implementation is based on the original paper.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import torch as th
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def round_altschuler(P, r, c):
|
|
13
|
+
"""
|
|
14
|
+
Performs rounding given a transport plan and marginals.
|
|
15
|
+
:param P: the input transport plan
|
|
16
|
+
:param r: row marginal
|
|
17
|
+
:param c: column marginal
|
|
18
|
+
:return: rounded transport plan in feasible set U(r, c).
|
|
19
|
+
"""
|
|
20
|
+
X = th.min(r / P.sum(-1), th.ones_like(r))
|
|
21
|
+
P *= X.unsqueeze(-1)
|
|
22
|
+
|
|
23
|
+
Y = th.min(c / P.sum(-2), th.ones_like(c))
|
|
24
|
+
P *= Y.unsqueeze(-2)
|
|
25
|
+
|
|
26
|
+
err_r = (r - P.sum(-1)).clamp(min=0)
|
|
27
|
+
err_c = (c - P.sum(-2)).clamp(min=0)
|
|
28
|
+
P += err_r.unsqueeze(-1) @ err_c.unsqueeze(-2) / (err_r.norm(p=1, dim=-1, keepdim=True) + 1e-30).unsqueeze(-1)
|
|
29
|
+
|
|
30
|
+
return P
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def rounded_cost_altschuler(u, v, r, c, C, gamma):
|
|
34
|
+
"""Performs rounding and cost computation in logdomain given dual variables, without storing n^2 matrices.
|
|
35
|
+
:param u: dual variable for rows
|
|
36
|
+
:param v: dual variable for columns
|
|
37
|
+
:param r: row marginal
|
|
38
|
+
:param c: column marginal
|
|
39
|
+
:param C: cost matrix
|
|
40
|
+
:param gamma: temperature, i.e., the inverse of the entropic regularization weight.
|
|
41
|
+
"""
|
|
42
|
+
r_P_log = u + th.logsumexp(v.unsqueeze(-2) - gamma * C, dim=-1)
|
|
43
|
+
delta_u = th.min(r.log() - r_P_log, th.zeros_like(r))
|
|
44
|
+
u += delta_u
|
|
45
|
+
|
|
46
|
+
c_P_log = v + th.logsumexp(u.unsqueeze(-1) - gamma * C, dim=-2)
|
|
47
|
+
delta_v = th.min(c.log() - c_P_log, th.zeros_like(c))
|
|
48
|
+
v += delta_v
|
|
49
|
+
|
|
50
|
+
r_P_log = u + th.logsumexp(v.unsqueeze(-2) - gamma * C, dim=-1)
|
|
51
|
+
r_P = r_P_log.exp()
|
|
52
|
+
err_r = r - r_P
|
|
53
|
+
err_r /= (err_r.norm(p=1, dim=-1, keepdim=True) + 1e-30)
|
|
54
|
+
|
|
55
|
+
c_P_log = v + th.logsumexp(u.unsqueeze(-1) - gamma * C, dim=-2)
|
|
56
|
+
c_P = c_P_log.exp()
|
|
57
|
+
err_c = c - c_P
|
|
58
|
+
|
|
59
|
+
cost = th.logsumexp(u.unsqueeze(-1) + v.unsqueeze(-2) - gamma * C + C.log(), dim=(-1, -2)).exp()
|
|
60
|
+
cost += (err_r.unsqueeze(-2) @ C @ err_c.unsqueeze(-1)).sum(-1).sum(-1)
|
|
61
|
+
|
|
62
|
+
return cost
|
|
@@ -0,0 +1,283 @@
|
|
|
1
|
+
|
|
2
|
+
import torch as th
|
|
3
|
+
import warnings
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class TruncatedNewtonProjector:
|
|
7
|
+
def __init__(self, device, dtype, **kwargs):
|
|
8
|
+
self.device = device
|
|
9
|
+
self.rho = th.zeros(1, device=device, dtype=dtype)
|
|
10
|
+
self.debug = kwargs.get('debug', False)
|
|
11
|
+
|
|
12
|
+
def project(self, gamma_C, log_r, log_c, eps_d, u, v):
|
|
13
|
+
"""
|
|
14
|
+
Project onto the set of couplings that satisfy the marginal constraints.
|
|
15
|
+
:param gamma_C: The cost matrix scaled by gamma.
|
|
16
|
+
:param log_r:
|
|
17
|
+
"""
|
|
18
|
+
logs = {
|
|
19
|
+
"errs": [],
|
|
20
|
+
'ls_func_cnt': 0,
|
|
21
|
+
'chisinkhorn_steps': 0,
|
|
22
|
+
'newtonsolve_steps': 0,
|
|
23
|
+
"deltas": [], # Ratios of actual to theoretically predicted (ideal) reduction in gradient norm.
|
|
24
|
+
"all_newtonsolve_steps": []
|
|
25
|
+
}
|
|
26
|
+
# In case of errors or issues, 10 times the tolerance level is considered
|
|
27
|
+
# a good enough solution to keep MDOT going.
|
|
28
|
+
success_fn = lambda err_: err_ < 10 * eps_d
|
|
29
|
+
|
|
30
|
+
# Each LSE operation costs 4 * n^2 operations.
|
|
31
|
+
self.LSE_r = lambda v_: th.logsumexp(v_.unsqueeze(-2) - gamma_C, dim=-1)
|
|
32
|
+
self.LSE_c = lambda u_: th.logsumexp(u_.unsqueeze(-1) - gamma_C, dim=-2)
|
|
33
|
+
|
|
34
|
+
r = log_r.exp()
|
|
35
|
+
c = log_c.exp()
|
|
36
|
+
|
|
37
|
+
log_c_P = v + self.LSE_c(u)
|
|
38
|
+
v += log_c - log_c_P # Ensure c=c(P)
|
|
39
|
+
log_r_P = u + self.LSE_r(v)
|
|
40
|
+
k = 8
|
|
41
|
+
|
|
42
|
+
u, v, log_r_P, err, k_ = self.chi_sinkhorn(u, v, log_r, log_c, log_r_P, eps_d ** (2 / 5))
|
|
43
|
+
r_P = log_r_P.exp()
|
|
44
|
+
logs["errs"].append(err)
|
|
45
|
+
logs["chisinkhorn_steps"] = k_
|
|
46
|
+
k += k_
|
|
47
|
+
|
|
48
|
+
num_iter = 0
|
|
49
|
+
|
|
50
|
+
while err > eps_d:
|
|
51
|
+
num_iter += 1
|
|
52
|
+
|
|
53
|
+
beta = 0.5
|
|
54
|
+
eta_k = th.max(err, 0.9 * (eps_d / err))
|
|
55
|
+
|
|
56
|
+
grad_k = r_P - r
|
|
57
|
+
self.rho = th.max(th.zeros_like(self.rho), self.rho)
|
|
58
|
+
|
|
59
|
+
P = th.exp(u.unsqueeze(-1) + v.unsqueeze(-2) - gamma_C)
|
|
60
|
+
diag_PPc = ((P ** 2) / c.unsqueeze(-2)).sum(-1)
|
|
61
|
+
k += 8
|
|
62
|
+
delta_u, delta_v, matmul_cnt, rho, pcg_success = self.newton_solve(
|
|
63
|
+
P, c, diag_PPc, grad_k, r_P, err, beta, eta_k, maxIter=5000)
|
|
64
|
+
del P # Free up memory
|
|
65
|
+
if not pcg_success:
|
|
66
|
+
k += matmul_cnt
|
|
67
|
+
logs["n_iter"] = k
|
|
68
|
+
msg = "PCG did not converge. TruncatedNewton returning with success={}".format(success_fn(err))
|
|
69
|
+
warnings.warn(msg)
|
|
70
|
+
return u, v, logs, success_fn(err)
|
|
71
|
+
|
|
72
|
+
self.rho = th.max(th.zeros_like(self.rho), 1. - (1. - rho) * 4.)
|
|
73
|
+
k += matmul_cnt
|
|
74
|
+
logs["newtonsolve_steps"] += matmul_cnt
|
|
75
|
+
|
|
76
|
+
alpha = th.ones_like(self.rho)
|
|
77
|
+
log_c_P = v + alpha * delta_v + self.LSE_c(u + alpha * delta_u)
|
|
78
|
+
k += 4
|
|
79
|
+
linear_decr = -(grad_k * delta_u).sum(-1, keepdim=True)
|
|
80
|
+
if not linear_decr > 0:
|
|
81
|
+
logs["n_iter"] = k
|
|
82
|
+
msg = "Linear decrease condition not satisfied. TruncatedNewton returning with success={}".format(
|
|
83
|
+
success_fn(err))
|
|
84
|
+
warnings.warn(msg)
|
|
85
|
+
return u, v, logs, success_fn(err)
|
|
86
|
+
|
|
87
|
+
armijo = log_c_P.exp().sum(-1, keepdim=True) - 1 <= 0.99 * alpha * linear_decr
|
|
88
|
+
while not armijo: # Check armijo condition for batch elements where err > eps_d
|
|
89
|
+
alpha *= 0.5
|
|
90
|
+
if alpha < 1e-9:
|
|
91
|
+
logs["n_iter"] = k
|
|
92
|
+
msg = "Line search did not converge. TruncatedNewton returning with success={}".format(
|
|
93
|
+
success_fn(err))
|
|
94
|
+
warnings.warn(msg)
|
|
95
|
+
return u, v, logs, success_fn(err)
|
|
96
|
+
|
|
97
|
+
log_c_P = v + alpha * delta_v + self.LSE_c(u + alpha * delta_u)
|
|
98
|
+
k += 4
|
|
99
|
+
logs["ls_func_cnt"] += 4
|
|
100
|
+
armijo = log_c_P.exp().sum(-1, keepdim=True) - 1 <= 0.99 * alpha * linear_decr
|
|
101
|
+
|
|
102
|
+
u += alpha * delta_u
|
|
103
|
+
v += alpha * delta_v
|
|
104
|
+
|
|
105
|
+
# The below error (before the Sinkhorn update) is used
|
|
106
|
+
# to measure the progress of the algorithm with TruncatedNewton steps.
|
|
107
|
+
err_before_sk = (c - log_c_P.exp()).abs().sum(-1)
|
|
108
|
+
err_before_sk += (r - (u + self.LSE_r(v)).exp()).abs().sum(-1)
|
|
109
|
+
|
|
110
|
+
# Sinkhorn update to ensure c=c(P).
|
|
111
|
+
v += log_c - log_c_P
|
|
112
|
+
|
|
113
|
+
log_r_P = u + self.LSE_r(v)
|
|
114
|
+
k += 4
|
|
115
|
+
|
|
116
|
+
u, v, log_r_P, err, k_ = self.chi_sinkhorn(u, v, log_r, log_c, log_r_P, eps_d ** (2 / 5))
|
|
117
|
+
r_P = log_r_P.exp()
|
|
118
|
+
logs["chisinkhorn_steps"] += k_
|
|
119
|
+
k += k_
|
|
120
|
+
|
|
121
|
+
logs["errs"].append(err)
|
|
122
|
+
logs["deltas"].append(th.min((logs["errs"][-2] - err_before_sk) / ((1 - eta_k) * logs["errs"][-2])).item())
|
|
123
|
+
|
|
124
|
+
if u.isnan().any() or v.isnan().any():
|
|
125
|
+
raise ValueError("NaNs encountered in u or v")
|
|
126
|
+
|
|
127
|
+
logs["n_iter"] = k
|
|
128
|
+
|
|
129
|
+
# Since we already computed log_r_P, we can use it to perform one last Sinkhorn update on rows.
|
|
130
|
+
delta_u = log_r - log_r_P
|
|
131
|
+
u += delta_u
|
|
132
|
+
|
|
133
|
+
return u, v, logs, True
|
|
134
|
+
|
|
135
|
+
def chi_sinkhorn(self, u, v, log_r, log_c, log_r_P, eps_chi, maxOps=float('inf')):
|
|
136
|
+
k = 0
|
|
137
|
+
r = log_r.exp()
|
|
138
|
+
err = (r - log_r_P.exp()).norm(p=1, dim=-1)
|
|
139
|
+
r_P = log_r_P.exp()
|
|
140
|
+
chi_squared = ((r - r_P) ** 2 / r_P).sum(-1)
|
|
141
|
+
|
|
142
|
+
while chi_squared > eps_chi and k < maxOps:
|
|
143
|
+
u += log_r - log_r_P
|
|
144
|
+
|
|
145
|
+
log_c_P = v + self.LSE_c(u)
|
|
146
|
+
v += log_c - log_c_P
|
|
147
|
+
|
|
148
|
+
log_r_P = u + self.LSE_r(v)
|
|
149
|
+
r_P = log_r_P.exp()
|
|
150
|
+
|
|
151
|
+
err = (r - r_P).norm(p=1, dim=-1)
|
|
152
|
+
|
|
153
|
+
chi_squared = ((r - r_P) ** 2 / r_P).sum(-1)
|
|
154
|
+
k += 8
|
|
155
|
+
|
|
156
|
+
if k >= maxOps:
|
|
157
|
+
raise ValueError("Chi-Sinkhorn did not converge in maxIter={} steps".format(maxOps))
|
|
158
|
+
|
|
159
|
+
return u, v, log_r_P, err, k
|
|
160
|
+
|
|
161
|
+
def newton_solve(self, P, c, diag_PPc, grad_k, r_P, err, beta=0.5, eta_k=0.5, maxIter=500):
|
|
162
|
+
rho = self.rho
|
|
163
|
+
tol = err * eta_k
|
|
164
|
+
|
|
165
|
+
def matmul_PPc(x_):
|
|
166
|
+
return (P @ ((x_.unsqueeze(-2) @ P).transpose(-2, -1) / c.unsqueeze(-1))).squeeze(-1)
|
|
167
|
+
|
|
168
|
+
# mml = th.compile(matmul_PPc)
|
|
169
|
+
mml = matmul_PPc
|
|
170
|
+
|
|
171
|
+
M = lambda rho_: (r_P - rho_ * diag_PPc) # Diagonal preconditioner
|
|
172
|
+
M_rho = M(th.ones_like(self.rho))
|
|
173
|
+
M_rho[M_rho <= 0] = M_rho[M_rho > 0].min()
|
|
174
|
+
|
|
175
|
+
x0 = -grad_k / M_rho
|
|
176
|
+
PPc_x0 = mml(x0)
|
|
177
|
+
matmul_cnt = 2
|
|
178
|
+
r_P_x0 = (r_P * x0)
|
|
179
|
+
|
|
180
|
+
x = x0.clone()
|
|
181
|
+
PPc_x = PPc_x0.clone()
|
|
182
|
+
r_P_x = r_P_x0.clone()
|
|
183
|
+
|
|
184
|
+
res_true = r_P_x0 - PPc_x + grad_k
|
|
185
|
+
|
|
186
|
+
linear_decr = (x * -grad_k).sum(-1)
|
|
187
|
+
if linear_decr <= 0:
|
|
188
|
+
raise ValueError("Linear decrease condition not satisfied")
|
|
189
|
+
|
|
190
|
+
r_true_norm = res_true.norm(p=1, dim=-1)
|
|
191
|
+
best_sol = x.clone()
|
|
192
|
+
best_r_true_norm = r_true_norm.clone()
|
|
193
|
+
|
|
194
|
+
done = False
|
|
195
|
+
success = True
|
|
196
|
+
|
|
197
|
+
while best_r_true_norm > tol:
|
|
198
|
+
best_sol[r_true_norm < best_r_true_norm] = x[r_true_norm < best_r_true_norm]
|
|
199
|
+
best_r_true_norm = th.min(r_true_norm, best_r_true_norm)
|
|
200
|
+
|
|
201
|
+
rho[r_true_norm > tol] = 1. - (1. - rho[r_true_norm > tol]) * 0.25
|
|
202
|
+
M_rho = M(rho)
|
|
203
|
+
|
|
204
|
+
if matmul_cnt > 0:
|
|
205
|
+
x = x0.clone()
|
|
206
|
+
PPc_x = PPc_x0.clone()
|
|
207
|
+
r_P_x = r_P_x0.clone()
|
|
208
|
+
|
|
209
|
+
Fr_x = r_P_x - rho * PPc_x
|
|
210
|
+
res = Fr_x + grad_k
|
|
211
|
+
|
|
212
|
+
res_true = r_P_x - PPc_x + grad_k
|
|
213
|
+
r_true_norm = res_true.norm(p=1, dim=-1)
|
|
214
|
+
best_r_true_norm = th.min(r_true_norm, best_r_true_norm)
|
|
215
|
+
linear_decr = (x * -grad_k).sum(-1)
|
|
216
|
+
if (best_r_true_norm < tol).all() and (linear_decr > 0).all():
|
|
217
|
+
break
|
|
218
|
+
|
|
219
|
+
y = res / M_rho
|
|
220
|
+
p = -y.clone()
|
|
221
|
+
ry_old = (res * y).sum(-1, keepdim=True)
|
|
222
|
+
|
|
223
|
+
r_norm = res.norm(p=1, dim=-1)
|
|
224
|
+
while (r_norm > 0.5 * (1 - beta) * tol)[best_r_true_norm > tol].any():
|
|
225
|
+
PPc_p = mml(p)
|
|
226
|
+
matmul_cnt += 2
|
|
227
|
+
Fr_p = (r_P * p) - rho * PPc_p
|
|
228
|
+
|
|
229
|
+
quad = (Fr_p * p).sum(-1, keepdim=True)
|
|
230
|
+
if (quad <= 0)[best_r_true_norm > tol].any():
|
|
231
|
+
warnings.warn("Warning: negative curvature encountered in CG. Returning best solution. "
|
|
232
|
+
"Residual norm less than error: {}".format((best_r_true_norm < err).item()))
|
|
233
|
+
x = best_sol.clone()
|
|
234
|
+
done = True
|
|
235
|
+
success = best_r_true_norm < err
|
|
236
|
+
warnings.warn("Resetting discount factor rho = 0")
|
|
237
|
+
rho = th.zeros_like(self.rho)
|
|
238
|
+
break
|
|
239
|
+
|
|
240
|
+
alpha = ry_old / quad
|
|
241
|
+
x += alpha * p
|
|
242
|
+
res += alpha * Fr_p
|
|
243
|
+
r_norm = res.norm(p=1, dim=-1)
|
|
244
|
+
|
|
245
|
+
if th.isnan(r_norm)[best_r_true_norm > tol].any() or th.isinf(r_norm)[best_r_true_norm > tol].any():
|
|
246
|
+
raise ValueError("NaNs or infs encountered in r_norm")
|
|
247
|
+
|
|
248
|
+
PPc_x += alpha * PPc_p
|
|
249
|
+
|
|
250
|
+
r_P_x = (r_P * x)
|
|
251
|
+
res_true = r_P_x - PPc_x + grad_k
|
|
252
|
+
r_true_norm = res_true.norm(p=1, dim=-1)
|
|
253
|
+
best_sol[r_true_norm < best_r_true_norm] = x[r_true_norm < best_r_true_norm]
|
|
254
|
+
best_r_true_norm = th.min(r_true_norm, best_r_true_norm)
|
|
255
|
+
|
|
256
|
+
linear_decr = (x * -grad_k).sum(-1)
|
|
257
|
+
if (best_r_true_norm <= tol).all() and (linear_decr > 0).all():
|
|
258
|
+
done = True
|
|
259
|
+
success = True
|
|
260
|
+
break
|
|
261
|
+
|
|
262
|
+
if matmul_cnt > 2 * maxIter:
|
|
263
|
+
warnings.warn("PCG did not converge.")
|
|
264
|
+
done = True
|
|
265
|
+
success = False
|
|
266
|
+
break
|
|
267
|
+
|
|
268
|
+
y = res / M_rho
|
|
269
|
+
ry_new = (res * y).sum(-1, keepdim=True)
|
|
270
|
+
p = -y + (ry_new / ry_old) * p
|
|
271
|
+
ry_old = ry_new.clone()
|
|
272
|
+
|
|
273
|
+
if done:
|
|
274
|
+
break
|
|
275
|
+
|
|
276
|
+
if r_true_norm <= tol:
|
|
277
|
+
success = True
|
|
278
|
+
|
|
279
|
+
x = best_sol
|
|
280
|
+
Pc_x = ((x.unsqueeze(-2) @ P).transpose(-2, -1) / c.unsqueeze(-1)).squeeze(-1)
|
|
281
|
+
matmul_cnt += 1
|
|
282
|
+
|
|
283
|
+
return x, -Pc_x, matmul_cnt, rho, success
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# Non-Commercial Research License (NCRL-1.0)
|
|
2
|
+
|
|
3
|
+
Copyright (C) 2025 Mete Kemertas
|
|
4
|
+
|
|
5
|
+
## 1. License Grant
|
|
6
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to use, copy, modify, merge, publish, and distribute copies of the Software **solely for non-commercial research, educational, and personal purposes**, subject to the following conditions:
|
|
7
|
+
|
|
8
|
+
## 2. Restrictions
|
|
9
|
+
### 2.1 **Non-Commercial Use Only**
|
|
10
|
+
- The Software **may NOT** be used for any commercial purpose without explicit written permission from the Licensor.
|
|
11
|
+
- "Commercial purpose" includes, but is not limited to:
|
|
12
|
+
- Selling or licensing the Software.
|
|
13
|
+
- Using the Software in proprietary products or services.
|
|
14
|
+
- Offering the Software as part of a paid or revenue-generating service.
|
|
15
|
+
|
|
16
|
+
### 2.2 **No Warranty & Liability**
|
|
17
|
+
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY ARISING FROM THE USE OF THE SOFTWARE.
|
|
18
|
+
|
|
19
|
+
### 2.3 **Commercial Licensing**
|
|
20
|
+
For commercial use, a separate license must be obtained from the Licensor. To inquire about licensing, please contact: **kemertas@cs.toronto.edu**.
|
|
21
|
+
|
|
22
|
+
## 3. Termination
|
|
23
|
+
This license automatically terminates if the Licensee breaches any of its terms. Upon termination, all rights granted under this license are revoked, and the Licensee must cease using and distributing the Software.
|
|
24
|
+
|
|
25
|
+
## 4. Governing Law and Enforcement
|
|
26
|
+
This license shall be governed by and construed in accordance with the laws of Ontario, Canada. However, violations of this license may also be pursued un
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: mdot-tnt
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: A fast, GPU-parallel, PyTorch-compatible optimal transport solver.
|
|
5
|
+
Author-email: Mete Kemertas <kemertas@cs.toronto.edu>
|
|
6
|
+
Project-URL: Homepage, https://github.com/metekemertas/mdot_tnt
|
|
7
|
+
Requires-Python: >=3.7
|
|
8
|
+
Description-Content-Type: text/markdown
|
|
9
|
+
License-File: LICENSE
|
|
10
|
+
|
|
11
|
+
This is the official repository for the MDOT-TruncatedNewton (or MDOT-TNT)
|
|
12
|
+
algorithm [1] for solving the entropic-regularized optimal transport (OT) problem.
|
|
13
|
+
In addition to being GPU-friendly, the algorithm is stable under weak regularization and can therefore find highly
|
|
14
|
+
precise approximations of the un-regularized problem's solution quickly.
|
|
15
|
+
|
|
16
|
+
The current implementation is based on PyTorch and is compatible with both CPU and GPU. PyTorch is the only dependency.
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
For installation:
|
|
20
|
+
First, install PyTorch following the instructions at https://pytorch.org/get-started/locally/ to select the version that matches your system's configuration.
|
|
21
|
+
```bash
|
|
22
|
+
pip install mdot_tnt
|
|
23
|
+
```
|
|
24
|
+
|
|
25
|
+
Quickstart guide:
|
|
26
|
+
```
|
|
27
|
+
import mdot_tnt
|
|
28
|
+
import torch as th
|
|
29
|
+
device = 'cuda' if th.cuda.is_available() else 'cpu'
|
|
30
|
+
N, M, dim = 100, 200, 128
|
|
31
|
+
|
|
32
|
+
# Sample row and column marginals from Dirichlet distributions
|
|
33
|
+
r = th.distributions.Dirichlet(th.ones(N)).sample()
|
|
34
|
+
c = th.distributions.Dirichlet(th.ones(M)).sample()
|
|
35
|
+
|
|
36
|
+
# Cost matrix from pairwise Euclidean distances of random points in R^100
|
|
37
|
+
x = th.distributions.MultivariateNormal(th.zeros(dim), th.eye(dim)).sample((N,))
|
|
38
|
+
y = th.distributions.MultivariateNormal(th.zeros(dim), th.eye(dim)).sample((M,))
|
|
39
|
+
C = th.cdist(x, y, p=2)
|
|
40
|
+
C /= C.max() # Normalize cost matrix to meet convention.
|
|
41
|
+
|
|
42
|
+
# Use double precision for numerical stability in high precision regime.
|
|
43
|
+
r, c, C = r.double(), c.double(), C.double()
|
|
44
|
+
|
|
45
|
+
# Solve OT problem. Increase (decrease) gamma_f for higher (lower) precision.
|
|
46
|
+
# Default is gamma_f=2**12. Expect error of order logn / gamma_f at worst, and possibly lower.
|
|
47
|
+
cost = mdot_tnt.solve_OT(r, c, C, gamma_f=2**12)
|
|
48
|
+
|
|
49
|
+
# To return a feasible transport plan, use the following:
|
|
50
|
+
transport_plan = mdot_tnt.solve_OT(r, c, C, gamma_f=2**12, return_plan=True)
|
|
51
|
+
|
|
52
|
+
# In both cases, the default rounding onto the feasible set can be disabled by setting `round=False`.
|
|
53
|
+
```
|
|
54
|
+
|
|
55
|
+
The code is released under a custom non-commerical use license. If you use our work in
|
|
56
|
+
your research, please consider citing:
|
|
57
|
+
|
|
58
|
+
```
|
|
59
|
+
@inproceedings{
|
|
60
|
+
kemertas2025a,
|
|
61
|
+
title={A Truncated Newton Method for Optimal Transport},
|
|
62
|
+
author={Mete Kemertas and Amir-massoud Farahmand and Allan Douglas Jepson},
|
|
63
|
+
booktitle={The Thirteenth International Conference on Learning Representations},
|
|
64
|
+
year={2025},
|
|
65
|
+
url={https://openreview.net/forum?id=gWrWUaCbMa}
|
|
66
|
+
}
|
|
67
|
+
```
|
|
68
|
+
|
|
69
|
+
For inquiries, email: kemertas [at] cs [dot] toronto [dot] edu
|
|
70
|
+
|
|
71
|
+
[1] Mete Kemertas, Amir-massoud Farahmand, Allan Douglas Jepson. "A Truncated Newton Method for Optimal Transport." The Thirteenth International Conference on Learning Representations (ICLR), 2025. https://openreview.net/forum?id=gWrWUaCbMa
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
mdot_tnt/__init__.py,sha256=zi0ujbN19nv_dt8wdxPQgsOuDi-4gIg4sq6y9fDfAkI,2657
|
|
2
|
+
mdot_tnt/mdot.py,sha256=EUch7ECZxDWDoNW3S84qyCx6-qbG6Mu1_XlWow862Wc,4620
|
|
3
|
+
mdot_tnt/rounding.py,sha256=Q7QBPsFzBqnMZKnlV147ruStUCme6gwt6HnjTjqBezk,2405
|
|
4
|
+
mdot_tnt/truncated_newton.py,sha256=YEkWJxL5H4XsuWQzfdZ0E5Ju3Npk3kKoe2l_go8wq2A,10314
|
|
5
|
+
mdot_tnt-0.1.0.dist-info/LICENSE,sha256=sXw3FpVqouAddNhfwD6nXaSgABFyLnXuSn_ghjU0AhY,1837
|
|
6
|
+
mdot_tnt-0.1.0.dist-info/METADATA,sha256=2Ou4X10K6bUGHQhnMXoFjrv-4YUeIyreSKibMLM2gq0,2972
|
|
7
|
+
mdot_tnt-0.1.0.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
|
8
|
+
mdot_tnt-0.1.0.dist-info/top_level.txt,sha256=HmxTNtoLH7F20hgZVFdfUowIQ2fviSX64wSG1HP8Iao,9
|
|
9
|
+
mdot_tnt-0.1.0.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
mdot_tnt
|