mdot-tnt 0.2.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mdot_tnt/batched.py ADDED
@@ -0,0 +1,634 @@
1
+ """
2
+ Batched MDOT-TNT solver for solving multiple optimal transport problems simultaneously.
3
+
4
+ This module provides batched versions of the MDOT-TNT solver that achieve significant
5
+ speedups (5-10x) over sequential solving by amortizing GPU synchronization overhead
6
+ across all problems in a batch.
7
+
8
+ Key insight: The main solver has many Python while-loops that check convergence,
9
+ each requiring a GPU→CPU sync. By batching N problems together, we do one sync
10
+ per iteration for the entire batch instead of N syncs.
11
+
12
+ Supports:
13
+ - Multiple marginal pairs with shared cost matrix: r, c shape (batch, n), C shape (n, m)
14
+ - Multiple OT problems with different costs: r, c shape (batch, n), C shape (batch, n, m)
15
+
16
+ Example usage:
17
+ >>> import torch
18
+ >>> from mdot_tnt.batched import solve_OT_batched
19
+ >>>
20
+ >>> # 32 problems, each 512-dimensional
21
+ >>> r = torch.rand(32, 512, device='cuda', dtype=torch.float64)
22
+ >>> r = r / r.sum(dim=-1, keepdim=True)
23
+ >>> c = torch.rand(32, 512, device='cuda', dtype=torch.float64)
24
+ >>> c = c / c.sum(dim=-1, keepdim=True)
25
+ >>> C = torch.rand(512, 512, device='cuda', dtype=torch.float64) # Shared cost
26
+ >>>
27
+ >>> costs = solve_OT_batched(r, c, C, gamma_f=1024.)
28
+ >>> print(costs.shape) # (32,)
29
+ """
30
+
31
+ import warnings
32
+ from typing import Any, Dict, Optional, Tuple, Union
33
+
34
+ import torch as th
35
+
36
+
37
+ class BatchedTruncatedNewtonProjector:
38
+ """
39
+ Batched Truncated Newton projector for the MDOT algorithm.
40
+
41
+ Projects onto the set of couplings satisfying marginal constraints,
42
+ processing multiple problems simultaneously for efficiency.
43
+ """
44
+
45
+ def __init__(self, device: th.device, dtype: th.dtype, **kwargs):
46
+ """
47
+ Initialize the projector.
48
+
49
+ Args:
50
+ device: PyTorch device for computations.
51
+ dtype: Data type for tensors.
52
+ **kwargs: Additional options (debug: bool for verbose output).
53
+ """
54
+ self.device = device
55
+ self.dtype = dtype
56
+ self.debug = kwargs.get("debug", False)
57
+
58
+ def project(
59
+ self,
60
+ gamma_C: th.Tensor,
61
+ log_r: th.Tensor,
62
+ log_c: th.Tensor,
63
+ eps_d: Union[float, th.Tensor],
64
+ u: th.Tensor,
65
+ v: th.Tensor,
66
+ active_mask: Optional[th.Tensor] = None,
67
+ ) -> Tuple[th.Tensor, th.Tensor, Dict[str, Any], th.Tensor]:
68
+ """
69
+ Project onto the constraint set for all problems in the batch.
70
+
71
+ Args:
72
+ gamma_C: (batch, n, m) or (n, m) cost matrix scaled by gamma.
73
+ log_r: (batch, n) log of row marginals.
74
+ log_c: (batch, m) log of column marginals.
75
+ eps_d: Convergence tolerance, scalar or (batch,) tensor.
76
+ u: (batch, n) initial row dual variables.
77
+ v: (batch, m) initial column dual variables.
78
+ active_mask: (batch,) bool tensor, True for problems to process.
79
+
80
+ Returns:
81
+ u: (batch, n) updated row dual variables.
82
+ v: (batch, m) updated column dual variables.
83
+ logs: Dictionary with optimization statistics.
84
+ success: (batch,) bool tensor indicating convergence per problem.
85
+ """
86
+ batch_size = u.shape[0]
87
+
88
+ if active_mask is None:
89
+ active_mask = th.ones(batch_size, device=self.device, dtype=th.bool)
90
+
91
+ # Normalize eps_d to (batch,) tensor
92
+ eps_d = self._to_batch_tensor(eps_d, batch_size)
93
+
94
+ logs: Dict[str, Any] = {"n_iter": 0, "errs": [], "deltas": []}
95
+
96
+ # Handle shared vs per-problem cost matrix
97
+ if gamma_C.dim() == 2:
98
+ gamma_C = gamma_C.unsqueeze(0)
99
+
100
+ r = log_r.exp()
101
+ c = log_c.exp()
102
+
103
+ # Define batched LSE operations
104
+ def LSE_r(v_):
105
+ return th.logsumexp(v_.unsqueeze(-2) - gamma_C, dim=-1)
106
+
107
+ def LSE_c(u_):
108
+ return th.logsumexp(u_.unsqueeze(-1) - gamma_C, dim=-2)
109
+
110
+ # Initial Sinkhorn step to ensure c = c(P)
111
+ log_c_P = v + LSE_c(u)
112
+ v = v + log_c - log_c_P
113
+ log_r_P = u + LSE_r(v)
114
+ k = 8
115
+
116
+ # Chi-Sinkhorn initialization phase
117
+ u, v, log_r_P, err = self._chi_sinkhorn_batched(
118
+ u, v, log_r, log_c, log_r_P, eps_d ** (2 / 5), LSE_r, LSE_c, active_mask
119
+ )
120
+ r_P = log_r_P.exp()
121
+ logs["errs"].append(err.max().item())
122
+ k += 8 * 10
123
+
124
+ converged = err <= eps_d
125
+ success = converged.clone()
126
+
127
+ num_iter = 0
128
+ max_iter = 100
129
+
130
+ # Main Newton loop
131
+ while (~converged & active_mask).any() and num_iter < max_iter:
132
+ num_iter += 1
133
+ working = ~converged & active_mask
134
+
135
+ eta_k = th.clamp(err, min=0.9 * eps_d / err.clamp(min=1e-30))
136
+ grad_k = r_P - r
137
+
138
+ # Compute transport plan for Hessian
139
+ P = th.exp(u.unsqueeze(-1) + v.unsqueeze(-2) - gamma_C)
140
+ diag_PPc = ((P**2) / c.unsqueeze(-2)).sum(-1)
141
+ k += 8
142
+
143
+ # Newton solve
144
+ delta_u, delta_v, matmul_cnt, pcg_success = self._newton_solve_batched(
145
+ P, c, diag_PPc, grad_k, r_P, err, eta_k, working
146
+ )
147
+ success = success & (pcg_success | ~working)
148
+ k += matmul_cnt
149
+
150
+ # Line search with Armijo condition
151
+ alpha = th.ones(batch_size, device=self.device, dtype=self.dtype)
152
+ log_c_P = v + alpha.unsqueeze(-1) * delta_v + LSE_c(u + alpha.unsqueeze(-1) * delta_u)
153
+ k += 4
154
+
155
+ linear_decr = -(grad_k * delta_u).sum(-1)
156
+ armijo = (log_c_P.exp().sum(-1) - 1) <= (0.99 * alpha * linear_decr)
157
+ armijo = armijo | ~working
158
+
159
+ ls_iter = 0
160
+ while not armijo.all() and ls_iter < 20:
161
+ alpha = th.where(armijo, alpha, alpha * 0.5)
162
+ log_c_P = (
163
+ v + alpha.unsqueeze(-1) * delta_v + LSE_c(u + alpha.unsqueeze(-1) * delta_u)
164
+ )
165
+ k += 4
166
+ armijo = (log_c_P.exp().sum(-1) - 1) <= (0.99 * alpha * linear_decr)
167
+ armijo = armijo | ~working
168
+ ls_iter += 1
169
+
170
+ # Update dual variables for working problems
171
+ u = th.where(working.unsqueeze(-1), u + alpha.unsqueeze(-1) * delta_u, u)
172
+ v = th.where(working.unsqueeze(-1), v + alpha.unsqueeze(-1) * delta_v, v)
173
+
174
+ # Sinkhorn correction
175
+ v = th.where(working.unsqueeze(-1), v + log_c - log_c_P, v)
176
+
177
+ log_r_P = u + LSE_r(v)
178
+ k += 4
179
+
180
+ # Chi-Sinkhorn refinement
181
+ u, v, log_r_P, err = self._chi_sinkhorn_batched(
182
+ u, v, log_r, log_c, log_r_P, eps_d ** (2 / 5), LSE_r, LSE_c, working
183
+ )
184
+ r_P = log_r_P.exp()
185
+ logs["errs"].append(err.max().item())
186
+
187
+ converged = converged | (err <= eps_d)
188
+
189
+ logs["n_iter"] = k
190
+
191
+ # Final row update
192
+ delta_u = log_r - log_r_P
193
+ u = u + delta_u
194
+
195
+ success = success | converged
196
+ return u, v, logs, success
197
+
198
+ def _to_batch_tensor(self, val: Union[float, th.Tensor], batch_size: int) -> th.Tensor:
199
+ """Convert scalar or tensor to (batch,) shaped tensor."""
200
+ if not isinstance(val, th.Tensor):
201
+ val = th.tensor(val, device=self.device, dtype=self.dtype)
202
+ if val.dim() == 0:
203
+ val = val.expand(batch_size)
204
+ return val
205
+
206
+ def _chi_sinkhorn_batched(
207
+ self, u, v, log_r, log_c, log_r_P, eps_chi, LSE_r, LSE_c, active_mask, max_iter=100
208
+ ):
209
+ """Batched chi-squared Sinkhorn iterations for initialization."""
210
+ r = log_r.exp()
211
+ r_P = log_r_P.exp()
212
+
213
+ err = (r - r_P).abs().sum(-1)
214
+ chi_squared = ((r - r_P) ** 2 / r_P.clamp(min=1e-30)).sum(-1)
215
+
216
+ eps_chi = self._to_batch_tensor(eps_chi, u.shape[0])
217
+ working = (chi_squared > eps_chi) & active_mask
218
+
219
+ for _ in range(max_iter):
220
+ if not working.any():
221
+ break
222
+
223
+ delta_u = log_r - log_r_P
224
+ u = th.where(working.unsqueeze(-1), u + delta_u, u)
225
+
226
+ log_c_P = v + LSE_c(u)
227
+ delta_v = log_c - log_c_P
228
+ v = th.where(working.unsqueeze(-1), v + delta_v, v)
229
+
230
+ log_r_P = u + LSE_r(v)
231
+ r_P = log_r_P.exp()
232
+
233
+ err = (r - r_P).abs().sum(-1)
234
+ chi_squared = ((r - r_P) ** 2 / r_P.clamp(min=1e-30)).sum(-1)
235
+ working = (chi_squared > eps_chi) & active_mask
236
+
237
+ return u, v, log_r_P, err
238
+
239
+ def _newton_solve_batched(
240
+ self, P, c, diag_PPc, grad_k, r_P, err, eta_k, active_mask, max_iter=50
241
+ ):
242
+ """Batched preconditioned conjugate gradient Newton solve."""
243
+ tol = err * eta_k
244
+
245
+ # Diagonal preconditioner
246
+ M_rho = r_P - diag_PPc
247
+ M_rho = th.where(M_rho > 0, M_rho, M_rho.clamp(min=1e-10))
248
+
249
+ x = -grad_k / M_rho
250
+ r_vec = r_P * x - self._batched_PPc_matmul(P, c, x) + grad_k
251
+ matmul_cnt = 2
252
+
253
+ y = r_vec / M_rho
254
+ p = -y.clone()
255
+ ry_old = (r_vec * y).sum(-1, keepdim=True)
256
+
257
+ for _ in range(max_iter):
258
+ PPc_p = self._batched_PPc_matmul(P, c, p)
259
+ matmul_cnt += 2
260
+
261
+ Fr_p = r_P * p - PPc_p
262
+ quad = (Fr_p * p).sum(-1, keepdim=True)
263
+ quad = th.where(quad > 0, quad, th.ones_like(quad))
264
+
265
+ alpha = ry_old / quad
266
+ x = x + alpha * p
267
+ r_vec = r_vec + alpha * Fr_p
268
+
269
+ r_norm = r_vec.abs().sum(-1)
270
+ if (r_norm <= tol).all():
271
+ break
272
+
273
+ y = r_vec / M_rho
274
+ ry_new = (r_vec * y).sum(-1, keepdim=True)
275
+ p = -y + (ry_new / ry_old.clamp(min=1e-30)) * p
276
+ ry_old = ry_new
277
+
278
+ Pc_x = (x.unsqueeze(-2) @ P).squeeze(-2) / c
279
+
280
+ # Track convergence: success if residual norm is below tolerance
281
+ r_norm = r_vec.abs().sum(-1)
282
+ success = r_norm <= tol
283
+
284
+ return x, -Pc_x, matmul_cnt, success
285
+
286
+ def _batched_PPc_matmul(self, P, c, x):
287
+ """Compute P @ (P^T @ x / c) efficiently in batched form."""
288
+ PTx = (x.unsqueeze(-1) * P).sum(-2)
289
+ PTx_over_c = PTx / c
290
+ return (PTx_over_c.unsqueeze(-2) * P).sum(-1)
291
+
292
+
293
+ def _batched_smooth_marginals(
294
+ r: th.Tensor, c: th.Tensor, eps: th.Tensor, w_r: float = 0.5, w_c: float = 0.5
295
+ ) -> Tuple[th.Tensor, th.Tensor]:
296
+ """
297
+ Smooth marginals by mixing with uniform distribution.
298
+
299
+ Args:
300
+ r: (batch, n) row marginals.
301
+ c: (batch, m) column marginals.
302
+ eps: (batch,) or scalar smoothing factor.
303
+ w_r, w_c: Weights for row/column smoothing (must sum to 1).
304
+
305
+ Returns:
306
+ r_hat, c_hat: Smoothed marginals.
307
+ """
308
+ eps = eps.clamp(max=1.0)
309
+ if eps.dim() == 0:
310
+ eps = eps.unsqueeze(0)
311
+ eps = eps.unsqueeze(-1)
312
+
313
+ r_hat = (1 - w_r * eps) * r + w_r * eps / r.size(-1)
314
+ c_hat = (1 - w_c * eps) * c + w_c * eps / c.size(-1)
315
+
316
+ return r_hat, c_hat
317
+
318
+
319
+ def _batched_mdot(
320
+ r: th.Tensor,
321
+ c: th.Tensor,
322
+ C: th.Tensor,
323
+ gamma_f: float,
324
+ gamma_i: float = 16,
325
+ p: float = 1.5,
326
+ q: float = 2.0,
327
+ ) -> Tuple[th.Tensor, th.Tensor, th.Tensor, int, Dict[str, Any]]:
328
+ """
329
+ Batched MDOT (Mirror Descent Optimal Transport) solver.
330
+
331
+ Solves multiple entropic-regularized OT problems simultaneously using
332
+ temperature annealing with truncated Newton projections.
333
+
334
+ Args:
335
+ r: (batch, n) row marginals.
336
+ c: (batch, m) column marginals.
337
+ C: (n, m) or (batch, n, m) cost matrix.
338
+ gamma_f: Final temperature (inverse regularization weight).
339
+ gamma_i: Initial temperature.
340
+ p: Exponent for the epsilon schedule.
341
+ q: Temperature annealing factor.
342
+
343
+ Returns:
344
+ u: (batch, n) optimal row dual variables.
345
+ v: (batch, m) optimal column dual variables.
346
+ gamma_final: (batch,) final temperature achieved per problem.
347
+ k_total: Total number of primitive operations.
348
+ logs: Optimization logs.
349
+ """
350
+ batch_size = r.shape[0]
351
+ device = r.device
352
+ dtype = r.dtype
353
+
354
+ projector = BatchedTruncatedNewtonProjector(device=device, dtype=dtype)
355
+
356
+ # Compute entropy bounds for epsilon schedule
357
+ H_r = -(r * (r + 1e-30).log()).sum(-1)
358
+ H_c = -(c * (c + 1e-30).log()).sum(-1)
359
+ H_min = th.min(H_r, H_c)
360
+ eps_fn = lambda g_: H_min / (g_**p)
361
+
362
+ logs: Dict[str, Any] = {"proj_logs": [], "gammas": []}
363
+
364
+ gamma = min(gamma_i, gamma_f)
365
+ gamma_per_problem = th.full((batch_size,), gamma, device=device, dtype=dtype)
366
+ gamma_prev = th.zeros((batch_size,), device=device, dtype=dtype)
367
+ active_mask = th.ones(batch_size, device=device, dtype=th.bool)
368
+
369
+ # Initialize dual variables
370
+ eps_d = eps_fn(gamma)
371
+ r_hat, c_hat = _batched_smooth_marginals(r, c, eps_d / 2, w_r=0.9, w_c=0.1)
372
+ u_init = r_hat.log()
373
+ v_init = c_hat.log()
374
+ u_cur = u_init.clone()
375
+ v_cur = v_init.clone()
376
+ u_prev = u_cur.clone()
377
+ v_prev = v_cur.clone()
378
+
379
+ t = 1
380
+ max_outer_iter = 50
381
+ done_all: Any = False
382
+
383
+ while active_mask.any() and t < max_outer_iter and not done_all:
384
+ done = th.abs(gamma_per_problem - gamma_f) < 1e-5
385
+ done_all = (done | ~active_mask).all()
386
+
387
+ eps_d = eps_fn(gamma_per_problem)
388
+ r_hat, c_hat = _batched_smooth_marginals(r, c, eps_d / 2, w_r=0.9, w_c=0.1)
389
+
390
+ # Scale cost matrix by per-problem gamma
391
+ if C.dim() == 2:
392
+ gamma_C = gamma_per_problem.unsqueeze(-1).unsqueeze(-1) * C.unsqueeze(0)
393
+ else:
394
+ gamma_C = gamma_per_problem.unsqueeze(-1).unsqueeze(-1) * C
395
+
396
+ # Save previous values for warm-starting
397
+ u_prev = th.where(active_mask.unsqueeze(-1), u_cur.clone(), u_prev)
398
+ v_prev = th.where(active_mask.unsqueeze(-1), v_cur.clone(), v_prev)
399
+
400
+ # Project using warm-started initial values
401
+ u_new, v_new, proj_log, success = projector.project(
402
+ gamma_C, r_hat.log(), c_hat.log(), eps_d / 2, u_init, v_init, active_mask
403
+ )
404
+
405
+ u_cur = th.where(active_mask.unsqueeze(-1), u_new, u_cur)
406
+ v_cur = th.where(active_mask.unsqueeze(-1), v_new, v_cur)
407
+
408
+ logs["proj_logs"].append(proj_log)
409
+
410
+ # Store previous gamma for warm-starting
411
+ gamma_prev_old = gamma_prev.clone()
412
+ gamma_prev = gamma_per_problem.clone()
413
+
414
+ # Update gamma for non-converged problems
415
+ gamma_per_problem = th.where(
416
+ active_mask & ~done, th.clamp(gamma_per_problem * q, max=gamma_f), gamma_per_problem
417
+ )
418
+
419
+ # Warm-start initialization for next iteration (extrapolation)
420
+ # Uses linear extrapolation from the previous two iterates, similar to the
421
+ # unbatched solver in mdot.py. The extrapolation factor is clamped to [-2, 2]
422
+ # to prevent instability when gamma changes rapidly between iterations.
423
+ if not done_all:
424
+ # Avoid division by zero for first iteration (gamma_prev_old starts at 0)
425
+ denom = (gamma_prev - gamma_prev_old).clamp(min=1e-10)
426
+ extrap_factor = ((gamma_per_problem - gamma_prev) / denom).unsqueeze(-1)
427
+ extrap_factor = extrap_factor.clamp(-2.0, 2.0)
428
+ u_init = th.where(
429
+ active_mask.unsqueeze(-1) & (t > 1), u_cur + (u_cur - u_prev) * extrap_factor, u_cur
430
+ )
431
+ v_init = th.where(
432
+ active_mask.unsqueeze(-1) & (t > 1), v_cur + (v_cur - v_prev) * extrap_factor, v_cur
433
+ )
434
+
435
+ logs["gammas"].append(gamma_per_problem.clone())
436
+ t += 1
437
+
438
+ k_total = sum([log["n_iter"] for log in logs["proj_logs"]])
439
+ logs["success"] = active_mask
440
+ logs["outer_iterations"] = t - 1
441
+
442
+ return u_cur, v_cur, gamma_per_problem, k_total, logs
443
+
444
+
445
+ def _batched_round(P: th.Tensor, r: th.Tensor, c: th.Tensor) -> th.Tensor:
446
+ """
447
+ Batched Altschuler rounding to project onto feasible transport plans.
448
+
449
+ Args:
450
+ P: (batch, n, m) approximate transport plans.
451
+ r: (batch, n) row marginals.
452
+ c: (batch, m) column marginals.
453
+
454
+ Returns:
455
+ P_rounded: (batch, n, m) feasible transport plans in U(r, c).
456
+ """
457
+ # Scale rows
458
+ row_sums = P.sum(-1)
459
+ X = th.clamp(r / row_sums.clamp(min=1e-30), max=1.0)
460
+ P = P * X.unsqueeze(-1)
461
+
462
+ # Scale columns
463
+ col_sums = P.sum(-2)
464
+ Y = th.clamp(c / col_sums.clamp(min=1e-30), max=1.0)
465
+ P = P * Y.unsqueeze(-2)
466
+
467
+ # Fix remaining error with rank-1 correction
468
+ err_r = (r - P.sum(-1)).clamp(min=0)
469
+ err_c = (c - P.sum(-2)).clamp(min=0)
470
+ err_r_norm = err_r.norm(p=1, dim=-1, keepdim=True).unsqueeze(-1) + 1e-30
471
+ P = P + err_r.unsqueeze(-1) * err_c.unsqueeze(-2) / err_r_norm
472
+
473
+ return P
474
+
475
+
476
+ def _batched_rounded_cost(
477
+ u: th.Tensor, v: th.Tensor, r: th.Tensor, c: th.Tensor, C: th.Tensor, gamma: th.Tensor
478
+ ) -> th.Tensor:
479
+ """
480
+ Compute transport cost with rounding in log-domain (memory efficient).
481
+
482
+ This avoids materializing the full n×m transport plan for each problem.
483
+
484
+ Args:
485
+ u: (batch, n) row dual variables.
486
+ v: (batch, m) column dual variables.
487
+ r: (batch, n) row marginals.
488
+ c: (batch, m) column marginals.
489
+ C: (n, m) or (batch, n, m) cost matrix.
490
+ gamma: (batch,) temperature per problem.
491
+
492
+ Returns:
493
+ costs: (batch,) optimal transport costs.
494
+ """
495
+ batch_size = u.shape[0]
496
+
497
+ if C.dim() == 2:
498
+ C = C.unsqueeze(0).expand(batch_size, -1, -1)
499
+
500
+ gamma = gamma.unsqueeze(-1).unsqueeze(-1)
501
+
502
+ # Row rounding in log domain
503
+ r_P_log = u + th.logsumexp(v.unsqueeze(-2) - gamma * C, dim=-1)
504
+ delta_u = th.clamp(r.log() - r_P_log, max=0)
505
+ u = u + delta_u
506
+
507
+ # Column rounding in log domain
508
+ c_P_log = v + th.logsumexp(u.unsqueeze(-1) - gamma * C, dim=-2)
509
+ delta_v = th.clamp(c.log() - c_P_log, max=0)
510
+ v = v + delta_v
511
+
512
+ # Compute row error for rank-1 correction
513
+ r_P_log = u + th.logsumexp(v.unsqueeze(-2) - gamma * C, dim=-1)
514
+ r_P = r_P_log.exp()
515
+ err_r = r - r_P
516
+ err_r_normalized = err_r / (err_r.abs().sum(-1, keepdim=True) + 1e-30)
517
+
518
+ # Column marginal after rounding
519
+ c_P_log = v + th.logsumexp(u.unsqueeze(-1) - gamma * C, dim=-2)
520
+ c_P = c_P_log.exp()
521
+ err_c = c - c_P
522
+
523
+ # Main cost term (in log domain for stability)
524
+ log_P = u.unsqueeze(-1) + v.unsqueeze(-2) - gamma * C
525
+ cost_main = th.logsumexp(log_P + C.log().clamp(min=-30), dim=(-1, -2)).exp()
526
+
527
+ # Rank-1 correction term
528
+ cost_correction = (
529
+ (err_r_normalized.unsqueeze(-2) @ C @ err_c.unsqueeze(-1)).squeeze(-1).squeeze(-1)
530
+ )
531
+
532
+ return cost_main + cost_correction
533
+
534
+
535
+ def solve_OT_batched(
536
+ r: th.Tensor,
537
+ c: th.Tensor,
538
+ C: th.Tensor,
539
+ gamma_f: float = 1024.0,
540
+ drop_tiny: bool = False,
541
+ return_plan: bool = False,
542
+ round: bool = True,
543
+ log: bool = False,
544
+ ) -> Union[th.Tensor, Tuple[th.Tensor, Dict[str, Any]]]:
545
+ """
546
+ Solve multiple entropic-regularized optimal transport problems in a single batched call.
547
+
548
+ This function provides significant speedup (5-10x) over solving problems sequentially
549
+ by amortizing GPU synchronization overhead across all problems in the batch.
550
+
551
+ Args:
552
+ r: (batch, n) row marginals. Each row must sum to 1.
553
+ c: (batch, m) column marginals. Each row must sum to 1.
554
+ C: Cost matrix. Either (n, m) for shared cost across all problems,
555
+ or (batch, n, m) for per-problem costs. Recommended to scale to [0, 1].
556
+ gamma_f: Temperature (inverse of regularization weight). Higher values give
557
+ more accurate solutions but take longer. Stable up to ~2^18 with float64.
558
+ drop_tiny: Not supported in batched solver. Raises NotImplementedError if True.
559
+ return_plan: If True, return transport plans instead of costs.
560
+ round: If True, apply Altschuler rounding for feasible solutions.
561
+ log: If True, also return optimization logs.
562
+
563
+ Returns:
564
+ If return_plan is False: (batch,) tensor of transport costs.
565
+ If return_plan is True: (batch, n, m) tensor of transport plans.
566
+ If log is True: tuple of (result, logs_dict).
567
+
568
+ Example:
569
+ >>> # Solve 32 OT problems of size 512×512
570
+ >>> r = torch.rand(32, 512, device='cuda', dtype=torch.float64)
571
+ >>> r = r / r.sum(-1, keepdim=True)
572
+ >>> c = torch.rand(32, 512, device='cuda', dtype=torch.float64)
573
+ >>> c = c / c.sum(-1, keepdim=True)
574
+ >>> C = torch.rand(512, 512, device='cuda', dtype=torch.float64)
575
+ >>> costs = solve_OT_batched(r, c, C, gamma_f=1024.)
576
+ """
577
+ # Input validation
578
+ if r.dim() != 2:
579
+ raise ValueError(f"r must be 2D (batch, n), got shape {r.shape}")
580
+ if c.dim() != 2:
581
+ raise ValueError(f"c must be 2D (batch, m), got shape {c.shape}")
582
+ if C.dim() not in [2, 3]:
583
+ raise ValueError(f"C must be 2D (n, m) or 3D (batch, n, m), got shape {C.shape}")
584
+ if r.shape[0] != c.shape[0]:
585
+ raise ValueError(f"Batch size mismatch: r has {r.shape[0]}, c has {c.shape[0]}")
586
+ if C.dim() == 3 and C.shape[0] != r.shape[0]:
587
+ raise ValueError(f"Batch size mismatch: C has {C.shape[0]}, r has {r.shape[0]}")
588
+
589
+ if drop_tiny:
590
+ raise NotImplementedError(
591
+ "drop_tiny is not yet implemented for batched solver. "
592
+ "Use solve_OT with drop_tiny=True for individual problems instead."
593
+ )
594
+
595
+ dtype = r.dtype
596
+
597
+ # Use double precision for high gamma
598
+ if gamma_f > 2**10 and dtype != th.float64:
599
+ warnings.warn(
600
+ f"Switching to float64 for gamma_f > 2^10. Output will be converted back to {dtype}."
601
+ )
602
+ r, c, C = r.double(), c.double(), C.double()
603
+
604
+ # Solve
605
+ u, v, gamma_final, k_total, opt_logs = _batched_mdot(r, c, C, gamma_f)
606
+
607
+ # Convert back to original dtype
608
+ u, v = u.to(dtype), v.to(dtype)
609
+ gamma_final = gamma_final.to(dtype)
610
+ if C.dtype != dtype:
611
+ C = C.to(dtype)
612
+
613
+ opt_logs["k_total"] = k_total
614
+
615
+ if return_plan:
616
+ # Expand C for broadcasting if shared
617
+ C_expanded = C.unsqueeze(0) if C.dim() == 2 else C
618
+ gamma_for_plan = gamma_final.unsqueeze(-1).unsqueeze(-1)
619
+
620
+ P = (u.unsqueeze(-1) + v.unsqueeze(-2) - gamma_for_plan * C_expanded).exp()
621
+ if round:
622
+ P = _batched_round(P, r, c)
623
+
624
+ return (P, opt_logs) if log else P
625
+ else:
626
+ if round:
627
+ costs = _batched_rounded_cost(u, v, r, c, C, gamma_final)
628
+ else:
629
+ C_expanded = C.unsqueeze(0) if C.dim() == 2 else C
630
+ gamma_for_plan = gamma_final.unsqueeze(-1).unsqueeze(-1)
631
+ P = (u.unsqueeze(-1) + v.unsqueeze(-2) - gamma_for_plan * C_expanded).exp()
632
+ costs = (P * C_expanded).sum(dim=(-2, -1))
633
+
634
+ return (costs, opt_logs) if log else costs