mdot-tnt 0.2.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mdot_tnt/__init__.py +41 -10
- mdot_tnt/batched.py +634 -0
- mdot_tnt/mdot.py +97 -43
- mdot_tnt/py.typed +0 -0
- mdot_tnt/rounding.py +41 -15
- mdot_tnt/truncated_newton.py +104 -34
- mdot_tnt-1.0.0.dist-info/METADATA +216 -0
- mdot_tnt-1.0.0.dist-info/RECORD +11 -0
- {mdot_tnt-0.2.0.dist-info → mdot_tnt-1.0.0.dist-info}/WHEEL +1 -1
- {mdot_tnt-0.2.0.dist-info → mdot_tnt-1.0.0.dist-info/licenses}/LICENSE +4 -1
- mdot_tnt-0.2.0.dist-info/METADATA +0 -71
- mdot_tnt-0.2.0.dist-info/RECORD +0 -9
- {mdot_tnt-0.2.0.dist-info → mdot_tnt-1.0.0.dist-info}/top_level.txt +0 -0
mdot_tnt/batched.py
ADDED
|
@@ -0,0 +1,634 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Batched MDOT-TNT solver for solving multiple optimal transport problems simultaneously.
|
|
3
|
+
|
|
4
|
+
This module provides batched versions of the MDOT-TNT solver that achieve significant
|
|
5
|
+
speedups (5-10x) over sequential solving by amortizing GPU synchronization overhead
|
|
6
|
+
across all problems in a batch.
|
|
7
|
+
|
|
8
|
+
Key insight: The main solver has many Python while-loops that check convergence,
|
|
9
|
+
each requiring a GPU→CPU sync. By batching N problems together, we do one sync
|
|
10
|
+
per iteration for the entire batch instead of N syncs.
|
|
11
|
+
|
|
12
|
+
Supports:
|
|
13
|
+
- Multiple marginal pairs with shared cost matrix: r, c shape (batch, n), C shape (n, m)
|
|
14
|
+
- Multiple OT problems with different costs: r, c shape (batch, n), C shape (batch, n, m)
|
|
15
|
+
|
|
16
|
+
Example usage:
|
|
17
|
+
>>> import torch
|
|
18
|
+
>>> from mdot_tnt.batched import solve_OT_batched
|
|
19
|
+
>>>
|
|
20
|
+
>>> # 32 problems, each 512-dimensional
|
|
21
|
+
>>> r = torch.rand(32, 512, device='cuda', dtype=torch.float64)
|
|
22
|
+
>>> r = r / r.sum(dim=-1, keepdim=True)
|
|
23
|
+
>>> c = torch.rand(32, 512, device='cuda', dtype=torch.float64)
|
|
24
|
+
>>> c = c / c.sum(dim=-1, keepdim=True)
|
|
25
|
+
>>> C = torch.rand(512, 512, device='cuda', dtype=torch.float64) # Shared cost
|
|
26
|
+
>>>
|
|
27
|
+
>>> costs = solve_OT_batched(r, c, C, gamma_f=1024.)
|
|
28
|
+
>>> print(costs.shape) # (32,)
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
import warnings
|
|
32
|
+
from typing import Any, Dict, Optional, Tuple, Union
|
|
33
|
+
|
|
34
|
+
import torch as th
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class BatchedTruncatedNewtonProjector:
|
|
38
|
+
"""
|
|
39
|
+
Batched Truncated Newton projector for the MDOT algorithm.
|
|
40
|
+
|
|
41
|
+
Projects onto the set of couplings satisfying marginal constraints,
|
|
42
|
+
processing multiple problems simultaneously for efficiency.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(self, device: th.device, dtype: th.dtype, **kwargs):
|
|
46
|
+
"""
|
|
47
|
+
Initialize the projector.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
device: PyTorch device for computations.
|
|
51
|
+
dtype: Data type for tensors.
|
|
52
|
+
**kwargs: Additional options (debug: bool for verbose output).
|
|
53
|
+
"""
|
|
54
|
+
self.device = device
|
|
55
|
+
self.dtype = dtype
|
|
56
|
+
self.debug = kwargs.get("debug", False)
|
|
57
|
+
|
|
58
|
+
def project(
|
|
59
|
+
self,
|
|
60
|
+
gamma_C: th.Tensor,
|
|
61
|
+
log_r: th.Tensor,
|
|
62
|
+
log_c: th.Tensor,
|
|
63
|
+
eps_d: Union[float, th.Tensor],
|
|
64
|
+
u: th.Tensor,
|
|
65
|
+
v: th.Tensor,
|
|
66
|
+
active_mask: Optional[th.Tensor] = None,
|
|
67
|
+
) -> Tuple[th.Tensor, th.Tensor, Dict[str, Any], th.Tensor]:
|
|
68
|
+
"""
|
|
69
|
+
Project onto the constraint set for all problems in the batch.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
gamma_C: (batch, n, m) or (n, m) cost matrix scaled by gamma.
|
|
73
|
+
log_r: (batch, n) log of row marginals.
|
|
74
|
+
log_c: (batch, m) log of column marginals.
|
|
75
|
+
eps_d: Convergence tolerance, scalar or (batch,) tensor.
|
|
76
|
+
u: (batch, n) initial row dual variables.
|
|
77
|
+
v: (batch, m) initial column dual variables.
|
|
78
|
+
active_mask: (batch,) bool tensor, True for problems to process.
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
u: (batch, n) updated row dual variables.
|
|
82
|
+
v: (batch, m) updated column dual variables.
|
|
83
|
+
logs: Dictionary with optimization statistics.
|
|
84
|
+
success: (batch,) bool tensor indicating convergence per problem.
|
|
85
|
+
"""
|
|
86
|
+
batch_size = u.shape[0]
|
|
87
|
+
|
|
88
|
+
if active_mask is None:
|
|
89
|
+
active_mask = th.ones(batch_size, device=self.device, dtype=th.bool)
|
|
90
|
+
|
|
91
|
+
# Normalize eps_d to (batch,) tensor
|
|
92
|
+
eps_d = self._to_batch_tensor(eps_d, batch_size)
|
|
93
|
+
|
|
94
|
+
logs: Dict[str, Any] = {"n_iter": 0, "errs": [], "deltas": []}
|
|
95
|
+
|
|
96
|
+
# Handle shared vs per-problem cost matrix
|
|
97
|
+
if gamma_C.dim() == 2:
|
|
98
|
+
gamma_C = gamma_C.unsqueeze(0)
|
|
99
|
+
|
|
100
|
+
r = log_r.exp()
|
|
101
|
+
c = log_c.exp()
|
|
102
|
+
|
|
103
|
+
# Define batched LSE operations
|
|
104
|
+
def LSE_r(v_):
|
|
105
|
+
return th.logsumexp(v_.unsqueeze(-2) - gamma_C, dim=-1)
|
|
106
|
+
|
|
107
|
+
def LSE_c(u_):
|
|
108
|
+
return th.logsumexp(u_.unsqueeze(-1) - gamma_C, dim=-2)
|
|
109
|
+
|
|
110
|
+
# Initial Sinkhorn step to ensure c = c(P)
|
|
111
|
+
log_c_P = v + LSE_c(u)
|
|
112
|
+
v = v + log_c - log_c_P
|
|
113
|
+
log_r_P = u + LSE_r(v)
|
|
114
|
+
k = 8
|
|
115
|
+
|
|
116
|
+
# Chi-Sinkhorn initialization phase
|
|
117
|
+
u, v, log_r_P, err = self._chi_sinkhorn_batched(
|
|
118
|
+
u, v, log_r, log_c, log_r_P, eps_d ** (2 / 5), LSE_r, LSE_c, active_mask
|
|
119
|
+
)
|
|
120
|
+
r_P = log_r_P.exp()
|
|
121
|
+
logs["errs"].append(err.max().item())
|
|
122
|
+
k += 8 * 10
|
|
123
|
+
|
|
124
|
+
converged = err <= eps_d
|
|
125
|
+
success = converged.clone()
|
|
126
|
+
|
|
127
|
+
num_iter = 0
|
|
128
|
+
max_iter = 100
|
|
129
|
+
|
|
130
|
+
# Main Newton loop
|
|
131
|
+
while (~converged & active_mask).any() and num_iter < max_iter:
|
|
132
|
+
num_iter += 1
|
|
133
|
+
working = ~converged & active_mask
|
|
134
|
+
|
|
135
|
+
eta_k = th.clamp(err, min=0.9 * eps_d / err.clamp(min=1e-30))
|
|
136
|
+
grad_k = r_P - r
|
|
137
|
+
|
|
138
|
+
# Compute transport plan for Hessian
|
|
139
|
+
P = th.exp(u.unsqueeze(-1) + v.unsqueeze(-2) - gamma_C)
|
|
140
|
+
diag_PPc = ((P**2) / c.unsqueeze(-2)).sum(-1)
|
|
141
|
+
k += 8
|
|
142
|
+
|
|
143
|
+
# Newton solve
|
|
144
|
+
delta_u, delta_v, matmul_cnt, pcg_success = self._newton_solve_batched(
|
|
145
|
+
P, c, diag_PPc, grad_k, r_P, err, eta_k, working
|
|
146
|
+
)
|
|
147
|
+
success = success & (pcg_success | ~working)
|
|
148
|
+
k += matmul_cnt
|
|
149
|
+
|
|
150
|
+
# Line search with Armijo condition
|
|
151
|
+
alpha = th.ones(batch_size, device=self.device, dtype=self.dtype)
|
|
152
|
+
log_c_P = v + alpha.unsqueeze(-1) * delta_v + LSE_c(u + alpha.unsqueeze(-1) * delta_u)
|
|
153
|
+
k += 4
|
|
154
|
+
|
|
155
|
+
linear_decr = -(grad_k * delta_u).sum(-1)
|
|
156
|
+
armijo = (log_c_P.exp().sum(-1) - 1) <= (0.99 * alpha * linear_decr)
|
|
157
|
+
armijo = armijo | ~working
|
|
158
|
+
|
|
159
|
+
ls_iter = 0
|
|
160
|
+
while not armijo.all() and ls_iter < 20:
|
|
161
|
+
alpha = th.where(armijo, alpha, alpha * 0.5)
|
|
162
|
+
log_c_P = (
|
|
163
|
+
v + alpha.unsqueeze(-1) * delta_v + LSE_c(u + alpha.unsqueeze(-1) * delta_u)
|
|
164
|
+
)
|
|
165
|
+
k += 4
|
|
166
|
+
armijo = (log_c_P.exp().sum(-1) - 1) <= (0.99 * alpha * linear_decr)
|
|
167
|
+
armijo = armijo | ~working
|
|
168
|
+
ls_iter += 1
|
|
169
|
+
|
|
170
|
+
# Update dual variables for working problems
|
|
171
|
+
u = th.where(working.unsqueeze(-1), u + alpha.unsqueeze(-1) * delta_u, u)
|
|
172
|
+
v = th.where(working.unsqueeze(-1), v + alpha.unsqueeze(-1) * delta_v, v)
|
|
173
|
+
|
|
174
|
+
# Sinkhorn correction
|
|
175
|
+
v = th.where(working.unsqueeze(-1), v + log_c - log_c_P, v)
|
|
176
|
+
|
|
177
|
+
log_r_P = u + LSE_r(v)
|
|
178
|
+
k += 4
|
|
179
|
+
|
|
180
|
+
# Chi-Sinkhorn refinement
|
|
181
|
+
u, v, log_r_P, err = self._chi_sinkhorn_batched(
|
|
182
|
+
u, v, log_r, log_c, log_r_P, eps_d ** (2 / 5), LSE_r, LSE_c, working
|
|
183
|
+
)
|
|
184
|
+
r_P = log_r_P.exp()
|
|
185
|
+
logs["errs"].append(err.max().item())
|
|
186
|
+
|
|
187
|
+
converged = converged | (err <= eps_d)
|
|
188
|
+
|
|
189
|
+
logs["n_iter"] = k
|
|
190
|
+
|
|
191
|
+
# Final row update
|
|
192
|
+
delta_u = log_r - log_r_P
|
|
193
|
+
u = u + delta_u
|
|
194
|
+
|
|
195
|
+
success = success | converged
|
|
196
|
+
return u, v, logs, success
|
|
197
|
+
|
|
198
|
+
def _to_batch_tensor(self, val: Union[float, th.Tensor], batch_size: int) -> th.Tensor:
|
|
199
|
+
"""Convert scalar or tensor to (batch,) shaped tensor."""
|
|
200
|
+
if not isinstance(val, th.Tensor):
|
|
201
|
+
val = th.tensor(val, device=self.device, dtype=self.dtype)
|
|
202
|
+
if val.dim() == 0:
|
|
203
|
+
val = val.expand(batch_size)
|
|
204
|
+
return val
|
|
205
|
+
|
|
206
|
+
def _chi_sinkhorn_batched(
|
|
207
|
+
self, u, v, log_r, log_c, log_r_P, eps_chi, LSE_r, LSE_c, active_mask, max_iter=100
|
|
208
|
+
):
|
|
209
|
+
"""Batched chi-squared Sinkhorn iterations for initialization."""
|
|
210
|
+
r = log_r.exp()
|
|
211
|
+
r_P = log_r_P.exp()
|
|
212
|
+
|
|
213
|
+
err = (r - r_P).abs().sum(-1)
|
|
214
|
+
chi_squared = ((r - r_P) ** 2 / r_P.clamp(min=1e-30)).sum(-1)
|
|
215
|
+
|
|
216
|
+
eps_chi = self._to_batch_tensor(eps_chi, u.shape[0])
|
|
217
|
+
working = (chi_squared > eps_chi) & active_mask
|
|
218
|
+
|
|
219
|
+
for _ in range(max_iter):
|
|
220
|
+
if not working.any():
|
|
221
|
+
break
|
|
222
|
+
|
|
223
|
+
delta_u = log_r - log_r_P
|
|
224
|
+
u = th.where(working.unsqueeze(-1), u + delta_u, u)
|
|
225
|
+
|
|
226
|
+
log_c_P = v + LSE_c(u)
|
|
227
|
+
delta_v = log_c - log_c_P
|
|
228
|
+
v = th.where(working.unsqueeze(-1), v + delta_v, v)
|
|
229
|
+
|
|
230
|
+
log_r_P = u + LSE_r(v)
|
|
231
|
+
r_P = log_r_P.exp()
|
|
232
|
+
|
|
233
|
+
err = (r - r_P).abs().sum(-1)
|
|
234
|
+
chi_squared = ((r - r_P) ** 2 / r_P.clamp(min=1e-30)).sum(-1)
|
|
235
|
+
working = (chi_squared > eps_chi) & active_mask
|
|
236
|
+
|
|
237
|
+
return u, v, log_r_P, err
|
|
238
|
+
|
|
239
|
+
def _newton_solve_batched(
|
|
240
|
+
self, P, c, diag_PPc, grad_k, r_P, err, eta_k, active_mask, max_iter=50
|
|
241
|
+
):
|
|
242
|
+
"""Batched preconditioned conjugate gradient Newton solve."""
|
|
243
|
+
tol = err * eta_k
|
|
244
|
+
|
|
245
|
+
# Diagonal preconditioner
|
|
246
|
+
M_rho = r_P - diag_PPc
|
|
247
|
+
M_rho = th.where(M_rho > 0, M_rho, M_rho.clamp(min=1e-10))
|
|
248
|
+
|
|
249
|
+
x = -grad_k / M_rho
|
|
250
|
+
r_vec = r_P * x - self._batched_PPc_matmul(P, c, x) + grad_k
|
|
251
|
+
matmul_cnt = 2
|
|
252
|
+
|
|
253
|
+
y = r_vec / M_rho
|
|
254
|
+
p = -y.clone()
|
|
255
|
+
ry_old = (r_vec * y).sum(-1, keepdim=True)
|
|
256
|
+
|
|
257
|
+
for _ in range(max_iter):
|
|
258
|
+
PPc_p = self._batched_PPc_matmul(P, c, p)
|
|
259
|
+
matmul_cnt += 2
|
|
260
|
+
|
|
261
|
+
Fr_p = r_P * p - PPc_p
|
|
262
|
+
quad = (Fr_p * p).sum(-1, keepdim=True)
|
|
263
|
+
quad = th.where(quad > 0, quad, th.ones_like(quad))
|
|
264
|
+
|
|
265
|
+
alpha = ry_old / quad
|
|
266
|
+
x = x + alpha * p
|
|
267
|
+
r_vec = r_vec + alpha * Fr_p
|
|
268
|
+
|
|
269
|
+
r_norm = r_vec.abs().sum(-1)
|
|
270
|
+
if (r_norm <= tol).all():
|
|
271
|
+
break
|
|
272
|
+
|
|
273
|
+
y = r_vec / M_rho
|
|
274
|
+
ry_new = (r_vec * y).sum(-1, keepdim=True)
|
|
275
|
+
p = -y + (ry_new / ry_old.clamp(min=1e-30)) * p
|
|
276
|
+
ry_old = ry_new
|
|
277
|
+
|
|
278
|
+
Pc_x = (x.unsqueeze(-2) @ P).squeeze(-2) / c
|
|
279
|
+
|
|
280
|
+
# Track convergence: success if residual norm is below tolerance
|
|
281
|
+
r_norm = r_vec.abs().sum(-1)
|
|
282
|
+
success = r_norm <= tol
|
|
283
|
+
|
|
284
|
+
return x, -Pc_x, matmul_cnt, success
|
|
285
|
+
|
|
286
|
+
def _batched_PPc_matmul(self, P, c, x):
|
|
287
|
+
"""Compute P @ (P^T @ x / c) efficiently in batched form."""
|
|
288
|
+
PTx = (x.unsqueeze(-1) * P).sum(-2)
|
|
289
|
+
PTx_over_c = PTx / c
|
|
290
|
+
return (PTx_over_c.unsqueeze(-2) * P).sum(-1)
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
def _batched_smooth_marginals(
|
|
294
|
+
r: th.Tensor, c: th.Tensor, eps: th.Tensor, w_r: float = 0.5, w_c: float = 0.5
|
|
295
|
+
) -> Tuple[th.Tensor, th.Tensor]:
|
|
296
|
+
"""
|
|
297
|
+
Smooth marginals by mixing with uniform distribution.
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
r: (batch, n) row marginals.
|
|
301
|
+
c: (batch, m) column marginals.
|
|
302
|
+
eps: (batch,) or scalar smoothing factor.
|
|
303
|
+
w_r, w_c: Weights for row/column smoothing (must sum to 1).
|
|
304
|
+
|
|
305
|
+
Returns:
|
|
306
|
+
r_hat, c_hat: Smoothed marginals.
|
|
307
|
+
"""
|
|
308
|
+
eps = eps.clamp(max=1.0)
|
|
309
|
+
if eps.dim() == 0:
|
|
310
|
+
eps = eps.unsqueeze(0)
|
|
311
|
+
eps = eps.unsqueeze(-1)
|
|
312
|
+
|
|
313
|
+
r_hat = (1 - w_r * eps) * r + w_r * eps / r.size(-1)
|
|
314
|
+
c_hat = (1 - w_c * eps) * c + w_c * eps / c.size(-1)
|
|
315
|
+
|
|
316
|
+
return r_hat, c_hat
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def _batched_mdot(
|
|
320
|
+
r: th.Tensor,
|
|
321
|
+
c: th.Tensor,
|
|
322
|
+
C: th.Tensor,
|
|
323
|
+
gamma_f: float,
|
|
324
|
+
gamma_i: float = 16,
|
|
325
|
+
p: float = 1.5,
|
|
326
|
+
q: float = 2.0,
|
|
327
|
+
) -> Tuple[th.Tensor, th.Tensor, th.Tensor, int, Dict[str, Any]]:
|
|
328
|
+
"""
|
|
329
|
+
Batched MDOT (Mirror Descent Optimal Transport) solver.
|
|
330
|
+
|
|
331
|
+
Solves multiple entropic-regularized OT problems simultaneously using
|
|
332
|
+
temperature annealing with truncated Newton projections.
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
r: (batch, n) row marginals.
|
|
336
|
+
c: (batch, m) column marginals.
|
|
337
|
+
C: (n, m) or (batch, n, m) cost matrix.
|
|
338
|
+
gamma_f: Final temperature (inverse regularization weight).
|
|
339
|
+
gamma_i: Initial temperature.
|
|
340
|
+
p: Exponent for the epsilon schedule.
|
|
341
|
+
q: Temperature annealing factor.
|
|
342
|
+
|
|
343
|
+
Returns:
|
|
344
|
+
u: (batch, n) optimal row dual variables.
|
|
345
|
+
v: (batch, m) optimal column dual variables.
|
|
346
|
+
gamma_final: (batch,) final temperature achieved per problem.
|
|
347
|
+
k_total: Total number of primitive operations.
|
|
348
|
+
logs: Optimization logs.
|
|
349
|
+
"""
|
|
350
|
+
batch_size = r.shape[0]
|
|
351
|
+
device = r.device
|
|
352
|
+
dtype = r.dtype
|
|
353
|
+
|
|
354
|
+
projector = BatchedTruncatedNewtonProjector(device=device, dtype=dtype)
|
|
355
|
+
|
|
356
|
+
# Compute entropy bounds for epsilon schedule
|
|
357
|
+
H_r = -(r * (r + 1e-30).log()).sum(-1)
|
|
358
|
+
H_c = -(c * (c + 1e-30).log()).sum(-1)
|
|
359
|
+
H_min = th.min(H_r, H_c)
|
|
360
|
+
eps_fn = lambda g_: H_min / (g_**p)
|
|
361
|
+
|
|
362
|
+
logs: Dict[str, Any] = {"proj_logs": [], "gammas": []}
|
|
363
|
+
|
|
364
|
+
gamma = min(gamma_i, gamma_f)
|
|
365
|
+
gamma_per_problem = th.full((batch_size,), gamma, device=device, dtype=dtype)
|
|
366
|
+
gamma_prev = th.zeros((batch_size,), device=device, dtype=dtype)
|
|
367
|
+
active_mask = th.ones(batch_size, device=device, dtype=th.bool)
|
|
368
|
+
|
|
369
|
+
# Initialize dual variables
|
|
370
|
+
eps_d = eps_fn(gamma)
|
|
371
|
+
r_hat, c_hat = _batched_smooth_marginals(r, c, eps_d / 2, w_r=0.9, w_c=0.1)
|
|
372
|
+
u_init = r_hat.log()
|
|
373
|
+
v_init = c_hat.log()
|
|
374
|
+
u_cur = u_init.clone()
|
|
375
|
+
v_cur = v_init.clone()
|
|
376
|
+
u_prev = u_cur.clone()
|
|
377
|
+
v_prev = v_cur.clone()
|
|
378
|
+
|
|
379
|
+
t = 1
|
|
380
|
+
max_outer_iter = 50
|
|
381
|
+
done_all: Any = False
|
|
382
|
+
|
|
383
|
+
while active_mask.any() and t < max_outer_iter and not done_all:
|
|
384
|
+
done = th.abs(gamma_per_problem - gamma_f) < 1e-5
|
|
385
|
+
done_all = (done | ~active_mask).all()
|
|
386
|
+
|
|
387
|
+
eps_d = eps_fn(gamma_per_problem)
|
|
388
|
+
r_hat, c_hat = _batched_smooth_marginals(r, c, eps_d / 2, w_r=0.9, w_c=0.1)
|
|
389
|
+
|
|
390
|
+
# Scale cost matrix by per-problem gamma
|
|
391
|
+
if C.dim() == 2:
|
|
392
|
+
gamma_C = gamma_per_problem.unsqueeze(-1).unsqueeze(-1) * C.unsqueeze(0)
|
|
393
|
+
else:
|
|
394
|
+
gamma_C = gamma_per_problem.unsqueeze(-1).unsqueeze(-1) * C
|
|
395
|
+
|
|
396
|
+
# Save previous values for warm-starting
|
|
397
|
+
u_prev = th.where(active_mask.unsqueeze(-1), u_cur.clone(), u_prev)
|
|
398
|
+
v_prev = th.where(active_mask.unsqueeze(-1), v_cur.clone(), v_prev)
|
|
399
|
+
|
|
400
|
+
# Project using warm-started initial values
|
|
401
|
+
u_new, v_new, proj_log, success = projector.project(
|
|
402
|
+
gamma_C, r_hat.log(), c_hat.log(), eps_d / 2, u_init, v_init, active_mask
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
u_cur = th.where(active_mask.unsqueeze(-1), u_new, u_cur)
|
|
406
|
+
v_cur = th.where(active_mask.unsqueeze(-1), v_new, v_cur)
|
|
407
|
+
|
|
408
|
+
logs["proj_logs"].append(proj_log)
|
|
409
|
+
|
|
410
|
+
# Store previous gamma for warm-starting
|
|
411
|
+
gamma_prev_old = gamma_prev.clone()
|
|
412
|
+
gamma_prev = gamma_per_problem.clone()
|
|
413
|
+
|
|
414
|
+
# Update gamma for non-converged problems
|
|
415
|
+
gamma_per_problem = th.where(
|
|
416
|
+
active_mask & ~done, th.clamp(gamma_per_problem * q, max=gamma_f), gamma_per_problem
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
# Warm-start initialization for next iteration (extrapolation)
|
|
420
|
+
# Uses linear extrapolation from the previous two iterates, similar to the
|
|
421
|
+
# unbatched solver in mdot.py. The extrapolation factor is clamped to [-2, 2]
|
|
422
|
+
# to prevent instability when gamma changes rapidly between iterations.
|
|
423
|
+
if not done_all:
|
|
424
|
+
# Avoid division by zero for first iteration (gamma_prev_old starts at 0)
|
|
425
|
+
denom = (gamma_prev - gamma_prev_old).clamp(min=1e-10)
|
|
426
|
+
extrap_factor = ((gamma_per_problem - gamma_prev) / denom).unsqueeze(-1)
|
|
427
|
+
extrap_factor = extrap_factor.clamp(-2.0, 2.0)
|
|
428
|
+
u_init = th.where(
|
|
429
|
+
active_mask.unsqueeze(-1) & (t > 1), u_cur + (u_cur - u_prev) * extrap_factor, u_cur
|
|
430
|
+
)
|
|
431
|
+
v_init = th.where(
|
|
432
|
+
active_mask.unsqueeze(-1) & (t > 1), v_cur + (v_cur - v_prev) * extrap_factor, v_cur
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
logs["gammas"].append(gamma_per_problem.clone())
|
|
436
|
+
t += 1
|
|
437
|
+
|
|
438
|
+
k_total = sum([log["n_iter"] for log in logs["proj_logs"]])
|
|
439
|
+
logs["success"] = active_mask
|
|
440
|
+
logs["outer_iterations"] = t - 1
|
|
441
|
+
|
|
442
|
+
return u_cur, v_cur, gamma_per_problem, k_total, logs
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
def _batched_round(P: th.Tensor, r: th.Tensor, c: th.Tensor) -> th.Tensor:
|
|
446
|
+
"""
|
|
447
|
+
Batched Altschuler rounding to project onto feasible transport plans.
|
|
448
|
+
|
|
449
|
+
Args:
|
|
450
|
+
P: (batch, n, m) approximate transport plans.
|
|
451
|
+
r: (batch, n) row marginals.
|
|
452
|
+
c: (batch, m) column marginals.
|
|
453
|
+
|
|
454
|
+
Returns:
|
|
455
|
+
P_rounded: (batch, n, m) feasible transport plans in U(r, c).
|
|
456
|
+
"""
|
|
457
|
+
# Scale rows
|
|
458
|
+
row_sums = P.sum(-1)
|
|
459
|
+
X = th.clamp(r / row_sums.clamp(min=1e-30), max=1.0)
|
|
460
|
+
P = P * X.unsqueeze(-1)
|
|
461
|
+
|
|
462
|
+
# Scale columns
|
|
463
|
+
col_sums = P.sum(-2)
|
|
464
|
+
Y = th.clamp(c / col_sums.clamp(min=1e-30), max=1.0)
|
|
465
|
+
P = P * Y.unsqueeze(-2)
|
|
466
|
+
|
|
467
|
+
# Fix remaining error with rank-1 correction
|
|
468
|
+
err_r = (r - P.sum(-1)).clamp(min=0)
|
|
469
|
+
err_c = (c - P.sum(-2)).clamp(min=0)
|
|
470
|
+
err_r_norm = err_r.norm(p=1, dim=-1, keepdim=True).unsqueeze(-1) + 1e-30
|
|
471
|
+
P = P + err_r.unsqueeze(-1) * err_c.unsqueeze(-2) / err_r_norm
|
|
472
|
+
|
|
473
|
+
return P
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
def _batched_rounded_cost(
|
|
477
|
+
u: th.Tensor, v: th.Tensor, r: th.Tensor, c: th.Tensor, C: th.Tensor, gamma: th.Tensor
|
|
478
|
+
) -> th.Tensor:
|
|
479
|
+
"""
|
|
480
|
+
Compute transport cost with rounding in log-domain (memory efficient).
|
|
481
|
+
|
|
482
|
+
This avoids materializing the full n×m transport plan for each problem.
|
|
483
|
+
|
|
484
|
+
Args:
|
|
485
|
+
u: (batch, n) row dual variables.
|
|
486
|
+
v: (batch, m) column dual variables.
|
|
487
|
+
r: (batch, n) row marginals.
|
|
488
|
+
c: (batch, m) column marginals.
|
|
489
|
+
C: (n, m) or (batch, n, m) cost matrix.
|
|
490
|
+
gamma: (batch,) temperature per problem.
|
|
491
|
+
|
|
492
|
+
Returns:
|
|
493
|
+
costs: (batch,) optimal transport costs.
|
|
494
|
+
"""
|
|
495
|
+
batch_size = u.shape[0]
|
|
496
|
+
|
|
497
|
+
if C.dim() == 2:
|
|
498
|
+
C = C.unsqueeze(0).expand(batch_size, -1, -1)
|
|
499
|
+
|
|
500
|
+
gamma = gamma.unsqueeze(-1).unsqueeze(-1)
|
|
501
|
+
|
|
502
|
+
# Row rounding in log domain
|
|
503
|
+
r_P_log = u + th.logsumexp(v.unsqueeze(-2) - gamma * C, dim=-1)
|
|
504
|
+
delta_u = th.clamp(r.log() - r_P_log, max=0)
|
|
505
|
+
u = u + delta_u
|
|
506
|
+
|
|
507
|
+
# Column rounding in log domain
|
|
508
|
+
c_P_log = v + th.logsumexp(u.unsqueeze(-1) - gamma * C, dim=-2)
|
|
509
|
+
delta_v = th.clamp(c.log() - c_P_log, max=0)
|
|
510
|
+
v = v + delta_v
|
|
511
|
+
|
|
512
|
+
# Compute row error for rank-1 correction
|
|
513
|
+
r_P_log = u + th.logsumexp(v.unsqueeze(-2) - gamma * C, dim=-1)
|
|
514
|
+
r_P = r_P_log.exp()
|
|
515
|
+
err_r = r - r_P
|
|
516
|
+
err_r_normalized = err_r / (err_r.abs().sum(-1, keepdim=True) + 1e-30)
|
|
517
|
+
|
|
518
|
+
# Column marginal after rounding
|
|
519
|
+
c_P_log = v + th.logsumexp(u.unsqueeze(-1) - gamma * C, dim=-2)
|
|
520
|
+
c_P = c_P_log.exp()
|
|
521
|
+
err_c = c - c_P
|
|
522
|
+
|
|
523
|
+
# Main cost term (in log domain for stability)
|
|
524
|
+
log_P = u.unsqueeze(-1) + v.unsqueeze(-2) - gamma * C
|
|
525
|
+
cost_main = th.logsumexp(log_P + C.log().clamp(min=-30), dim=(-1, -2)).exp()
|
|
526
|
+
|
|
527
|
+
# Rank-1 correction term
|
|
528
|
+
cost_correction = (
|
|
529
|
+
(err_r_normalized.unsqueeze(-2) @ C @ err_c.unsqueeze(-1)).squeeze(-1).squeeze(-1)
|
|
530
|
+
)
|
|
531
|
+
|
|
532
|
+
return cost_main + cost_correction
|
|
533
|
+
|
|
534
|
+
|
|
535
|
+
def solve_OT_batched(
|
|
536
|
+
r: th.Tensor,
|
|
537
|
+
c: th.Tensor,
|
|
538
|
+
C: th.Tensor,
|
|
539
|
+
gamma_f: float = 1024.0,
|
|
540
|
+
drop_tiny: bool = False,
|
|
541
|
+
return_plan: bool = False,
|
|
542
|
+
round: bool = True,
|
|
543
|
+
log: bool = False,
|
|
544
|
+
) -> Union[th.Tensor, Tuple[th.Tensor, Dict[str, Any]]]:
|
|
545
|
+
"""
|
|
546
|
+
Solve multiple entropic-regularized optimal transport problems in a single batched call.
|
|
547
|
+
|
|
548
|
+
This function provides significant speedup (5-10x) over solving problems sequentially
|
|
549
|
+
by amortizing GPU synchronization overhead across all problems in the batch.
|
|
550
|
+
|
|
551
|
+
Args:
|
|
552
|
+
r: (batch, n) row marginals. Each row must sum to 1.
|
|
553
|
+
c: (batch, m) column marginals. Each row must sum to 1.
|
|
554
|
+
C: Cost matrix. Either (n, m) for shared cost across all problems,
|
|
555
|
+
or (batch, n, m) for per-problem costs. Recommended to scale to [0, 1].
|
|
556
|
+
gamma_f: Temperature (inverse of regularization weight). Higher values give
|
|
557
|
+
more accurate solutions but take longer. Stable up to ~2^18 with float64.
|
|
558
|
+
drop_tiny: Not supported in batched solver. Raises NotImplementedError if True.
|
|
559
|
+
return_plan: If True, return transport plans instead of costs.
|
|
560
|
+
round: If True, apply Altschuler rounding for feasible solutions.
|
|
561
|
+
log: If True, also return optimization logs.
|
|
562
|
+
|
|
563
|
+
Returns:
|
|
564
|
+
If return_plan is False: (batch,) tensor of transport costs.
|
|
565
|
+
If return_plan is True: (batch, n, m) tensor of transport plans.
|
|
566
|
+
If log is True: tuple of (result, logs_dict).
|
|
567
|
+
|
|
568
|
+
Example:
|
|
569
|
+
>>> # Solve 32 OT problems of size 512×512
|
|
570
|
+
>>> r = torch.rand(32, 512, device='cuda', dtype=torch.float64)
|
|
571
|
+
>>> r = r / r.sum(-1, keepdim=True)
|
|
572
|
+
>>> c = torch.rand(32, 512, device='cuda', dtype=torch.float64)
|
|
573
|
+
>>> c = c / c.sum(-1, keepdim=True)
|
|
574
|
+
>>> C = torch.rand(512, 512, device='cuda', dtype=torch.float64)
|
|
575
|
+
>>> costs = solve_OT_batched(r, c, C, gamma_f=1024.)
|
|
576
|
+
"""
|
|
577
|
+
# Input validation
|
|
578
|
+
if r.dim() != 2:
|
|
579
|
+
raise ValueError(f"r must be 2D (batch, n), got shape {r.shape}")
|
|
580
|
+
if c.dim() != 2:
|
|
581
|
+
raise ValueError(f"c must be 2D (batch, m), got shape {c.shape}")
|
|
582
|
+
if C.dim() not in [2, 3]:
|
|
583
|
+
raise ValueError(f"C must be 2D (n, m) or 3D (batch, n, m), got shape {C.shape}")
|
|
584
|
+
if r.shape[0] != c.shape[0]:
|
|
585
|
+
raise ValueError(f"Batch size mismatch: r has {r.shape[0]}, c has {c.shape[0]}")
|
|
586
|
+
if C.dim() == 3 and C.shape[0] != r.shape[0]:
|
|
587
|
+
raise ValueError(f"Batch size mismatch: C has {C.shape[0]}, r has {r.shape[0]}")
|
|
588
|
+
|
|
589
|
+
if drop_tiny:
|
|
590
|
+
raise NotImplementedError(
|
|
591
|
+
"drop_tiny is not yet implemented for batched solver. "
|
|
592
|
+
"Use solve_OT with drop_tiny=True for individual problems instead."
|
|
593
|
+
)
|
|
594
|
+
|
|
595
|
+
dtype = r.dtype
|
|
596
|
+
|
|
597
|
+
# Use double precision for high gamma
|
|
598
|
+
if gamma_f > 2**10 and dtype != th.float64:
|
|
599
|
+
warnings.warn(
|
|
600
|
+
f"Switching to float64 for gamma_f > 2^10. Output will be converted back to {dtype}."
|
|
601
|
+
)
|
|
602
|
+
r, c, C = r.double(), c.double(), C.double()
|
|
603
|
+
|
|
604
|
+
# Solve
|
|
605
|
+
u, v, gamma_final, k_total, opt_logs = _batched_mdot(r, c, C, gamma_f)
|
|
606
|
+
|
|
607
|
+
# Convert back to original dtype
|
|
608
|
+
u, v = u.to(dtype), v.to(dtype)
|
|
609
|
+
gamma_final = gamma_final.to(dtype)
|
|
610
|
+
if C.dtype != dtype:
|
|
611
|
+
C = C.to(dtype)
|
|
612
|
+
|
|
613
|
+
opt_logs["k_total"] = k_total
|
|
614
|
+
|
|
615
|
+
if return_plan:
|
|
616
|
+
# Expand C for broadcasting if shared
|
|
617
|
+
C_expanded = C.unsqueeze(0) if C.dim() == 2 else C
|
|
618
|
+
gamma_for_plan = gamma_final.unsqueeze(-1).unsqueeze(-1)
|
|
619
|
+
|
|
620
|
+
P = (u.unsqueeze(-1) + v.unsqueeze(-2) - gamma_for_plan * C_expanded).exp()
|
|
621
|
+
if round:
|
|
622
|
+
P = _batched_round(P, r, c)
|
|
623
|
+
|
|
624
|
+
return (P, opt_logs) if log else P
|
|
625
|
+
else:
|
|
626
|
+
if round:
|
|
627
|
+
costs = _batched_rounded_cost(u, v, r, c, C, gamma_final)
|
|
628
|
+
else:
|
|
629
|
+
C_expanded = C.unsqueeze(0) if C.dim() == 2 else C
|
|
630
|
+
gamma_for_plan = gamma_final.unsqueeze(-1).unsqueeze(-1)
|
|
631
|
+
P = (u.unsqueeze(-1) + v.unsqueeze(-2) - gamma_for_plan * C_expanded).exp()
|
|
632
|
+
costs = (P * C_expanded).sum(dim=(-2, -1))
|
|
633
|
+
|
|
634
|
+
return (costs, opt_logs) if log else costs
|