mdot-tnt 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mdot_tnt/__init__.py +18 -5
- mdot_tnt/mdot.py +27 -17
- mdot_tnt/truncated_newton.py +4 -5
- {mdot_tnt-0.1.0.dist-info → mdot_tnt-0.2.0.dist-info}/METADATA +7 -7
- mdot_tnt-0.2.0.dist-info/RECORD +9 -0
- mdot_tnt-0.1.0.dist-info/RECORD +0 -9
- {mdot_tnt-0.1.0.dist-info → mdot_tnt-0.2.0.dist-info}/LICENSE +0 -0
- {mdot_tnt-0.1.0.dist-info → mdot_tnt-0.2.0.dist-info}/WHEEL +0 -0
- {mdot_tnt-0.1.0.dist-info → mdot_tnt-0.2.0.dist-info}/top_level.txt +0 -0
mdot_tnt/__init__.py
CHANGED
|
@@ -1,11 +1,19 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Code for solving the entropic-regularized optimal transport problem via the MDOT-TruncatedNewton (MDOT-TNT)
|
|
3
|
+
method introduced in the paper "A Truncated Newton Method for Optimal Transport"
|
|
4
|
+
by Mete Kemertas, Amir-massoud Farahmand, Allan D. Jepson (ICLR, 2025).
|
|
5
|
+
URL: https://openreview.net/forum?id=gWrWUaCbMa
|
|
6
|
+
"""
|
|
1
7
|
|
|
8
|
+
import math
|
|
2
9
|
import torch as th
|
|
10
|
+
import warnings
|
|
3
11
|
|
|
4
|
-
from mdot_tnt.mdot import mdot
|
|
12
|
+
from mdot_tnt.mdot import mdot, preprocess_marginals
|
|
5
13
|
from mdot_tnt.rounding import round_altschuler, rounded_cost_altschuler
|
|
6
14
|
|
|
7
15
|
|
|
8
|
-
def solve_OT(r, c, C, gamma_f=
|
|
16
|
+
def solve_OT(r, c, C, gamma_f=1024., drop_tiny=False, return_plan=False, round=True, log=False):
|
|
9
17
|
"""
|
|
10
18
|
Solve the entropic-regularized optimal transport problem. Inputs r, c, C are required to be torch tensors.
|
|
11
19
|
:param r: n-dimensional row marginal.
|
|
@@ -23,8 +31,12 @@ def solve_OT(r, c, C, gamma_f=4096., drop_tiny=False, return_plan=False, round=T
|
|
|
23
31
|
"""
|
|
24
32
|
assert all(isinstance(x, th.Tensor) for x in [r, c, C]), "r, c, and C must be torch tensors"
|
|
25
33
|
dtype = r.dtype
|
|
26
|
-
|
|
34
|
+
# Require high precision for gamma_f > 2^10
|
|
35
|
+
if gamma_f > 2 ** 10 and dtype != th.float64:
|
|
36
|
+
warnings.warn("Switching to double precision for gamma_f > 2^10 during execution. "
|
|
37
|
+
"Output will be input dtype: {}.".format(dtype))
|
|
27
38
|
r, c, C = r.double(), c.double(), C.double()
|
|
39
|
+
|
|
28
40
|
if drop_tiny:
|
|
29
41
|
drop_lessthan = math.log(min(r.size(-1), c.size(-1))) / (gamma_f ** 2)
|
|
30
42
|
(r_, r_keep), (c_, c_keep), C_ = preprocess_marginals(r, c, C, drop_lessthan)
|
|
@@ -32,12 +44,13 @@ def solve_OT(r, c, C, gamma_f=4096., drop_tiny=False, return_plan=False, round=T
|
|
|
32
44
|
u_, v_, gamma_f_, k_total, opt_logs = mdot(r_, c_, C_, gamma_f)
|
|
33
45
|
|
|
34
46
|
u = -th.ones_like(r) * float('inf')
|
|
35
|
-
u[
|
|
47
|
+
u[r_keep] = u_
|
|
36
48
|
v = -th.ones_like(c) * float('inf')
|
|
37
|
-
v[
|
|
49
|
+
v[c_keep] = v_
|
|
38
50
|
else:
|
|
39
51
|
u, v, gamma_f_, k_total, opt_logs = mdot(r, c, C, gamma_f)
|
|
40
52
|
|
|
53
|
+
# Switch back to original dtype
|
|
41
54
|
u, v = u.to(dtype), v.to(dtype)
|
|
42
55
|
|
|
43
56
|
if return_plan:
|
mdot_tnt/mdot.py
CHANGED
|
@@ -1,12 +1,4 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Code for solving the entropic-regularized optimal transport problem via the MDOT-TruncatedNewton (MDOT-TNT)
|
|
3
|
-
method introduced in the paper "A Truncated Newton Method for Optimal Transport"
|
|
4
|
-
by Mete Kemertas, Amir-massoud Farahmand, Allan D. Jepson (ICLR, 2025).
|
|
5
|
-
URL: https://openreview.net/forum?id=gWrWUaCbMa
|
|
6
|
-
"""
|
|
7
1
|
|
|
8
|
-
|
|
9
|
-
import math
|
|
10
2
|
import warnings
|
|
11
3
|
|
|
12
4
|
from mdot_tnt.rounding import *
|
|
@@ -16,18 +8,18 @@ from mdot_tnt.truncated_newton import TruncatedNewtonProjector
|
|
|
16
8
|
def preprocess_marginals(r, c, C, eps):
|
|
17
9
|
"""
|
|
18
10
|
This function drops the smallest entries whose cumulative sum equals
|
|
19
|
-
:param r:
|
|
20
|
-
:param c:
|
|
21
|
-
:param C:
|
|
22
|
-
:param eps:
|
|
23
|
-
:return:
|
|
11
|
+
:param r: the row marginal.
|
|
12
|
+
:param c: the column marginal.
|
|
13
|
+
:param C: the cost matrix.
|
|
14
|
+
:param eps: the threshold for the cumulative sum of the marginal entries to be dropped
|
|
15
|
+
:return: the new marginals and the new cost matrix with the corresponding rows and columns dropped.
|
|
24
16
|
"""
|
|
25
17
|
|
|
26
18
|
def preprocess_marginal(m, eps):
|
|
27
19
|
m_sorted, m_idx = th.sort(m, dim=-1, descending=False)
|
|
28
20
|
m_cumsum = th.cumsum(m_sorted, dim=-1)
|
|
29
21
|
m_keep = m_idx[m_cumsum > eps]
|
|
30
|
-
m_new = m[
|
|
22
|
+
m_new = m[m_keep]
|
|
31
23
|
mass_removed = 1 - m_new.sum(-1)
|
|
32
24
|
m_new = m_new + mass_removed / m_new.size(-1)
|
|
33
25
|
|
|
@@ -35,6 +27,8 @@ def preprocess_marginals(r, c, C, eps):
|
|
|
35
27
|
|
|
36
28
|
r_new, r_keep = preprocess_marginal(r, eps)
|
|
37
29
|
c_new, c_keep = preprocess_marginal(c, eps)
|
|
30
|
+
print("Dropped {} entries from r and {} entries from c.".format(
|
|
31
|
+
r.size(-1) - r_new.size(-1), c.size(-1) - c_new.size(-1)))
|
|
38
32
|
|
|
39
33
|
C = C[r_keep][:, c_keep]
|
|
40
34
|
|
|
@@ -42,6 +36,15 @@ def preprocess_marginals(r, c, C, eps):
|
|
|
42
36
|
|
|
43
37
|
|
|
44
38
|
def smooth_marginals(r, c, eps, w_r=0.5, w_c=0.5):
|
|
39
|
+
"""
|
|
40
|
+
Smooth the marginals by adding a small amount of uniform mass to each entry.
|
|
41
|
+
:param r: the row marginal.
|
|
42
|
+
:param c: the column marginal.
|
|
43
|
+
:param eps: the amount of mass to add to each entry.
|
|
44
|
+
:param w_r: the weight for the row marginal.
|
|
45
|
+
:param w_c: the weight for the column marginal.
|
|
46
|
+
:return: the smoothed marginals with a total TV distance at most eps from the original marginals.
|
|
47
|
+
"""
|
|
45
48
|
assert w_r + w_c == 1, "w_r and w_c must sum to 1"
|
|
46
49
|
eps = eps.clamp(max=1.).unsqueeze(-1)
|
|
47
50
|
r_hat = (1 - w_r * eps) * r + w_r * eps * th.ones_like(r) / r.size(-1)
|
|
@@ -51,6 +54,12 @@ def smooth_marginals(r, c, eps, w_r=0.5, w_c=0.5):
|
|
|
51
54
|
|
|
52
55
|
|
|
53
56
|
def adjust_schedule(q, deltas=None):
|
|
57
|
+
"""
|
|
58
|
+
Adjust the temperature annealing schedule based on the success of the Truncated Newton method.
|
|
59
|
+
:param q: the current temperature annealing schedule adjustment factor.
|
|
60
|
+
:param deltas: the list of deltas from the Truncated Newton method; see Sec. 3.3 of Kemertas et al. (2025).
|
|
61
|
+
:return: the new temperature annealing schedule adjustment factor.
|
|
62
|
+
"""
|
|
54
63
|
if deltas is None:
|
|
55
64
|
return q
|
|
56
65
|
|
|
@@ -65,7 +74,7 @@ def adjust_schedule(q, deltas=None):
|
|
|
65
74
|
return q
|
|
66
75
|
|
|
67
76
|
|
|
68
|
-
def mdot(r, c, C, gamma_f, gamma_i=16, p=1.5, q=2.0
|
|
77
|
+
def mdot(r, c, C, gamma_f, gamma_i=16, p=1.5, q=2.0):
|
|
69
78
|
"""
|
|
70
79
|
Solve the entropic-regularized optimal transport problem using the MDOT method introduced in the paper:
|
|
71
80
|
"Efficient and Accurate Optimal Transport with Mirror Descent and Conjugate Gradients" by Mete Kemertas,
|
|
@@ -78,7 +87,7 @@ def mdot(r, c, C, gamma_f, gamma_i=16, p=1.5, q=2.0**(1/3)):
|
|
|
78
87
|
:param gamma_i: The initial temperature.
|
|
79
88
|
:param p: The exponent for the epsilon function, used to determine the stopping criterion for the dual gradient.
|
|
80
89
|
:param q: The temperature annealing (or mirror descent step size) schedule adjustment factor.
|
|
81
|
-
:return:
|
|
90
|
+
:return: The dual variables u, v, the final temperature, the total number of O(n^2) primitive ops, and logs.
|
|
82
91
|
"""
|
|
83
92
|
projector = TruncatedNewtonProjector(device=C.device, dtype=C.dtype)
|
|
84
93
|
|
|
@@ -109,8 +118,9 @@ def mdot(r, c, C, gamma_f, gamma_i=16, p=1.5, q=2.0**(1/3)):
|
|
|
109
118
|
u_cur, v_cur = u_init.clone(), v_init.clone()
|
|
110
119
|
|
|
111
120
|
u_prev, v_prev = u_cur.clone(), v_cur.clone()
|
|
121
|
+
gamma_C = gamma * C
|
|
112
122
|
u_cur, v_cur, proj_log, success = projector.project(
|
|
113
|
-
|
|
123
|
+
gamma_C, r_hat.log(), c_hat.log(), eps_d / 2, u_init, v_init)
|
|
114
124
|
|
|
115
125
|
logs["proj_logs"].append(proj_log)
|
|
116
126
|
if not success:
|
mdot_tnt/truncated_newton.py
CHANGED
|
@@ -27,13 +27,13 @@ class TruncatedNewtonProjector:
|
|
|
27
27
|
# a good enough solution to keep MDOT going.
|
|
28
28
|
success_fn = lambda err_: err_ < 10 * eps_d
|
|
29
29
|
|
|
30
|
+
r = log_r.exp()
|
|
31
|
+
c = log_c.exp()
|
|
32
|
+
|
|
30
33
|
# Each LSE operation costs 4 * n^2 operations.
|
|
31
34
|
self.LSE_r = lambda v_: th.logsumexp(v_.unsqueeze(-2) - gamma_C, dim=-1)
|
|
32
35
|
self.LSE_c = lambda u_: th.logsumexp(u_.unsqueeze(-1) - gamma_C, dim=-2)
|
|
33
36
|
|
|
34
|
-
r = log_r.exp()
|
|
35
|
-
c = log_c.exp()
|
|
36
|
-
|
|
37
37
|
log_c_P = v + self.LSE_c(u)
|
|
38
38
|
v += log_c - log_c_P # Ensure c=c(P)
|
|
39
39
|
log_r_P = u + self.LSE_r(v)
|
|
@@ -162,8 +162,7 @@ class TruncatedNewtonProjector:
|
|
|
162
162
|
rho = self.rho
|
|
163
163
|
tol = err * eta_k
|
|
164
164
|
|
|
165
|
-
|
|
166
|
-
return (P @ ((x_.unsqueeze(-2) @ P).transpose(-2, -1) / c.unsqueeze(-1))).squeeze(-1)
|
|
165
|
+
matmul_PPc = lambda x_: (P @ ((x_.unsqueeze(-2) @ P).transpose(-2, -1) / c.unsqueeze(-1))).squeeze(-1)
|
|
167
166
|
|
|
168
167
|
# mml = th.compile(matmul_PPc)
|
|
169
168
|
mml = matmul_PPc
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: mdot-tnt
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Summary: A fast, GPU-parallel, PyTorch-compatible optimal transport solver.
|
|
5
5
|
Author-email: Mete Kemertas <kemertas@cs.toronto.edu>
|
|
6
6
|
Project-URL: Homepage, https://github.com/metekemertas/mdot_tnt
|
|
@@ -19,7 +19,7 @@ The current implementation is based on PyTorch and is compatible with both CPU a
|
|
|
19
19
|
For installation:
|
|
20
20
|
First, install PyTorch following the instructions at https://pytorch.org/get-started/locally/ to select the version that matches your system's configuration.
|
|
21
21
|
```bash
|
|
22
|
-
|
|
22
|
+
pip3 install mdot_tnt
|
|
23
23
|
```
|
|
24
24
|
|
|
25
25
|
Quickstart guide:
|
|
@@ -33,18 +33,18 @@ N, M, dim = 100, 200, 128
|
|
|
33
33
|
r = th.distributions.Dirichlet(th.ones(N)).sample()
|
|
34
34
|
c = th.distributions.Dirichlet(th.ones(M)).sample()
|
|
35
35
|
|
|
36
|
-
# Cost matrix from pairwise Euclidean distances
|
|
36
|
+
# Cost matrix from pairwise Euclidean distances squared given random points in R^100
|
|
37
37
|
x = th.distributions.MultivariateNormal(th.zeros(dim), th.eye(dim)).sample((N,))
|
|
38
38
|
y = th.distributions.MultivariateNormal(th.zeros(dim), th.eye(dim)).sample((M,))
|
|
39
|
-
C = th.cdist(x, y, p=2)
|
|
39
|
+
C = th.cdist(x, y, p=2) ** 2
|
|
40
40
|
C /= C.max() # Normalize cost matrix to meet convention.
|
|
41
41
|
|
|
42
42
|
# Use double precision for numerical stability in high precision regime.
|
|
43
|
-
r, c, C = r.double(), c.double(), C.double()
|
|
43
|
+
r, c, C = r.double().to(device), c.double().to(device), C.double().to(device)
|
|
44
44
|
|
|
45
45
|
# Solve OT problem. Increase (decrease) gamma_f for higher (lower) precision.
|
|
46
|
-
# Default is gamma_f=2**
|
|
47
|
-
cost = mdot_tnt.solve_OT(r, c, C, gamma_f=2**
|
|
46
|
+
# Default is gamma_f=2**10. Expect error of order logn / gamma_f at worst, and possibly lower.
|
|
47
|
+
cost = mdot_tnt.solve_OT(r, c, C, gamma_f=2**10)
|
|
48
48
|
|
|
49
49
|
# To return a feasible transport plan, use the following:
|
|
50
50
|
transport_plan = mdot_tnt.solve_OT(r, c, C, gamma_f=2**12, return_plan=True)
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
mdot_tnt/__init__.py,sha256=9_DJl0mXe5WiKFf_12oN4d2Bnq4Q3AcuuNz9A6FmFp8,3290
|
|
2
|
+
mdot_tnt/mdot.py,sha256=8SZxldnq64ySWZlUjsdGqT_prqfcs2ZrVFBGp-EJ4B0,5562
|
|
3
|
+
mdot_tnt/rounding.py,sha256=Q7QBPsFzBqnMZKnlV147ruStUCme6gwt6HnjTjqBezk,2405
|
|
4
|
+
mdot_tnt/truncated_newton.py,sha256=Zp4Tb65dSE_AGDElgAE1gp3V_CaXFsbDX3SJCaNEWfc,10299
|
|
5
|
+
mdot_tnt-0.2.0.dist-info/LICENSE,sha256=sXw3FpVqouAddNhfwD6nXaSgABFyLnXuSn_ghjU0AhY,1837
|
|
6
|
+
mdot_tnt-0.2.0.dist-info/METADATA,sha256=sfqk_nDPsGJlSrXoz982dgLLAOGqI1F1obSdOsodIN8,3022
|
|
7
|
+
mdot_tnt-0.2.0.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
|
8
|
+
mdot_tnt-0.2.0.dist-info/top_level.txt,sha256=HmxTNtoLH7F20hgZVFdfUowIQ2fviSX64wSG1HP8Iao,9
|
|
9
|
+
mdot_tnt-0.2.0.dist-info/RECORD,,
|
mdot_tnt-0.1.0.dist-info/RECORD
DELETED
|
@@ -1,9 +0,0 @@
|
|
|
1
|
-
mdot_tnt/__init__.py,sha256=zi0ujbN19nv_dt8wdxPQgsOuDi-4gIg4sq6y9fDfAkI,2657
|
|
2
|
-
mdot_tnt/mdot.py,sha256=EUch7ECZxDWDoNW3S84qyCx6-qbG6Mu1_XlWow862Wc,4620
|
|
3
|
-
mdot_tnt/rounding.py,sha256=Q7QBPsFzBqnMZKnlV147ruStUCme6gwt6HnjTjqBezk,2405
|
|
4
|
-
mdot_tnt/truncated_newton.py,sha256=YEkWJxL5H4XsuWQzfdZ0E5Ju3Npk3kKoe2l_go8wq2A,10314
|
|
5
|
-
mdot_tnt-0.1.0.dist-info/LICENSE,sha256=sXw3FpVqouAddNhfwD6nXaSgABFyLnXuSn_ghjU0AhY,1837
|
|
6
|
-
mdot_tnt-0.1.0.dist-info/METADATA,sha256=2Ou4X10K6bUGHQhnMXoFjrv-4YUeIyreSKibMLM2gq0,2972
|
|
7
|
-
mdot_tnt-0.1.0.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
|
8
|
-
mdot_tnt-0.1.0.dist-info/top_level.txt,sha256=HmxTNtoLH7F20hgZVFdfUowIQ2fviSX64wSG1HP8Iao,9
|
|
9
|
-
mdot_tnt-0.1.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|