freealg 0.7.12__py3-none-any.whl → 0.7.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. freealg/__version__.py +1 -1
  2. freealg/_algebraic_form/_cusp.py +357 -0
  3. freealg/_algebraic_form/_cusp_wrap.py +268 -0
  4. freealg/_algebraic_form/_decompress2.py +2 -0
  5. freealg/_algebraic_form/_decompress4.py +739 -0
  6. freealg/_algebraic_form/_decompress5.py +738 -0
  7. freealg/_algebraic_form/_decompress6.py +492 -0
  8. freealg/_algebraic_form/_decompress7.py +355 -0
  9. freealg/_algebraic_form/_decompress8.py +369 -0
  10. freealg/_algebraic_form/_decompress9.py +363 -0
  11. freealg/_algebraic_form/_decompress_new.py +431 -0
  12. freealg/_algebraic_form/_decompress_new_2.py +1631 -0
  13. freealg/_algebraic_form/_decompress_util.py +172 -0
  14. freealg/_algebraic_form/_homotopy2.py +289 -0
  15. freealg/_algebraic_form/_homotopy3.py +215 -0
  16. freealg/_algebraic_form/_homotopy4.py +320 -0
  17. freealg/_algebraic_form/_homotopy5.py +185 -0
  18. freealg/_algebraic_form/_moments.py +0 -1
  19. freealg/_algebraic_form/_support.py +132 -177
  20. freealg/_algebraic_form/algebraic_form.py +21 -2
  21. freealg/distributions/_compound_poisson.py +481 -0
  22. freealg/distributions/_deformed_marchenko_pastur.py +6 -7
  23. {freealg-0.7.12.dist-info → freealg-0.7.15.dist-info}/METADATA +1 -1
  24. {freealg-0.7.12.dist-info → freealg-0.7.15.dist-info}/RECORD +28 -12
  25. {freealg-0.7.12.dist-info → freealg-0.7.15.dist-info}/WHEEL +0 -0
  26. {freealg-0.7.12.dist-info → freealg-0.7.15.dist-info}/licenses/AUTHORS.txt +0 -0
  27. {freealg-0.7.12.dist-info → freealg-0.7.15.dist-info}/licenses/LICENSE.txt +0 -0
  28. {freealg-0.7.12.dist-info → freealg-0.7.15.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,172 @@
1
+ # SPDX-FileCopyrightText: Copyright 2026, Siavash Ameli <sameli@berkeley.edu>
2
+ # SPDX-License-Identifier: BSD-3-Clause
3
+ # SPDX-FileType: SOURCE
4
+ #
5
+ # This program is free software: you can redistribute it and/or modify it under
6
+ # the terms of the license found in the LICENSE.txt file in the root directory
7
+ # of this source tree.
8
+
9
+
10
+ # =======
11
+ # Imports
12
+ # =======
13
+
14
+ import numpy
15
+ from ._continuation_algebraic import powers
16
+
17
+ __all__ = ['build_time_grid', 'eval_P_partials']
18
+
19
+
20
+ # ===============
21
+ # build time grid
22
+ # ===============
23
+
24
+ def build_time_grid(sizes, n0, min_n_times=0):
25
+ """
26
+ sizes: list/array of requested matrix sizes (e.g. [2000,3000,4000,8000])
27
+ n0: initial size (self.n)
28
+ min_n_times: minimum number of time points to run Newton sweep on
29
+
30
+ Returns
31
+ -------
32
+ t_all: sorted time grid to run solver on
33
+ idx_req: indices of requested times inside t_all (same order as sizes)
34
+ """
35
+
36
+ sizes = numpy.asarray(sizes, dtype=float)
37
+ alpha = sizes / float(n0)
38
+ t_req = numpy.log(alpha)
39
+
40
+ # Always include t=0 and T=max(t_req)
41
+ T = float(numpy.max(t_req)) if t_req.size else 0.0
42
+ base = numpy.unique(numpy.r_[0.0, t_req, T])
43
+ t_all = numpy.sort(base)
44
+
45
+ # Add points only if needed: split largest gaps
46
+ N = int(min_n_times) if min_n_times is not None else 0
47
+ while t_all.size < N and t_all.size >= 2:
48
+ gaps = numpy.diff(t_all)
49
+ k = int(numpy.argmax(gaps))
50
+ mid = 0.5 * (t_all[k] + t_all[k+1])
51
+ t_all = numpy.sort(numpy.unique(numpy.r_[t_all, mid]))
52
+
53
+ # Map each requested time to an index in t_all (stable, no float drama)
54
+ # (t_req values came from same construction, so they should match exactly;
55
+ # still: use searchsorted + assert)
56
+ idx_req = numpy.searchsorted(t_all, t_req)
57
+ # optional sanity:
58
+ # assert numpy.allclose(t_all[idx_req], t_req, rtol=0, atol=0)
59
+
60
+ return t_all, idx_req
61
+
62
+
63
+ # ===============
64
+ # eval P partials
65
+ # ===============
66
+
67
+ def eval_P_partials(z, m, a_coeffs):
68
+ """
69
+ Evaluate P(z,m) and its partial derivatives dP/dz and dP/dm.
70
+
71
+ This assumes P is represented by `a_coeffs` in the monomial basis
72
+
73
+ P(z, m) = sum_{j=0..s} a_j(z) * m^j,
74
+ a_j(z) = sum_{i=0..deg_z} a_coeffs[i, j] * z^i.
75
+
76
+ The function returns P, dP/dz, dP/dm with broadcasting over z and m.
77
+
78
+ Parameters
79
+ ----------
80
+ z : complex or array_like of complex
81
+ First argument to P.
82
+ m : complex or array_like of complex
83
+ Second argument to P. Must be broadcast-compatible with `z`.
84
+ a_coeffs : ndarray, shape (deg_z+1, s+1)
85
+ Coefficient matrix for P in the monomial basis.
86
+
87
+ Returns
88
+ -------
89
+ P : complex or ndarray of complex
90
+ Value P(z,m).
91
+ Pz : complex or ndarray of complex
92
+ Partial derivative dP/dz evaluated at (z,m).
93
+ Pm : complex or ndarray of complex
94
+ Partial derivative dP/dm evaluated at (z,m).
95
+
96
+ Notes
97
+ -----
98
+ For scalar (z,m), this uses Horner evaluation for a_j(z) and then Horner
99
+ in m. For array inputs, it uses precomputed power tables via `_powers` for
100
+ simplicity.
101
+
102
+ Examples
103
+ --------
104
+ .. code-block:: python
105
+
106
+ P, Pz, Pm = eval_P_partials(1.0 + 1j, 0.2 + 0.3j, a_coeffs)
107
+ """
108
+
109
+ z = numpy.asarray(z, dtype=complex)
110
+ m = numpy.asarray(m, dtype=complex)
111
+
112
+ deg_z = int(a_coeffs.shape[0] - 1)
113
+ s = int(a_coeffs.shape[1] - 1)
114
+
115
+ if (z.ndim == 0) and (m.ndim == 0):
116
+ zz = complex(z)
117
+ mm = complex(m)
118
+
119
+ a = numpy.empty(s + 1, dtype=complex)
120
+ ap = numpy.empty(s + 1, dtype=complex)
121
+
122
+ for j in range(s + 1):
123
+ c = a_coeffs[:, j]
124
+
125
+ val = 0.0 + 0.0j
126
+ for i in range(deg_z, -1, -1):
127
+ val = val * zz + c[i]
128
+ a[j] = val
129
+
130
+ dval = 0.0 + 0.0j
131
+ for i in range(deg_z, 0, -1):
132
+ dval = dval * zz + (i * c[i])
133
+ ap[j] = dval
134
+
135
+ p = a[s]
136
+ pm = 0.0 + 0.0j
137
+ for j in range(s - 1, -1, -1):
138
+ pm = pm * mm + p
139
+ p = p * mm + a[j]
140
+
141
+ pz = ap[s]
142
+ for j in range(s - 1, -1, -1):
143
+ pz = pz * mm + ap[j]
144
+
145
+ return p, pz, pm
146
+
147
+ shp = numpy.broadcast(z, m).shape
148
+ zz = numpy.broadcast_to(z, shp).ravel()
149
+ mm = numpy.broadcast_to(m, shp).ravel()
150
+
151
+ zp = powers(zz, deg_z)
152
+ mp = powers(mm, s)
153
+
154
+ dzp = numpy.zeros_like(zp)
155
+ for i in range(1, deg_z + 1):
156
+ dzp[:, i] = i * zp[:, i - 1]
157
+
158
+ P = numpy.zeros(zz.size, dtype=complex)
159
+ Pz = numpy.zeros(zz.size, dtype=complex)
160
+ Pm = numpy.zeros(zz.size, dtype=complex)
161
+
162
+ for j in range(s + 1):
163
+ aj = zp @ a_coeffs[:, j]
164
+ P += aj * mp[:, j]
165
+
166
+ ajp = dzp @ a_coeffs[:, j]
167
+ Pz += ajp * mp[:, j]
168
+
169
+ if j >= 1:
170
+ Pm += (j * aj) * mp[:, j - 1]
171
+
172
+ return P.reshape(shp), Pz.reshape(shp), Pm.reshape(shp)
@@ -0,0 +1,289 @@
1
+ # =======
2
+ # Imports
3
+ # =======
4
+
5
+ import numpy
6
+ from ._moments import AlgebraicStieltjesMoments
7
+ from tqdm import tqdm
8
+
9
+ __all__ = ['StieltjesPoly']
10
+
11
+
12
+ # ===========
13
+ # select root
14
+ # ===========
15
+
16
+ # def select_root(roots, z, target):
17
+ # """
18
+ # Select the root among Herglotz candidates at a given z closest to a
19
+ # given target.
20
+
21
+ # Parameters
22
+ # ----------
23
+ # roots : array_like of complex
24
+ # Candidate roots for m at the given z.
25
+ # z : complex
26
+ # Evaluation point. The Stieltjes/Herglotz branch satisfies
27
+ # sign(Im(m)) = sign(Im(z)) away from the real axis.
28
+ # target : complex
29
+ # Previous continuation value used to enforce continuity, or
30
+ # target value.
31
+
32
+ # Returns
33
+ # -------
34
+ # w : complex
35
+ # Selected root corresponding to the Stieltjes branch.
36
+ # """
37
+
38
+ def select_root(roots, z, target, tol_im=1e-10, tiny_im=1e-6, ratio=50.0):
39
+ z = complex(z)
40
+ roots = numpy.asarray(roots, dtype=numpy.complex128).ravel()
41
+
42
+ s = numpy.sign(z.imag)
43
+ if s == 0.0:
44
+ s = 1.0
45
+
46
+ # candidates in the correct half-plane (soft)
47
+ im = numpy.imag(roots) * s
48
+ cand = roots[im > -tol_im]
49
+ if cand.size == 0:
50
+ cand = roots
51
+ im = numpy.imag(cand) * s
52
+ else:
53
+ im = numpy.imag(cand) * s
54
+
55
+ # if there is a "big-imag" candidate and a "tiny-imag" candidate, pick big-imag
56
+ im_max = float(im.max())
57
+ im_min = float(im.min())
58
+
59
+ if im_max > tiny_im and im_min < tiny_im and im_max / max(im_min, 1e-16) > ratio:
60
+ return cand[int(numpy.argmax(im))]
61
+
62
+ # otherwise fall back to continuity
63
+ t = complex(target)
64
+ return cand[int(numpy.argmin(numpy.abs(cand - t)))]
65
+
66
+ # def select_root(roots, z, target, tol_im=1e-10, tiny_im=1e-6, ratio=50.0):
67
+ # z = complex(z)
68
+ # roots = numpy.asarray(roots, dtype=numpy.complex128).ravel()
69
+
70
+ # s = numpy.sign(z.imag)
71
+ # if s == 0.0:
72
+ # s = 1.0
73
+
74
+ # im = numpy.imag(roots) * s
75
+ # cand = roots[im > -tol_im]
76
+ # if cand.size == 0:
77
+ # cand = roots
78
+ # im = numpy.imag(cand) * s
79
+ # else:
80
+ # im = numpy.imag(cand) * s
81
+
82
+ # im_max = float(im.max())
83
+ # im_min = float(im.min())
84
+
85
+ # # If there's a clear "big-Im" vs "tiny-Im" split, take big-Im
86
+ # if im_max > tiny_im and im_min < tiny_im and im_max / max(im_min, 1e-16) > ratio:
87
+ # return cand[int(numpy.argmax(im))]
88
+
89
+ # # Otherwise continuity
90
+ # t = complex(target)
91
+ # return cand[int(numpy.argmin(numpy.abs(cand - t)))]
92
+
93
+
94
+
95
+
96
+
97
+ # ==============
98
+ # stieltjes poly
99
+ # ==============
100
+
101
+ class StieltjesPoly(object):
102
+ """
103
+ Stieltjes-branch evaluator for an algebraic equation P(z, m) = 0.
104
+
105
+ Parameters
106
+ ----------
107
+ a : ndarray, shape (L, K)
108
+ Coefficient matrix defining P(z, m) in the monomial basis.
109
+ eps : float or None, optional
110
+ If Im(z) == 0, use z + i*eps as the boundary evaluation point.
111
+ If None and Im(z) == 0, eps is set to 1e-8 * max(1, |z|).
112
+ height : float, default = 2.0
113
+ Imaginary height factor used to build a safe start radius.
114
+ steps : int, default = 100
115
+ Number of continuation steps along the vertical leg.
116
+ order : int, default = 15
117
+ Number of moments in Stieltjes estimate.
118
+ reanchor : int, default = 1
119
+ During 1D line-sweeps, every ``reanchor`` points we re-run a full
120
+ evaluate() at the current point. Setting this to 1 is the most robust.
121
+ """
122
+
123
+ def __init__(self, a, eps=None, height=2.0, steps=100, order=15,
124
+ reanchor=1):
125
+ a = numpy.asarray(a)
126
+ if a.ndim != 2:
127
+ raise ValueError("a must be a 2D array.")
128
+
129
+ self.a = a
130
+ self.a_l, _ = a.shape
131
+ self.eps = eps
132
+ self.height = float(height)
133
+ self.steps = int(steps)
134
+ self.order = int(order)
135
+ self.reanchor = reanchor
136
+
137
+ self.mom = AlgebraicStieltjesMoments(a)
138
+
139
+ # Start point far enough away (imag direction)
140
+ self.rad = 1.0 + self.height * self.mom.radius(self.order)
141
+ self.z0_p = 1j * self.rad
142
+ self.z0_m = -1j * self.rad
143
+
144
+ # Moment anchors at z0
145
+ self.m0_p = self.mom.stieltjes(self.z0_p, self.order)
146
+ self.m0_m = self.mom.stieltjes(self.z0_m, self.order)
147
+
148
+ def _poly_coeffs_m(self, z_val):
149
+ z_powers = z_val ** numpy.arange(self.a_l)
150
+ return (z_powers @ self.a)[::-1]
151
+
152
+ def _poly_roots(self, z_val):
153
+ coeffs = numpy.asarray(self._poly_coeffs_m(z_val),
154
+ dtype=numpy.complex128)
155
+ return numpy.roots(coeffs)
156
+
157
+ def evaluate(self, z):
158
+ """
159
+ Evaluate the Stieltjes-branch solution m(z) at a single point.
160
+
161
+ Robust strategy for multi-bulk cubics:
162
+ 1) Move to z_mid = x + i*sign(y)*rad (high above the real axis).
163
+ 2) At z_mid, pick the root closest to the moment-based estimate m_mom(z_mid).
164
+ (This is the crucial re-anchoring that prevents choosing the wrong sheet.)
165
+ 3) Continue vertically from z_mid down to z_eval = x + i*y (or x + i*eps).
166
+ """
167
+
168
+ z = complex(z)
169
+
170
+ if self.steps < 1:
171
+ raise ValueError("steps must be >= 1.")
172
+
173
+ # Boundary-value interpretation on the real axis
174
+ if z.imag == 0.0:
175
+ if self.eps is None:
176
+ eps_loc = 1e-8 * max(1.0, abs(z))
177
+ else:
178
+ eps_loc = float(self.eps)
179
+ z_eval = z + 1j * eps_loc
180
+ else:
181
+ z_eval = z
182
+
183
+ half_sign = numpy.sign(z_eval.imag)
184
+ if half_sign == 0.0:
185
+ half_sign = 1.0
186
+
187
+ # High-imag anchor at same real part
188
+ z_mid = complex(z_eval.real, half_sign * self.rad)
189
+
190
+ r_mid = self._poly_roots(z_mid)
191
+ w_prev = r_mid[int(numpy.argmax(numpy.imag(r_mid) * numpy.sign(z_mid.imag)))]
192
+
193
+
194
+ # Moment-based target at z_mid (THIS fixes the wrong-bulk issue)
195
+ # m_mid_target = self.mom.stieltjes(z_mid, self.order)
196
+ m_mid_target = -1.0 / z_mid
197
+
198
+ # Select correct branch at z_mid using the moment target
199
+ w_prev = select_root(self._poly_roots(z_mid), z_mid, m_mid_target)
200
+
201
+ # Vertical continuation: z_mid -> z_eval
202
+ for tau in numpy.linspace(0.0, 1.0, int(self.steps) + 1)[1:]:
203
+ z_tau = z_mid + tau * (z_eval - z_mid)
204
+ w_prev = select_root(self._poly_roots(z_tau), z_tau, w_prev)
205
+
206
+ return w_prev
207
+
208
+ def _is_flat_imag_line_1d(self, z_flat):
209
+ if z_flat.ndim != 1 or z_flat.size < 2:
210
+ return False
211
+
212
+ y = numpy.imag(z_flat)
213
+ if not numpy.all(numpy.isfinite(y)):
214
+ return False
215
+
216
+ y0 = float(y[0])
217
+ if y0 == 0.0:
218
+ return False
219
+
220
+ # all in same half-plane
221
+ if not numpy.all(numpy.sign(y) == numpy.sign(y0)):
222
+ return False
223
+
224
+ # nearly constant imaginary part (relative)
225
+ tol = 1e-14 * max(1.0, abs(y0))
226
+ return numpy.max(numpy.abs(y - y0)) <= tol
227
+
228
+ def _sweep_line(self, z_sorted, progress=False):
229
+ n = z_sorted.size
230
+ out = numpy.empty_like(z_sorted, dtype=numpy.complex128)
231
+
232
+ # Always anchor the first point robustly
233
+ m_prev = self.evaluate(z_sorted[0])
234
+ out[0] = m_prev
235
+
236
+ it = range(1, n)
237
+ if progress:
238
+ it = tqdm(it, total=n - 1)
239
+
240
+ reanchor = self.reanchor
241
+ do_reanchor = (reanchor is not None) and (int(reanchor) > 0)
242
+ reanchor = int(reanchor) if do_reanchor else 0
243
+
244
+ for k in it:
245
+ zk = z_sorted[k]
246
+
247
+ # With default reanchor=1, this becomes "evaluate every point"
248
+ if do_reanchor and (k % reanchor == 0):
249
+ m_prev = self.evaluate(zk)
250
+ else:
251
+ rk = self._poly_roots(zk)
252
+ m_prev = select_root(rk, zk, m_prev)
253
+
254
+ out[k] = m_prev
255
+
256
+ return out
257
+
258
+ def __call__(self, z, progress=False):
259
+ # Scalar fast-path
260
+ if numpy.isscalar(z):
261
+ return self.evaluate(z)
262
+
263
+ z_arr = numpy.asarray(z, dtype=numpy.complex128)
264
+ out = numpy.empty(z_arr.shape, dtype=numpy.complex128)
265
+
266
+ # 1D horizontal line sweep (density-style queries)
267
+ if z_arr.ndim == 1:
268
+ z_flat = z_arr
269
+ if self._is_flat_imag_line_1d(z_flat):
270
+ order = numpy.argsort(numpy.real(z_flat))
271
+ inv = numpy.empty_like(order)
272
+ inv[order] = numpy.arange(order.size)
273
+
274
+ z_sorted = z_flat[order]
275
+ out_sorted = self._sweep_line(z_sorted, progress=progress)
276
+
277
+ out[:] = out_sorted[inv]
278
+ return out
279
+
280
+ # Fallback: elementwise
281
+ if progress:
282
+ indices = tqdm(numpy.ndindex(z_arr.shape), total=z_arr.size)
283
+ else:
284
+ indices = numpy.ndindex(z_arr.shape)
285
+
286
+ for idx in indices:
287
+ out[idx] = self.evaluate(z_arr[idx])
288
+
289
+ return out
@@ -0,0 +1,215 @@
1
+ # =======
2
+ # Imports
3
+ # =======
4
+
5
+ import numpy
6
+ from tqdm import tqdm
7
+ from ._moments import AlgebraicStieltjesMoments
8
+
9
+ __all__ = ['StieltjesPoly']
10
+
11
+
12
+ # ==============
13
+ # Stieltjes Poly
14
+ # ==============
15
+
16
+ class StieltjesPoly(object):
17
+ """
18
+ Stieltjes-branch evaluator for P(z, m)=0 with robust 1D tracking.
19
+
20
+ For 1D arrays on a horizontal line (z = x + i*delta), uses Viterbi
21
+ tracking across the whole line to avoid branch mis-selection.
22
+ Otherwise falls back to pointwise evaluate().
23
+
24
+ Parameters
25
+ ----------
26
+ a : ndarray
27
+ Polynomial coefficient matrix.
28
+ eps : float or None
29
+ Imaginary offset when Im(z)=0.
30
+ height : float
31
+ Radius factor for safe imaginary height.
32
+ steps : int
33
+ Vertical continuation steps in evaluate().
34
+ order : int
35
+ Moment order.
36
+ lam_im : float
37
+ Viterbi penalty strength for tiny |Im(m)|.
38
+ tol_im : float
39
+ Herglotz tolerance.
40
+ """
41
+
42
+ def __init__(self, a, eps=None, height=2.0, steps=100, order=15,
43
+ lam_im=1.0e3, tol_im=1.0e-12):
44
+ a = numpy.asarray(a)
45
+ if a.ndim != 2:
46
+ raise ValueError("a must be a 2D array.")
47
+
48
+ self.a = a
49
+ self.a_l, _ = a.shape
50
+
51
+ self.eps = eps
52
+ self.height = float(height)
53
+ self.steps = int(steps)
54
+ self.order = int(order)
55
+
56
+ self.lam_im = float(lam_im)
57
+ self.tol_im = float(tol_im)
58
+
59
+ self.mom = AlgebraicStieltjesMoments(a)
60
+
61
+ self.rad = 1.0 + self.height * self.mom.radius(self.order)
62
+ self.z0_p = 1j * self.rad
63
+ self.z0_m = -1j * self.rad
64
+
65
+ self.m0_p = self.mom.stieltjes(self.z0_p, self.order)
66
+ self.m0_m = self.mom.stieltjes(self.z0_m, self.order)
67
+
68
+ # -----------
69
+ # poly roots
70
+ # -----------
71
+
72
+ def _poly_coeffs_m(self, z_val):
73
+ z_powers = z_val ** numpy.arange(self.a_l)
74
+ return (z_powers @ self.a)[::-1]
75
+
76
+ def _poly_roots(self, z_val):
77
+ coeffs = numpy.asarray(self._poly_coeffs_m(z_val),
78
+ dtype=numpy.complex128)
79
+ return numpy.roots(coeffs)
80
+
81
+ # -------------
82
+ # point evaluate
83
+ # -------------
84
+
85
+ def _select_root_continuity(self, roots, z, target):
86
+ z = complex(z)
87
+ roots = numpy.asarray(roots, dtype=numpy.complex128).ravel()
88
+
89
+ s = numpy.sign(z.imag)
90
+ if s == 0.0:
91
+ s = 1.0
92
+
93
+ im = numpy.imag(roots) * s
94
+ cand = roots[im > -self.tol_im]
95
+ if cand.size == 0:
96
+ cand = roots
97
+
98
+ t = complex(target)
99
+ return cand[int(numpy.argmin(numpy.abs(cand - t)))]
100
+
101
+ def evaluate(self, z):
102
+ z = complex(z)
103
+
104
+ if z.imag == 0.0:
105
+ if self.eps is None:
106
+ eps_loc = 1e-8 * max(1.0, abs(z))
107
+ else:
108
+ eps_loc = float(self.eps)
109
+ z_eval = z + 1j * eps_loc
110
+ else:
111
+ z_eval = z
112
+
113
+ s = numpy.sign(z_eval.imag)
114
+ if s == 0.0:
115
+ s = 1.0
116
+
117
+ z_mid = complex(z_eval.real, s * self.rad)
118
+
119
+ # anchor using moment estimate at z_mid
120
+ target = self.mom.stieltjes(z_mid, self.order)
121
+ w_prev = self._select_root_continuity(self._poly_roots(z_mid), z_mid, target)
122
+
123
+ # vertical continuation
124
+ for tau in numpy.linspace(0.0, 1.0, int(self.steps) + 1)[1:]:
125
+ z_tau = z_mid + tau * (z_eval - z_mid)
126
+ w_prev = self._select_root_continuity(self._poly_roots(z_tau), z_tau, w_prev)
127
+
128
+ return w_prev
129
+
130
+ # -----------------
131
+ # viterbi utilities
132
+ # -----------------
133
+
134
+ def _is_flat_imag_line_1d(self, z):
135
+ if z.ndim != 1 or z.size < 2:
136
+ return False
137
+ y = numpy.imag(z)
138
+ if not numpy.all(numpy.isfinite(y)):
139
+ return False
140
+ y0 = float(y[0])
141
+ if y0 == 0.0:
142
+ return False
143
+ if not numpy.all(numpy.sign(y) == numpy.sign(y0)):
144
+ return False
145
+ tol = 1e-14 * max(1.0, abs(y0))
146
+ return numpy.max(numpy.abs(y - y0)) <= tol
147
+
148
+ def _herglotz_mask(self, roots, z):
149
+ s = numpy.sign(z.imag)
150
+ if s == 0.0:
151
+ s = 1.0
152
+ return (numpy.imag(roots) * s) > -self.tol_im
153
+
154
+ def _viterbi_track(self, z_sorted, roots_all):
155
+ z_sorted = numpy.asarray(z_sorted, dtype=numpy.complex128)
156
+ R = numpy.asarray(roots_all, dtype=numpy.complex128)
157
+ N, S = R.shape
158
+
159
+ unary = numpy.zeros((N, S), dtype=numpy.float64)
160
+ for k in range(N):
161
+ mask = self._herglotz_mask(R[k], z_sorted[k])
162
+ unary[k, ~mask] += 1e30
163
+ im = numpy.abs(numpy.imag(R[k]))
164
+ unary[k] += self.lam_im / numpy.maximum(im, 1e-16)
165
+
166
+ dp = numpy.full((N, S), numpy.inf, dtype=numpy.float64)
167
+ prev = numpy.full((N, S), -1, dtype=numpy.int64)
168
+
169
+ dp[0] = unary[0]
170
+
171
+ for k in range(1, N):
172
+ diff = R[k][None, :] - R[k-1][:, None]
173
+ cost = dp[k-1][:, None] + (numpy.abs(diff) ** 2)
174
+ i_star = numpy.argmin(cost, axis=0)
175
+ dp[k] = unary[k] + cost[i_star, numpy.arange(S)]
176
+ prev[k] = i_star
177
+
178
+ j = int(numpy.argmin(dp[-1]))
179
+ path = numpy.empty(N, dtype=numpy.int64)
180
+ path[-1] = j
181
+ for k in range(N - 1, 0, -1):
182
+ path[k - 1] = prev[k, path[k]]
183
+
184
+ return R[numpy.arange(N), path]
185
+
186
+ # -----
187
+ # call
188
+ # -----
189
+
190
+ def __call__(self, z, progress=False):
191
+ if numpy.isscalar(z):
192
+ return self.evaluate(z)
193
+
194
+ z_arr = numpy.asarray(z, dtype=numpy.complex128)
195
+
196
+ # Viterbi path only for 1D horizontal lines
197
+ if z_arr.ndim == 1 and self._is_flat_imag_line_1d(z_arr):
198
+ order = numpy.argsort(numpy.real(z_arr))
199
+ inv = numpy.empty_like(order)
200
+ inv[order] = numpy.arange(order.size)
201
+
202
+ z_sorted = z_arr[order]
203
+ roots_all = numpy.array([self._poly_roots(zk) for zk in z_sorted],
204
+ dtype=numpy.complex128)
205
+ m_sorted = self._viterbi_track(z_sorted, roots_all)
206
+ return m_sorted[inv]
207
+
208
+ # Fallback: elementwise
209
+ out = numpy.empty(z_arr.shape, dtype=numpy.complex128)
210
+ it = numpy.ndindex(z_arr.shape)
211
+ if progress:
212
+ it = tqdm(it, total=z_arr.size)
213
+ for idx in it:
214
+ out[idx] = self.evaluate(z_arr[idx])
215
+ return out