diffaaable 1.1.0__tar.gz → 1.2.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {diffaaable-1.1.0 → diffaaable-1.2.1}/PKG-INFO +11 -8
- {diffaaable-1.1.0 → diffaaable-1.2.1}/README.md +5 -4
- diffaaable-1.2.1/diffaaable/__init__.py +15 -0
- {diffaaable-1.1.0 → diffaaable-1.2.1}/diffaaable/adaptive.py +48 -8
- {diffaaable-1.1.0 → diffaaable-1.2.1}/diffaaable/core.py +20 -68
- diffaaable-1.2.1/diffaaable/selective.py +267 -0
- diffaaable-1.2.1/diffaaable/set_aaa/__init__.py +1 -0
- diffaaable-1.2.1/diffaaable/set_aaa/set_aaa.m +407 -0
- diffaaable-1.2.1/diffaaable/set_aaa/set_aaa.py +168 -0
- diffaaable-1.2.1/diffaaable/set_aaa/test_set_aaa.m +25 -0
- diffaaable-1.2.1/diffaaable/tensor.py +84 -0
- diffaaable-1.2.1/diffaaable/util.py +59 -0
- {diffaaable-1.1.0 → diffaaable-1.2.1}/diffaaable/vectorial.py +3 -5
- {diffaaable-1.1.0 → diffaaable-1.2.1}/pyproject.toml +13 -5
- diffaaable-1.1.0/diffaaable/__init__.py +0 -10
- diffaaable-1.1.0/diffaaable/selective.py +0 -197
- {diffaaable-1.1.0 → diffaaable-1.2.1}/LICENSE +0 -0
- {diffaaable-1.1.0 → diffaaable-1.2.1}/diffaaable/lorentz.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: diffaaable
|
|
3
|
-
Version: 1.1
|
|
3
|
+
Version: 1.2.1
|
|
4
4
|
Summary: JAX-differentiable AAA algorithm
|
|
5
5
|
Keywords: python
|
|
6
6
|
Author-email: Jan David Fischbach <fischbach@kit.edu>
|
|
@@ -10,11 +10,13 @@ Classifier: Programming Language :: Python :: 3.9
|
|
|
10
10
|
Classifier: Programming Language :: Python :: 3.10
|
|
11
11
|
Classifier: Programming Language :: Python :: 3.11
|
|
12
12
|
Classifier: Operating System :: OS Independent
|
|
13
|
-
|
|
13
|
+
License-File: LICENSE
|
|
14
|
+
Requires-Dist: numpy>2
|
|
14
15
|
Requires-Dist: jax
|
|
15
16
|
Requires-Dist: jaxlib
|
|
16
17
|
Requires-Dist: baryrat
|
|
17
18
|
Requires-Dist: jaxopt
|
|
19
|
+
Requires-Dist: pandas>=2.2.3
|
|
18
20
|
Requires-Dist: tbump>=6.11.0 ; extra == "dev"
|
|
19
21
|
Requires-Dist: towncrier ; extra == "dev"
|
|
20
22
|
Requires-Dist: pre-commit ; extra == "dev"
|
|
@@ -25,12 +27,12 @@ Requires-Dist: pytest-benchmark ; extra == "dev"
|
|
|
25
27
|
Requires-Dist: matplotlib ; extra == "dev"
|
|
26
28
|
Requires-Dist: jupytext ; extra == "docs"
|
|
27
29
|
Requires-Dist: matplotlib ; extra == "docs"
|
|
28
|
-
Requires-Dist: jupyter-book ; extra == "docs"
|
|
30
|
+
Requires-Dist: jupyter-book<2 ; extra == "docs"
|
|
29
31
|
Requires-Dist: sphinx_math_dollar ; extra == "docs"
|
|
30
32
|
Provides-Extra: dev
|
|
31
33
|
Provides-Extra: docs
|
|
32
34
|
|
|
33
|
-
# diffaaable 1.1
|
|
35
|
+
# diffaaable 1.2.1
|
|
34
36
|
|
|
35
37
|

|
|
36
38
|
|
|
@@ -39,8 +41,9 @@ A detailed derivation of the used matrix expressions is provided in the appendix
|
|
|
39
41
|
Under the hood `diffaaable` uses the AAA implementation of [`baryrat`](https://github.com/c-f-h/baryrat).
|
|
40
42
|
Additionaly the following application specific extensions to the AAA algorithm are included:
|
|
41
43
|
|
|
42
|
-
- **Adaptive**: Adaptive refinement strategy (called Iterative Sample Refinement (ISR) in the corresponding paper) to minimize the number of function evaluation needed to precisely locate poles within some domain
|
|
43
|
-
- **Vectorial
|
|
44
|
+
- **Adaptive**: Adaptive refinement strategy (called *Iterative Sample Refinement* (ISR) in the corresponding paper) to minimize the number of function evaluation needed to precisely locate poles within some domain
|
|
45
|
+
- **Vectorial**: AAA algorithm acting on vector valued functions $\mathbf{f}(z)$ as presented in [^3].
|
|
46
|
+
- **Tensor**: Convenience alternative to the vector valued AAA algorithm (`vectorial`) accepting a tensor valued function F_k (so arbitrary dimensionality) instead of the single dimension that `vectorial` requires.
|
|
44
47
|
- **Lorentz**: Variant that enforces symmetric poles around the imaginary axis.
|
|
45
48
|
- **Selective Refinement**: Use a divide and conquer theme to capture many pole simultaneously and accurately, by limiting the number of poles per AAA solve. Suggested in [^4].
|
|
46
49
|
|
|
@@ -60,7 +63,7 @@ When using this software package for scientific work please cite the associated
|
|
|
60
63
|
+++
|
|
61
64
|
|
|
62
65
|
[^1]: https://arxiv.org/pdf/2403.19404
|
|
63
|
-
[^2]: "A framework to compute resonances arising from multiple scattering", https://
|
|
66
|
+
[^2]: "A framework to compute resonances arising from multiple scattering", https://doi.org/10.1002/adts.202400989
|
|
64
67
|
[^3]: https://doi.org/10.1093/imanum/draa098
|
|
65
68
|
[^4]: https://doi.org/10.48550/arXiv.2405.19582
|
|
66
69
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# diffaaable 1.1
|
|
1
|
+
# diffaaable 1.2.1
|
|
2
2
|
|
|
3
3
|

|
|
4
4
|
|
|
@@ -7,8 +7,9 @@ A detailed derivation of the used matrix expressions is provided in the appendix
|
|
|
7
7
|
Under the hood `diffaaable` uses the AAA implementation of [`baryrat`](https://github.com/c-f-h/baryrat).
|
|
8
8
|
Additionaly the following application specific extensions to the AAA algorithm are included:
|
|
9
9
|
|
|
10
|
-
- **Adaptive**: Adaptive refinement strategy (called Iterative Sample Refinement (ISR) in the corresponding paper) to minimize the number of function evaluation needed to precisely locate poles within some domain
|
|
11
|
-
- **Vectorial
|
|
10
|
+
- **Adaptive**: Adaptive refinement strategy (called *Iterative Sample Refinement* (ISR) in the corresponding paper) to minimize the number of function evaluation needed to precisely locate poles within some domain
|
|
11
|
+
- **Vectorial**: AAA algorithm acting on vector valued functions $\mathbf{f}(z)$ as presented in [^3].
|
|
12
|
+
- **Tensor**: Convenience alternative to the vector valued AAA algorithm (`vectorial`) accepting a tensor valued function F_k (so arbitrary dimensionality) instead of the single dimension that `vectorial` requires.
|
|
12
13
|
- **Lorentz**: Variant that enforces symmetric poles around the imaginary axis.
|
|
13
14
|
- **Selective Refinement**: Use a divide and conquer theme to capture many pole simultaneously and accurately, by limiting the number of poles per AAA solve. Suggested in [^4].
|
|
14
15
|
|
|
@@ -28,6 +29,6 @@ When using this software package for scientific work please cite the associated
|
|
|
28
29
|
+++
|
|
29
30
|
|
|
30
31
|
[^1]: https://arxiv.org/pdf/2403.19404
|
|
31
|
-
[^2]: "A framework to compute resonances arising from multiple scattering", https://
|
|
32
|
+
[^2]: "A framework to compute resonances arising from multiple scattering", https://doi.org/10.1002/adts.202400989
|
|
32
33
|
[^3]: https://doi.org/10.1093/imanum/draa098
|
|
33
34
|
[^4]: https://doi.org/10.48550/arXiv.2405.19582
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""diffaaable - JAX-differentiable AAA algorithm"""
|
|
2
|
+
|
|
3
|
+
__version__ = "1.2.1"
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
from diffaaable.core import aaa
|
|
7
|
+
from diffaaable.adaptive import adaptive_aaa
|
|
8
|
+
from diffaaable.lorentz import lorentz_aaa
|
|
9
|
+
from diffaaable.selective import selective_subdivision_aaa
|
|
10
|
+
from diffaaable.set_aaa import set_aaa
|
|
11
|
+
from diffaaable.tensor import tensor_aaa
|
|
12
|
+
from diffaaable.vectorial import vectorial_aaa
|
|
13
|
+
from diffaaable.util import poles, residues
|
|
14
|
+
|
|
15
|
+
__all__ = ["aaa", "adaptive_aaa", "lorentz_aaa", "selective_subdivision_aaa", "set_aaa", "tensor_aaa", "vectorial_aaa", "poles", "residues"]
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import functools
|
|
1
2
|
from typing import Union
|
|
2
3
|
import jax.numpy as np
|
|
3
4
|
import numpy.typing as npt
|
|
@@ -8,7 +9,7 @@ import matplotlib.pyplot as plt
|
|
|
8
9
|
from functools import partial
|
|
9
10
|
Domain = tuple[complex, complex]
|
|
10
11
|
|
|
11
|
-
def top_right(a: npt.NDArray
|
|
12
|
+
def top_right(a: npt.NDArray, b: npt.NDArray):
|
|
12
13
|
return np.logical_and(a.imag>=b.imag, a.real>=b.real)
|
|
13
14
|
|
|
14
15
|
def domain_mask(domain: Domain, z_n):
|
|
@@ -111,11 +112,14 @@ def next_samples_heat(
|
|
|
111
112
|
|
|
112
113
|
return add_samples
|
|
113
114
|
|
|
114
|
-
|
|
115
|
+
vanilla_aaa = jax.tree_util.Partial(aaa)
|
|
115
116
|
|
|
116
117
|
def mask(z_k, f_k, f_k_dot, cutoff):
|
|
117
|
-
|
|
118
|
-
|
|
118
|
+
def all_f(f):
|
|
119
|
+
# to make sure all f_k (in case of tensor valued functions) behave nice
|
|
120
|
+
return np.squeeze(np.apply_over_axes(np.all, f, np.arange(f.ndim-1)+1))
|
|
121
|
+
m = all_f(np.abs(f_k)<cutoff) #filter out values, that have diverged too strongly
|
|
122
|
+
m = np.logical_and(m, all_f(~np.isnan(f_k))) #filter out nans
|
|
119
123
|
m = np.logical_and(m, ~np.isnan(z_k)) #filter out nans
|
|
120
124
|
|
|
121
125
|
if m.ndim == 2:
|
|
@@ -150,6 +154,11 @@ def _adaptive_aaa(z_k_0: npt.NDArray,
|
|
|
150
154
|
Tangent of `f`. If provided JVPs of `f` will be collected throughout the
|
|
151
155
|
iterations. For use in custom_jvp
|
|
152
156
|
"""
|
|
157
|
+
if sampling is None:
|
|
158
|
+
sampling = next_samples
|
|
159
|
+
|
|
160
|
+
if aaa is None:
|
|
161
|
+
aaa = vanilla_aaa
|
|
153
162
|
|
|
154
163
|
if sampling == "heat":
|
|
155
164
|
sampling = next_samples_heat
|
|
@@ -235,16 +244,18 @@ def adaptive_aaa(z_k_0: npt.NDArray,
|
|
|
235
244
|
radius: float = None,
|
|
236
245
|
domain: Domain = None,
|
|
237
246
|
f_k_0: npt.NDArray = None,
|
|
238
|
-
sampling: callable =
|
|
247
|
+
sampling: callable = None,
|
|
239
248
|
prev_z_n: npt.NDArray = None,
|
|
240
249
|
return_samples: bool = False,
|
|
241
|
-
aaa: callable =
|
|
250
|
+
aaa: callable = None):
|
|
242
251
|
""" An 2x adaptive Antoulas–Anderson algorithm for rational approximation of
|
|
243
252
|
meromorphic functions that are costly to evaluate.
|
|
244
253
|
|
|
245
254
|
The algorithm iteratively places additional sample points close to estimated
|
|
246
255
|
positions of poles identified during the past iteration. By this refinement
|
|
247
|
-
scheme the number of function evaluations can be reduced.
|
|
256
|
+
scheme the number of function evaluations can be reduced. A more detailed
|
|
257
|
+
description of the iterative sample refinement (ISR) algorithm is provided
|
|
258
|
+
in (https://doi.org/10.1002/adts.202400989).
|
|
248
259
|
|
|
249
260
|
It is JAX differentiable wrt. the approximated function `f`, via its other
|
|
250
261
|
arguments besides `z`. `f` should be provided as a `jax.tree_util.Partial`
|
|
@@ -267,13 +278,36 @@ def adaptive_aaa(z_k_0: npt.NDArray,
|
|
|
267
278
|
tol: float
|
|
268
279
|
Tolerance used in AAA (see `diffaaable.aaa`)
|
|
269
280
|
radius: float
|
|
270
|
-
Distance from the assumed poles for
|
|
281
|
+
Distance from the assumed poles for next samples
|
|
271
282
|
domain: tuple[complex, complex]
|
|
272
283
|
Tuple of min (lower left) and max (upper right) values defining a
|
|
273
284
|
rectangle in the complex plane. Assumed poles outside of the domain
|
|
274
285
|
will not receive refinement.
|
|
275
286
|
f_k_0:
|
|
276
287
|
Allows user to provide f evaluated at z_k_0
|
|
288
|
+
sampling: callable
|
|
289
|
+
strategy to determine the next sample points. The function should
|
|
290
|
+
accept the following arguments:
|
|
291
|
+
- z_n: np.ndarray
|
|
292
|
+
estimated poles
|
|
293
|
+
- prev_z_n: np.ndarray
|
|
294
|
+
previous estimated poles
|
|
295
|
+
- samples: np.ndarray
|
|
296
|
+
current sample points
|
|
297
|
+
- domain: Domain
|
|
298
|
+
domian in wich to refine poles
|
|
299
|
+
- radius: float
|
|
300
|
+
distance from the assumed poles for next samples
|
|
301
|
+
- randkey: jax.random.PRNGKey
|
|
302
|
+
random key for sampling
|
|
303
|
+
prev_z_n: np.ndarray
|
|
304
|
+
the previous poles that will be passed to the first evaluation of the sampling strategy
|
|
305
|
+
return_samples: bool
|
|
306
|
+
If True, the function returns the samples used for the AAA approximation
|
|
307
|
+
and the function evaluations at these points at the 4t and 5th position.
|
|
308
|
+
aaa: callable
|
|
309
|
+
The AAA variant to be used. By default `diffaaable.aaa` is used.
|
|
310
|
+
If you want to use the tensor AAA, you can pass `diffaaable.tensor.tensor_aaa`.
|
|
277
311
|
|
|
278
312
|
|
|
279
313
|
Returns
|
|
@@ -286,6 +320,12 @@ def adaptive_aaa(z_k_0: npt.NDArray,
|
|
|
286
320
|
Weights of Barycentric Approximation
|
|
287
321
|
z_n: np.array
|
|
288
322
|
Poles of Barycentric Approximation
|
|
323
|
+
|
|
324
|
+
z_k_final: np.array
|
|
325
|
+
all sample points used for the AAA approximation. Only returned if
|
|
326
|
+
`return_samples` is True.
|
|
327
|
+
f_k_final: np.array
|
|
328
|
+
`f(z_k_final)` Only returned if `return_samples` is True.
|
|
289
329
|
"""
|
|
290
330
|
return _adaptive_aaa(
|
|
291
331
|
z_k_0=z_k_0, f=f, evolutions=evolutions, cutoff=cutoff, tol=tol, mmax=mmax,
|
|
@@ -5,8 +5,11 @@ import numpy.typing as npt
|
|
|
5
5
|
import jax
|
|
6
6
|
import numpy as np
|
|
7
7
|
from baryrat import aaa as oaaa # ordinary aaa
|
|
8
|
+
from .tensor import tensor_aaa
|
|
8
9
|
import functools
|
|
9
|
-
import
|
|
10
|
+
from .util import poles
|
|
11
|
+
|
|
12
|
+
USE_SETAAA = False
|
|
10
13
|
|
|
11
14
|
@functools.wraps(oaaa)
|
|
12
15
|
@jax.custom_jvp
|
|
@@ -14,17 +17,22 @@ def aaa(z_k: npt.NDArray, f_k: npt.NDArray, tol: float=1e-13, mmax: int=100):
|
|
|
14
17
|
"""
|
|
15
18
|
Wraped aaa to enable JAX based autodiff.
|
|
16
19
|
"""
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
20
|
+
if USE_SETAAA:
|
|
21
|
+
print("Using set_aaa")
|
|
22
|
+
z_j, f_j, w_j, z_n = tensor_aaa(z_k, f_k[:, None], tol_aaa=tol, mmax_aaa=mmax)
|
|
23
|
+
f_j = f_j[:, 0]
|
|
24
|
+
else:
|
|
25
|
+
r = oaaa(z_k, f_k, tol=tol, mmax=mmax)
|
|
26
|
+
z_j = r.nodes
|
|
27
|
+
f_j = r.values
|
|
28
|
+
w_j = r.weights
|
|
29
|
+
|
|
30
|
+
mask = w_j!=0
|
|
31
|
+
z_j = z_j[mask]
|
|
32
|
+
f_j = f_j[mask]
|
|
33
|
+
w_j = w_j[mask]
|
|
34
|
+
|
|
35
|
+
z_n = poles(z_j, w_j)
|
|
28
36
|
|
|
29
37
|
z_n = z_n[jnp.argsort(-jnp.abs(z_n))]
|
|
30
38
|
|
|
@@ -166,59 +174,3 @@ def aaa_jvp(primals, tangents):
|
|
|
166
174
|
tangent_out = z_j_dot, f_j_dot, w_j_dot, z_n_dot
|
|
167
175
|
|
|
168
176
|
return primal_out, tangent_out
|
|
169
|
-
|
|
170
|
-
def poles(z_j,w_j):
|
|
171
|
-
"""
|
|
172
|
-
The poles of a barycentric rational with given nodes and weights.
|
|
173
|
-
Poles lifted by zeros of the nominator are included.
|
|
174
|
-
Thus the values $f_j$ do not contribute and don't need to be provided
|
|
175
|
-
The implementation was modified from `baryrat` to support JAX AD.
|
|
176
|
-
|
|
177
|
-
Parameters
|
|
178
|
-
----------
|
|
179
|
-
z_j : array (m,)
|
|
180
|
-
nodes of the barycentric rational
|
|
181
|
-
w_j : array (m,)
|
|
182
|
-
weights of the barycentric rational
|
|
183
|
-
|
|
184
|
-
Returns
|
|
185
|
-
-------
|
|
186
|
-
z_n : array (m-1,)
|
|
187
|
-
poles of the barycentric rational (more strictly zeros of the denominator)
|
|
188
|
-
"""
|
|
189
|
-
f_j = np.ones_like(z_j)
|
|
190
|
-
|
|
191
|
-
B = np.eye(len(w_j) + 1)
|
|
192
|
-
B[0,0] = 0
|
|
193
|
-
E = np.block([[0, w_j],
|
|
194
|
-
[f_j[:,None], np.diag(z_j)]])
|
|
195
|
-
evals = scipy.linalg.eigvals(E, B)
|
|
196
|
-
return evals[np.isfinite(evals)]
|
|
197
|
-
|
|
198
|
-
def residues(z_j,f_j,w_j,z_n):
|
|
199
|
-
'''
|
|
200
|
-
Residues for given poles via formula for simple poles
|
|
201
|
-
of quotients of analytic functions.
|
|
202
|
-
The implementation was modified from `baryrat` to support JAX AD.
|
|
203
|
-
|
|
204
|
-
Parameters
|
|
205
|
-
----------
|
|
206
|
-
z_j : array (m,)
|
|
207
|
-
nodes of the barycentric rational
|
|
208
|
-
w_j : array (m,)
|
|
209
|
-
weights of the barycentric rational
|
|
210
|
-
z_n : array (n,)
|
|
211
|
-
poles of interest of the barycentric rational (n<=m-1)
|
|
212
|
-
|
|
213
|
-
Returns
|
|
214
|
-
-------
|
|
215
|
-
r_n : array (n,)
|
|
216
|
-
residues of poles `z_n`
|
|
217
|
-
'''
|
|
218
|
-
|
|
219
|
-
C_pol = 1.0 / (z_n[:,None] - z_j[None,:])
|
|
220
|
-
N_pol = C_pol.dot(f_j*w_j)
|
|
221
|
-
Ddiff_pol = (-C_pol**2).dot(w_j)
|
|
222
|
-
res = N_pol / Ddiff_pol
|
|
223
|
-
|
|
224
|
-
return jnp.nan_to_num(res)
|
|
@@ -0,0 +1,267 @@
|
|
|
1
|
+
import jax.numpy as np
|
|
2
|
+
import numpy as onp
|
|
3
|
+
from diffaaable.core import aaa
|
|
4
|
+
from diffaaable.util import residues
|
|
5
|
+
from diffaaable.adaptive import Domain, domain_mask, adaptive_aaa
|
|
6
|
+
import matplotlib.pyplot as plt
|
|
7
|
+
|
|
8
|
+
def reduced_domain(domain, reduction=1-1/12):
|
|
9
|
+
"""
|
|
10
|
+
Utility: Rescale the domain. Can be used to shring or enlarge the domain.
|
|
11
|
+
|
|
12
|
+
Parameters
|
|
13
|
+
----------
|
|
14
|
+
domain: Domain
|
|
15
|
+
The domain to rescale.
|
|
16
|
+
reduction: float
|
|
17
|
+
The factor by which to rescale the domain. A value of 1 will not change the domain. A value of 0.5 will shrink the domain by half. A value of 2 will enlarge the domain by a factor of 2.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
r = reduction
|
|
21
|
+
return (
|
|
22
|
+
domain[0]*r+domain[1]*(1-r),
|
|
23
|
+
domain[1]*r+domain[0]*(1-r)
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
def sample_cross(domain):
|
|
27
|
+
center = domain_center(domain)
|
|
28
|
+
dist = 0.5 * (domain[1]-domain[0])
|
|
29
|
+
return center+np.array([dist.real, -dist.real, 1j*dist.imag, -1j*dist.imag])
|
|
30
|
+
|
|
31
|
+
def sample_domain(domain: Domain, N: int):
|
|
32
|
+
sqrt_N = np.round(np.sqrt(N)).astype(int)
|
|
33
|
+
domain = reduced_domain(domain)
|
|
34
|
+
z_k_r = np.linspace(domain[0].real, domain[1].real, sqrt_N)
|
|
35
|
+
z_k_i = np.linspace(domain[0].imag, domain[1].imag, sqrt_N)
|
|
36
|
+
Z_r, Z_i = np.meshgrid(z_k_r, z_k_i)
|
|
37
|
+
z_k = (Z_r+1j*Z_i).flatten()
|
|
38
|
+
return z_k
|
|
39
|
+
|
|
40
|
+
def sample_rim(domain: Domain, N: int):
|
|
41
|
+
side_N = N//4
|
|
42
|
+
z_k_r = np.linspace(domain[0].real, domain[1].real, side_N+2)[1:-1]
|
|
43
|
+
z_k_i = np.linspace(domain[0].imag, domain[1].imag, side_N+2)[1:-1] * 1j
|
|
44
|
+
return np.array([
|
|
45
|
+
1j*domain[0].imag + z_k_r,
|
|
46
|
+
1j*domain[1].imag + z_k_r,
|
|
47
|
+
domain[0].real + z_k_i,
|
|
48
|
+
domain[1].real + z_k_i
|
|
49
|
+
]).flatten()
|
|
50
|
+
|
|
51
|
+
def anti_domain(domain: Domain):
|
|
52
|
+
return (
|
|
53
|
+
domain[0].real + 1j*domain[1].imag,
|
|
54
|
+
domain[1].real + 1j*domain[0].imag
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
def domain_center(domain: Domain):
|
|
58
|
+
return np.mean(np.array(domain))
|
|
59
|
+
|
|
60
|
+
def subdomains(domain: Domain, divide_horizontal: bool, center: complex=None):
|
|
61
|
+
"""
|
|
62
|
+
Utility: Subdivide the domain into two subdomains.
|
|
63
|
+
|
|
64
|
+
Parameters
|
|
65
|
+
----------
|
|
66
|
+
domain: Domain
|
|
67
|
+
The domain to subdivide.
|
|
68
|
+
divide_horizontal: bool
|
|
69
|
+
If True, the domain is divided horizontally. If False, the domain is divided vertically.
|
|
70
|
+
center: complex
|
|
71
|
+
The center through wich to divide the domain. If None, the center is calculated as the mean of the domain.
|
|
72
|
+
"""
|
|
73
|
+
if center is None:
|
|
74
|
+
center = domain_center(domain)
|
|
75
|
+
left_up = domain[0].real + 1j*domain[1].imag
|
|
76
|
+
right_down = domain[1].real + 1j*domain[0].imag
|
|
77
|
+
|
|
78
|
+
subs = [
|
|
79
|
+
(center, domain[1]),
|
|
80
|
+
anti_domain((left_up, center)),
|
|
81
|
+
(domain[0], center),
|
|
82
|
+
anti_domain((center, right_down)),
|
|
83
|
+
]
|
|
84
|
+
|
|
85
|
+
if divide_horizontal:
|
|
86
|
+
return [(subs[1][0], subs[0][1]), (subs[2][0], subs[3][1])]
|
|
87
|
+
return [(subs[2][0], subs[1][1]), (subs[3][0], subs[0][1])]
|
|
88
|
+
|
|
89
|
+
def plot_domain(domain: Domain, size: float=1):
|
|
90
|
+
"""
|
|
91
|
+
Utility: Plot the domain as a rectangle.
|
|
92
|
+
|
|
93
|
+
Parameters
|
|
94
|
+
----------
|
|
95
|
+
domain: Domain
|
|
96
|
+
The domain to plot.
|
|
97
|
+
size: float
|
|
98
|
+
The relative size to scale the linewidth of the rectangle.
|
|
99
|
+
|
|
100
|
+
TODO
|
|
101
|
+
----
|
|
102
|
+
- Add a color argument to the function.
|
|
103
|
+
"""
|
|
104
|
+
left_up = domain[0].real + 1j*domain[1].imag
|
|
105
|
+
right_down = domain[1].real + 1j*domain[0].imag
|
|
106
|
+
|
|
107
|
+
points = np.array([domain[0], right_down, domain[1], left_up, domain[0]])
|
|
108
|
+
|
|
109
|
+
return plt.plot(points.real, points.imag,
|
|
110
|
+
lw=size/30, zorder=1)
|
|
111
|
+
|
|
112
|
+
def all_poles_known(poles, prev, tol):
|
|
113
|
+
if prev is None or len(prev)!=len(poles):
|
|
114
|
+
return False
|
|
115
|
+
|
|
116
|
+
dist = np.abs(poles[:, None] - prev[None, :])
|
|
117
|
+
check = np.all(np.any(dist < tol, axis=1))
|
|
118
|
+
return check
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def selective_subdivision_aaa(f: callable,
|
|
122
|
+
domain: Domain,
|
|
123
|
+
N: int = 36,
|
|
124
|
+
max_poles: int = 400,
|
|
125
|
+
cutoff: float = None,
|
|
126
|
+
tol_aaa: float = 1e-9,
|
|
127
|
+
tol_pol: float = 1e-5,
|
|
128
|
+
suggestions = None,
|
|
129
|
+
on_rim: bool = False,
|
|
130
|
+
Dmax=30,
|
|
131
|
+
use_adaptive: bool = True,
|
|
132
|
+
evolutions_adaptive: int = 5,
|
|
133
|
+
radius_adaptive: float = 1e-4,
|
|
134
|
+
z_k = None, f_k = None,
|
|
135
|
+
divide_horizontal=True,
|
|
136
|
+
debug_plot_domains: bool = False,
|
|
137
|
+
):
|
|
138
|
+
"""
|
|
139
|
+
When the number of poles that need to be located is large it can be beneficial to subdivide the search domain.
|
|
140
|
+
This function implements a recursive subdivision of the domain, that automatically terminates once the poles are found to a satisfactory degree.
|
|
141
|
+
|
|
142
|
+
Parameters
|
|
143
|
+
----------
|
|
144
|
+
f: callable
|
|
145
|
+
The function that the pole search is conducted on. It should accept batches of complex numbers and return a complex number per input.
|
|
146
|
+
domain: Domain
|
|
147
|
+
The pole search is limited to this domain.
|
|
148
|
+
N: int
|
|
149
|
+
The initial number of samples. If the number of samples drops below N the algorithm will add N new samples to the domain.
|
|
150
|
+
max_poles: int
|
|
151
|
+
The maximum number of poles that are considered valid within one search. If this number is exceeded the domain will be subdivided.
|
|
152
|
+
cutoff: float
|
|
153
|
+
A cutoff to avoid numerical instability due to large samples close to poles. See also `diffaaable.adaptive.adaptive_aaa`.
|
|
154
|
+
tol_aaa: float
|
|
155
|
+
The tolerance for the AAA algorithm. See also `diffaaable.core.aaa`.
|
|
156
|
+
tol_pol: float
|
|
157
|
+
The tolerance for the pole search. This is used to determine if a pole has moved significantly since the last domain subdivision.
|
|
158
|
+
suggestions: array
|
|
159
|
+
A list of poles that are already known. This is used internally to recursively call `selective_subdivision_aaa`.
|
|
160
|
+
on_rim: bool
|
|
161
|
+
If True, the initial samples are taken on the rim of the domain. Selecting the samples on the domain border is closely
|
|
162
|
+
related to similar contour integral approaches. It is however generally recomended to sample within the domain (default: False).
|
|
163
|
+
Dmax: int
|
|
164
|
+
The maximum number of subdivisions. If this number is exceeded the algorithm will stop and return the current poles.
|
|
165
|
+
TODO: allow the user to specify that when reaching Dmax the algorithm should return no poles to avoid false poles.
|
|
166
|
+
use_adaptive: bool
|
|
167
|
+
If True, the algorithm will use the adaptive AAA algorithm within the subdomains to locate the poles more accurately.
|
|
168
|
+
The samples collected while searching the parent domain are passed to the respective subdomains to minimize computational cost.
|
|
169
|
+
Using the adaptive aaa is generally recommended (default: True). If False, the algorithm will use the standard AAA algorithm.
|
|
170
|
+
evolutions_adaptive: int
|
|
171
|
+
The number of evolutions for the adaptive AAA algorithm.
|
|
172
|
+
radius_adaptive: float
|
|
173
|
+
The radius for the adaptive AAA algorithm. See also `diffaaable.adaptive.adaptive_aaa`.
|
|
174
|
+
z_k: array
|
|
175
|
+
The samples that have already been collected. This is used internally to recursively call `selective_subdivision_aaa`. It can also be used by the user to pass samples that are already known.
|
|
176
|
+
f_k: array
|
|
177
|
+
The function values of the samples that have already been collected. This is used internally to recursively call `selective_subdivision_aaa`. It can also be used by the user to pass samples that are already known.
|
|
178
|
+
divide_horizontal: bool
|
|
179
|
+
If True, the next domain division will be horizontal. Used internally to alternate between horizontal and vertical divisions during recursion.
|
|
180
|
+
debug_plot_domains: bool
|
|
181
|
+
If True, the algorithm will plot the domains that are being searched. This is useful for debugging and understanding the algorithm.
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
TODO
|
|
185
|
+
----
|
|
186
|
+
- allow access to samples slightly outside of domain
|
|
187
|
+
- divide horizontal/vertical according to the distribution of poles
|
|
188
|
+
"""
|
|
189
|
+
|
|
190
|
+
domain_size = np.abs(domain[1]-domain[0])/2
|
|
191
|
+
|
|
192
|
+
if debug_plot_domains:
|
|
193
|
+
print(f"Domain: {domain}")
|
|
194
|
+
plot_domain(domain, size=domain_size)
|
|
195
|
+
|
|
196
|
+
if cutoff is None:
|
|
197
|
+
cutoff = np.inf
|
|
198
|
+
|
|
199
|
+
eval_count = 0
|
|
200
|
+
if use_adaptive:
|
|
201
|
+
if z_k is None:
|
|
202
|
+
z_k = np.empty((0,), dtype=complex)
|
|
203
|
+
f_k = z_k.copy()
|
|
204
|
+
|
|
205
|
+
if len(z_k) < N:
|
|
206
|
+
z_k_new = sample_domain(domain, N)
|
|
207
|
+
f_k = np.append(f_k, f(z_k_new))
|
|
208
|
+
z_k = np.append(z_k, z_k_new)
|
|
209
|
+
|
|
210
|
+
eval_count += len(z_k_new)
|
|
211
|
+
|
|
212
|
+
eval_count -= len(z_k)
|
|
213
|
+
|
|
214
|
+
# NOTE: reduced domain with a factor larger than 1
|
|
215
|
+
# actually increases domain size to avoid missing poles right at the border
|
|
216
|
+
z_j, f_j, w_j, z_n, z_k, f_k = adaptive_aaa(
|
|
217
|
+
z_k, f, f_k_0=f_k, evolutions=evolutions_adaptive, tol=tol_aaa,
|
|
218
|
+
domain=reduced_domain(domain, 1.07), radius=domain_size*radius_adaptive,
|
|
219
|
+
return_samples=True, cutoff=cutoff
|
|
220
|
+
)
|
|
221
|
+
# TODO pass down samples in buffer zone
|
|
222
|
+
eval_count += len(z_k)
|
|
223
|
+
|
|
224
|
+
else:
|
|
225
|
+
if on_rim:
|
|
226
|
+
z_k = sample_rim(domain, N)
|
|
227
|
+
else:
|
|
228
|
+
z_k = sample_domain(reduced_domain(domain, 1.05), N)
|
|
229
|
+
f_k = f(z_k)
|
|
230
|
+
eval_count += len(f_k)
|
|
231
|
+
try:
|
|
232
|
+
z_j, f_j, w_j, z_n = aaa(z_k, f_k, tol=tol_aaa)
|
|
233
|
+
except onp.linalg.LinAlgError as e:
|
|
234
|
+
z_n = z_j = f_j = w_j = np.empty((0,))
|
|
235
|
+
|
|
236
|
+
poles = z_n[domain_mask(domain, z_n)]
|
|
237
|
+
|
|
238
|
+
if (Dmax == 0 or
|
|
239
|
+
(len(poles)<=max_poles and all_poles_known(poles, suggestions, tol_pol))):
|
|
240
|
+
|
|
241
|
+
res = residues(z_j, f_j, w_j, poles)
|
|
242
|
+
return poles, res, eval_count
|
|
243
|
+
|
|
244
|
+
subs = subdomains(domain, divide_horizontal)
|
|
245
|
+
|
|
246
|
+
pol = np.empty((0,), dtype=complex)
|
|
247
|
+
res = pol.copy()
|
|
248
|
+
for i,sub in enumerate(subs):
|
|
249
|
+
sug = poles[domain_mask(sub, poles)]
|
|
250
|
+
sample_mask = domain_mask(sub, z_k)
|
|
251
|
+
|
|
252
|
+
known_z_k = z_k[sample_mask]
|
|
253
|
+
known_f_k = f_k[sample_mask]
|
|
254
|
+
|
|
255
|
+
p, r, e = selective_subdivision_aaa(
|
|
256
|
+
f, sub, N, max_poles, cutoff, tol_aaa, tol_pol,
|
|
257
|
+
use_adaptive=use_adaptive,
|
|
258
|
+
evolutions_adaptive=evolutions_adaptive,
|
|
259
|
+
radius_adaptive=radius_adaptive, on_rim=on_rim,
|
|
260
|
+
suggestions=sug, Dmax=Dmax-1, z_k=known_z_k, f_k=known_f_k,
|
|
261
|
+
divide_horizontal = not divide_horizontal,
|
|
262
|
+
debug_plot_domains=debug_plot_domains
|
|
263
|
+
)
|
|
264
|
+
pol = np.append(pol, p)
|
|
265
|
+
res = np.append(res, r)
|
|
266
|
+
eval_count += e
|
|
267
|
+
return pol, res, eval_count
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .set_aaa import set_aaa
|