mdot-tnt 0.1.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mdot_tnt/mdot.py CHANGED
@@ -1,33 +1,37 @@
1
- """
2
- Code for solving the entropic-regularized optimal transport problem via the MDOT-TruncatedNewton (MDOT-TNT)
3
- method introduced in the paper "A Truncated Newton Method for Optimal Transport"
4
- by Mete Kemertas, Amir-massoud Farahmand, Allan D. Jepson (ICLR, 2025).
5
- URL: https://openreview.net/forum?id=gWrWUaCbMa
6
- """
1
+ """Core MDOT solver using truncated Newton projection."""
7
2
 
8
-
9
- import math
10
3
  import warnings
4
+ from typing import Any, Dict, List, Tuple, Union
5
+
6
+ import torch as th
11
7
 
12
- from mdot_tnt.rounding import *
13
8
  from mdot_tnt.truncated_newton import TruncatedNewtonProjector
14
9
 
15
10
 
16
- def preprocess_marginals(r, c, C, eps):
11
+ def preprocess_marginals(
12
+ r: th.Tensor, c: th.Tensor, C: th.Tensor, eps: float
13
+ ) -> Tuple[Tuple[th.Tensor, th.Tensor], Tuple[th.Tensor, th.Tensor], th.Tensor]:
17
14
  """
18
- This function drops the smallest entries whose cumulative sum equals
19
- :param r:
20
- :param c:
21
- :param C:
22
- :param eps:
23
- :return:
15
+ Drop the smallest marginal entries whose cumulative sum is below a threshold.
16
+
17
+ Args:
18
+ r: The row marginal of shape (n,).
19
+ c: The column marginal of shape (m,).
20
+ C: The cost matrix of shape (n, m).
21
+ eps: The threshold for the cumulative sum of the marginal entries to be dropped.
22
+
23
+ Returns:
24
+ A tuple containing:
25
+ - (r_new, r_keep): The new row marginal and indices of kept entries.
26
+ - (c_new, c_keep): The new column marginal and indices of kept entries.
27
+ - C: The cost matrix with corresponding rows and columns dropped.
24
28
  """
25
29
 
26
- def preprocess_marginal(m, eps):
30
+ def preprocess_marginal(m: th.Tensor, eps: float) -> Tuple[th.Tensor, th.Tensor]:
27
31
  m_sorted, m_idx = th.sort(m, dim=-1, descending=False)
28
32
  m_cumsum = th.cumsum(m_sorted, dim=-1)
29
33
  m_keep = m_idx[m_cumsum > eps]
30
- m_new = m[:, m_keep]
34
+ m_new = m[m_keep]
31
35
  mass_removed = 1 - m_new.sum(-1)
32
36
  m_new = m_new + mass_removed / m_new.size(-1)
33
37
 
@@ -35,59 +39,115 @@ def preprocess_marginals(r, c, C, eps):
35
39
 
36
40
  r_new, r_keep = preprocess_marginal(r, eps)
37
41
  c_new, c_keep = preprocess_marginal(c, eps)
42
+ print(
43
+ f"Dropped {r.size(-1) - r_new.size(-1)} entries from r and {c.size(-1) - c_new.size(-1)} entries from c."
44
+ )
38
45
 
39
46
  C = C[r_keep][:, c_keep]
40
47
 
41
48
  return (r_new, r_keep), (c_new, c_keep), C
42
49
 
43
50
 
44
- def smooth_marginals(r, c, eps, w_r=0.5, w_c=0.5):
51
+ def smooth_marginals(
52
+ r: th.Tensor,
53
+ c: th.Tensor,
54
+ eps: th.Tensor,
55
+ w_r: float = 0.5,
56
+ w_c: float = 0.5,
57
+ ) -> Tuple[th.Tensor, th.Tensor]:
58
+ """
59
+ Smooth the marginals by adding a small amount of uniform mass to each entry.
60
+
61
+ Args:
62
+ r: The row marginal of shape (n,).
63
+ c: The column marginal of shape (m,).
64
+ eps: The amount of mass to add to each entry.
65
+ w_r: The weight for the row marginal.
66
+ w_c: The weight for the column marginal.
67
+
68
+ Returns:
69
+ A tuple (r_hat, c_hat) of smoothed marginals with total TV distance at most eps
70
+ from the original marginals.
71
+ """
45
72
  assert w_r + w_c == 1, "w_r and w_c must sum to 1"
46
- eps = eps.clamp(max=1.).unsqueeze(-1)
73
+ eps = eps.clamp(max=1.0).unsqueeze(-1)
47
74
  r_hat = (1 - w_r * eps) * r + w_r * eps * th.ones_like(r) / r.size(-1)
48
75
  c_hat = (1 - w_c * eps) * c + w_c * eps * th.ones_like(c) / c.size(-1)
49
76
 
50
77
  return r_hat, c_hat
51
78
 
52
79
 
53
- def adjust_schedule(q, deltas=None):
80
+ def adjust_schedule(q: float, deltas: Union[List[float], None] = None) -> float:
81
+ """
82
+ Adjust the temperature annealing schedule based on the success of the Truncated Newton method.
83
+
84
+ Args:
85
+ q: The current temperature annealing schedule adjustment factor.
86
+ deltas: The list of deltas from the Truncated Newton method;
87
+ see Sec. 3.3 of Kemertas et al. (2025).
88
+
89
+ Returns:
90
+ The new temperature annealing schedule adjustment factor.
91
+ """
54
92
  if deltas is None:
55
93
  return q
56
94
 
57
- deltas = deltas + [1.] # If deltas is empty, we assume that the first iteration was successful
95
+ deltas = deltas + [1.0] # If deltas is empty, we assume that the first iteration was successful
58
96
  delta_min = min(deltas)
59
97
 
60
98
  if delta_min < 0.5:
61
- q = q ** 0.5
99
+ q = q**0.5
62
100
  elif delta_min > 0.9:
63
- q = q ** 2
101
+ q = q**2
64
102
 
65
103
  return q
66
104
 
67
105
 
68
- def mdot(r, c, C, gamma_f, gamma_i=16, p=1.5, q=2.0**(1/3)):
106
+ def mdot(
107
+ r: th.Tensor,
108
+ c: th.Tensor,
109
+ C: th.Tensor,
110
+ gamma_f: float,
111
+ gamma_i: float = 16,
112
+ p: float = 1.5,
113
+ q: float = 2.0,
114
+ ) -> Tuple[th.Tensor, th.Tensor, float, int, Dict[str, Any]]:
69
115
  """
70
- Solve the entropic-regularized optimal transport problem using the MDOT method introduced in the paper:
71
- "Efficient and Accurate Optimal Transport with Mirror Descent and Conjugate Gradients" by Mete Kemertas,
72
- Allan D. Jepson and Amir-massoud Farahmand. URL: https://arxiv.org/abs/2307.08507
116
+ Solve the entropic-regularized optimal transport problem using the MDOT method.
117
+
118
+ This implements the MDOT method introduced in the paper:
119
+ "Efficient and Accurate Optimal Transport with Mirror Descent and Conjugate Gradients"
120
+ by Mete Kemertas, Allan D. Jepson and Amir-massoud Farahmand.
121
+ URL: https://arxiv.org/abs/2307.08507
122
+
73
123
  Here, we use the Truncated Newton method for projection.
74
- :param r: The first marginal.
75
- :param c: The second marginal.
76
- :param C: The cost matrix. Recommended use is to scale the entries to be in [0, 1].
77
- :param gamma_f: The final temperature (inverse of the regularization weight).
78
- :param gamma_i: The initial temperature.
79
- :param p: The exponent for the epsilon function, used to determine the stopping criterion for the dual gradient.
80
- :param q: The temperature annealing (or mirror descent step size) schedule adjustment factor.
81
- :return:
124
+
125
+ Args:
126
+ r: The first marginal of shape (n,).
127
+ c: The second marginal of shape (m,).
128
+ C: The cost matrix of shape (n, m). Recommended to scale entries to [0, 1].
129
+ gamma_f: The final temperature (inverse of the regularization weight).
130
+ gamma_i: The initial temperature.
131
+ p: The exponent for the epsilon function, used to determine the stopping
132
+ criterion for the dual gradient.
133
+ q: The temperature annealing (or mirror descent step size) schedule adjustment factor.
134
+
135
+ Returns:
136
+ A tuple containing:
137
+ - u: The row dual variables of shape (n,).
138
+ - v: The column dual variables of shape (m,).
139
+ - gamma: The final temperature achieved.
140
+ - k_total: The total number of O(n^2) primitive operations.
141
+ - logs: Dictionary with optimization statistics.
82
142
  """
83
143
  projector = TruncatedNewtonProjector(device=C.device, dtype=C.dtype)
84
144
 
85
145
  H_r = -(r * (r + 1e-30).log()).sum(-1)
86
146
  H_c = -(c * (c + 1e-30).log()).sum(-1)
87
147
  H_min = th.min(H_r, H_c)
88
- eps_fn = lambda g_: H_min / (g_ ** p)
148
+ eps_fn = lambda g_: H_min / (g_**p)
89
149
 
90
- logs = {
150
+ logs: Dict[str, Any] = {
91
151
  "proj_logs": [],
92
152
  "eps": [],
93
153
  }
@@ -95,7 +155,7 @@ def mdot(r, c, C, gamma_f, gamma_i=16, p=1.5, q=2.0**(1/3)):
95
155
  t = 1
96
156
  done = False
97
157
  gamma = min(gamma_i, gamma_f)
98
- gammas = [0., gamma]
158
+ gammas = [0.0, gamma]
99
159
 
100
160
  while not done:
101
161
  done = abs(gamma - gamma_f) < 1e-5 # Check if gamma == gamma_f (modulo rounding errors)
@@ -109,12 +169,16 @@ def mdot(r, c, C, gamma_f, gamma_i=16, p=1.5, q=2.0**(1/3)):
109
169
  u_cur, v_cur = u_init.clone(), v_init.clone()
110
170
 
111
171
  u_prev, v_prev = u_cur.clone(), v_cur.clone()
172
+ gamma_C = gamma * C
112
173
  u_cur, v_cur, proj_log, success = projector.project(
113
- gamma * C, r_hat.log(), c_hat.log(), eps_d / 2, u_init, v_init)
174
+ gamma_C, r_hat.log(), c_hat.log(), eps_d / 2, u_init, v_init
175
+ )
114
176
 
115
177
  logs["proj_logs"].append(proj_log)
116
178
  if not success:
117
- warnings.warn("Projection failed. Returning result at the last temperature: {:.4e}".format(1 / gammas[-2]))
179
+ warnings.warn(
180
+ f"Projection failed. Returning result at the last temperature: {1 / gammas[-2]:.4e}"
181
+ )
118
182
  u_cur = u_prev.clone()
119
183
  v_cur = v_prev.clone()
120
184
  gammas = gammas[:-1]
mdot_tnt/py.typed ADDED
File without changes
mdot_tnt/rounding.py CHANGED
@@ -6,16 +6,22 @@ and c are the row and column marginals, respectively. The algorithm is used in t
6
6
  plan and compute the cost of the rounded plan. The implementation is based on the original paper.
7
7
  """
8
8
 
9
+ from typing import Union
10
+
9
11
  import torch as th
10
12
 
11
13
 
12
- def round_altschuler(P, r, c):
14
+ def round_altschuler(P: th.Tensor, r: th.Tensor, c: th.Tensor) -> th.Tensor:
13
15
  """
14
16
  Performs rounding given a transport plan and marginals.
15
- :param P: the input transport plan
16
- :param r: row marginal
17
- :param c: column marginal
18
- :return: rounded transport plan in feasible set U(r, c).
17
+
18
+ Args:
19
+ P: The input transport plan of shape (n, m).
20
+ r: Row marginal of shape (n,).
21
+ c: Column marginal of shape (m,).
22
+
23
+ Returns:
24
+ Rounded transport plan in feasible set U(r, c).
19
25
  """
20
26
  X = th.min(r / P.sum(-1), th.ones_like(r))
21
27
  P *= X.unsqueeze(-1)
@@ -25,19 +31,39 @@ def round_altschuler(P, r, c):
25
31
 
26
32
  err_r = (r - P.sum(-1)).clamp(min=0)
27
33
  err_c = (c - P.sum(-2)).clamp(min=0)
28
- P += err_r.unsqueeze(-1) @ err_c.unsqueeze(-2) / (err_r.norm(p=1, dim=-1, keepdim=True) + 1e-30).unsqueeze(-1)
34
+ P += (
35
+ err_r.unsqueeze(-1)
36
+ @ err_c.unsqueeze(-2)
37
+ / (err_r.norm(p=1, dim=-1, keepdim=True) + 1e-30).unsqueeze(-1)
38
+ )
29
39
 
30
40
  return P
31
41
 
32
42
 
33
- def rounded_cost_altschuler(u, v, r, c, C, gamma):
34
- """Performs rounding and cost computation in logdomain given dual variables, without storing n^2 matrices.
35
- :param u: dual variable for rows
36
- :param v: dual variable for columns
37
- :param r: row marginal
38
- :param c: column marginal
39
- :param C: cost matrix
40
- :param gamma: temperature, i.e., the inverse of the entropic regularization weight.
43
+ def rounded_cost_altschuler(
44
+ u: th.Tensor,
45
+ v: th.Tensor,
46
+ r: th.Tensor,
47
+ c: th.Tensor,
48
+ C: th.Tensor,
49
+ gamma: Union[float, th.Tensor],
50
+ ) -> th.Tensor:
51
+ """
52
+ Performs rounding and cost computation in log-domain given dual variables.
53
+
54
+ This function computes the transport cost without storing the full n×m transport plan,
55
+ making it memory efficient.
56
+
57
+ Args:
58
+ u: Dual variable for rows of shape (n,).
59
+ v: Dual variable for columns of shape (m,).
60
+ r: Row marginal of shape (n,).
61
+ c: Column marginal of shape (m,).
62
+ C: Cost matrix of shape (n, m).
63
+ gamma: Temperature (inverse of the entropic regularization weight).
64
+
65
+ Returns:
66
+ The optimal transport cost as a scalar tensor.
41
67
  """
42
68
  r_P_log = u + th.logsumexp(v.unsqueeze(-2) - gamma * C, dim=-1)
43
69
  delta_u = th.min(r.log() - r_P_log, th.zeros_like(r))
@@ -50,7 +76,7 @@ def rounded_cost_altschuler(u, v, r, c, C, gamma):
50
76
  r_P_log = u + th.logsumexp(v.unsqueeze(-2) - gamma * C, dim=-1)
51
77
  r_P = r_P_log.exp()
52
78
  err_r = r - r_P
53
- err_r /= (err_r.norm(p=1, dim=-1, keepdim=True) + 1e-30)
79
+ err_r /= err_r.norm(p=1, dim=-1, keepdim=True) + 1e-30
54
80
 
55
81
  c_P_log = v + th.logsumexp(u.unsqueeze(-1) - gamma * C, dim=-2)
56
82
  c_P = c_P_log.exp()
@@ -1,39 +1,79 @@
1
+ """Truncated Newton projector for the MDOT algorithm."""
1
2
 
2
- import torch as th
3
3
  import warnings
4
+ from typing import Any, Callable, Dict, Tuple, Union
5
+
6
+ import torch as th
4
7
 
5
8
 
6
9
  class TruncatedNewtonProjector:
7
- def __init__(self, device, dtype, **kwargs):
10
+ """
11
+ Truncated Newton projector for the MDOT algorithm.
12
+
13
+ Projects onto the set of couplings satisfying marginal constraints using
14
+ a preconditioned conjugate gradient method within a Newton framework.
15
+ """
16
+
17
+ def __init__(self, device: th.device, dtype: th.dtype, **kwargs: Any) -> None:
18
+ """
19
+ Initialize the projector.
20
+
21
+ Args:
22
+ device: PyTorch device for computations.
23
+ dtype: Data type for tensors.
24
+ **kwargs: Additional options (debug: bool for verbose output).
25
+ """
8
26
  self.device = device
9
27
  self.rho = th.zeros(1, device=device, dtype=dtype)
10
- self.debug = kwargs.get('debug', False)
11
-
12
- def project(self, gamma_C, log_r, log_c, eps_d, u, v):
28
+ self.debug = kwargs.get("debug", False)
29
+ self.LSE_r: Callable[[th.Tensor], th.Tensor]
30
+ self.LSE_c: Callable[[th.Tensor], th.Tensor]
31
+
32
+ def project(
33
+ self,
34
+ gamma_C: th.Tensor,
35
+ log_r: th.Tensor,
36
+ log_c: th.Tensor,
37
+ eps_d: Union[float, th.Tensor],
38
+ u: th.Tensor,
39
+ v: th.Tensor,
40
+ ) -> Tuple[th.Tensor, th.Tensor, Dict[str, Any], bool]:
13
41
  """
14
42
  Project onto the set of couplings that satisfy the marginal constraints.
15
- :param gamma_C: The cost matrix scaled by gamma.
16
- :param log_r:
43
+
44
+ Args:
45
+ gamma_C: The cost matrix scaled by gamma, shape (n, m).
46
+ log_r: Log of row marginals, shape (n,).
47
+ log_c: Log of column marginals, shape (m,).
48
+ eps_d: Convergence tolerance for the dual gradient norm.
49
+ u: Initial row dual variables, shape (n,).
50
+ v: Initial column dual variables, shape (m,).
51
+
52
+ Returns:
53
+ u: Updated row dual variables.
54
+ v: Updated column dual variables.
55
+ logs: Dictionary with optimization statistics.
56
+ success: Whether projection converged successfully.
17
57
  """
18
- logs = {
58
+ logs: Dict[str, Any] = {
19
59
  "errs": [],
20
- 'ls_func_cnt': 0,
21
- 'chisinkhorn_steps': 0,
22
- 'newtonsolve_steps': 0,
60
+ "ls_func_cnt": 0,
61
+ "chisinkhorn_steps": 0,
62
+ "newtonsolve_steps": 0,
23
63
  "deltas": [], # Ratios of actual to theoretically predicted (ideal) reduction in gradient norm.
24
- "all_newtonsolve_steps": []
64
+ "all_newtonsolve_steps": [],
25
65
  }
26
66
  # In case of errors or issues, 10 times the tolerance level is considered
27
67
  # a good enough solution to keep MDOT going.
28
68
  success_fn = lambda err_: err_ < 10 * eps_d
29
69
 
70
+ r = log_r.exp()
71
+ c = log_c.exp()
72
+
30
73
  # Each LSE operation costs 4 * n^2 operations.
31
74
  self.LSE_r = lambda v_: th.logsumexp(v_.unsqueeze(-2) - gamma_C, dim=-1)
32
75
  self.LSE_c = lambda u_: th.logsumexp(u_.unsqueeze(-1) - gamma_C, dim=-2)
33
76
 
34
- r = log_r.exp()
35
- c = log_c.exp()
36
-
37
77
  log_c_P = v + self.LSE_c(u)
38
78
  v += log_c - log_c_P # Ensure c=c(P)
39
79
  log_r_P = u + self.LSE_r(v)
@@ -57,19 +97,20 @@ class TruncatedNewtonProjector:
57
97
  self.rho = th.max(th.zeros_like(self.rho), self.rho)
58
98
 
59
99
  P = th.exp(u.unsqueeze(-1) + v.unsqueeze(-2) - gamma_C)
60
- diag_PPc = ((P ** 2) / c.unsqueeze(-2)).sum(-1)
100
+ diag_PPc = ((P**2) / c.unsqueeze(-2)).sum(-1)
61
101
  k += 8
62
102
  delta_u, delta_v, matmul_cnt, rho, pcg_success = self.newton_solve(
63
- P, c, diag_PPc, grad_k, r_P, err, beta, eta_k, maxIter=5000)
103
+ P, c, diag_PPc, grad_k, r_P, err, beta, eta_k, maxIter=5000
104
+ )
64
105
  del P # Free up memory
65
106
  if not pcg_success:
66
107
  k += matmul_cnt
67
108
  logs["n_iter"] = k
68
- msg = "PCG did not converge. TruncatedNewton returning with success={}".format(success_fn(err))
109
+ msg = f"PCG did not converge. TruncatedNewton returning with success={success_fn(err)}"
69
110
  warnings.warn(msg)
70
111
  return u, v, logs, success_fn(err)
71
112
 
72
- self.rho = th.max(th.zeros_like(self.rho), 1. - (1. - rho) * 4.)
113
+ self.rho = th.max(th.zeros_like(self.rho), 1.0 - (1.0 - rho) * 4.0)
73
114
  k += matmul_cnt
74
115
  logs["newtonsolve_steps"] += matmul_cnt
75
116
 
@@ -79,8 +120,7 @@ class TruncatedNewtonProjector:
79
120
  linear_decr = -(grad_k * delta_u).sum(-1, keepdim=True)
80
121
  if not linear_decr > 0:
81
122
  logs["n_iter"] = k
82
- msg = "Linear decrease condition not satisfied. TruncatedNewton returning with success={}".format(
83
- success_fn(err))
123
+ msg = f"Linear decrease condition not satisfied. TruncatedNewton returning with success={success_fn(err)}"
84
124
  warnings.warn(msg)
85
125
  return u, v, logs, success_fn(err)
86
126
 
@@ -89,8 +129,7 @@ class TruncatedNewtonProjector:
89
129
  alpha *= 0.5
90
130
  if alpha < 1e-9:
91
131
  logs["n_iter"] = k
92
- msg = "Line search did not converge. TruncatedNewton returning with success={}".format(
93
- success_fn(err))
132
+ msg = f"Line search did not converge. TruncatedNewton returning with success={success_fn(err)}"
94
133
  warnings.warn(msg)
95
134
  return u, v, logs, success_fn(err)
96
135
 
@@ -113,13 +152,17 @@ class TruncatedNewtonProjector:
113
152
  log_r_P = u + self.LSE_r(v)
114
153
  k += 4
115
154
 
116
- u, v, log_r_P, err, k_ = self.chi_sinkhorn(u, v, log_r, log_c, log_r_P, eps_d ** (2 / 5))
155
+ u, v, log_r_P, err, k_ = self.chi_sinkhorn(
156
+ u, v, log_r, log_c, log_r_P, eps_d ** (2 / 5)
157
+ )
117
158
  r_P = log_r_P.exp()
118
159
  logs["chisinkhorn_steps"] += k_
119
160
  k += k_
120
161
 
121
162
  logs["errs"].append(err)
122
- logs["deltas"].append(th.min((logs["errs"][-2] - err_before_sk) / ((1 - eta_k) * logs["errs"][-2])).item())
163
+ logs["deltas"].append(
164
+ th.min((logs["errs"][-2] - err_before_sk) / ((1 - eta_k) * logs["errs"][-2])).item()
165
+ )
123
166
 
124
167
  if u.isnan().any() or v.isnan().any():
125
168
  raise ValueError("NaNs encountered in u or v")
@@ -132,7 +175,16 @@ class TruncatedNewtonProjector:
132
175
 
133
176
  return u, v, logs, True
134
177
 
135
- def chi_sinkhorn(self, u, v, log_r, log_c, log_r_P, eps_chi, maxOps=float('inf')):
178
+ def chi_sinkhorn(
179
+ self,
180
+ u: th.Tensor,
181
+ v: th.Tensor,
182
+ log_r: th.Tensor,
183
+ log_c: th.Tensor,
184
+ log_r_P: th.Tensor,
185
+ eps_chi: Union[float, th.Tensor],
186
+ maxOps: float = float("inf"),
187
+ ) -> Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, int]:
136
188
  k = 0
137
189
  r = log_r.exp()
138
190
  err = (r - log_r_P.exp()).norm(p=1, dim=-1)
@@ -154,28 +206,40 @@ class TruncatedNewtonProjector:
154
206
  k += 8
155
207
 
156
208
  if k >= maxOps:
157
- raise ValueError("Chi-Sinkhorn did not converge in maxIter={} steps".format(maxOps))
209
+ raise ValueError(f"Chi-Sinkhorn did not converge in maxIter={maxOps} steps")
158
210
 
159
211
  return u, v, log_r_P, err, k
160
212
 
161
- def newton_solve(self, P, c, diag_PPc, grad_k, r_P, err, beta=0.5, eta_k=0.5, maxIter=500):
213
+ def newton_solve(
214
+ self,
215
+ P: th.Tensor,
216
+ c: th.Tensor,
217
+ diag_PPc: th.Tensor,
218
+ grad_k: th.Tensor,
219
+ r_P: th.Tensor,
220
+ err: th.Tensor,
221
+ beta: float = 0.5,
222
+ eta_k: Union[float, th.Tensor] = 0.5,
223
+ maxIter: int = 500,
224
+ ) -> Tuple[th.Tensor, th.Tensor, int, th.Tensor, bool]:
162
225
  rho = self.rho
163
226
  tol = err * eta_k
164
227
 
165
- def matmul_PPc(x_):
166
- return (P @ ((x_.unsqueeze(-2) @ P).transpose(-2, -1) / c.unsqueeze(-1))).squeeze(-1)
228
+ matmul_PPc = lambda x_: (
229
+ P @ ((x_.unsqueeze(-2) @ P).transpose(-2, -1) / c.unsqueeze(-1))
230
+ ).squeeze(-1)
167
231
 
168
232
  # mml = th.compile(matmul_PPc)
169
233
  mml = matmul_PPc
170
234
 
171
- M = lambda rho_: (r_P - rho_ * diag_PPc) # Diagonal preconditioner
235
+ M = lambda rho_: r_P - rho_ * diag_PPc # Diagonal preconditioner
172
236
  M_rho = M(th.ones_like(self.rho))
173
237
  M_rho[M_rho <= 0] = M_rho[M_rho > 0].min()
174
238
 
175
239
  x0 = -grad_k / M_rho
176
240
  PPc_x0 = mml(x0)
177
241
  matmul_cnt = 2
178
- r_P_x0 = (r_P * x0)
242
+ r_P_x0 = r_P * x0
179
243
 
180
244
  x = x0.clone()
181
245
  PPc_x = PPc_x0.clone()
@@ -198,7 +262,7 @@ class TruncatedNewtonProjector:
198
262
  best_sol[r_true_norm < best_r_true_norm] = x[r_true_norm < best_r_true_norm]
199
263
  best_r_true_norm = th.min(r_true_norm, best_r_true_norm)
200
264
 
201
- rho[r_true_norm > tol] = 1. - (1. - rho[r_true_norm > tol]) * 0.25
265
+ rho[r_true_norm > tol] = 1.0 - (1.0 - rho[r_true_norm > tol]) * 0.25
202
266
  M_rho = M(rho)
203
267
 
204
268
  if matmul_cnt > 0:
@@ -228,8 +292,10 @@ class TruncatedNewtonProjector:
228
292
 
229
293
  quad = (Fr_p * p).sum(-1, keepdim=True)
230
294
  if (quad <= 0)[best_r_true_norm > tol].any():
231
- warnings.warn("Warning: negative curvature encountered in CG. Returning best solution. "
232
- "Residual norm less than error: {}".format((best_r_true_norm < err).item()))
295
+ warnings.warn(
296
+ "Warning: negative curvature encountered in CG. Returning best solution. "
297
+ f"Residual norm less than error: {(best_r_true_norm < err).item()}"
298
+ )
233
299
  x = best_sol.clone()
234
300
  done = True
235
301
  success = best_r_true_norm < err
@@ -242,12 +308,15 @@ class TruncatedNewtonProjector:
242
308
  res += alpha * Fr_p
243
309
  r_norm = res.norm(p=1, dim=-1)
244
310
 
245
- if th.isnan(r_norm)[best_r_true_norm > tol].any() or th.isinf(r_norm)[best_r_true_norm > tol].any():
311
+ if (
312
+ th.isnan(r_norm)[best_r_true_norm > tol].any()
313
+ or th.isinf(r_norm)[best_r_true_norm > tol].any()
314
+ ):
246
315
  raise ValueError("NaNs or infs encountered in r_norm")
247
316
 
248
317
  PPc_x += alpha * PPc_p
249
318
 
250
- r_P_x = (r_P * x)
319
+ r_P_x = r_P * x
251
320
  res_true = r_P_x - PPc_x + grad_k
252
321
  r_true_norm = res_true.norm(p=1, dim=-1)
253
322
  best_sol[r_true_norm < best_r_true_norm] = x[r_true_norm < best_r_true_norm]
@@ -280,4 +349,4 @@ class TruncatedNewtonProjector:
280
349
  Pc_x = ((x.unsqueeze(-2) @ P).transpose(-2, -1) / c.unsqueeze(-1)).squeeze(-1)
281
350
  matmul_cnt += 1
282
351
 
283
- return x, -Pc_x, matmul_cnt, rho, success
352
+ return x, -Pc_x, matmul_cnt, rho, success