mdot-tnt 0.1.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mdot_tnt/__init__.py +52 -8
- mdot_tnt/batched.py +634 -0
- mdot_tnt/mdot.py +105 -41
- mdot_tnt/py.typed +0 -0
- mdot_tnt/rounding.py +41 -15
- mdot_tnt/truncated_newton.py +107 -38
- mdot_tnt-1.0.0.dist-info/METADATA +216 -0
- mdot_tnt-1.0.0.dist-info/RECORD +11 -0
- {mdot_tnt-0.1.0.dist-info → mdot_tnt-1.0.0.dist-info}/WHEEL +1 -1
- {mdot_tnt-0.1.0.dist-info → mdot_tnt-1.0.0.dist-info/licenses}/LICENSE +4 -1
- mdot_tnt-0.1.0.dist-info/METADATA +0 -71
- mdot_tnt-0.1.0.dist-info/RECORD +0 -9
- {mdot_tnt-0.1.0.dist-info → mdot_tnt-1.0.0.dist-info}/top_level.txt +0 -0
mdot_tnt/__init__.py
CHANGED
|
@@ -1,11 +1,48 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MDOT-TNT: A Truncated Newton Method for Optimal Transport
|
|
3
|
+
|
|
4
|
+
This package provides efficient solvers for the entropic-regularized optimal transport
|
|
5
|
+
problem, as introduced in the paper "A Truncated Newton Method for Optimal Transport"
|
|
6
|
+
by Mete Kemertas, Amir-massoud Farahmand, Allan D. Jepson (ICLR, 2025).
|
|
7
|
+
URL: https://openreview.net/forum?id=gWrWUaCbMa
|
|
8
|
+
|
|
9
|
+
Main functions:
|
|
10
|
+
solve_OT: Solve a single OT problem.
|
|
11
|
+
solve_OT_batched: Solve multiple OT problems simultaneously (5-10x faster).
|
|
12
|
+
|
|
13
|
+
Example:
|
|
14
|
+
>>> import torch
|
|
15
|
+
>>> from mdot_tnt import solve_OT, solve_OT_batched
|
|
16
|
+
>>>
|
|
17
|
+
>>> # Single problem
|
|
18
|
+
>>> r = torch.rand(512, device='cuda', dtype=torch.float64)
|
|
19
|
+
>>> r = r / r.sum()
|
|
20
|
+
>>> c = torch.rand(512, device='cuda', dtype=torch.float64)
|
|
21
|
+
>>> c = c / c.sum()
|
|
22
|
+
>>> C = torch.rand(512, 512, device='cuda', dtype=torch.float64)
|
|
23
|
+
>>> cost = solve_OT(r, c, C, gamma_f=1024.)
|
|
24
|
+
>>>
|
|
25
|
+
>>> # Batched (32 problems at once)
|
|
26
|
+
>>> r_batch = torch.rand(32, 512, device='cuda', dtype=torch.float64)
|
|
27
|
+
>>> r_batch = r_batch / r_batch.sum(-1, keepdim=True)
|
|
28
|
+
>>> c_batch = torch.rand(32, 512, device='cuda', dtype=torch.float64)
|
|
29
|
+
>>> c_batch = c_batch / c_batch.sum(-1, keepdim=True)
|
|
30
|
+
>>> costs = solve_OT_batched(r_batch, c_batch, C, gamma_f=1024.)
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
import math
|
|
34
|
+
import warnings
|
|
1
35
|
|
|
2
36
|
import torch as th
|
|
3
37
|
|
|
4
|
-
from mdot_tnt.
|
|
38
|
+
from mdot_tnt.batched import solve_OT_batched
|
|
39
|
+
from mdot_tnt.mdot import mdot, preprocess_marginals
|
|
5
40
|
from mdot_tnt.rounding import round_altschuler, rounded_cost_altschuler
|
|
6
41
|
|
|
42
|
+
__all__ = ["solve_OT", "solve_OT_batched"]
|
|
43
|
+
|
|
7
44
|
|
|
8
|
-
def solve_OT(r, c, C, gamma_f=
|
|
45
|
+
def solve_OT(r, c, C, gamma_f=1024.0, drop_tiny=False, return_plan=False, round=True, log=False):
|
|
9
46
|
"""
|
|
10
47
|
Solve the entropic-regularized optimal transport problem. Inputs r, c, C are required to be torch tensors.
|
|
11
48
|
:param r: n-dimensional row marginal.
|
|
@@ -23,21 +60,28 @@ def solve_OT(r, c, C, gamma_f=4096., drop_tiny=False, return_plan=False, round=T
|
|
|
23
60
|
"""
|
|
24
61
|
assert all(isinstance(x, th.Tensor) for x in [r, c, C]), "r, c, and C must be torch tensors"
|
|
25
62
|
dtype = r.dtype
|
|
26
|
-
|
|
63
|
+
# Require high precision for gamma_f > 2^10
|
|
64
|
+
if gamma_f > 2**10 and dtype != th.float64:
|
|
65
|
+
warnings.warn(
|
|
66
|
+
"Switching to double precision for gamma_f > 2^10 during execution. "
|
|
67
|
+
f"Output will be input dtype: {dtype}."
|
|
68
|
+
)
|
|
27
69
|
r, c, C = r.double(), c.double(), C.double()
|
|
70
|
+
|
|
28
71
|
if drop_tiny:
|
|
29
|
-
drop_lessthan = math.log(min(r.size(-1), c.size(-1))) / (gamma_f
|
|
72
|
+
drop_lessthan = math.log(min(r.size(-1), c.size(-1))) / (gamma_f**2)
|
|
30
73
|
(r_, r_keep), (c_, c_keep), C_ = preprocess_marginals(r, c, C, drop_lessthan)
|
|
31
74
|
|
|
32
75
|
u_, v_, gamma_f_, k_total, opt_logs = mdot(r_, c_, C_, gamma_f)
|
|
33
76
|
|
|
34
|
-
u = -th.ones_like(r) * float(
|
|
35
|
-
u[
|
|
36
|
-
v = -th.ones_like(c) * float(
|
|
37
|
-
v[
|
|
77
|
+
u = -th.ones_like(r) * float("inf")
|
|
78
|
+
u[r_keep] = u_
|
|
79
|
+
v = -th.ones_like(c) * float("inf")
|
|
80
|
+
v[c_keep] = v_
|
|
38
81
|
else:
|
|
39
82
|
u, v, gamma_f_, k_total, opt_logs = mdot(r, c, C, gamma_f)
|
|
40
83
|
|
|
84
|
+
# Switch back to original dtype
|
|
41
85
|
u, v = u.to(dtype), v.to(dtype)
|
|
42
86
|
|
|
43
87
|
if return_plan:
|