mdot-tnt 0.2.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mdot_tnt/__init__.py CHANGED
@@ -1,19 +1,48 @@
1
1
  """
2
- Code for solving the entropic-regularized optimal transport problem via the MDOT-TruncatedNewton (MDOT-TNT)
3
- method introduced in the paper "A Truncated Newton Method for Optimal Transport"
2
+ MDOT-TNT: A Truncated Newton Method for Optimal Transport
3
+
4
+ This package provides efficient solvers for the entropic-regularized optimal transport
5
+ problem, as introduced in the paper "A Truncated Newton Method for Optimal Transport"
4
6
  by Mete Kemertas, Amir-massoud Farahmand, Allan D. Jepson (ICLR, 2025).
5
7
  URL: https://openreview.net/forum?id=gWrWUaCbMa
8
+
9
+ Main functions:
10
+ solve_OT: Solve a single OT problem.
11
+ solve_OT_batched: Solve multiple OT problems simultaneously (5-10x faster).
12
+
13
+ Example:
14
+ >>> import torch
15
+ >>> from mdot_tnt import solve_OT, solve_OT_batched
16
+ >>>
17
+ >>> # Single problem
18
+ >>> r = torch.rand(512, device='cuda', dtype=torch.float64)
19
+ >>> r = r / r.sum()
20
+ >>> c = torch.rand(512, device='cuda', dtype=torch.float64)
21
+ >>> c = c / c.sum()
22
+ >>> C = torch.rand(512, 512, device='cuda', dtype=torch.float64)
23
+ >>> cost = solve_OT(r, c, C, gamma_f=1024.)
24
+ >>>
25
+ >>> # Batched (32 problems at once)
26
+ >>> r_batch = torch.rand(32, 512, device='cuda', dtype=torch.float64)
27
+ >>> r_batch = r_batch / r_batch.sum(-1, keepdim=True)
28
+ >>> c_batch = torch.rand(32, 512, device='cuda', dtype=torch.float64)
29
+ >>> c_batch = c_batch / c_batch.sum(-1, keepdim=True)
30
+ >>> costs = solve_OT_batched(r_batch, c_batch, C, gamma_f=1024.)
6
31
  """
7
32
 
8
33
  import math
9
- import torch as th
10
34
  import warnings
11
35
 
36
+ import torch as th
37
+
38
+ from mdot_tnt.batched import solve_OT_batched
12
39
  from mdot_tnt.mdot import mdot, preprocess_marginals
13
40
  from mdot_tnt.rounding import round_altschuler, rounded_cost_altschuler
14
41
 
42
+ __all__ = ["solve_OT", "solve_OT_batched"]
43
+
15
44
 
16
- def solve_OT(r, c, C, gamma_f=1024., drop_tiny=False, return_plan=False, round=True, log=False):
45
+ def solve_OT(r, c, C, gamma_f=1024.0, drop_tiny=False, return_plan=False, round=True, log=False):
17
46
  """
18
47
  Solve the entropic-regularized optimal transport problem. Inputs r, c, C are required to be torch tensors.
19
48
  :param r: n-dimensional row marginal.
@@ -32,20 +61,22 @@ def solve_OT(r, c, C, gamma_f=1024., drop_tiny=False, return_plan=False, round=T
32
61
  assert all(isinstance(x, th.Tensor) for x in [r, c, C]), "r, c, and C must be torch tensors"
33
62
  dtype = r.dtype
34
63
  # Require high precision for gamma_f > 2^10
35
- if gamma_f > 2 ** 10 and dtype != th.float64:
36
- warnings.warn("Switching to double precision for gamma_f > 2^10 during execution. "
37
- "Output will be input dtype: {}.".format(dtype))
64
+ if gamma_f > 2**10 and dtype != th.float64:
65
+ warnings.warn(
66
+ "Switching to double precision for gamma_f > 2^10 during execution. "
67
+ f"Output will be input dtype: {dtype}."
68
+ )
38
69
  r, c, C = r.double(), c.double(), C.double()
39
70
 
40
71
  if drop_tiny:
41
- drop_lessthan = math.log(min(r.size(-1), c.size(-1))) / (gamma_f ** 2)
72
+ drop_lessthan = math.log(min(r.size(-1), c.size(-1))) / (gamma_f**2)
42
73
  (r_, r_keep), (c_, c_keep), C_ = preprocess_marginals(r, c, C, drop_lessthan)
43
74
 
44
75
  u_, v_, gamma_f_, k_total, opt_logs = mdot(r_, c_, C_, gamma_f)
45
76
 
46
- u = -th.ones_like(r) * float('inf')
77
+ u = -th.ones_like(r) * float("inf")
47
78
  u[r_keep] = u_
48
- v = -th.ones_like(c) * float('inf')
79
+ v = -th.ones_like(c) * float("inf")
49
80
  v[c_keep] = v_
50
81
  else:
51
82
  u, v, gamma_f_, k_total, opt_logs = mdot(r, c, C, gamma_f)