freealg 0.1.11__py3-none-any.whl → 0.7.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. freealg/__init__.py +8 -2
  2. freealg/__version__.py +1 -1
  3. freealg/_algebraic_form/__init__.py +12 -0
  4. freealg/_algebraic_form/_branch_points.py +288 -0
  5. freealg/_algebraic_form/_constraints.py +139 -0
  6. freealg/_algebraic_form/_continuation_algebraic.py +706 -0
  7. freealg/_algebraic_form/_decompress.py +641 -0
  8. freealg/_algebraic_form/_decompress2.py +204 -0
  9. freealg/_algebraic_form/_edge.py +330 -0
  10. freealg/_algebraic_form/_homotopy.py +323 -0
  11. freealg/_algebraic_form/_moments.py +448 -0
  12. freealg/_algebraic_form/_sheets_util.py +145 -0
  13. freealg/_algebraic_form/_support.py +309 -0
  14. freealg/_algebraic_form/algebraic_form.py +1232 -0
  15. freealg/_free_form/__init__.py +16 -0
  16. freealg/{_chebyshev.py → _free_form/_chebyshev.py} +75 -43
  17. freealg/_free_form/_decompress.py +993 -0
  18. freealg/_free_form/_density_util.py +243 -0
  19. freealg/_free_form/_jacobi.py +359 -0
  20. freealg/_free_form/_linalg.py +508 -0
  21. freealg/{_pade.py → _free_form/_pade.py} +42 -208
  22. freealg/{_plot_util.py → _free_form/_plot_util.py} +37 -22
  23. freealg/{_sample.py → _free_form/_sample.py} +58 -22
  24. freealg/_free_form/_series.py +454 -0
  25. freealg/_free_form/_support.py +214 -0
  26. freealg/_free_form/free_form.py +1362 -0
  27. freealg/_geometric_form/__init__.py +13 -0
  28. freealg/_geometric_form/_continuation_genus0.py +175 -0
  29. freealg/_geometric_form/_continuation_genus1.py +275 -0
  30. freealg/_geometric_form/_elliptic_functions.py +174 -0
  31. freealg/_geometric_form/_sphere_maps.py +63 -0
  32. freealg/_geometric_form/_torus_maps.py +118 -0
  33. freealg/_geometric_form/geometric_form.py +1094 -0
  34. freealg/_util.py +56 -110
  35. freealg/distributions/__init__.py +7 -1
  36. freealg/distributions/_chiral_block.py +494 -0
  37. freealg/distributions/_deformed_marchenko_pastur.py +726 -0
  38. freealg/distributions/_deformed_wigner.py +386 -0
  39. freealg/distributions/_kesten_mckay.py +29 -15
  40. freealg/distributions/_marchenko_pastur.py +224 -95
  41. freealg/distributions/_meixner.py +47 -37
  42. freealg/distributions/_wachter.py +29 -17
  43. freealg/distributions/_wigner.py +27 -14
  44. freealg/visualization/__init__.py +12 -0
  45. freealg/visualization/_glue_util.py +32 -0
  46. freealg/visualization/_rgb_hsv.py +125 -0
  47. freealg-0.7.12.dist-info/METADATA +172 -0
  48. freealg-0.7.12.dist-info/RECORD +53 -0
  49. {freealg-0.1.11.dist-info → freealg-0.7.12.dist-info}/WHEEL +1 -1
  50. freealg/_decompress.py +0 -180
  51. freealg/_jacobi.py +0 -218
  52. freealg/_support.py +0 -85
  53. freealg/freeform.py +0 -967
  54. freealg-0.1.11.dist-info/METADATA +0 -140
  55. freealg-0.1.11.dist-info/RECORD +0 -24
  56. /freealg/{_damp.py → _free_form/_damp.py} +0 -0
  57. {freealg-0.1.11.dist-info → freealg-0.7.12.dist-info}/licenses/AUTHORS.txt +0 -0
  58. {freealg-0.1.11.dist-info → freealg-0.7.12.dist-info}/licenses/LICENSE.txt +0 -0
  59. {freealg-0.1.11.dist-info → freealg-0.7.12.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,706 @@
1
+ # SPDX-FileCopyrightText: Copyright 2025, Siavash Ameli <sameli@berkeley.edu>
2
+ # SPDX-License-Identifier: BSD-3-Clause
3
+ # SPDX-FileType: SOURCE
4
+ #
5
+ # This program is free software: you can redistribute it and/or modify it under
6
+ # the terms of the license found in the LICENSE.txt file in the root directory
7
+ # of this source tree.
8
+
9
+
10
+ # =======
11
+ # Imports
12
+ # =======
13
+
14
+ import numpy
15
+ from .._geometric_form._continuation_genus0 import joukowski_z
16
+ from ._constraints import build_moment_constraint_matrix
17
+
18
+ __all__ = ['sample_z_joukowski', 'filter_z_away_from_cuts', 'powers',
19
+ 'fit_polynomial_relation', 'sanity_check_stieltjes_branch',
20
+ 'eval_P', 'eval_roots', 'build_sheets_from_roots']
21
+
22
+
23
+ # ======================
24
+ # normalize coefficients
25
+ # ======================
26
+
27
+ def _normalize_coefficients(arr):
28
+ """
29
+ Trim rows and columns on the sides (equivalent to factorizing or reducing
30
+ degree) and normalize so that the sum of the first column is one.
31
+ """
32
+
33
+ a = numpy.asarray(arr).copy()
34
+
35
+ if a.size == 0:
36
+ return a
37
+
38
+ # --- Trim zero rows (top and bottom) ---
39
+ non_zero_rows = numpy.any(a != 0, axis=1)
40
+ if not numpy.any(non_zero_rows):
41
+ return a[:0, :0]
42
+
43
+ first_row = numpy.argmax(non_zero_rows)
44
+ last_row = len(non_zero_rows) - numpy.argmax(non_zero_rows[::-1])
45
+ a = a[first_row:last_row, :]
46
+
47
+ # --- Trim zero columns (left and right) ---
48
+ non_zero_cols = numpy.any(a != 0, axis=0)
49
+ if not numpy.any(non_zero_cols):
50
+ return a[:, :0]
51
+
52
+ first_col = numpy.argmax(non_zero_cols)
53
+ last_col = len(non_zero_cols) - numpy.argmax(non_zero_cols[::-1])
54
+ a = a[:, first_col:last_col]
55
+
56
+ # --- Normalize so first column sums to 1 ---
57
+ col_sum = numpy.sum(numpy.abs(a[:, 0]))
58
+ if col_sum != 0:
59
+ a = a / col_sum
60
+
61
+ return a
62
+
63
+
64
+ # ==================
65
+ # sample z joukowski
66
+ # ==================
67
+
68
+ def sample_z_joukowski(a, b, n_samples=4096, r=1.25, n_r=3, r_min=None):
69
+
70
+ if r_min is None:
71
+ r_min = 1.0 + 0.05 * (r - 1.0) if r > 1.0 else 1.0
72
+
73
+ if n_r is None or n_r < 1:
74
+ n_r = 1
75
+
76
+ if n_samples % 2 != 0:
77
+ raise ValueError('n_samples should be even.')
78
+
79
+ if n_r == 1:
80
+ rs = numpy.array([r], dtype=float)
81
+ else:
82
+ rs = numpy.linspace(r_min, r, n_r)
83
+
84
+ n_half = n_samples // 2
85
+ theta = numpy.pi * (numpy.arange(n_half) + 0.5) / n_half
86
+
87
+ z_list = []
88
+ for r_i in rs:
89
+ w = r_i * numpy.exp(1j * theta)
90
+ z = joukowski_z(w, a, b)
91
+ z_list.append(z)
92
+ z_list.append(numpy.conjugate(z))
93
+
94
+ return numpy.concatenate(z_list)
95
+
96
+
97
+ # =======================
98
+ # filter z away from cuts
99
+ # =======================
100
+
101
+ def filter_z_away_from_cuts(z, cuts, y_eps=1e-2, x_pad=0.0):
102
+
103
+ z = numpy.asarray(z, dtype=numpy.complex128).ravel()
104
+ x = numpy.real(z)
105
+ y = numpy.imag(z)
106
+
107
+ keep = numpy.ones(z.size, dtype=bool)
108
+ for a, b in cuts:
109
+ aa = a - x_pad
110
+ bb = b + x_pad
111
+ near_real_cut = (numpy.abs(y) <= y_eps) & (x >= aa) & (x <= bb)
112
+ keep &= ~near_real_cut
113
+
114
+ return z[keep]
115
+
116
+
117
+ # ======
118
+ # powers
119
+ # ======
120
+
121
+ def powers(x, deg):
122
+
123
+ n = x.size
124
+ xp = numpy.ones((n, deg + 1), dtype=complex)
125
+ for k in range(1, deg + 1):
126
+ xp[:, k] = xp[:, k - 1] * x
127
+ return xp
128
+
129
+
130
+ # =======================
131
+ # fit polynomial relation
132
+ # =======================
133
+
134
+ def fit_polynomial_relation(z, m, s, deg_z, ridge_lambda=0.0, weights=None,
135
+ triangular=None, normalize=False,
136
+ mu=None, mu_reg=None):
137
+ """
138
+ Fits polynomial P(z, m) = 0 with samples from the physical branch.
139
+ """
140
+
141
+ z = numpy.asarray(z, dtype=complex).ravel()
142
+ m = numpy.asarray(m, dtype=complex).ravel()
143
+
144
+ if z.size != m.size:
145
+ raise ValueError('z and m must have the same size.')
146
+ if s < 1:
147
+ raise ValueError('s must be >= 1.')
148
+ if deg_z < 0:
149
+ raise ValueError('deg_z must be >= 0.')
150
+
151
+ zp = powers(z, deg_z)
152
+ mp = powers(m, s)
153
+
154
+ if weights is None:
155
+ w = None
156
+ else:
157
+ w = numpy.asarray(weights, dtype=float).ravel()
158
+ if w.size != z.size:
159
+ raise ValueError('weights must have the same size as z.')
160
+ w = numpy.sqrt(numpy.maximum(w, 0.0))
161
+
162
+ tri = None
163
+ if triangular is not None:
164
+ tri = str(triangular).strip().lower()
165
+ if tri in ['none', '']:
166
+ tri = None
167
+
168
+ if tri is None:
169
+ pairs = [(i, j) for j in range(s + 1)
170
+ for i in range(deg_z + 1)]
171
+
172
+ elif tri in ['lower', 'l']:
173
+ pairs = [(i, j) for j in range(s + 1)
174
+ for i in range(deg_z + 1) if i >= j]
175
+
176
+ elif tri in ['upper', 'u']:
177
+ pairs = [(i, j) for j in range(s + 1)
178
+ for i in range(deg_z + 1) if i <= j]
179
+
180
+ elif tri in ['antidiag', 'anti', 'antidiagonal', 'ad']:
181
+ pairs = [(i, j) for j in range(s + 1)
182
+ for i in range(deg_z + 1) if (i + j) <= deg_z]
183
+
184
+ if len(pairs) == 0:
185
+ raise ValueError('antidiag constraint removed all coefficients.')
186
+ else:
187
+ raise ValueError("triangular must be None, 'lower', 'upper', or " +
188
+ "'antidiag'.")
189
+
190
+ n_coef = len(pairs)
191
+ A = numpy.empty((z.size, n_coef), dtype=complex)
192
+
193
+ for k, (i, j) in enumerate(pairs):
194
+ A[:, k] = zp[:, i] * mp[:, j]
195
+
196
+ if w is not None:
197
+ A = A * w[:, None]
198
+
199
+ # Enforce real coefficients by solving: Re(A) c = 0 and Im(A) c = 0
200
+ Ar = numpy.vstack([A.real, A.imag])
201
+
202
+ s_col = numpy.max(numpy.abs(Ar), axis=0)
203
+ s_col[s_col == 0.0] = 1.0
204
+ As = Ar / s_col[None, :]
205
+
206
+ # Optional moment constraints B c = 0 (hard via nullspace, soft via
207
+ # weighted rows)
208
+ if mu is not None:
209
+ B = build_moment_constraint_matrix(pairs, deg_z, s, mu)
210
+ if B.shape[0] > 0:
211
+ Bs = B / s_col[None, :]
212
+
213
+ if mu_reg is None:
214
+ # Hard constraints: solve in nullspace of Bs
215
+ uB, sB, vhB = numpy.linalg.svd(Bs, full_matrices=True)
216
+ tolB = 1e-12 * (sB[0] if sB.size else 1.0)
217
+ rankB = int(numpy.sum(sB > tolB))
218
+ if rankB >= n_coef:
219
+ raise RuntimeError(
220
+ 'Moment constraints leave no feasible coefficients.')
221
+
222
+ N = vhB[rankB:, :].T # (n_coef, n_free)
223
+ AN = As @ N
224
+
225
+ if ridge_lambda > 0.0:
226
+ L = numpy.sqrt(ridge_lambda) * numpy.eye(N.shape[1],
227
+ dtype=float)
228
+ AN = numpy.vstack([AN, L])
229
+
230
+ _, svals, vhN = numpy.linalg.svd(AN, full_matrices=False)
231
+ y = vhN[-1, :]
232
+ coef_scaled = N @ y
233
+
234
+ coef = coef_scaled / s_col
235
+
236
+ else:
237
+ mu_reg = float(mu_reg)
238
+ if mu_reg > 0.0:
239
+ As_aug = As
240
+ Bs_w = numpy.sqrt(mu_reg) * Bs
241
+ As_aug = numpy.vstack([As_aug, Bs_w])
242
+
243
+ if ridge_lambda > 0.0:
244
+ L = numpy.sqrt(ridge_lambda) * numpy.eye(n_coef,
245
+ dtype=float)
246
+ As_aug = numpy.vstack([As_aug, L])
247
+
248
+ _, svals, vh = numpy.linalg.svd(As_aug,
249
+ full_matrices=False)
250
+ coef_scaled = vh[-1, :]
251
+ coef = coef_scaled / s_col
252
+ else:
253
+ # mu_reg == 0 => ignore constraints
254
+ if ridge_lambda > 0.0:
255
+ L = numpy.sqrt(ridge_lambda) * numpy.eye(n_coef,
256
+ dtype=float)
257
+ As = numpy.vstack([As, L])
258
+
259
+ _, svals, vh = numpy.linalg.svd(As, full_matrices=False)
260
+ coef_scaled = vh[-1, :]
261
+ coef = coef_scaled / s_col
262
+
263
+ else:
264
+ # B has no effective rows -> proceed unconstrained
265
+ if ridge_lambda > 0.0:
266
+ L = numpy.sqrt(ridge_lambda) * numpy.eye(n_coef, dtype=float)
267
+ As = numpy.vstack([As, L])
268
+
269
+ _, svals, vh = numpy.linalg.svd(As, full_matrices=False)
270
+ coef_scaled = vh[-1, :]
271
+ coef = coef_scaled / s_col
272
+
273
+ else:
274
+ # No moment constraints
275
+ if ridge_lambda > 0.0:
276
+ L = numpy.sqrt(ridge_lambda) * numpy.eye(n_coef, dtype=float)
277
+ As = numpy.vstack([As, L])
278
+
279
+ _, svals, vh = numpy.linalg.svd(As, full_matrices=False)
280
+ coef_scaled = vh[-1, :]
281
+ coef = coef_scaled / s_col
282
+
283
+ full = numpy.zeros((deg_z + 1, s + 1), dtype=complex)
284
+ for k, (i, j) in enumerate(pairs):
285
+ full[i, j] = coef[k]
286
+
287
+ if normalize:
288
+ full = _normalize_coefficients(full)
289
+
290
+ # Diagnostic metrics
291
+ fit_metrics = {
292
+ 's_min': float(svals[-1]),
293
+ 'gap_ratio': float(svals[-2] / svals[-1]),
294
+ 'n_small': float(int(numpy.sum(svals <= svals[0] * 1e-12))),
295
+ }
296
+
297
+ return full, fit_metrics
298
+
299
+
300
+ # =============================
301
+ # sanity check stieltjes branch
302
+ # =============================
303
+
304
+ def sanity_check_stieltjes_branch(a_coeffs, x_min, x_max, eta=0.1,
305
+ n_x=64, y0=None, max_bad_frac=0.05):
306
+ """
307
+ Quick sanity check: does P(z,m)=0 admit a continuously trackable root with
308
+ Im(m)>0 along z=x+i*eta.
309
+ """
310
+
311
+ x_min = float(x_min)
312
+ x_max = float(x_max)
313
+ eta = float(eta)
314
+ n_x = int(n_x)
315
+ if n_x < 4:
316
+ n_x = 4
317
+
318
+ if y0 is None:
319
+ y0 = 10.0 * max(1.0, abs(x_min), abs(x_max))
320
+ y0 = float(y0)
321
+
322
+ z0 = 1j * y0
323
+ m0_target = -1.0 / z0
324
+
325
+ c0 = _poly_coef_in_m(numpy.array([z0]), a_coeffs)[0]
326
+ r0 = numpy.roots(c0[::-1])
327
+ if r0.size == 0:
328
+ return {'ok': False, 'frac_bad': 1.0, 'n_test': 0, 'n_bad': 0}
329
+
330
+ k0 = int(numpy.argmin(numpy.abs(r0 - m0_target)))
331
+ m_prev = r0[k0]
332
+
333
+ xs = numpy.linspace(x_min, x_max, n_x)
334
+ zs = xs + 1j * eta
335
+
336
+ n_bad = 0
337
+ n_ok = 0
338
+
339
+ for z in zs:
340
+ c = _poly_coef_in_m(numpy.array([z]), a_coeffs)[0]
341
+ r = numpy.roots(c[::-1])
342
+ if r.size == 0 or not numpy.all(numpy.isfinite(r)):
343
+ n_bad += 1
344
+ continue
345
+
346
+ k = int(numpy.argmin(numpy.abs(r - m_prev)))
347
+ m_sel = r[k]
348
+ m_prev = m_sel
349
+ n_ok += 1
350
+
351
+ if not numpy.isfinite(m_sel) or (m_sel.imag <= 0.0):
352
+ n_bad += 1
353
+
354
+ n_test = n_ok + (n_bad - (n_x - n_ok))
355
+ if n_test <= 0:
356
+ n_test = n_x
357
+
358
+ frac_bad = float(n_bad) / float(n_x)
359
+ ok = frac_bad <= float(max_bad_frac)
360
+
361
+ status = {
362
+ 'ok': ok,
363
+ 'frac_bad': frac_bad,
364
+ 'n_test': n_x,
365
+ 'n_bad': n_bad
366
+ }
367
+
368
+ return status
369
+
370
+
371
+ # ======
372
+ # eval P
373
+ # ======
374
+
375
+ def eval_P(z, m, a_coeffs):
376
+
377
+ z = numpy.asarray(z, dtype=complex)
378
+ m = numpy.asarray(m, dtype=complex)
379
+ deg_z = int(a_coeffs.shape[0] - 1)
380
+ s = int(a_coeffs.shape[1] - 1)
381
+
382
+ shp = numpy.broadcast(z, m).shape
383
+ zz = numpy.broadcast_to(z, shp).ravel()
384
+ mm = numpy.broadcast_to(m, shp).ravel()
385
+
386
+ zp = powers(zz, deg_z)
387
+ mp = powers(mm, s)
388
+
389
+ P = numpy.zeros(zz.size, dtype=complex)
390
+ for j in range(s + 1):
391
+ aj = zp @ a_coeffs[:, j]
392
+ P = P + aj * mp[:, j]
393
+
394
+ return P.reshape(shp)
395
+
396
+
397
+ # ==============
398
+ # poly coef in m
399
+ # ==============
400
+
401
+ def _poly_coef_in_m(z, a_coeffs):
402
+
403
+ z = numpy.asarray(z, dtype=complex).ravel()
404
+ deg_z = int(a_coeffs.shape[0] - 1)
405
+ s = int(a_coeffs.shape[1] - 1)
406
+ zp = powers(z, deg_z)
407
+
408
+ c = numpy.empty((z.size, s + 1), dtype=complex)
409
+ for j in range(s + 1):
410
+ c[:, j] = zp @ a_coeffs[:, j]
411
+ return c
412
+
413
+
414
+ # ==============
415
+ # root quadratic
416
+ # ==============
417
+
418
+ def _roots_quadratic(c0, c1, c2):
419
+
420
+ disc = c1 * c1 - 4.0 * c2 * c0
421
+ sq = numpy.sqrt(disc)
422
+ den = 2.0 * c2
423
+
424
+ r1 = (-c1 + sq) / den
425
+ r2 = (-c1 - sq) / den
426
+ return numpy.stack([r1, r2], axis=1)
427
+
428
+
429
+ # ============
430
+ # cbrt complex
431
+ # ============
432
+
433
+ def _cbrt_complex(z):
434
+
435
+ z = numpy.asarray(z, dtype=complex)
436
+ r = numpy.abs(z)
437
+ th = numpy.angle(z)
438
+ return (r ** (1.0 / 3.0)) * numpy.exp(1j * th / 3.0)
439
+
440
+
441
+ # ==========
442
+ # root cubic
443
+ # ==========
444
+
445
+ def _roots_cubic(c0, c1, c2, c3):
446
+
447
+ c0 = numpy.asarray(c0, dtype=complex)
448
+ c1 = numpy.asarray(c1, dtype=complex)
449
+ c2 = numpy.asarray(c2, dtype=complex)
450
+ c3 = numpy.asarray(c3, dtype=complex)
451
+
452
+ a = c2 / c3
453
+ b = c1 / c3
454
+ c = c0 / c3
455
+
456
+ p = b - (a * a) / 3.0
457
+ q = (2.0 * a * a * a) / 27.0 - (a * b) / 3.0 + c
458
+
459
+ Delta = (q * q) / 4.0 + (p * p * p) / 27.0
460
+ sqrtD = numpy.sqrt(Delta)
461
+
462
+ A = -q / 2.0 + sqrtD
463
+ u = _cbrt_complex(A)
464
+
465
+ eps = 1e-30
466
+ small = numpy.abs(u) < eps
467
+ if numpy.any(small):
468
+ u2 = _cbrt_complex(-q / 2.0 - sqrtD)
469
+ u = numpy.where(small, u2, u)
470
+
471
+ small = numpy.abs(u) < eps
472
+ v = numpy.empty_like(u)
473
+ v[~small] = -p[~small] / (3.0 * u[~small])
474
+ v[small] = _cbrt_complex(-q[small])
475
+
476
+ y1 = u + v
477
+ w = complex(-0.5, numpy.sqrt(3.0) / 2.0)
478
+ y2 = w * u + numpy.conjugate(w) * v
479
+ y3 = numpy.conjugate(w) * u + w * v
480
+
481
+ x1 = y1 - a / 3.0
482
+ x2 = y2 - a / 3.0
483
+ x3 = y3 - a / 3.0
484
+
485
+ return numpy.stack([x1, x2, x3], axis=1)
486
+
487
+
488
+ # ==========
489
+ # eval roots
490
+ # ==========
491
+
492
+ def eval_roots(z, a_coeffs):
493
+
494
+ z = numpy.asarray(z, dtype=complex).ravel()
495
+ c = _poly_coef_in_m(z, a_coeffs)
496
+
497
+ s = int(c.shape[1] - 1)
498
+ if s == 1:
499
+ m = -c[:, 0] / c[:, 1]
500
+ return m[:, None]
501
+
502
+ if s == 2:
503
+ return _roots_quadratic(c[:, 0], c[:, 1], c[:, 2])
504
+
505
+ if s == 3:
506
+ return _roots_cubic(c[:, 0], c[:, 1], c[:, 2], c[:, 3])
507
+
508
+ roots = numpy.empty((z.size, s), dtype=complex)
509
+ for i in range(z.size):
510
+ roots[i, :] = numpy.roots(c[i, ::-1])
511
+ return roots
512
+
513
+
514
+ # =======================
515
+ # track one sheet on grid
516
+ # =======================
517
+
518
+ def track_one_sheet_on_grid(z, roots, sheet_seed, cuts=None, i0=None, j0=None):
519
+ """
520
+ This is mostly used for visualization of the sheets.
521
+ """
522
+
523
+ z = numpy.asarray(z)
524
+ n_y, n_x = z.shape
525
+ s = roots.shape[1]
526
+ if s < 1:
527
+ raise ValueError("s must be >= 1.")
528
+
529
+ R = roots.reshape((n_y, n_x, s))
530
+
531
+ if i0 is None:
532
+ ycol = numpy.imag(z[:, 0])
533
+ pos = numpy.where(ycol > 0.0)[0]
534
+ i0 = int(pos[0]) if pos.size > 0 else (n_y // 2)
535
+
536
+ if j0 is None:
537
+ j0 = n_x // 2
538
+
539
+ seed_imag = float(numpy.imag(sheet_seed))
540
+ cand0 = R[i0, j0, :]
541
+ idx0 = int(numpy.argmin(numpy.abs(cand0 - sheet_seed)))
542
+
543
+ sheet = numpy.full((n_y, n_x), numpy.nan + 1j * numpy.nan, dtype=complex)
544
+ sheet[i0, j0] = cand0[idx0]
545
+
546
+ visited = numpy.zeros((n_y, n_x), dtype=bool)
547
+ q_i = numpy.empty(n_y * n_x, dtype=int)
548
+ q_j = numpy.empty(n_y * n_x, dtype=int)
549
+
550
+ head = 0
551
+ tail = 0
552
+ q_i[tail] = i0
553
+ q_j[tail] = j0
554
+ tail += 1
555
+ visited[i0, j0] = True
556
+
557
+ neighbors = [(-1, 0), (1, 0), (0, -1), (0, 1)]
558
+
559
+ y_unique = numpy.unique(numpy.imag(z[:, 0]))
560
+ if y_unique.size >= 2:
561
+ dy = float(numpy.min(numpy.diff(y_unique)))
562
+ y_eps = 0.49 * dy
563
+ else:
564
+ y_eps = 0.0
565
+
566
+ def crosses_cut(x_mid):
567
+ if cuts is None:
568
+ return False
569
+ for a, b in cuts:
570
+ if a <= x_mid <= b:
571
+ return True
572
+ return False
573
+
574
+ while head < tail:
575
+ i = int(q_i[head])
576
+ j = int(q_j[head])
577
+ head += 1
578
+
579
+ m_prev = sheet[i, j]
580
+ y1 = float(numpy.imag(z[i, j]))
581
+ x1 = float(numpy.real(z[i, j]))
582
+
583
+ for di, dj in neighbors:
584
+ i2 = i + di
585
+ j2 = j + dj
586
+ if i2 < 0 or i2 >= n_y or j2 < 0 or j2 >= n_x:
587
+ continue
588
+ if visited[i2, j2]:
589
+ continue
590
+
591
+ y2 = float(numpy.imag(z[i2, j2]))
592
+ x2 = float(numpy.real(z[i2, j2]))
593
+
594
+ if cuts is not None:
595
+ if (y1 > y_eps and y2 < -y_eps) or \
596
+ (y1 < -y_eps and y2 > y_eps):
597
+ x_mid = 0.5 * (x1 + x2)
598
+ if crosses_cut(x_mid):
599
+ continue
600
+
601
+ cand = R[i2, j2, :]
602
+ d = numpy.abs(cand - m_prev)
603
+ idx = int(numpy.argmin(d))
604
+
605
+ if seed_imag != 0.0:
606
+ y_sign = 1.0 if y2 >= 0.0 else -1.0
607
+ target = float(numpy.sign(seed_imag) * y_sign)
608
+ if target != 0.0:
609
+ sgn = numpy.sign(numpy.imag(cand))
610
+ ok = (sgn == numpy.sign(target)) | (sgn == 0.0)
611
+ if numpy.any(ok):
612
+ ok_idx = numpy.where(ok)[0]
613
+ idx = int(ok_idx[numpy.argmin(d[ok])])
614
+
615
+ sheet[i2, j2] = cand[idx]
616
+ visited[i2, j2] = True
617
+ q_i[tail] = i2
618
+ q_j[tail] = j2
619
+ tail += 1
620
+
621
+ return sheet
622
+
623
+
624
+ # =======================
625
+ # build sheets from roots
626
+ # =======================
627
+
628
+ def build_sheets_from_roots(z, roots, m1, cuts=None, i0=None, j0=None):
629
+
630
+ z = numpy.asarray(z)
631
+ m1 = numpy.asarray(m1)
632
+
633
+ n_y, n_x = z.shape
634
+ s = roots.shape[1]
635
+ if s < 1:
636
+ raise ValueError("s must be >= 1.")
637
+
638
+ if i0 is None:
639
+ ycol = numpy.imag(z[:, 0])
640
+ pos = numpy.where(ycol > 0.0)[0]
641
+ i0 = int(pos[0]) if pos.size > 0 else (n_y // 2)
642
+
643
+ if j0 is None:
644
+ j0 = n_x // 2
645
+
646
+ R0 = roots.reshape((n_y, n_x, s))[i0, j0, :]
647
+ idx_phys = int(numpy.argmin(numpy.abs(R0 - m1[i0, j0])))
648
+
649
+ idxs = list(range(s))
650
+ idxs.sort(key=lambda k: numpy.imag(R0[k]))
651
+
652
+ seeds = [R0[k] for k in idxs]
653
+ sheets = [track_one_sheet_on_grid(z, roots, seed, cuts=cuts, i0=i0, j0=j0)
654
+ for seed in seeds]
655
+
656
+ phys_pos = int(numpy.where(numpy.array(idxs, dtype=int) == idx_phys)[0][0])
657
+ if phys_pos != 0:
658
+ sheets[0], sheets[phys_pos] = sheets[phys_pos], sheets[0]
659
+ idxs[0], idxs[phys_pos] = idxs[phys_pos], idxs[0]
660
+
661
+ if cuts is not None:
662
+ y_unique = numpy.unique(numpy.imag(z[:, 0]))
663
+ if y_unique.size >= 2:
664
+ dy = float(numpy.min(numpy.diff(y_unique)))
665
+ eps_y = 0.49 * dy
666
+ else:
667
+ eps_y = 0.0
668
+
669
+ i_cut = numpy.where(numpy.abs(numpy.imag(z[:, 0])) <= eps_y)[0]
670
+ if i_cut.size > 0:
671
+ i_cut = int(i_cut[numpy.argmin(numpy.abs(
672
+ numpy.imag(z[i_cut, 0])))])
673
+
674
+ X = numpy.real(z[i_cut, :])
675
+ on_cut = numpy.zeros(n_x, dtype=bool)
676
+ for j in range(n_x):
677
+ xj = float(X[j])
678
+ for a, b in cuts:
679
+ if a <= xj <= b:
680
+ on_cut[j] = True
681
+ break
682
+
683
+ sheets[0][i_cut, on_cut] = m1[i_cut, on_cut]
684
+
685
+ ycol = numpy.imag(z[:, 0])
686
+ y_unique = numpy.unique(ycol)
687
+ if y_unique.size >= 2:
688
+ dy = float(numpy.min(numpy.diff(y_unique)))
689
+ eps_y = 1.1 * dy
690
+ else:
691
+ eps_y = 0.0
692
+
693
+ i_band = numpy.where(numpy.abs(ycol) <= eps_y)[0]
694
+ i_up = numpy.where(ycol > eps_y)[0]
695
+ i_dn = numpy.where(ycol < -eps_y)[0]
696
+ if (i_band.size > 0) and (i_up.size > 0) and (i_dn.size > 0):
697
+ i_up = int(i_up[0])
698
+ i_dn = int(i_dn[-1])
699
+ for r in range(1, len(sheets)):
700
+ for i in i_band:
701
+ if ycol[i] >= 0.0:
702
+ sheets[r][i, :] = sheets[r][i_up, :]
703
+ else:
704
+ sheets[r][i, :] = sheets[r][i_dn, :]
705
+
706
+ return sheets, idxs