mdot-tnt 0.1.0__tar.gz → 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -23,4 +23,7 @@ For commercial use, a separate license must be obtained from the Licensor. To in
23
23
  This license automatically terminates if the Licensee breaches any of its terms. Upon termination, all rights granted under this license are revoked, and the Licensee must cease using and distributing the Software.
24
24
 
25
25
  ## 4. Governing Law and Enforcement
26
- This license shall be governed by and construed in accordance with the laws of Ontario, Canada. However, violations of this license may also be pursued un
26
+ This license shall be governed by and construed in accordance with the laws of Ontario, Canada. However, violations of this license may also be pursued under applicable copyright laws in the jurisdiction where infringement occurs.
27
+
28
+ ## 5. Contact
29
+ For licensing inquiries, please contact: **kemertas@cs.toronto.edu**.
@@ -0,0 +1,216 @@
1
+ Metadata-Version: 2.4
2
+ Name: mdot-tnt
3
+ Version: 1.0.0
4
+ Summary: A fast, GPU-parallel, PyTorch-compatible optimal transport solver.
5
+ Author-email: Mete Kemertas <kemertas@cs.toronto.edu>
6
+ License: # Non-Commercial Research License (NCRL-1.0)
7
+
8
+ Copyright (C) 2025 Mete Kemertas
9
+
10
+ ## 1. License Grant
11
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to use, copy, modify, merge, publish, and distribute copies of the Software **solely for non-commercial research, educational, and personal purposes**, subject to the following conditions:
12
+
13
+ ## 2. Restrictions
14
+ ### 2.1 **Non-Commercial Use Only**
15
+ - The Software **may NOT** be used for any commercial purpose without explicit written permission from the Licensor.
16
+ - "Commercial purpose" includes, but is not limited to:
17
+ - Selling or licensing the Software.
18
+ - Using the Software in proprietary products or services.
19
+ - Offering the Software as part of a paid or revenue-generating service.
20
+
21
+ ### 2.2 **No Warranty & Liability**
22
+ THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY ARISING FROM THE USE OF THE SOFTWARE.
23
+
24
+ ### 2.3 **Commercial Licensing**
25
+ For commercial use, a separate license must be obtained from the Licensor. To inquire about licensing, please contact: **kemertas@cs.toronto.edu**.
26
+
27
+ ## 3. Termination
28
+ This license automatically terminates if the Licensee breaches any of its terms. Upon termination, all rights granted under this license are revoked, and the Licensee must cease using and distributing the Software.
29
+
30
+ ## 4. Governing Law and Enforcement
31
+ This license shall be governed by and construed in accordance with the laws of Ontario, Canada. However, violations of this license may also be pursued under applicable copyright laws in the jurisdiction where infringement occurs.
32
+
33
+ ## 5. Contact
34
+ For licensing inquiries, please contact: **kemertas@cs.toronto.edu**.
35
+
36
+ Project-URL: Homepage, https://github.com/metekemertas/mdot_tnt
37
+ Project-URL: Documentation, https://mdot-tnt.readthedocs.io
38
+ Project-URL: Repository, https://github.com/metekemertas/mdot_tnt
39
+ Project-URL: Issues, https://github.com/metekemertas/mdot_tnt/issues
40
+ Keywords: optimal-transport,sinkhorn,entropy-regularization,machine-learning,pytorch,gpu
41
+ Classifier: Development Status :: 5 - Production/Stable
42
+ Classifier: Intended Audience :: Science/Research
43
+ Classifier: Intended Audience :: Developers
44
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
45
+ Classifier: Topic :: Scientific/Engineering :: Mathematics
46
+ Classifier: Programming Language :: Python :: 3
47
+ Classifier: Programming Language :: Python :: 3.8
48
+ Classifier: Programming Language :: Python :: 3.9
49
+ Classifier: Programming Language :: Python :: 3.10
50
+ Classifier: Programming Language :: Python :: 3.11
51
+ Classifier: Programming Language :: Python :: 3.12
52
+ Classifier: Operating System :: OS Independent
53
+ Classifier: Typing :: Typed
54
+ Requires-Python: >=3.8
55
+ Description-Content-Type: text/markdown
56
+ License-File: LICENSE
57
+ Provides-Extra: dev
58
+ Requires-Dist: pytest>=7.0; extra == "dev"
59
+ Requires-Dist: pytest-cov>=4.0; extra == "dev"
60
+ Requires-Dist: ruff>=0.1.0; extra == "dev"
61
+ Requires-Dist: pre-commit>=3.0; extra == "dev"
62
+ Requires-Dist: numpy>=1.20; extra == "dev"
63
+ Dynamic: license-file
64
+
65
+ # MDOT-TNT
66
+
67
+ <img src="assets/logo.png" alt="MDOT-TNT Logo" width="180" align="right"/>
68
+
69
+ **A Truncated Newton Method for Optimal Transport**
70
+
71
+ [![PyPI version](https://badge.fury.io/py/mdot-tnt.svg)](https://badge.fury.io/py/mdot-tnt)
72
+ [![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
73
+ [![License](https://img.shields.io/badge/license-Non--Commercial-green.svg)](LICENSE)
74
+
75
+ A fast, GPU-accelerated solver for entropic-regularized optimal transport (OT) problems. MDOT-TNT combines mirror descent with a truncated Newton projection method to achieve high numerical precision while remaining stable under weak regularization.
76
+
77
+ <br clear="right"/>
78
+
79
+ ## Features
80
+
81
+ - **High Precision**: Stable under extremely weak regularization (γ up to 2¹⁸), enabling highly precise approximations of unregularized OT
82
+ - **GPU Accelerated**: Fully compatible with CUDA for fast computation on large problems
83
+ - **Batched Solving**: Solve multiple OT problems simultaneously in batched mode
84
+ - **Memory Efficient**: Log-domain computations and efficient rounding avoid storing full transport plans
85
+ - **PyTorch Native**: Seamless integration with PyTorch, supporting autograd-compatible inputs
86
+
87
+ ## Installation
88
+
89
+ **Prerequisites**: Install [PyTorch](https://pytorch.org/get-started/locally/) for your system configuration first.
90
+
91
+ ```bash
92
+ pip install mdot-tnt
93
+ ```
94
+
95
+ For development:
96
+
97
+ ```bash
98
+ git clone https://github.com/metekemertas/mdot_tnt.git
99
+ cd mdot_tnt
100
+ pip install -e ".[dev]"
101
+ ```
102
+
103
+ ## Quick Start
104
+
105
+ ### Single Problem
106
+
107
+ ```python
108
+ import torch
109
+ import mdot_tnt
110
+
111
+ device = "cuda" if torch.cuda.is_available() else "cpu"
112
+
113
+ # Create marginals (probability distributions)
114
+ n, m = 512, 512
115
+ r = torch.rand(n, device=device, dtype=torch.float64)
116
+ r = r / r.sum()
117
+ c = torch.rand(m, device=device, dtype=torch.float64)
118
+ c = c / c.sum()
119
+
120
+ # Cost matrix (e.g., pairwise distances)
121
+ C = torch.rand(n, m, device=device, dtype=torch.float64)
122
+
123
+ # Solve for optimal transport cost
124
+ cost = mdot_tnt.solve_OT(r, c, C, gamma_f=1024)
125
+
126
+ # Or get the full transport plan
127
+ plan = mdot_tnt.solve_OT(r, c, C, gamma_f=1024, return_plan=True)
128
+ ```
129
+
130
+ ### Batched Solving
131
+
132
+ When solving multiple OT problems, use the batched solver for significant speedup compared to sequential solution:
133
+
134
+ ```python
135
+ import torch
136
+ import mdot_tnt
137
+
138
+ device = "cuda"
139
+ batch_size, n, m = 32, 512, 512
140
+
141
+ # Multiple marginal pairs
142
+ r = torch.rand(batch_size, n, device=device, dtype=torch.float64)
143
+ r = r / r.sum(-1, keepdim=True)
144
+ c = torch.rand(batch_size, m, device=device, dtype=torch.float64)
145
+ c = c / c.sum(-1, keepdim=True)
146
+
147
+ # Shared cost matrix (or per-problem: shape [batch_size, n, m])
148
+ C = torch.rand(n, m, device=device, dtype=torch.float64)
149
+
150
+ # Solve all problems at once
151
+ costs = mdot_tnt.solve_OT_batched(r, c, C, gamma_f=1024) # Returns (batch_size,) tensor
152
+ ```
153
+
154
+ The batched solver achieves speedup by amortizing GPU synchronization overhead across all problems in the batch.
155
+
156
+ ## API Reference
157
+
158
+ ### `solve_OT`
159
+
160
+ ```python
161
+ mdot_tnt.solve_OT(r, c, C, gamma_f=1024., return_plan=False, round=True, log=False)
162
+ ```
163
+
164
+ | Parameter | Type | Description |
165
+ |-----------|------|-------------|
166
+ | `r` | `Tensor` | Row marginal of shape `(n,)`, must sum to 1 |
167
+ | `c` | `Tensor` | Column marginal of shape `(m,)`, must sum to 1 |
168
+ | `C` | `Tensor` | Cost matrix of shape `(n, m)`, recommended to normalize to [0, 1] |
169
+ | `gamma_f` | `float` | Temperature parameter (inverse regularization). Higher = more accurate. Default: 1024 |
170
+ | `return_plan` | `bool` | If True, return transport plan instead of cost |
171
+ | `round` | `bool` | If True, round solution onto feasible set |
172
+ | `log` | `bool` | If True, also return optimization logs |
173
+
174
+ **Returns**: Transport cost (scalar) or plan `(n, m)`, optionally with logs dict.
175
+
176
+ ### `solve_OT_batched`
177
+
178
+ ```python
179
+ mdot_tnt.solve_OT_batched(r, c, C, gamma_f=1024., return_plan=False, round=True, log=False)
180
+ ```
181
+
182
+ Same parameters as `solve_OT`, but with batched inputs:
183
+ - `r`: Shape `(batch, n)`
184
+ - `c`: Shape `(batch, m)`
185
+ - `C`: Shape `(n, m)` for shared cost, or `(batch, n, m)` for per-problem costs
186
+
187
+ **Returns**: Costs `(batch,)` or plans `(batch, n, m)`.
188
+
189
+ ## Performance Tips
190
+
191
+ 1. **Use float64** for `gamma_f > 1024` (automatic conversion with warning)
192
+ 2. **Normalize cost matrices** to [0, 1] for numerical stability
193
+ 3. **Use batched solver** when solving multiple problems with shared structure
194
+ 4. **Increase `gamma_f`** for higher precision (error scales as O(log n / γ) in the worst case, but can be much better)
195
+
196
+ ## Citation
197
+
198
+ If you use MDOT-TNT in your research, please cite:
199
+
200
+ ```bibtex
201
+ @inproceedings{kemertas2025truncated,
202
+ title={A Truncated Newton Method for Optimal Transport},
203
+ author={Kemertas, Mete and Farahmand, Amir-massoud and Jepson, Allan Douglas},
204
+ booktitle={The Thirteenth International Conference on Learning Representations},
205
+ year={2025},
206
+ url={https://openreview.net/forum?id=gWrWUaCbMa}
207
+ }
208
+ ```
209
+
210
+ ## License
211
+
212
+ This code is released under a [non-commercial use license](LICENSE). For commercial licensing inquiries, please contact the authors.
213
+
214
+ ## Contact
215
+
216
+ For questions or issues, please [open an issue](https://github.com/metekemertas/mdot_tnt/issues) or email: kemertas [at] cs [dot] toronto [dot] edu
@@ -0,0 +1,152 @@
1
+ # MDOT-TNT
2
+
3
+ <img src="assets/logo.png" alt="MDOT-TNT Logo" width="180" align="right"/>
4
+
5
+ **A Truncated Newton Method for Optimal Transport**
6
+
7
+ [![PyPI version](https://badge.fury.io/py/mdot-tnt.svg)](https://badge.fury.io/py/mdot-tnt)
8
+ [![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
9
+ [![License](https://img.shields.io/badge/license-Non--Commercial-green.svg)](LICENSE)
10
+
11
+ A fast, GPU-accelerated solver for entropic-regularized optimal transport (OT) problems. MDOT-TNT combines mirror descent with a truncated Newton projection method to achieve high numerical precision while remaining stable under weak regularization.
12
+
13
+ <br clear="right"/>
14
+
15
+ ## Features
16
+
17
+ - **High Precision**: Stable under extremely weak regularization (γ up to 2¹⁸), enabling highly precise approximations of unregularized OT
18
+ - **GPU Accelerated**: Fully compatible with CUDA for fast computation on large problems
19
+ - **Batched Solving**: Solve multiple OT problems simultaneously in batched mode
20
+ - **Memory Efficient**: Log-domain computations and efficient rounding avoid storing full transport plans
21
+ - **PyTorch Native**: Seamless integration with PyTorch, supporting autograd-compatible inputs
22
+
23
+ ## Installation
24
+
25
+ **Prerequisites**: Install [PyTorch](https://pytorch.org/get-started/locally/) for your system configuration first.
26
+
27
+ ```bash
28
+ pip install mdot-tnt
29
+ ```
30
+
31
+ For development:
32
+
33
+ ```bash
34
+ git clone https://github.com/metekemertas/mdot_tnt.git
35
+ cd mdot_tnt
36
+ pip install -e ".[dev]"
37
+ ```
38
+
39
+ ## Quick Start
40
+
41
+ ### Single Problem
42
+
43
+ ```python
44
+ import torch
45
+ import mdot_tnt
46
+
47
+ device = "cuda" if torch.cuda.is_available() else "cpu"
48
+
49
+ # Create marginals (probability distributions)
50
+ n, m = 512, 512
51
+ r = torch.rand(n, device=device, dtype=torch.float64)
52
+ r = r / r.sum()
53
+ c = torch.rand(m, device=device, dtype=torch.float64)
54
+ c = c / c.sum()
55
+
56
+ # Cost matrix (e.g., pairwise distances)
57
+ C = torch.rand(n, m, device=device, dtype=torch.float64)
58
+
59
+ # Solve for optimal transport cost
60
+ cost = mdot_tnt.solve_OT(r, c, C, gamma_f=1024)
61
+
62
+ # Or get the full transport plan
63
+ plan = mdot_tnt.solve_OT(r, c, C, gamma_f=1024, return_plan=True)
64
+ ```
65
+
66
+ ### Batched Solving
67
+
68
+ When solving multiple OT problems, use the batched solver for significant speedup compared to sequential solution:
69
+
70
+ ```python
71
+ import torch
72
+ import mdot_tnt
73
+
74
+ device = "cuda"
75
+ batch_size, n, m = 32, 512, 512
76
+
77
+ # Multiple marginal pairs
78
+ r = torch.rand(batch_size, n, device=device, dtype=torch.float64)
79
+ r = r / r.sum(-1, keepdim=True)
80
+ c = torch.rand(batch_size, m, device=device, dtype=torch.float64)
81
+ c = c / c.sum(-1, keepdim=True)
82
+
83
+ # Shared cost matrix (or per-problem: shape [batch_size, n, m])
84
+ C = torch.rand(n, m, device=device, dtype=torch.float64)
85
+
86
+ # Solve all problems at once
87
+ costs = mdot_tnt.solve_OT_batched(r, c, C, gamma_f=1024) # Returns (batch_size,) tensor
88
+ ```
89
+
90
+ The batched solver achieves speedup by amortizing GPU synchronization overhead across all problems in the batch.
91
+
92
+ ## API Reference
93
+
94
+ ### `solve_OT`
95
+
96
+ ```python
97
+ mdot_tnt.solve_OT(r, c, C, gamma_f=1024., return_plan=False, round=True, log=False)
98
+ ```
99
+
100
+ | Parameter | Type | Description |
101
+ |-----------|------|-------------|
102
+ | `r` | `Tensor` | Row marginal of shape `(n,)`, must sum to 1 |
103
+ | `c` | `Tensor` | Column marginal of shape `(m,)`, must sum to 1 |
104
+ | `C` | `Tensor` | Cost matrix of shape `(n, m)`, recommended to normalize to [0, 1] |
105
+ | `gamma_f` | `float` | Temperature parameter (inverse regularization). Higher = more accurate. Default: 1024 |
106
+ | `return_plan` | `bool` | If True, return transport plan instead of cost |
107
+ | `round` | `bool` | If True, round solution onto feasible set |
108
+ | `log` | `bool` | If True, also return optimization logs |
109
+
110
+ **Returns**: Transport cost (scalar) or plan `(n, m)`, optionally with logs dict.
111
+
112
+ ### `solve_OT_batched`
113
+
114
+ ```python
115
+ mdot_tnt.solve_OT_batched(r, c, C, gamma_f=1024., return_plan=False, round=True, log=False)
116
+ ```
117
+
118
+ Same parameters as `solve_OT`, but with batched inputs:
119
+ - `r`: Shape `(batch, n)`
120
+ - `c`: Shape `(batch, m)`
121
+ - `C`: Shape `(n, m)` for shared cost, or `(batch, n, m)` for per-problem costs
122
+
123
+ **Returns**: Costs `(batch,)` or plans `(batch, n, m)`.
124
+
125
+ ## Performance Tips
126
+
127
+ 1. **Use float64** for `gamma_f > 1024` (automatic conversion with warning)
128
+ 2. **Normalize cost matrices** to [0, 1] for numerical stability
129
+ 3. **Use batched solver** when solving multiple problems with shared structure
130
+ 4. **Increase `gamma_f`** for higher precision (error scales as O(log n / γ) in the worst case, but can be much better)
131
+
132
+ ## Citation
133
+
134
+ If you use MDOT-TNT in your research, please cite:
135
+
136
+ ```bibtex
137
+ @inproceedings{kemertas2025truncated,
138
+ title={A Truncated Newton Method for Optimal Transport},
139
+ author={Kemertas, Mete and Farahmand, Amir-massoud and Jepson, Allan Douglas},
140
+ booktitle={The Thirteenth International Conference on Learning Representations},
141
+ year={2025},
142
+ url={https://openreview.net/forum?id=gWrWUaCbMa}
143
+ }
144
+ ```
145
+
146
+ ## License
147
+
148
+ This code is released under a [non-commercial use license](LICENSE). For commercial licensing inquiries, please contact the authors.
149
+
150
+ ## Contact
151
+
152
+ For questions or issues, please [open an issue](https://github.com/metekemertas/mdot_tnt/issues) or email: kemertas [at] cs [dot] toronto [dot] edu
@@ -1,11 +1,48 @@
1
+ """
2
+ MDOT-TNT: A Truncated Newton Method for Optimal Transport
3
+
4
+ This package provides efficient solvers for the entropic-regularized optimal transport
5
+ problem, as introduced in the paper "A Truncated Newton Method for Optimal Transport"
6
+ by Mete Kemertas, Amir-massoud Farahmand, Allan D. Jepson (ICLR, 2025).
7
+ URL: https://openreview.net/forum?id=gWrWUaCbMa
8
+
9
+ Main functions:
10
+ solve_OT: Solve a single OT problem.
11
+ solve_OT_batched: Solve multiple OT problems simultaneously (5-10x faster).
12
+
13
+ Example:
14
+ >>> import torch
15
+ >>> from mdot_tnt import solve_OT, solve_OT_batched
16
+ >>>
17
+ >>> # Single problem
18
+ >>> r = torch.rand(512, device='cuda', dtype=torch.float64)
19
+ >>> r = r / r.sum()
20
+ >>> c = torch.rand(512, device='cuda', dtype=torch.float64)
21
+ >>> c = c / c.sum()
22
+ >>> C = torch.rand(512, 512, device='cuda', dtype=torch.float64)
23
+ >>> cost = solve_OT(r, c, C, gamma_f=1024.)
24
+ >>>
25
+ >>> # Batched (32 problems at once)
26
+ >>> r_batch = torch.rand(32, 512, device='cuda', dtype=torch.float64)
27
+ >>> r_batch = r_batch / r_batch.sum(-1, keepdim=True)
28
+ >>> c_batch = torch.rand(32, 512, device='cuda', dtype=torch.float64)
29
+ >>> c_batch = c_batch / c_batch.sum(-1, keepdim=True)
30
+ >>> costs = solve_OT_batched(r_batch, c_batch, C, gamma_f=1024.)
31
+ """
32
+
33
+ import math
34
+ import warnings
1
35
 
2
36
  import torch as th
3
37
 
4
- from mdot_tnt.mdot import mdot
38
+ from mdot_tnt.batched import solve_OT_batched
39
+ from mdot_tnt.mdot import mdot, preprocess_marginals
5
40
  from mdot_tnt.rounding import round_altschuler, rounded_cost_altschuler
6
41
 
42
+ __all__ = ["solve_OT", "solve_OT_batched"]
43
+
7
44
 
8
- def solve_OT(r, c, C, gamma_f=4096., drop_tiny=False, return_plan=False, round=True, log=False):
45
+ def solve_OT(r, c, C, gamma_f=1024.0, drop_tiny=False, return_plan=False, round=True, log=False):
9
46
  """
10
47
  Solve the entropic-regularized optimal transport problem. Inputs r, c, C are required to be torch tensors.
11
48
  :param r: n-dimensional row marginal.
@@ -23,21 +60,28 @@ def solve_OT(r, c, C, gamma_f=4096., drop_tiny=False, return_plan=False, round=T
23
60
  """
24
61
  assert all(isinstance(x, th.Tensor) for x in [r, c, C]), "r, c, and C must be torch tensors"
25
62
  dtype = r.dtype
26
- if gamma_f > 2 ** 10:
63
+ # Require high precision for gamma_f > 2^10
64
+ if gamma_f > 2**10 and dtype != th.float64:
65
+ warnings.warn(
66
+ "Switching to double precision for gamma_f > 2^10 during execution. "
67
+ f"Output will be input dtype: {dtype}."
68
+ )
27
69
  r, c, C = r.double(), c.double(), C.double()
70
+
28
71
  if drop_tiny:
29
- drop_lessthan = math.log(min(r.size(-1), c.size(-1))) / (gamma_f ** 2)
72
+ drop_lessthan = math.log(min(r.size(-1), c.size(-1))) / (gamma_f**2)
30
73
  (r_, r_keep), (c_, c_keep), C_ = preprocess_marginals(r, c, C, drop_lessthan)
31
74
 
32
75
  u_, v_, gamma_f_, k_total, opt_logs = mdot(r_, c_, C_, gamma_f)
33
76
 
34
- u = -th.ones_like(r) * float('inf')
35
- u[:, r_keep] = u_
36
- v = -th.ones_like(c) * float('inf')
37
- v[:, c_keep] = v_
77
+ u = -th.ones_like(r) * float("inf")
78
+ u[r_keep] = u_
79
+ v = -th.ones_like(c) * float("inf")
80
+ v[c_keep] = v_
38
81
  else:
39
82
  u, v, gamma_f_, k_total, opt_logs = mdot(r, c, C, gamma_f)
40
83
 
84
+ # Switch back to original dtype
41
85
  u, v = u.to(dtype), v.to(dtype)
42
86
 
43
87
  if return_plan: