mdot-tnt 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mdot_tnt/__init__.py ADDED
@@ -0,0 +1,57 @@
1
+
2
+ import torch as th
3
+
4
+ from mdot_tnt.mdot import mdot
5
+ from mdot_tnt.rounding import round_altschuler, rounded_cost_altschuler
6
+
7
+
8
+ def solve_OT(r, c, C, gamma_f=4096., drop_tiny=False, return_plan=False, round=True, log=False):
9
+ """
10
+ Solve the entropic-regularized optimal transport problem. Inputs r, c, C are required to be torch tensors.
11
+ :param r: n-dimensional row marginal.
12
+ :param c: m-dimensional column marginal.
13
+ :param C: n x m cost matrix. Recommended use is to scale the entries to be in [0, 1].
14
+ :param gamma_f: The temperature (inverse of the regularization weight). For many problems, stable up to 2^18.
15
+ Higher values return more accurate solutions but take longer to converge. Use double precision if gamma_f large.
16
+ :param drop_tiny: If either marginal is known to be sparse, set this to True to drop tiny entries for a speedup.
17
+ If return_plan is True, the returned plan will be in the original dimensions.
18
+ :param return_plan: If True, return the optimal transport plan rather than the cost.
19
+ :param round: If True, use the rounding algorithm of Altschuler et al. (2017) to (a) return a feasible plan
20
+ if return_plan is True and (b) the cost of the rounded plan if return_plan is False.
21
+ :param log: If True, additionally return a dictionary containing logs of the optimization process.
22
+ :return: If return_plan is True, return the optimal transport plan as a torch tensor. Otherwise, return the cost.
23
+ """
24
+ assert all(isinstance(x, th.Tensor) for x in [r, c, C]), "r, c, and C must be torch tensors"
25
+ dtype = r.dtype
26
+ if gamma_f > 2 ** 10:
27
+ r, c, C = r.double(), c.double(), C.double()
28
+ if drop_tiny:
29
+ drop_lessthan = math.log(min(r.size(-1), c.size(-1))) / (gamma_f ** 2)
30
+ (r_, r_keep), (c_, c_keep), C_ = preprocess_marginals(r, c, C, drop_lessthan)
31
+
32
+ u_, v_, gamma_f_, k_total, opt_logs = mdot(r_, c_, C_, gamma_f)
33
+
34
+ u = -th.ones_like(r) * float('inf')
35
+ u[:, r_keep] = u_
36
+ v = -th.ones_like(c) * float('inf')
37
+ v[:, c_keep] = v_
38
+ else:
39
+ u, v, gamma_f_, k_total, opt_logs = mdot(r, c, C, gamma_f)
40
+
41
+ u, v = u.to(dtype), v.to(dtype)
42
+
43
+ if return_plan:
44
+ P = (u.unsqueeze(-1) + v.unsqueeze(-2) - gamma_f_ * C).exp()
45
+ if round:
46
+ P = round_altschuler(P, r, c)
47
+ if log:
48
+ return P, opt_logs
49
+ return P
50
+ else:
51
+ if round:
52
+ cost = rounded_cost_altschuler(u, v, r, c, C, gamma_f_)
53
+ else:
54
+ cost = ((u.unsqueeze(-1) + v.unsqueeze(-2) - gamma_f_ * C).exp() * C).sum()
55
+ if log:
56
+ return cost, opt_logs
57
+ return cost
mdot_tnt/mdot.py ADDED
@@ -0,0 +1,139 @@
1
+ """
2
+ Code for solving the entropic-regularized optimal transport problem via the MDOT-TruncatedNewton (MDOT-TNT)
3
+ method introduced in the paper "A Truncated Newton Method for Optimal Transport"
4
+ by Mete Kemertas, Amir-massoud Farahmand, Allan D. Jepson (ICLR, 2025).
5
+ URL: https://openreview.net/forum?id=gWrWUaCbMa
6
+ """
7
+
8
+
9
+ import math
10
+ import warnings
11
+
12
+ from mdot_tnt.rounding import *
13
+ from mdot_tnt.truncated_newton import TruncatedNewtonProjector
14
+
15
+
16
+ def preprocess_marginals(r, c, C, eps):
17
+ """
18
+ This function drops the smallest entries whose cumulative sum equals
19
+ :param r:
20
+ :param c:
21
+ :param C:
22
+ :param eps:
23
+ :return:
24
+ """
25
+
26
+ def preprocess_marginal(m, eps):
27
+ m_sorted, m_idx = th.sort(m, dim=-1, descending=False)
28
+ m_cumsum = th.cumsum(m_sorted, dim=-1)
29
+ m_keep = m_idx[m_cumsum > eps]
30
+ m_new = m[:, m_keep]
31
+ mass_removed = 1 - m_new.sum(-1)
32
+ m_new = m_new + mass_removed / m_new.size(-1)
33
+
34
+ return m_new, m_keep
35
+
36
+ r_new, r_keep = preprocess_marginal(r, eps)
37
+ c_new, c_keep = preprocess_marginal(c, eps)
38
+
39
+ C = C[r_keep][:, c_keep]
40
+
41
+ return (r_new, r_keep), (c_new, c_keep), C
42
+
43
+
44
+ def smooth_marginals(r, c, eps, w_r=0.5, w_c=0.5):
45
+ assert w_r + w_c == 1, "w_r and w_c must sum to 1"
46
+ eps = eps.clamp(max=1.).unsqueeze(-1)
47
+ r_hat = (1 - w_r * eps) * r + w_r * eps * th.ones_like(r) / r.size(-1)
48
+ c_hat = (1 - w_c * eps) * c + w_c * eps * th.ones_like(c) / c.size(-1)
49
+
50
+ return r_hat, c_hat
51
+
52
+
53
+ def adjust_schedule(q, deltas=None):
54
+ if deltas is None:
55
+ return q
56
+
57
+ deltas = deltas + [1.] # If deltas is empty, we assume that the first iteration was successful
58
+ delta_min = min(deltas)
59
+
60
+ if delta_min < 0.5:
61
+ q = q ** 0.5
62
+ elif delta_min > 0.9:
63
+ q = q ** 2
64
+
65
+ return q
66
+
67
+
68
+ def mdot(r, c, C, gamma_f, gamma_i=16, p=1.5, q=2.0**(1/3)):
69
+ """
70
+ Solve the entropic-regularized optimal transport problem using the MDOT method introduced in the paper:
71
+ "Efficient and Accurate Optimal Transport with Mirror Descent and Conjugate Gradients" by Mete Kemertas,
72
+ Allan D. Jepson and Amir-massoud Farahmand. URL: https://arxiv.org/abs/2307.08507
73
+ Here, we use the Truncated Newton method for projection.
74
+ :param r: The first marginal.
75
+ :param c: The second marginal.
76
+ :param C: The cost matrix. Recommended use is to scale the entries to be in [0, 1].
77
+ :param gamma_f: The final temperature (inverse of the regularization weight).
78
+ :param gamma_i: The initial temperature.
79
+ :param p: The exponent for the epsilon function, used to determine the stopping criterion for the dual gradient.
80
+ :param q: The temperature annealing (or mirror descent step size) schedule adjustment factor.
81
+ :return:
82
+ """
83
+ projector = TruncatedNewtonProjector(device=C.device, dtype=C.dtype)
84
+
85
+ H_r = -(r * (r + 1e-30).log()).sum(-1)
86
+ H_c = -(c * (c + 1e-30).log()).sum(-1)
87
+ H_min = th.min(H_r, H_c)
88
+ eps_fn = lambda g_: H_min / (g_ ** p)
89
+
90
+ logs = {
91
+ "proj_logs": [],
92
+ "eps": [],
93
+ }
94
+
95
+ t = 1
96
+ done = False
97
+ gamma = min(gamma_i, gamma_f)
98
+ gammas = [0., gamma]
99
+
100
+ while not done:
101
+ done = abs(gamma - gamma_f) < 1e-5 # Check if gamma == gamma_f (modulo rounding errors)
102
+
103
+ eps_d = eps_fn(gamma)
104
+
105
+ r_hat, c_hat = smooth_marginals(r, c, eps_d / 2, w_r=0.9, w_c=0.1)
106
+
107
+ if t == 1:
108
+ u_init, v_init = r_hat.log(), c_hat.log()
109
+ u_cur, v_cur = u_init.clone(), v_init.clone()
110
+
111
+ u_prev, v_prev = u_cur.clone(), v_cur.clone()
112
+ u_cur, v_cur, proj_log, success = projector.project(
113
+ gamma * C, r_hat.log(), c_hat.log(), eps_d / 2, u_init, v_init)
114
+
115
+ logs["proj_logs"].append(proj_log)
116
+ if not success:
117
+ warnings.warn("Projection failed. Returning result at the last temperature: {:.4e}".format(1 / gammas[-2]))
118
+ u_cur = u_prev.clone()
119
+ v_cur = v_prev.clone()
120
+ gammas = gammas[:-1]
121
+ break
122
+
123
+ q = adjust_schedule(q, proj_log["deltas"])
124
+ gamma = min(gamma * q, gamma_f)
125
+
126
+ if not done:
127
+ # Generate warm-started initializations for the next iteration.
128
+ u_init = u_cur + (u_cur - u_prev) * (gamma - gammas[-1]) / (gammas[-1] - gammas[-2])
129
+ v_init = v_cur + (v_cur - v_prev) * (gamma - gammas[-1]) / (gammas[-1] - gammas[-2])
130
+
131
+ gammas.append(gamma)
132
+ t += 1
133
+
134
+ k_total = sum([log["n_iter"] for log in logs["proj_logs"]])
135
+ k_total += t - 1
136
+ logs["success"] = success
137
+ logs["gammas"] = gammas
138
+
139
+ return u_cur, v_cur, gammas[-1], k_total, logs
mdot_tnt/rounding.py ADDED
@@ -0,0 +1,62 @@
1
+ """
2
+ This file contains the implementation of the rounding algorithm proposed by Altschuler et al. (2017) in the paper
3
+ "Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration". The algorithm is used to
4
+ round the transport plan obtained from the Sinkhorn algorithm to a feasible transport plan in the set U(r, c), where r
5
+ and c are the row and column marginals, respectively. The algorithm is used in the mdot.py file to round the transport
6
+ plan and compute the cost of the rounded plan. The implementation is based on the original paper.
7
+ """
8
+
9
+ import torch as th
10
+
11
+
12
+ def round_altschuler(P, r, c):
13
+ """
14
+ Performs rounding given a transport plan and marginals.
15
+ :param P: the input transport plan
16
+ :param r: row marginal
17
+ :param c: column marginal
18
+ :return: rounded transport plan in feasible set U(r, c).
19
+ """
20
+ X = th.min(r / P.sum(-1), th.ones_like(r))
21
+ P *= X.unsqueeze(-1)
22
+
23
+ Y = th.min(c / P.sum(-2), th.ones_like(c))
24
+ P *= Y.unsqueeze(-2)
25
+
26
+ err_r = (r - P.sum(-1)).clamp(min=0)
27
+ err_c = (c - P.sum(-2)).clamp(min=0)
28
+ P += err_r.unsqueeze(-1) @ err_c.unsqueeze(-2) / (err_r.norm(p=1, dim=-1, keepdim=True) + 1e-30).unsqueeze(-1)
29
+
30
+ return P
31
+
32
+
33
+ def rounded_cost_altschuler(u, v, r, c, C, gamma):
34
+ """Performs rounding and cost computation in logdomain given dual variables, without storing n^2 matrices.
35
+ :param u: dual variable for rows
36
+ :param v: dual variable for columns
37
+ :param r: row marginal
38
+ :param c: column marginal
39
+ :param C: cost matrix
40
+ :param gamma: temperature, i.e., the inverse of the entropic regularization weight.
41
+ """
42
+ r_P_log = u + th.logsumexp(v.unsqueeze(-2) - gamma * C, dim=-1)
43
+ delta_u = th.min(r.log() - r_P_log, th.zeros_like(r))
44
+ u += delta_u
45
+
46
+ c_P_log = v + th.logsumexp(u.unsqueeze(-1) - gamma * C, dim=-2)
47
+ delta_v = th.min(c.log() - c_P_log, th.zeros_like(c))
48
+ v += delta_v
49
+
50
+ r_P_log = u + th.logsumexp(v.unsqueeze(-2) - gamma * C, dim=-1)
51
+ r_P = r_P_log.exp()
52
+ err_r = r - r_P
53
+ err_r /= (err_r.norm(p=1, dim=-1, keepdim=True) + 1e-30)
54
+
55
+ c_P_log = v + th.logsumexp(u.unsqueeze(-1) - gamma * C, dim=-2)
56
+ c_P = c_P_log.exp()
57
+ err_c = c - c_P
58
+
59
+ cost = th.logsumexp(u.unsqueeze(-1) + v.unsqueeze(-2) - gamma * C + C.log(), dim=(-1, -2)).exp()
60
+ cost += (err_r.unsqueeze(-2) @ C @ err_c.unsqueeze(-1)).sum(-1).sum(-1)
61
+
62
+ return cost
@@ -0,0 +1,283 @@
1
+
2
+ import torch as th
3
+ import warnings
4
+
5
+
6
+ class TruncatedNewtonProjector:
7
+ def __init__(self, device, dtype, **kwargs):
8
+ self.device = device
9
+ self.rho = th.zeros(1, device=device, dtype=dtype)
10
+ self.debug = kwargs.get('debug', False)
11
+
12
+ def project(self, gamma_C, log_r, log_c, eps_d, u, v):
13
+ """
14
+ Project onto the set of couplings that satisfy the marginal constraints.
15
+ :param gamma_C: The cost matrix scaled by gamma.
16
+ :param log_r:
17
+ """
18
+ logs = {
19
+ "errs": [],
20
+ 'ls_func_cnt': 0,
21
+ 'chisinkhorn_steps': 0,
22
+ 'newtonsolve_steps': 0,
23
+ "deltas": [], # Ratios of actual to theoretically predicted (ideal) reduction in gradient norm.
24
+ "all_newtonsolve_steps": []
25
+ }
26
+ # In case of errors or issues, 10 times the tolerance level is considered
27
+ # a good enough solution to keep MDOT going.
28
+ success_fn = lambda err_: err_ < 10 * eps_d
29
+
30
+ # Each LSE operation costs 4 * n^2 operations.
31
+ self.LSE_r = lambda v_: th.logsumexp(v_.unsqueeze(-2) - gamma_C, dim=-1)
32
+ self.LSE_c = lambda u_: th.logsumexp(u_.unsqueeze(-1) - gamma_C, dim=-2)
33
+
34
+ r = log_r.exp()
35
+ c = log_c.exp()
36
+
37
+ log_c_P = v + self.LSE_c(u)
38
+ v += log_c - log_c_P # Ensure c=c(P)
39
+ log_r_P = u + self.LSE_r(v)
40
+ k = 8
41
+
42
+ u, v, log_r_P, err, k_ = self.chi_sinkhorn(u, v, log_r, log_c, log_r_P, eps_d ** (2 / 5))
43
+ r_P = log_r_P.exp()
44
+ logs["errs"].append(err)
45
+ logs["chisinkhorn_steps"] = k_
46
+ k += k_
47
+
48
+ num_iter = 0
49
+
50
+ while err > eps_d:
51
+ num_iter += 1
52
+
53
+ beta = 0.5
54
+ eta_k = th.max(err, 0.9 * (eps_d / err))
55
+
56
+ grad_k = r_P - r
57
+ self.rho = th.max(th.zeros_like(self.rho), self.rho)
58
+
59
+ P = th.exp(u.unsqueeze(-1) + v.unsqueeze(-2) - gamma_C)
60
+ diag_PPc = ((P ** 2) / c.unsqueeze(-2)).sum(-1)
61
+ k += 8
62
+ delta_u, delta_v, matmul_cnt, rho, pcg_success = self.newton_solve(
63
+ P, c, diag_PPc, grad_k, r_P, err, beta, eta_k, maxIter=5000)
64
+ del P # Free up memory
65
+ if not pcg_success:
66
+ k += matmul_cnt
67
+ logs["n_iter"] = k
68
+ msg = "PCG did not converge. TruncatedNewton returning with success={}".format(success_fn(err))
69
+ warnings.warn(msg)
70
+ return u, v, logs, success_fn(err)
71
+
72
+ self.rho = th.max(th.zeros_like(self.rho), 1. - (1. - rho) * 4.)
73
+ k += matmul_cnt
74
+ logs["newtonsolve_steps"] += matmul_cnt
75
+
76
+ alpha = th.ones_like(self.rho)
77
+ log_c_P = v + alpha * delta_v + self.LSE_c(u + alpha * delta_u)
78
+ k += 4
79
+ linear_decr = -(grad_k * delta_u).sum(-1, keepdim=True)
80
+ if not linear_decr > 0:
81
+ logs["n_iter"] = k
82
+ msg = "Linear decrease condition not satisfied. TruncatedNewton returning with success={}".format(
83
+ success_fn(err))
84
+ warnings.warn(msg)
85
+ return u, v, logs, success_fn(err)
86
+
87
+ armijo = log_c_P.exp().sum(-1, keepdim=True) - 1 <= 0.99 * alpha * linear_decr
88
+ while not armijo: # Check armijo condition for batch elements where err > eps_d
89
+ alpha *= 0.5
90
+ if alpha < 1e-9:
91
+ logs["n_iter"] = k
92
+ msg = "Line search did not converge. TruncatedNewton returning with success={}".format(
93
+ success_fn(err))
94
+ warnings.warn(msg)
95
+ return u, v, logs, success_fn(err)
96
+
97
+ log_c_P = v + alpha * delta_v + self.LSE_c(u + alpha * delta_u)
98
+ k += 4
99
+ logs["ls_func_cnt"] += 4
100
+ armijo = log_c_P.exp().sum(-1, keepdim=True) - 1 <= 0.99 * alpha * linear_decr
101
+
102
+ u += alpha * delta_u
103
+ v += alpha * delta_v
104
+
105
+ # The below error (before the Sinkhorn update) is used
106
+ # to measure the progress of the algorithm with TruncatedNewton steps.
107
+ err_before_sk = (c - log_c_P.exp()).abs().sum(-1)
108
+ err_before_sk += (r - (u + self.LSE_r(v)).exp()).abs().sum(-1)
109
+
110
+ # Sinkhorn update to ensure c=c(P).
111
+ v += log_c - log_c_P
112
+
113
+ log_r_P = u + self.LSE_r(v)
114
+ k += 4
115
+
116
+ u, v, log_r_P, err, k_ = self.chi_sinkhorn(u, v, log_r, log_c, log_r_P, eps_d ** (2 / 5))
117
+ r_P = log_r_P.exp()
118
+ logs["chisinkhorn_steps"] += k_
119
+ k += k_
120
+
121
+ logs["errs"].append(err)
122
+ logs["deltas"].append(th.min((logs["errs"][-2] - err_before_sk) / ((1 - eta_k) * logs["errs"][-2])).item())
123
+
124
+ if u.isnan().any() or v.isnan().any():
125
+ raise ValueError("NaNs encountered in u or v")
126
+
127
+ logs["n_iter"] = k
128
+
129
+ # Since we already computed log_r_P, we can use it to perform one last Sinkhorn update on rows.
130
+ delta_u = log_r - log_r_P
131
+ u += delta_u
132
+
133
+ return u, v, logs, True
134
+
135
+ def chi_sinkhorn(self, u, v, log_r, log_c, log_r_P, eps_chi, maxOps=float('inf')):
136
+ k = 0
137
+ r = log_r.exp()
138
+ err = (r - log_r_P.exp()).norm(p=1, dim=-1)
139
+ r_P = log_r_P.exp()
140
+ chi_squared = ((r - r_P) ** 2 / r_P).sum(-1)
141
+
142
+ while chi_squared > eps_chi and k < maxOps:
143
+ u += log_r - log_r_P
144
+
145
+ log_c_P = v + self.LSE_c(u)
146
+ v += log_c - log_c_P
147
+
148
+ log_r_P = u + self.LSE_r(v)
149
+ r_P = log_r_P.exp()
150
+
151
+ err = (r - r_P).norm(p=1, dim=-1)
152
+
153
+ chi_squared = ((r - r_P) ** 2 / r_P).sum(-1)
154
+ k += 8
155
+
156
+ if k >= maxOps:
157
+ raise ValueError("Chi-Sinkhorn did not converge in maxIter={} steps".format(maxOps))
158
+
159
+ return u, v, log_r_P, err, k
160
+
161
+ def newton_solve(self, P, c, diag_PPc, grad_k, r_P, err, beta=0.5, eta_k=0.5, maxIter=500):
162
+ rho = self.rho
163
+ tol = err * eta_k
164
+
165
+ def matmul_PPc(x_):
166
+ return (P @ ((x_.unsqueeze(-2) @ P).transpose(-2, -1) / c.unsqueeze(-1))).squeeze(-1)
167
+
168
+ # mml = th.compile(matmul_PPc)
169
+ mml = matmul_PPc
170
+
171
+ M = lambda rho_: (r_P - rho_ * diag_PPc) # Diagonal preconditioner
172
+ M_rho = M(th.ones_like(self.rho))
173
+ M_rho[M_rho <= 0] = M_rho[M_rho > 0].min()
174
+
175
+ x0 = -grad_k / M_rho
176
+ PPc_x0 = mml(x0)
177
+ matmul_cnt = 2
178
+ r_P_x0 = (r_P * x0)
179
+
180
+ x = x0.clone()
181
+ PPc_x = PPc_x0.clone()
182
+ r_P_x = r_P_x0.clone()
183
+
184
+ res_true = r_P_x0 - PPc_x + grad_k
185
+
186
+ linear_decr = (x * -grad_k).sum(-1)
187
+ if linear_decr <= 0:
188
+ raise ValueError("Linear decrease condition not satisfied")
189
+
190
+ r_true_norm = res_true.norm(p=1, dim=-1)
191
+ best_sol = x.clone()
192
+ best_r_true_norm = r_true_norm.clone()
193
+
194
+ done = False
195
+ success = True
196
+
197
+ while best_r_true_norm > tol:
198
+ best_sol[r_true_norm < best_r_true_norm] = x[r_true_norm < best_r_true_norm]
199
+ best_r_true_norm = th.min(r_true_norm, best_r_true_norm)
200
+
201
+ rho[r_true_norm > tol] = 1. - (1. - rho[r_true_norm > tol]) * 0.25
202
+ M_rho = M(rho)
203
+
204
+ if matmul_cnt > 0:
205
+ x = x0.clone()
206
+ PPc_x = PPc_x0.clone()
207
+ r_P_x = r_P_x0.clone()
208
+
209
+ Fr_x = r_P_x - rho * PPc_x
210
+ res = Fr_x + grad_k
211
+
212
+ res_true = r_P_x - PPc_x + grad_k
213
+ r_true_norm = res_true.norm(p=1, dim=-1)
214
+ best_r_true_norm = th.min(r_true_norm, best_r_true_norm)
215
+ linear_decr = (x * -grad_k).sum(-1)
216
+ if (best_r_true_norm < tol).all() and (linear_decr > 0).all():
217
+ break
218
+
219
+ y = res / M_rho
220
+ p = -y.clone()
221
+ ry_old = (res * y).sum(-1, keepdim=True)
222
+
223
+ r_norm = res.norm(p=1, dim=-1)
224
+ while (r_norm > 0.5 * (1 - beta) * tol)[best_r_true_norm > tol].any():
225
+ PPc_p = mml(p)
226
+ matmul_cnt += 2
227
+ Fr_p = (r_P * p) - rho * PPc_p
228
+
229
+ quad = (Fr_p * p).sum(-1, keepdim=True)
230
+ if (quad <= 0)[best_r_true_norm > tol].any():
231
+ warnings.warn("Warning: negative curvature encountered in CG. Returning best solution. "
232
+ "Residual norm less than error: {}".format((best_r_true_norm < err).item()))
233
+ x = best_sol.clone()
234
+ done = True
235
+ success = best_r_true_norm < err
236
+ warnings.warn("Resetting discount factor rho = 0")
237
+ rho = th.zeros_like(self.rho)
238
+ break
239
+
240
+ alpha = ry_old / quad
241
+ x += alpha * p
242
+ res += alpha * Fr_p
243
+ r_norm = res.norm(p=1, dim=-1)
244
+
245
+ if th.isnan(r_norm)[best_r_true_norm > tol].any() or th.isinf(r_norm)[best_r_true_norm > tol].any():
246
+ raise ValueError("NaNs or infs encountered in r_norm")
247
+
248
+ PPc_x += alpha * PPc_p
249
+
250
+ r_P_x = (r_P * x)
251
+ res_true = r_P_x - PPc_x + grad_k
252
+ r_true_norm = res_true.norm(p=1, dim=-1)
253
+ best_sol[r_true_norm < best_r_true_norm] = x[r_true_norm < best_r_true_norm]
254
+ best_r_true_norm = th.min(r_true_norm, best_r_true_norm)
255
+
256
+ linear_decr = (x * -grad_k).sum(-1)
257
+ if (best_r_true_norm <= tol).all() and (linear_decr > 0).all():
258
+ done = True
259
+ success = True
260
+ break
261
+
262
+ if matmul_cnt > 2 * maxIter:
263
+ warnings.warn("PCG did not converge.")
264
+ done = True
265
+ success = False
266
+ break
267
+
268
+ y = res / M_rho
269
+ ry_new = (res * y).sum(-1, keepdim=True)
270
+ p = -y + (ry_new / ry_old) * p
271
+ ry_old = ry_new.clone()
272
+
273
+ if done:
274
+ break
275
+
276
+ if r_true_norm <= tol:
277
+ success = True
278
+
279
+ x = best_sol
280
+ Pc_x = ((x.unsqueeze(-2) @ P).transpose(-2, -1) / c.unsqueeze(-1)).squeeze(-1)
281
+ matmul_cnt += 1
282
+
283
+ return x, -Pc_x, matmul_cnt, rho, success
@@ -0,0 +1,26 @@
1
+ # Non-Commercial Research License (NCRL-1.0)
2
+
3
+ Copyright (C) 2025 Mete Kemertas
4
+
5
+ ## 1. License Grant
6
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to use, copy, modify, merge, publish, and distribute copies of the Software **solely for non-commercial research, educational, and personal purposes**, subject to the following conditions:
7
+
8
+ ## 2. Restrictions
9
+ ### 2.1 **Non-Commercial Use Only**
10
+ - The Software **may NOT** be used for any commercial purpose without explicit written permission from the Licensor.
11
+ - "Commercial purpose" includes, but is not limited to:
12
+ - Selling or licensing the Software.
13
+ - Using the Software in proprietary products or services.
14
+ - Offering the Software as part of a paid or revenue-generating service.
15
+
16
+ ### 2.2 **No Warranty & Liability**
17
+ THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY ARISING FROM THE USE OF THE SOFTWARE.
18
+
19
+ ### 2.3 **Commercial Licensing**
20
+ For commercial use, a separate license must be obtained from the Licensor. To inquire about licensing, please contact: **kemertas@cs.toronto.edu**.
21
+
22
+ ## 3. Termination
23
+ This license automatically terminates if the Licensee breaches any of its terms. Upon termination, all rights granted under this license are revoked, and the Licensee must cease using and distributing the Software.
24
+
25
+ ## 4. Governing Law and Enforcement
26
+ This license shall be governed by and construed in accordance with the laws of Ontario, Canada. However, violations of this license may also be pursued un
@@ -0,0 +1,71 @@
1
+ Metadata-Version: 2.1
2
+ Name: mdot-tnt
3
+ Version: 0.1.0
4
+ Summary: A fast, GPU-parallel, PyTorch-compatible optimal transport solver.
5
+ Author-email: Mete Kemertas <kemertas@cs.toronto.edu>
6
+ Project-URL: Homepage, https://github.com/metekemertas/mdot_tnt
7
+ Requires-Python: >=3.7
8
+ Description-Content-Type: text/markdown
9
+ License-File: LICENSE
10
+
11
+ This is the official repository for the MDOT-TruncatedNewton (or MDOT-TNT)
12
+ algorithm [1] for solving the entropic-regularized optimal transport (OT) problem.
13
+ In addition to being GPU-friendly, the algorithm is stable under weak regularization and can therefore find highly
14
+ precise approximations of the un-regularized problem's solution quickly.
15
+
16
+ The current implementation is based on PyTorch and is compatible with both CPU and GPU. PyTorch is the only dependency.
17
+
18
+
19
+ For installation:
20
+ First, install PyTorch following the instructions at https://pytorch.org/get-started/locally/ to select the version that matches your system's configuration.
21
+ ```bash
22
+ pip install mdot_tnt
23
+ ```
24
+
25
+ Quickstart guide:
26
+ ```
27
+ import mdot_tnt
28
+ import torch as th
29
+ device = 'cuda' if th.cuda.is_available() else 'cpu'
30
+ N, M, dim = 100, 200, 128
31
+
32
+ # Sample row and column marginals from Dirichlet distributions
33
+ r = th.distributions.Dirichlet(th.ones(N)).sample()
34
+ c = th.distributions.Dirichlet(th.ones(M)).sample()
35
+
36
+ # Cost matrix from pairwise Euclidean distances of random points in R^100
37
+ x = th.distributions.MultivariateNormal(th.zeros(dim), th.eye(dim)).sample((N,))
38
+ y = th.distributions.MultivariateNormal(th.zeros(dim), th.eye(dim)).sample((M,))
39
+ C = th.cdist(x, y, p=2)
40
+ C /= C.max() # Normalize cost matrix to meet convention.
41
+
42
+ # Use double precision for numerical stability in high precision regime.
43
+ r, c, C = r.double(), c.double(), C.double()
44
+
45
+ # Solve OT problem. Increase (decrease) gamma_f for higher (lower) precision.
46
+ # Default is gamma_f=2**12. Expect error of order logn / gamma_f at worst, and possibly lower.
47
+ cost = mdot_tnt.solve_OT(r, c, C, gamma_f=2**12)
48
+
49
+ # To return a feasible transport plan, use the following:
50
+ transport_plan = mdot_tnt.solve_OT(r, c, C, gamma_f=2**12, return_plan=True)
51
+
52
+ # In both cases, the default rounding onto the feasible set can be disabled by setting `round=False`.
53
+ ```
54
+
55
+ The code is released under a custom non-commerical use license. If you use our work in
56
+ your research, please consider citing:
57
+
58
+ ```
59
+ @inproceedings{
60
+ kemertas2025a,
61
+ title={A Truncated Newton Method for Optimal Transport},
62
+ author={Mete Kemertas and Amir-massoud Farahmand and Allan Douglas Jepson},
63
+ booktitle={The Thirteenth International Conference on Learning Representations},
64
+ year={2025},
65
+ url={https://openreview.net/forum?id=gWrWUaCbMa}
66
+ }
67
+ ```
68
+
69
+ For inquiries, email: kemertas [at] cs [dot] toronto [dot] edu
70
+
71
+ [1] Mete Kemertas, Amir-massoud Farahmand, Allan Douglas Jepson. "A Truncated Newton Method for Optimal Transport." The Thirteenth International Conference on Learning Representations (ICLR), 2025. https://openreview.net/forum?id=gWrWUaCbMa
@@ -0,0 +1,9 @@
1
+ mdot_tnt/__init__.py,sha256=zi0ujbN19nv_dt8wdxPQgsOuDi-4gIg4sq6y9fDfAkI,2657
2
+ mdot_tnt/mdot.py,sha256=EUch7ECZxDWDoNW3S84qyCx6-qbG6Mu1_XlWow862Wc,4620
3
+ mdot_tnt/rounding.py,sha256=Q7QBPsFzBqnMZKnlV147ruStUCme6gwt6HnjTjqBezk,2405
4
+ mdot_tnt/truncated_newton.py,sha256=YEkWJxL5H4XsuWQzfdZ0E5Ju3Npk3kKoe2l_go8wq2A,10314
5
+ mdot_tnt-0.1.0.dist-info/LICENSE,sha256=sXw3FpVqouAddNhfwD6nXaSgABFyLnXuSn_ghjU0AhY,1837
6
+ mdot_tnt-0.1.0.dist-info/METADATA,sha256=2Ou4X10K6bUGHQhnMXoFjrv-4YUeIyreSKibMLM2gq0,2972
7
+ mdot_tnt-0.1.0.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
8
+ mdot_tnt-0.1.0.dist-info/top_level.txt,sha256=HmxTNtoLH7F20hgZVFdfUowIQ2fviSX64wSG1HP8Iao,9
9
+ mdot_tnt-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (75.3.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1 @@
1
+ mdot_tnt