diffaaable 1.0.1__tar.gz → 1.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.3
2
2
  Name: diffaaable
3
- Version: 1.0.1
3
+ Version: 1.1.1
4
4
  Summary: JAX-differentiable AAA algorithm
5
5
  Keywords: python
6
6
  Author-email: Jan David Fischbach <fischbach@kit.edu>
@@ -15,12 +15,14 @@ Requires-Dist: jax
15
15
  Requires-Dist: jaxlib
16
16
  Requires-Dist: baryrat
17
17
  Requires-Dist: jaxopt
18
- Requires-Dist: pytest
19
- Requires-Dist: pytest-benchmark
18
+ Requires-Dist: tbump>=6.11.0 ; extra == "dev"
19
+ Requires-Dist: towncrier ; extra == "dev"
20
20
  Requires-Dist: pre-commit ; extra == "dev"
21
21
  Requires-Dist: pytest ; extra == "dev"
22
22
  Requires-Dist: pytest-cov ; extra == "dev"
23
23
  Requires-Dist: pytest_regressions ; extra == "dev"
24
+ Requires-Dist: pytest-benchmark ; extra == "dev"
25
+ Requires-Dist: matplotlib ; extra == "dev"
24
26
  Requires-Dist: jupytext ; extra == "docs"
25
27
  Requires-Dist: matplotlib ; extra == "docs"
26
28
  Requires-Dist: jupyter-book ; extra == "docs"
@@ -28,7 +30,7 @@ Requires-Dist: sphinx_math_dollar ; extra == "docs"
28
30
  Provides-Extra: dev
29
31
  Provides-Extra: docs
30
32
 
31
- # diffaaable 1.0.1
33
+ # diffaaable 1.1.1
32
34
 
33
35
  ![](docs/assets/diffaaable.png)
34
36
 
@@ -37,7 +39,7 @@ A detailed derivation of the used matrix expressions is provided in the appendix
37
39
  Under the hood `diffaaable` uses the AAA implementation of [`baryrat`](https://github.com/c-f-h/baryrat).
38
40
  Additionaly the following application specific extensions to the AAA algorithm are included:
39
41
 
40
- - **Adaptive**: Adaptive refinement strategy to minimize the number of function evaluation needed to precisely locate poles within some domain
42
+ - **Adaptive**: Adaptive refinement strategy (called Iterative Sample Refinement (ISR) in the corresponding paper) to minimize the number of function evaluation needed to precisely locate poles within some domain
41
43
  - **Vectorial** (also referred to as set-valued): AAA algorithm acting on vector valued functions $\mathbf{f}(z)$ as presented in [^3].
42
44
  - **Lorentz**: Variant that enforces symmetric poles around the imaginary axis.
43
45
  - **Selective Refinement**: Use a divide and conquer theme to capture many pole simultaneously and accurately, by limiting the number of poles per AAA solve. Suggested in [^4].
@@ -58,7 +60,7 @@ When using this software package for scientific work please cite the associated
58
60
  +++
59
61
 
60
62
  [^1]: https://arxiv.org/pdf/2403.19404
61
- [^2]: Multiscat Resonances (to be publlished)
63
+ [^2]: "A framework to compute resonances arising from multiple scattering", https://arxiv.org/abs/2409.05563
62
64
  [^3]: https://doi.org/10.1093/imanum/draa098
63
65
  [^4]: https://doi.org/10.48550/arXiv.2405.19582
64
66
 
@@ -1,4 +1,4 @@
1
- # diffaaable 1.0.1
1
+ # diffaaable 1.1.1
2
2
 
3
3
  ![](docs/assets/diffaaable.png)
4
4
 
@@ -7,7 +7,7 @@ A detailed derivation of the used matrix expressions is provided in the appendix
7
7
  Under the hood `diffaaable` uses the AAA implementation of [`baryrat`](https://github.com/c-f-h/baryrat).
8
8
  Additionaly the following application specific extensions to the AAA algorithm are included:
9
9
 
10
- - **Adaptive**: Adaptive refinement strategy to minimize the number of function evaluation needed to precisely locate poles within some domain
10
+ - **Adaptive**: Adaptive refinement strategy (called Iterative Sample Refinement (ISR) in the corresponding paper) to minimize the number of function evaluation needed to precisely locate poles within some domain
11
11
  - **Vectorial** (also referred to as set-valued): AAA algorithm acting on vector valued functions $\mathbf{f}(z)$ as presented in [^3].
12
12
  - **Lorentz**: Variant that enforces symmetric poles around the imaginary axis.
13
13
  - **Selective Refinement**: Use a divide and conquer theme to capture many pole simultaneously and accurately, by limiting the number of poles per AAA solve. Suggested in [^4].
@@ -28,6 +28,6 @@ When using this software package for scientific work please cite the associated
28
28
  +++
29
29
 
30
30
  [^1]: https://arxiv.org/pdf/2403.19404
31
- [^2]: Multiscat Resonances (to be publlished)
31
+ [^2]: "A framework to compute resonances arising from multiple scattering", https://arxiv.org/abs/2409.05563
32
32
  [^3]: https://doi.org/10.1093/imanum/draa098
33
33
  [^4]: https://doi.org/10.48550/arXiv.2405.19582
@@ -1,6 +1,6 @@
1
1
  """diffaaable - JAX-differentiable AAA algorithm"""
2
2
 
3
- __version__ = "1.0.1"
3
+ __version__ = "1.1.1"
4
4
 
5
5
  __all__ = ["aaa", "adaptive_aaa", "vectorial_aaa", "lorentz_aaa"]
6
6
 
@@ -113,6 +113,18 @@ def next_samples_heat(
113
113
 
114
114
  aaa = jax.tree_util.Partial(aaa)
115
115
 
116
+ def mask(z_k, f_k, f_k_dot, cutoff):
117
+ m = np.abs(f_k)<cutoff #filter out values, that have diverged too strongly
118
+ m = np.logical_and(m, ~np.isnan(f_k)) #filter out nans
119
+ m = np.logical_and(m, ~np.isnan(z_k)) #filter out nans
120
+
121
+ if m.ndim == 2:
122
+ m = np.all(m, axis=1)
123
+ z_k, f_k, f_k_dot = z_k[m], f_k[m], f_k_dot[m]
124
+
125
+ z_k, idx = np.unique(z_k, return_index=True) #filter out duplicates
126
+ return z_k, f_k[idx], f_k_dot[idx]
127
+
116
128
  def _adaptive_aaa(z_k_0: npt.NDArray,
117
129
  f: callable,
118
130
  evolutions: int = 2,
@@ -162,7 +174,7 @@ def _adaptive_aaa(z_k_0: npt.NDArray,
162
174
  f_k_dot = np.zeros_like(f_k)
163
175
 
164
176
  if cutoff is None:
165
- cutoff = 1e10*np.median(np.abs(f_k))
177
+ cutoff = 1e10*np.nanmedian(np.abs(f_k))
166
178
 
167
179
  if domain is None:
168
180
  center = np.mean(z_k)
@@ -175,15 +187,9 @@ def _adaptive_aaa(z_k_0: npt.NDArray,
175
187
  if prev_z_n is None:
176
188
  prev_z_n = np.array([np.inf], dtype=complex)
177
189
 
178
- def mask(z_k, f_k, f_k_dot):
179
- m = np.abs(f_k)<cutoff
180
- if m.ndim == 2:
181
- m = np.all(m, axis=1) #filter out values, that have diverged too strongly
182
- return z_k[m], f_k[m], f_k_dot[m]
183
-
184
190
  key = random.key(0)
185
191
  for i in range(evolutions):
186
- z_k, f_k, f_k_dot = mask(z_k, f_k, f_k_dot)
192
+ z_k, f_k, f_k_dot = mask(z_k, f_k, f_k_dot, cutoff=cutoff)
187
193
  z_j, f_j, w_j, z_n = aaa(z_k, f_k, tol, mmax)
188
194
 
189
195
  if i==evolutions-1:
@@ -211,7 +217,7 @@ def _adaptive_aaa(z_k_0: npt.NDArray,
211
217
  f_k = np.concatenate([f_k, f_k_new])
212
218
  f_k_dot = np.concatenate([f_k_dot, f_k_dot_new])
213
219
 
214
- z_k, f_k, f_k_dot = mask(z_k, f_k, f_k_dot)
220
+ z_k, f_k, f_k_dot = mask(z_k, f_k, f_k_dot, cutoff=cutoff)
215
221
 
216
222
  if collect_tangents:
217
223
  return z_k, f_k, f_k_dot
@@ -65,11 +65,6 @@ def subdomains(domain: Domain, divide_horizontal: bool, center: complex=None):
65
65
  return [(subs[1][0], subs[0][1]), (subs[2][0], subs[3][1])]
66
66
  return [(subs[2][0], subs[1][1]), (subs[3][0], subs[0][1])]
67
67
 
68
-
69
- def cutoff_mask(z_k, f_k, f_k_dot, cutoff):
70
- m = np.abs(f_k)<cutoff #filter out values, that have diverged too strongly
71
- return z_k[m], f_k[m], f_k_dot[m]
72
-
73
68
  def plot_domain(domain: Domain, size: float=1):
74
69
  left_up = domain[0].real + 1j*domain[1].imag
75
70
  right_down = domain[1].real + 1j*domain[0].imag
@@ -82,7 +77,6 @@ def plot_domain(domain: Domain, size: float=1):
82
77
  def all_poles_known(poles, prev, tol):
83
78
  if prev is None or len(prev)!=len(poles):
84
79
  return False
85
- #return True
86
80
 
87
81
  dist = np.abs(poles[:, None] - prev[None, :])
88
82
  check = np.all(np.any(dist < tol, axis=1))
@@ -110,12 +104,9 @@ def selective_refinement_aaa(f: callable,
110
104
  TODO: allow access to samples slightly outside of domain
111
105
  """
112
106
 
113
- print(f"start domain '{debug_name}', {Dmax=}")
107
+ # print(f"start domain '{debug_name}', {Dmax=}")
114
108
  folder = f"debug_out/{debug_name:0<33}"
115
109
  domain_size = np.abs(domain[1]-domain[0])/2
116
- size = domain_size/2 # for plotting
117
- #plot_rect = plot_domain(domain, size=30)
118
- #color = plot_rect[0].get_color()
119
110
 
120
111
  if cutoff is None:
121
112
  cutoff = np.inf
@@ -135,11 +126,13 @@ def selective_refinement_aaa(f: callable,
135
126
  z_k = np.append(z_k, z_k_new)
136
127
 
137
128
  eval_count += len(z_k_new)
138
- print(f"new eval: {eval_count}")
129
+ # print(f"new eval: {eval_count}")
139
130
  eval_count -= len(z_k)
140
131
  z_j, f_j, w_j, z_n, z_k, f_k = adaptive_aaa(
141
132
  z_k, f, f_k_0=f_k, evolutions=N*16, tol=tol_aaa,
142
- domain=reduced_domain(domain, 1.07), radius=4*domain_size/(N), #NOTE: actually increased domain :/
133
+ domain=reduced_domain(domain, 1.07), radius=4*domain_size/(N),
134
+ # NOTE: reduced domain with a factor larger than 1
135
+ # actually increases domain size
143
136
  return_samples=True, sampling=sampling, cutoff=np.inf
144
137
  )
145
138
  # TODO pass down samples in buffer zone
@@ -156,18 +149,14 @@ def selective_refinement_aaa(f: callable,
156
149
  except onp.linalg.LinAlgError as e:
157
150
  z_n = z_j = f_j = w_j = np.empty((0,))
158
151
 
159
- print(f"domain '{debug_name}' done: {domain} -> eval: {eval_count}")
152
+ # print(f"domain '{debug_name}' done: {domain} -> eval: {eval_count}")
160
153
  poles = z_n[domain_mask(domain, z_n)]
161
154
 
162
155
  if (Dmax == 0 or
163
156
  (len(poles)<=max_poles and all_poles_known(poles, suggestions, tol_pol))):
164
157
 
165
- #plt.scatter(poles.real, poles.imag, color = color, marker="x")#, s=size*3, linewidths=size/2)
166
- print("I am done here")
167
-
168
158
  res = residues(z_j, f_j, w_j, poles)
169
159
  return poles, res, eval_count
170
- #plt.scatter(poles.real, poles.imag, color = color, marker="+", s=0.2, zorder=3)#, s=size, linewidths=size/6)
171
160
 
172
161
  subs = subdomains(domain, divide_horizontal)
173
162
 
@@ -190,8 +179,4 @@ def selective_refinement_aaa(f: callable,
190
179
  pol = np.append(pol, p)
191
180
  res = np.append(res, r)
192
181
  eval_count += e
193
- # if len(pol) > 0:
194
- # plt.xlim(domain[0].real, domain[1].real)
195
- # plt.ylim(domain[0].imag, domain[1].imag)
196
- # plt.savefig(f"debug_out/{debug_name:0<33}.png")
197
182
  return pol, res, eval_count
@@ -22,9 +22,7 @@ dependencies = [
22
22
  "jax",
23
23
  "jaxlib",
24
24
  "baryrat",
25
- "jaxopt",
26
- "pytest",
27
- "pytest-benchmark"
25
+ "jaxopt"
28
26
  ]
29
27
  description = "JAX-differentiable AAA algorithm"
30
28
  keywords = ["python"]
@@ -32,14 +30,18 @@ license = {file = "LICENSE"}
32
30
  name = "diffaaable"
33
31
  readme = "README.md"
34
32
  requires-python = ">=3.9"
35
- version = "1.0.1"
33
+ version = "1.1.1"
36
34
 
37
35
  [project.optional-dependencies]
38
36
  dev = [
37
+ "tbump>=6.11.0",
38
+ "towncrier",
39
39
  "pre-commit",
40
40
  "pytest",
41
41
  "pytest-cov",
42
- "pytest_regressions"
42
+ "pytest_regressions",
43
+ "pytest-benchmark",
44
+ "matplotlib"
43
45
  ]
44
46
  docs = [
45
47
  "jupytext",
@@ -135,7 +137,7 @@ message_template = "Bump to {new_version}"
135
137
  tag_template = "v{new_version}"
136
138
 
137
139
  [tool.tbump.version]
138
- current = "1.0.1"
140
+ current = "1.1.1"
139
141
  # Example of a semver regexp.
140
142
  # Make sure this matches current_version before
141
143
  # using tbump
File without changes