diffaaable 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
diffaaable/__init__.py ADDED
@@ -0,0 +1,10 @@
1
+ """diffaaable - JAX-differentiable AAA algorithm"""
2
+
3
+ __version__ = "1.0.1"
4
+
5
+ __all__ = ["aaa", "adaptive_aaa", "vectorial_aaa", "lorentz_aaa"]
6
+
7
+ from diffaaable.core import aaa
8
+ from diffaaable.adaptive import adaptive_aaa
9
+ from diffaaable.vectorial import vectorial_aaa
10
+ from diffaaable.lorentz import lorentz_aaa
diffaaable/adaptive.py ADDED
@@ -0,0 +1,305 @@
1
+ from typing import Union
2
+ import jax.numpy as np
3
+ import numpy.typing as npt
4
+ from diffaaable import aaa
5
+ import jax
6
+ from jax import random
7
+ import matplotlib.pyplot as plt
8
+ from functools import partial
9
+ Domain = tuple[complex, complex]
10
+
11
+ def top_right(a: npt.NDArray[complex], b: npt.NDArray[complex]):
12
+ return np.logical_and(a.imag>=b.imag, a.real>=b.real)
13
+
14
+ def domain_mask(domain: Domain, z_n):
15
+ larger_min = top_right(z_n, domain[0])
16
+ smaller_max = top_right(domain[1], z_n)
17
+ return np.logical_and(larger_min, smaller_max)
18
+
19
+ @jax.tree_util.Partial
20
+ def next_samples(z_n, prev_z_n, z_k, domain: Domain, radius, randkey, tolerance=1e-9, min_samples=0, max_samples=0):
21
+ z_n = z_n[domain_mask(domain, z_n)]
22
+ movement = np.min(np.abs(z_n[:, None]-prev_z_n[None, :]), axis=-1)
23
+ ranking = np.argsort(-movement)
24
+ unstable = movement > tolerance
25
+ if np.sum(unstable) > min_samples:
26
+ if max_samples != 0 and np.sum(unstable)>max_samples:
27
+ z_n_unstable = z_n[ranking[:max_samples]]
28
+ else:
29
+ z_n_unstable = z_n[unstable]
30
+ else:
31
+ z_n_unstable = z_n[ranking[:min_samples]]
32
+
33
+ add_z_k = z_n_unstable
34
+ add_z_k += radius*np.exp(1j*2*np.pi*jax.random.uniform(randkey, add_z_k.shape))
35
+ return add_z_k
36
+
37
+
38
+ def heat(poles, samples, mesh, sigma):
39
+ #jax.debug.print("poles: {}", poles)
40
+ f_p = np.nansum(
41
+ np.exp(-np.abs(poles[:, None, None]-mesh[None, :])**2/sigma**2),
42
+ axis=0
43
+ )
44
+
45
+ #jax.debug.print("{}", f_p)
46
+
47
+ f = f_p / np.nansum(
48
+ sigma**2/np.abs(mesh[None, :]-samples[:, None, None])**2,
49
+ axis=0
50
+ )
51
+ return f
52
+
53
+ @partial(jax.jit, static_argnames=["resolution", "batchsize"])
54
+ def _next_samples_heat(
55
+ poles, prev_poles, samples, domain, radius, randkey, resolution=(101, 101),
56
+ batchsize=1, stop=0.2
57
+ ):
58
+
59
+ x = np.linspace(domain[0].real, domain[1].real, resolution[0])
60
+ y = np.linspace(domain[0].imag, domain[1].imag, resolution[1])
61
+
62
+ X, Y = np.meshgrid(x,y, indexing="ij")
63
+ mesh = X +1j*Y
64
+
65
+ add_samples = np.empty(0, dtype=complex)
66
+ for j in range(batchsize):
67
+ heat_map = heat(poles, np.concat([samples, add_samples]), mesh, sigma=radius)
68
+ next_i = np.unravel_index(np.nanargmax(heat_map), heat_map.shape)
69
+
70
+ next = np.where(heat_map[next_i] < stop, np.nan, mesh[next_i])
71
+ add_samples = np.append(add_samples, next)
72
+
73
+ return add_samples, X, Y, heat(poles, np.concat([samples]), mesh, sigma=radius)
74
+
75
+ @jax.tree_util.Partial
76
+ def next_samples_heat(
77
+ poles, prev_poles, samples, domain, radius, randkey, resolution=(101, 101),
78
+ batchsize=1, stop=0.2, debug=False, debug_known_poles=None
79
+ ):
80
+
81
+ add_samples, X, Y, heat_map = _next_samples_heat(
82
+ poles, prev_poles, samples, domain, radius, randkey, resolution,
83
+ batchsize, stop
84
+ )
85
+
86
+ add_samples = add_samples[~np.isnan(add_samples)]
87
+
88
+ if debug:
89
+ ax = plt.gca()
90
+ plt.figure()
91
+ plt.title(f"radius={radius}")
92
+ if resolution[1]>1:
93
+ plt.pcolormesh(X, Y, heat_map, vmax=1)#, alpha=np.clip(heat_map, 0, 1))
94
+ plt.colorbar()
95
+ plt.scatter(samples.real, samples.imag, label="samples", zorder=1)
96
+ plt.scatter(poles.real, poles.imag, color="C1", marker="x", label="est. pole", zorder=2)
97
+ plt.scatter(add_samples.real, add_samples.imag, color="C2", label="next samples", zorder=3)
98
+ if debug_known_poles is not None:
99
+ plt.scatter(debug_known_poles.real, debug_known_poles.imag, color="C3", marker="+", label="known pole")
100
+ plt.xlim(domain[0].real, domain[1].real)
101
+ plt.ylim(domain[0].imag, domain[1].imag)
102
+ plt.legend(loc="lower right")
103
+ else:
104
+ plt.plot(np.squeeze(X), np.squeeze(heat_map))
105
+ plt.scatter(samples.real, np.zeros(len(samples)))
106
+ plt.scatter(poles.real, np.zeros(len(poles)), color="C1", marker="x", label="est. pole", zorder=2)
107
+ plt.xlim(domain[0].real, domain[1].real)
108
+ plt.savefig(f"{debug}/{len(samples)}.png")
109
+ plt.close()
110
+ plt.sca(ax)
111
+
112
+ return add_samples
113
+
114
+ aaa = jax.tree_util.Partial(aaa)
115
+
116
+ def _adaptive_aaa(z_k_0: npt.NDArray,
117
+ f: callable,
118
+ evolutions: int = 2,
119
+ cutoff: float = None,
120
+ tol: float = 1e-9,
121
+ mmax: int = 100,
122
+ radius: float = None,
123
+ domain: tuple[complex, complex] = None,
124
+ f_k_0: npt.NDArray = None,
125
+ sampling: Union[callable, str] = next_samples,
126
+ prev_z_n: npt.NDArray = None,
127
+ return_samples: bool = False,
128
+ aaa: callable = aaa,
129
+ f_dot: callable = None):
130
+ """
131
+ Implementation of `adaptive_aaa`
132
+
133
+ Parameters
134
+ ----------
135
+ see `adaptive_aaa`
136
+
137
+ f_dot: callable
138
+ Tangent of `f`. If provided JVPs of `f` will be collected throughout the
139
+ iterations. For use in custom_jvp
140
+ """
141
+
142
+ if sampling == "heat":
143
+ sampling = next_samples_heat
144
+
145
+ collect_tangents = f_dot is not None
146
+ z_k = z_k_0
147
+ max_dist = np.max(np.abs(z_k_0[:, np.newaxis] - z_k_0[np.newaxis, :]))
148
+
149
+ if collect_tangents:
150
+ f_unpartial = f.func
151
+ args, _ = jax.tree.flatten(f)
152
+ args_dot, _ = jax.tree.flatten(f_dot)
153
+ z_k_dot = np.zeros_like(z_k)
154
+ f_k, f_k_dot = jax.jvp(
155
+ f_unpartial, (*args, z_k), (*args_dot, z_k_dot)
156
+ )
157
+ else:
158
+ if f_k_0 is None:
159
+ f_k = f(z_k)
160
+ else:
161
+ f_k = f_k_0
162
+ f_k_dot = np.zeros_like(f_k)
163
+
164
+ if cutoff is None:
165
+ cutoff = 1e10*np.median(np.abs(f_k))
166
+
167
+ if domain is None:
168
+ center = np.mean(z_k)
169
+ disp = max_dist*(1+1j)
170
+ domain = (center-disp, center+disp)
171
+
172
+ if radius is None:
173
+ radius = 1e-3 * max_dist
174
+
175
+ if prev_z_n is None:
176
+ prev_z_n = np.array([np.inf], dtype=complex)
177
+
178
+ def mask(z_k, f_k, f_k_dot):
179
+ m = np.abs(f_k)<cutoff
180
+ if m.ndim == 2:
181
+ m = np.all(m, axis=1) #filter out values, that have diverged too strongly
182
+ return z_k[m], f_k[m], f_k_dot[m]
183
+
184
+ key = random.key(0)
185
+ for i in range(evolutions):
186
+ z_k, f_k, f_k_dot = mask(z_k, f_k, f_k_dot)
187
+ z_j, f_j, w_j, z_n = aaa(z_k, f_k, tol, mmax)
188
+
189
+ if i==evolutions-1:
190
+ break
191
+
192
+ key, subkey = jax.random.split(key)
193
+ #print(f"{z_n=}")
194
+ #print(f"{prev_z_n=}")
195
+ add_z_k = sampling(z_n, prev_z_n, z_k, domain, radius, subkey)
196
+ prev_z_n = z_n
197
+ add_z_k_dot = np.zeros_like(add_z_k)
198
+
199
+ if len(add_z_k) == 0:
200
+ break
201
+
202
+ if collect_tangents:
203
+ f_k_new , f_k_dot_new = jax.jvp(
204
+ f_unpartial, (*args, add_z_k), (*args_dot, add_z_k_dot)
205
+ )
206
+ else:
207
+ f_k_new = f(add_z_k)
208
+ f_k_dot_new = np.zeros_like(f_k_new)
209
+
210
+ z_k = np.append(z_k, add_z_k)
211
+ f_k = np.concatenate([f_k, f_k_new])
212
+ f_k_dot = np.concatenate([f_k_dot, f_k_dot_new])
213
+
214
+ z_k, f_k, f_k_dot = mask(z_k, f_k, f_k_dot)
215
+
216
+ if collect_tangents:
217
+ return z_k, f_k, f_k_dot
218
+ if return_samples:
219
+ return z_j, f_j, w_j, z_n, z_k, f_k
220
+ return z_j, f_j, w_j, z_n
221
+
222
+ @jax.custom_jvp
223
+ def adaptive_aaa(z_k_0: npt.NDArray,
224
+ f:callable,
225
+ evolutions: int = 2,
226
+ cutoff: float = None,
227
+ tol: float = 1e-9,
228
+ mmax: int = 100,
229
+ radius: float = None,
230
+ domain: Domain = None,
231
+ f_k_0: npt.NDArray = None,
232
+ sampling: callable = next_samples,
233
+ prev_z_n: npt.NDArray = None,
234
+ return_samples: bool = False,
235
+ aaa: callable = aaa):
236
+ """ An 2x adaptive Antoulas–Anderson algorithm for rational approximation of
237
+ meromorphic functions that are costly to evaluate.
238
+
239
+ The algorithm iteratively places additional sample points close to estimated
240
+ positions of poles identified during the past iteration. By this refinement
241
+ scheme the number of function evaluations can be reduced.
242
+
243
+ It is JAX differentiable wrt. the approximated function `f`, via its other
244
+ arguments besides `z`. `f` should be provided as a `jax.tree_util.Partial`
245
+ with only positional arguments pre-filled!
246
+
247
+ Parameters
248
+ ----------
249
+ z_k_0 : np.ndarray
250
+ Array of initial sample points
251
+ f : callable
252
+ function to be approximated. When using gradients `f` should be provided
253
+ as a `jax.tree_util.Partial` with only positional arguments pre-filled.
254
+ Furthermore it should be compatible to `jax.jvp`.
255
+ evolutions: int
256
+ Number of refinement iterations
257
+ cutoff: float
258
+ Maximum absolute value a function evaluation should take
259
+ to be regarded valid. Otherwise the sample point is discarded.
260
+ Defaults to 1e10 times the median of `f(z_k_0)`
261
+ tol: float
262
+ Tolerance used in AAA (see `diffaaable.aaa`)
263
+ radius: float
264
+ Distance from the assumed poles for nex samples
265
+ domain: tuple[complex, complex]
266
+ Tuple of min (lower left) and max (upper right) values defining a
267
+ rectangle in the complex plane. Assumed poles outside of the domain
268
+ will not receive refinement.
269
+ f_k_0:
270
+ Allows user to provide f evaluated at z_k_0
271
+
272
+
273
+ Returns
274
+ -------
275
+ z_j: np.array
276
+ chosen samples
277
+ f_j: np.array
278
+ `f(z_j)`
279
+ w_j: np.array
280
+ Weights of Barycentric Approximation
281
+ z_n: np.array
282
+ Poles of Barycentric Approximation
283
+ """
284
+ return _adaptive_aaa(
285
+ z_k_0=z_k_0, f=f, evolutions=evolutions, cutoff=cutoff, tol=tol, mmax=mmax,
286
+ radius=radius, domain=domain, f_k_0=f_k_0, sampling=sampling,
287
+ prev_z_n=prev_z_n, return_samples=return_samples, aaa=aaa
288
+ )
289
+
290
+ @adaptive_aaa.defjvp
291
+ def adaptive_aaa_jvp(primals, tangents):
292
+ z_k_0, f = primals[:2]
293
+ z_dot, f_dot = tangents[:2]
294
+
295
+ if np.any(z_dot):
296
+ raise NotImplementedError(
297
+ "Parametrizing the sampling positions z_k is not supported"
298
+ )
299
+
300
+ z_k, f_k, f_k_dot = \
301
+ _adaptive_aaa(z_k_0, f, *primals[2:], f_dot=f_dot)
302
+
303
+ z_k_dot = np.zeros_like(z_k)
304
+
305
+ return jax.jvp(aaa, (z_k, f_k), (z_k_dot, f_k_dot))
diffaaable/core.py ADDED
@@ -0,0 +1,224 @@
1
+ from jax import config
2
+ config.update("jax_enable_x64", True) #important -> else aaa fails
3
+ import jax.numpy as jnp
4
+ import numpy.typing as npt
5
+ import jax
6
+ import numpy as np
7
+ from baryrat import aaa as oaaa # ordinary aaa
8
+ import functools
9
+ import scipy.linalg
10
+
11
+ @functools.wraps(oaaa)
12
+ @jax.custom_jvp
13
+ def aaa(z_k: npt.NDArray, f_k: npt.NDArray, tol: float=1e-13, mmax: int=100):
14
+ """
15
+ Wraped aaa to enable JAX based autodiff.
16
+ """
17
+ r = oaaa(z_k, f_k, tol=tol, mmax=mmax)
18
+ z_j = r.nodes
19
+ f_j = r.values
20
+ w_j = r.weights
21
+
22
+ mask = w_j!=0
23
+ z_j = z_j[mask]
24
+ f_j = f_j[mask]
25
+ w_j = w_j[mask]
26
+
27
+ z_n = poles(z_j, w_j)
28
+
29
+ z_n = z_n[jnp.argsort(-jnp.abs(z_n))]
30
+
31
+ return z_j, f_j, w_j, z_n
32
+
33
+ delimiter = ' \n'
34
+
35
+ aaa.__doc__ = f"""\
36
+ This is a wrapped version of `aaa` as provided by `baryrat`,
37
+ providing a custom jvp to enable differentiability.
38
+
39
+ For detailed information on the usage of `aaa` please refer to
40
+ the original documentation::
41
+
42
+ {delimiter.join(aaa.__doc__.splitlines())}
43
+
44
+ .. attention::
45
+ Returns nodes, values, weights and poles, in contrast to the
46
+ baryrat implementation that returns the BarycentricRational ready to be
47
+ evaluated. This is done to facilitate differentiability.
48
+
49
+ Parameters
50
+ ----------
51
+ z_k : array (M,)
52
+ the sampling points of the function. Unlike for interpolation
53
+ algorithms, where a small number of nodes is preferred, since the
54
+ AAA algorithm chooses its support points adaptively, it is better
55
+ to provide a finer mesh over the support.
56
+ f_k : array (M,)
57
+ the function to be approximated; can be given as a callable function
58
+ or as an array of function values over `z_k`.
59
+ tol : float
60
+ the approximation tolerance
61
+ mmax : int
62
+ the maximum number of iterations/degree of the resulting approximant
63
+
64
+ Returns
65
+ -------
66
+ z_j : array (m,)
67
+ nodes of the barycentric approximant
68
+
69
+ f_j : array (m,)
70
+ values of the barycentric approximant
71
+
72
+ w_j : array (m,)
73
+ weights of the barycentric approximant
74
+
75
+ z_n : array (m-1,)
76
+ poles of the barycentric approximant (for convenience)
77
+
78
+ """
79
+
80
+ @aaa.defjvp
81
+ def aaa_jvp(primals, tangents):
82
+ r"""Derivatives according to [1].
83
+ The implemented matrix expressions are motivated in the appendix of [2]:
84
+
85
+ .. topic:: AAA Derivatives
86
+
87
+ Here we will briefly elaborate how the derivatives introduced in [1] are
88
+ implemented as JAX compatible Jacobian Vector products (JVPs) in `diffaaable`.
89
+
90
+ Given the tangents $\frac{\partial f_k}{\partial p}$ we will use the chain rule on $r(w_j, f_j, z)$ along its weights $w_j$ and values $f_j$ (the nodes $z_j$ are treated as independent of $p$) (Equation 4 of [1]):
91
+ \[
92
+ \frac{\partial f_k}{\partial p} \approx \frac{\partial r_k}{\partial p}= \sum_{j=1}^m\frac{\partial r_k}{\partial f_j}\frac{\partial f_j}{\partial p}+\sum_{j=1}^m\frac{\partial r_k}{\partial w_j}\frac{\partial w_j}{\partial p}
93
+ \]
94
+
95
+ To solve this system of equations for $\frac{\partial w_j}{\partial p}$ we express it in matrix form:
96
+
97
+ .. math::
98
+ \mathbf{A}\mathbf{w}^\prime = \mathbf{b}
99
+
100
+ where $\mathbf{w}^\prime$ is the column vector containing $\frac{\partial w_j}{\partial p}$ and $\mathbf{b}$ and $\mathbf{A}$ are defined element wise:
101
+
102
+ .. math::
103
+ \begin{aligned}
104
+ b_k &= \frac{\partial f_k}{\partial p} - \sum_{j=1}^m\frac{\partial r_k}{\partial f_j}\frac{\partial f_j}{\partial p}\\
105
+ A_{kj} &= \frac{\partial r_k}{\partial w_j}
106
+ \end{aligned}
107
+
108
+
109
+ These are augmented with Equation 5 of \cite{betzEfficientRationalApproximation2024} which removes the ambiguity in $\mathbf{w}^\prime$ associated with a shared phase of all weights.
110
+
111
+ The expressions for the derivatives of $r_k$ are found in the definition of $r(z)$ (Equation 1 in [2]): \[ \frac{\partial r_k}{\partial f_j}= \frac{1}{d(z_k)} \frac{w_j}{z_k-z_j}\] and
112
+ \[
113
+ \frac{\partial r_k}{\partial w_j}= \frac{1}{d(z_k)} \frac{f_j-r_k}{z_k-z_j} \approx \frac{1}{d(z_k)} \frac{f_j-f_k}{z_k-z_j}
114
+ \]
115
+
116
+ """
117
+ z_k_full, f_k = primals[:2]
118
+ z_dot, f_dot = tangents[:2]
119
+
120
+ primal_out = aaa(z_k_full, f_k)
121
+ z_j, f_j, w_j, z_n = primal_out
122
+
123
+ chosen = np.isin(z_k_full, z_j)
124
+
125
+ z_k = z_k_full[~chosen]
126
+ f_k = f_k[~chosen]
127
+
128
+ # z_dot should be zero anyways
129
+ if np.any(z_dot):
130
+ raise NotImplementedError("Parametrizing the sampling positions z_k is not supported")
131
+ z_k_dot = z_dot[~chosen]
132
+ f_k_dot = f_dot[~chosen] # $\del f_k / \del p$
133
+
134
+ ##################################################
135
+ # We have to track which f_dot corresponds to z_k
136
+ sort_orig = jnp.argsort(jnp.abs(z_k_full[chosen]))
137
+ sort_out = jnp.argsort(jnp.argsort(jnp.abs(z_j)))
138
+
139
+ z_j_dot = z_dot[chosen][sort_orig][sort_out]
140
+ f_j_dot = f_dot[chosen][sort_orig][sort_out]
141
+ ##################################################
142
+
143
+ C = 1/(z_k[:, None]-z_j[None, :]) # Cauchy matrix k x j
144
+
145
+ d = C @ w_j # denominator in barycentric formula
146
+ via_f_j = C @ (f_j_dot * w_j) / d # $\sum_j f_j^\prime \frac{\del r}{\del f_j}$
147
+
148
+ A = (f_j[None, :] - f_k[:, None])*C/d[:, None]
149
+ b = f_k_dot - via_f_j
150
+
151
+ # make sure system is not underdetermined according to eq. 5 of [1]
152
+ A = jnp.concatenate([A, np.conj(w_j.reshape(1, -1))])
153
+ b = jnp.append(b, 0)
154
+
155
+ with jax.disable_jit(): #otherwise backwards differentiation led to error
156
+ w_j_dot, _, _, _ = jnp.linalg.lstsq(A, b)
157
+
158
+ denom = z_n.reshape(1, -1)-z_j.reshape(-1, 1)
159
+ # jax.debug.print("wj: {}", w_j.reshape(-1, 1))
160
+ # jax.debug.print("denom^2: {}", denom**2)
161
+ z_n_dot = (
162
+ jnp.sum(w_j_dot.reshape(-1, 1)/denom, axis=0)/
163
+ jnp.sum(w_j.reshape(-1, 1) /denom**2, axis=0)
164
+ )
165
+
166
+ tangent_out = z_j_dot, f_j_dot, w_j_dot, z_n_dot
167
+
168
+ return primal_out, tangent_out
169
+
170
+ def poles(z_j,w_j):
171
+ """
172
+ The poles of a barycentric rational with given nodes and weights.
173
+ Poles lifted by zeros of the nominator are included.
174
+ Thus the values $f_j$ do not contribute and don't need to be provided
175
+ The implementation was modified from `baryrat` to support JAX AD.
176
+
177
+ Parameters
178
+ ----------
179
+ z_j : array (m,)
180
+ nodes of the barycentric rational
181
+ w_j : array (m,)
182
+ weights of the barycentric rational
183
+
184
+ Returns
185
+ -------
186
+ z_n : array (m-1,)
187
+ poles of the barycentric rational (more strictly zeros of the denominator)
188
+ """
189
+ f_j = np.ones_like(z_j)
190
+
191
+ B = np.eye(len(w_j) + 1)
192
+ B[0,0] = 0
193
+ E = np.block([[0, w_j],
194
+ [f_j[:,None], np.diag(z_j)]])
195
+ evals = scipy.linalg.eigvals(E, B)
196
+ return evals[np.isfinite(evals)]
197
+
198
+ def residues(z_j,f_j,w_j,z_n):
199
+ '''
200
+ Residues for given poles via formula for simple poles
201
+ of quotients of analytic functions.
202
+ The implementation was modified from `baryrat` to support JAX AD.
203
+
204
+ Parameters
205
+ ----------
206
+ z_j : array (m,)
207
+ nodes of the barycentric rational
208
+ w_j : array (m,)
209
+ weights of the barycentric rational
210
+ z_n : array (n,)
211
+ poles of interest of the barycentric rational (n<=m-1)
212
+
213
+ Returns
214
+ -------
215
+ r_n : array (n,)
216
+ residues of poles `z_n`
217
+ '''
218
+
219
+ C_pol = 1.0 / (z_n[:,None] - z_j[None,:])
220
+ N_pol = C_pol.dot(f_j*w_j)
221
+ Ddiff_pol = (-C_pol**2).dot(w_j)
222
+ res = N_pol / Ddiff_pol
223
+
224
+ return jnp.nan_to_num(res)
diffaaable/lorentz.py ADDED
@@ -0,0 +1,88 @@
1
+ from jax import config
2
+ config.update("jax_enable_x64", True) #important -> else aaa fails
3
+ import jax.numpy as np
4
+ import jax
5
+ from diffaaable.vectorial import check_inputs
6
+ import jaxopt
7
+
8
+ def optimal_weights(A, A_hat, stepsize=0.5):
9
+ # Initial guess (from uncorrected A)
10
+ _, _, Vh = np.linalg.svd(A)
11
+ w_j = Vh[-1, :].conj()
12
+
13
+ def obj_fun(w):
14
+ w /= np.linalg.norm(w)
15
+ err = (A @ w) + (A_hat @ w.conj())
16
+ return np.linalg.norm(err)
17
+
18
+ solver = jaxopt.LBFGS(fun=obj_fun, maxiter=300, tol=1e-13)
19
+ res = solver.run(w_j)
20
+ w_j, state = res
21
+ w_j /= np.linalg.norm(w_j)
22
+
23
+ return w_j
24
+
25
+ def lorentz_aaa(z_k, f_k, tol=1e-9, mmax=100, return_errors=False):
26
+ """
27
+ """
28
+ z_k, f_k, M, V = check_inputs(z_k, f_k)
29
+
30
+ J = np.ones(M, dtype=bool)
31
+ z_j = np.empty(0, dtype=z_k.dtype)
32
+ f_j = np.empty((0, V), dtype=f_k.dtype)
33
+ errors = []
34
+
35
+ reltol = tol * np.linalg.norm(f_k, np.inf)
36
+
37
+ r_k = np.mean(f_k) * np.ones_like(f_k)
38
+ # approx.
39
+
40
+ for m in range(mmax):
41
+ # find largest residual
42
+ jj = np.argmax(np.linalg.norm(f_k - r_k, axis=-1)) #Next sample point to include
43
+ z_j = np.append(z_j, np.array([z_k[jj]]))
44
+ f_j = np.concatenate([f_j, f_k[jj][None, :]])
45
+ J = J.at[jj].set(False)
46
+
47
+ # Cauchy matrix containing the basis functions as columns
48
+ C = 1.0 / (z_k[J,None] - z_j[None,:])
49
+ # Loewner matrix
50
+ A = (f_k[J,None] - f_j[None,:]) * C[:,:,None]
51
+ A = np.concatenate(np.moveaxis(A, -1, 0))
52
+
53
+ # Lorentz Correction
54
+ C_hat = 1.0 / (z_k[J,None] + np.conj(z_j)[None,:])
55
+ # Loewner matrix
56
+ A_hat = (f_k[J,None] - np.conj(f_j)[None,:]) * C_hat[:,:,None]
57
+ A_hat = np.concatenate(np.moveaxis(A_hat, -1, 0))
58
+
59
+ w_j = optimal_weights(A, A_hat)
60
+
61
+ # approximation: numerator / denominator
62
+ N = C.dot(w_j[:, None] * f_j)
63
+ N_hat = C_hat.dot(np.conj(w_j[:, None] * f_j))
64
+
65
+ D = C.dot(w_j)[:, None]
66
+ D_hat = C_hat.dot(np.conj(w_j))[:, None]
67
+
68
+
69
+
70
+ # update approximation
71
+ r_k = f_k.at[J].set((N + N_hat) / (D+D_hat))
72
+
73
+ # check for convergence
74
+ errors.append(np.linalg.norm(f_k - r_k, np.inf))
75
+ if errors[-1] <= reltol:
76
+ break
77
+
78
+ if V == 1:
79
+ f_j = f_j[:, 0]
80
+
81
+
82
+ z_j = np.concatenate([z_j, -np.conj(z_j)])
83
+ f_j = np.concatenate([f_j, np.conj(f_j)])
84
+ w_j = np.concatenate([w_j, np.conj(w_j)])
85
+
86
+ if return_errors:
87
+ return z_j, f_j, w_j, errors
88
+ return z_j, f_j, w_j
@@ -0,0 +1,197 @@
1
+ import os
2
+ import pathlib
3
+ import jax.numpy as np
4
+ import numpy as onp
5
+ from jax.tree_util import Partial
6
+ from diffaaable.core import aaa, residues
7
+ from diffaaable.adaptive import Domain, domain_mask, adaptive_aaa, next_samples_heat
8
+ import matplotlib.pyplot as plt
9
+
10
+ def reduced_domain(domain, reduction=1-1/12):
11
+ r = reduction
12
+ return (
13
+ domain[0]*r+domain[1]*(1-r),
14
+ domain[1]*r+domain[0]*(1-r)
15
+ )
16
+
17
+ def sample_cross(domain):
18
+ center = domain_center(domain)
19
+ dist = 0.5 * (domain[1]-domain[0])
20
+ return center+np.array([dist.real, -dist.real, 1j*dist.imag, -1j*dist.imag])
21
+
22
+ def sample_domain(domain: Domain, N: int):
23
+ sqrt_N = np.round(np.sqrt(N)).astype(int)
24
+ domain = reduced_domain(domain)
25
+ z_k_r = np.linspace(domain[0].real, domain[1].real, sqrt_N)
26
+ z_k_i = np.linspace(domain[0].imag, domain[1].imag, sqrt_N)
27
+ Z_r, Z_i = np.meshgrid(z_k_r, z_k_i)
28
+ z_k = (Z_r+1j*Z_i).flatten()
29
+ return z_k
30
+
31
+ def sample_rim(domain: Domain, N: int):
32
+ side_N = N//4
33
+ z_k_r = np.linspace(domain[0].real, domain[1].real, side_N+2)[1:-1]
34
+ z_k_i = np.linspace(domain[0].imag, domain[1].imag, side_N+2)[1:-1] * 1j
35
+ return np.array([
36
+ 1j*domain[0].imag + z_k_r,
37
+ 1j*domain[1].imag + z_k_r,
38
+ domain[0].real + z_k_i,
39
+ domain[1].real + z_k_i
40
+ ]).flatten()
41
+
42
+ def anti_domain(domain: Domain):
43
+ return (
44
+ domain[0].real + 1j*domain[1].imag,
45
+ domain[1].real + 1j*domain[0].imag
46
+ )
47
+
48
+ def domain_center(domain: Domain):
49
+ return np.mean(np.array(domain))
50
+
51
+ def subdomains(domain: Domain, divide_horizontal: bool, center: complex=None):
52
+ if center is None:
53
+ center = domain_center(domain)
54
+ left_up = domain[0].real + 1j*domain[1].imag
55
+ right_down = domain[1].real + 1j*domain[0].imag
56
+
57
+ subs = [
58
+ (center, domain[1]),
59
+ anti_domain((left_up, center)),
60
+ (domain[0], center),
61
+ anti_domain((center, right_down)),
62
+ ]
63
+
64
+ if divide_horizontal:
65
+ return [(subs[1][0], subs[0][1]), (subs[2][0], subs[3][1])]
66
+ return [(subs[2][0], subs[1][1]), (subs[3][0], subs[0][1])]
67
+
68
+
69
+ def cutoff_mask(z_k, f_k, f_k_dot, cutoff):
70
+ m = np.abs(f_k)<cutoff #filter out values, that have diverged too strongly
71
+ return z_k[m], f_k[m], f_k_dot[m]
72
+
73
+ def plot_domain(domain: Domain, size: float=1):
74
+ left_up = domain[0].real + 1j*domain[1].imag
75
+ right_down = domain[1].real + 1j*domain[0].imag
76
+
77
+ points = np.array([domain[0], right_down, domain[1], left_up, domain[0]])
78
+
79
+ return plt.plot(points.real, points.imag,
80
+ lw=size/30, zorder=1)
81
+
82
+ def all_poles_known(poles, prev, tol):
83
+ if prev is None or len(prev)!=len(poles):
84
+ return False
85
+ #return True
86
+
87
+ dist = np.abs(poles[:, None] - prev[None, :])
88
+ check = np.all(np.any(dist < tol, axis=1))
89
+ return check
90
+
91
+
92
+ def selective_refinement_aaa(f: callable,
93
+ domain: Domain,
94
+ N: int = 36,
95
+ max_poles: int = 400,
96
+ cutoff: float = None,
97
+ tol_aaa: float = 1e-9,
98
+ tol_pol: float = 1e-5,
99
+ suggestions = None,
100
+ on_rim: bool = False,
101
+ Dmax=30,
102
+ use_adaptive: bool = True,
103
+ z_k = None, f_k = None,
104
+ divide_horizontal=True,
105
+ debug_name = "d",
106
+ stop = 0.1,
107
+ batchsize=10
108
+ ):
109
+ """
110
+ TODO: allow access to samples slightly outside of domain
111
+ """
112
+
113
+ print(f"start domain '{debug_name}', {Dmax=}")
114
+ folder = f"debug_out/{debug_name:0<33}"
115
+ domain_size = np.abs(domain[1]-domain[0])/2
116
+ size = domain_size/2 # for plotting
117
+ #plot_rect = plot_domain(domain, size=30)
118
+ #color = plot_rect[0].get_color()
119
+
120
+ if cutoff is None:
121
+ cutoff = np.inf
122
+
123
+ eval_count = 0
124
+ if use_adaptive:
125
+ pathlib.Path(folder).mkdir(parents=True, exist_ok=True)
126
+ sampling = Partial(next_samples_heat, debug=folder,
127
+ stop=stop, resolution=(101, 101), batchsize=batchsize)
128
+ if z_k is None:
129
+ z_k = np.empty((0,), dtype=complex)
130
+ f_k = z_k.copy()
131
+
132
+ if len(z_k) < N/4:
133
+ z_k_new = sample_domain(domain, N)
134
+ f_k = np.append(f_k, f(z_k_new))
135
+ z_k = np.append(z_k, z_k_new)
136
+
137
+ eval_count += len(z_k_new)
138
+ print(f"new eval: {eval_count}")
139
+ eval_count -= len(z_k)
140
+ z_j, f_j, w_j, z_n, z_k, f_k = adaptive_aaa(
141
+ z_k, f, f_k_0=f_k, evolutions=N*16, tol=tol_aaa,
142
+ domain=reduced_domain(domain, 1.07), radius=4*domain_size/(N), #NOTE: actually increased domain :/
143
+ return_samples=True, sampling=sampling, cutoff=np.inf
144
+ )
145
+ # TODO pass down samples in buffer zone
146
+ eval_count += len(z_k)
147
+ else:
148
+ if on_rim:
149
+ z_k = sample_rim(domain, N)
150
+ else:
151
+ z_k = sample_domain(reduced_domain(domain, 1.05), N)
152
+ f_k = f(z_k)
153
+ eval_count += len(f_k)
154
+ try:
155
+ z_j, f_j, w_j, z_n = aaa(z_k, f_k, tol=tol_aaa)
156
+ except onp.linalg.LinAlgError as e:
157
+ z_n = z_j = f_j = w_j = np.empty((0,))
158
+
159
+ print(f"domain '{debug_name}' done: {domain} -> eval: {eval_count}")
160
+ poles = z_n[domain_mask(domain, z_n)]
161
+
162
+ if (Dmax == 0 or
163
+ (len(poles)<=max_poles and all_poles_known(poles, suggestions, tol_pol))):
164
+
165
+ #plt.scatter(poles.real, poles.imag, color = color, marker="x")#, s=size*3, linewidths=size/2)
166
+ print("I am done here")
167
+
168
+ res = residues(z_j, f_j, w_j, poles)
169
+ return poles, res, eval_count
170
+ #plt.scatter(poles.real, poles.imag, color = color, marker="+", s=0.2, zorder=3)#, s=size, linewidths=size/6)
171
+
172
+ subs = subdomains(domain, divide_horizontal)
173
+
174
+ pol = np.empty((0,), dtype=complex)
175
+ res = pol.copy()
176
+ for i,sub in enumerate(subs):
177
+ sug = poles[domain_mask(sub, poles)]
178
+ sample_mask = domain_mask(sub, z_k)
179
+
180
+ known_z_k = z_k[sample_mask]
181
+ known_f_k = f_k[sample_mask]
182
+
183
+ p, r, e = selective_refinement_aaa(
184
+ f, sub, N, max_poles, cutoff, tol_aaa, tol_pol,
185
+ use_adaptive=use_adaptive,
186
+ suggestions=sug, Dmax=Dmax-1, z_k=known_z_k, f_k=known_f_k,
187
+ divide_horizontal = not divide_horizontal,
188
+ debug_name=f"{debug_name}{i+1}",
189
+ )
190
+ pol = np.append(pol, p)
191
+ res = np.append(res, r)
192
+ eval_count += e
193
+ # if len(pol) > 0:
194
+ # plt.xlim(domain[0].real, domain[1].real)
195
+ # plt.ylim(domain[0].imag, domain[1].imag)
196
+ # plt.savefig(f"debug_out/{debug_name:0<33}.png")
197
+ return pol, res, eval_count
@@ -0,0 +1,114 @@
1
+ from jax import config
2
+ config.update("jax_enable_x64", True) #important -> else aaa fails
3
+ import jax.numpy as np
4
+ from diffaaable.core import poles
5
+
6
+ def check_inputs(z_k, f_k):
7
+ f_k = np.array(f_k)
8
+ z_k = np.array(z_k)
9
+
10
+ if z_k.ndim != 1:
11
+ raise ValueError("z_k should be 1D but has shape {z_k.shape}")
12
+ M = z_k.shape[0]
13
+
14
+ if f_k.ndim == 1:
15
+ f_k = f_k[:, np.newaxis]
16
+
17
+ if f_k.ndim != 2 or f_k.shape[0]!=M:
18
+ raise ValueError("f_k should be 1 or 2D and have the same first"
19
+ f"dimension as z_k, {f_k.shape=}, {z_k.shape=}")
20
+ V = f_k.shape[1]
21
+
22
+ return z_k, f_k, M, V
23
+
24
+ def vectorial_aaa(z_k, f_k, tol=1e-13, mmax=100, return_errors=False):
25
+ """Find a rational approximation to $\mathbf f(z)$ over the points $z_k$ using
26
+ a modified AAA algorithm, as presented in [^4]. Importantly the weights and
27
+ thus also the poles are shared between all entries of $\mathbf f(z)$.
28
+
29
+ Parameters
30
+ ----------
31
+ z_k : array (M,):
32
+ M sample points
33
+ f_k : array (M, V):
34
+ vector valued function values
35
+ tol : float
36
+ the approximation tolerance
37
+ mmax : int
38
+ the maximum number of iterations/degree of the resulting approximant
39
+
40
+ Returns:
41
+
42
+
43
+
44
+ """
45
+ z_k, f_k, M, V = check_inputs(z_k, f_k)
46
+
47
+ J = np.ones(M, dtype=bool)
48
+ z_j = np.empty(0, dtype=z_k.dtype)
49
+ f_j = np.empty((0, V), dtype=f_k.dtype)
50
+ errors = []
51
+
52
+ reltol = tol * np.linalg.norm(f_k, np.inf)
53
+
54
+ r_k = np.mean(f_k) * np.ones_like(f_k)
55
+
56
+ mmax = min(mmax, len(f_k)//2)
57
+
58
+ for m in range(mmax):
59
+ # find largest residual
60
+ jj = np.argmax(np.linalg.norm(f_k - r_k, axis=-1)) #Next sample point to include
61
+ z_j = np.append(z_j, np.array([z_k[jj]]))
62
+ f_j = np.concatenate([f_j, f_k[jj][None, :]])
63
+ J = J.at[jj].set(False)
64
+
65
+ # Cauchy matrix containing the basis functions as columns
66
+ C = 1.0 / (z_k[J,None] - z_j[None,:])
67
+ # Loewner matrix
68
+ A = (f_k[J,None] - f_j[None,:]) * C[:,:,None]
69
+
70
+ # TODO: stack A
71
+ A = np.concatenate(np.moveaxis(A, -1, 0))
72
+
73
+ # compute weights as right singular vector for smallest singular value
74
+ if return_errors:
75
+ print("start SVD")
76
+ _, _, Vh = np.linalg.svd(A)
77
+ if return_errors:
78
+ print("finished SVD")
79
+
80
+ w_j = Vh[-1, :].conj()
81
+
82
+ # approximation: numerator / denominator
83
+ N = C.dot(w_j[:, None] * f_j) #TODO check it works
84
+ D = C.dot(w_j)[:, None]
85
+
86
+ # update residual
87
+ r_k = f_k.at[J].set(N / D)
88
+
89
+ # check for convergence
90
+ errors.append(np.linalg.norm(f_k - r_k, np.inf))
91
+ if return_errors:
92
+ print(errors[-1])
93
+ if errors[-1] <= reltol:
94
+ break
95
+
96
+ z_n = poles(z_j, w_j)
97
+ if return_errors:
98
+ return z_j, f_j, w_j, z_n, errors
99
+ return z_j, f_j, w_j, z_n
100
+
101
+
102
+ def residues_vec(z_j,f_j,w_j,z_n):
103
+ '''Vectorial residues for given poles via formula for simple poles
104
+ of quotients of analytic functions. For a barycentric rational of order `m`
105
+ the
106
+
107
+ '''
108
+
109
+ C_pol = 1.0 / (z_n[:,None] - z_j[None,:])
110
+ N_pol = C_pol.dot((f_j*w_j[None,:]).T)
111
+ Ddiff_pol = (-C_pol**2).dot(w_j)
112
+ res = N_pol / Ddiff_pol[:, None]
113
+
114
+ return np.nan_to_num(res.T)
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Jan David Fischbach
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,64 @@
1
+ Metadata-Version: 2.1
2
+ Name: diffaaable
3
+ Version: 1.0.1
4
+ Summary: JAX-differentiable AAA algorithm
5
+ Keywords: python
6
+ Author-email: Jan David Fischbach <fischbach@kit.edu>
7
+ Requires-Python: >=3.9
8
+ Description-Content-Type: text/markdown
9
+ Classifier: Programming Language :: Python :: 3.9
10
+ Classifier: Programming Language :: Python :: 3.10
11
+ Classifier: Programming Language :: Python :: 3.11
12
+ Classifier: Operating System :: OS Independent
13
+ Requires-Dist: numpy
14
+ Requires-Dist: jax
15
+ Requires-Dist: jaxlib
16
+ Requires-Dist: baryrat
17
+ Requires-Dist: jaxopt
18
+ Requires-Dist: pytest
19
+ Requires-Dist: pytest-benchmark
20
+ Requires-Dist: pre-commit ; extra == "dev"
21
+ Requires-Dist: pytest ; extra == "dev"
22
+ Requires-Dist: pytest-cov ; extra == "dev"
23
+ Requires-Dist: pytest_regressions ; extra == "dev"
24
+ Requires-Dist: jupytext ; extra == "docs"
25
+ Requires-Dist: matplotlib ; extra == "docs"
26
+ Requires-Dist: jupyter-book ; extra == "docs"
27
+ Requires-Dist: sphinx_math_dollar ; extra == "docs"
28
+ Provides-Extra: dev
29
+ Provides-Extra: docs
30
+
31
+ # diffaaable 1.0.1
32
+
33
+ ![](docs/assets/diffaaable.png)
34
+
35
+ `diffaaable` is a JAX differentiable version of the AAA algorithm. The derivatives are implemented as custom Jacobian Vector products in accordance to [^1].
36
+ A detailed derivation of the used matrix expressions is provided in the appendix of [^2].
37
+ Under the hood `diffaaable` uses the AAA implementation of [`baryrat`](https://github.com/c-f-h/baryrat).
38
+ Additionaly the following application specific extensions to the AAA algorithm are included:
39
+
40
+ - **Adaptive**: Adaptive refinement strategy to minimize the number of function evaluation needed to precisely locate poles within some domain
41
+ - **Vectorial** (also referred to as set-valued): AAA algorithm acting on vector valued functions $\mathbf{f}(z)$ as presented in [^3].
42
+ - **Lorentz**: Variant that enforces symmetric poles around the imaginary axis.
43
+ - **Selective Refinement**: Use a divide and conquer theme to capture many pole simultaneously and accurately, by limiting the number of poles per AAA solve. Suggested in [^4].
44
+
45
+ ## Installation
46
+ to install `diffaaable` run
47
+ `pip install diffaaable`
48
+
49
+ ## Usage
50
+ Please refer to the [quickstart tutorial](./usage.md)
51
+
52
+ ## Contributing
53
+ Feel free to open issues and/or PRs.
54
+
55
+ ## Citation
56
+ When using this software package for scientific work please cite the associated publication [^2].
57
+
58
+ +++
59
+
60
+ [^1]: https://arxiv.org/pdf/2403.19404
61
+ [^2]: Multiscat Resonances (to be publlished)
62
+ [^3]: https://doi.org/10.1093/imanum/draa098
63
+ [^4]: https://doi.org/10.48550/arXiv.2405.19582
64
+
@@ -0,0 +1,10 @@
1
+ diffaaable/__init__.py,sha256=EiOU3Hg-tdBkK2vZAlL0CkXQM38QzuZyWOnTFxSODnI,310
2
+ diffaaable/adaptive.py,sha256=skpKa9jMpirpAWR-fqnYrD9a0WaXrjrTc9kZk1tMbKI,9667
3
+ diffaaable/core.py,sha256=LOcq4k0IiYHaMffIS8Amgdu9lw9CWVH8WPYtFA5KtbY,7168
4
+ diffaaable/lorentz.py,sha256=BVimKKOo-FuTNGhjwr3di54hziq2_xclQKLNe8YGJ-k,2390
5
+ diffaaable/selective.py,sha256=F0pDH9Sf1vhIzgYSLZrqdRPx0oYY6p_dIJyI67JLHxI,6246
6
+ diffaaable/vectorial.py,sha256=JtuowSk_bfgdjrEK0PsVIQGA1EuKuNOciQqqy_Nxv54,3145
7
+ diffaaable-1.0.1.dist-info/LICENSE,sha256=_GGRQSEhqmML9t-dWxiXXZXwPLdzVHT5HdEISjFjSrU,1076
8
+ diffaaable-1.0.1.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
9
+ diffaaable-1.0.1.dist-info/METADATA,sha256=CdLhlhvE4EzjNWdlAPeP41j6AlrybpHM00lHx3k_9JA,2512
10
+ diffaaable-1.0.1.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: flit 3.9.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any