freealg 0.7.16__py3-none-any.whl → 0.7.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. freealg/__init__.py +8 -6
  2. freealg/__version__.py +1 -1
  3. freealg/_algebraic_form/_branch_points.py +18 -18
  4. freealg/_algebraic_form/_continuation_algebraic.py +13 -13
  5. freealg/_algebraic_form/_cusp.py +15 -15
  6. freealg/_algebraic_form/_cusp_wrap.py +6 -6
  7. freealg/_algebraic_form/_decompress.py +16 -16
  8. freealg/_algebraic_form/_decompress4.py +31 -31
  9. freealg/_algebraic_form/_decompress5.py +23 -23
  10. freealg/_algebraic_form/_decompress6.py +13 -13
  11. freealg/_algebraic_form/_decompress7.py +15 -15
  12. freealg/_algebraic_form/_decompress8.py +17 -17
  13. freealg/_algebraic_form/_decompress9.py +18 -18
  14. freealg/_algebraic_form/_decompress_new.py +17 -17
  15. freealg/_algebraic_form/_decompress_new_2.py +57 -57
  16. freealg/_algebraic_form/_decompress_util.py +10 -10
  17. freealg/_algebraic_form/_decompressible.py +292 -0
  18. freealg/_algebraic_form/_edge.py +10 -10
  19. freealg/_algebraic_form/_homotopy4.py +9 -9
  20. freealg/_algebraic_form/_homotopy5.py +9 -9
  21. freealg/_algebraic_form/_support.py +19 -19
  22. freealg/_algebraic_form/algebraic_form.py +262 -468
  23. freealg/_base_form.py +401 -0
  24. freealg/_free_form/__init__.py +1 -4
  25. freealg/_free_form/_density_util.py +1 -1
  26. freealg/_free_form/_plot_util.py +3 -511
  27. freealg/_free_form/free_form.py +8 -367
  28. freealg/_util.py +59 -11
  29. freealg/distributions/__init__.py +2 -1
  30. freealg/distributions/_base_distribution.py +163 -0
  31. freealg/distributions/_chiral_block.py +137 -11
  32. freealg/distributions/_compound_poisson.py +168 -64
  33. freealg/distributions/_deformed_marchenko_pastur.py +137 -88
  34. freealg/distributions/_deformed_wigner.py +92 -40
  35. freealg/distributions/_fuss_catalan.py +269 -0
  36. freealg/distributions/_kesten_mckay.py +4 -130
  37. freealg/distributions/_marchenko_pastur.py +8 -196
  38. freealg/distributions/_meixner.py +4 -130
  39. freealg/distributions/_wachter.py +4 -130
  40. freealg/distributions/_wigner.py +10 -127
  41. freealg/visualization/__init__.py +2 -2
  42. freealg/visualization/{_rgb_hsv.py → _domain_coloring.py} +37 -29
  43. freealg/visualization/_plot_util.py +513 -0
  44. {freealg-0.7.16.dist-info → freealg-0.7.18.dist-info}/METADATA +1 -1
  45. freealg-0.7.18.dist-info/RECORD +74 -0
  46. freealg-0.7.16.dist-info/RECORD +0 -69
  47. /freealg/{_free_form/_sample.py → _sample.py} +0 -0
  48. /freealg/{_free_form/_support.py → _support.py} +0 -0
  49. {freealg-0.7.16.dist-info → freealg-0.7.18.dist-info}/WHEEL +0 -0
  50. {freealg-0.7.16.dist-info → freealg-0.7.18.dist-info}/licenses/AUTHORS.txt +0 -0
  51. {freealg-0.7.16.dist-info → freealg-0.7.18.dist-info}/licenses/LICENSE.txt +0 -0
  52. {freealg-0.7.16.dist-info → freealg-0.7.18.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,269 @@
1
+ # SPDX-FileCopyrightText: Copyright 2026, Siavash Ameli <sameli@berkeley.edu>
2
+ # SPDX-License-Identifier: BSD-3-Clause
3
+ # SPDX-FileType: SOURCE
4
+ #
5
+ # This program is free software: you can redistribute it and/or modify it
6
+ # under the terms of the license found in the LICENSE.txt file in the root
7
+ # directory of this source tree.
8
+
9
+
10
+ # =======
11
+ # Imports
12
+ # =======
13
+
14
+ import numpy
15
+ from .._algebraic_form._sheets_util import _pick_physical_root_scalar
16
+
17
+ __all__ = ['FussCatalan']
18
+
19
+
20
+ # ============
21
+ # Fuss-Catalan
22
+ # ============
23
+
24
+ class FussCatalan(object):
25
+ """
26
+ Fuss-Catalan (a.k.a. free Bessel / Raney) distribution.
27
+
28
+ This is the law of squared singular values of a product of :math:`p`
29
+ independent (square) Ginibre matrices in the large-n limit. For :math:`p=1`
30
+ this reduces to the Marchenko--Pastur law with :math:`c = 1`.
31
+
32
+ Parameters
33
+ ----------
34
+
35
+ p : int, default=2
36
+ Order of the Fuss-Catalan distribution. :math:`p = 1` is MP
37
+ (:math:`c=1`).
38
+
39
+ Notes
40
+ -----
41
+
42
+ Let :math:`p \\geq 1` be an integer. The Stieltjes transform :math:`m(z)`
43
+ of the :math:`p`-th Fuss-Catalan distribution solves the algebraic equation
44
+
45
+ ..math::
46
+
47
+ (-1)^p z^p m(z)^{p+1} - z m(z) - 1 = 0,
48
+
49
+ where the physical branch is selected by the Herglotz condition
50
+ Im m(z) > 0 for Im z > 0.
51
+
52
+ Equivalently, in terms of w(z) = -m(z):
53
+
54
+ .. math::
55
+
56
+ z = (1 + w)^{p+1} / w^p.
57
+
58
+ The support is a single interval [0, x_max] with
59
+
60
+ .. math::
61
+
62
+ x_max = (p+1)^{p+1} / p^p.
63
+
64
+ (All in the standard normalization where the mean is 1.)
65
+
66
+ **Application:**
67
+
68
+ This law might be applicable to the Jacobian of neural netowkrs.
69
+ """
70
+
71
+ # ====
72
+ # init
73
+ # ====
74
+
75
+ def __init__(self, p=2):
76
+ """
77
+ Initialization.
78
+ """
79
+
80
+ p = int(p)
81
+
82
+ if p < 1:
83
+ raise ValueError("p must be an integer >= 1.")
84
+
85
+ self.p = p
86
+
87
+ # ============
88
+ # roots scalar
89
+ # ============
90
+
91
+ def _roots_scalar(self, z):
92
+ """
93
+ Return all algebraic roots of the defining polynomial at scalar z.
94
+ """
95
+
96
+ p = int(self.p)
97
+ z = numpy.asarray(z, dtype=numpy.complex128).reshape(())
98
+ coeffs = numpy.zeros(p + 2, dtype=numpy.complex128)
99
+
100
+ # Polynomial in m:
101
+ # (-1)^p z^p m^{p+1} - z m - 1 = 0
102
+ coeffs[0] = ((-1.0)**p) * (z**p) # m^{p+1}
103
+ coeffs[-2] = -z # m^1
104
+ coeffs[-1] = -1.0 # m^0
105
+
106
+ return numpy.roots(coeffs)
107
+
108
+ # =========
109
+ # stieltjes
110
+ # =========
111
+
112
+ def stieltjes(self, z, max_iter=100, tol=1e-12):
113
+ """
114
+ Compute the physical Stieltjes transform m(z) by Newton, with robust
115
+ fallback to algebraic roots + physical picking.
116
+ """
117
+
118
+ p = int(self.p)
119
+
120
+ z = numpy.asarray(z, dtype=numpy.complex128)
121
+ scalar = (z.ndim == 0)
122
+ if scalar:
123
+ z = z.reshape((1,))
124
+
125
+ # Initial guess: m ~ -1/z for |z| large
126
+ m = -1.0 / z
127
+ active = numpy.isfinite(m)
128
+
129
+ # Newton on f(m) = (-1)^p z^p m^{p+1} - z m - 1
130
+ for _ in range(int(max_iter)):
131
+ if not numpy.any(active):
132
+ break
133
+
134
+ ma = m[active]
135
+ za = z[active]
136
+
137
+ # f and f'
138
+ f = ((-1.0)**p) * (za**p) * (ma**(p + 1)) - za * ma - 1.0
139
+ fp = ((-1.0)**p) * (za**p) * (p + 1) * (ma**p) - za
140
+
141
+ step = f / fp
142
+ mn = ma - step
143
+ m[active] = mn
144
+
145
+ conv = numpy.abs(step) < tol * (1.0 + numpy.abs(mn))
146
+ idx = numpy.where(active)[0]
147
+ active[idx[conv]] = False
148
+
149
+ # Herglotz sanity: sign(Im z) == sign(Im m)
150
+ sign = numpy.where(numpy.imag(z) >= 0.0, 1.0, -1.0)
151
+ bad = (~numpy.isfinite(m)) | (sign * numpy.imag(m) <= 0.0)
152
+
153
+ if numpy.any(bad):
154
+ zb = z.ravel()
155
+ mb = m.ravel()
156
+ bad_idx = numpy.flatnonzero(bad)
157
+ for i in bad_idx:
158
+ zi = zb[i]
159
+ r = self._roots_scalar(zi)
160
+ mb[i] = _pick_physical_root_scalar(zi, r)
161
+ m = mb.reshape(z.shape)
162
+
163
+ if scalar:
164
+ return m.reshape(())
165
+ return m
166
+
167
+ # =======
168
+ # density
169
+ # =======
170
+
171
+ def density(self, x, eta=2e-4):
172
+ """
173
+ Density rho(x) from Im m(x + i eta) / pi.
174
+ """
175
+
176
+ x = numpy.asarray(x, dtype=numpy.float64)
177
+ z = x + 1j * float(eta)
178
+ m = self.stieltjes(z)
179
+ rho = numpy.imag(m) / numpy.pi
180
+ return numpy.maximum(rho, 0.0)
181
+
182
+ # =====
183
+ # roots
184
+ # =====
185
+
186
+ def roots(self, z):
187
+ """
188
+ Return all algebraic branches at scalar z.
189
+ """
190
+
191
+ z = numpy.asarray(z, dtype=numpy.complex128)
192
+ if z.ndim != 0:
193
+ raise ValueError("roots(z) expects scalar z.")
194
+ return self._roots_scalar(z.reshape(()))
195
+
196
+ # =======
197
+ # support
198
+ # =======
199
+
200
+ def support(self):
201
+ """
202
+ Return the single support interval [0, x_max].
203
+ """
204
+
205
+ p = float(self.p)
206
+ x_max = ((p + 1.0)**(p + 1.0)) / (p**p)
207
+ return [(0.0, float(x_max))]
208
+
209
+ # ======
210
+ # matrix
211
+ # ======
212
+
213
+ def matrix(self, size, seed=None):
214
+ """
215
+ Generate a PSD random matrix whose ESD approximates Fuss-Catalan(p).
216
+
217
+ Construction
218
+ ------------
219
+ Let G_k be i.i.d. n x n Gaussian (Ginibre) matrices. Define
220
+
221
+ P = (1/sqrt(n)) G_1 (1/sqrt(n)) G_2 ... (1/sqrt(n)) G_p,
222
+ A = P P^T.
223
+
224
+ Then the ESD of A converges to Fuss-Catalan(p) as n->infty.
225
+ """
226
+
227
+ p = int(self.p)
228
+ n = int(size)
229
+ if n <= 0:
230
+ raise ValueError("size must be a positive integer.")
231
+
232
+ rng = numpy.random.default_rng(seed)
233
+
234
+ P = numpy.eye(n, dtype=numpy.float64)
235
+ scale = 1.0 / numpy.sqrt(float(n))
236
+ for _ in range(p):
237
+ G = rng.standard_normal((n, n))
238
+ P = P @ (scale * G)
239
+
240
+ A = P @ P.T
241
+ return A
242
+
243
+ # ====
244
+ # poly
245
+ # ====
246
+
247
+ def poly(self):
248
+ """
249
+ Return coeffs for the exact polynomial P(z,m)=0.
250
+
251
+ P(z,m) = (-1)^p z^p m^{p+1} - z m - 1.
252
+
253
+ coeffs[i, j] is the coefficient of z^i m^j.
254
+ Shape is (deg_z+1, deg_m+1) = (p+1, p+2).
255
+ """
256
+
257
+ p = int(self.p)
258
+ a = numpy.zeros((p + 1, p + 2), dtype=numpy.complex128)
259
+
260
+ # constant: -1
261
+ a[0, 0] = -1.0
262
+
263
+ # - z m
264
+ a[1, 1] = -1.0
265
+
266
+ # (-1)^p z^p m^{p+1}
267
+ a[p, p + 1] = ((-1.0)**p)
268
+
269
+ return a
@@ -12,15 +12,9 @@
12
12
  # =======
13
13
 
14
14
  import numpy
15
- from scipy.interpolate import interp1d
16
- from .._free_form._plot_util import plot_density, plot_hilbert, \
17
- plot_stieltjes, plot_stieltjes_on_disk, plot_samples
18
-
19
- try:
20
- from scipy.integrate import cumtrapz
21
- except ImportError:
22
- from scipy.integrate import cumulative_trapezoid as cumtrapz
23
- from scipy.stats import qmc
15
+ from ..visualization._plot_util import plot_density, plot_hilbert, \
16
+ plot_stieltjes, plot_stieltjes_on_disk
17
+ from ._base_distribution import BaseDistribution
24
18
 
25
19
  __all__ = ['KestenMcKay']
26
20
 
@@ -29,7 +23,7 @@ __all__ = ['KestenMcKay']
29
23
  # Kesten McKay
30
24
  # ============
31
25
 
32
- class KestenMcKay(object):
26
+ class KestenMcKay(BaseDistribution):
33
27
  """
34
28
  Kesten-McKay distribution.
35
29
 
@@ -433,126 +427,6 @@ class KestenMcKay(object):
433
427
 
434
428
  return m1, m2
435
429
 
436
- # ======
437
- # sample
438
- # ======
439
-
440
- def sample(self, size, x_min=None, x_max=None, method='qmc', seed=None,
441
- plot=False, latex=False, save=False):
442
- """
443
- Sample from distribution.
444
-
445
- Parameters
446
- ----------
447
-
448
- size : int
449
- Size of sample.
450
-
451
- x_min : float, default=None
452
- Minimum of sample values. If `None`, the left edge of the support
453
- is used.
454
-
455
- x_max : float, default=None
456
- Maximum of sample values. If `None`, the right edge of the support
457
- is used.
458
-
459
- method : {``'mc'``, ``'qmc'``}, default= ``'qmc'``
460
- Method of drawing samples from uniform distribution:
461
-
462
- * ``'mc'``: Monte Carlo
463
- * ``'qmc'``: Quasi Monte Carlo
464
-
465
- seed : int, default=None,
466
- Seed for random number generator.
467
-
468
- plot : bool, default=False
469
- If `True`, samples histogram is plotted.
470
-
471
- latex : bool, default=False
472
- If `True`, the plot is rendered using LaTeX. This option is
473
- relevant only if ``plot=True``.
474
-
475
- save : bool, default=False
476
- If not `False`, the plot is saved. If a string is given, it is
477
- assumed to the save filename (with the file extension). This option
478
- is relevant only if ``plot=True``.
479
-
480
- Returns
481
- -------
482
-
483
- s : numpy.ndarray
484
- Samples.
485
-
486
- Notes
487
- -----
488
-
489
- This method uses inverse transform sampling.
490
-
491
- Examples
492
- --------
493
-
494
- .. code-block::python
495
-
496
- >>> from freealg.distributions import KestenMcKay
497
- >>> km = KestenMcKay(3)
498
- >>> s = km.sample(2000)
499
-
500
- .. image:: ../_static/images/plots/km_samples.png
501
- :align: center
502
- :class: custom-dark
503
- """
504
-
505
- if x_min is None:
506
- x_min = self.lam_m
507
-
508
- if x_max is None:
509
- x_max = self.lam_p
510
-
511
- # Grid and PDF
512
- xs = numpy.linspace(x_min, x_max, size)
513
- pdf = self.density(xs)
514
-
515
- # CDF (using cumulative trapezoidal rule)
516
- cdf = cumtrapz(pdf, xs, initial=0)
517
- cdf /= cdf[-1] # normalize CDF to 1
518
-
519
- # Inverse CDF interpolator
520
- inv_cdf = interp1d(cdf, xs, bounds_error=False,
521
- fill_value=(x_min, x_max))
522
-
523
- # Random generator
524
- rng = numpy.random.default_rng(seed)
525
-
526
- # Draw from uniform distribution
527
- if method == 'mc':
528
- u = rng.random(size)
529
-
530
- elif method == 'qmc':
531
- try:
532
- engine = qmc.Halton(d=1, scramble=True, rng=rng)
533
- except TypeError:
534
- # Older scipy versions
535
- engine = qmc.Halton(d=1, scramble=True, seed=rng)
536
- u = engine.random(size).ravel()
537
-
538
- else:
539
- raise NotImplementedError('"method" is invalid.')
540
-
541
- # Draw from distribution by mapping from inverse CDF
542
- samples = inv_cdf(u).ravel()
543
-
544
- if plot:
545
- radius = 0.5 * (self.lam_p - self.lam_m)
546
- center = 0.5 * (self.lam_p + self.lam_m)
547
- scale = 1.25
548
- x_min = numpy.floor(center - radius * scale)
549
- x_max = numpy.ceil(center + radius * scale)
550
- x = numpy.linspace(x_min, x_max, 500)
551
- rho = self.density(x)
552
- plot_samples(x, rho, x_min, x_max, samples, latex=latex, save=save)
553
-
554
- return samples
555
-
556
430
  # ===============
557
431
  # haar orthogonal
558
432
  # ===============
@@ -12,16 +12,10 @@
12
12
  # =======
13
13
 
14
14
  import numpy
15
- from scipy.interpolate import interp1d
16
- from .._free_form._plot_util import plot_density, plot_hilbert, \
17
- plot_stieltjes, plot_stieltjes_on_disk, plot_samples
15
+ from ..visualization._plot_util import plot_density, plot_hilbert, \
16
+ plot_stieltjes, plot_stieltjes_on_disk
18
17
  from ..visualization import glue_branches
19
-
20
- try:
21
- from scipy.integrate import cumtrapz
22
- except ImportError:
23
- from scipy.integrate import cumulative_trapezoid as cumtrapz
24
- from scipy.stats import qmc
18
+ from ._base_distribution import BaseDistribution
25
19
 
26
20
  __all__ = ['MarchenkoPastur']
27
21
 
@@ -30,7 +24,7 @@ __all__ = ['MarchenkoPastur']
30
24
  # Marchenko Pastur
31
25
  # ================
32
26
 
33
- class MarchenkoPastur(object):
27
+ class MarchenkoPastur(BaseDistribution):
34
28
  """
35
29
  Marchenko-Pastur distribution.
36
30
 
@@ -103,8 +97,6 @@ class MarchenkoPastur(object):
103
97
  self.lam = lam
104
98
  self.sigma = sigma
105
99
 
106
- # self.lam_p = (1 + numpy.sqrt(self.lam))**2
107
- # self.lam_m = (1 - numpy.sqrt(self.lam))**2
108
100
  self.lam_p = sigma**2 * (1.0 + numpy.sqrt(lam))**2
109
101
  self.lam_m = sigma**2 * (1.0 - numpy.sqrt(lam))**2
110
102
 
@@ -123,7 +115,7 @@ class MarchenkoPastur(object):
123
115
 
124
116
  x : numpy.array, default=None
125
117
  The locations where density is evaluated at. If `None`, an interval
126
- slightly larger than the supp interval of the spectral density
118
+ slightly larger than the support interval of the spectral density
127
119
  is used.
128
120
 
129
121
  rho : numpy.array, default=None
@@ -138,7 +130,7 @@ class MarchenkoPastur(object):
138
130
 
139
131
  save : bool, default=False
140
132
  If not `False`, the plot is saved. If a string is given, it is
141
- assumed to the save filename (with the file extension). This option
133
+ assumed to the save filename (with the file extension). This option
142
134
  is relevant only if ``plot=True``.
143
135
 
144
136
  eig : numpy.array, default=None
@@ -223,7 +215,7 @@ class MarchenkoPastur(object):
223
215
 
224
216
  x : numpy.array, default=None
225
217
  The locations where Hilbert transform is evaluated at. If `None`,
226
- an interval slightly larger than the supp interval of the
218
+ an interval slightly larger than the support interval of the
227
219
  spectral density is used.
228
220
 
229
221
  plot : bool, default=False
@@ -288,46 +280,6 @@ class MarchenkoPastur(object):
288
280
 
289
281
  return hilb
290
282
 
291
- # =======================
292
- # m mp numeric vectorized
293
- # =======================
294
-
295
- # def _m_mp_numeric_vectorized(self, z, alt_branch=False, tol=1e-8):
296
- # """
297
- # Stieltjes transform (principal or secondary branch)
298
- # for Marchenko-Pastur distribution on upper half-plane.
299
- # """
300
- #
301
- # sigma = 1.0
302
- # m = numpy.empty_like(z, dtype=complex)
303
- #
304
- # # When z is too small, do not use quadratic form.
305
- # mask = numpy.abs(z) < tol
306
- # m[mask] = 1 / (sigma**2 * (1 - self.lam))
307
- #
308
- # # Use quadratic form
309
- # not_mask = ~mask
310
- # if numpy.any(not_mask):
311
- #
312
- # sign = -1 if alt_branch else 1
313
- # A = self.lam * sigma**2 * z[not_mask]
314
- # B = z[not_mask] - sigma**2 * (1 - self.lam)
315
- # D = B**2 - 4 * A
316
- # sqrtD = numpy.sqrt(D)
317
- # m1 = (-B + sqrtD) / (2 * A)
318
- # m2 = (-B - sqrtD) / (2 * A)
319
- #
320
- # # pick correct branch only for non-masked entries
321
- # upper = z[not_mask].imag >= 0
322
- # branch = numpy.empty_like(m1)
323
- # branch[upper] = numpy.where(sign*m1[upper].imag > 0, m1[upper],
324
- # m2[upper])
325
- # branch[~upper] = numpy.where(sign*m1[~upper].imag < 0,
326
- # m1[~upper], m2[~upper])
327
- # m[not_mask] = branch
328
- #
329
- # return m
330
-
331
283
  # =============
332
284
  # sqrt pos imag
333
285
  # =============
@@ -342,26 +294,6 @@ class MarchenkoPastur(object):
342
294
 
343
295
  return sq
344
296
 
345
- # ============
346
- # m mp reflect
347
- # ============
348
-
349
- # def _m_mp_reflect(self, z, alt_branch=False):
350
- # """
351
- # Analytic continuation using Schwarz reflection.
352
- # """
353
- #
354
- # mask_p = z.imag >= 0.0
355
- # mask_n = z.imag < 0.0
356
- #
357
- # m = numpy.zeros_like(z)
358
- #
359
- # f = self._m_mp_numeric_vectorized
360
- # m[mask_p] = f(z[mask_p], alt_branch=False)
361
- # m[mask_n] = f(z[mask_n], alt_branch=alt_branch)
362
- #
363
- # return m
364
-
365
297
  # ================
366
298
  # stieltjes branch
367
299
  # ================
@@ -429,7 +361,7 @@ class MarchenkoPastur(object):
429
361
 
430
362
  x : numpy.array, default=None
431
363
  The x axis of the grid where the Stieltjes transform is evaluated.
432
- If `None`, an interval slightly larger than the supp interval of
364
+ If `None`, an interval slightly larger than the support interval of
433
365
  the spectral density is used.
434
366
 
435
367
  y : numpy.array, default=None
@@ -557,126 +489,6 @@ class MarchenkoPastur(object):
557
489
  else:
558
490
  return m1
559
491
 
560
- # ======
561
- # sample
562
- # ======
563
-
564
- def sample(self, size, x_min=None, x_max=None, method='qmc', seed=None,
565
- plot=False, latex=False, save=False):
566
- """
567
- Sample from distribution.
568
-
569
- Parameters
570
- ----------
571
-
572
- size : int
573
- Size of sample.
574
-
575
- x_min : float, default=None
576
- Minimum of sample values. If `None`, the left edge of the supp
577
- is used.
578
-
579
- x_max : float, default=None
580
- Maximum of sample values. If `None`, the right edge of the supp
581
- is used.
582
-
583
- method : {``'mc'``, ``'qmc'``}, default= ``'qmc'``
584
- Method of drawing samples from uniform distribution:
585
-
586
- * ``'mc'``: Monte Carlo
587
- * ``'qmc'``: Quasi Monte Carlo
588
-
589
- seed : int, default=None,
590
- Seed for random number generator.
591
-
592
- plot : bool, default=False
593
- If `True`, samples histogram is plotted.
594
-
595
- latex : bool, default=False
596
- If `True`, the plot is rendered using LaTeX. This option is
597
- relevant only if ``plot=True``.
598
-
599
- save : bool, default=False
600
- If not `False`, the plot is saved. If a string is given, it is
601
- assumed to the save filename (with the file extension). This option
602
- is relevant only if ``plot=True``.
603
-
604
- Returns
605
- -------
606
-
607
- s : numpy.ndarray
608
- Samples.
609
-
610
- Notes
611
- -----
612
-
613
- This method uses inverse transform sampling.
614
-
615
- Examples
616
- --------
617
-
618
- .. code-block::python
619
-
620
- >>> from freealg.distributions import MarchenkoPastur
621
- >>> mp = MarchenkoPastur(1/50)
622
- >>> s = mp.sample(2000)
623
-
624
- .. image:: ../_static/images/plots/mp_samples.png
625
- :align: center
626
- :class: custom-dark
627
- """
628
-
629
- if x_min is None:
630
- x_min = self.lam_m
631
-
632
- if x_max is None:
633
- x_max = self.lam_p
634
-
635
- # Grid and PDF
636
- xs = numpy.linspace(x_min, x_max, size)
637
- pdf = self.density(xs)
638
-
639
- # CDF (using cumulative trapezoidal rule)
640
- cdf = cumtrapz(pdf, xs, initial=0)
641
- cdf /= cdf[-1] # normalize CDF to 1
642
-
643
- # Inverse CDF interpolator
644
- inv_cdf = interp1d(cdf, xs, bounds_error=False,
645
- fill_value=(x_min, x_max))
646
-
647
- # Random generator
648
- rng = numpy.random.default_rng(seed)
649
-
650
- # Draw from uniform distribution
651
- if method == 'mc':
652
- u = rng.random(size)
653
-
654
- elif method == 'qmc':
655
- try:
656
- engine = qmc.Halton(d=1, scramble=True, rng=rng)
657
- except TypeError:
658
- # Older scipy versions
659
- engine = qmc.Halton(d=1, scramble=True, seed=rng)
660
- u = engine.random(size).ravel()
661
-
662
- else:
663
- raise NotImplementedError('"method" is invalid.')
664
-
665
- # Draw from distribution by mapping from inverse CDF
666
- samples = inv_cdf(u).ravel()
667
-
668
- if plot:
669
- radius = 0.5 * (self.lam_p - self.lam_m)
670
- center = 0.5 * (self.lam_p + self.lam_m)
671
- scale = 1.25
672
- x_min = numpy.floor(center - radius * scale)
673
- x_max = numpy.ceil(center + radius * scale)
674
- x = numpy.linspace(x_min, x_max, 500)
675
- rho = self.density(x)
676
- plot_samples(x, rho, x_min, x_max, samples, latex=latex, save=save)
677
-
678
- return samples
679
-
680
492
  # ======
681
493
  # matrix
682
494
  # ======