mdot-tnt 0.2.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mdot_tnt/__init__.py +41 -10
- mdot_tnt/batched.py +634 -0
- mdot_tnt/mdot.py +97 -43
- mdot_tnt/py.typed +0 -0
- mdot_tnt/rounding.py +41 -15
- mdot_tnt/truncated_newton.py +104 -34
- mdot_tnt-1.0.0.dist-info/METADATA +216 -0
- mdot_tnt-1.0.0.dist-info/RECORD +11 -0
- {mdot_tnt-0.2.0.dist-info → mdot_tnt-1.0.0.dist-info}/WHEEL +1 -1
- {mdot_tnt-0.2.0.dist-info → mdot_tnt-1.0.0.dist-info/licenses}/LICENSE +4 -1
- mdot_tnt-0.2.0.dist-info/METADATA +0 -71
- mdot_tnt-0.2.0.dist-info/RECORD +0 -9
- {mdot_tnt-0.2.0.dist-info → mdot_tnt-1.0.0.dist-info}/top_level.txt +0 -0
mdot_tnt/mdot.py
CHANGED
|
@@ -1,21 +1,33 @@
|
|
|
1
|
+
"""Core MDOT solver using truncated Newton projection."""
|
|
1
2
|
|
|
2
3
|
import warnings
|
|
4
|
+
from typing import Any, Dict, List, Tuple, Union
|
|
5
|
+
|
|
6
|
+
import torch as th
|
|
3
7
|
|
|
4
|
-
from mdot_tnt.rounding import *
|
|
5
8
|
from mdot_tnt.truncated_newton import TruncatedNewtonProjector
|
|
6
9
|
|
|
7
10
|
|
|
8
|
-
def preprocess_marginals(
|
|
11
|
+
def preprocess_marginals(
|
|
12
|
+
r: th.Tensor, c: th.Tensor, C: th.Tensor, eps: float
|
|
13
|
+
) -> Tuple[Tuple[th.Tensor, th.Tensor], Tuple[th.Tensor, th.Tensor], th.Tensor]:
|
|
9
14
|
"""
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
:
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
15
|
+
Drop the smallest marginal entries whose cumulative sum is below a threshold.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
r: The row marginal of shape (n,).
|
|
19
|
+
c: The column marginal of shape (m,).
|
|
20
|
+
C: The cost matrix of shape (n, m).
|
|
21
|
+
eps: The threshold for the cumulative sum of the marginal entries to be dropped.
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
A tuple containing:
|
|
25
|
+
- (r_new, r_keep): The new row marginal and indices of kept entries.
|
|
26
|
+
- (c_new, c_keep): The new column marginal and indices of kept entries.
|
|
27
|
+
- C: The cost matrix with corresponding rows and columns dropped.
|
|
16
28
|
"""
|
|
17
29
|
|
|
18
|
-
def preprocess_marginal(m, eps):
|
|
30
|
+
def preprocess_marginal(m: th.Tensor, eps: float) -> Tuple[th.Tensor, th.Tensor]:
|
|
19
31
|
m_sorted, m_idx = th.sort(m, dim=-1, descending=False)
|
|
20
32
|
m_cumsum = th.cumsum(m_sorted, dim=-1)
|
|
21
33
|
m_keep = m_idx[m_cumsum > eps]
|
|
@@ -27,76 +39,115 @@ def preprocess_marginals(r, c, C, eps):
|
|
|
27
39
|
|
|
28
40
|
r_new, r_keep = preprocess_marginal(r, eps)
|
|
29
41
|
c_new, c_keep = preprocess_marginal(c, eps)
|
|
30
|
-
print(
|
|
31
|
-
r.size(-1) - r_new.size(-1)
|
|
42
|
+
print(
|
|
43
|
+
f"Dropped {r.size(-1) - r_new.size(-1)} entries from r and {c.size(-1) - c_new.size(-1)} entries from c."
|
|
44
|
+
)
|
|
32
45
|
|
|
33
46
|
C = C[r_keep][:, c_keep]
|
|
34
47
|
|
|
35
48
|
return (r_new, r_keep), (c_new, c_keep), C
|
|
36
49
|
|
|
37
50
|
|
|
38
|
-
def smooth_marginals(
|
|
51
|
+
def smooth_marginals(
|
|
52
|
+
r: th.Tensor,
|
|
53
|
+
c: th.Tensor,
|
|
54
|
+
eps: th.Tensor,
|
|
55
|
+
w_r: float = 0.5,
|
|
56
|
+
w_c: float = 0.5,
|
|
57
|
+
) -> Tuple[th.Tensor, th.Tensor]:
|
|
39
58
|
"""
|
|
40
59
|
Smooth the marginals by adding a small amount of uniform mass to each entry.
|
|
41
|
-
|
|
42
|
-
:
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
r: The row marginal of shape (n,).
|
|
63
|
+
c: The column marginal of shape (m,).
|
|
64
|
+
eps: The amount of mass to add to each entry.
|
|
65
|
+
w_r: The weight for the row marginal.
|
|
66
|
+
w_c: The weight for the column marginal.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
A tuple (r_hat, c_hat) of smoothed marginals with total TV distance at most eps
|
|
70
|
+
from the original marginals.
|
|
47
71
|
"""
|
|
48
72
|
assert w_r + w_c == 1, "w_r and w_c must sum to 1"
|
|
49
|
-
eps = eps.clamp(max=1.).unsqueeze(-1)
|
|
73
|
+
eps = eps.clamp(max=1.0).unsqueeze(-1)
|
|
50
74
|
r_hat = (1 - w_r * eps) * r + w_r * eps * th.ones_like(r) / r.size(-1)
|
|
51
75
|
c_hat = (1 - w_c * eps) * c + w_c * eps * th.ones_like(c) / c.size(-1)
|
|
52
76
|
|
|
53
77
|
return r_hat, c_hat
|
|
54
78
|
|
|
55
79
|
|
|
56
|
-
def adjust_schedule(q, deltas=None):
|
|
80
|
+
def adjust_schedule(q: float, deltas: Union[List[float], None] = None) -> float:
|
|
57
81
|
"""
|
|
58
82
|
Adjust the temperature annealing schedule based on the success of the Truncated Newton method.
|
|
59
|
-
|
|
60
|
-
:
|
|
61
|
-
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
q: The current temperature annealing schedule adjustment factor.
|
|
86
|
+
deltas: The list of deltas from the Truncated Newton method;
|
|
87
|
+
see Sec. 3.3 of Kemertas et al. (2025).
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
The new temperature annealing schedule adjustment factor.
|
|
62
91
|
"""
|
|
63
92
|
if deltas is None:
|
|
64
93
|
return q
|
|
65
94
|
|
|
66
|
-
deltas = deltas + [1.] # If deltas is empty, we assume that the first iteration was successful
|
|
95
|
+
deltas = deltas + [1.0] # If deltas is empty, we assume that the first iteration was successful
|
|
67
96
|
delta_min = min(deltas)
|
|
68
97
|
|
|
69
98
|
if delta_min < 0.5:
|
|
70
|
-
q = q
|
|
99
|
+
q = q**0.5
|
|
71
100
|
elif delta_min > 0.9:
|
|
72
|
-
q = q
|
|
101
|
+
q = q**2
|
|
73
102
|
|
|
74
103
|
return q
|
|
75
104
|
|
|
76
105
|
|
|
77
|
-
def mdot(
|
|
106
|
+
def mdot(
|
|
107
|
+
r: th.Tensor,
|
|
108
|
+
c: th.Tensor,
|
|
109
|
+
C: th.Tensor,
|
|
110
|
+
gamma_f: float,
|
|
111
|
+
gamma_i: float = 16,
|
|
112
|
+
p: float = 1.5,
|
|
113
|
+
q: float = 2.0,
|
|
114
|
+
) -> Tuple[th.Tensor, th.Tensor, float, int, Dict[str, Any]]:
|
|
78
115
|
"""
|
|
79
|
-
Solve the entropic-regularized optimal transport problem using the MDOT method
|
|
80
|
-
|
|
81
|
-
|
|
116
|
+
Solve the entropic-regularized optimal transport problem using the MDOT method.
|
|
117
|
+
|
|
118
|
+
This implements the MDOT method introduced in the paper:
|
|
119
|
+
"Efficient and Accurate Optimal Transport with Mirror Descent and Conjugate Gradients"
|
|
120
|
+
by Mete Kemertas, Allan D. Jepson and Amir-massoud Farahmand.
|
|
121
|
+
URL: https://arxiv.org/abs/2307.08507
|
|
122
|
+
|
|
82
123
|
Here, we use the Truncated Newton method for projection.
|
|
83
|
-
|
|
84
|
-
:
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
r: The first marginal of shape (n,).
|
|
127
|
+
c: The second marginal of shape (m,).
|
|
128
|
+
C: The cost matrix of shape (n, m). Recommended to scale entries to [0, 1].
|
|
129
|
+
gamma_f: The final temperature (inverse of the regularization weight).
|
|
130
|
+
gamma_i: The initial temperature.
|
|
131
|
+
p: The exponent for the epsilon function, used to determine the stopping
|
|
132
|
+
criterion for the dual gradient.
|
|
133
|
+
q: The temperature annealing (or mirror descent step size) schedule adjustment factor.
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
A tuple containing:
|
|
137
|
+
- u: The row dual variables of shape (n,).
|
|
138
|
+
- v: The column dual variables of shape (m,).
|
|
139
|
+
- gamma: The final temperature achieved.
|
|
140
|
+
- k_total: The total number of O(n^2) primitive operations.
|
|
141
|
+
- logs: Dictionary with optimization statistics.
|
|
91
142
|
"""
|
|
92
143
|
projector = TruncatedNewtonProjector(device=C.device, dtype=C.dtype)
|
|
93
144
|
|
|
94
145
|
H_r = -(r * (r + 1e-30).log()).sum(-1)
|
|
95
146
|
H_c = -(c * (c + 1e-30).log()).sum(-1)
|
|
96
147
|
H_min = th.min(H_r, H_c)
|
|
97
|
-
eps_fn = lambda g_: H_min / (g_
|
|
148
|
+
eps_fn = lambda g_: H_min / (g_**p)
|
|
98
149
|
|
|
99
|
-
logs = {
|
|
150
|
+
logs: Dict[str, Any] = {
|
|
100
151
|
"proj_logs": [],
|
|
101
152
|
"eps": [],
|
|
102
153
|
}
|
|
@@ -104,7 +155,7 @@ def mdot(r, c, C, gamma_f, gamma_i=16, p=1.5, q=2.0):
|
|
|
104
155
|
t = 1
|
|
105
156
|
done = False
|
|
106
157
|
gamma = min(gamma_i, gamma_f)
|
|
107
|
-
gammas = [0
|
|
158
|
+
gammas = [0.0, gamma]
|
|
108
159
|
|
|
109
160
|
while not done:
|
|
110
161
|
done = abs(gamma - gamma_f) < 1e-5 # Check if gamma == gamma_f (modulo rounding errors)
|
|
@@ -120,11 +171,14 @@ def mdot(r, c, C, gamma_f, gamma_i=16, p=1.5, q=2.0):
|
|
|
120
171
|
u_prev, v_prev = u_cur.clone(), v_cur.clone()
|
|
121
172
|
gamma_C = gamma * C
|
|
122
173
|
u_cur, v_cur, proj_log, success = projector.project(
|
|
123
|
-
gamma_C, r_hat.log(), c_hat.log(), eps_d / 2, u_init, v_init
|
|
174
|
+
gamma_C, r_hat.log(), c_hat.log(), eps_d / 2, u_init, v_init
|
|
175
|
+
)
|
|
124
176
|
|
|
125
177
|
logs["proj_logs"].append(proj_log)
|
|
126
178
|
if not success:
|
|
127
|
-
warnings.warn(
|
|
179
|
+
warnings.warn(
|
|
180
|
+
f"Projection failed. Returning result at the last temperature: {1 / gammas[-2]:.4e}"
|
|
181
|
+
)
|
|
128
182
|
u_cur = u_prev.clone()
|
|
129
183
|
v_cur = v_prev.clone()
|
|
130
184
|
gammas = gammas[:-1]
|
mdot_tnt/py.typed
ADDED
|
File without changes
|
mdot_tnt/rounding.py
CHANGED
|
@@ -6,16 +6,22 @@ and c are the row and column marginals, respectively. The algorithm is used in t
|
|
|
6
6
|
plan and compute the cost of the rounded plan. The implementation is based on the original paper.
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
|
+
from typing import Union
|
|
10
|
+
|
|
9
11
|
import torch as th
|
|
10
12
|
|
|
11
13
|
|
|
12
|
-
def round_altschuler(P, r, c):
|
|
14
|
+
def round_altschuler(P: th.Tensor, r: th.Tensor, c: th.Tensor) -> th.Tensor:
|
|
13
15
|
"""
|
|
14
16
|
Performs rounding given a transport plan and marginals.
|
|
15
|
-
|
|
16
|
-
:
|
|
17
|
-
|
|
18
|
-
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
P: The input transport plan of shape (n, m).
|
|
20
|
+
r: Row marginal of shape (n,).
|
|
21
|
+
c: Column marginal of shape (m,).
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
Rounded transport plan in feasible set U(r, c).
|
|
19
25
|
"""
|
|
20
26
|
X = th.min(r / P.sum(-1), th.ones_like(r))
|
|
21
27
|
P *= X.unsqueeze(-1)
|
|
@@ -25,19 +31,39 @@ def round_altschuler(P, r, c):
|
|
|
25
31
|
|
|
26
32
|
err_r = (r - P.sum(-1)).clamp(min=0)
|
|
27
33
|
err_c = (c - P.sum(-2)).clamp(min=0)
|
|
28
|
-
P +=
|
|
34
|
+
P += (
|
|
35
|
+
err_r.unsqueeze(-1)
|
|
36
|
+
@ err_c.unsqueeze(-2)
|
|
37
|
+
/ (err_r.norm(p=1, dim=-1, keepdim=True) + 1e-30).unsqueeze(-1)
|
|
38
|
+
)
|
|
29
39
|
|
|
30
40
|
return P
|
|
31
41
|
|
|
32
42
|
|
|
33
|
-
def rounded_cost_altschuler(
|
|
34
|
-
|
|
35
|
-
:
|
|
36
|
-
:
|
|
37
|
-
:
|
|
38
|
-
:
|
|
39
|
-
:
|
|
40
|
-
|
|
43
|
+
def rounded_cost_altschuler(
|
|
44
|
+
u: th.Tensor,
|
|
45
|
+
v: th.Tensor,
|
|
46
|
+
r: th.Tensor,
|
|
47
|
+
c: th.Tensor,
|
|
48
|
+
C: th.Tensor,
|
|
49
|
+
gamma: Union[float, th.Tensor],
|
|
50
|
+
) -> th.Tensor:
|
|
51
|
+
"""
|
|
52
|
+
Performs rounding and cost computation in log-domain given dual variables.
|
|
53
|
+
|
|
54
|
+
This function computes the transport cost without storing the full n×m transport plan,
|
|
55
|
+
making it memory efficient.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
u: Dual variable for rows of shape (n,).
|
|
59
|
+
v: Dual variable for columns of shape (m,).
|
|
60
|
+
r: Row marginal of shape (n,).
|
|
61
|
+
c: Column marginal of shape (m,).
|
|
62
|
+
C: Cost matrix of shape (n, m).
|
|
63
|
+
gamma: Temperature (inverse of the entropic regularization weight).
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
The optimal transport cost as a scalar tensor.
|
|
41
67
|
"""
|
|
42
68
|
r_P_log = u + th.logsumexp(v.unsqueeze(-2) - gamma * C, dim=-1)
|
|
43
69
|
delta_u = th.min(r.log() - r_P_log, th.zeros_like(r))
|
|
@@ -50,7 +76,7 @@ def rounded_cost_altschuler(u, v, r, c, C, gamma):
|
|
|
50
76
|
r_P_log = u + th.logsumexp(v.unsqueeze(-2) - gamma * C, dim=-1)
|
|
51
77
|
r_P = r_P_log.exp()
|
|
52
78
|
err_r = r - r_P
|
|
53
|
-
err_r /=
|
|
79
|
+
err_r /= err_r.norm(p=1, dim=-1, keepdim=True) + 1e-30
|
|
54
80
|
|
|
55
81
|
c_P_log = v + th.logsumexp(u.unsqueeze(-1) - gamma * C, dim=-2)
|
|
56
82
|
c_P = c_P_log.exp()
|
mdot_tnt/truncated_newton.py
CHANGED
|
@@ -1,27 +1,67 @@
|
|
|
1
|
+
"""Truncated Newton projector for the MDOT algorithm."""
|
|
1
2
|
|
|
2
|
-
import torch as th
|
|
3
3
|
import warnings
|
|
4
|
+
from typing import Any, Callable, Dict, Tuple, Union
|
|
5
|
+
|
|
6
|
+
import torch as th
|
|
4
7
|
|
|
5
8
|
|
|
6
9
|
class TruncatedNewtonProjector:
|
|
7
|
-
|
|
10
|
+
"""
|
|
11
|
+
Truncated Newton projector for the MDOT algorithm.
|
|
12
|
+
|
|
13
|
+
Projects onto the set of couplings satisfying marginal constraints using
|
|
14
|
+
a preconditioned conjugate gradient method within a Newton framework.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, device: th.device, dtype: th.dtype, **kwargs: Any) -> None:
|
|
18
|
+
"""
|
|
19
|
+
Initialize the projector.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
device: PyTorch device for computations.
|
|
23
|
+
dtype: Data type for tensors.
|
|
24
|
+
**kwargs: Additional options (debug: bool for verbose output).
|
|
25
|
+
"""
|
|
8
26
|
self.device = device
|
|
9
27
|
self.rho = th.zeros(1, device=device, dtype=dtype)
|
|
10
|
-
self.debug = kwargs.get(
|
|
11
|
-
|
|
12
|
-
|
|
28
|
+
self.debug = kwargs.get("debug", False)
|
|
29
|
+
self.LSE_r: Callable[[th.Tensor], th.Tensor]
|
|
30
|
+
self.LSE_c: Callable[[th.Tensor], th.Tensor]
|
|
31
|
+
|
|
32
|
+
def project(
|
|
33
|
+
self,
|
|
34
|
+
gamma_C: th.Tensor,
|
|
35
|
+
log_r: th.Tensor,
|
|
36
|
+
log_c: th.Tensor,
|
|
37
|
+
eps_d: Union[float, th.Tensor],
|
|
38
|
+
u: th.Tensor,
|
|
39
|
+
v: th.Tensor,
|
|
40
|
+
) -> Tuple[th.Tensor, th.Tensor, Dict[str, Any], bool]:
|
|
13
41
|
"""
|
|
14
42
|
Project onto the set of couplings that satisfy the marginal constraints.
|
|
15
|
-
|
|
16
|
-
:
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
gamma_C: The cost matrix scaled by gamma, shape (n, m).
|
|
46
|
+
log_r: Log of row marginals, shape (n,).
|
|
47
|
+
log_c: Log of column marginals, shape (m,).
|
|
48
|
+
eps_d: Convergence tolerance for the dual gradient norm.
|
|
49
|
+
u: Initial row dual variables, shape (n,).
|
|
50
|
+
v: Initial column dual variables, shape (m,).
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
u: Updated row dual variables.
|
|
54
|
+
v: Updated column dual variables.
|
|
55
|
+
logs: Dictionary with optimization statistics.
|
|
56
|
+
success: Whether projection converged successfully.
|
|
17
57
|
"""
|
|
18
|
-
logs = {
|
|
58
|
+
logs: Dict[str, Any] = {
|
|
19
59
|
"errs": [],
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
60
|
+
"ls_func_cnt": 0,
|
|
61
|
+
"chisinkhorn_steps": 0,
|
|
62
|
+
"newtonsolve_steps": 0,
|
|
23
63
|
"deltas": [], # Ratios of actual to theoretically predicted (ideal) reduction in gradient norm.
|
|
24
|
-
"all_newtonsolve_steps": []
|
|
64
|
+
"all_newtonsolve_steps": [],
|
|
25
65
|
}
|
|
26
66
|
# In case of errors or issues, 10 times the tolerance level is considered
|
|
27
67
|
# a good enough solution to keep MDOT going.
|
|
@@ -57,19 +97,20 @@ class TruncatedNewtonProjector:
|
|
|
57
97
|
self.rho = th.max(th.zeros_like(self.rho), self.rho)
|
|
58
98
|
|
|
59
99
|
P = th.exp(u.unsqueeze(-1) + v.unsqueeze(-2) - gamma_C)
|
|
60
|
-
diag_PPc = ((P
|
|
100
|
+
diag_PPc = ((P**2) / c.unsqueeze(-2)).sum(-1)
|
|
61
101
|
k += 8
|
|
62
102
|
delta_u, delta_v, matmul_cnt, rho, pcg_success = self.newton_solve(
|
|
63
|
-
P, c, diag_PPc, grad_k, r_P, err, beta, eta_k, maxIter=5000
|
|
103
|
+
P, c, diag_PPc, grad_k, r_P, err, beta, eta_k, maxIter=5000
|
|
104
|
+
)
|
|
64
105
|
del P # Free up memory
|
|
65
106
|
if not pcg_success:
|
|
66
107
|
k += matmul_cnt
|
|
67
108
|
logs["n_iter"] = k
|
|
68
|
-
msg = "PCG did not converge. TruncatedNewton returning with success={
|
|
109
|
+
msg = f"PCG did not converge. TruncatedNewton returning with success={success_fn(err)}"
|
|
69
110
|
warnings.warn(msg)
|
|
70
111
|
return u, v, logs, success_fn(err)
|
|
71
112
|
|
|
72
|
-
self.rho = th.max(th.zeros_like(self.rho), 1. - (1. - rho) * 4.)
|
|
113
|
+
self.rho = th.max(th.zeros_like(self.rho), 1.0 - (1.0 - rho) * 4.0)
|
|
73
114
|
k += matmul_cnt
|
|
74
115
|
logs["newtonsolve_steps"] += matmul_cnt
|
|
75
116
|
|
|
@@ -79,8 +120,7 @@ class TruncatedNewtonProjector:
|
|
|
79
120
|
linear_decr = -(grad_k * delta_u).sum(-1, keepdim=True)
|
|
80
121
|
if not linear_decr > 0:
|
|
81
122
|
logs["n_iter"] = k
|
|
82
|
-
msg = "Linear decrease condition not satisfied. TruncatedNewton returning with success={}"
|
|
83
|
-
success_fn(err))
|
|
123
|
+
msg = f"Linear decrease condition not satisfied. TruncatedNewton returning with success={success_fn(err)}"
|
|
84
124
|
warnings.warn(msg)
|
|
85
125
|
return u, v, logs, success_fn(err)
|
|
86
126
|
|
|
@@ -89,8 +129,7 @@ class TruncatedNewtonProjector:
|
|
|
89
129
|
alpha *= 0.5
|
|
90
130
|
if alpha < 1e-9:
|
|
91
131
|
logs["n_iter"] = k
|
|
92
|
-
msg = "Line search did not converge. TruncatedNewton returning with success={}"
|
|
93
|
-
success_fn(err))
|
|
132
|
+
msg = f"Line search did not converge. TruncatedNewton returning with success={success_fn(err)}"
|
|
94
133
|
warnings.warn(msg)
|
|
95
134
|
return u, v, logs, success_fn(err)
|
|
96
135
|
|
|
@@ -113,13 +152,17 @@ class TruncatedNewtonProjector:
|
|
|
113
152
|
log_r_P = u + self.LSE_r(v)
|
|
114
153
|
k += 4
|
|
115
154
|
|
|
116
|
-
u, v, log_r_P, err, k_ = self.chi_sinkhorn(
|
|
155
|
+
u, v, log_r_P, err, k_ = self.chi_sinkhorn(
|
|
156
|
+
u, v, log_r, log_c, log_r_P, eps_d ** (2 / 5)
|
|
157
|
+
)
|
|
117
158
|
r_P = log_r_P.exp()
|
|
118
159
|
logs["chisinkhorn_steps"] += k_
|
|
119
160
|
k += k_
|
|
120
161
|
|
|
121
162
|
logs["errs"].append(err)
|
|
122
|
-
logs["deltas"].append(
|
|
163
|
+
logs["deltas"].append(
|
|
164
|
+
th.min((logs["errs"][-2] - err_before_sk) / ((1 - eta_k) * logs["errs"][-2])).item()
|
|
165
|
+
)
|
|
123
166
|
|
|
124
167
|
if u.isnan().any() or v.isnan().any():
|
|
125
168
|
raise ValueError("NaNs encountered in u or v")
|
|
@@ -132,7 +175,16 @@ class TruncatedNewtonProjector:
|
|
|
132
175
|
|
|
133
176
|
return u, v, logs, True
|
|
134
177
|
|
|
135
|
-
def chi_sinkhorn(
|
|
178
|
+
def chi_sinkhorn(
|
|
179
|
+
self,
|
|
180
|
+
u: th.Tensor,
|
|
181
|
+
v: th.Tensor,
|
|
182
|
+
log_r: th.Tensor,
|
|
183
|
+
log_c: th.Tensor,
|
|
184
|
+
log_r_P: th.Tensor,
|
|
185
|
+
eps_chi: Union[float, th.Tensor],
|
|
186
|
+
maxOps: float = float("inf"),
|
|
187
|
+
) -> Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, int]:
|
|
136
188
|
k = 0
|
|
137
189
|
r = log_r.exp()
|
|
138
190
|
err = (r - log_r_P.exp()).norm(p=1, dim=-1)
|
|
@@ -154,27 +206,40 @@ class TruncatedNewtonProjector:
|
|
|
154
206
|
k += 8
|
|
155
207
|
|
|
156
208
|
if k >= maxOps:
|
|
157
|
-
raise ValueError("Chi-Sinkhorn did not converge in maxIter={} steps"
|
|
209
|
+
raise ValueError(f"Chi-Sinkhorn did not converge in maxIter={maxOps} steps")
|
|
158
210
|
|
|
159
211
|
return u, v, log_r_P, err, k
|
|
160
212
|
|
|
161
|
-
def newton_solve(
|
|
213
|
+
def newton_solve(
|
|
214
|
+
self,
|
|
215
|
+
P: th.Tensor,
|
|
216
|
+
c: th.Tensor,
|
|
217
|
+
diag_PPc: th.Tensor,
|
|
218
|
+
grad_k: th.Tensor,
|
|
219
|
+
r_P: th.Tensor,
|
|
220
|
+
err: th.Tensor,
|
|
221
|
+
beta: float = 0.5,
|
|
222
|
+
eta_k: Union[float, th.Tensor] = 0.5,
|
|
223
|
+
maxIter: int = 500,
|
|
224
|
+
) -> Tuple[th.Tensor, th.Tensor, int, th.Tensor, bool]:
|
|
162
225
|
rho = self.rho
|
|
163
226
|
tol = err * eta_k
|
|
164
227
|
|
|
165
|
-
matmul_PPc = lambda x_: (
|
|
228
|
+
matmul_PPc = lambda x_: (
|
|
229
|
+
P @ ((x_.unsqueeze(-2) @ P).transpose(-2, -1) / c.unsqueeze(-1))
|
|
230
|
+
).squeeze(-1)
|
|
166
231
|
|
|
167
232
|
# mml = th.compile(matmul_PPc)
|
|
168
233
|
mml = matmul_PPc
|
|
169
234
|
|
|
170
|
-
M = lambda rho_:
|
|
235
|
+
M = lambda rho_: r_P - rho_ * diag_PPc # Diagonal preconditioner
|
|
171
236
|
M_rho = M(th.ones_like(self.rho))
|
|
172
237
|
M_rho[M_rho <= 0] = M_rho[M_rho > 0].min()
|
|
173
238
|
|
|
174
239
|
x0 = -grad_k / M_rho
|
|
175
240
|
PPc_x0 = mml(x0)
|
|
176
241
|
matmul_cnt = 2
|
|
177
|
-
r_P_x0 =
|
|
242
|
+
r_P_x0 = r_P * x0
|
|
178
243
|
|
|
179
244
|
x = x0.clone()
|
|
180
245
|
PPc_x = PPc_x0.clone()
|
|
@@ -197,7 +262,7 @@ class TruncatedNewtonProjector:
|
|
|
197
262
|
best_sol[r_true_norm < best_r_true_norm] = x[r_true_norm < best_r_true_norm]
|
|
198
263
|
best_r_true_norm = th.min(r_true_norm, best_r_true_norm)
|
|
199
264
|
|
|
200
|
-
rho[r_true_norm > tol] = 1. - (1. - rho[r_true_norm > tol]) * 0.25
|
|
265
|
+
rho[r_true_norm > tol] = 1.0 - (1.0 - rho[r_true_norm > tol]) * 0.25
|
|
201
266
|
M_rho = M(rho)
|
|
202
267
|
|
|
203
268
|
if matmul_cnt > 0:
|
|
@@ -227,8 +292,10 @@ class TruncatedNewtonProjector:
|
|
|
227
292
|
|
|
228
293
|
quad = (Fr_p * p).sum(-1, keepdim=True)
|
|
229
294
|
if (quad <= 0)[best_r_true_norm > tol].any():
|
|
230
|
-
warnings.warn(
|
|
231
|
-
|
|
295
|
+
warnings.warn(
|
|
296
|
+
"Warning: negative curvature encountered in CG. Returning best solution. "
|
|
297
|
+
f"Residual norm less than error: {(best_r_true_norm < err).item()}"
|
|
298
|
+
)
|
|
232
299
|
x = best_sol.clone()
|
|
233
300
|
done = True
|
|
234
301
|
success = best_r_true_norm < err
|
|
@@ -241,12 +308,15 @@ class TruncatedNewtonProjector:
|
|
|
241
308
|
res += alpha * Fr_p
|
|
242
309
|
r_norm = res.norm(p=1, dim=-1)
|
|
243
310
|
|
|
244
|
-
if
|
|
311
|
+
if (
|
|
312
|
+
th.isnan(r_norm)[best_r_true_norm > tol].any()
|
|
313
|
+
or th.isinf(r_norm)[best_r_true_norm > tol].any()
|
|
314
|
+
):
|
|
245
315
|
raise ValueError("NaNs or infs encountered in r_norm")
|
|
246
316
|
|
|
247
317
|
PPc_x += alpha * PPc_p
|
|
248
318
|
|
|
249
|
-
r_P_x =
|
|
319
|
+
r_P_x = r_P * x
|
|
250
320
|
res_true = r_P_x - PPc_x + grad_k
|
|
251
321
|
r_true_norm = res_true.norm(p=1, dim=-1)
|
|
252
322
|
best_sol[r_true_norm < best_r_true_norm] = x[r_true_norm < best_r_true_norm]
|
|
@@ -279,4 +349,4 @@ class TruncatedNewtonProjector:
|
|
|
279
349
|
Pc_x = ((x.unsqueeze(-2) @ P).transpose(-2, -1) / c.unsqueeze(-1)).squeeze(-1)
|
|
280
350
|
matmul_cnt += 1
|
|
281
351
|
|
|
282
|
-
return x, -Pc_x, matmul_cnt, rho, success
|
|
352
|
+
return x, -Pc_x, matmul_cnt, rho, success
|