mdot-tnt 0.1.0__tar.gz → 1.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mdot_tnt-0.1.0 → mdot_tnt-1.0.0}/LICENSE +4 -1
- mdot_tnt-1.0.0/PKG-INFO +216 -0
- mdot_tnt-1.0.0/README.md +152 -0
- {mdot_tnt-0.1.0 → mdot_tnt-1.0.0}/mdot_tnt/__init__.py +52 -8
- mdot_tnt-1.0.0/mdot_tnt/batched.py +634 -0
- mdot_tnt-1.0.0/mdot_tnt/mdot.py +203 -0
- mdot_tnt-1.0.0/mdot_tnt/py.typed +0 -0
- {mdot_tnt-0.1.0 → mdot_tnt-1.0.0}/mdot_tnt/rounding.py +41 -15
- {mdot_tnt-0.1.0 → mdot_tnt-1.0.0}/mdot_tnt/truncated_newton.py +107 -38
- mdot_tnt-1.0.0/mdot_tnt.egg-info/PKG-INFO +216 -0
- {mdot_tnt-0.1.0 → mdot_tnt-1.0.0}/mdot_tnt.egg-info/SOURCES.txt +7 -1
- mdot_tnt-1.0.0/mdot_tnt.egg-info/requires.txt +7 -0
- mdot_tnt-1.0.0/pyproject.toml +92 -0
- mdot_tnt-1.0.0/tests/test_batched.py +169 -0
- mdot_tnt-1.0.0/tests/test_rounding.py +151 -0
- mdot_tnt-1.0.0/tests/test_solve_ot.py +137 -0
- mdot_tnt-0.1.0/PKG-INFO +0 -71
- mdot_tnt-0.1.0/README.md +0 -61
- mdot_tnt-0.1.0/mdot_tnt/mdot.py +0 -139
- mdot_tnt-0.1.0/mdot_tnt.egg-info/PKG-INFO +0 -71
- mdot_tnt-0.1.0/pyproject.toml +0 -21
- {mdot_tnt-0.1.0 → mdot_tnt-1.0.0}/mdot_tnt.egg-info/dependency_links.txt +0 -0
- {mdot_tnt-0.1.0 → mdot_tnt-1.0.0}/mdot_tnt.egg-info/top_level.txt +0 -0
- {mdot_tnt-0.1.0 → mdot_tnt-1.0.0}/setup.cfg +0 -0
|
@@ -23,4 +23,7 @@ For commercial use, a separate license must be obtained from the Licensor. To in
|
|
|
23
23
|
This license automatically terminates if the Licensee breaches any of its terms. Upon termination, all rights granted under this license are revoked, and the Licensee must cease using and distributing the Software.
|
|
24
24
|
|
|
25
25
|
## 4. Governing Law and Enforcement
|
|
26
|
-
This license shall be governed by and construed in accordance with the laws of Ontario, Canada. However, violations of this license may also be pursued
|
|
26
|
+
This license shall be governed by and construed in accordance with the laws of Ontario, Canada. However, violations of this license may also be pursued under applicable copyright laws in the jurisdiction where infringement occurs.
|
|
27
|
+
|
|
28
|
+
## 5. Contact
|
|
29
|
+
For licensing inquiries, please contact: **kemertas@cs.toronto.edu**.
|
mdot_tnt-1.0.0/PKG-INFO
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: mdot-tnt
|
|
3
|
+
Version: 1.0.0
|
|
4
|
+
Summary: A fast, GPU-parallel, PyTorch-compatible optimal transport solver.
|
|
5
|
+
Author-email: Mete Kemertas <kemertas@cs.toronto.edu>
|
|
6
|
+
License: # Non-Commercial Research License (NCRL-1.0)
|
|
7
|
+
|
|
8
|
+
Copyright (C) 2025 Mete Kemertas
|
|
9
|
+
|
|
10
|
+
## 1. License Grant
|
|
11
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to use, copy, modify, merge, publish, and distribute copies of the Software **solely for non-commercial research, educational, and personal purposes**, subject to the following conditions:
|
|
12
|
+
|
|
13
|
+
## 2. Restrictions
|
|
14
|
+
### 2.1 **Non-Commercial Use Only**
|
|
15
|
+
- The Software **may NOT** be used for any commercial purpose without explicit written permission from the Licensor.
|
|
16
|
+
- "Commercial purpose" includes, but is not limited to:
|
|
17
|
+
- Selling or licensing the Software.
|
|
18
|
+
- Using the Software in proprietary products or services.
|
|
19
|
+
- Offering the Software as part of a paid or revenue-generating service.
|
|
20
|
+
|
|
21
|
+
### 2.2 **No Warranty & Liability**
|
|
22
|
+
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY ARISING FROM THE USE OF THE SOFTWARE.
|
|
23
|
+
|
|
24
|
+
### 2.3 **Commercial Licensing**
|
|
25
|
+
For commercial use, a separate license must be obtained from the Licensor. To inquire about licensing, please contact: **kemertas@cs.toronto.edu**.
|
|
26
|
+
|
|
27
|
+
## 3. Termination
|
|
28
|
+
This license automatically terminates if the Licensee breaches any of its terms. Upon termination, all rights granted under this license are revoked, and the Licensee must cease using and distributing the Software.
|
|
29
|
+
|
|
30
|
+
## 4. Governing Law and Enforcement
|
|
31
|
+
This license shall be governed by and construed in accordance with the laws of Ontario, Canada. However, violations of this license may also be pursued under applicable copyright laws in the jurisdiction where infringement occurs.
|
|
32
|
+
|
|
33
|
+
## 5. Contact
|
|
34
|
+
For licensing inquiries, please contact: **kemertas@cs.toronto.edu**.
|
|
35
|
+
|
|
36
|
+
Project-URL: Homepage, https://github.com/metekemertas/mdot_tnt
|
|
37
|
+
Project-URL: Documentation, https://mdot-tnt.readthedocs.io
|
|
38
|
+
Project-URL: Repository, https://github.com/metekemertas/mdot_tnt
|
|
39
|
+
Project-URL: Issues, https://github.com/metekemertas/mdot_tnt/issues
|
|
40
|
+
Keywords: optimal-transport,sinkhorn,entropy-regularization,machine-learning,pytorch,gpu
|
|
41
|
+
Classifier: Development Status :: 5 - Production/Stable
|
|
42
|
+
Classifier: Intended Audience :: Science/Research
|
|
43
|
+
Classifier: Intended Audience :: Developers
|
|
44
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
45
|
+
Classifier: Topic :: Scientific/Engineering :: Mathematics
|
|
46
|
+
Classifier: Programming Language :: Python :: 3
|
|
47
|
+
Classifier: Programming Language :: Python :: 3.8
|
|
48
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
49
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
50
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
51
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
52
|
+
Classifier: Operating System :: OS Independent
|
|
53
|
+
Classifier: Typing :: Typed
|
|
54
|
+
Requires-Python: >=3.8
|
|
55
|
+
Description-Content-Type: text/markdown
|
|
56
|
+
License-File: LICENSE
|
|
57
|
+
Provides-Extra: dev
|
|
58
|
+
Requires-Dist: pytest>=7.0; extra == "dev"
|
|
59
|
+
Requires-Dist: pytest-cov>=4.0; extra == "dev"
|
|
60
|
+
Requires-Dist: ruff>=0.1.0; extra == "dev"
|
|
61
|
+
Requires-Dist: pre-commit>=3.0; extra == "dev"
|
|
62
|
+
Requires-Dist: numpy>=1.20; extra == "dev"
|
|
63
|
+
Dynamic: license-file
|
|
64
|
+
|
|
65
|
+
# MDOT-TNT
|
|
66
|
+
|
|
67
|
+
<img src="assets/logo.png" alt="MDOT-TNT Logo" width="180" align="right"/>
|
|
68
|
+
|
|
69
|
+
**A Truncated Newton Method for Optimal Transport**
|
|
70
|
+
|
|
71
|
+
[](https://badge.fury.io/py/mdot-tnt)
|
|
72
|
+
[](https://www.python.org/downloads/)
|
|
73
|
+
[](LICENSE)
|
|
74
|
+
|
|
75
|
+
A fast, GPU-accelerated solver for entropic-regularized optimal transport (OT) problems. MDOT-TNT combines mirror descent with a truncated Newton projection method to achieve high numerical precision while remaining stable under weak regularization.
|
|
76
|
+
|
|
77
|
+
<br clear="right"/>
|
|
78
|
+
|
|
79
|
+
## Features
|
|
80
|
+
|
|
81
|
+
- **High Precision**: Stable under extremely weak regularization (γ up to 2¹⁸), enabling highly precise approximations of unregularized OT
|
|
82
|
+
- **GPU Accelerated**: Fully compatible with CUDA for fast computation on large problems
|
|
83
|
+
- **Batched Solving**: Solve multiple OT problems simultaneously in batched mode
|
|
84
|
+
- **Memory Efficient**: Log-domain computations and efficient rounding avoid storing full transport plans
|
|
85
|
+
- **PyTorch Native**: Seamless integration with PyTorch, supporting autograd-compatible inputs
|
|
86
|
+
|
|
87
|
+
## Installation
|
|
88
|
+
|
|
89
|
+
**Prerequisites**: Install [PyTorch](https://pytorch.org/get-started/locally/) for your system configuration first.
|
|
90
|
+
|
|
91
|
+
```bash
|
|
92
|
+
pip install mdot-tnt
|
|
93
|
+
```
|
|
94
|
+
|
|
95
|
+
For development:
|
|
96
|
+
|
|
97
|
+
```bash
|
|
98
|
+
git clone https://github.com/metekemertas/mdot_tnt.git
|
|
99
|
+
cd mdot_tnt
|
|
100
|
+
pip install -e ".[dev]"
|
|
101
|
+
```
|
|
102
|
+
|
|
103
|
+
## Quick Start
|
|
104
|
+
|
|
105
|
+
### Single Problem
|
|
106
|
+
|
|
107
|
+
```python
|
|
108
|
+
import torch
|
|
109
|
+
import mdot_tnt
|
|
110
|
+
|
|
111
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
112
|
+
|
|
113
|
+
# Create marginals (probability distributions)
|
|
114
|
+
n, m = 512, 512
|
|
115
|
+
r = torch.rand(n, device=device, dtype=torch.float64)
|
|
116
|
+
r = r / r.sum()
|
|
117
|
+
c = torch.rand(m, device=device, dtype=torch.float64)
|
|
118
|
+
c = c / c.sum()
|
|
119
|
+
|
|
120
|
+
# Cost matrix (e.g., pairwise distances)
|
|
121
|
+
C = torch.rand(n, m, device=device, dtype=torch.float64)
|
|
122
|
+
|
|
123
|
+
# Solve for optimal transport cost
|
|
124
|
+
cost = mdot_tnt.solve_OT(r, c, C, gamma_f=1024)
|
|
125
|
+
|
|
126
|
+
# Or get the full transport plan
|
|
127
|
+
plan = mdot_tnt.solve_OT(r, c, C, gamma_f=1024, return_plan=True)
|
|
128
|
+
```
|
|
129
|
+
|
|
130
|
+
### Batched Solving
|
|
131
|
+
|
|
132
|
+
When solving multiple OT problems, use the batched solver for significant speedup compared to sequential solution:
|
|
133
|
+
|
|
134
|
+
```python
|
|
135
|
+
import torch
|
|
136
|
+
import mdot_tnt
|
|
137
|
+
|
|
138
|
+
device = "cuda"
|
|
139
|
+
batch_size, n, m = 32, 512, 512
|
|
140
|
+
|
|
141
|
+
# Multiple marginal pairs
|
|
142
|
+
r = torch.rand(batch_size, n, device=device, dtype=torch.float64)
|
|
143
|
+
r = r / r.sum(-1, keepdim=True)
|
|
144
|
+
c = torch.rand(batch_size, m, device=device, dtype=torch.float64)
|
|
145
|
+
c = c / c.sum(-1, keepdim=True)
|
|
146
|
+
|
|
147
|
+
# Shared cost matrix (or per-problem: shape [batch_size, n, m])
|
|
148
|
+
C = torch.rand(n, m, device=device, dtype=torch.float64)
|
|
149
|
+
|
|
150
|
+
# Solve all problems at once
|
|
151
|
+
costs = mdot_tnt.solve_OT_batched(r, c, C, gamma_f=1024) # Returns (batch_size,) tensor
|
|
152
|
+
```
|
|
153
|
+
|
|
154
|
+
The batched solver achieves speedup by amortizing GPU synchronization overhead across all problems in the batch.
|
|
155
|
+
|
|
156
|
+
## API Reference
|
|
157
|
+
|
|
158
|
+
### `solve_OT`
|
|
159
|
+
|
|
160
|
+
```python
|
|
161
|
+
mdot_tnt.solve_OT(r, c, C, gamma_f=1024., return_plan=False, round=True, log=False)
|
|
162
|
+
```
|
|
163
|
+
|
|
164
|
+
| Parameter | Type | Description |
|
|
165
|
+
|-----------|------|-------------|
|
|
166
|
+
| `r` | `Tensor` | Row marginal of shape `(n,)`, must sum to 1 |
|
|
167
|
+
| `c` | `Tensor` | Column marginal of shape `(m,)`, must sum to 1 |
|
|
168
|
+
| `C` | `Tensor` | Cost matrix of shape `(n, m)`, recommended to normalize to [0, 1] |
|
|
169
|
+
| `gamma_f` | `float` | Temperature parameter (inverse regularization). Higher = more accurate. Default: 1024 |
|
|
170
|
+
| `return_plan` | `bool` | If True, return transport plan instead of cost |
|
|
171
|
+
| `round` | `bool` | If True, round solution onto feasible set |
|
|
172
|
+
| `log` | `bool` | If True, also return optimization logs |
|
|
173
|
+
|
|
174
|
+
**Returns**: Transport cost (scalar) or plan `(n, m)`, optionally with logs dict.
|
|
175
|
+
|
|
176
|
+
### `solve_OT_batched`
|
|
177
|
+
|
|
178
|
+
```python
|
|
179
|
+
mdot_tnt.solve_OT_batched(r, c, C, gamma_f=1024., return_plan=False, round=True, log=False)
|
|
180
|
+
```
|
|
181
|
+
|
|
182
|
+
Same parameters as `solve_OT`, but with batched inputs:
|
|
183
|
+
- `r`: Shape `(batch, n)`
|
|
184
|
+
- `c`: Shape `(batch, m)`
|
|
185
|
+
- `C`: Shape `(n, m)` for shared cost, or `(batch, n, m)` for per-problem costs
|
|
186
|
+
|
|
187
|
+
**Returns**: Costs `(batch,)` or plans `(batch, n, m)`.
|
|
188
|
+
|
|
189
|
+
## Performance Tips
|
|
190
|
+
|
|
191
|
+
1. **Use float64** for `gamma_f > 1024` (automatic conversion with warning)
|
|
192
|
+
2. **Normalize cost matrices** to [0, 1] for numerical stability
|
|
193
|
+
3. **Use batched solver** when solving multiple problems with shared structure
|
|
194
|
+
4. **Increase `gamma_f`** for higher precision (error scales as O(log n / γ) in the worst case, but can be much better)
|
|
195
|
+
|
|
196
|
+
## Citation
|
|
197
|
+
|
|
198
|
+
If you use MDOT-TNT in your research, please cite:
|
|
199
|
+
|
|
200
|
+
```bibtex
|
|
201
|
+
@inproceedings{kemertas2025truncated,
|
|
202
|
+
title={A Truncated Newton Method for Optimal Transport},
|
|
203
|
+
author={Kemertas, Mete and Farahmand, Amir-massoud and Jepson, Allan Douglas},
|
|
204
|
+
booktitle={The Thirteenth International Conference on Learning Representations},
|
|
205
|
+
year={2025},
|
|
206
|
+
url={https://openreview.net/forum?id=gWrWUaCbMa}
|
|
207
|
+
}
|
|
208
|
+
```
|
|
209
|
+
|
|
210
|
+
## License
|
|
211
|
+
|
|
212
|
+
This code is released under a [non-commercial use license](LICENSE). For commercial licensing inquiries, please contact the authors.
|
|
213
|
+
|
|
214
|
+
## Contact
|
|
215
|
+
|
|
216
|
+
For questions or issues, please [open an issue](https://github.com/metekemertas/mdot_tnt/issues) or email: kemertas [at] cs [dot] toronto [dot] edu
|
mdot_tnt-1.0.0/README.md
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
# MDOT-TNT
|
|
2
|
+
|
|
3
|
+
<img src="assets/logo.png" alt="MDOT-TNT Logo" width="180" align="right"/>
|
|
4
|
+
|
|
5
|
+
**A Truncated Newton Method for Optimal Transport**
|
|
6
|
+
|
|
7
|
+
[](https://badge.fury.io/py/mdot-tnt)
|
|
8
|
+
[](https://www.python.org/downloads/)
|
|
9
|
+
[](LICENSE)
|
|
10
|
+
|
|
11
|
+
A fast, GPU-accelerated solver for entropic-regularized optimal transport (OT) problems. MDOT-TNT combines mirror descent with a truncated Newton projection method to achieve high numerical precision while remaining stable under weak regularization.
|
|
12
|
+
|
|
13
|
+
<br clear="right"/>
|
|
14
|
+
|
|
15
|
+
## Features
|
|
16
|
+
|
|
17
|
+
- **High Precision**: Stable under extremely weak regularization (γ up to 2¹⁸), enabling highly precise approximations of unregularized OT
|
|
18
|
+
- **GPU Accelerated**: Fully compatible with CUDA for fast computation on large problems
|
|
19
|
+
- **Batched Solving**: Solve multiple OT problems simultaneously in batched mode
|
|
20
|
+
- **Memory Efficient**: Log-domain computations and efficient rounding avoid storing full transport plans
|
|
21
|
+
- **PyTorch Native**: Seamless integration with PyTorch, supporting autograd-compatible inputs
|
|
22
|
+
|
|
23
|
+
## Installation
|
|
24
|
+
|
|
25
|
+
**Prerequisites**: Install [PyTorch](https://pytorch.org/get-started/locally/) for your system configuration first.
|
|
26
|
+
|
|
27
|
+
```bash
|
|
28
|
+
pip install mdot-tnt
|
|
29
|
+
```
|
|
30
|
+
|
|
31
|
+
For development:
|
|
32
|
+
|
|
33
|
+
```bash
|
|
34
|
+
git clone https://github.com/metekemertas/mdot_tnt.git
|
|
35
|
+
cd mdot_tnt
|
|
36
|
+
pip install -e ".[dev]"
|
|
37
|
+
```
|
|
38
|
+
|
|
39
|
+
## Quick Start
|
|
40
|
+
|
|
41
|
+
### Single Problem
|
|
42
|
+
|
|
43
|
+
```python
|
|
44
|
+
import torch
|
|
45
|
+
import mdot_tnt
|
|
46
|
+
|
|
47
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
48
|
+
|
|
49
|
+
# Create marginals (probability distributions)
|
|
50
|
+
n, m = 512, 512
|
|
51
|
+
r = torch.rand(n, device=device, dtype=torch.float64)
|
|
52
|
+
r = r / r.sum()
|
|
53
|
+
c = torch.rand(m, device=device, dtype=torch.float64)
|
|
54
|
+
c = c / c.sum()
|
|
55
|
+
|
|
56
|
+
# Cost matrix (e.g., pairwise distances)
|
|
57
|
+
C = torch.rand(n, m, device=device, dtype=torch.float64)
|
|
58
|
+
|
|
59
|
+
# Solve for optimal transport cost
|
|
60
|
+
cost = mdot_tnt.solve_OT(r, c, C, gamma_f=1024)
|
|
61
|
+
|
|
62
|
+
# Or get the full transport plan
|
|
63
|
+
plan = mdot_tnt.solve_OT(r, c, C, gamma_f=1024, return_plan=True)
|
|
64
|
+
```
|
|
65
|
+
|
|
66
|
+
### Batched Solving
|
|
67
|
+
|
|
68
|
+
When solving multiple OT problems, use the batched solver for significant speedup compared to sequential solution:
|
|
69
|
+
|
|
70
|
+
```python
|
|
71
|
+
import torch
|
|
72
|
+
import mdot_tnt
|
|
73
|
+
|
|
74
|
+
device = "cuda"
|
|
75
|
+
batch_size, n, m = 32, 512, 512
|
|
76
|
+
|
|
77
|
+
# Multiple marginal pairs
|
|
78
|
+
r = torch.rand(batch_size, n, device=device, dtype=torch.float64)
|
|
79
|
+
r = r / r.sum(-1, keepdim=True)
|
|
80
|
+
c = torch.rand(batch_size, m, device=device, dtype=torch.float64)
|
|
81
|
+
c = c / c.sum(-1, keepdim=True)
|
|
82
|
+
|
|
83
|
+
# Shared cost matrix (or per-problem: shape [batch_size, n, m])
|
|
84
|
+
C = torch.rand(n, m, device=device, dtype=torch.float64)
|
|
85
|
+
|
|
86
|
+
# Solve all problems at once
|
|
87
|
+
costs = mdot_tnt.solve_OT_batched(r, c, C, gamma_f=1024) # Returns (batch_size,) tensor
|
|
88
|
+
```
|
|
89
|
+
|
|
90
|
+
The batched solver achieves speedup by amortizing GPU synchronization overhead across all problems in the batch.
|
|
91
|
+
|
|
92
|
+
## API Reference
|
|
93
|
+
|
|
94
|
+
### `solve_OT`
|
|
95
|
+
|
|
96
|
+
```python
|
|
97
|
+
mdot_tnt.solve_OT(r, c, C, gamma_f=1024., return_plan=False, round=True, log=False)
|
|
98
|
+
```
|
|
99
|
+
|
|
100
|
+
| Parameter | Type | Description |
|
|
101
|
+
|-----------|------|-------------|
|
|
102
|
+
| `r` | `Tensor` | Row marginal of shape `(n,)`, must sum to 1 |
|
|
103
|
+
| `c` | `Tensor` | Column marginal of shape `(m,)`, must sum to 1 |
|
|
104
|
+
| `C` | `Tensor` | Cost matrix of shape `(n, m)`, recommended to normalize to [0, 1] |
|
|
105
|
+
| `gamma_f` | `float` | Temperature parameter (inverse regularization). Higher = more accurate. Default: 1024 |
|
|
106
|
+
| `return_plan` | `bool` | If True, return transport plan instead of cost |
|
|
107
|
+
| `round` | `bool` | If True, round solution onto feasible set |
|
|
108
|
+
| `log` | `bool` | If True, also return optimization logs |
|
|
109
|
+
|
|
110
|
+
**Returns**: Transport cost (scalar) or plan `(n, m)`, optionally with logs dict.
|
|
111
|
+
|
|
112
|
+
### `solve_OT_batched`
|
|
113
|
+
|
|
114
|
+
```python
|
|
115
|
+
mdot_tnt.solve_OT_batched(r, c, C, gamma_f=1024., return_plan=False, round=True, log=False)
|
|
116
|
+
```
|
|
117
|
+
|
|
118
|
+
Same parameters as `solve_OT`, but with batched inputs:
|
|
119
|
+
- `r`: Shape `(batch, n)`
|
|
120
|
+
- `c`: Shape `(batch, m)`
|
|
121
|
+
- `C`: Shape `(n, m)` for shared cost, or `(batch, n, m)` for per-problem costs
|
|
122
|
+
|
|
123
|
+
**Returns**: Costs `(batch,)` or plans `(batch, n, m)`.
|
|
124
|
+
|
|
125
|
+
## Performance Tips
|
|
126
|
+
|
|
127
|
+
1. **Use float64** for `gamma_f > 1024` (automatic conversion with warning)
|
|
128
|
+
2. **Normalize cost matrices** to [0, 1] for numerical stability
|
|
129
|
+
3. **Use batched solver** when solving multiple problems with shared structure
|
|
130
|
+
4. **Increase `gamma_f`** for higher precision (error scales as O(log n / γ) in the worst case, but can be much better)
|
|
131
|
+
|
|
132
|
+
## Citation
|
|
133
|
+
|
|
134
|
+
If you use MDOT-TNT in your research, please cite:
|
|
135
|
+
|
|
136
|
+
```bibtex
|
|
137
|
+
@inproceedings{kemertas2025truncated,
|
|
138
|
+
title={A Truncated Newton Method for Optimal Transport},
|
|
139
|
+
author={Kemertas, Mete and Farahmand, Amir-massoud and Jepson, Allan Douglas},
|
|
140
|
+
booktitle={The Thirteenth International Conference on Learning Representations},
|
|
141
|
+
year={2025},
|
|
142
|
+
url={https://openreview.net/forum?id=gWrWUaCbMa}
|
|
143
|
+
}
|
|
144
|
+
```
|
|
145
|
+
|
|
146
|
+
## License
|
|
147
|
+
|
|
148
|
+
This code is released under a [non-commercial use license](LICENSE). For commercial licensing inquiries, please contact the authors.
|
|
149
|
+
|
|
150
|
+
## Contact
|
|
151
|
+
|
|
152
|
+
For questions or issues, please [open an issue](https://github.com/metekemertas/mdot_tnt/issues) or email: kemertas [at] cs [dot] toronto [dot] edu
|
|
@@ -1,11 +1,48 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MDOT-TNT: A Truncated Newton Method for Optimal Transport
|
|
3
|
+
|
|
4
|
+
This package provides efficient solvers for the entropic-regularized optimal transport
|
|
5
|
+
problem, as introduced in the paper "A Truncated Newton Method for Optimal Transport"
|
|
6
|
+
by Mete Kemertas, Amir-massoud Farahmand, Allan D. Jepson (ICLR, 2025).
|
|
7
|
+
URL: https://openreview.net/forum?id=gWrWUaCbMa
|
|
8
|
+
|
|
9
|
+
Main functions:
|
|
10
|
+
solve_OT: Solve a single OT problem.
|
|
11
|
+
solve_OT_batched: Solve multiple OT problems simultaneously (5-10x faster).
|
|
12
|
+
|
|
13
|
+
Example:
|
|
14
|
+
>>> import torch
|
|
15
|
+
>>> from mdot_tnt import solve_OT, solve_OT_batched
|
|
16
|
+
>>>
|
|
17
|
+
>>> # Single problem
|
|
18
|
+
>>> r = torch.rand(512, device='cuda', dtype=torch.float64)
|
|
19
|
+
>>> r = r / r.sum()
|
|
20
|
+
>>> c = torch.rand(512, device='cuda', dtype=torch.float64)
|
|
21
|
+
>>> c = c / c.sum()
|
|
22
|
+
>>> C = torch.rand(512, 512, device='cuda', dtype=torch.float64)
|
|
23
|
+
>>> cost = solve_OT(r, c, C, gamma_f=1024.)
|
|
24
|
+
>>>
|
|
25
|
+
>>> # Batched (32 problems at once)
|
|
26
|
+
>>> r_batch = torch.rand(32, 512, device='cuda', dtype=torch.float64)
|
|
27
|
+
>>> r_batch = r_batch / r_batch.sum(-1, keepdim=True)
|
|
28
|
+
>>> c_batch = torch.rand(32, 512, device='cuda', dtype=torch.float64)
|
|
29
|
+
>>> c_batch = c_batch / c_batch.sum(-1, keepdim=True)
|
|
30
|
+
>>> costs = solve_OT_batched(r_batch, c_batch, C, gamma_f=1024.)
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
import math
|
|
34
|
+
import warnings
|
|
1
35
|
|
|
2
36
|
import torch as th
|
|
3
37
|
|
|
4
|
-
from mdot_tnt.
|
|
38
|
+
from mdot_tnt.batched import solve_OT_batched
|
|
39
|
+
from mdot_tnt.mdot import mdot, preprocess_marginals
|
|
5
40
|
from mdot_tnt.rounding import round_altschuler, rounded_cost_altschuler
|
|
6
41
|
|
|
42
|
+
__all__ = ["solve_OT", "solve_OT_batched"]
|
|
43
|
+
|
|
7
44
|
|
|
8
|
-
def solve_OT(r, c, C, gamma_f=
|
|
45
|
+
def solve_OT(r, c, C, gamma_f=1024.0, drop_tiny=False, return_plan=False, round=True, log=False):
|
|
9
46
|
"""
|
|
10
47
|
Solve the entropic-regularized optimal transport problem. Inputs r, c, C are required to be torch tensors.
|
|
11
48
|
:param r: n-dimensional row marginal.
|
|
@@ -23,21 +60,28 @@ def solve_OT(r, c, C, gamma_f=4096., drop_tiny=False, return_plan=False, round=T
|
|
|
23
60
|
"""
|
|
24
61
|
assert all(isinstance(x, th.Tensor) for x in [r, c, C]), "r, c, and C must be torch tensors"
|
|
25
62
|
dtype = r.dtype
|
|
26
|
-
|
|
63
|
+
# Require high precision for gamma_f > 2^10
|
|
64
|
+
if gamma_f > 2**10 and dtype != th.float64:
|
|
65
|
+
warnings.warn(
|
|
66
|
+
"Switching to double precision for gamma_f > 2^10 during execution. "
|
|
67
|
+
f"Output will be input dtype: {dtype}."
|
|
68
|
+
)
|
|
27
69
|
r, c, C = r.double(), c.double(), C.double()
|
|
70
|
+
|
|
28
71
|
if drop_tiny:
|
|
29
|
-
drop_lessthan = math.log(min(r.size(-1), c.size(-1))) / (gamma_f
|
|
72
|
+
drop_lessthan = math.log(min(r.size(-1), c.size(-1))) / (gamma_f**2)
|
|
30
73
|
(r_, r_keep), (c_, c_keep), C_ = preprocess_marginals(r, c, C, drop_lessthan)
|
|
31
74
|
|
|
32
75
|
u_, v_, gamma_f_, k_total, opt_logs = mdot(r_, c_, C_, gamma_f)
|
|
33
76
|
|
|
34
|
-
u = -th.ones_like(r) * float(
|
|
35
|
-
u[
|
|
36
|
-
v = -th.ones_like(c) * float(
|
|
37
|
-
v[
|
|
77
|
+
u = -th.ones_like(r) * float("inf")
|
|
78
|
+
u[r_keep] = u_
|
|
79
|
+
v = -th.ones_like(c) * float("inf")
|
|
80
|
+
v[c_keep] = v_
|
|
38
81
|
else:
|
|
39
82
|
u, v, gamma_f_, k_total, opt_logs = mdot(r, c, C, gamma_f)
|
|
40
83
|
|
|
84
|
+
# Switch back to original dtype
|
|
41
85
|
u, v = u.to(dtype), v.to(dtype)
|
|
42
86
|
|
|
43
87
|
if return_plan:
|