diffaaable 1.1.0__tar.gz → 1.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.3
1
+ Metadata-Version: 2.4
2
2
  Name: diffaaable
3
- Version: 1.1.0
3
+ Version: 1.2.1
4
4
  Summary: JAX-differentiable AAA algorithm
5
5
  Keywords: python
6
6
  Author-email: Jan David Fischbach <fischbach@kit.edu>
@@ -10,11 +10,13 @@ Classifier: Programming Language :: Python :: 3.9
10
10
  Classifier: Programming Language :: Python :: 3.10
11
11
  Classifier: Programming Language :: Python :: 3.11
12
12
  Classifier: Operating System :: OS Independent
13
- Requires-Dist: numpy
13
+ License-File: LICENSE
14
+ Requires-Dist: numpy>2
14
15
  Requires-Dist: jax
15
16
  Requires-Dist: jaxlib
16
17
  Requires-Dist: baryrat
17
18
  Requires-Dist: jaxopt
19
+ Requires-Dist: pandas>=2.2.3
18
20
  Requires-Dist: tbump>=6.11.0 ; extra == "dev"
19
21
  Requires-Dist: towncrier ; extra == "dev"
20
22
  Requires-Dist: pre-commit ; extra == "dev"
@@ -25,12 +27,12 @@ Requires-Dist: pytest-benchmark ; extra == "dev"
25
27
  Requires-Dist: matplotlib ; extra == "dev"
26
28
  Requires-Dist: jupytext ; extra == "docs"
27
29
  Requires-Dist: matplotlib ; extra == "docs"
28
- Requires-Dist: jupyter-book ; extra == "docs"
30
+ Requires-Dist: jupyter-book<2 ; extra == "docs"
29
31
  Requires-Dist: sphinx_math_dollar ; extra == "docs"
30
32
  Provides-Extra: dev
31
33
  Provides-Extra: docs
32
34
 
33
- # diffaaable 1.1.0
35
+ # diffaaable 1.2.1
34
36
 
35
37
  ![](docs/assets/diffaaable.png)
36
38
 
@@ -39,8 +41,9 @@ A detailed derivation of the used matrix expressions is provided in the appendix
39
41
  Under the hood `diffaaable` uses the AAA implementation of [`baryrat`](https://github.com/c-f-h/baryrat).
40
42
  Additionaly the following application specific extensions to the AAA algorithm are included:
41
43
 
42
- - **Adaptive**: Adaptive refinement strategy (called Iterative Sample Refinement (ISR) in the corresponding paper) to minimize the number of function evaluation needed to precisely locate poles within some domain
43
- - **Vectorial** (also referred to as set-valued): AAA algorithm acting on vector valued functions $\mathbf{f}(z)$ as presented in [^3].
44
+ - **Adaptive**: Adaptive refinement strategy (called *Iterative Sample Refinement* (ISR) in the corresponding paper) to minimize the number of function evaluation needed to precisely locate poles within some domain
45
+ - **Vectorial**: AAA algorithm acting on vector valued functions $\mathbf{f}(z)$ as presented in [^3].
46
+ - **Tensor**: Convenience alternative to the vector valued AAA algorithm (`vectorial`) accepting a tensor valued function F_k (so arbitrary dimensionality) instead of the single dimension that `vectorial` requires.
44
47
  - **Lorentz**: Variant that enforces symmetric poles around the imaginary axis.
45
48
  - **Selective Refinement**: Use a divide and conquer theme to capture many pole simultaneously and accurately, by limiting the number of poles per AAA solve. Suggested in [^4].
46
49
 
@@ -60,7 +63,7 @@ When using this software package for scientific work please cite the associated
60
63
  +++
61
64
 
62
65
  [^1]: https://arxiv.org/pdf/2403.19404
63
- [^2]: "A framework to compute resonances arising from multiple scattering", https://arxiv.org/abs/2409.05563
66
+ [^2]: "A framework to compute resonances arising from multiple scattering", https://doi.org/10.1002/adts.202400989
64
67
  [^3]: https://doi.org/10.1093/imanum/draa098
65
68
  [^4]: https://doi.org/10.48550/arXiv.2405.19582
66
69
 
@@ -1,4 +1,4 @@
1
- # diffaaable 1.1.0
1
+ # diffaaable 1.2.1
2
2
 
3
3
  ![](docs/assets/diffaaable.png)
4
4
 
@@ -7,8 +7,9 @@ A detailed derivation of the used matrix expressions is provided in the appendix
7
7
  Under the hood `diffaaable` uses the AAA implementation of [`baryrat`](https://github.com/c-f-h/baryrat).
8
8
  Additionaly the following application specific extensions to the AAA algorithm are included:
9
9
 
10
- - **Adaptive**: Adaptive refinement strategy (called Iterative Sample Refinement (ISR) in the corresponding paper) to minimize the number of function evaluation needed to precisely locate poles within some domain
11
- - **Vectorial** (also referred to as set-valued): AAA algorithm acting on vector valued functions $\mathbf{f}(z)$ as presented in [^3].
10
+ - **Adaptive**: Adaptive refinement strategy (called *Iterative Sample Refinement* (ISR) in the corresponding paper) to minimize the number of function evaluation needed to precisely locate poles within some domain
11
+ - **Vectorial**: AAA algorithm acting on vector valued functions $\mathbf{f}(z)$ as presented in [^3].
12
+ - **Tensor**: Convenience alternative to the vector valued AAA algorithm (`vectorial`) accepting a tensor valued function F_k (so arbitrary dimensionality) instead of the single dimension that `vectorial` requires.
12
13
  - **Lorentz**: Variant that enforces symmetric poles around the imaginary axis.
13
14
  - **Selective Refinement**: Use a divide and conquer theme to capture many pole simultaneously and accurately, by limiting the number of poles per AAA solve. Suggested in [^4].
14
15
 
@@ -28,6 +29,6 @@ When using this software package for scientific work please cite the associated
28
29
  +++
29
30
 
30
31
  [^1]: https://arxiv.org/pdf/2403.19404
31
- [^2]: "A framework to compute resonances arising from multiple scattering", https://arxiv.org/abs/2409.05563
32
+ [^2]: "A framework to compute resonances arising from multiple scattering", https://doi.org/10.1002/adts.202400989
32
33
  [^3]: https://doi.org/10.1093/imanum/draa098
33
34
  [^4]: https://doi.org/10.48550/arXiv.2405.19582
@@ -0,0 +1,15 @@
1
+ """diffaaable - JAX-differentiable AAA algorithm"""
2
+
3
+ __version__ = "1.2.1"
4
+
5
+
6
+ from diffaaable.core import aaa
7
+ from diffaaable.adaptive import adaptive_aaa
8
+ from diffaaable.lorentz import lorentz_aaa
9
+ from diffaaable.selective import selective_subdivision_aaa
10
+ from diffaaable.set_aaa import set_aaa
11
+ from diffaaable.tensor import tensor_aaa
12
+ from diffaaable.vectorial import vectorial_aaa
13
+ from diffaaable.util import poles, residues
14
+
15
+ __all__ = ["aaa", "adaptive_aaa", "lorentz_aaa", "selective_subdivision_aaa", "set_aaa", "tensor_aaa", "vectorial_aaa", "poles", "residues"]
@@ -1,3 +1,4 @@
1
+ import functools
1
2
  from typing import Union
2
3
  import jax.numpy as np
3
4
  import numpy.typing as npt
@@ -8,7 +9,7 @@ import matplotlib.pyplot as plt
8
9
  from functools import partial
9
10
  Domain = tuple[complex, complex]
10
11
 
11
- def top_right(a: npt.NDArray[complex], b: npt.NDArray[complex]):
12
+ def top_right(a: npt.NDArray, b: npt.NDArray):
12
13
  return np.logical_and(a.imag>=b.imag, a.real>=b.real)
13
14
 
14
15
  def domain_mask(domain: Domain, z_n):
@@ -111,11 +112,14 @@ def next_samples_heat(
111
112
 
112
113
  return add_samples
113
114
 
114
- aaa = jax.tree_util.Partial(aaa)
115
+ vanilla_aaa = jax.tree_util.Partial(aaa)
115
116
 
116
117
  def mask(z_k, f_k, f_k_dot, cutoff):
117
- m = np.abs(f_k)<cutoff #filter out values, that have diverged too strongly
118
- m = np.logical_and(m, ~np.isnan(f_k)) #filter out nans
118
+ def all_f(f):
119
+ # to make sure all f_k (in case of tensor valued functions) behave nice
120
+ return np.squeeze(np.apply_over_axes(np.all, f, np.arange(f.ndim-1)+1))
121
+ m = all_f(np.abs(f_k)<cutoff) #filter out values, that have diverged too strongly
122
+ m = np.logical_and(m, all_f(~np.isnan(f_k))) #filter out nans
119
123
  m = np.logical_and(m, ~np.isnan(z_k)) #filter out nans
120
124
 
121
125
  if m.ndim == 2:
@@ -150,6 +154,11 @@ def _adaptive_aaa(z_k_0: npt.NDArray,
150
154
  Tangent of `f`. If provided JVPs of `f` will be collected throughout the
151
155
  iterations. For use in custom_jvp
152
156
  """
157
+ if sampling is None:
158
+ sampling = next_samples
159
+
160
+ if aaa is None:
161
+ aaa = vanilla_aaa
153
162
 
154
163
  if sampling == "heat":
155
164
  sampling = next_samples_heat
@@ -235,16 +244,18 @@ def adaptive_aaa(z_k_0: npt.NDArray,
235
244
  radius: float = None,
236
245
  domain: Domain = None,
237
246
  f_k_0: npt.NDArray = None,
238
- sampling: callable = next_samples,
247
+ sampling: callable = None,
239
248
  prev_z_n: npt.NDArray = None,
240
249
  return_samples: bool = False,
241
- aaa: callable = aaa):
250
+ aaa: callable = None):
242
251
  """ An 2x adaptive Antoulas–Anderson algorithm for rational approximation of
243
252
  meromorphic functions that are costly to evaluate.
244
253
 
245
254
  The algorithm iteratively places additional sample points close to estimated
246
255
  positions of poles identified during the past iteration. By this refinement
247
- scheme the number of function evaluations can be reduced.
256
+ scheme the number of function evaluations can be reduced. A more detailed
257
+ description of the iterative sample refinement (ISR) algorithm is provided
258
+ in (https://doi.org/10.1002/adts.202400989).
248
259
 
249
260
  It is JAX differentiable wrt. the approximated function `f`, via its other
250
261
  arguments besides `z`. `f` should be provided as a `jax.tree_util.Partial`
@@ -267,13 +278,36 @@ def adaptive_aaa(z_k_0: npt.NDArray,
267
278
  tol: float
268
279
  Tolerance used in AAA (see `diffaaable.aaa`)
269
280
  radius: float
270
- Distance from the assumed poles for nex samples
281
+ Distance from the assumed poles for next samples
271
282
  domain: tuple[complex, complex]
272
283
  Tuple of min (lower left) and max (upper right) values defining a
273
284
  rectangle in the complex plane. Assumed poles outside of the domain
274
285
  will not receive refinement.
275
286
  f_k_0:
276
287
  Allows user to provide f evaluated at z_k_0
288
+ sampling: callable
289
+ strategy to determine the next sample points. The function should
290
+ accept the following arguments:
291
+ - z_n: np.ndarray
292
+ estimated poles
293
+ - prev_z_n: np.ndarray
294
+ previous estimated poles
295
+ - samples: np.ndarray
296
+ current sample points
297
+ - domain: Domain
298
+ domian in wich to refine poles
299
+ - radius: float
300
+ distance from the assumed poles for next samples
301
+ - randkey: jax.random.PRNGKey
302
+ random key for sampling
303
+ prev_z_n: np.ndarray
304
+ the previous poles that will be passed to the first evaluation of the sampling strategy
305
+ return_samples: bool
306
+ If True, the function returns the samples used for the AAA approximation
307
+ and the function evaluations at these points at the 4t and 5th position.
308
+ aaa: callable
309
+ The AAA variant to be used. By default `diffaaable.aaa` is used.
310
+ If you want to use the tensor AAA, you can pass `diffaaable.tensor.tensor_aaa`.
277
311
 
278
312
 
279
313
  Returns
@@ -286,6 +320,12 @@ def adaptive_aaa(z_k_0: npt.NDArray,
286
320
  Weights of Barycentric Approximation
287
321
  z_n: np.array
288
322
  Poles of Barycentric Approximation
323
+
324
+ z_k_final: np.array
325
+ all sample points used for the AAA approximation. Only returned if
326
+ `return_samples` is True.
327
+ f_k_final: np.array
328
+ `f(z_k_final)` Only returned if `return_samples` is True.
289
329
  """
290
330
  return _adaptive_aaa(
291
331
  z_k_0=z_k_0, f=f, evolutions=evolutions, cutoff=cutoff, tol=tol, mmax=mmax,
@@ -5,8 +5,11 @@ import numpy.typing as npt
5
5
  import jax
6
6
  import numpy as np
7
7
  from baryrat import aaa as oaaa # ordinary aaa
8
+ from .tensor import tensor_aaa
8
9
  import functools
9
- import scipy.linalg
10
+ from .util import poles
11
+
12
+ USE_SETAAA = False
10
13
 
11
14
  @functools.wraps(oaaa)
12
15
  @jax.custom_jvp
@@ -14,17 +17,22 @@ def aaa(z_k: npt.NDArray, f_k: npt.NDArray, tol: float=1e-13, mmax: int=100):
14
17
  """
15
18
  Wraped aaa to enable JAX based autodiff.
16
19
  """
17
- r = oaaa(z_k, f_k, tol=tol, mmax=mmax)
18
- z_j = r.nodes
19
- f_j = r.values
20
- w_j = r.weights
21
-
22
- mask = w_j!=0
23
- z_j = z_j[mask]
24
- f_j = f_j[mask]
25
- w_j = w_j[mask]
26
-
27
- z_n = poles(z_j, w_j)
20
+ if USE_SETAAA:
21
+ print("Using set_aaa")
22
+ z_j, f_j, w_j, z_n = tensor_aaa(z_k, f_k[:, None], tol_aaa=tol, mmax_aaa=mmax)
23
+ f_j = f_j[:, 0]
24
+ else:
25
+ r = oaaa(z_k, f_k, tol=tol, mmax=mmax)
26
+ z_j = r.nodes
27
+ f_j = r.values
28
+ w_j = r.weights
29
+
30
+ mask = w_j!=0
31
+ z_j = z_j[mask]
32
+ f_j = f_j[mask]
33
+ w_j = w_j[mask]
34
+
35
+ z_n = poles(z_j, w_j)
28
36
 
29
37
  z_n = z_n[jnp.argsort(-jnp.abs(z_n))]
30
38
 
@@ -166,59 +174,3 @@ def aaa_jvp(primals, tangents):
166
174
  tangent_out = z_j_dot, f_j_dot, w_j_dot, z_n_dot
167
175
 
168
176
  return primal_out, tangent_out
169
-
170
- def poles(z_j,w_j):
171
- """
172
- The poles of a barycentric rational with given nodes and weights.
173
- Poles lifted by zeros of the nominator are included.
174
- Thus the values $f_j$ do not contribute and don't need to be provided
175
- The implementation was modified from `baryrat` to support JAX AD.
176
-
177
- Parameters
178
- ----------
179
- z_j : array (m,)
180
- nodes of the barycentric rational
181
- w_j : array (m,)
182
- weights of the barycentric rational
183
-
184
- Returns
185
- -------
186
- z_n : array (m-1,)
187
- poles of the barycentric rational (more strictly zeros of the denominator)
188
- """
189
- f_j = np.ones_like(z_j)
190
-
191
- B = np.eye(len(w_j) + 1)
192
- B[0,0] = 0
193
- E = np.block([[0, w_j],
194
- [f_j[:,None], np.diag(z_j)]])
195
- evals = scipy.linalg.eigvals(E, B)
196
- return evals[np.isfinite(evals)]
197
-
198
- def residues(z_j,f_j,w_j,z_n):
199
- '''
200
- Residues for given poles via formula for simple poles
201
- of quotients of analytic functions.
202
- The implementation was modified from `baryrat` to support JAX AD.
203
-
204
- Parameters
205
- ----------
206
- z_j : array (m,)
207
- nodes of the barycentric rational
208
- w_j : array (m,)
209
- weights of the barycentric rational
210
- z_n : array (n,)
211
- poles of interest of the barycentric rational (n<=m-1)
212
-
213
- Returns
214
- -------
215
- r_n : array (n,)
216
- residues of poles `z_n`
217
- '''
218
-
219
- C_pol = 1.0 / (z_n[:,None] - z_j[None,:])
220
- N_pol = C_pol.dot(f_j*w_j)
221
- Ddiff_pol = (-C_pol**2).dot(w_j)
222
- res = N_pol / Ddiff_pol
223
-
224
- return jnp.nan_to_num(res)
@@ -0,0 +1,267 @@
1
+ import jax.numpy as np
2
+ import numpy as onp
3
+ from diffaaable.core import aaa
4
+ from diffaaable.util import residues
5
+ from diffaaable.adaptive import Domain, domain_mask, adaptive_aaa
6
+ import matplotlib.pyplot as plt
7
+
8
+ def reduced_domain(domain, reduction=1-1/12):
9
+ """
10
+ Utility: Rescale the domain. Can be used to shring or enlarge the domain.
11
+
12
+ Parameters
13
+ ----------
14
+ domain: Domain
15
+ The domain to rescale.
16
+ reduction: float
17
+ The factor by which to rescale the domain. A value of 1 will not change the domain. A value of 0.5 will shrink the domain by half. A value of 2 will enlarge the domain by a factor of 2.
18
+ """
19
+
20
+ r = reduction
21
+ return (
22
+ domain[0]*r+domain[1]*(1-r),
23
+ domain[1]*r+domain[0]*(1-r)
24
+ )
25
+
26
+ def sample_cross(domain):
27
+ center = domain_center(domain)
28
+ dist = 0.5 * (domain[1]-domain[0])
29
+ return center+np.array([dist.real, -dist.real, 1j*dist.imag, -1j*dist.imag])
30
+
31
+ def sample_domain(domain: Domain, N: int):
32
+ sqrt_N = np.round(np.sqrt(N)).astype(int)
33
+ domain = reduced_domain(domain)
34
+ z_k_r = np.linspace(domain[0].real, domain[1].real, sqrt_N)
35
+ z_k_i = np.linspace(domain[0].imag, domain[1].imag, sqrt_N)
36
+ Z_r, Z_i = np.meshgrid(z_k_r, z_k_i)
37
+ z_k = (Z_r+1j*Z_i).flatten()
38
+ return z_k
39
+
40
+ def sample_rim(domain: Domain, N: int):
41
+ side_N = N//4
42
+ z_k_r = np.linspace(domain[0].real, domain[1].real, side_N+2)[1:-1]
43
+ z_k_i = np.linspace(domain[0].imag, domain[1].imag, side_N+2)[1:-1] * 1j
44
+ return np.array([
45
+ 1j*domain[0].imag + z_k_r,
46
+ 1j*domain[1].imag + z_k_r,
47
+ domain[0].real + z_k_i,
48
+ domain[1].real + z_k_i
49
+ ]).flatten()
50
+
51
+ def anti_domain(domain: Domain):
52
+ return (
53
+ domain[0].real + 1j*domain[1].imag,
54
+ domain[1].real + 1j*domain[0].imag
55
+ )
56
+
57
+ def domain_center(domain: Domain):
58
+ return np.mean(np.array(domain))
59
+
60
+ def subdomains(domain: Domain, divide_horizontal: bool, center: complex=None):
61
+ """
62
+ Utility: Subdivide the domain into two subdomains.
63
+
64
+ Parameters
65
+ ----------
66
+ domain: Domain
67
+ The domain to subdivide.
68
+ divide_horizontal: bool
69
+ If True, the domain is divided horizontally. If False, the domain is divided vertically.
70
+ center: complex
71
+ The center through wich to divide the domain. If None, the center is calculated as the mean of the domain.
72
+ """
73
+ if center is None:
74
+ center = domain_center(domain)
75
+ left_up = domain[0].real + 1j*domain[1].imag
76
+ right_down = domain[1].real + 1j*domain[0].imag
77
+
78
+ subs = [
79
+ (center, domain[1]),
80
+ anti_domain((left_up, center)),
81
+ (domain[0], center),
82
+ anti_domain((center, right_down)),
83
+ ]
84
+
85
+ if divide_horizontal:
86
+ return [(subs[1][0], subs[0][1]), (subs[2][0], subs[3][1])]
87
+ return [(subs[2][0], subs[1][1]), (subs[3][0], subs[0][1])]
88
+
89
+ def plot_domain(domain: Domain, size: float=1):
90
+ """
91
+ Utility: Plot the domain as a rectangle.
92
+
93
+ Parameters
94
+ ----------
95
+ domain: Domain
96
+ The domain to plot.
97
+ size: float
98
+ The relative size to scale the linewidth of the rectangle.
99
+
100
+ TODO
101
+ ----
102
+ - Add a color argument to the function.
103
+ """
104
+ left_up = domain[0].real + 1j*domain[1].imag
105
+ right_down = domain[1].real + 1j*domain[0].imag
106
+
107
+ points = np.array([domain[0], right_down, domain[1], left_up, domain[0]])
108
+
109
+ return plt.plot(points.real, points.imag,
110
+ lw=size/30, zorder=1)
111
+
112
+ def all_poles_known(poles, prev, tol):
113
+ if prev is None or len(prev)!=len(poles):
114
+ return False
115
+
116
+ dist = np.abs(poles[:, None] - prev[None, :])
117
+ check = np.all(np.any(dist < tol, axis=1))
118
+ return check
119
+
120
+
121
+ def selective_subdivision_aaa(f: callable,
122
+ domain: Domain,
123
+ N: int = 36,
124
+ max_poles: int = 400,
125
+ cutoff: float = None,
126
+ tol_aaa: float = 1e-9,
127
+ tol_pol: float = 1e-5,
128
+ suggestions = None,
129
+ on_rim: bool = False,
130
+ Dmax=30,
131
+ use_adaptive: bool = True,
132
+ evolutions_adaptive: int = 5,
133
+ radius_adaptive: float = 1e-4,
134
+ z_k = None, f_k = None,
135
+ divide_horizontal=True,
136
+ debug_plot_domains: bool = False,
137
+ ):
138
+ """
139
+ When the number of poles that need to be located is large it can be beneficial to subdivide the search domain.
140
+ This function implements a recursive subdivision of the domain, that automatically terminates once the poles are found to a satisfactory degree.
141
+
142
+ Parameters
143
+ ----------
144
+ f: callable
145
+ The function that the pole search is conducted on. It should accept batches of complex numbers and return a complex number per input.
146
+ domain: Domain
147
+ The pole search is limited to this domain.
148
+ N: int
149
+ The initial number of samples. If the number of samples drops below N the algorithm will add N new samples to the domain.
150
+ max_poles: int
151
+ The maximum number of poles that are considered valid within one search. If this number is exceeded the domain will be subdivided.
152
+ cutoff: float
153
+ A cutoff to avoid numerical instability due to large samples close to poles. See also `diffaaable.adaptive.adaptive_aaa`.
154
+ tol_aaa: float
155
+ The tolerance for the AAA algorithm. See also `diffaaable.core.aaa`.
156
+ tol_pol: float
157
+ The tolerance for the pole search. This is used to determine if a pole has moved significantly since the last domain subdivision.
158
+ suggestions: array
159
+ A list of poles that are already known. This is used internally to recursively call `selective_subdivision_aaa`.
160
+ on_rim: bool
161
+ If True, the initial samples are taken on the rim of the domain. Selecting the samples on the domain border is closely
162
+ related to similar contour integral approaches. It is however generally recomended to sample within the domain (default: False).
163
+ Dmax: int
164
+ The maximum number of subdivisions. If this number is exceeded the algorithm will stop and return the current poles.
165
+ TODO: allow the user to specify that when reaching Dmax the algorithm should return no poles to avoid false poles.
166
+ use_adaptive: bool
167
+ If True, the algorithm will use the adaptive AAA algorithm within the subdomains to locate the poles more accurately.
168
+ The samples collected while searching the parent domain are passed to the respective subdomains to minimize computational cost.
169
+ Using the adaptive aaa is generally recommended (default: True). If False, the algorithm will use the standard AAA algorithm.
170
+ evolutions_adaptive: int
171
+ The number of evolutions for the adaptive AAA algorithm.
172
+ radius_adaptive: float
173
+ The radius for the adaptive AAA algorithm. See also `diffaaable.adaptive.adaptive_aaa`.
174
+ z_k: array
175
+ The samples that have already been collected. This is used internally to recursively call `selective_subdivision_aaa`. It can also be used by the user to pass samples that are already known.
176
+ f_k: array
177
+ The function values of the samples that have already been collected. This is used internally to recursively call `selective_subdivision_aaa`. It can also be used by the user to pass samples that are already known.
178
+ divide_horizontal: bool
179
+ If True, the next domain division will be horizontal. Used internally to alternate between horizontal and vertical divisions during recursion.
180
+ debug_plot_domains: bool
181
+ If True, the algorithm will plot the domains that are being searched. This is useful for debugging and understanding the algorithm.
182
+
183
+
184
+ TODO
185
+ ----
186
+ - allow access to samples slightly outside of domain
187
+ - divide horizontal/vertical according to the distribution of poles
188
+ """
189
+
190
+ domain_size = np.abs(domain[1]-domain[0])/2
191
+
192
+ if debug_plot_domains:
193
+ print(f"Domain: {domain}")
194
+ plot_domain(domain, size=domain_size)
195
+
196
+ if cutoff is None:
197
+ cutoff = np.inf
198
+
199
+ eval_count = 0
200
+ if use_adaptive:
201
+ if z_k is None:
202
+ z_k = np.empty((0,), dtype=complex)
203
+ f_k = z_k.copy()
204
+
205
+ if len(z_k) < N:
206
+ z_k_new = sample_domain(domain, N)
207
+ f_k = np.append(f_k, f(z_k_new))
208
+ z_k = np.append(z_k, z_k_new)
209
+
210
+ eval_count += len(z_k_new)
211
+
212
+ eval_count -= len(z_k)
213
+
214
+ # NOTE: reduced domain with a factor larger than 1
215
+ # actually increases domain size to avoid missing poles right at the border
216
+ z_j, f_j, w_j, z_n, z_k, f_k = adaptive_aaa(
217
+ z_k, f, f_k_0=f_k, evolutions=evolutions_adaptive, tol=tol_aaa,
218
+ domain=reduced_domain(domain, 1.07), radius=domain_size*radius_adaptive,
219
+ return_samples=True, cutoff=cutoff
220
+ )
221
+ # TODO pass down samples in buffer zone
222
+ eval_count += len(z_k)
223
+
224
+ else:
225
+ if on_rim:
226
+ z_k = sample_rim(domain, N)
227
+ else:
228
+ z_k = sample_domain(reduced_domain(domain, 1.05), N)
229
+ f_k = f(z_k)
230
+ eval_count += len(f_k)
231
+ try:
232
+ z_j, f_j, w_j, z_n = aaa(z_k, f_k, tol=tol_aaa)
233
+ except onp.linalg.LinAlgError as e:
234
+ z_n = z_j = f_j = w_j = np.empty((0,))
235
+
236
+ poles = z_n[domain_mask(domain, z_n)]
237
+
238
+ if (Dmax == 0 or
239
+ (len(poles)<=max_poles and all_poles_known(poles, suggestions, tol_pol))):
240
+
241
+ res = residues(z_j, f_j, w_j, poles)
242
+ return poles, res, eval_count
243
+
244
+ subs = subdomains(domain, divide_horizontal)
245
+
246
+ pol = np.empty((0,), dtype=complex)
247
+ res = pol.copy()
248
+ for i,sub in enumerate(subs):
249
+ sug = poles[domain_mask(sub, poles)]
250
+ sample_mask = domain_mask(sub, z_k)
251
+
252
+ known_z_k = z_k[sample_mask]
253
+ known_f_k = f_k[sample_mask]
254
+
255
+ p, r, e = selective_subdivision_aaa(
256
+ f, sub, N, max_poles, cutoff, tol_aaa, tol_pol,
257
+ use_adaptive=use_adaptive,
258
+ evolutions_adaptive=evolutions_adaptive,
259
+ radius_adaptive=radius_adaptive, on_rim=on_rim,
260
+ suggestions=sug, Dmax=Dmax-1, z_k=known_z_k, f_k=known_f_k,
261
+ divide_horizontal = not divide_horizontal,
262
+ debug_plot_domains=debug_plot_domains
263
+ )
264
+ pol = np.append(pol, p)
265
+ res = np.append(res, r)
266
+ eval_count += e
267
+ return pol, res, eval_count
@@ -0,0 +1 @@
1
+ from .set_aaa import set_aaa