diffaaable 1.0.1__tar.gz → 1.1.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {diffaaable-1.0.1 → diffaaable-1.1.1}/PKG-INFO +9 -7
- {diffaaable-1.0.1 → diffaaable-1.1.1}/README.md +3 -3
- {diffaaable-1.0.1 → diffaaable-1.1.1}/diffaaable/__init__.py +1 -1
- {diffaaable-1.0.1 → diffaaable-1.1.1}/diffaaable/adaptive.py +15 -9
- {diffaaable-1.0.1 → diffaaable-1.1.1}/diffaaable/selective.py +6 -21
- {diffaaable-1.0.1 → diffaaable-1.1.1}/pyproject.toml +8 -6
- {diffaaable-1.0.1 → diffaaable-1.1.1}/LICENSE +0 -0
- {diffaaable-1.0.1 → diffaaable-1.1.1}/diffaaable/core.py +0 -0
- {diffaaable-1.0.1 → diffaaable-1.1.1}/diffaaable/lorentz.py +0 -0
- {diffaaable-1.0.1 → diffaaable-1.1.1}/diffaaable/vectorial.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
2
|
Name: diffaaable
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.1.1
|
|
4
4
|
Summary: JAX-differentiable AAA algorithm
|
|
5
5
|
Keywords: python
|
|
6
6
|
Author-email: Jan David Fischbach <fischbach@kit.edu>
|
|
@@ -15,12 +15,14 @@ Requires-Dist: jax
|
|
|
15
15
|
Requires-Dist: jaxlib
|
|
16
16
|
Requires-Dist: baryrat
|
|
17
17
|
Requires-Dist: jaxopt
|
|
18
|
-
Requires-Dist:
|
|
19
|
-
Requires-Dist:
|
|
18
|
+
Requires-Dist: tbump>=6.11.0 ; extra == "dev"
|
|
19
|
+
Requires-Dist: towncrier ; extra == "dev"
|
|
20
20
|
Requires-Dist: pre-commit ; extra == "dev"
|
|
21
21
|
Requires-Dist: pytest ; extra == "dev"
|
|
22
22
|
Requires-Dist: pytest-cov ; extra == "dev"
|
|
23
23
|
Requires-Dist: pytest_regressions ; extra == "dev"
|
|
24
|
+
Requires-Dist: pytest-benchmark ; extra == "dev"
|
|
25
|
+
Requires-Dist: matplotlib ; extra == "dev"
|
|
24
26
|
Requires-Dist: jupytext ; extra == "docs"
|
|
25
27
|
Requires-Dist: matplotlib ; extra == "docs"
|
|
26
28
|
Requires-Dist: jupyter-book ; extra == "docs"
|
|
@@ -28,7 +30,7 @@ Requires-Dist: sphinx_math_dollar ; extra == "docs"
|
|
|
28
30
|
Provides-Extra: dev
|
|
29
31
|
Provides-Extra: docs
|
|
30
32
|
|
|
31
|
-
# diffaaable 1.
|
|
33
|
+
# diffaaable 1.1.1
|
|
32
34
|
|
|
33
35
|

|
|
34
36
|
|
|
@@ -37,7 +39,7 @@ A detailed derivation of the used matrix expressions is provided in the appendix
|
|
|
37
39
|
Under the hood `diffaaable` uses the AAA implementation of [`baryrat`](https://github.com/c-f-h/baryrat).
|
|
38
40
|
Additionaly the following application specific extensions to the AAA algorithm are included:
|
|
39
41
|
|
|
40
|
-
- **Adaptive**: Adaptive refinement strategy to minimize the number of function evaluation needed to precisely locate poles within some domain
|
|
42
|
+
- **Adaptive**: Adaptive refinement strategy (called Iterative Sample Refinement (ISR) in the corresponding paper) to minimize the number of function evaluation needed to precisely locate poles within some domain
|
|
41
43
|
- **Vectorial** (also referred to as set-valued): AAA algorithm acting on vector valued functions $\mathbf{f}(z)$ as presented in [^3].
|
|
42
44
|
- **Lorentz**: Variant that enforces symmetric poles around the imaginary axis.
|
|
43
45
|
- **Selective Refinement**: Use a divide and conquer theme to capture many pole simultaneously and accurately, by limiting the number of poles per AAA solve. Suggested in [^4].
|
|
@@ -58,7 +60,7 @@ When using this software package for scientific work please cite the associated
|
|
|
58
60
|
+++
|
|
59
61
|
|
|
60
62
|
[^1]: https://arxiv.org/pdf/2403.19404
|
|
61
|
-
[^2]:
|
|
63
|
+
[^2]: "A framework to compute resonances arising from multiple scattering", https://arxiv.org/abs/2409.05563
|
|
62
64
|
[^3]: https://doi.org/10.1093/imanum/draa098
|
|
63
65
|
[^4]: https://doi.org/10.48550/arXiv.2405.19582
|
|
64
66
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# diffaaable 1.
|
|
1
|
+
# diffaaable 1.1.1
|
|
2
2
|
|
|
3
3
|

|
|
4
4
|
|
|
@@ -7,7 +7,7 @@ A detailed derivation of the used matrix expressions is provided in the appendix
|
|
|
7
7
|
Under the hood `diffaaable` uses the AAA implementation of [`baryrat`](https://github.com/c-f-h/baryrat).
|
|
8
8
|
Additionaly the following application specific extensions to the AAA algorithm are included:
|
|
9
9
|
|
|
10
|
-
- **Adaptive**: Adaptive refinement strategy to minimize the number of function evaluation needed to precisely locate poles within some domain
|
|
10
|
+
- **Adaptive**: Adaptive refinement strategy (called Iterative Sample Refinement (ISR) in the corresponding paper) to minimize the number of function evaluation needed to precisely locate poles within some domain
|
|
11
11
|
- **Vectorial** (also referred to as set-valued): AAA algorithm acting on vector valued functions $\mathbf{f}(z)$ as presented in [^3].
|
|
12
12
|
- **Lorentz**: Variant that enforces symmetric poles around the imaginary axis.
|
|
13
13
|
- **Selective Refinement**: Use a divide and conquer theme to capture many pole simultaneously and accurately, by limiting the number of poles per AAA solve. Suggested in [^4].
|
|
@@ -28,6 +28,6 @@ When using this software package for scientific work please cite the associated
|
|
|
28
28
|
+++
|
|
29
29
|
|
|
30
30
|
[^1]: https://arxiv.org/pdf/2403.19404
|
|
31
|
-
[^2]:
|
|
31
|
+
[^2]: "A framework to compute resonances arising from multiple scattering", https://arxiv.org/abs/2409.05563
|
|
32
32
|
[^3]: https://doi.org/10.1093/imanum/draa098
|
|
33
33
|
[^4]: https://doi.org/10.48550/arXiv.2405.19582
|
|
@@ -113,6 +113,18 @@ def next_samples_heat(
|
|
|
113
113
|
|
|
114
114
|
aaa = jax.tree_util.Partial(aaa)
|
|
115
115
|
|
|
116
|
+
def mask(z_k, f_k, f_k_dot, cutoff):
|
|
117
|
+
m = np.abs(f_k)<cutoff #filter out values, that have diverged too strongly
|
|
118
|
+
m = np.logical_and(m, ~np.isnan(f_k)) #filter out nans
|
|
119
|
+
m = np.logical_and(m, ~np.isnan(z_k)) #filter out nans
|
|
120
|
+
|
|
121
|
+
if m.ndim == 2:
|
|
122
|
+
m = np.all(m, axis=1)
|
|
123
|
+
z_k, f_k, f_k_dot = z_k[m], f_k[m], f_k_dot[m]
|
|
124
|
+
|
|
125
|
+
z_k, idx = np.unique(z_k, return_index=True) #filter out duplicates
|
|
126
|
+
return z_k, f_k[idx], f_k_dot[idx]
|
|
127
|
+
|
|
116
128
|
def _adaptive_aaa(z_k_0: npt.NDArray,
|
|
117
129
|
f: callable,
|
|
118
130
|
evolutions: int = 2,
|
|
@@ -162,7 +174,7 @@ def _adaptive_aaa(z_k_0: npt.NDArray,
|
|
|
162
174
|
f_k_dot = np.zeros_like(f_k)
|
|
163
175
|
|
|
164
176
|
if cutoff is None:
|
|
165
|
-
cutoff = 1e10*np.
|
|
177
|
+
cutoff = 1e10*np.nanmedian(np.abs(f_k))
|
|
166
178
|
|
|
167
179
|
if domain is None:
|
|
168
180
|
center = np.mean(z_k)
|
|
@@ -175,15 +187,9 @@ def _adaptive_aaa(z_k_0: npt.NDArray,
|
|
|
175
187
|
if prev_z_n is None:
|
|
176
188
|
prev_z_n = np.array([np.inf], dtype=complex)
|
|
177
189
|
|
|
178
|
-
def mask(z_k, f_k, f_k_dot):
|
|
179
|
-
m = np.abs(f_k)<cutoff
|
|
180
|
-
if m.ndim == 2:
|
|
181
|
-
m = np.all(m, axis=1) #filter out values, that have diverged too strongly
|
|
182
|
-
return z_k[m], f_k[m], f_k_dot[m]
|
|
183
|
-
|
|
184
190
|
key = random.key(0)
|
|
185
191
|
for i in range(evolutions):
|
|
186
|
-
z_k, f_k, f_k_dot = mask(z_k, f_k, f_k_dot)
|
|
192
|
+
z_k, f_k, f_k_dot = mask(z_k, f_k, f_k_dot, cutoff=cutoff)
|
|
187
193
|
z_j, f_j, w_j, z_n = aaa(z_k, f_k, tol, mmax)
|
|
188
194
|
|
|
189
195
|
if i==evolutions-1:
|
|
@@ -211,7 +217,7 @@ def _adaptive_aaa(z_k_0: npt.NDArray,
|
|
|
211
217
|
f_k = np.concatenate([f_k, f_k_new])
|
|
212
218
|
f_k_dot = np.concatenate([f_k_dot, f_k_dot_new])
|
|
213
219
|
|
|
214
|
-
z_k, f_k, f_k_dot = mask(z_k, f_k, f_k_dot)
|
|
220
|
+
z_k, f_k, f_k_dot = mask(z_k, f_k, f_k_dot, cutoff=cutoff)
|
|
215
221
|
|
|
216
222
|
if collect_tangents:
|
|
217
223
|
return z_k, f_k, f_k_dot
|
|
@@ -65,11 +65,6 @@ def subdomains(domain: Domain, divide_horizontal: bool, center: complex=None):
|
|
|
65
65
|
return [(subs[1][0], subs[0][1]), (subs[2][0], subs[3][1])]
|
|
66
66
|
return [(subs[2][0], subs[1][1]), (subs[3][0], subs[0][1])]
|
|
67
67
|
|
|
68
|
-
|
|
69
|
-
def cutoff_mask(z_k, f_k, f_k_dot, cutoff):
|
|
70
|
-
m = np.abs(f_k)<cutoff #filter out values, that have diverged too strongly
|
|
71
|
-
return z_k[m], f_k[m], f_k_dot[m]
|
|
72
|
-
|
|
73
68
|
def plot_domain(domain: Domain, size: float=1):
|
|
74
69
|
left_up = domain[0].real + 1j*domain[1].imag
|
|
75
70
|
right_down = domain[1].real + 1j*domain[0].imag
|
|
@@ -82,7 +77,6 @@ def plot_domain(domain: Domain, size: float=1):
|
|
|
82
77
|
def all_poles_known(poles, prev, tol):
|
|
83
78
|
if prev is None or len(prev)!=len(poles):
|
|
84
79
|
return False
|
|
85
|
-
#return True
|
|
86
80
|
|
|
87
81
|
dist = np.abs(poles[:, None] - prev[None, :])
|
|
88
82
|
check = np.all(np.any(dist < tol, axis=1))
|
|
@@ -110,12 +104,9 @@ def selective_refinement_aaa(f: callable,
|
|
|
110
104
|
TODO: allow access to samples slightly outside of domain
|
|
111
105
|
"""
|
|
112
106
|
|
|
113
|
-
print(f"start domain '{debug_name}', {Dmax=}")
|
|
107
|
+
# print(f"start domain '{debug_name}', {Dmax=}")
|
|
114
108
|
folder = f"debug_out/{debug_name:0<33}"
|
|
115
109
|
domain_size = np.abs(domain[1]-domain[0])/2
|
|
116
|
-
size = domain_size/2 # for plotting
|
|
117
|
-
#plot_rect = plot_domain(domain, size=30)
|
|
118
|
-
#color = plot_rect[0].get_color()
|
|
119
110
|
|
|
120
111
|
if cutoff is None:
|
|
121
112
|
cutoff = np.inf
|
|
@@ -135,11 +126,13 @@ def selective_refinement_aaa(f: callable,
|
|
|
135
126
|
z_k = np.append(z_k, z_k_new)
|
|
136
127
|
|
|
137
128
|
eval_count += len(z_k_new)
|
|
138
|
-
print(f"new eval: {eval_count}")
|
|
129
|
+
# print(f"new eval: {eval_count}")
|
|
139
130
|
eval_count -= len(z_k)
|
|
140
131
|
z_j, f_j, w_j, z_n, z_k, f_k = adaptive_aaa(
|
|
141
132
|
z_k, f, f_k_0=f_k, evolutions=N*16, tol=tol_aaa,
|
|
142
|
-
domain=reduced_domain(domain, 1.07), radius=4*domain_size/(N),
|
|
133
|
+
domain=reduced_domain(domain, 1.07), radius=4*domain_size/(N),
|
|
134
|
+
# NOTE: reduced domain with a factor larger than 1
|
|
135
|
+
# actually increases domain size
|
|
143
136
|
return_samples=True, sampling=sampling, cutoff=np.inf
|
|
144
137
|
)
|
|
145
138
|
# TODO pass down samples in buffer zone
|
|
@@ -156,18 +149,14 @@ def selective_refinement_aaa(f: callable,
|
|
|
156
149
|
except onp.linalg.LinAlgError as e:
|
|
157
150
|
z_n = z_j = f_j = w_j = np.empty((0,))
|
|
158
151
|
|
|
159
|
-
print(f"domain '{debug_name}' done: {domain} -> eval: {eval_count}")
|
|
152
|
+
# print(f"domain '{debug_name}' done: {domain} -> eval: {eval_count}")
|
|
160
153
|
poles = z_n[domain_mask(domain, z_n)]
|
|
161
154
|
|
|
162
155
|
if (Dmax == 0 or
|
|
163
156
|
(len(poles)<=max_poles and all_poles_known(poles, suggestions, tol_pol))):
|
|
164
157
|
|
|
165
|
-
#plt.scatter(poles.real, poles.imag, color = color, marker="x")#, s=size*3, linewidths=size/2)
|
|
166
|
-
print("I am done here")
|
|
167
|
-
|
|
168
158
|
res = residues(z_j, f_j, w_j, poles)
|
|
169
159
|
return poles, res, eval_count
|
|
170
|
-
#plt.scatter(poles.real, poles.imag, color = color, marker="+", s=0.2, zorder=3)#, s=size, linewidths=size/6)
|
|
171
160
|
|
|
172
161
|
subs = subdomains(domain, divide_horizontal)
|
|
173
162
|
|
|
@@ -190,8 +179,4 @@ def selective_refinement_aaa(f: callable,
|
|
|
190
179
|
pol = np.append(pol, p)
|
|
191
180
|
res = np.append(res, r)
|
|
192
181
|
eval_count += e
|
|
193
|
-
# if len(pol) > 0:
|
|
194
|
-
# plt.xlim(domain[0].real, domain[1].real)
|
|
195
|
-
# plt.ylim(domain[0].imag, domain[1].imag)
|
|
196
|
-
# plt.savefig(f"debug_out/{debug_name:0<33}.png")
|
|
197
182
|
return pol, res, eval_count
|
|
@@ -22,9 +22,7 @@ dependencies = [
|
|
|
22
22
|
"jax",
|
|
23
23
|
"jaxlib",
|
|
24
24
|
"baryrat",
|
|
25
|
-
"jaxopt"
|
|
26
|
-
"pytest",
|
|
27
|
-
"pytest-benchmark"
|
|
25
|
+
"jaxopt"
|
|
28
26
|
]
|
|
29
27
|
description = "JAX-differentiable AAA algorithm"
|
|
30
28
|
keywords = ["python"]
|
|
@@ -32,14 +30,18 @@ license = {file = "LICENSE"}
|
|
|
32
30
|
name = "diffaaable"
|
|
33
31
|
readme = "README.md"
|
|
34
32
|
requires-python = ">=3.9"
|
|
35
|
-
version = "1.
|
|
33
|
+
version = "1.1.1"
|
|
36
34
|
|
|
37
35
|
[project.optional-dependencies]
|
|
38
36
|
dev = [
|
|
37
|
+
"tbump>=6.11.0",
|
|
38
|
+
"towncrier",
|
|
39
39
|
"pre-commit",
|
|
40
40
|
"pytest",
|
|
41
41
|
"pytest-cov",
|
|
42
|
-
"pytest_regressions"
|
|
42
|
+
"pytest_regressions",
|
|
43
|
+
"pytest-benchmark",
|
|
44
|
+
"matplotlib"
|
|
43
45
|
]
|
|
44
46
|
docs = [
|
|
45
47
|
"jupytext",
|
|
@@ -135,7 +137,7 @@ message_template = "Bump to {new_version}"
|
|
|
135
137
|
tag_template = "v{new_version}"
|
|
136
138
|
|
|
137
139
|
[tool.tbump.version]
|
|
138
|
-
current = "1.
|
|
140
|
+
current = "1.1.1"
|
|
139
141
|
# Example of a semver regexp.
|
|
140
142
|
# Make sure this matches current_version before
|
|
141
143
|
# using tbump
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|