mdot-tnt 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mdot_tnt/__init__.py CHANGED
@@ -1,11 +1,19 @@
1
+ """
2
+ Code for solving the entropic-regularized optimal transport problem via the MDOT-TruncatedNewton (MDOT-TNT)
3
+ method introduced in the paper "A Truncated Newton Method for Optimal Transport"
4
+ by Mete Kemertas, Amir-massoud Farahmand, Allan D. Jepson (ICLR, 2025).
5
+ URL: https://openreview.net/forum?id=gWrWUaCbMa
6
+ """
1
7
 
8
+ import math
2
9
  import torch as th
10
+ import warnings
3
11
 
4
- from mdot_tnt.mdot import mdot
12
+ from mdot_tnt.mdot import mdot, preprocess_marginals
5
13
  from mdot_tnt.rounding import round_altschuler, rounded_cost_altschuler
6
14
 
7
15
 
8
- def solve_OT(r, c, C, gamma_f=4096., drop_tiny=False, return_plan=False, round=True, log=False):
16
+ def solve_OT(r, c, C, gamma_f=1024., drop_tiny=False, return_plan=False, round=True, log=False):
9
17
  """
10
18
  Solve the entropic-regularized optimal transport problem. Inputs r, c, C are required to be torch tensors.
11
19
  :param r: n-dimensional row marginal.
@@ -23,8 +31,12 @@ def solve_OT(r, c, C, gamma_f=4096., drop_tiny=False, return_plan=False, round=T
23
31
  """
24
32
  assert all(isinstance(x, th.Tensor) for x in [r, c, C]), "r, c, and C must be torch tensors"
25
33
  dtype = r.dtype
26
- if gamma_f > 2 ** 10:
34
+ # Require high precision for gamma_f > 2^10
35
+ if gamma_f > 2 ** 10 and dtype != th.float64:
36
+ warnings.warn("Switching to double precision for gamma_f > 2^10 during execution. "
37
+ "Output will be input dtype: {}.".format(dtype))
27
38
  r, c, C = r.double(), c.double(), C.double()
39
+
28
40
  if drop_tiny:
29
41
  drop_lessthan = math.log(min(r.size(-1), c.size(-1))) / (gamma_f ** 2)
30
42
  (r_, r_keep), (c_, c_keep), C_ = preprocess_marginals(r, c, C, drop_lessthan)
@@ -32,12 +44,13 @@ def solve_OT(r, c, C, gamma_f=4096., drop_tiny=False, return_plan=False, round=T
32
44
  u_, v_, gamma_f_, k_total, opt_logs = mdot(r_, c_, C_, gamma_f)
33
45
 
34
46
  u = -th.ones_like(r) * float('inf')
35
- u[:, r_keep] = u_
47
+ u[r_keep] = u_
36
48
  v = -th.ones_like(c) * float('inf')
37
- v[:, c_keep] = v_
49
+ v[c_keep] = v_
38
50
  else:
39
51
  u, v, gamma_f_, k_total, opt_logs = mdot(r, c, C, gamma_f)
40
52
 
53
+ # Switch back to original dtype
41
54
  u, v = u.to(dtype), v.to(dtype)
42
55
 
43
56
  if return_plan:
mdot_tnt/mdot.py CHANGED
@@ -1,12 +1,4 @@
1
- """
2
- Code for solving the entropic-regularized optimal transport problem via the MDOT-TruncatedNewton (MDOT-TNT)
3
- method introduced in the paper "A Truncated Newton Method for Optimal Transport"
4
- by Mete Kemertas, Amir-massoud Farahmand, Allan D. Jepson (ICLR, 2025).
5
- URL: https://openreview.net/forum?id=gWrWUaCbMa
6
- """
7
1
 
8
-
9
- import math
10
2
  import warnings
11
3
 
12
4
  from mdot_tnt.rounding import *
@@ -16,18 +8,18 @@ from mdot_tnt.truncated_newton import TruncatedNewtonProjector
16
8
  def preprocess_marginals(r, c, C, eps):
17
9
  """
18
10
  This function drops the smallest entries whose cumulative sum equals
19
- :param r:
20
- :param c:
21
- :param C:
22
- :param eps:
23
- :return:
11
+ :param r: the row marginal.
12
+ :param c: the column marginal.
13
+ :param C: the cost matrix.
14
+ :param eps: the threshold for the cumulative sum of the marginal entries to be dropped
15
+ :return: the new marginals and the new cost matrix with the corresponding rows and columns dropped.
24
16
  """
25
17
 
26
18
  def preprocess_marginal(m, eps):
27
19
  m_sorted, m_idx = th.sort(m, dim=-1, descending=False)
28
20
  m_cumsum = th.cumsum(m_sorted, dim=-1)
29
21
  m_keep = m_idx[m_cumsum > eps]
30
- m_new = m[:, m_keep]
22
+ m_new = m[m_keep]
31
23
  mass_removed = 1 - m_new.sum(-1)
32
24
  m_new = m_new + mass_removed / m_new.size(-1)
33
25
 
@@ -35,6 +27,8 @@ def preprocess_marginals(r, c, C, eps):
35
27
 
36
28
  r_new, r_keep = preprocess_marginal(r, eps)
37
29
  c_new, c_keep = preprocess_marginal(c, eps)
30
+ print("Dropped {} entries from r and {} entries from c.".format(
31
+ r.size(-1) - r_new.size(-1), c.size(-1) - c_new.size(-1)))
38
32
 
39
33
  C = C[r_keep][:, c_keep]
40
34
 
@@ -42,6 +36,15 @@ def preprocess_marginals(r, c, C, eps):
42
36
 
43
37
 
44
38
  def smooth_marginals(r, c, eps, w_r=0.5, w_c=0.5):
39
+ """
40
+ Smooth the marginals by adding a small amount of uniform mass to each entry.
41
+ :param r: the row marginal.
42
+ :param c: the column marginal.
43
+ :param eps: the amount of mass to add to each entry.
44
+ :param w_r: the weight for the row marginal.
45
+ :param w_c: the weight for the column marginal.
46
+ :return: the smoothed marginals with a total TV distance at most eps from the original marginals.
47
+ """
45
48
  assert w_r + w_c == 1, "w_r and w_c must sum to 1"
46
49
  eps = eps.clamp(max=1.).unsqueeze(-1)
47
50
  r_hat = (1 - w_r * eps) * r + w_r * eps * th.ones_like(r) / r.size(-1)
@@ -51,6 +54,12 @@ def smooth_marginals(r, c, eps, w_r=0.5, w_c=0.5):
51
54
 
52
55
 
53
56
  def adjust_schedule(q, deltas=None):
57
+ """
58
+ Adjust the temperature annealing schedule based on the success of the Truncated Newton method.
59
+ :param q: the current temperature annealing schedule adjustment factor.
60
+ :param deltas: the list of deltas from the Truncated Newton method; see Sec. 3.3 of Kemertas et al. (2025).
61
+ :return: the new temperature annealing schedule adjustment factor.
62
+ """
54
63
  if deltas is None:
55
64
  return q
56
65
 
@@ -65,7 +74,7 @@ def adjust_schedule(q, deltas=None):
65
74
  return q
66
75
 
67
76
 
68
- def mdot(r, c, C, gamma_f, gamma_i=16, p=1.5, q=2.0**(1/3)):
77
+ def mdot(r, c, C, gamma_f, gamma_i=16, p=1.5, q=2.0):
69
78
  """
70
79
  Solve the entropic-regularized optimal transport problem using the MDOT method introduced in the paper:
71
80
  "Efficient and Accurate Optimal Transport with Mirror Descent and Conjugate Gradients" by Mete Kemertas,
@@ -78,7 +87,7 @@ def mdot(r, c, C, gamma_f, gamma_i=16, p=1.5, q=2.0**(1/3)):
78
87
  :param gamma_i: The initial temperature.
79
88
  :param p: The exponent for the epsilon function, used to determine the stopping criterion for the dual gradient.
80
89
  :param q: The temperature annealing (or mirror descent step size) schedule adjustment factor.
81
- :return:
90
+ :return: The dual variables u, v, the final temperature, the total number of O(n^2) primitive ops, and logs.
82
91
  """
83
92
  projector = TruncatedNewtonProjector(device=C.device, dtype=C.dtype)
84
93
 
@@ -109,8 +118,9 @@ def mdot(r, c, C, gamma_f, gamma_i=16, p=1.5, q=2.0**(1/3)):
109
118
  u_cur, v_cur = u_init.clone(), v_init.clone()
110
119
 
111
120
  u_prev, v_prev = u_cur.clone(), v_cur.clone()
121
+ gamma_C = gamma * C
112
122
  u_cur, v_cur, proj_log, success = projector.project(
113
- gamma * C, r_hat.log(), c_hat.log(), eps_d / 2, u_init, v_init)
123
+ gamma_C, r_hat.log(), c_hat.log(), eps_d / 2, u_init, v_init)
114
124
 
115
125
  logs["proj_logs"].append(proj_log)
116
126
  if not success:
@@ -27,13 +27,13 @@ class TruncatedNewtonProjector:
27
27
  # a good enough solution to keep MDOT going.
28
28
  success_fn = lambda err_: err_ < 10 * eps_d
29
29
 
30
+ r = log_r.exp()
31
+ c = log_c.exp()
32
+
30
33
  # Each LSE operation costs 4 * n^2 operations.
31
34
  self.LSE_r = lambda v_: th.logsumexp(v_.unsqueeze(-2) - gamma_C, dim=-1)
32
35
  self.LSE_c = lambda u_: th.logsumexp(u_.unsqueeze(-1) - gamma_C, dim=-2)
33
36
 
34
- r = log_r.exp()
35
- c = log_c.exp()
36
-
37
37
  log_c_P = v + self.LSE_c(u)
38
38
  v += log_c - log_c_P # Ensure c=c(P)
39
39
  log_r_P = u + self.LSE_r(v)
@@ -162,8 +162,7 @@ class TruncatedNewtonProjector:
162
162
  rho = self.rho
163
163
  tol = err * eta_k
164
164
 
165
- def matmul_PPc(x_):
166
- return (P @ ((x_.unsqueeze(-2) @ P).transpose(-2, -1) / c.unsqueeze(-1))).squeeze(-1)
165
+ matmul_PPc = lambda x_: (P @ ((x_.unsqueeze(-2) @ P).transpose(-2, -1) / c.unsqueeze(-1))).squeeze(-1)
167
166
 
168
167
  # mml = th.compile(matmul_PPc)
169
168
  mml = matmul_PPc
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mdot-tnt
3
- Version: 0.1.0
3
+ Version: 0.2.0
4
4
  Summary: A fast, GPU-parallel, PyTorch-compatible optimal transport solver.
5
5
  Author-email: Mete Kemertas <kemertas@cs.toronto.edu>
6
6
  Project-URL: Homepage, https://github.com/metekemertas/mdot_tnt
@@ -19,7 +19,7 @@ The current implementation is based on PyTorch and is compatible with both CPU a
19
19
  For installation:
20
20
  First, install PyTorch following the instructions at https://pytorch.org/get-started/locally/ to select the version that matches your system's configuration.
21
21
  ```bash
22
- pip install mdot_tnt
22
+ pip3 install mdot_tnt
23
23
  ```
24
24
 
25
25
  Quickstart guide:
@@ -33,18 +33,18 @@ N, M, dim = 100, 200, 128
33
33
  r = th.distributions.Dirichlet(th.ones(N)).sample()
34
34
  c = th.distributions.Dirichlet(th.ones(M)).sample()
35
35
 
36
- # Cost matrix from pairwise Euclidean distances of random points in R^100
36
+ # Cost matrix from pairwise Euclidean distances squared given random points in R^100
37
37
  x = th.distributions.MultivariateNormal(th.zeros(dim), th.eye(dim)).sample((N,))
38
38
  y = th.distributions.MultivariateNormal(th.zeros(dim), th.eye(dim)).sample((M,))
39
- C = th.cdist(x, y, p=2)
39
+ C = th.cdist(x, y, p=2) ** 2
40
40
  C /= C.max() # Normalize cost matrix to meet convention.
41
41
 
42
42
  # Use double precision for numerical stability in high precision regime.
43
- r, c, C = r.double(), c.double(), C.double()
43
+ r, c, C = r.double().to(device), c.double().to(device), C.double().to(device)
44
44
 
45
45
  # Solve OT problem. Increase (decrease) gamma_f for higher (lower) precision.
46
- # Default is gamma_f=2**12. Expect error of order logn / gamma_f at worst, and possibly lower.
47
- cost = mdot_tnt.solve_OT(r, c, C, gamma_f=2**12)
46
+ # Default is gamma_f=2**10. Expect error of order logn / gamma_f at worst, and possibly lower.
47
+ cost = mdot_tnt.solve_OT(r, c, C, gamma_f=2**10)
48
48
 
49
49
  # To return a feasible transport plan, use the following:
50
50
  transport_plan = mdot_tnt.solve_OT(r, c, C, gamma_f=2**12, return_plan=True)
@@ -0,0 +1,9 @@
1
+ mdot_tnt/__init__.py,sha256=9_DJl0mXe5WiKFf_12oN4d2Bnq4Q3AcuuNz9A6FmFp8,3290
2
+ mdot_tnt/mdot.py,sha256=8SZxldnq64ySWZlUjsdGqT_prqfcs2ZrVFBGp-EJ4B0,5562
3
+ mdot_tnt/rounding.py,sha256=Q7QBPsFzBqnMZKnlV147ruStUCme6gwt6HnjTjqBezk,2405
4
+ mdot_tnt/truncated_newton.py,sha256=Zp4Tb65dSE_AGDElgAE1gp3V_CaXFsbDX3SJCaNEWfc,10299
5
+ mdot_tnt-0.2.0.dist-info/LICENSE,sha256=sXw3FpVqouAddNhfwD6nXaSgABFyLnXuSn_ghjU0AhY,1837
6
+ mdot_tnt-0.2.0.dist-info/METADATA,sha256=sfqk_nDPsGJlSrXoz982dgLLAOGqI1F1obSdOsodIN8,3022
7
+ mdot_tnt-0.2.0.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
8
+ mdot_tnt-0.2.0.dist-info/top_level.txt,sha256=HmxTNtoLH7F20hgZVFdfUowIQ2fviSX64wSG1HP8Iao,9
9
+ mdot_tnt-0.2.0.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- mdot_tnt/__init__.py,sha256=zi0ujbN19nv_dt8wdxPQgsOuDi-4gIg4sq6y9fDfAkI,2657
2
- mdot_tnt/mdot.py,sha256=EUch7ECZxDWDoNW3S84qyCx6-qbG6Mu1_XlWow862Wc,4620
3
- mdot_tnt/rounding.py,sha256=Q7QBPsFzBqnMZKnlV147ruStUCme6gwt6HnjTjqBezk,2405
4
- mdot_tnt/truncated_newton.py,sha256=YEkWJxL5H4XsuWQzfdZ0E5Ju3Npk3kKoe2l_go8wq2A,10314
5
- mdot_tnt-0.1.0.dist-info/LICENSE,sha256=sXw3FpVqouAddNhfwD6nXaSgABFyLnXuSn_ghjU0AhY,1837
6
- mdot_tnt-0.1.0.dist-info/METADATA,sha256=2Ou4X10K6bUGHQhnMXoFjrv-4YUeIyreSKibMLM2gq0,2972
7
- mdot_tnt-0.1.0.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
8
- mdot_tnt-0.1.0.dist-info/top_level.txt,sha256=HmxTNtoLH7F20hgZVFdfUowIQ2fviSX64wSG1HP8Iao,9
9
- mdot_tnt-0.1.0.dist-info/RECORD,,