mdot-tnt 0.1.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mdot_tnt/__init__.py +52 -8
- mdot_tnt/batched.py +634 -0
- mdot_tnt/mdot.py +105 -41
- mdot_tnt/py.typed +0 -0
- mdot_tnt/rounding.py +41 -15
- mdot_tnt/truncated_newton.py +107 -38
- mdot_tnt-1.0.0.dist-info/METADATA +216 -0
- mdot_tnt-1.0.0.dist-info/RECORD +11 -0
- {mdot_tnt-0.1.0.dist-info → mdot_tnt-1.0.0.dist-info}/WHEEL +1 -1
- {mdot_tnt-0.1.0.dist-info → mdot_tnt-1.0.0.dist-info/licenses}/LICENSE +4 -1
- mdot_tnt-0.1.0.dist-info/METADATA +0 -71
- mdot_tnt-0.1.0.dist-info/RECORD +0 -9
- {mdot_tnt-0.1.0.dist-info → mdot_tnt-1.0.0.dist-info}/top_level.txt +0 -0
mdot_tnt/mdot.py
CHANGED
|
@@ -1,33 +1,37 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Code for solving the entropic-regularized optimal transport problem via the MDOT-TruncatedNewton (MDOT-TNT)
|
|
3
|
-
method introduced in the paper "A Truncated Newton Method for Optimal Transport"
|
|
4
|
-
by Mete Kemertas, Amir-massoud Farahmand, Allan D. Jepson (ICLR, 2025).
|
|
5
|
-
URL: https://openreview.net/forum?id=gWrWUaCbMa
|
|
6
|
-
"""
|
|
1
|
+
"""Core MDOT solver using truncated Newton projection."""
|
|
7
2
|
|
|
8
|
-
|
|
9
|
-
import math
|
|
10
3
|
import warnings
|
|
4
|
+
from typing import Any, Dict, List, Tuple, Union
|
|
5
|
+
|
|
6
|
+
import torch as th
|
|
11
7
|
|
|
12
|
-
from mdot_tnt.rounding import *
|
|
13
8
|
from mdot_tnt.truncated_newton import TruncatedNewtonProjector
|
|
14
9
|
|
|
15
10
|
|
|
16
|
-
def preprocess_marginals(
|
|
11
|
+
def preprocess_marginals(
|
|
12
|
+
r: th.Tensor, c: th.Tensor, C: th.Tensor, eps: float
|
|
13
|
+
) -> Tuple[Tuple[th.Tensor, th.Tensor], Tuple[th.Tensor, th.Tensor], th.Tensor]:
|
|
17
14
|
"""
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
:
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
15
|
+
Drop the smallest marginal entries whose cumulative sum is below a threshold.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
r: The row marginal of shape (n,).
|
|
19
|
+
c: The column marginal of shape (m,).
|
|
20
|
+
C: The cost matrix of shape (n, m).
|
|
21
|
+
eps: The threshold for the cumulative sum of the marginal entries to be dropped.
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
A tuple containing:
|
|
25
|
+
- (r_new, r_keep): The new row marginal and indices of kept entries.
|
|
26
|
+
- (c_new, c_keep): The new column marginal and indices of kept entries.
|
|
27
|
+
- C: The cost matrix with corresponding rows and columns dropped.
|
|
24
28
|
"""
|
|
25
29
|
|
|
26
|
-
def preprocess_marginal(m, eps):
|
|
30
|
+
def preprocess_marginal(m: th.Tensor, eps: float) -> Tuple[th.Tensor, th.Tensor]:
|
|
27
31
|
m_sorted, m_idx = th.sort(m, dim=-1, descending=False)
|
|
28
32
|
m_cumsum = th.cumsum(m_sorted, dim=-1)
|
|
29
33
|
m_keep = m_idx[m_cumsum > eps]
|
|
30
|
-
m_new = m[
|
|
34
|
+
m_new = m[m_keep]
|
|
31
35
|
mass_removed = 1 - m_new.sum(-1)
|
|
32
36
|
m_new = m_new + mass_removed / m_new.size(-1)
|
|
33
37
|
|
|
@@ -35,59 +39,115 @@ def preprocess_marginals(r, c, C, eps):
|
|
|
35
39
|
|
|
36
40
|
r_new, r_keep = preprocess_marginal(r, eps)
|
|
37
41
|
c_new, c_keep = preprocess_marginal(c, eps)
|
|
42
|
+
print(
|
|
43
|
+
f"Dropped {r.size(-1) - r_new.size(-1)} entries from r and {c.size(-1) - c_new.size(-1)} entries from c."
|
|
44
|
+
)
|
|
38
45
|
|
|
39
46
|
C = C[r_keep][:, c_keep]
|
|
40
47
|
|
|
41
48
|
return (r_new, r_keep), (c_new, c_keep), C
|
|
42
49
|
|
|
43
50
|
|
|
44
|
-
def smooth_marginals(
|
|
51
|
+
def smooth_marginals(
|
|
52
|
+
r: th.Tensor,
|
|
53
|
+
c: th.Tensor,
|
|
54
|
+
eps: th.Tensor,
|
|
55
|
+
w_r: float = 0.5,
|
|
56
|
+
w_c: float = 0.5,
|
|
57
|
+
) -> Tuple[th.Tensor, th.Tensor]:
|
|
58
|
+
"""
|
|
59
|
+
Smooth the marginals by adding a small amount of uniform mass to each entry.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
r: The row marginal of shape (n,).
|
|
63
|
+
c: The column marginal of shape (m,).
|
|
64
|
+
eps: The amount of mass to add to each entry.
|
|
65
|
+
w_r: The weight for the row marginal.
|
|
66
|
+
w_c: The weight for the column marginal.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
A tuple (r_hat, c_hat) of smoothed marginals with total TV distance at most eps
|
|
70
|
+
from the original marginals.
|
|
71
|
+
"""
|
|
45
72
|
assert w_r + w_c == 1, "w_r and w_c must sum to 1"
|
|
46
|
-
eps = eps.clamp(max=1.).unsqueeze(-1)
|
|
73
|
+
eps = eps.clamp(max=1.0).unsqueeze(-1)
|
|
47
74
|
r_hat = (1 - w_r * eps) * r + w_r * eps * th.ones_like(r) / r.size(-1)
|
|
48
75
|
c_hat = (1 - w_c * eps) * c + w_c * eps * th.ones_like(c) / c.size(-1)
|
|
49
76
|
|
|
50
77
|
return r_hat, c_hat
|
|
51
78
|
|
|
52
79
|
|
|
53
|
-
def adjust_schedule(q, deltas=None):
|
|
80
|
+
def adjust_schedule(q: float, deltas: Union[List[float], None] = None) -> float:
|
|
81
|
+
"""
|
|
82
|
+
Adjust the temperature annealing schedule based on the success of the Truncated Newton method.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
q: The current temperature annealing schedule adjustment factor.
|
|
86
|
+
deltas: The list of deltas from the Truncated Newton method;
|
|
87
|
+
see Sec. 3.3 of Kemertas et al. (2025).
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
The new temperature annealing schedule adjustment factor.
|
|
91
|
+
"""
|
|
54
92
|
if deltas is None:
|
|
55
93
|
return q
|
|
56
94
|
|
|
57
|
-
deltas = deltas + [1.] # If deltas is empty, we assume that the first iteration was successful
|
|
95
|
+
deltas = deltas + [1.0] # If deltas is empty, we assume that the first iteration was successful
|
|
58
96
|
delta_min = min(deltas)
|
|
59
97
|
|
|
60
98
|
if delta_min < 0.5:
|
|
61
|
-
q = q
|
|
99
|
+
q = q**0.5
|
|
62
100
|
elif delta_min > 0.9:
|
|
63
|
-
q = q
|
|
101
|
+
q = q**2
|
|
64
102
|
|
|
65
103
|
return q
|
|
66
104
|
|
|
67
105
|
|
|
68
|
-
def mdot(
|
|
106
|
+
def mdot(
|
|
107
|
+
r: th.Tensor,
|
|
108
|
+
c: th.Tensor,
|
|
109
|
+
C: th.Tensor,
|
|
110
|
+
gamma_f: float,
|
|
111
|
+
gamma_i: float = 16,
|
|
112
|
+
p: float = 1.5,
|
|
113
|
+
q: float = 2.0,
|
|
114
|
+
) -> Tuple[th.Tensor, th.Tensor, float, int, Dict[str, Any]]:
|
|
69
115
|
"""
|
|
70
|
-
Solve the entropic-regularized optimal transport problem using the MDOT method
|
|
71
|
-
|
|
72
|
-
|
|
116
|
+
Solve the entropic-regularized optimal transport problem using the MDOT method.
|
|
117
|
+
|
|
118
|
+
This implements the MDOT method introduced in the paper:
|
|
119
|
+
"Efficient and Accurate Optimal Transport with Mirror Descent and Conjugate Gradients"
|
|
120
|
+
by Mete Kemertas, Allan D. Jepson and Amir-massoud Farahmand.
|
|
121
|
+
URL: https://arxiv.org/abs/2307.08507
|
|
122
|
+
|
|
73
123
|
Here, we use the Truncated Newton method for projection.
|
|
74
|
-
|
|
75
|
-
:
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
r: The first marginal of shape (n,).
|
|
127
|
+
c: The second marginal of shape (m,).
|
|
128
|
+
C: The cost matrix of shape (n, m). Recommended to scale entries to [0, 1].
|
|
129
|
+
gamma_f: The final temperature (inverse of the regularization weight).
|
|
130
|
+
gamma_i: The initial temperature.
|
|
131
|
+
p: The exponent for the epsilon function, used to determine the stopping
|
|
132
|
+
criterion for the dual gradient.
|
|
133
|
+
q: The temperature annealing (or mirror descent step size) schedule adjustment factor.
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
A tuple containing:
|
|
137
|
+
- u: The row dual variables of shape (n,).
|
|
138
|
+
- v: The column dual variables of shape (m,).
|
|
139
|
+
- gamma: The final temperature achieved.
|
|
140
|
+
- k_total: The total number of O(n^2) primitive operations.
|
|
141
|
+
- logs: Dictionary with optimization statistics.
|
|
82
142
|
"""
|
|
83
143
|
projector = TruncatedNewtonProjector(device=C.device, dtype=C.dtype)
|
|
84
144
|
|
|
85
145
|
H_r = -(r * (r + 1e-30).log()).sum(-1)
|
|
86
146
|
H_c = -(c * (c + 1e-30).log()).sum(-1)
|
|
87
147
|
H_min = th.min(H_r, H_c)
|
|
88
|
-
eps_fn = lambda g_: H_min / (g_
|
|
148
|
+
eps_fn = lambda g_: H_min / (g_**p)
|
|
89
149
|
|
|
90
|
-
logs = {
|
|
150
|
+
logs: Dict[str, Any] = {
|
|
91
151
|
"proj_logs": [],
|
|
92
152
|
"eps": [],
|
|
93
153
|
}
|
|
@@ -95,7 +155,7 @@ def mdot(r, c, C, gamma_f, gamma_i=16, p=1.5, q=2.0**(1/3)):
|
|
|
95
155
|
t = 1
|
|
96
156
|
done = False
|
|
97
157
|
gamma = min(gamma_i, gamma_f)
|
|
98
|
-
gammas = [0
|
|
158
|
+
gammas = [0.0, gamma]
|
|
99
159
|
|
|
100
160
|
while not done:
|
|
101
161
|
done = abs(gamma - gamma_f) < 1e-5 # Check if gamma == gamma_f (modulo rounding errors)
|
|
@@ -109,12 +169,16 @@ def mdot(r, c, C, gamma_f, gamma_i=16, p=1.5, q=2.0**(1/3)):
|
|
|
109
169
|
u_cur, v_cur = u_init.clone(), v_init.clone()
|
|
110
170
|
|
|
111
171
|
u_prev, v_prev = u_cur.clone(), v_cur.clone()
|
|
172
|
+
gamma_C = gamma * C
|
|
112
173
|
u_cur, v_cur, proj_log, success = projector.project(
|
|
113
|
-
|
|
174
|
+
gamma_C, r_hat.log(), c_hat.log(), eps_d / 2, u_init, v_init
|
|
175
|
+
)
|
|
114
176
|
|
|
115
177
|
logs["proj_logs"].append(proj_log)
|
|
116
178
|
if not success:
|
|
117
|
-
warnings.warn(
|
|
179
|
+
warnings.warn(
|
|
180
|
+
f"Projection failed. Returning result at the last temperature: {1 / gammas[-2]:.4e}"
|
|
181
|
+
)
|
|
118
182
|
u_cur = u_prev.clone()
|
|
119
183
|
v_cur = v_prev.clone()
|
|
120
184
|
gammas = gammas[:-1]
|
mdot_tnt/py.typed
ADDED
|
File without changes
|
mdot_tnt/rounding.py
CHANGED
|
@@ -6,16 +6,22 @@ and c are the row and column marginals, respectively. The algorithm is used in t
|
|
|
6
6
|
plan and compute the cost of the rounded plan. The implementation is based on the original paper.
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
|
+
from typing import Union
|
|
10
|
+
|
|
9
11
|
import torch as th
|
|
10
12
|
|
|
11
13
|
|
|
12
|
-
def round_altschuler(P, r, c):
|
|
14
|
+
def round_altschuler(P: th.Tensor, r: th.Tensor, c: th.Tensor) -> th.Tensor:
|
|
13
15
|
"""
|
|
14
16
|
Performs rounding given a transport plan and marginals.
|
|
15
|
-
|
|
16
|
-
:
|
|
17
|
-
|
|
18
|
-
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
P: The input transport plan of shape (n, m).
|
|
20
|
+
r: Row marginal of shape (n,).
|
|
21
|
+
c: Column marginal of shape (m,).
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
Rounded transport plan in feasible set U(r, c).
|
|
19
25
|
"""
|
|
20
26
|
X = th.min(r / P.sum(-1), th.ones_like(r))
|
|
21
27
|
P *= X.unsqueeze(-1)
|
|
@@ -25,19 +31,39 @@ def round_altschuler(P, r, c):
|
|
|
25
31
|
|
|
26
32
|
err_r = (r - P.sum(-1)).clamp(min=0)
|
|
27
33
|
err_c = (c - P.sum(-2)).clamp(min=0)
|
|
28
|
-
P +=
|
|
34
|
+
P += (
|
|
35
|
+
err_r.unsqueeze(-1)
|
|
36
|
+
@ err_c.unsqueeze(-2)
|
|
37
|
+
/ (err_r.norm(p=1, dim=-1, keepdim=True) + 1e-30).unsqueeze(-1)
|
|
38
|
+
)
|
|
29
39
|
|
|
30
40
|
return P
|
|
31
41
|
|
|
32
42
|
|
|
33
|
-
def rounded_cost_altschuler(
|
|
34
|
-
|
|
35
|
-
:
|
|
36
|
-
:
|
|
37
|
-
:
|
|
38
|
-
:
|
|
39
|
-
:
|
|
40
|
-
|
|
43
|
+
def rounded_cost_altschuler(
|
|
44
|
+
u: th.Tensor,
|
|
45
|
+
v: th.Tensor,
|
|
46
|
+
r: th.Tensor,
|
|
47
|
+
c: th.Tensor,
|
|
48
|
+
C: th.Tensor,
|
|
49
|
+
gamma: Union[float, th.Tensor],
|
|
50
|
+
) -> th.Tensor:
|
|
51
|
+
"""
|
|
52
|
+
Performs rounding and cost computation in log-domain given dual variables.
|
|
53
|
+
|
|
54
|
+
This function computes the transport cost without storing the full n×m transport plan,
|
|
55
|
+
making it memory efficient.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
u: Dual variable for rows of shape (n,).
|
|
59
|
+
v: Dual variable for columns of shape (m,).
|
|
60
|
+
r: Row marginal of shape (n,).
|
|
61
|
+
c: Column marginal of shape (m,).
|
|
62
|
+
C: Cost matrix of shape (n, m).
|
|
63
|
+
gamma: Temperature (inverse of the entropic regularization weight).
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
The optimal transport cost as a scalar tensor.
|
|
41
67
|
"""
|
|
42
68
|
r_P_log = u + th.logsumexp(v.unsqueeze(-2) - gamma * C, dim=-1)
|
|
43
69
|
delta_u = th.min(r.log() - r_P_log, th.zeros_like(r))
|
|
@@ -50,7 +76,7 @@ def rounded_cost_altschuler(u, v, r, c, C, gamma):
|
|
|
50
76
|
r_P_log = u + th.logsumexp(v.unsqueeze(-2) - gamma * C, dim=-1)
|
|
51
77
|
r_P = r_P_log.exp()
|
|
52
78
|
err_r = r - r_P
|
|
53
|
-
err_r /=
|
|
79
|
+
err_r /= err_r.norm(p=1, dim=-1, keepdim=True) + 1e-30
|
|
54
80
|
|
|
55
81
|
c_P_log = v + th.logsumexp(u.unsqueeze(-1) - gamma * C, dim=-2)
|
|
56
82
|
c_P = c_P_log.exp()
|
mdot_tnt/truncated_newton.py
CHANGED
|
@@ -1,39 +1,79 @@
|
|
|
1
|
+
"""Truncated Newton projector for the MDOT algorithm."""
|
|
1
2
|
|
|
2
|
-
import torch as th
|
|
3
3
|
import warnings
|
|
4
|
+
from typing import Any, Callable, Dict, Tuple, Union
|
|
5
|
+
|
|
6
|
+
import torch as th
|
|
4
7
|
|
|
5
8
|
|
|
6
9
|
class TruncatedNewtonProjector:
|
|
7
|
-
|
|
10
|
+
"""
|
|
11
|
+
Truncated Newton projector for the MDOT algorithm.
|
|
12
|
+
|
|
13
|
+
Projects onto the set of couplings satisfying marginal constraints using
|
|
14
|
+
a preconditioned conjugate gradient method within a Newton framework.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, device: th.device, dtype: th.dtype, **kwargs: Any) -> None:
|
|
18
|
+
"""
|
|
19
|
+
Initialize the projector.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
device: PyTorch device for computations.
|
|
23
|
+
dtype: Data type for tensors.
|
|
24
|
+
**kwargs: Additional options (debug: bool for verbose output).
|
|
25
|
+
"""
|
|
8
26
|
self.device = device
|
|
9
27
|
self.rho = th.zeros(1, device=device, dtype=dtype)
|
|
10
|
-
self.debug = kwargs.get(
|
|
11
|
-
|
|
12
|
-
|
|
28
|
+
self.debug = kwargs.get("debug", False)
|
|
29
|
+
self.LSE_r: Callable[[th.Tensor], th.Tensor]
|
|
30
|
+
self.LSE_c: Callable[[th.Tensor], th.Tensor]
|
|
31
|
+
|
|
32
|
+
def project(
|
|
33
|
+
self,
|
|
34
|
+
gamma_C: th.Tensor,
|
|
35
|
+
log_r: th.Tensor,
|
|
36
|
+
log_c: th.Tensor,
|
|
37
|
+
eps_d: Union[float, th.Tensor],
|
|
38
|
+
u: th.Tensor,
|
|
39
|
+
v: th.Tensor,
|
|
40
|
+
) -> Tuple[th.Tensor, th.Tensor, Dict[str, Any], bool]:
|
|
13
41
|
"""
|
|
14
42
|
Project onto the set of couplings that satisfy the marginal constraints.
|
|
15
|
-
|
|
16
|
-
:
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
gamma_C: The cost matrix scaled by gamma, shape (n, m).
|
|
46
|
+
log_r: Log of row marginals, shape (n,).
|
|
47
|
+
log_c: Log of column marginals, shape (m,).
|
|
48
|
+
eps_d: Convergence tolerance for the dual gradient norm.
|
|
49
|
+
u: Initial row dual variables, shape (n,).
|
|
50
|
+
v: Initial column dual variables, shape (m,).
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
u: Updated row dual variables.
|
|
54
|
+
v: Updated column dual variables.
|
|
55
|
+
logs: Dictionary with optimization statistics.
|
|
56
|
+
success: Whether projection converged successfully.
|
|
17
57
|
"""
|
|
18
|
-
logs = {
|
|
58
|
+
logs: Dict[str, Any] = {
|
|
19
59
|
"errs": [],
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
60
|
+
"ls_func_cnt": 0,
|
|
61
|
+
"chisinkhorn_steps": 0,
|
|
62
|
+
"newtonsolve_steps": 0,
|
|
23
63
|
"deltas": [], # Ratios of actual to theoretically predicted (ideal) reduction in gradient norm.
|
|
24
|
-
"all_newtonsolve_steps": []
|
|
64
|
+
"all_newtonsolve_steps": [],
|
|
25
65
|
}
|
|
26
66
|
# In case of errors or issues, 10 times the tolerance level is considered
|
|
27
67
|
# a good enough solution to keep MDOT going.
|
|
28
68
|
success_fn = lambda err_: err_ < 10 * eps_d
|
|
29
69
|
|
|
70
|
+
r = log_r.exp()
|
|
71
|
+
c = log_c.exp()
|
|
72
|
+
|
|
30
73
|
# Each LSE operation costs 4 * n^2 operations.
|
|
31
74
|
self.LSE_r = lambda v_: th.logsumexp(v_.unsqueeze(-2) - gamma_C, dim=-1)
|
|
32
75
|
self.LSE_c = lambda u_: th.logsumexp(u_.unsqueeze(-1) - gamma_C, dim=-2)
|
|
33
76
|
|
|
34
|
-
r = log_r.exp()
|
|
35
|
-
c = log_c.exp()
|
|
36
|
-
|
|
37
77
|
log_c_P = v + self.LSE_c(u)
|
|
38
78
|
v += log_c - log_c_P # Ensure c=c(P)
|
|
39
79
|
log_r_P = u + self.LSE_r(v)
|
|
@@ -57,19 +97,20 @@ class TruncatedNewtonProjector:
|
|
|
57
97
|
self.rho = th.max(th.zeros_like(self.rho), self.rho)
|
|
58
98
|
|
|
59
99
|
P = th.exp(u.unsqueeze(-1) + v.unsqueeze(-2) - gamma_C)
|
|
60
|
-
diag_PPc = ((P
|
|
100
|
+
diag_PPc = ((P**2) / c.unsqueeze(-2)).sum(-1)
|
|
61
101
|
k += 8
|
|
62
102
|
delta_u, delta_v, matmul_cnt, rho, pcg_success = self.newton_solve(
|
|
63
|
-
P, c, diag_PPc, grad_k, r_P, err, beta, eta_k, maxIter=5000
|
|
103
|
+
P, c, diag_PPc, grad_k, r_P, err, beta, eta_k, maxIter=5000
|
|
104
|
+
)
|
|
64
105
|
del P # Free up memory
|
|
65
106
|
if not pcg_success:
|
|
66
107
|
k += matmul_cnt
|
|
67
108
|
logs["n_iter"] = k
|
|
68
|
-
msg = "PCG did not converge. TruncatedNewton returning with success={
|
|
109
|
+
msg = f"PCG did not converge. TruncatedNewton returning with success={success_fn(err)}"
|
|
69
110
|
warnings.warn(msg)
|
|
70
111
|
return u, v, logs, success_fn(err)
|
|
71
112
|
|
|
72
|
-
self.rho = th.max(th.zeros_like(self.rho), 1. - (1. - rho) * 4.)
|
|
113
|
+
self.rho = th.max(th.zeros_like(self.rho), 1.0 - (1.0 - rho) * 4.0)
|
|
73
114
|
k += matmul_cnt
|
|
74
115
|
logs["newtonsolve_steps"] += matmul_cnt
|
|
75
116
|
|
|
@@ -79,8 +120,7 @@ class TruncatedNewtonProjector:
|
|
|
79
120
|
linear_decr = -(grad_k * delta_u).sum(-1, keepdim=True)
|
|
80
121
|
if not linear_decr > 0:
|
|
81
122
|
logs["n_iter"] = k
|
|
82
|
-
msg = "Linear decrease condition not satisfied. TruncatedNewton returning with success={}"
|
|
83
|
-
success_fn(err))
|
|
123
|
+
msg = f"Linear decrease condition not satisfied. TruncatedNewton returning with success={success_fn(err)}"
|
|
84
124
|
warnings.warn(msg)
|
|
85
125
|
return u, v, logs, success_fn(err)
|
|
86
126
|
|
|
@@ -89,8 +129,7 @@ class TruncatedNewtonProjector:
|
|
|
89
129
|
alpha *= 0.5
|
|
90
130
|
if alpha < 1e-9:
|
|
91
131
|
logs["n_iter"] = k
|
|
92
|
-
msg = "Line search did not converge. TruncatedNewton returning with success={}"
|
|
93
|
-
success_fn(err))
|
|
132
|
+
msg = f"Line search did not converge. TruncatedNewton returning with success={success_fn(err)}"
|
|
94
133
|
warnings.warn(msg)
|
|
95
134
|
return u, v, logs, success_fn(err)
|
|
96
135
|
|
|
@@ -113,13 +152,17 @@ class TruncatedNewtonProjector:
|
|
|
113
152
|
log_r_P = u + self.LSE_r(v)
|
|
114
153
|
k += 4
|
|
115
154
|
|
|
116
|
-
u, v, log_r_P, err, k_ = self.chi_sinkhorn(
|
|
155
|
+
u, v, log_r_P, err, k_ = self.chi_sinkhorn(
|
|
156
|
+
u, v, log_r, log_c, log_r_P, eps_d ** (2 / 5)
|
|
157
|
+
)
|
|
117
158
|
r_P = log_r_P.exp()
|
|
118
159
|
logs["chisinkhorn_steps"] += k_
|
|
119
160
|
k += k_
|
|
120
161
|
|
|
121
162
|
logs["errs"].append(err)
|
|
122
|
-
logs["deltas"].append(
|
|
163
|
+
logs["deltas"].append(
|
|
164
|
+
th.min((logs["errs"][-2] - err_before_sk) / ((1 - eta_k) * logs["errs"][-2])).item()
|
|
165
|
+
)
|
|
123
166
|
|
|
124
167
|
if u.isnan().any() or v.isnan().any():
|
|
125
168
|
raise ValueError("NaNs encountered in u or v")
|
|
@@ -132,7 +175,16 @@ class TruncatedNewtonProjector:
|
|
|
132
175
|
|
|
133
176
|
return u, v, logs, True
|
|
134
177
|
|
|
135
|
-
def chi_sinkhorn(
|
|
178
|
+
def chi_sinkhorn(
|
|
179
|
+
self,
|
|
180
|
+
u: th.Tensor,
|
|
181
|
+
v: th.Tensor,
|
|
182
|
+
log_r: th.Tensor,
|
|
183
|
+
log_c: th.Tensor,
|
|
184
|
+
log_r_P: th.Tensor,
|
|
185
|
+
eps_chi: Union[float, th.Tensor],
|
|
186
|
+
maxOps: float = float("inf"),
|
|
187
|
+
) -> Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, int]:
|
|
136
188
|
k = 0
|
|
137
189
|
r = log_r.exp()
|
|
138
190
|
err = (r - log_r_P.exp()).norm(p=1, dim=-1)
|
|
@@ -154,28 +206,40 @@ class TruncatedNewtonProjector:
|
|
|
154
206
|
k += 8
|
|
155
207
|
|
|
156
208
|
if k >= maxOps:
|
|
157
|
-
raise ValueError("Chi-Sinkhorn did not converge in maxIter={} steps"
|
|
209
|
+
raise ValueError(f"Chi-Sinkhorn did not converge in maxIter={maxOps} steps")
|
|
158
210
|
|
|
159
211
|
return u, v, log_r_P, err, k
|
|
160
212
|
|
|
161
|
-
def newton_solve(
|
|
213
|
+
def newton_solve(
|
|
214
|
+
self,
|
|
215
|
+
P: th.Tensor,
|
|
216
|
+
c: th.Tensor,
|
|
217
|
+
diag_PPc: th.Tensor,
|
|
218
|
+
grad_k: th.Tensor,
|
|
219
|
+
r_P: th.Tensor,
|
|
220
|
+
err: th.Tensor,
|
|
221
|
+
beta: float = 0.5,
|
|
222
|
+
eta_k: Union[float, th.Tensor] = 0.5,
|
|
223
|
+
maxIter: int = 500,
|
|
224
|
+
) -> Tuple[th.Tensor, th.Tensor, int, th.Tensor, bool]:
|
|
162
225
|
rho = self.rho
|
|
163
226
|
tol = err * eta_k
|
|
164
227
|
|
|
165
|
-
|
|
166
|
-
|
|
228
|
+
matmul_PPc = lambda x_: (
|
|
229
|
+
P @ ((x_.unsqueeze(-2) @ P).transpose(-2, -1) / c.unsqueeze(-1))
|
|
230
|
+
).squeeze(-1)
|
|
167
231
|
|
|
168
232
|
# mml = th.compile(matmul_PPc)
|
|
169
233
|
mml = matmul_PPc
|
|
170
234
|
|
|
171
|
-
M = lambda rho_:
|
|
235
|
+
M = lambda rho_: r_P - rho_ * diag_PPc # Diagonal preconditioner
|
|
172
236
|
M_rho = M(th.ones_like(self.rho))
|
|
173
237
|
M_rho[M_rho <= 0] = M_rho[M_rho > 0].min()
|
|
174
238
|
|
|
175
239
|
x0 = -grad_k / M_rho
|
|
176
240
|
PPc_x0 = mml(x0)
|
|
177
241
|
matmul_cnt = 2
|
|
178
|
-
r_P_x0 =
|
|
242
|
+
r_P_x0 = r_P * x0
|
|
179
243
|
|
|
180
244
|
x = x0.clone()
|
|
181
245
|
PPc_x = PPc_x0.clone()
|
|
@@ -198,7 +262,7 @@ class TruncatedNewtonProjector:
|
|
|
198
262
|
best_sol[r_true_norm < best_r_true_norm] = x[r_true_norm < best_r_true_norm]
|
|
199
263
|
best_r_true_norm = th.min(r_true_norm, best_r_true_norm)
|
|
200
264
|
|
|
201
|
-
rho[r_true_norm > tol] = 1. - (1. - rho[r_true_norm > tol]) * 0.25
|
|
265
|
+
rho[r_true_norm > tol] = 1.0 - (1.0 - rho[r_true_norm > tol]) * 0.25
|
|
202
266
|
M_rho = M(rho)
|
|
203
267
|
|
|
204
268
|
if matmul_cnt > 0:
|
|
@@ -228,8 +292,10 @@ class TruncatedNewtonProjector:
|
|
|
228
292
|
|
|
229
293
|
quad = (Fr_p * p).sum(-1, keepdim=True)
|
|
230
294
|
if (quad <= 0)[best_r_true_norm > tol].any():
|
|
231
|
-
warnings.warn(
|
|
232
|
-
|
|
295
|
+
warnings.warn(
|
|
296
|
+
"Warning: negative curvature encountered in CG. Returning best solution. "
|
|
297
|
+
f"Residual norm less than error: {(best_r_true_norm < err).item()}"
|
|
298
|
+
)
|
|
233
299
|
x = best_sol.clone()
|
|
234
300
|
done = True
|
|
235
301
|
success = best_r_true_norm < err
|
|
@@ -242,12 +308,15 @@ class TruncatedNewtonProjector:
|
|
|
242
308
|
res += alpha * Fr_p
|
|
243
309
|
r_norm = res.norm(p=1, dim=-1)
|
|
244
310
|
|
|
245
|
-
if
|
|
311
|
+
if (
|
|
312
|
+
th.isnan(r_norm)[best_r_true_norm > tol].any()
|
|
313
|
+
or th.isinf(r_norm)[best_r_true_norm > tol].any()
|
|
314
|
+
):
|
|
246
315
|
raise ValueError("NaNs or infs encountered in r_norm")
|
|
247
316
|
|
|
248
317
|
PPc_x += alpha * PPc_p
|
|
249
318
|
|
|
250
|
-
r_P_x =
|
|
319
|
+
r_P_x = r_P * x
|
|
251
320
|
res_true = r_P_x - PPc_x + grad_k
|
|
252
321
|
r_true_norm = res_true.norm(p=1, dim=-1)
|
|
253
322
|
best_sol[r_true_norm < best_r_true_norm] = x[r_true_norm < best_r_true_norm]
|
|
@@ -280,4 +349,4 @@ class TruncatedNewtonProjector:
|
|
|
280
349
|
Pc_x = ((x.unsqueeze(-2) @ P).transpose(-2, -1) / c.unsqueeze(-1)).squeeze(-1)
|
|
281
350
|
matmul_cnt += 1
|
|
282
351
|
|
|
283
|
-
return x, -Pc_x, matmul_cnt, rho, success
|
|
352
|
+
return x, -Pc_x, matmul_cnt, rho, success
|