mdot-tnt 0.2.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,216 @@
1
+ Metadata-Version: 2.4
2
+ Name: mdot-tnt
3
+ Version: 1.0.0
4
+ Summary: A fast, GPU-parallel, PyTorch-compatible optimal transport solver.
5
+ Author-email: Mete Kemertas <kemertas@cs.toronto.edu>
6
+ License: # Non-Commercial Research License (NCRL-1.0)
7
+
8
+ Copyright (C) 2025 Mete Kemertas
9
+
10
+ ## 1. License Grant
11
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to use, copy, modify, merge, publish, and distribute copies of the Software **solely for non-commercial research, educational, and personal purposes**, subject to the following conditions:
12
+
13
+ ## 2. Restrictions
14
+ ### 2.1 **Non-Commercial Use Only**
15
+ - The Software **may NOT** be used for any commercial purpose without explicit written permission from the Licensor.
16
+ - "Commercial purpose" includes, but is not limited to:
17
+ - Selling or licensing the Software.
18
+ - Using the Software in proprietary products or services.
19
+ - Offering the Software as part of a paid or revenue-generating service.
20
+
21
+ ### 2.2 **No Warranty & Liability**
22
+ THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY ARISING FROM THE USE OF THE SOFTWARE.
23
+
24
+ ### 2.3 **Commercial Licensing**
25
+ For commercial use, a separate license must be obtained from the Licensor. To inquire about licensing, please contact: **kemertas@cs.toronto.edu**.
26
+
27
+ ## 3. Termination
28
+ This license automatically terminates if the Licensee breaches any of its terms. Upon termination, all rights granted under this license are revoked, and the Licensee must cease using and distributing the Software.
29
+
30
+ ## 4. Governing Law and Enforcement
31
+ This license shall be governed by and construed in accordance with the laws of Ontario, Canada. However, violations of this license may also be pursued under applicable copyright laws in the jurisdiction where infringement occurs.
32
+
33
+ ## 5. Contact
34
+ For licensing inquiries, please contact: **kemertas@cs.toronto.edu**.
35
+
36
+ Project-URL: Homepage, https://github.com/metekemertas/mdot_tnt
37
+ Project-URL: Documentation, https://mdot-tnt.readthedocs.io
38
+ Project-URL: Repository, https://github.com/metekemertas/mdot_tnt
39
+ Project-URL: Issues, https://github.com/metekemertas/mdot_tnt/issues
40
+ Keywords: optimal-transport,sinkhorn,entropy-regularization,machine-learning,pytorch,gpu
41
+ Classifier: Development Status :: 5 - Production/Stable
42
+ Classifier: Intended Audience :: Science/Research
43
+ Classifier: Intended Audience :: Developers
44
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
45
+ Classifier: Topic :: Scientific/Engineering :: Mathematics
46
+ Classifier: Programming Language :: Python :: 3
47
+ Classifier: Programming Language :: Python :: 3.8
48
+ Classifier: Programming Language :: Python :: 3.9
49
+ Classifier: Programming Language :: Python :: 3.10
50
+ Classifier: Programming Language :: Python :: 3.11
51
+ Classifier: Programming Language :: Python :: 3.12
52
+ Classifier: Operating System :: OS Independent
53
+ Classifier: Typing :: Typed
54
+ Requires-Python: >=3.8
55
+ Description-Content-Type: text/markdown
56
+ License-File: LICENSE
57
+ Provides-Extra: dev
58
+ Requires-Dist: pytest>=7.0; extra == "dev"
59
+ Requires-Dist: pytest-cov>=4.0; extra == "dev"
60
+ Requires-Dist: ruff>=0.1.0; extra == "dev"
61
+ Requires-Dist: pre-commit>=3.0; extra == "dev"
62
+ Requires-Dist: numpy>=1.20; extra == "dev"
63
+ Dynamic: license-file
64
+
65
+ # MDOT-TNT
66
+
67
+ <img src="assets/logo.png" alt="MDOT-TNT Logo" width="180" align="right"/>
68
+
69
+ **A Truncated Newton Method for Optimal Transport**
70
+
71
+ [![PyPI version](https://badge.fury.io/py/mdot-tnt.svg)](https://badge.fury.io/py/mdot-tnt)
72
+ [![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
73
+ [![License](https://img.shields.io/badge/license-Non--Commercial-green.svg)](LICENSE)
74
+
75
+ A fast, GPU-accelerated solver for entropic-regularized optimal transport (OT) problems. MDOT-TNT combines mirror descent with a truncated Newton projection method to achieve high numerical precision while remaining stable under weak regularization.
76
+
77
+ <br clear="right"/>
78
+
79
+ ## Features
80
+
81
+ - **High Precision**: Stable under extremely weak regularization (γ up to 2¹⁸), enabling highly precise approximations of unregularized OT
82
+ - **GPU Accelerated**: Fully compatible with CUDA for fast computation on large problems
83
+ - **Batched Solving**: Solve multiple OT problems simultaneously in batched mode
84
+ - **Memory Efficient**: Log-domain computations and efficient rounding avoid storing full transport plans
85
+ - **PyTorch Native**: Seamless integration with PyTorch, supporting autograd-compatible inputs
86
+
87
+ ## Installation
88
+
89
+ **Prerequisites**: Install [PyTorch](https://pytorch.org/get-started/locally/) for your system configuration first.
90
+
91
+ ```bash
92
+ pip install mdot-tnt
93
+ ```
94
+
95
+ For development:
96
+
97
+ ```bash
98
+ git clone https://github.com/metekemertas/mdot_tnt.git
99
+ cd mdot_tnt
100
+ pip install -e ".[dev]"
101
+ ```
102
+
103
+ ## Quick Start
104
+
105
+ ### Single Problem
106
+
107
+ ```python
108
+ import torch
109
+ import mdot_tnt
110
+
111
+ device = "cuda" if torch.cuda.is_available() else "cpu"
112
+
113
+ # Create marginals (probability distributions)
114
+ n, m = 512, 512
115
+ r = torch.rand(n, device=device, dtype=torch.float64)
116
+ r = r / r.sum()
117
+ c = torch.rand(m, device=device, dtype=torch.float64)
118
+ c = c / c.sum()
119
+
120
+ # Cost matrix (e.g., pairwise distances)
121
+ C = torch.rand(n, m, device=device, dtype=torch.float64)
122
+
123
+ # Solve for optimal transport cost
124
+ cost = mdot_tnt.solve_OT(r, c, C, gamma_f=1024)
125
+
126
+ # Or get the full transport plan
127
+ plan = mdot_tnt.solve_OT(r, c, C, gamma_f=1024, return_plan=True)
128
+ ```
129
+
130
+ ### Batched Solving
131
+
132
+ When solving multiple OT problems, use the batched solver for significant speedup compared to sequential solution:
133
+
134
+ ```python
135
+ import torch
136
+ import mdot_tnt
137
+
138
+ device = "cuda"
139
+ batch_size, n, m = 32, 512, 512
140
+
141
+ # Multiple marginal pairs
142
+ r = torch.rand(batch_size, n, device=device, dtype=torch.float64)
143
+ r = r / r.sum(-1, keepdim=True)
144
+ c = torch.rand(batch_size, m, device=device, dtype=torch.float64)
145
+ c = c / c.sum(-1, keepdim=True)
146
+
147
+ # Shared cost matrix (or per-problem: shape [batch_size, n, m])
148
+ C = torch.rand(n, m, device=device, dtype=torch.float64)
149
+
150
+ # Solve all problems at once
151
+ costs = mdot_tnt.solve_OT_batched(r, c, C, gamma_f=1024) # Returns (batch_size,) tensor
152
+ ```
153
+
154
+ The batched solver achieves speedup by amortizing GPU synchronization overhead across all problems in the batch.
155
+
156
+ ## API Reference
157
+
158
+ ### `solve_OT`
159
+
160
+ ```python
161
+ mdot_tnt.solve_OT(r, c, C, gamma_f=1024., return_plan=False, round=True, log=False)
162
+ ```
163
+
164
+ | Parameter | Type | Description |
165
+ |-----------|------|-------------|
166
+ | `r` | `Tensor` | Row marginal of shape `(n,)`, must sum to 1 |
167
+ | `c` | `Tensor` | Column marginal of shape `(m,)`, must sum to 1 |
168
+ | `C` | `Tensor` | Cost matrix of shape `(n, m)`, recommended to normalize to [0, 1] |
169
+ | `gamma_f` | `float` | Temperature parameter (inverse regularization). Higher = more accurate. Default: 1024 |
170
+ | `return_plan` | `bool` | If True, return transport plan instead of cost |
171
+ | `round` | `bool` | If True, round solution onto feasible set |
172
+ | `log` | `bool` | If True, also return optimization logs |
173
+
174
+ **Returns**: Transport cost (scalar) or plan `(n, m)`, optionally with logs dict.
175
+
176
+ ### `solve_OT_batched`
177
+
178
+ ```python
179
+ mdot_tnt.solve_OT_batched(r, c, C, gamma_f=1024., return_plan=False, round=True, log=False)
180
+ ```
181
+
182
+ Same parameters as `solve_OT`, but with batched inputs:
183
+ - `r`: Shape `(batch, n)`
184
+ - `c`: Shape `(batch, m)`
185
+ - `C`: Shape `(n, m)` for shared cost, or `(batch, n, m)` for per-problem costs
186
+
187
+ **Returns**: Costs `(batch,)` or plans `(batch, n, m)`.
188
+
189
+ ## Performance Tips
190
+
191
+ 1. **Use float64** for `gamma_f > 1024` (automatic conversion with warning)
192
+ 2. **Normalize cost matrices** to [0, 1] for numerical stability
193
+ 3. **Use batched solver** when solving multiple problems with shared structure
194
+ 4. **Increase `gamma_f`** for higher precision (error scales as O(log n / γ) in the worst case, but can be much better)
195
+
196
+ ## Citation
197
+
198
+ If you use MDOT-TNT in your research, please cite:
199
+
200
+ ```bibtex
201
+ @inproceedings{kemertas2025truncated,
202
+ title={A Truncated Newton Method for Optimal Transport},
203
+ author={Kemertas, Mete and Farahmand, Amir-massoud and Jepson, Allan Douglas},
204
+ booktitle={The Thirteenth International Conference on Learning Representations},
205
+ year={2025},
206
+ url={https://openreview.net/forum?id=gWrWUaCbMa}
207
+ }
208
+ ```
209
+
210
+ ## License
211
+
212
+ This code is released under a [non-commercial use license](LICENSE). For commercial licensing inquiries, please contact the authors.
213
+
214
+ ## Contact
215
+
216
+ For questions or issues, please [open an issue](https://github.com/metekemertas/mdot_tnt/issues) or email: kemertas [at] cs [dot] toronto [dot] edu
@@ -0,0 +1,11 @@
1
+ mdot_tnt/__init__.py,sha256=f44F4mkhV1jg1Nm1q-pSwHilO0D0Mol0moduRmYIo0U,4357
2
+ mdot_tnt/batched.py,sha256=x2MH1ghMoaIOFBcarPD8LPOfyLKrvvMNekz9-SmwG-s,22616
3
+ mdot_tnt/mdot.py,sha256=6MwuOZmkTAlSeg3KgX6Dq9v8ISJGeX5Fzp7OBV3jpZw,6798
4
+ mdot_tnt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
+ mdot_tnt/rounding.py,sha256=eD20MV5eLy48gEAEEfr-duXmYXV5uY1ncep4_UFccKc,2916
6
+ mdot_tnt/truncated_newton.py,sha256=gH8Ta9j1h8oQ2VOuBt8kPXF0CFkAGXIxWzwx8I5kAlQ,12333
7
+ mdot_tnt-1.0.0.dist-info/licenses/LICENSE,sha256=LAzHjfxooWpseE8S3R4-rLm2tUmdoRwvimgpArBByVI,1998
8
+ mdot_tnt-1.0.0.dist-info/METADATA,sha256=bncxqe6JToreZmaUtYRdrPaCmUvNDdkV3gApJ2L7Vd8,9000
9
+ mdot_tnt-1.0.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
10
+ mdot_tnt-1.0.0.dist-info/top_level.txt,sha256=HmxTNtoLH7F20hgZVFdfUowIQ2fviSX64wSG1HP8Iao,9
11
+ mdot_tnt-1.0.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.3.0)
2
+ Generator: setuptools (80.10.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -23,4 +23,7 @@ For commercial use, a separate license must be obtained from the Licensor. To in
23
23
  This license automatically terminates if the Licensee breaches any of its terms. Upon termination, all rights granted under this license are revoked, and the Licensee must cease using and distributing the Software.
24
24
 
25
25
  ## 4. Governing Law and Enforcement
26
- This license shall be governed by and construed in accordance with the laws of Ontario, Canada. However, violations of this license may also be pursued un
26
+ This license shall be governed by and construed in accordance with the laws of Ontario, Canada. However, violations of this license may also be pursued under applicable copyright laws in the jurisdiction where infringement occurs.
27
+
28
+ ## 5. Contact
29
+ For licensing inquiries, please contact: **kemertas@cs.toronto.edu**.
@@ -1,71 +0,0 @@
1
- Metadata-Version: 2.1
2
- Name: mdot-tnt
3
- Version: 0.2.0
4
- Summary: A fast, GPU-parallel, PyTorch-compatible optimal transport solver.
5
- Author-email: Mete Kemertas <kemertas@cs.toronto.edu>
6
- Project-URL: Homepage, https://github.com/metekemertas/mdot_tnt
7
- Requires-Python: >=3.7
8
- Description-Content-Type: text/markdown
9
- License-File: LICENSE
10
-
11
- This is the official repository for the MDOT-TruncatedNewton (or MDOT-TNT)
12
- algorithm [1] for solving the entropic-regularized optimal transport (OT) problem.
13
- In addition to being GPU-friendly, the algorithm is stable under weak regularization and can therefore find highly
14
- precise approximations of the un-regularized problem's solution quickly.
15
-
16
- The current implementation is based on PyTorch and is compatible with both CPU and GPU. PyTorch is the only dependency.
17
-
18
-
19
- For installation:
20
- First, install PyTorch following the instructions at https://pytorch.org/get-started/locally/ to select the version that matches your system's configuration.
21
- ```bash
22
- pip3 install mdot_tnt
23
- ```
24
-
25
- Quickstart guide:
26
- ```
27
- import mdot_tnt
28
- import torch as th
29
- device = 'cuda' if th.cuda.is_available() else 'cpu'
30
- N, M, dim = 100, 200, 128
31
-
32
- # Sample row and column marginals from Dirichlet distributions
33
- r = th.distributions.Dirichlet(th.ones(N)).sample()
34
- c = th.distributions.Dirichlet(th.ones(M)).sample()
35
-
36
- # Cost matrix from pairwise Euclidean distances squared given random points in R^100
37
- x = th.distributions.MultivariateNormal(th.zeros(dim), th.eye(dim)).sample((N,))
38
- y = th.distributions.MultivariateNormal(th.zeros(dim), th.eye(dim)).sample((M,))
39
- C = th.cdist(x, y, p=2) ** 2
40
- C /= C.max() # Normalize cost matrix to meet convention.
41
-
42
- # Use double precision for numerical stability in high precision regime.
43
- r, c, C = r.double().to(device), c.double().to(device), C.double().to(device)
44
-
45
- # Solve OT problem. Increase (decrease) gamma_f for higher (lower) precision.
46
- # Default is gamma_f=2**10. Expect error of order logn / gamma_f at worst, and possibly lower.
47
- cost = mdot_tnt.solve_OT(r, c, C, gamma_f=2**10)
48
-
49
- # To return a feasible transport plan, use the following:
50
- transport_plan = mdot_tnt.solve_OT(r, c, C, gamma_f=2**12, return_plan=True)
51
-
52
- # In both cases, the default rounding onto the feasible set can be disabled by setting `round=False`.
53
- ```
54
-
55
- The code is released under a custom non-commerical use license. If you use our work in
56
- your research, please consider citing:
57
-
58
- ```
59
- @inproceedings{
60
- kemertas2025a,
61
- title={A Truncated Newton Method for Optimal Transport},
62
- author={Mete Kemertas and Amir-massoud Farahmand and Allan Douglas Jepson},
63
- booktitle={The Thirteenth International Conference on Learning Representations},
64
- year={2025},
65
- url={https://openreview.net/forum?id=gWrWUaCbMa}
66
- }
67
- ```
68
-
69
- For inquiries, email: kemertas [at] cs [dot] toronto [dot] edu
70
-
71
- [1] Mete Kemertas, Amir-massoud Farahmand, Allan Douglas Jepson. "A Truncated Newton Method for Optimal Transport." The Thirteenth International Conference on Learning Representations (ICLR), 2025. https://openreview.net/forum?id=gWrWUaCbMa
@@ -1,9 +0,0 @@
1
- mdot_tnt/__init__.py,sha256=9_DJl0mXe5WiKFf_12oN4d2Bnq4Q3AcuuNz9A6FmFp8,3290
2
- mdot_tnt/mdot.py,sha256=8SZxldnq64ySWZlUjsdGqT_prqfcs2ZrVFBGp-EJ4B0,5562
3
- mdot_tnt/rounding.py,sha256=Q7QBPsFzBqnMZKnlV147ruStUCme6gwt6HnjTjqBezk,2405
4
- mdot_tnt/truncated_newton.py,sha256=Zp4Tb65dSE_AGDElgAE1gp3V_CaXFsbDX3SJCaNEWfc,10299
5
- mdot_tnt-0.2.0.dist-info/LICENSE,sha256=sXw3FpVqouAddNhfwD6nXaSgABFyLnXuSn_ghjU0AhY,1837
6
- mdot_tnt-0.2.0.dist-info/METADATA,sha256=sfqk_nDPsGJlSrXoz982dgLLAOGqI1F1obSdOsodIN8,3022
7
- mdot_tnt-0.2.0.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
8
- mdot_tnt-0.2.0.dist-info/top_level.txt,sha256=HmxTNtoLH7F20hgZVFdfUowIQ2fviSX64wSG1HP8Iao,9
9
- mdot_tnt-0.2.0.dist-info/RECORD,,