diffaaable 1.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffaaable-1.0.1/LICENSE +21 -0
- diffaaable-1.0.1/PKG-INFO +64 -0
- diffaaable-1.0.1/README.md +33 -0
- diffaaable-1.0.1/diffaaable/__init__.py +10 -0
- diffaaable-1.0.1/diffaaable/adaptive.py +305 -0
- diffaaable-1.0.1/diffaaable/core.py +224 -0
- diffaaable-1.0.1/diffaaable/lorentz.py +88 -0
- diffaaable-1.0.1/diffaaable/selective.py +197 -0
- diffaaable-1.0.1/diffaaable/vectorial.py +114 -0
- diffaaable-1.0.1/pyproject.toml +188 -0
diffaaable-1.0.1/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2024 Jan David Fischbach
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: diffaaable
|
|
3
|
+
Version: 1.0.1
|
|
4
|
+
Summary: JAX-differentiable AAA algorithm
|
|
5
|
+
Keywords: python
|
|
6
|
+
Author-email: Jan David Fischbach <fischbach@kit.edu>
|
|
7
|
+
Requires-Python: >=3.9
|
|
8
|
+
Description-Content-Type: text/markdown
|
|
9
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
12
|
+
Classifier: Operating System :: OS Independent
|
|
13
|
+
Requires-Dist: numpy
|
|
14
|
+
Requires-Dist: jax
|
|
15
|
+
Requires-Dist: jaxlib
|
|
16
|
+
Requires-Dist: baryrat
|
|
17
|
+
Requires-Dist: jaxopt
|
|
18
|
+
Requires-Dist: pytest
|
|
19
|
+
Requires-Dist: pytest-benchmark
|
|
20
|
+
Requires-Dist: pre-commit ; extra == "dev"
|
|
21
|
+
Requires-Dist: pytest ; extra == "dev"
|
|
22
|
+
Requires-Dist: pytest-cov ; extra == "dev"
|
|
23
|
+
Requires-Dist: pytest_regressions ; extra == "dev"
|
|
24
|
+
Requires-Dist: jupytext ; extra == "docs"
|
|
25
|
+
Requires-Dist: matplotlib ; extra == "docs"
|
|
26
|
+
Requires-Dist: jupyter-book ; extra == "docs"
|
|
27
|
+
Requires-Dist: sphinx_math_dollar ; extra == "docs"
|
|
28
|
+
Provides-Extra: dev
|
|
29
|
+
Provides-Extra: docs
|
|
30
|
+
|
|
31
|
+
# diffaaable 1.0.1
|
|
32
|
+
|
|
33
|
+

|
|
34
|
+
|
|
35
|
+
`diffaaable` is a JAX differentiable version of the AAA algorithm. The derivatives are implemented as custom Jacobian Vector products in accordance to [^1].
|
|
36
|
+
A detailed derivation of the used matrix expressions is provided in the appendix of [^2].
|
|
37
|
+
Under the hood `diffaaable` uses the AAA implementation of [`baryrat`](https://github.com/c-f-h/baryrat).
|
|
38
|
+
Additionaly the following application specific extensions to the AAA algorithm are included:
|
|
39
|
+
|
|
40
|
+
- **Adaptive**: Adaptive refinement strategy to minimize the number of function evaluation needed to precisely locate poles within some domain
|
|
41
|
+
- **Vectorial** (also referred to as set-valued): AAA algorithm acting on vector valued functions $\mathbf{f}(z)$ as presented in [^3].
|
|
42
|
+
- **Lorentz**: Variant that enforces symmetric poles around the imaginary axis.
|
|
43
|
+
- **Selective Refinement**: Use a divide and conquer theme to capture many pole simultaneously and accurately, by limiting the number of poles per AAA solve. Suggested in [^4].
|
|
44
|
+
|
|
45
|
+
## Installation
|
|
46
|
+
to install `diffaaable` run
|
|
47
|
+
`pip install diffaaable`
|
|
48
|
+
|
|
49
|
+
## Usage
|
|
50
|
+
Please refer to the [quickstart tutorial](./usage.md)
|
|
51
|
+
|
|
52
|
+
## Contributing
|
|
53
|
+
Feel free to open issues and/or PRs.
|
|
54
|
+
|
|
55
|
+
## Citation
|
|
56
|
+
When using this software package for scientific work please cite the associated publication [^2].
|
|
57
|
+
|
|
58
|
+
+++
|
|
59
|
+
|
|
60
|
+
[^1]: https://arxiv.org/pdf/2403.19404
|
|
61
|
+
[^2]: Multiscat Resonances (to be publlished)
|
|
62
|
+
[^3]: https://doi.org/10.1093/imanum/draa098
|
|
63
|
+
[^4]: https://doi.org/10.48550/arXiv.2405.19582
|
|
64
|
+
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
# diffaaable 1.0.1
|
|
2
|
+
|
|
3
|
+

|
|
4
|
+
|
|
5
|
+
`diffaaable` is a JAX differentiable version of the AAA algorithm. The derivatives are implemented as custom Jacobian Vector products in accordance to [^1].
|
|
6
|
+
A detailed derivation of the used matrix expressions is provided in the appendix of [^2].
|
|
7
|
+
Under the hood `diffaaable` uses the AAA implementation of [`baryrat`](https://github.com/c-f-h/baryrat).
|
|
8
|
+
Additionaly the following application specific extensions to the AAA algorithm are included:
|
|
9
|
+
|
|
10
|
+
- **Adaptive**: Adaptive refinement strategy to minimize the number of function evaluation needed to precisely locate poles within some domain
|
|
11
|
+
- **Vectorial** (also referred to as set-valued): AAA algorithm acting on vector valued functions $\mathbf{f}(z)$ as presented in [^3].
|
|
12
|
+
- **Lorentz**: Variant that enforces symmetric poles around the imaginary axis.
|
|
13
|
+
- **Selective Refinement**: Use a divide and conquer theme to capture many pole simultaneously and accurately, by limiting the number of poles per AAA solve. Suggested in [^4].
|
|
14
|
+
|
|
15
|
+
## Installation
|
|
16
|
+
to install `diffaaable` run
|
|
17
|
+
`pip install diffaaable`
|
|
18
|
+
|
|
19
|
+
## Usage
|
|
20
|
+
Please refer to the [quickstart tutorial](./usage.md)
|
|
21
|
+
|
|
22
|
+
## Contributing
|
|
23
|
+
Feel free to open issues and/or PRs.
|
|
24
|
+
|
|
25
|
+
## Citation
|
|
26
|
+
When using this software package for scientific work please cite the associated publication [^2].
|
|
27
|
+
|
|
28
|
+
+++
|
|
29
|
+
|
|
30
|
+
[^1]: https://arxiv.org/pdf/2403.19404
|
|
31
|
+
[^2]: Multiscat Resonances (to be publlished)
|
|
32
|
+
[^3]: https://doi.org/10.1093/imanum/draa098
|
|
33
|
+
[^4]: https://doi.org/10.48550/arXiv.2405.19582
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
"""diffaaable - JAX-differentiable AAA algorithm"""
|
|
2
|
+
|
|
3
|
+
__version__ = "1.0.1"
|
|
4
|
+
|
|
5
|
+
__all__ = ["aaa", "adaptive_aaa", "vectorial_aaa", "lorentz_aaa"]
|
|
6
|
+
|
|
7
|
+
from diffaaable.core import aaa
|
|
8
|
+
from diffaaable.adaptive import adaptive_aaa
|
|
9
|
+
from diffaaable.vectorial import vectorial_aaa
|
|
10
|
+
from diffaaable.lorentz import lorentz_aaa
|
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
from typing import Union
|
|
2
|
+
import jax.numpy as np
|
|
3
|
+
import numpy.typing as npt
|
|
4
|
+
from diffaaable import aaa
|
|
5
|
+
import jax
|
|
6
|
+
from jax import random
|
|
7
|
+
import matplotlib.pyplot as plt
|
|
8
|
+
from functools import partial
|
|
9
|
+
Domain = tuple[complex, complex]
|
|
10
|
+
|
|
11
|
+
def top_right(a: npt.NDArray[complex], b: npt.NDArray[complex]):
|
|
12
|
+
return np.logical_and(a.imag>=b.imag, a.real>=b.real)
|
|
13
|
+
|
|
14
|
+
def domain_mask(domain: Domain, z_n):
|
|
15
|
+
larger_min = top_right(z_n, domain[0])
|
|
16
|
+
smaller_max = top_right(domain[1], z_n)
|
|
17
|
+
return np.logical_and(larger_min, smaller_max)
|
|
18
|
+
|
|
19
|
+
@jax.tree_util.Partial
|
|
20
|
+
def next_samples(z_n, prev_z_n, z_k, domain: Domain, radius, randkey, tolerance=1e-9, min_samples=0, max_samples=0):
|
|
21
|
+
z_n = z_n[domain_mask(domain, z_n)]
|
|
22
|
+
movement = np.min(np.abs(z_n[:, None]-prev_z_n[None, :]), axis=-1)
|
|
23
|
+
ranking = np.argsort(-movement)
|
|
24
|
+
unstable = movement > tolerance
|
|
25
|
+
if np.sum(unstable) > min_samples:
|
|
26
|
+
if max_samples != 0 and np.sum(unstable)>max_samples:
|
|
27
|
+
z_n_unstable = z_n[ranking[:max_samples]]
|
|
28
|
+
else:
|
|
29
|
+
z_n_unstable = z_n[unstable]
|
|
30
|
+
else:
|
|
31
|
+
z_n_unstable = z_n[ranking[:min_samples]]
|
|
32
|
+
|
|
33
|
+
add_z_k = z_n_unstable
|
|
34
|
+
add_z_k += radius*np.exp(1j*2*np.pi*jax.random.uniform(randkey, add_z_k.shape))
|
|
35
|
+
return add_z_k
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def heat(poles, samples, mesh, sigma):
|
|
39
|
+
#jax.debug.print("poles: {}", poles)
|
|
40
|
+
f_p = np.nansum(
|
|
41
|
+
np.exp(-np.abs(poles[:, None, None]-mesh[None, :])**2/sigma**2),
|
|
42
|
+
axis=0
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
#jax.debug.print("{}", f_p)
|
|
46
|
+
|
|
47
|
+
f = f_p / np.nansum(
|
|
48
|
+
sigma**2/np.abs(mesh[None, :]-samples[:, None, None])**2,
|
|
49
|
+
axis=0
|
|
50
|
+
)
|
|
51
|
+
return f
|
|
52
|
+
|
|
53
|
+
@partial(jax.jit, static_argnames=["resolution", "batchsize"])
|
|
54
|
+
def _next_samples_heat(
|
|
55
|
+
poles, prev_poles, samples, domain, radius, randkey, resolution=(101, 101),
|
|
56
|
+
batchsize=1, stop=0.2
|
|
57
|
+
):
|
|
58
|
+
|
|
59
|
+
x = np.linspace(domain[0].real, domain[1].real, resolution[0])
|
|
60
|
+
y = np.linspace(domain[0].imag, domain[1].imag, resolution[1])
|
|
61
|
+
|
|
62
|
+
X, Y = np.meshgrid(x,y, indexing="ij")
|
|
63
|
+
mesh = X +1j*Y
|
|
64
|
+
|
|
65
|
+
add_samples = np.empty(0, dtype=complex)
|
|
66
|
+
for j in range(batchsize):
|
|
67
|
+
heat_map = heat(poles, np.concat([samples, add_samples]), mesh, sigma=radius)
|
|
68
|
+
next_i = np.unravel_index(np.nanargmax(heat_map), heat_map.shape)
|
|
69
|
+
|
|
70
|
+
next = np.where(heat_map[next_i] < stop, np.nan, mesh[next_i])
|
|
71
|
+
add_samples = np.append(add_samples, next)
|
|
72
|
+
|
|
73
|
+
return add_samples, X, Y, heat(poles, np.concat([samples]), mesh, sigma=radius)
|
|
74
|
+
|
|
75
|
+
@jax.tree_util.Partial
|
|
76
|
+
def next_samples_heat(
|
|
77
|
+
poles, prev_poles, samples, domain, radius, randkey, resolution=(101, 101),
|
|
78
|
+
batchsize=1, stop=0.2, debug=False, debug_known_poles=None
|
|
79
|
+
):
|
|
80
|
+
|
|
81
|
+
add_samples, X, Y, heat_map = _next_samples_heat(
|
|
82
|
+
poles, prev_poles, samples, domain, radius, randkey, resolution,
|
|
83
|
+
batchsize, stop
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
add_samples = add_samples[~np.isnan(add_samples)]
|
|
87
|
+
|
|
88
|
+
if debug:
|
|
89
|
+
ax = plt.gca()
|
|
90
|
+
plt.figure()
|
|
91
|
+
plt.title(f"radius={radius}")
|
|
92
|
+
if resolution[1]>1:
|
|
93
|
+
plt.pcolormesh(X, Y, heat_map, vmax=1)#, alpha=np.clip(heat_map, 0, 1))
|
|
94
|
+
plt.colorbar()
|
|
95
|
+
plt.scatter(samples.real, samples.imag, label="samples", zorder=1)
|
|
96
|
+
plt.scatter(poles.real, poles.imag, color="C1", marker="x", label="est. pole", zorder=2)
|
|
97
|
+
plt.scatter(add_samples.real, add_samples.imag, color="C2", label="next samples", zorder=3)
|
|
98
|
+
if debug_known_poles is not None:
|
|
99
|
+
plt.scatter(debug_known_poles.real, debug_known_poles.imag, color="C3", marker="+", label="known pole")
|
|
100
|
+
plt.xlim(domain[0].real, domain[1].real)
|
|
101
|
+
plt.ylim(domain[0].imag, domain[1].imag)
|
|
102
|
+
plt.legend(loc="lower right")
|
|
103
|
+
else:
|
|
104
|
+
plt.plot(np.squeeze(X), np.squeeze(heat_map))
|
|
105
|
+
plt.scatter(samples.real, np.zeros(len(samples)))
|
|
106
|
+
plt.scatter(poles.real, np.zeros(len(poles)), color="C1", marker="x", label="est. pole", zorder=2)
|
|
107
|
+
plt.xlim(domain[0].real, domain[1].real)
|
|
108
|
+
plt.savefig(f"{debug}/{len(samples)}.png")
|
|
109
|
+
plt.close()
|
|
110
|
+
plt.sca(ax)
|
|
111
|
+
|
|
112
|
+
return add_samples
|
|
113
|
+
|
|
114
|
+
aaa = jax.tree_util.Partial(aaa)
|
|
115
|
+
|
|
116
|
+
def _adaptive_aaa(z_k_0: npt.NDArray,
|
|
117
|
+
f: callable,
|
|
118
|
+
evolutions: int = 2,
|
|
119
|
+
cutoff: float = None,
|
|
120
|
+
tol: float = 1e-9,
|
|
121
|
+
mmax: int = 100,
|
|
122
|
+
radius: float = None,
|
|
123
|
+
domain: tuple[complex, complex] = None,
|
|
124
|
+
f_k_0: npt.NDArray = None,
|
|
125
|
+
sampling: Union[callable, str] = next_samples,
|
|
126
|
+
prev_z_n: npt.NDArray = None,
|
|
127
|
+
return_samples: bool = False,
|
|
128
|
+
aaa: callable = aaa,
|
|
129
|
+
f_dot: callable = None):
|
|
130
|
+
"""
|
|
131
|
+
Implementation of `adaptive_aaa`
|
|
132
|
+
|
|
133
|
+
Parameters
|
|
134
|
+
----------
|
|
135
|
+
see `adaptive_aaa`
|
|
136
|
+
|
|
137
|
+
f_dot: callable
|
|
138
|
+
Tangent of `f`. If provided JVPs of `f` will be collected throughout the
|
|
139
|
+
iterations. For use in custom_jvp
|
|
140
|
+
"""
|
|
141
|
+
|
|
142
|
+
if sampling == "heat":
|
|
143
|
+
sampling = next_samples_heat
|
|
144
|
+
|
|
145
|
+
collect_tangents = f_dot is not None
|
|
146
|
+
z_k = z_k_0
|
|
147
|
+
max_dist = np.max(np.abs(z_k_0[:, np.newaxis] - z_k_0[np.newaxis, :]))
|
|
148
|
+
|
|
149
|
+
if collect_tangents:
|
|
150
|
+
f_unpartial = f.func
|
|
151
|
+
args, _ = jax.tree.flatten(f)
|
|
152
|
+
args_dot, _ = jax.tree.flatten(f_dot)
|
|
153
|
+
z_k_dot = np.zeros_like(z_k)
|
|
154
|
+
f_k, f_k_dot = jax.jvp(
|
|
155
|
+
f_unpartial, (*args, z_k), (*args_dot, z_k_dot)
|
|
156
|
+
)
|
|
157
|
+
else:
|
|
158
|
+
if f_k_0 is None:
|
|
159
|
+
f_k = f(z_k)
|
|
160
|
+
else:
|
|
161
|
+
f_k = f_k_0
|
|
162
|
+
f_k_dot = np.zeros_like(f_k)
|
|
163
|
+
|
|
164
|
+
if cutoff is None:
|
|
165
|
+
cutoff = 1e10*np.median(np.abs(f_k))
|
|
166
|
+
|
|
167
|
+
if domain is None:
|
|
168
|
+
center = np.mean(z_k)
|
|
169
|
+
disp = max_dist*(1+1j)
|
|
170
|
+
domain = (center-disp, center+disp)
|
|
171
|
+
|
|
172
|
+
if radius is None:
|
|
173
|
+
radius = 1e-3 * max_dist
|
|
174
|
+
|
|
175
|
+
if prev_z_n is None:
|
|
176
|
+
prev_z_n = np.array([np.inf], dtype=complex)
|
|
177
|
+
|
|
178
|
+
def mask(z_k, f_k, f_k_dot):
|
|
179
|
+
m = np.abs(f_k)<cutoff
|
|
180
|
+
if m.ndim == 2:
|
|
181
|
+
m = np.all(m, axis=1) #filter out values, that have diverged too strongly
|
|
182
|
+
return z_k[m], f_k[m], f_k_dot[m]
|
|
183
|
+
|
|
184
|
+
key = random.key(0)
|
|
185
|
+
for i in range(evolutions):
|
|
186
|
+
z_k, f_k, f_k_dot = mask(z_k, f_k, f_k_dot)
|
|
187
|
+
z_j, f_j, w_j, z_n = aaa(z_k, f_k, tol, mmax)
|
|
188
|
+
|
|
189
|
+
if i==evolutions-1:
|
|
190
|
+
break
|
|
191
|
+
|
|
192
|
+
key, subkey = jax.random.split(key)
|
|
193
|
+
#print(f"{z_n=}")
|
|
194
|
+
#print(f"{prev_z_n=}")
|
|
195
|
+
add_z_k = sampling(z_n, prev_z_n, z_k, domain, radius, subkey)
|
|
196
|
+
prev_z_n = z_n
|
|
197
|
+
add_z_k_dot = np.zeros_like(add_z_k)
|
|
198
|
+
|
|
199
|
+
if len(add_z_k) == 0:
|
|
200
|
+
break
|
|
201
|
+
|
|
202
|
+
if collect_tangents:
|
|
203
|
+
f_k_new , f_k_dot_new = jax.jvp(
|
|
204
|
+
f_unpartial, (*args, add_z_k), (*args_dot, add_z_k_dot)
|
|
205
|
+
)
|
|
206
|
+
else:
|
|
207
|
+
f_k_new = f(add_z_k)
|
|
208
|
+
f_k_dot_new = np.zeros_like(f_k_new)
|
|
209
|
+
|
|
210
|
+
z_k = np.append(z_k, add_z_k)
|
|
211
|
+
f_k = np.concatenate([f_k, f_k_new])
|
|
212
|
+
f_k_dot = np.concatenate([f_k_dot, f_k_dot_new])
|
|
213
|
+
|
|
214
|
+
z_k, f_k, f_k_dot = mask(z_k, f_k, f_k_dot)
|
|
215
|
+
|
|
216
|
+
if collect_tangents:
|
|
217
|
+
return z_k, f_k, f_k_dot
|
|
218
|
+
if return_samples:
|
|
219
|
+
return z_j, f_j, w_j, z_n, z_k, f_k
|
|
220
|
+
return z_j, f_j, w_j, z_n
|
|
221
|
+
|
|
222
|
+
@jax.custom_jvp
|
|
223
|
+
def adaptive_aaa(z_k_0: npt.NDArray,
|
|
224
|
+
f:callable,
|
|
225
|
+
evolutions: int = 2,
|
|
226
|
+
cutoff: float = None,
|
|
227
|
+
tol: float = 1e-9,
|
|
228
|
+
mmax: int = 100,
|
|
229
|
+
radius: float = None,
|
|
230
|
+
domain: Domain = None,
|
|
231
|
+
f_k_0: npt.NDArray = None,
|
|
232
|
+
sampling: callable = next_samples,
|
|
233
|
+
prev_z_n: npt.NDArray = None,
|
|
234
|
+
return_samples: bool = False,
|
|
235
|
+
aaa: callable = aaa):
|
|
236
|
+
""" An 2x adaptive Antoulas–Anderson algorithm for rational approximation of
|
|
237
|
+
meromorphic functions that are costly to evaluate.
|
|
238
|
+
|
|
239
|
+
The algorithm iteratively places additional sample points close to estimated
|
|
240
|
+
positions of poles identified during the past iteration. By this refinement
|
|
241
|
+
scheme the number of function evaluations can be reduced.
|
|
242
|
+
|
|
243
|
+
It is JAX differentiable wrt. the approximated function `f`, via its other
|
|
244
|
+
arguments besides `z`. `f` should be provided as a `jax.tree_util.Partial`
|
|
245
|
+
with only positional arguments pre-filled!
|
|
246
|
+
|
|
247
|
+
Parameters
|
|
248
|
+
----------
|
|
249
|
+
z_k_0 : np.ndarray
|
|
250
|
+
Array of initial sample points
|
|
251
|
+
f : callable
|
|
252
|
+
function to be approximated. When using gradients `f` should be provided
|
|
253
|
+
as a `jax.tree_util.Partial` with only positional arguments pre-filled.
|
|
254
|
+
Furthermore it should be compatible to `jax.jvp`.
|
|
255
|
+
evolutions: int
|
|
256
|
+
Number of refinement iterations
|
|
257
|
+
cutoff: float
|
|
258
|
+
Maximum absolute value a function evaluation should take
|
|
259
|
+
to be regarded valid. Otherwise the sample point is discarded.
|
|
260
|
+
Defaults to 1e10 times the median of `f(z_k_0)`
|
|
261
|
+
tol: float
|
|
262
|
+
Tolerance used in AAA (see `diffaaable.aaa`)
|
|
263
|
+
radius: float
|
|
264
|
+
Distance from the assumed poles for nex samples
|
|
265
|
+
domain: tuple[complex, complex]
|
|
266
|
+
Tuple of min (lower left) and max (upper right) values defining a
|
|
267
|
+
rectangle in the complex plane. Assumed poles outside of the domain
|
|
268
|
+
will not receive refinement.
|
|
269
|
+
f_k_0:
|
|
270
|
+
Allows user to provide f evaluated at z_k_0
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
Returns
|
|
274
|
+
-------
|
|
275
|
+
z_j: np.array
|
|
276
|
+
chosen samples
|
|
277
|
+
f_j: np.array
|
|
278
|
+
`f(z_j)`
|
|
279
|
+
w_j: np.array
|
|
280
|
+
Weights of Barycentric Approximation
|
|
281
|
+
z_n: np.array
|
|
282
|
+
Poles of Barycentric Approximation
|
|
283
|
+
"""
|
|
284
|
+
return _adaptive_aaa(
|
|
285
|
+
z_k_0=z_k_0, f=f, evolutions=evolutions, cutoff=cutoff, tol=tol, mmax=mmax,
|
|
286
|
+
radius=radius, domain=domain, f_k_0=f_k_0, sampling=sampling,
|
|
287
|
+
prev_z_n=prev_z_n, return_samples=return_samples, aaa=aaa
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
@adaptive_aaa.defjvp
|
|
291
|
+
def adaptive_aaa_jvp(primals, tangents):
|
|
292
|
+
z_k_0, f = primals[:2]
|
|
293
|
+
z_dot, f_dot = tangents[:2]
|
|
294
|
+
|
|
295
|
+
if np.any(z_dot):
|
|
296
|
+
raise NotImplementedError(
|
|
297
|
+
"Parametrizing the sampling positions z_k is not supported"
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
z_k, f_k, f_k_dot = \
|
|
301
|
+
_adaptive_aaa(z_k_0, f, *primals[2:], f_dot=f_dot)
|
|
302
|
+
|
|
303
|
+
z_k_dot = np.zeros_like(z_k)
|
|
304
|
+
|
|
305
|
+
return jax.jvp(aaa, (z_k, f_k), (z_k_dot, f_k_dot))
|
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
from jax import config
|
|
2
|
+
config.update("jax_enable_x64", True) #important -> else aaa fails
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
import numpy.typing as npt
|
|
5
|
+
import jax
|
|
6
|
+
import numpy as np
|
|
7
|
+
from baryrat import aaa as oaaa # ordinary aaa
|
|
8
|
+
import functools
|
|
9
|
+
import scipy.linalg
|
|
10
|
+
|
|
11
|
+
@functools.wraps(oaaa)
|
|
12
|
+
@jax.custom_jvp
|
|
13
|
+
def aaa(z_k: npt.NDArray, f_k: npt.NDArray, tol: float=1e-13, mmax: int=100):
|
|
14
|
+
"""
|
|
15
|
+
Wraped aaa to enable JAX based autodiff.
|
|
16
|
+
"""
|
|
17
|
+
r = oaaa(z_k, f_k, tol=tol, mmax=mmax)
|
|
18
|
+
z_j = r.nodes
|
|
19
|
+
f_j = r.values
|
|
20
|
+
w_j = r.weights
|
|
21
|
+
|
|
22
|
+
mask = w_j!=0
|
|
23
|
+
z_j = z_j[mask]
|
|
24
|
+
f_j = f_j[mask]
|
|
25
|
+
w_j = w_j[mask]
|
|
26
|
+
|
|
27
|
+
z_n = poles(z_j, w_j)
|
|
28
|
+
|
|
29
|
+
z_n = z_n[jnp.argsort(-jnp.abs(z_n))]
|
|
30
|
+
|
|
31
|
+
return z_j, f_j, w_j, z_n
|
|
32
|
+
|
|
33
|
+
delimiter = ' \n'
|
|
34
|
+
|
|
35
|
+
aaa.__doc__ = f"""\
|
|
36
|
+
This is a wrapped version of `aaa` as provided by `baryrat`,
|
|
37
|
+
providing a custom jvp to enable differentiability.
|
|
38
|
+
|
|
39
|
+
For detailed information on the usage of `aaa` please refer to
|
|
40
|
+
the original documentation::
|
|
41
|
+
|
|
42
|
+
{delimiter.join(aaa.__doc__.splitlines())}
|
|
43
|
+
|
|
44
|
+
.. attention::
|
|
45
|
+
Returns nodes, values, weights and poles, in contrast to the
|
|
46
|
+
baryrat implementation that returns the BarycentricRational ready to be
|
|
47
|
+
evaluated. This is done to facilitate differentiability.
|
|
48
|
+
|
|
49
|
+
Parameters
|
|
50
|
+
----------
|
|
51
|
+
z_k : array (M,)
|
|
52
|
+
the sampling points of the function. Unlike for interpolation
|
|
53
|
+
algorithms, where a small number of nodes is preferred, since the
|
|
54
|
+
AAA algorithm chooses its support points adaptively, it is better
|
|
55
|
+
to provide a finer mesh over the support.
|
|
56
|
+
f_k : array (M,)
|
|
57
|
+
the function to be approximated; can be given as a callable function
|
|
58
|
+
or as an array of function values over `z_k`.
|
|
59
|
+
tol : float
|
|
60
|
+
the approximation tolerance
|
|
61
|
+
mmax : int
|
|
62
|
+
the maximum number of iterations/degree of the resulting approximant
|
|
63
|
+
|
|
64
|
+
Returns
|
|
65
|
+
-------
|
|
66
|
+
z_j : array (m,)
|
|
67
|
+
nodes of the barycentric approximant
|
|
68
|
+
|
|
69
|
+
f_j : array (m,)
|
|
70
|
+
values of the barycentric approximant
|
|
71
|
+
|
|
72
|
+
w_j : array (m,)
|
|
73
|
+
weights of the barycentric approximant
|
|
74
|
+
|
|
75
|
+
z_n : array (m-1,)
|
|
76
|
+
poles of the barycentric approximant (for convenience)
|
|
77
|
+
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
@aaa.defjvp
|
|
81
|
+
def aaa_jvp(primals, tangents):
|
|
82
|
+
r"""Derivatives according to [1].
|
|
83
|
+
The implemented matrix expressions are motivated in the appendix of [2]:
|
|
84
|
+
|
|
85
|
+
.. topic:: AAA Derivatives
|
|
86
|
+
|
|
87
|
+
Here we will briefly elaborate how the derivatives introduced in [1] are
|
|
88
|
+
implemented as JAX compatible Jacobian Vector products (JVPs) in `diffaaable`.
|
|
89
|
+
|
|
90
|
+
Given the tangents $\frac{\partial f_k}{\partial p}$ we will use the chain rule on $r(w_j, f_j, z)$ along its weights $w_j$ and values $f_j$ (the nodes $z_j$ are treated as independent of $p$) (Equation 4 of [1]):
|
|
91
|
+
\[
|
|
92
|
+
\frac{\partial f_k}{\partial p} \approx \frac{\partial r_k}{\partial p}= \sum_{j=1}^m\frac{\partial r_k}{\partial f_j}\frac{\partial f_j}{\partial p}+\sum_{j=1}^m\frac{\partial r_k}{\partial w_j}\frac{\partial w_j}{\partial p}
|
|
93
|
+
\]
|
|
94
|
+
|
|
95
|
+
To solve this system of equations for $\frac{\partial w_j}{\partial p}$ we express it in matrix form:
|
|
96
|
+
|
|
97
|
+
.. math::
|
|
98
|
+
\mathbf{A}\mathbf{w}^\prime = \mathbf{b}
|
|
99
|
+
|
|
100
|
+
where $\mathbf{w}^\prime$ is the column vector containing $\frac{\partial w_j}{\partial p}$ and $\mathbf{b}$ and $\mathbf{A}$ are defined element wise:
|
|
101
|
+
|
|
102
|
+
.. math::
|
|
103
|
+
\begin{aligned}
|
|
104
|
+
b_k &= \frac{\partial f_k}{\partial p} - \sum_{j=1}^m\frac{\partial r_k}{\partial f_j}\frac{\partial f_j}{\partial p}\\
|
|
105
|
+
A_{kj} &= \frac{\partial r_k}{\partial w_j}
|
|
106
|
+
\end{aligned}
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
These are augmented with Equation 5 of \cite{betzEfficientRationalApproximation2024} which removes the ambiguity in $\mathbf{w}^\prime$ associated with a shared phase of all weights.
|
|
110
|
+
|
|
111
|
+
The expressions for the derivatives of $r_k$ are found in the definition of $r(z)$ (Equation 1 in [2]): \[ \frac{\partial r_k}{\partial f_j}= \frac{1}{d(z_k)} \frac{w_j}{z_k-z_j}\] and
|
|
112
|
+
\[
|
|
113
|
+
\frac{\partial r_k}{\partial w_j}= \frac{1}{d(z_k)} \frac{f_j-r_k}{z_k-z_j} \approx \frac{1}{d(z_k)} \frac{f_j-f_k}{z_k-z_j}
|
|
114
|
+
\]
|
|
115
|
+
|
|
116
|
+
"""
|
|
117
|
+
z_k_full, f_k = primals[:2]
|
|
118
|
+
z_dot, f_dot = tangents[:2]
|
|
119
|
+
|
|
120
|
+
primal_out = aaa(z_k_full, f_k)
|
|
121
|
+
z_j, f_j, w_j, z_n = primal_out
|
|
122
|
+
|
|
123
|
+
chosen = np.isin(z_k_full, z_j)
|
|
124
|
+
|
|
125
|
+
z_k = z_k_full[~chosen]
|
|
126
|
+
f_k = f_k[~chosen]
|
|
127
|
+
|
|
128
|
+
# z_dot should be zero anyways
|
|
129
|
+
if np.any(z_dot):
|
|
130
|
+
raise NotImplementedError("Parametrizing the sampling positions z_k is not supported")
|
|
131
|
+
z_k_dot = z_dot[~chosen]
|
|
132
|
+
f_k_dot = f_dot[~chosen] # $\del f_k / \del p$
|
|
133
|
+
|
|
134
|
+
##################################################
|
|
135
|
+
# We have to track which f_dot corresponds to z_k
|
|
136
|
+
sort_orig = jnp.argsort(jnp.abs(z_k_full[chosen]))
|
|
137
|
+
sort_out = jnp.argsort(jnp.argsort(jnp.abs(z_j)))
|
|
138
|
+
|
|
139
|
+
z_j_dot = z_dot[chosen][sort_orig][sort_out]
|
|
140
|
+
f_j_dot = f_dot[chosen][sort_orig][sort_out]
|
|
141
|
+
##################################################
|
|
142
|
+
|
|
143
|
+
C = 1/(z_k[:, None]-z_j[None, :]) # Cauchy matrix k x j
|
|
144
|
+
|
|
145
|
+
d = C @ w_j # denominator in barycentric formula
|
|
146
|
+
via_f_j = C @ (f_j_dot * w_j) / d # $\sum_j f_j^\prime \frac{\del r}{\del f_j}$
|
|
147
|
+
|
|
148
|
+
A = (f_j[None, :] - f_k[:, None])*C/d[:, None]
|
|
149
|
+
b = f_k_dot - via_f_j
|
|
150
|
+
|
|
151
|
+
# make sure system is not underdetermined according to eq. 5 of [1]
|
|
152
|
+
A = jnp.concatenate([A, np.conj(w_j.reshape(1, -1))])
|
|
153
|
+
b = jnp.append(b, 0)
|
|
154
|
+
|
|
155
|
+
with jax.disable_jit(): #otherwise backwards differentiation led to error
|
|
156
|
+
w_j_dot, _, _, _ = jnp.linalg.lstsq(A, b)
|
|
157
|
+
|
|
158
|
+
denom = z_n.reshape(1, -1)-z_j.reshape(-1, 1)
|
|
159
|
+
# jax.debug.print("wj: {}", w_j.reshape(-1, 1))
|
|
160
|
+
# jax.debug.print("denom^2: {}", denom**2)
|
|
161
|
+
z_n_dot = (
|
|
162
|
+
jnp.sum(w_j_dot.reshape(-1, 1)/denom, axis=0)/
|
|
163
|
+
jnp.sum(w_j.reshape(-1, 1) /denom**2, axis=0)
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
tangent_out = z_j_dot, f_j_dot, w_j_dot, z_n_dot
|
|
167
|
+
|
|
168
|
+
return primal_out, tangent_out
|
|
169
|
+
|
|
170
|
+
def poles(z_j,w_j):
|
|
171
|
+
"""
|
|
172
|
+
The poles of a barycentric rational with given nodes and weights.
|
|
173
|
+
Poles lifted by zeros of the nominator are included.
|
|
174
|
+
Thus the values $f_j$ do not contribute and don't need to be provided
|
|
175
|
+
The implementation was modified from `baryrat` to support JAX AD.
|
|
176
|
+
|
|
177
|
+
Parameters
|
|
178
|
+
----------
|
|
179
|
+
z_j : array (m,)
|
|
180
|
+
nodes of the barycentric rational
|
|
181
|
+
w_j : array (m,)
|
|
182
|
+
weights of the barycentric rational
|
|
183
|
+
|
|
184
|
+
Returns
|
|
185
|
+
-------
|
|
186
|
+
z_n : array (m-1,)
|
|
187
|
+
poles of the barycentric rational (more strictly zeros of the denominator)
|
|
188
|
+
"""
|
|
189
|
+
f_j = np.ones_like(z_j)
|
|
190
|
+
|
|
191
|
+
B = np.eye(len(w_j) + 1)
|
|
192
|
+
B[0,0] = 0
|
|
193
|
+
E = np.block([[0, w_j],
|
|
194
|
+
[f_j[:,None], np.diag(z_j)]])
|
|
195
|
+
evals = scipy.linalg.eigvals(E, B)
|
|
196
|
+
return evals[np.isfinite(evals)]
|
|
197
|
+
|
|
198
|
+
def residues(z_j,f_j,w_j,z_n):
|
|
199
|
+
'''
|
|
200
|
+
Residues for given poles via formula for simple poles
|
|
201
|
+
of quotients of analytic functions.
|
|
202
|
+
The implementation was modified from `baryrat` to support JAX AD.
|
|
203
|
+
|
|
204
|
+
Parameters
|
|
205
|
+
----------
|
|
206
|
+
z_j : array (m,)
|
|
207
|
+
nodes of the barycentric rational
|
|
208
|
+
w_j : array (m,)
|
|
209
|
+
weights of the barycentric rational
|
|
210
|
+
z_n : array (n,)
|
|
211
|
+
poles of interest of the barycentric rational (n<=m-1)
|
|
212
|
+
|
|
213
|
+
Returns
|
|
214
|
+
-------
|
|
215
|
+
r_n : array (n,)
|
|
216
|
+
residues of poles `z_n`
|
|
217
|
+
'''
|
|
218
|
+
|
|
219
|
+
C_pol = 1.0 / (z_n[:,None] - z_j[None,:])
|
|
220
|
+
N_pol = C_pol.dot(f_j*w_j)
|
|
221
|
+
Ddiff_pol = (-C_pol**2).dot(w_j)
|
|
222
|
+
res = N_pol / Ddiff_pol
|
|
223
|
+
|
|
224
|
+
return jnp.nan_to_num(res)
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from jax import config
|
|
2
|
+
config.update("jax_enable_x64", True) #important -> else aaa fails
|
|
3
|
+
import jax.numpy as np
|
|
4
|
+
import jax
|
|
5
|
+
from diffaaable.vectorial import check_inputs
|
|
6
|
+
import jaxopt
|
|
7
|
+
|
|
8
|
+
def optimal_weights(A, A_hat, stepsize=0.5):
|
|
9
|
+
# Initial guess (from uncorrected A)
|
|
10
|
+
_, _, Vh = np.linalg.svd(A)
|
|
11
|
+
w_j = Vh[-1, :].conj()
|
|
12
|
+
|
|
13
|
+
def obj_fun(w):
|
|
14
|
+
w /= np.linalg.norm(w)
|
|
15
|
+
err = (A @ w) + (A_hat @ w.conj())
|
|
16
|
+
return np.linalg.norm(err)
|
|
17
|
+
|
|
18
|
+
solver = jaxopt.LBFGS(fun=obj_fun, maxiter=300, tol=1e-13)
|
|
19
|
+
res = solver.run(w_j)
|
|
20
|
+
w_j, state = res
|
|
21
|
+
w_j /= np.linalg.norm(w_j)
|
|
22
|
+
|
|
23
|
+
return w_j
|
|
24
|
+
|
|
25
|
+
def lorentz_aaa(z_k, f_k, tol=1e-9, mmax=100, return_errors=False):
|
|
26
|
+
"""
|
|
27
|
+
"""
|
|
28
|
+
z_k, f_k, M, V = check_inputs(z_k, f_k)
|
|
29
|
+
|
|
30
|
+
J = np.ones(M, dtype=bool)
|
|
31
|
+
z_j = np.empty(0, dtype=z_k.dtype)
|
|
32
|
+
f_j = np.empty((0, V), dtype=f_k.dtype)
|
|
33
|
+
errors = []
|
|
34
|
+
|
|
35
|
+
reltol = tol * np.linalg.norm(f_k, np.inf)
|
|
36
|
+
|
|
37
|
+
r_k = np.mean(f_k) * np.ones_like(f_k)
|
|
38
|
+
# approx.
|
|
39
|
+
|
|
40
|
+
for m in range(mmax):
|
|
41
|
+
# find largest residual
|
|
42
|
+
jj = np.argmax(np.linalg.norm(f_k - r_k, axis=-1)) #Next sample point to include
|
|
43
|
+
z_j = np.append(z_j, np.array([z_k[jj]]))
|
|
44
|
+
f_j = np.concatenate([f_j, f_k[jj][None, :]])
|
|
45
|
+
J = J.at[jj].set(False)
|
|
46
|
+
|
|
47
|
+
# Cauchy matrix containing the basis functions as columns
|
|
48
|
+
C = 1.0 / (z_k[J,None] - z_j[None,:])
|
|
49
|
+
# Loewner matrix
|
|
50
|
+
A = (f_k[J,None] - f_j[None,:]) * C[:,:,None]
|
|
51
|
+
A = np.concatenate(np.moveaxis(A, -1, 0))
|
|
52
|
+
|
|
53
|
+
# Lorentz Correction
|
|
54
|
+
C_hat = 1.0 / (z_k[J,None] + np.conj(z_j)[None,:])
|
|
55
|
+
# Loewner matrix
|
|
56
|
+
A_hat = (f_k[J,None] - np.conj(f_j)[None,:]) * C_hat[:,:,None]
|
|
57
|
+
A_hat = np.concatenate(np.moveaxis(A_hat, -1, 0))
|
|
58
|
+
|
|
59
|
+
w_j = optimal_weights(A, A_hat)
|
|
60
|
+
|
|
61
|
+
# approximation: numerator / denominator
|
|
62
|
+
N = C.dot(w_j[:, None] * f_j)
|
|
63
|
+
N_hat = C_hat.dot(np.conj(w_j[:, None] * f_j))
|
|
64
|
+
|
|
65
|
+
D = C.dot(w_j)[:, None]
|
|
66
|
+
D_hat = C_hat.dot(np.conj(w_j))[:, None]
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
# update approximation
|
|
71
|
+
r_k = f_k.at[J].set((N + N_hat) / (D+D_hat))
|
|
72
|
+
|
|
73
|
+
# check for convergence
|
|
74
|
+
errors.append(np.linalg.norm(f_k - r_k, np.inf))
|
|
75
|
+
if errors[-1] <= reltol:
|
|
76
|
+
break
|
|
77
|
+
|
|
78
|
+
if V == 1:
|
|
79
|
+
f_j = f_j[:, 0]
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
z_j = np.concatenate([z_j, -np.conj(z_j)])
|
|
83
|
+
f_j = np.concatenate([f_j, np.conj(f_j)])
|
|
84
|
+
w_j = np.concatenate([w_j, np.conj(w_j)])
|
|
85
|
+
|
|
86
|
+
if return_errors:
|
|
87
|
+
return z_j, f_j, w_j, errors
|
|
88
|
+
return z_j, f_j, w_j
|
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import pathlib
|
|
3
|
+
import jax.numpy as np
|
|
4
|
+
import numpy as onp
|
|
5
|
+
from jax.tree_util import Partial
|
|
6
|
+
from diffaaable.core import aaa, residues
|
|
7
|
+
from diffaaable.adaptive import Domain, domain_mask, adaptive_aaa, next_samples_heat
|
|
8
|
+
import matplotlib.pyplot as plt
|
|
9
|
+
|
|
10
|
+
def reduced_domain(domain, reduction=1-1/12):
|
|
11
|
+
r = reduction
|
|
12
|
+
return (
|
|
13
|
+
domain[0]*r+domain[1]*(1-r),
|
|
14
|
+
domain[1]*r+domain[0]*(1-r)
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
def sample_cross(domain):
|
|
18
|
+
center = domain_center(domain)
|
|
19
|
+
dist = 0.5 * (domain[1]-domain[0])
|
|
20
|
+
return center+np.array([dist.real, -dist.real, 1j*dist.imag, -1j*dist.imag])
|
|
21
|
+
|
|
22
|
+
def sample_domain(domain: Domain, N: int):
|
|
23
|
+
sqrt_N = np.round(np.sqrt(N)).astype(int)
|
|
24
|
+
domain = reduced_domain(domain)
|
|
25
|
+
z_k_r = np.linspace(domain[0].real, domain[1].real, sqrt_N)
|
|
26
|
+
z_k_i = np.linspace(domain[0].imag, domain[1].imag, sqrt_N)
|
|
27
|
+
Z_r, Z_i = np.meshgrid(z_k_r, z_k_i)
|
|
28
|
+
z_k = (Z_r+1j*Z_i).flatten()
|
|
29
|
+
return z_k
|
|
30
|
+
|
|
31
|
+
def sample_rim(domain: Domain, N: int):
|
|
32
|
+
side_N = N//4
|
|
33
|
+
z_k_r = np.linspace(domain[0].real, domain[1].real, side_N+2)[1:-1]
|
|
34
|
+
z_k_i = np.linspace(domain[0].imag, domain[1].imag, side_N+2)[1:-1] * 1j
|
|
35
|
+
return np.array([
|
|
36
|
+
1j*domain[0].imag + z_k_r,
|
|
37
|
+
1j*domain[1].imag + z_k_r,
|
|
38
|
+
domain[0].real + z_k_i,
|
|
39
|
+
domain[1].real + z_k_i
|
|
40
|
+
]).flatten()
|
|
41
|
+
|
|
42
|
+
def anti_domain(domain: Domain):
|
|
43
|
+
return (
|
|
44
|
+
domain[0].real + 1j*domain[1].imag,
|
|
45
|
+
domain[1].real + 1j*domain[0].imag
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
def domain_center(domain: Domain):
|
|
49
|
+
return np.mean(np.array(domain))
|
|
50
|
+
|
|
51
|
+
def subdomains(domain: Domain, divide_horizontal: bool, center: complex=None):
|
|
52
|
+
if center is None:
|
|
53
|
+
center = domain_center(domain)
|
|
54
|
+
left_up = domain[0].real + 1j*domain[1].imag
|
|
55
|
+
right_down = domain[1].real + 1j*domain[0].imag
|
|
56
|
+
|
|
57
|
+
subs = [
|
|
58
|
+
(center, domain[1]),
|
|
59
|
+
anti_domain((left_up, center)),
|
|
60
|
+
(domain[0], center),
|
|
61
|
+
anti_domain((center, right_down)),
|
|
62
|
+
]
|
|
63
|
+
|
|
64
|
+
if divide_horizontal:
|
|
65
|
+
return [(subs[1][0], subs[0][1]), (subs[2][0], subs[3][1])]
|
|
66
|
+
return [(subs[2][0], subs[1][1]), (subs[3][0], subs[0][1])]
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def cutoff_mask(z_k, f_k, f_k_dot, cutoff):
|
|
70
|
+
m = np.abs(f_k)<cutoff #filter out values, that have diverged too strongly
|
|
71
|
+
return z_k[m], f_k[m], f_k_dot[m]
|
|
72
|
+
|
|
73
|
+
def plot_domain(domain: Domain, size: float=1):
|
|
74
|
+
left_up = domain[0].real + 1j*domain[1].imag
|
|
75
|
+
right_down = domain[1].real + 1j*domain[0].imag
|
|
76
|
+
|
|
77
|
+
points = np.array([domain[0], right_down, domain[1], left_up, domain[0]])
|
|
78
|
+
|
|
79
|
+
return plt.plot(points.real, points.imag,
|
|
80
|
+
lw=size/30, zorder=1)
|
|
81
|
+
|
|
82
|
+
def all_poles_known(poles, prev, tol):
|
|
83
|
+
if prev is None or len(prev)!=len(poles):
|
|
84
|
+
return False
|
|
85
|
+
#return True
|
|
86
|
+
|
|
87
|
+
dist = np.abs(poles[:, None] - prev[None, :])
|
|
88
|
+
check = np.all(np.any(dist < tol, axis=1))
|
|
89
|
+
return check
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def selective_refinement_aaa(f: callable,
|
|
93
|
+
domain: Domain,
|
|
94
|
+
N: int = 36,
|
|
95
|
+
max_poles: int = 400,
|
|
96
|
+
cutoff: float = None,
|
|
97
|
+
tol_aaa: float = 1e-9,
|
|
98
|
+
tol_pol: float = 1e-5,
|
|
99
|
+
suggestions = None,
|
|
100
|
+
on_rim: bool = False,
|
|
101
|
+
Dmax=30,
|
|
102
|
+
use_adaptive: bool = True,
|
|
103
|
+
z_k = None, f_k = None,
|
|
104
|
+
divide_horizontal=True,
|
|
105
|
+
debug_name = "d",
|
|
106
|
+
stop = 0.1,
|
|
107
|
+
batchsize=10
|
|
108
|
+
):
|
|
109
|
+
"""
|
|
110
|
+
TODO: allow access to samples slightly outside of domain
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
print(f"start domain '{debug_name}', {Dmax=}")
|
|
114
|
+
folder = f"debug_out/{debug_name:0<33}"
|
|
115
|
+
domain_size = np.abs(domain[1]-domain[0])/2
|
|
116
|
+
size = domain_size/2 # for plotting
|
|
117
|
+
#plot_rect = plot_domain(domain, size=30)
|
|
118
|
+
#color = plot_rect[0].get_color()
|
|
119
|
+
|
|
120
|
+
if cutoff is None:
|
|
121
|
+
cutoff = np.inf
|
|
122
|
+
|
|
123
|
+
eval_count = 0
|
|
124
|
+
if use_adaptive:
|
|
125
|
+
pathlib.Path(folder).mkdir(parents=True, exist_ok=True)
|
|
126
|
+
sampling = Partial(next_samples_heat, debug=folder,
|
|
127
|
+
stop=stop, resolution=(101, 101), batchsize=batchsize)
|
|
128
|
+
if z_k is None:
|
|
129
|
+
z_k = np.empty((0,), dtype=complex)
|
|
130
|
+
f_k = z_k.copy()
|
|
131
|
+
|
|
132
|
+
if len(z_k) < N/4:
|
|
133
|
+
z_k_new = sample_domain(domain, N)
|
|
134
|
+
f_k = np.append(f_k, f(z_k_new))
|
|
135
|
+
z_k = np.append(z_k, z_k_new)
|
|
136
|
+
|
|
137
|
+
eval_count += len(z_k_new)
|
|
138
|
+
print(f"new eval: {eval_count}")
|
|
139
|
+
eval_count -= len(z_k)
|
|
140
|
+
z_j, f_j, w_j, z_n, z_k, f_k = adaptive_aaa(
|
|
141
|
+
z_k, f, f_k_0=f_k, evolutions=N*16, tol=tol_aaa,
|
|
142
|
+
domain=reduced_domain(domain, 1.07), radius=4*domain_size/(N), #NOTE: actually increased domain :/
|
|
143
|
+
return_samples=True, sampling=sampling, cutoff=np.inf
|
|
144
|
+
)
|
|
145
|
+
# TODO pass down samples in buffer zone
|
|
146
|
+
eval_count += len(z_k)
|
|
147
|
+
else:
|
|
148
|
+
if on_rim:
|
|
149
|
+
z_k = sample_rim(domain, N)
|
|
150
|
+
else:
|
|
151
|
+
z_k = sample_domain(reduced_domain(domain, 1.05), N)
|
|
152
|
+
f_k = f(z_k)
|
|
153
|
+
eval_count += len(f_k)
|
|
154
|
+
try:
|
|
155
|
+
z_j, f_j, w_j, z_n = aaa(z_k, f_k, tol=tol_aaa)
|
|
156
|
+
except onp.linalg.LinAlgError as e:
|
|
157
|
+
z_n = z_j = f_j = w_j = np.empty((0,))
|
|
158
|
+
|
|
159
|
+
print(f"domain '{debug_name}' done: {domain} -> eval: {eval_count}")
|
|
160
|
+
poles = z_n[domain_mask(domain, z_n)]
|
|
161
|
+
|
|
162
|
+
if (Dmax == 0 or
|
|
163
|
+
(len(poles)<=max_poles and all_poles_known(poles, suggestions, tol_pol))):
|
|
164
|
+
|
|
165
|
+
#plt.scatter(poles.real, poles.imag, color = color, marker="x")#, s=size*3, linewidths=size/2)
|
|
166
|
+
print("I am done here")
|
|
167
|
+
|
|
168
|
+
res = residues(z_j, f_j, w_j, poles)
|
|
169
|
+
return poles, res, eval_count
|
|
170
|
+
#plt.scatter(poles.real, poles.imag, color = color, marker="+", s=0.2, zorder=3)#, s=size, linewidths=size/6)
|
|
171
|
+
|
|
172
|
+
subs = subdomains(domain, divide_horizontal)
|
|
173
|
+
|
|
174
|
+
pol = np.empty((0,), dtype=complex)
|
|
175
|
+
res = pol.copy()
|
|
176
|
+
for i,sub in enumerate(subs):
|
|
177
|
+
sug = poles[domain_mask(sub, poles)]
|
|
178
|
+
sample_mask = domain_mask(sub, z_k)
|
|
179
|
+
|
|
180
|
+
known_z_k = z_k[sample_mask]
|
|
181
|
+
known_f_k = f_k[sample_mask]
|
|
182
|
+
|
|
183
|
+
p, r, e = selective_refinement_aaa(
|
|
184
|
+
f, sub, N, max_poles, cutoff, tol_aaa, tol_pol,
|
|
185
|
+
use_adaptive=use_adaptive,
|
|
186
|
+
suggestions=sug, Dmax=Dmax-1, z_k=known_z_k, f_k=known_f_k,
|
|
187
|
+
divide_horizontal = not divide_horizontal,
|
|
188
|
+
debug_name=f"{debug_name}{i+1}",
|
|
189
|
+
)
|
|
190
|
+
pol = np.append(pol, p)
|
|
191
|
+
res = np.append(res, r)
|
|
192
|
+
eval_count += e
|
|
193
|
+
# if len(pol) > 0:
|
|
194
|
+
# plt.xlim(domain[0].real, domain[1].real)
|
|
195
|
+
# plt.ylim(domain[0].imag, domain[1].imag)
|
|
196
|
+
# plt.savefig(f"debug_out/{debug_name:0<33}.png")
|
|
197
|
+
return pol, res, eval_count
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
from jax import config
|
|
2
|
+
config.update("jax_enable_x64", True) #important -> else aaa fails
|
|
3
|
+
import jax.numpy as np
|
|
4
|
+
from diffaaable.core import poles
|
|
5
|
+
|
|
6
|
+
def check_inputs(z_k, f_k):
|
|
7
|
+
f_k = np.array(f_k)
|
|
8
|
+
z_k = np.array(z_k)
|
|
9
|
+
|
|
10
|
+
if z_k.ndim != 1:
|
|
11
|
+
raise ValueError("z_k should be 1D but has shape {z_k.shape}")
|
|
12
|
+
M = z_k.shape[0]
|
|
13
|
+
|
|
14
|
+
if f_k.ndim == 1:
|
|
15
|
+
f_k = f_k[:, np.newaxis]
|
|
16
|
+
|
|
17
|
+
if f_k.ndim != 2 or f_k.shape[0]!=M:
|
|
18
|
+
raise ValueError("f_k should be 1 or 2D and have the same first"
|
|
19
|
+
f"dimension as z_k, {f_k.shape=}, {z_k.shape=}")
|
|
20
|
+
V = f_k.shape[1]
|
|
21
|
+
|
|
22
|
+
return z_k, f_k, M, V
|
|
23
|
+
|
|
24
|
+
def vectorial_aaa(z_k, f_k, tol=1e-13, mmax=100, return_errors=False):
|
|
25
|
+
"""Find a rational approximation to $\mathbf f(z)$ over the points $z_k$ using
|
|
26
|
+
a modified AAA algorithm, as presented in [^4]. Importantly the weights and
|
|
27
|
+
thus also the poles are shared between all entries of $\mathbf f(z)$.
|
|
28
|
+
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
z_k : array (M,):
|
|
32
|
+
M sample points
|
|
33
|
+
f_k : array (M, V):
|
|
34
|
+
vector valued function values
|
|
35
|
+
tol : float
|
|
36
|
+
the approximation tolerance
|
|
37
|
+
mmax : int
|
|
38
|
+
the maximum number of iterations/degree of the resulting approximant
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
"""
|
|
45
|
+
z_k, f_k, M, V = check_inputs(z_k, f_k)
|
|
46
|
+
|
|
47
|
+
J = np.ones(M, dtype=bool)
|
|
48
|
+
z_j = np.empty(0, dtype=z_k.dtype)
|
|
49
|
+
f_j = np.empty((0, V), dtype=f_k.dtype)
|
|
50
|
+
errors = []
|
|
51
|
+
|
|
52
|
+
reltol = tol * np.linalg.norm(f_k, np.inf)
|
|
53
|
+
|
|
54
|
+
r_k = np.mean(f_k) * np.ones_like(f_k)
|
|
55
|
+
|
|
56
|
+
mmax = min(mmax, len(f_k)//2)
|
|
57
|
+
|
|
58
|
+
for m in range(mmax):
|
|
59
|
+
# find largest residual
|
|
60
|
+
jj = np.argmax(np.linalg.norm(f_k - r_k, axis=-1)) #Next sample point to include
|
|
61
|
+
z_j = np.append(z_j, np.array([z_k[jj]]))
|
|
62
|
+
f_j = np.concatenate([f_j, f_k[jj][None, :]])
|
|
63
|
+
J = J.at[jj].set(False)
|
|
64
|
+
|
|
65
|
+
# Cauchy matrix containing the basis functions as columns
|
|
66
|
+
C = 1.0 / (z_k[J,None] - z_j[None,:])
|
|
67
|
+
# Loewner matrix
|
|
68
|
+
A = (f_k[J,None] - f_j[None,:]) * C[:,:,None]
|
|
69
|
+
|
|
70
|
+
# TODO: stack A
|
|
71
|
+
A = np.concatenate(np.moveaxis(A, -1, 0))
|
|
72
|
+
|
|
73
|
+
# compute weights as right singular vector for smallest singular value
|
|
74
|
+
if return_errors:
|
|
75
|
+
print("start SVD")
|
|
76
|
+
_, _, Vh = np.linalg.svd(A)
|
|
77
|
+
if return_errors:
|
|
78
|
+
print("finished SVD")
|
|
79
|
+
|
|
80
|
+
w_j = Vh[-1, :].conj()
|
|
81
|
+
|
|
82
|
+
# approximation: numerator / denominator
|
|
83
|
+
N = C.dot(w_j[:, None] * f_j) #TODO check it works
|
|
84
|
+
D = C.dot(w_j)[:, None]
|
|
85
|
+
|
|
86
|
+
# update residual
|
|
87
|
+
r_k = f_k.at[J].set(N / D)
|
|
88
|
+
|
|
89
|
+
# check for convergence
|
|
90
|
+
errors.append(np.linalg.norm(f_k - r_k, np.inf))
|
|
91
|
+
if return_errors:
|
|
92
|
+
print(errors[-1])
|
|
93
|
+
if errors[-1] <= reltol:
|
|
94
|
+
break
|
|
95
|
+
|
|
96
|
+
z_n = poles(z_j, w_j)
|
|
97
|
+
if return_errors:
|
|
98
|
+
return z_j, f_j, w_j, z_n, errors
|
|
99
|
+
return z_j, f_j, w_j, z_n
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def residues_vec(z_j,f_j,w_j,z_n):
|
|
103
|
+
'''Vectorial residues for given poles via formula for simple poles
|
|
104
|
+
of quotients of analytic functions. For a barycentric rational of order `m`
|
|
105
|
+
the
|
|
106
|
+
|
|
107
|
+
'''
|
|
108
|
+
|
|
109
|
+
C_pol = 1.0 / (z_n[:,None] - z_j[None,:])
|
|
110
|
+
N_pol = C_pol.dot((f_j*w_j[None,:]).T)
|
|
111
|
+
Ddiff_pol = (-C_pol**2).dot(w_j)
|
|
112
|
+
res = N_pol / Ddiff_pol[:, None]
|
|
113
|
+
|
|
114
|
+
return np.nan_to_num(res.T)
|
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
# https://setuptools.pypa.io/en/latest/userguide/pyproject_config.html
|
|
2
|
+
|
|
3
|
+
[build-system]
|
|
4
|
+
build-backend = "flit_core.buildapi"
|
|
5
|
+
requires = ["flit_core >=3.2,<4"]
|
|
6
|
+
|
|
7
|
+
[lint.pydocstyle]
|
|
8
|
+
convention = "google"
|
|
9
|
+
|
|
10
|
+
[project]
|
|
11
|
+
authors = [
|
|
12
|
+
{name = "Jan David Fischbach", email = "fischbach@kit.edu"}
|
|
13
|
+
]
|
|
14
|
+
classifiers = [
|
|
15
|
+
"Programming Language :: Python :: 3.9",
|
|
16
|
+
"Programming Language :: Python :: 3.10",
|
|
17
|
+
"Programming Language :: Python :: 3.11",
|
|
18
|
+
"Operating System :: OS Independent"
|
|
19
|
+
]
|
|
20
|
+
dependencies = [
|
|
21
|
+
"numpy",
|
|
22
|
+
"jax",
|
|
23
|
+
"jaxlib",
|
|
24
|
+
"baryrat",
|
|
25
|
+
"jaxopt",
|
|
26
|
+
"pytest",
|
|
27
|
+
"pytest-benchmark"
|
|
28
|
+
]
|
|
29
|
+
description = "JAX-differentiable AAA algorithm"
|
|
30
|
+
keywords = ["python"]
|
|
31
|
+
license = {file = "LICENSE"}
|
|
32
|
+
name = "diffaaable"
|
|
33
|
+
readme = "README.md"
|
|
34
|
+
requires-python = ">=3.9"
|
|
35
|
+
version = "1.0.1"
|
|
36
|
+
|
|
37
|
+
[project.optional-dependencies]
|
|
38
|
+
dev = [
|
|
39
|
+
"pre-commit",
|
|
40
|
+
"pytest",
|
|
41
|
+
"pytest-cov",
|
|
42
|
+
"pytest_regressions"
|
|
43
|
+
]
|
|
44
|
+
docs = [
|
|
45
|
+
"jupytext",
|
|
46
|
+
"matplotlib",
|
|
47
|
+
"jupyter-book",
|
|
48
|
+
"sphinx_math_dollar"
|
|
49
|
+
]
|
|
50
|
+
|
|
51
|
+
[tool.black]
|
|
52
|
+
exclude = ''' # Specify the files/dirs that should be ignored by the black formatter
|
|
53
|
+
/(
|
|
54
|
+
\.eggs
|
|
55
|
+
| \.git
|
|
56
|
+
| \.hg
|
|
57
|
+
| \.mypy_cache
|
|
58
|
+
| \.tox
|
|
59
|
+
| \.venv
|
|
60
|
+
| env
|
|
61
|
+
| _build
|
|
62
|
+
| buck-out
|
|
63
|
+
| build
|
|
64
|
+
| dist
|
|
65
|
+
)/
|
|
66
|
+
'''
|
|
67
|
+
include = '\.pyi?$'
|
|
68
|
+
line-length = 88
|
|
69
|
+
target-version = ['py39']
|
|
70
|
+
|
|
71
|
+
[tool.codespell]
|
|
72
|
+
ignore-words-list = ""
|
|
73
|
+
|
|
74
|
+
[tool.mypy]
|
|
75
|
+
python_version = "3.9"
|
|
76
|
+
strict = true
|
|
77
|
+
|
|
78
|
+
[tool.pytest.ini_options]
|
|
79
|
+
# addopts = --tb=no
|
|
80
|
+
addopts = '--tb=short'
|
|
81
|
+
norecursedirs = ["extra/*.py"]
|
|
82
|
+
python_files = ["diffaaable/*.py", "notebooks/*.ipynb", "tests/*.py"]
|
|
83
|
+
testpaths = ["diffaaable/", "tests"]
|
|
84
|
+
|
|
85
|
+
[tool.ruff]
|
|
86
|
+
fix = true
|
|
87
|
+
lint.ignore = [
|
|
88
|
+
"E501", # line too long, handled by black
|
|
89
|
+
"B008", # do not perform function calls in argument defaults
|
|
90
|
+
"C901", # too complex
|
|
91
|
+
"B905", # `zip()` without an explicit `strict=` parameter
|
|
92
|
+
"C408" # C408 Unnecessary `dict` call (rewrite as a literal)
|
|
93
|
+
]
|
|
94
|
+
lint.select = [
|
|
95
|
+
"E", # pycodestyle errors
|
|
96
|
+
"W", # pycodestyle warnings
|
|
97
|
+
"F", # pyflakes
|
|
98
|
+
"I", # isort
|
|
99
|
+
"C", # flake8-comprehensions
|
|
100
|
+
"B", # flake8-bugbear
|
|
101
|
+
"UP"
|
|
102
|
+
]
|
|
103
|
+
|
|
104
|
+
[tool.setuptools.packages]
|
|
105
|
+
find = {}
|
|
106
|
+
|
|
107
|
+
[tool.tbump]
|
|
108
|
+
|
|
109
|
+
[[tool.tbump.before_commit]]
|
|
110
|
+
cmd = "towncrier build --yes --version {new_version}"
|
|
111
|
+
name = "create & check changelog"
|
|
112
|
+
|
|
113
|
+
[[tool.tbump.before_commit]]
|
|
114
|
+
cmd = "git add CHANGELOG.md"
|
|
115
|
+
name = "create & check changelog"
|
|
116
|
+
|
|
117
|
+
[[tool.tbump.before_commit]]
|
|
118
|
+
cmd = "grep -q -F {new_version} CHANGELOG.md"
|
|
119
|
+
name = "create & check changelog"
|
|
120
|
+
|
|
121
|
+
# For each file to patch, add a [[file]] config
|
|
122
|
+
# section containing the path of the file, relative to the
|
|
123
|
+
# tbump.toml location.
|
|
124
|
+
[[tool.tbump.file]]
|
|
125
|
+
src = "README.md"
|
|
126
|
+
|
|
127
|
+
[[tool.tbump.file]]
|
|
128
|
+
src = "pyproject.toml"
|
|
129
|
+
|
|
130
|
+
[[tool.tbump.file]]
|
|
131
|
+
src = "diffaaable/__init__.py"
|
|
132
|
+
|
|
133
|
+
[tool.tbump.git]
|
|
134
|
+
message_template = "Bump to {new_version}"
|
|
135
|
+
tag_template = "v{new_version}"
|
|
136
|
+
|
|
137
|
+
[tool.tbump.version]
|
|
138
|
+
current = "1.0.1"
|
|
139
|
+
# Example of a semver regexp.
|
|
140
|
+
# Make sure this matches current_version before
|
|
141
|
+
# using tbump
|
|
142
|
+
regex = '''
|
|
143
|
+
(?P<major>\d+)
|
|
144
|
+
\.
|
|
145
|
+
(?P<minor>\d+)
|
|
146
|
+
\.
|
|
147
|
+
(?P<patch>\d+)
|
|
148
|
+
'''
|
|
149
|
+
|
|
150
|
+
[tool.towncrier]
|
|
151
|
+
directory = ".changelog.d"
|
|
152
|
+
filename = "CHANGELOG.md"
|
|
153
|
+
issue_format = "[#{issue}](https://github.com/jan-david-fischbach/diffaaable/issues/{issue})"
|
|
154
|
+
package = "diffaaable"
|
|
155
|
+
start_string = "<!-- towncrier release notes start -->\n"
|
|
156
|
+
template = ".changelog.d/changelog_template.jinja"
|
|
157
|
+
title_format = "## [{version}](https://github.com/jan-david-fischbach/diffaaable/releases/tag/v{version}) - {project_date}"
|
|
158
|
+
underlines = ["", "", ""]
|
|
159
|
+
|
|
160
|
+
[[tool.towncrier.type]]
|
|
161
|
+
directory = "security"
|
|
162
|
+
name = "Security"
|
|
163
|
+
showcontent = true
|
|
164
|
+
|
|
165
|
+
[[tool.towncrier.type]]
|
|
166
|
+
directory = "removed"
|
|
167
|
+
name = "Removed"
|
|
168
|
+
showcontent = true
|
|
169
|
+
|
|
170
|
+
[[tool.towncrier.type]]
|
|
171
|
+
directory = "deprecated"
|
|
172
|
+
name = "Deprecated"
|
|
173
|
+
showcontent = true
|
|
174
|
+
|
|
175
|
+
[[tool.towncrier.type]]
|
|
176
|
+
directory = "added"
|
|
177
|
+
name = "Added"
|
|
178
|
+
showcontent = true
|
|
179
|
+
|
|
180
|
+
[[tool.towncrier.type]]
|
|
181
|
+
directory = "changed"
|
|
182
|
+
name = "Changed"
|
|
183
|
+
showcontent = true
|
|
184
|
+
|
|
185
|
+
[[tool.towncrier.type]]
|
|
186
|
+
directory = "fixed"
|
|
187
|
+
name = "Fixed"
|
|
188
|
+
showcontent = true
|