jaxspec 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxspec/analysis/_plot.py +35 -0
- jaxspec/data/obsconf.py +34 -0
- jaxspec/fit.py +247 -261
- jaxspec/model/abc.py +60 -11
- jaxspec/model/additive.py +1 -3
- jaxspec/model/multiplicative.py +3 -11
- jaxspec/util/typing.py +27 -2
- {jaxspec-0.1.0.dist-info → jaxspec-0.1.2.dist-info}/METADATA +13 -7
- {jaxspec-0.1.0.dist-info → jaxspec-0.1.2.dist-info}/RECORD +12 -14
- jaxspec/model/_additive/__init__.py +0 -0
- jaxspec/model/_additive/apec.py +0 -316
- jaxspec/model/_additive/apec_loaders.py +0 -73
- {jaxspec-0.1.0.dist-info → jaxspec-0.1.2.dist-info}/LICENSE.md +0 -0
- {jaxspec-0.1.0.dist-info → jaxspec-0.1.2.dist-info}/WHEEL +0 -0
- {jaxspec-0.1.0.dist-info → jaxspec-0.1.2.dist-info}/entry_points.txt +0 -0
jaxspec/model/abc.py
CHANGED
|
@@ -1,12 +1,17 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from uuid import uuid4
|
|
5
|
+
|
|
2
6
|
import haiku as hk
|
|
3
7
|
import jax
|
|
4
8
|
import jax.numpy as jnp
|
|
5
9
|
import networkx as nx
|
|
10
|
+
import rich
|
|
11
|
+
|
|
6
12
|
from haiku._src import base
|
|
7
|
-
from uuid import uuid4
|
|
8
13
|
from jax.scipy.integrate import trapezoid
|
|
9
|
-
from
|
|
14
|
+
from rich.table import Table
|
|
10
15
|
from simpleeval import simple_eval
|
|
11
16
|
|
|
12
17
|
|
|
@@ -107,6 +112,30 @@ class SpectralModel:
|
|
|
107
112
|
def params(self):
|
|
108
113
|
return self.transformed_func_photon.init(None, jnp.ones(10), jnp.ones(10))
|
|
109
114
|
|
|
115
|
+
def __rich_repr__(self):
|
|
116
|
+
table = Table(title=str(self))
|
|
117
|
+
|
|
118
|
+
table.add_column("Component", justify="right", style="bold", no_wrap=True)
|
|
119
|
+
table.add_column("Parameter")
|
|
120
|
+
|
|
121
|
+
params = self.params
|
|
122
|
+
|
|
123
|
+
for component in params.keys():
|
|
124
|
+
once = True
|
|
125
|
+
|
|
126
|
+
for parameters in params[component].keys():
|
|
127
|
+
table.add_row(component if once else "", parameters)
|
|
128
|
+
once = False
|
|
129
|
+
|
|
130
|
+
return table
|
|
131
|
+
|
|
132
|
+
def __repr_html_(self):
|
|
133
|
+
return self.__rich_repr__()
|
|
134
|
+
|
|
135
|
+
def __repr__(self):
|
|
136
|
+
rich.print(self.__rich_repr__())
|
|
137
|
+
return ""
|
|
138
|
+
|
|
110
139
|
def photon_flux(self, params, e_low, e_high, n_points=2):
|
|
111
140
|
r"""
|
|
112
141
|
Compute the expected counts between $E_\min$ and $E_\max$ by integrating the model.
|
|
@@ -215,16 +244,18 @@ class SpectralModel:
|
|
|
215
244
|
continuum[node_id] = runtime_modules[node_id].continuum(energies)
|
|
216
245
|
|
|
217
246
|
elif node and node["type"] == "operation":
|
|
218
|
-
component_1 = list(self.graph.in_edges(node_id))[0][0]
|
|
247
|
+
component_1 = list(self.graph.in_edges(node_id))[0][0] # noqa: RUF015
|
|
219
248
|
component_2 = list(self.graph.in_edges(node_id))[1][0]
|
|
220
|
-
continuum[node_id] = node["function"](
|
|
249
|
+
continuum[node_id] = node["function"](
|
|
250
|
+
continuum[component_1], continuum[component_2]
|
|
251
|
+
)
|
|
221
252
|
|
|
222
253
|
if n_points == 2:
|
|
223
|
-
flux_1D = continuum[list(self.graph.in_edges("out"))[0][0]]
|
|
254
|
+
flux_1D = continuum[list(self.graph.in_edges("out"))[0][0]] # noqa: RUF015
|
|
224
255
|
flux = jnp.stack((flux_1D[:-1], flux_1D[1:]))
|
|
225
256
|
|
|
226
257
|
else:
|
|
227
|
-
flux = continuum[list(self.graph.in_edges("out"))[0][0]]
|
|
258
|
+
flux = continuum[list(self.graph.in_edges("out"))[0][0]] # noqa: RUF015
|
|
228
259
|
|
|
229
260
|
if energy_flux:
|
|
230
261
|
continuum_flux = trapezoid(
|
|
@@ -234,7 +265,9 @@ class SpectralModel:
|
|
|
234
265
|
)
|
|
235
266
|
|
|
236
267
|
else:
|
|
237
|
-
continuum_flux = trapezoid(
|
|
268
|
+
continuum_flux = trapezoid(
|
|
269
|
+
flux * energies_to_integrate, x=jnp.log(energies_to_integrate), axis=0
|
|
270
|
+
)
|
|
238
271
|
|
|
239
272
|
# Iterate from the root nodes to the output node and
|
|
240
273
|
# compute the fine structure contribution for each component
|
|
@@ -249,14 +282,18 @@ class SpectralModel:
|
|
|
249
282
|
path = nx.shortest_path(self.graph, source=root_node_id, target="out")
|
|
250
283
|
nodes_id_in_path = [node_id for node_id in path]
|
|
251
284
|
|
|
252
|
-
flux_from_component, mean_energy = runtime_modules[root_node_id].emission_lines(
|
|
285
|
+
flux_from_component, mean_energy = runtime_modules[root_node_id].emission_lines(
|
|
286
|
+
e_low, e_high
|
|
287
|
+
)
|
|
253
288
|
|
|
254
289
|
multiplicative_nodes = []
|
|
255
290
|
|
|
256
291
|
# Search all multiplicative components connected to this node
|
|
257
292
|
# and apply them at mean energy
|
|
258
293
|
for node_id in nodes_id_in_path[::-1]:
|
|
259
|
-
multiplicative_nodes.extend(
|
|
294
|
+
multiplicative_nodes.extend(
|
|
295
|
+
[node_id for node_id in self.find_multiplicative_components(node_id)]
|
|
296
|
+
)
|
|
260
297
|
|
|
261
298
|
for mul_node in multiplicative_nodes:
|
|
262
299
|
flux_from_component *= runtime_modules[mul_node].continuum(mean_energy)
|
|
@@ -309,7 +346,10 @@ class SpectralModel:
|
|
|
309
346
|
if component.type == "additive":
|
|
310
347
|
|
|
311
348
|
def lam_func(e):
|
|
312
|
-
return
|
|
349
|
+
return (
|
|
350
|
+
component(**kwargs).continuum(e)
|
|
351
|
+
+ component(**kwargs).emission_lines(e, e + 1)[0]
|
|
352
|
+
)
|
|
313
353
|
|
|
314
354
|
elif component.type == "multiplicative":
|
|
315
355
|
|
|
@@ -342,7 +382,9 @@ class SpectralModel:
|
|
|
342
382
|
|
|
343
383
|
return cls(graph, labels)
|
|
344
384
|
|
|
345
|
-
def compose(
|
|
385
|
+
def compose(
|
|
386
|
+
self, other: SpectralModel, operation=None, function=None, name=None
|
|
387
|
+
) -> SpectralModel:
|
|
346
388
|
"""
|
|
347
389
|
This function operate a composition between the operation graph of two models
|
|
348
390
|
1) It fuses the two graphs using which joins at the 'out' nodes
|
|
@@ -524,3 +566,10 @@ class AdditiveComponent(ModelComponent, ABC):
|
|
|
524
566
|
|
|
525
567
|
return jnp.trapz(self(x) * dx, x=t)
|
|
526
568
|
'''
|
|
569
|
+
|
|
570
|
+
|
|
571
|
+
class MultiplicativeComponent(ModelComponent, ABC):
|
|
572
|
+
type = "multiplicative"
|
|
573
|
+
|
|
574
|
+
@abstractmethod
|
|
575
|
+
def continuum(self, energy): ...
|
jaxspec/model/additive.py
CHANGED
|
@@ -14,8 +14,6 @@ from haiku.initializers import Constant as HaikuConstant
|
|
|
14
14
|
|
|
15
15
|
from ..util.integrate import integrate_interval
|
|
16
16
|
from ..util.online_storage import table_manager
|
|
17
|
-
|
|
18
|
-
# from ._additive.apec import APEC
|
|
19
17
|
from .abc import AdditiveComponent
|
|
20
18
|
|
|
21
19
|
|
|
@@ -38,7 +36,7 @@ class Powerlaw(AdditiveComponent):
|
|
|
38
36
|
return norm * energy ** (-alpha)
|
|
39
37
|
|
|
40
38
|
|
|
41
|
-
class
|
|
39
|
+
class Additiveconstant(AdditiveComponent):
|
|
42
40
|
r"""
|
|
43
41
|
A constant model
|
|
44
42
|
|
jaxspec/model/multiplicative.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from abc import ABC, abstractmethod
|
|
4
|
-
|
|
5
3
|
import haiku as hk
|
|
6
4
|
import jax.numpy as jnp
|
|
7
5
|
import numpy as np
|
|
@@ -10,14 +8,7 @@ from astropy.table import Table
|
|
|
10
8
|
from haiku.initializers import Constant as HaikuConstant
|
|
11
9
|
|
|
12
10
|
from ..util.online_storage import table_manager
|
|
13
|
-
from .abc import
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class MultiplicativeComponent(ModelComponent, ABC):
|
|
17
|
-
type = "multiplicative"
|
|
18
|
-
|
|
19
|
-
@abstractmethod
|
|
20
|
-
def continuum(self, energy): ...
|
|
11
|
+
from .abc import MultiplicativeComponent
|
|
21
12
|
|
|
22
13
|
|
|
23
14
|
class Expfac(MultiplicativeComponent):
|
|
@@ -226,6 +217,7 @@ class Tbpcf(MultiplicativeComponent):
|
|
|
226
217
|
|
|
227
218
|
return f * jnp.exp(-nh * sigma) + (1 - f)
|
|
228
219
|
|
|
220
|
+
|
|
229
221
|
class FDcut(MultiplicativeComponent):
|
|
230
222
|
r"""
|
|
231
223
|
A Fermi-Dirac cutoff model.
|
|
@@ -243,4 +235,4 @@ class FDcut(MultiplicativeComponent):
|
|
|
243
235
|
cutoff = hk.get_parameter("E_c", [], init=HaikuConstant(1))
|
|
244
236
|
folding = hk.get_parameter("E_f", [], init=HaikuConstant(1))
|
|
245
237
|
|
|
246
|
-
return (1 + jnp.exp((energy - cutoff)/folding)) ** -1
|
|
238
|
+
return (1 + jnp.exp((energy - cutoff) / folding)) ** -1
|
jaxspec/util/typing.py
CHANGED
|
@@ -9,6 +9,13 @@ from pydantic import BaseModel, field_validator
|
|
|
9
9
|
PriorDictType = dict[str, dict[str, dist.Distribution | ArrayLike]]
|
|
10
10
|
|
|
11
11
|
|
|
12
|
+
def is_flat_dict(input_data: dict[str, Any]) -> bool:
|
|
13
|
+
"""
|
|
14
|
+
Check if the input data is a flat dictionary with string keys and non-dictionary values.
|
|
15
|
+
"""
|
|
16
|
+
return all(isinstance(k, str) and not isinstance(v, dict) for k, v in input_data.items())
|
|
17
|
+
|
|
18
|
+
|
|
12
19
|
class PriorDictModel(BaseModel):
|
|
13
20
|
"""
|
|
14
21
|
Pydantic model for a nested dictionary of NumPyro distributions or JAX arrays.
|
|
@@ -21,6 +28,23 @@ class PriorDictModel(BaseModel):
|
|
|
21
28
|
class Config: # noqa D106
|
|
22
29
|
arbitrary_types_allowed = True
|
|
23
30
|
|
|
31
|
+
@classmethod
|
|
32
|
+
def from_dict(cls, input_prior: dict[str, Any]):
|
|
33
|
+
if is_flat_dict(input_prior):
|
|
34
|
+
nested_dict = {}
|
|
35
|
+
|
|
36
|
+
for key, obj in input_prior.items():
|
|
37
|
+
component, component_number, *parameter = key.split("_")
|
|
38
|
+
|
|
39
|
+
sub_dict = nested_dict.get(f"{component}_{component_number}", {})
|
|
40
|
+
sub_dict["_".join(parameter)] = obj
|
|
41
|
+
|
|
42
|
+
nested_dict[f"{component}_{component_number}"] = sub_dict
|
|
43
|
+
|
|
44
|
+
return cls(nested_dict=nested_dict)
|
|
45
|
+
|
|
46
|
+
return cls(nested_dict=input_prior)
|
|
47
|
+
|
|
24
48
|
@field_validator("nested_dict", mode="before")
|
|
25
49
|
def check_and_cast_nested_dict(cls, value: dict[str, Any]):
|
|
26
50
|
if not isinstance(value, dict):
|
|
@@ -35,9 +59,10 @@ class PriorDictModel(BaseModel):
|
|
|
35
59
|
try:
|
|
36
60
|
# Attempt to cast to JAX array
|
|
37
61
|
value[key][inner_key] = jnp.array(obj, dtype=float)
|
|
62
|
+
|
|
38
63
|
except Exception as e:
|
|
39
64
|
raise ValueError(
|
|
40
|
-
f'The value for key "{inner_key}" in
|
|
41
|
-
f"
|
|
65
|
+
f'The value for key "{inner_key}" in {key} be a NumPyro '
|
|
66
|
+
f"distribution or castable to JAX array. Error: {e}"
|
|
42
67
|
)
|
|
43
68
|
return value
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: jaxspec
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.2
|
|
4
4
|
Summary: jaxspec is a bayesian spectral fitting library for X-ray astronomy.
|
|
5
5
|
Home-page: https://github.com/renecotyfanboy/jaxspec
|
|
6
6
|
License: MIT
|
|
@@ -15,19 +15,18 @@ Requires-Dist: arviz (>=0.17.1,<0.20.0)
|
|
|
15
15
|
Requires-Dist: astropy (>=6.0.0,<7.0.0)
|
|
16
16
|
Requires-Dist: chainconsumer (>=1.1.2,<2.0.0)
|
|
17
17
|
Requires-Dist: cmasher (>=1.6.3,<2.0.0)
|
|
18
|
-
Requires-Dist: dm-haiku (>=0.0.
|
|
18
|
+
Requires-Dist: dm-haiku (>=0.0.12,<0.0.13)
|
|
19
19
|
Requires-Dist: gpjax (>=0.8.0,<0.9.0)
|
|
20
20
|
Requires-Dist: interpax (>=0.3.3,<0.4.0)
|
|
21
|
-
Requires-Dist: jax (>=0.4.
|
|
21
|
+
Requires-Dist: jax (>=0.4.33,<0.5.0)
|
|
22
22
|
Requires-Dist: jaxlib (>=0.4.30,<0.5.0)
|
|
23
|
-
Requires-Dist: jaxns (
|
|
23
|
+
Requires-Dist: jaxns (<2.6)
|
|
24
24
|
Requires-Dist: jaxopt (>=0.8.1,<0.9.0)
|
|
25
25
|
Requires-Dist: matplotlib (>=3.8.0,<4.0.0)
|
|
26
26
|
Requires-Dist: mendeleev (>=0.15,<0.18)
|
|
27
|
-
Requires-Dist: mkdocstrings (>=0.24,<0.26)
|
|
28
27
|
Requires-Dist: networkx (>=3.1,<4.0)
|
|
29
28
|
Requires-Dist: numpy (<2.0.0)
|
|
30
|
-
Requires-Dist: numpyro (>=0.15.
|
|
29
|
+
Requires-Dist: numpyro (>=0.15.3,<0.16.0)
|
|
31
30
|
Requires-Dist: optimistix (>=0.0.7,<0.0.8)
|
|
32
31
|
Requires-Dist: pandas (>=2.2.0,<3.0.0)
|
|
33
32
|
Requires-Dist: pooch (>=1.8.2,<2.0.0)
|
|
@@ -41,7 +40,14 @@ Requires-Dist: watermark (>=2.4.3,<3.0.0)
|
|
|
41
40
|
Project-URL: Documentation, https://jaxspec.readthedocs.io/en/latest/
|
|
42
41
|
Description-Content-Type: text/markdown
|
|
43
42
|
|
|
44
|
-
|
|
43
|
+
<p align="center">
|
|
44
|
+
<img src="https://raw.githubusercontent.com/renecotyfanboy/jaxspec/main/docs/logo/logo_small.svg" alt="Logo" width="100" height="100">
|
|
45
|
+
</p>
|
|
46
|
+
|
|
47
|
+
<h1 align="center">
|
|
48
|
+
jaxspec
|
|
49
|
+
</h1>
|
|
50
|
+
|
|
45
51
|
|
|
46
52
|
[)](https://pypi.org/project/jaxspec/)
|
|
47
53
|
[](https://pypi.org/project/jaxspec/)
|
|
@@ -1,33 +1,31 @@
|
|
|
1
1
|
jaxspec/__init__.py,sha256=Sbn02lX6Y-zNXk17N8dec22c5jeypiS0LkHmGfz7lWA,126
|
|
2
2
|
jaxspec/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
+
jaxspec/analysis/_plot.py,sha256=C4XljmuzQz8xQur_jQddgInrBDmKgTn0eugSreLoD5k,862
|
|
3
4
|
jaxspec/analysis/compare.py,sha256=g2UFhmR9Zt-7cz5gQFOB6lXuklXB3yTyUvjTypOzoSY,725
|
|
4
5
|
jaxspec/analysis/results.py,sha256=Kz3eryxS3N_hiajcFLTWS1dtgTQo5hlh-rDCnJ3A-3c,27811
|
|
5
6
|
jaxspec/data/__init__.py,sha256=aantcYKC9kZFvaE-V2SIwSuLhIld17Kjrd9CIUu___Y,415
|
|
6
7
|
jaxspec/data/grouping.py,sha256=hhgBt-voiH0DDSyePacaIGsaMnrYbJM_-ZeU66keC7I,622
|
|
7
8
|
jaxspec/data/instrument.py,sha256=0pSf1p82g7syDMmKm13eVbYih-Veiq5DnwsyZe6_b4g,3890
|
|
8
|
-
jaxspec/data/obsconf.py,sha256=
|
|
9
|
+
jaxspec/data/obsconf.py,sha256=gv14sL6azK2avRiMCWuTbyLBPulzm4PwvoLY6iWPEVE,9833
|
|
9
10
|
jaxspec/data/observation.py,sha256=1UnFu5ihZp9z-vP_I7tsFY8jhhIJunv46JyuE-acrg0,6394
|
|
10
11
|
jaxspec/data/ogip.py,sha256=sv9p00qHS5pzw61pzWyyF0nV-E-RXySdSFK2tUavokA,9545
|
|
11
12
|
jaxspec/data/util.py,sha256=ycLPVE-cjn6VpUWYlBU1BGfw73ANXIBilyVAUOYOSj0,9540
|
|
12
|
-
jaxspec/fit.py,sha256=
|
|
13
|
+
jaxspec/fit.py,sha256=hI0koMO4KsNpe9mLlaFm_tNLgm4BVAYVyiMb1E1eyZE,24553
|
|
13
14
|
jaxspec/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
14
|
-
jaxspec/model/
|
|
15
|
-
jaxspec/model/
|
|
16
|
-
jaxspec/model/_additive/apec_loaders.py,sha256=jkUoH0ezeYdaNw3oV10V0L-jt848SKp2thanLWLWp9k,2412
|
|
17
|
-
jaxspec/model/abc.py,sha256=SWjKOOsqU5UJsVy63Tt9dDq8H2eTIbvK2C9iqgiR0cY,19817
|
|
18
|
-
jaxspec/model/additive.py,sha256=CT2K2DVVeHKN1tee9-J3MYdEPqEOolLB2E7HU-RJKZw,22485
|
|
15
|
+
jaxspec/model/abc.py,sha256=MuxEyvn223QPwGoFIJiST8nRMgrZ08ZLkw33oep3tx4,20887
|
|
16
|
+
jaxspec/model/additive.py,sha256=wjY2wL3Io3F45GJpz-UB8xYVnA-W1OFBnZMbj5pWPbQ,22449
|
|
19
17
|
jaxspec/model/background.py,sha256=QSFFiuyUEvuzXBx3QfkvVneUR8KKEP-VaANEVXcavDE,7865
|
|
20
18
|
jaxspec/model/list.py,sha256=0RPAoscVz_zM1CWdx_Gd5wfrQWV5Nv4Kd4bSXu2ayUA,860
|
|
21
|
-
jaxspec/model/multiplicative.py,sha256=
|
|
19
|
+
jaxspec/model/multiplicative.py,sha256=GCQ6JRz92QqbzDBFwWxGZ9SUqTJZQpD7B6ji9VEFXWo,8135
|
|
22
20
|
jaxspec/scripts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
23
21
|
jaxspec/scripts/debug.py,sha256=RIRykBW_kzt8__PhOohQn2xtDW6oAz49E8rmuR5ewAU,293
|
|
24
22
|
jaxspec/util/__init__.py,sha256=vKurfp7p2hxHptJjXhXqFAXAikAGXAqISMJUqPeiGTw,1259
|
|
25
23
|
jaxspec/util/abundance.py,sha256=fsC313taIlGzQsZNwbYsJupDWm7ZbqzGhY66Ku394Mw,8546
|
|
26
24
|
jaxspec/util/integrate.py,sha256=_Ax_knpC7d4et2-QFkOUzVtNeQLX1-cwLvm-FRBxYcw,4505
|
|
27
25
|
jaxspec/util/online_storage.py,sha256=vm56RfcbFKpkRVfr0bXO7J9aQxuBq-I_oEgA26YIhCo,2469
|
|
28
|
-
jaxspec/util/typing.py,sha256=
|
|
29
|
-
jaxspec-0.1.
|
|
30
|
-
jaxspec-0.1.
|
|
31
|
-
jaxspec-0.1.
|
|
32
|
-
jaxspec-0.1.
|
|
33
|
-
jaxspec-0.1.
|
|
26
|
+
jaxspec/util/typing.py,sha256=8qK1aJlsqTcVKjYN-BxsDx20BTwtnS-wMw6Bdurpm-o,2459
|
|
27
|
+
jaxspec-0.1.2.dist-info/LICENSE.md,sha256=2q5XoWzddts5IqzIcgYYMOL21puU3MfO8gvT3Ype1eQ,1073
|
|
28
|
+
jaxspec-0.1.2.dist-info/METADATA,sha256=FE2bTAk-3Xryi6fplV4Y-F2eibUdLZgC9ET9_4HvdOA,3708
|
|
29
|
+
jaxspec-0.1.2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
30
|
+
jaxspec-0.1.2.dist-info/entry_points.txt,sha256=kzLG2mGlCWITRn4Q6zKG_idx-_RKAncvA0DMNYTgHAg,71
|
|
31
|
+
jaxspec-0.1.2.dist-info/RECORD,,
|
|
File without changes
|
jaxspec/model/_additive/apec.py
DELETED
|
@@ -1,316 +0,0 @@
|
|
|
1
|
-
import warnings
|
|
2
|
-
|
|
3
|
-
from typing import Literal
|
|
4
|
-
|
|
5
|
-
import astropy.units as u
|
|
6
|
-
import haiku as hk
|
|
7
|
-
import jax
|
|
8
|
-
import jax.numpy as jnp
|
|
9
|
-
|
|
10
|
-
from astropy.constants import c, m_p
|
|
11
|
-
from haiku.initializers import Constant as HaikuConstant
|
|
12
|
-
from jax import lax
|
|
13
|
-
from jax.lax import fori_loop, scan
|
|
14
|
-
from jax.scipy.stats import norm as gaussian
|
|
15
|
-
|
|
16
|
-
from ...util.abundance import abundance_table, element_data
|
|
17
|
-
from ..abc import AdditiveComponent
|
|
18
|
-
from .apec_loaders import get_continuum, get_lines, get_pseudo, get_temperature
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
@jax.jit
|
|
22
|
-
def lerp(x, x0, x1, y0, y1):
|
|
23
|
-
"""
|
|
24
|
-
Linear interpolation routine
|
|
25
|
-
Return y(x) = (y0 * (x1 - x) + y1 * (x - x0)) / (x1 - x0)
|
|
26
|
-
"""
|
|
27
|
-
return (y0 * (x1 - x) + y1 * (x - x0)) / (x1 - x0)
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
@jax.jit
|
|
31
|
-
def interp_and_integrate(energy_low, energy_high, energy_ref, continuum_ref, end_index):
|
|
32
|
-
"""
|
|
33
|
-
This function interpolate & integrate the values of a tabulated reference continuum between two energy limits
|
|
34
|
-
Sorry for the boilerplate here, but be sure that it works !
|
|
35
|
-
|
|
36
|
-
Parameters:
|
|
37
|
-
energy_low: lower limit of the integral
|
|
38
|
-
energy_high: upper limit of the integral
|
|
39
|
-
energy_ref: energy grid of the reference continuum
|
|
40
|
-
continuum_ref: continuum values evaluated at energy_ref
|
|
41
|
-
|
|
42
|
-
"""
|
|
43
|
-
energy_ref = jnp.where(jnp.arange(energy_ref.shape[0]) < end_index, energy_ref, jnp.nan)
|
|
44
|
-
start_index = jnp.searchsorted(energy_ref, energy_low, side="left") - 1
|
|
45
|
-
end_index = jnp.searchsorted(energy_ref, energy_high, side="left") + 1
|
|
46
|
-
|
|
47
|
-
def body_func(index, value):
|
|
48
|
-
integrated_flux, previous_energy, previous_continuum = value
|
|
49
|
-
current_energy, current_continuum = energy_ref[index], continuum_ref[index]
|
|
50
|
-
|
|
51
|
-
# 5 cases
|
|
52
|
-
# Neither current and previous energies are within the integral limits > nothing is added to the integrated flux
|
|
53
|
-
# The left limit of the integral is between the current and previous energy > previous energy is set to the limit, previous continuum is interpolated, and then added to the integrated flux
|
|
54
|
-
# The right limit of the integral is between the current and previous energy > current energy is set to the limit, current continuum is interpolated, and then added to the integrated flux
|
|
55
|
-
# Both current and previous energies are within the integral limits -> add to the integrated flux
|
|
56
|
-
# Within
|
|
57
|
-
|
|
58
|
-
current_energy_is_between = (energy_low <= current_energy) * (current_energy < energy_high)
|
|
59
|
-
previous_energy_is_between = (energy_low <= previous_energy) * (
|
|
60
|
-
previous_energy < energy_high
|
|
61
|
-
)
|
|
62
|
-
energies_within_bins = (previous_energy <= energy_low) * (energy_high < current_energy)
|
|
63
|
-
|
|
64
|
-
case = (
|
|
65
|
-
(1 - previous_energy_is_between) * current_energy_is_between * 1
|
|
66
|
-
+ previous_energy_is_between * (1 - current_energy_is_between) * 2
|
|
67
|
-
+ (previous_energy_is_between * current_energy_is_between) * 3
|
|
68
|
-
+ energies_within_bins * 4
|
|
69
|
-
)
|
|
70
|
-
|
|
71
|
-
term_to_add = lax.switch(
|
|
72
|
-
case,
|
|
73
|
-
[
|
|
74
|
-
lambda pe, pc, ce, cc, el, er: 0.0, # 1
|
|
75
|
-
lambda pe, pc, ce, cc, el, er: (cc + lerp(el, pe, ce, pc, cc)) * (ce - el) / 2, # 2
|
|
76
|
-
lambda pe, pc, ce, cc, el, er: (pc + lerp(er, pe, ce, pc, cc)) * (er - pe) / 2, # 3
|
|
77
|
-
lambda pe, pc, ce, cc, el, er: (pc + cc) * (ce - pe) / 2, # 4
|
|
78
|
-
lambda pe, pc, ce, cc, el, er: (lerp(el, pe, ce, pc, cc) + lerp(er, pe, ce, pc, cc))
|
|
79
|
-
* (er - el)
|
|
80
|
-
/ 2,
|
|
81
|
-
# 5
|
|
82
|
-
],
|
|
83
|
-
previous_energy,
|
|
84
|
-
previous_continuum,
|
|
85
|
-
current_energy,
|
|
86
|
-
current_continuum,
|
|
87
|
-
energy_low,
|
|
88
|
-
energy_high,
|
|
89
|
-
)
|
|
90
|
-
|
|
91
|
-
return integrated_flux + term_to_add, current_energy, current_continuum
|
|
92
|
-
|
|
93
|
-
integrated_flux, _, _ = fori_loop(start_index, end_index, body_func, (0.0, 0.0, 0.0))
|
|
94
|
-
|
|
95
|
-
return integrated_flux
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
@jax.jit
|
|
99
|
-
def interp(e_low, e_high, energy_ref, continuum_ref, end_index):
|
|
100
|
-
energy_ref = jnp.where(jnp.arange(energy_ref.shape[0]) < end_index, energy_ref, jnp.nan)
|
|
101
|
-
|
|
102
|
-
return (
|
|
103
|
-
jnp.interp(e_high, energy_ref, continuum_ref) - jnp.interp(e_low, energy_ref, continuum_ref)
|
|
104
|
-
) / (e_high - e_low)
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
@jax.jit
|
|
108
|
-
def interp_flux(energy, energy_ref, continuum_ref, end_index):
|
|
109
|
-
"""
|
|
110
|
-
Iterate through an array of shape (energy_ref,) and compute the flux between the bins defined by energy
|
|
111
|
-
"""
|
|
112
|
-
|
|
113
|
-
def scanned_func(carry, unpack):
|
|
114
|
-
e_low, e_high = unpack
|
|
115
|
-
continuum = interp_and_integrate(e_low, e_high, energy_ref, continuum_ref, end_index)
|
|
116
|
-
|
|
117
|
-
return carry, continuum
|
|
118
|
-
|
|
119
|
-
_, continuum = scan(scanned_func, 0.0, (energy[:-1], energy[1:]))
|
|
120
|
-
|
|
121
|
-
return continuum
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
@jax.jit
|
|
125
|
-
def interp_flux_elements(energy_ref, continuum_ref, end_index, energy, abundances):
|
|
126
|
-
"""
|
|
127
|
-
Iterate through an array of shape (abundance, energy_ref) and compute the flux between the bins defined by energy
|
|
128
|
-
and weight the flux depending on the abundance of each element
|
|
129
|
-
"""
|
|
130
|
-
|
|
131
|
-
def scanned_func(_, unpack):
|
|
132
|
-
energy_ref, continuum_ref, end_idx = unpack
|
|
133
|
-
element_flux = interp_flux(energy, energy_ref, continuum_ref, end_idx)
|
|
134
|
-
|
|
135
|
-
return _, element_flux
|
|
136
|
-
|
|
137
|
-
_, flux = scan(scanned_func, 0.0, (energy_ref, continuum_ref, end_index))
|
|
138
|
-
|
|
139
|
-
return abundances @ flux
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
@jax.jit
|
|
143
|
-
def get_lines_contribution_broadening(
|
|
144
|
-
line_energy, line_element, line_emissivity, end_index, energy, abundances, total_broadening
|
|
145
|
-
):
|
|
146
|
-
def body_func(i, flux):
|
|
147
|
-
# Notice the -1 in line element to match the 0-based indexing
|
|
148
|
-
l_energy, l_emissivity, l_element = line_energy[i], line_emissivity[i], line_element[i] - 1
|
|
149
|
-
broadening = l_energy * total_broadening[l_element]
|
|
150
|
-
l_flux = gaussian.cdf(energy[1:], l_energy, broadening) - gaussian.cdf(
|
|
151
|
-
energy[:-1], l_energy, broadening
|
|
152
|
-
)
|
|
153
|
-
l_flux = l_flux * l_emissivity * abundances[l_element]
|
|
154
|
-
|
|
155
|
-
return flux + l_flux
|
|
156
|
-
|
|
157
|
-
return fori_loop(0, end_index, body_func, jnp.zeros_like(energy[:-1]))
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
@jax.jit
|
|
161
|
-
def continuum_func(energy, kT, abundances):
|
|
162
|
-
idx, kT_low, kT_high = get_temperature(kT)
|
|
163
|
-
continuum_low = interp_flux_elements(*get_continuum(idx), energy, abundances)
|
|
164
|
-
continuum_high = interp_flux_elements(*get_continuum(idx + 1), energy, abundances)
|
|
165
|
-
|
|
166
|
-
return lerp(kT, kT_low, kT_high, continuum_low, continuum_high)
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
@jax.jit
|
|
170
|
-
def pseudo_func(energy, kT, abundances):
|
|
171
|
-
idx, kT_low, kT_high = get_temperature(kT)
|
|
172
|
-
continuum_low = interp_flux_elements(*get_pseudo(idx), energy, abundances)
|
|
173
|
-
continuum_high = interp_flux_elements(*get_pseudo(idx + 1), energy, abundances)
|
|
174
|
-
|
|
175
|
-
return lerp(kT, kT_low, kT_high, continuum_low, continuum_high)
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
# @jax.custom_jvp
|
|
179
|
-
@jax.jit
|
|
180
|
-
def lines_func(energy, kT, abundances, broadening):
|
|
181
|
-
idx, kT_low, kT_high = get_temperature(kT)
|
|
182
|
-
line_low = get_lines_contribution_broadening(*get_lines(idx), energy, abundances, broadening)
|
|
183
|
-
line_high = get_lines_contribution_broadening(
|
|
184
|
-
*get_lines(idx + 1), energy, abundances, broadening
|
|
185
|
-
)
|
|
186
|
-
|
|
187
|
-
return lerp(kT, kT_low, kT_high, line_low, line_high)
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
class APEC(AdditiveComponent):
|
|
191
|
-
"""
|
|
192
|
-
APEC model implementation in pure JAX for X-ray spectral fitting.
|
|
193
|
-
|
|
194
|
-
!!! warning
|
|
195
|
-
This implementation is optimised for the CPU, it shows poor performance on the GPU.
|
|
196
|
-
"""
|
|
197
|
-
|
|
198
|
-
def __init__(
|
|
199
|
-
self,
|
|
200
|
-
continuum: bool = True,
|
|
201
|
-
pseudo: bool = True,
|
|
202
|
-
lines: bool = True,
|
|
203
|
-
thermal_broadening: bool = True,
|
|
204
|
-
turbulent_broadening: bool = True,
|
|
205
|
-
variant: Literal["none", "v", "vv"] = "none",
|
|
206
|
-
abundance_table: Literal[
|
|
207
|
-
"angr", "aspl", "feld", "aneb", "grsa", "wilm", "lodd", "lgpp", "lgps"
|
|
208
|
-
] = "angr",
|
|
209
|
-
trace_abundance: float = 1.0,
|
|
210
|
-
**kwargs,
|
|
211
|
-
):
|
|
212
|
-
super().__init__(**kwargs)
|
|
213
|
-
|
|
214
|
-
warnings.warn("Be aware that this APEC implementation is not meant to be used yet")
|
|
215
|
-
|
|
216
|
-
self.atomic_weights = jnp.asarray(element_data["atomic_weight"].to_numpy())
|
|
217
|
-
|
|
218
|
-
self.abundance_table = abundance_table
|
|
219
|
-
self.thermal_broadening = thermal_broadening
|
|
220
|
-
self.turbulent_broadening = turbulent_broadening
|
|
221
|
-
self.continuum_to_compute = continuum
|
|
222
|
-
self.pseudo_to_compute = pseudo
|
|
223
|
-
self.lines_to_compute = lines
|
|
224
|
-
self.trace_abundance = trace_abundance
|
|
225
|
-
self.variant = variant
|
|
226
|
-
|
|
227
|
-
def get_thermal_broadening(self):
|
|
228
|
-
r"""
|
|
229
|
-
Compute the thermal broadening $\sigma_T$ for each element using :
|
|
230
|
-
|
|
231
|
-
$$ \frac{\sigma_T}{E_{\text{line}}} = \frac{1}{c}\sqrt{\frac{k_{B} T}{A m_p}}$$
|
|
232
|
-
|
|
233
|
-
where $E_{\text{line}}$ is the energy of the line, $c$ is the speed of light, $k_{B}$ is the Boltzmann constant,
|
|
234
|
-
$T$ is the temperature, $A$ is the atomic weight of the element and $m_p$ is the proton mass.
|
|
235
|
-
"""
|
|
236
|
-
|
|
237
|
-
if self.thermal_broadening:
|
|
238
|
-
kT = hk.get_parameter("kT", [], init=HaikuConstant(6.5))
|
|
239
|
-
factor = 1 / c * (1 / m_p) ** (1 / 2)
|
|
240
|
-
factor = factor.to(u.keV ** (-1 / 2)).value
|
|
241
|
-
|
|
242
|
-
# Multiply this factor by Line_Energy * sqrt(kT/A) to get the broadening for a line
|
|
243
|
-
# This return value must be multiplied by the energy of the line to get actual broadening
|
|
244
|
-
return factor * jnp.sqrt(kT / self.atomic_weights)
|
|
245
|
-
|
|
246
|
-
else:
|
|
247
|
-
return jnp.zeros((30,))
|
|
248
|
-
|
|
249
|
-
def get_turbulent_broadening(self):
|
|
250
|
-
r"""
|
|
251
|
-
Return the turbulent broadening using :
|
|
252
|
-
|
|
253
|
-
$$\frac{\sigma_\text{turb}}{E_{\text{line}}} = \frac{\sigma_{v ~ ||}}{c}$$
|
|
254
|
-
|
|
255
|
-
where $\sigma_{v ~ ||}$ is the velocity dispersion along the line of sight in km/s.
|
|
256
|
-
"""
|
|
257
|
-
if self.turbulent_broadening:
|
|
258
|
-
# This return value must be multiplied by the energy of the line to get actual broadening
|
|
259
|
-
return (
|
|
260
|
-
hk.get_parameter("Velocity", [], init=HaikuConstant(100.0)) / c.to(u.km / u.s).value
|
|
261
|
-
)
|
|
262
|
-
else:
|
|
263
|
-
return 0.0
|
|
264
|
-
|
|
265
|
-
def get_parameters(self):
|
|
266
|
-
none_elements = ["C", "N", "O", "Ne", "Mg", "Al", "Si", "S", "Ar", "Ca", "Fe", "Ni"]
|
|
267
|
-
v_elements = ["He", "C", "N", "O", "Ne", "Mg", "Al", "Si", "S", "Ar", "Ca", "Fe", "Ni"]
|
|
268
|
-
trace_elements = (
|
|
269
|
-
jnp.asarray([3, 4, 5, 9, 11, 15, 17, 19, 21, 22, 23, 24, 25, 27, 29, 30], dtype=int) - 1
|
|
270
|
-
)
|
|
271
|
-
|
|
272
|
-
# Set abundances of trace element (will be overwritten in the vv case)
|
|
273
|
-
abund = jnp.ones((30,)).at[trace_elements].multiply(self.trace_abundance)
|
|
274
|
-
|
|
275
|
-
if self.variant == "vv":
|
|
276
|
-
for i, element in enumerate(abundance_table["Element"]):
|
|
277
|
-
if element != "H":
|
|
278
|
-
abund = abund.at[i].set(hk.get_parameter(element, [], init=HaikuConstant(1.0)))
|
|
279
|
-
|
|
280
|
-
elif self.variant == "v":
|
|
281
|
-
for i, element in enumerate(abundance_table["Element"]):
|
|
282
|
-
if element != "H" and element in v_elements:
|
|
283
|
-
abund = abund.at[i].set(hk.get_parameter(element, [], init=HaikuConstant(1.0)))
|
|
284
|
-
|
|
285
|
-
else:
|
|
286
|
-
Z = hk.get_parameter("Abundance", [], init=HaikuConstant(1.0))
|
|
287
|
-
for i, element in enumerate(abundance_table["Element"]):
|
|
288
|
-
if element != "H" and element in none_elements:
|
|
289
|
-
abund = abund.at[i].set(Z)
|
|
290
|
-
|
|
291
|
-
if abund != "angr":
|
|
292
|
-
abund = abund * jnp.asarray(
|
|
293
|
-
abundance_table[self.abundance_table] / abundance_table["angr"]
|
|
294
|
-
)
|
|
295
|
-
|
|
296
|
-
# Set the temperature, redshift, normalisation
|
|
297
|
-
kT = hk.get_parameter("kT", [], init=HaikuConstant(6.5))
|
|
298
|
-
z = hk.get_parameter("Redshift", [], init=HaikuConstant(0.0))
|
|
299
|
-
norm = hk.get_parameter("norm", [], init=HaikuConstant(1.0))
|
|
300
|
-
|
|
301
|
-
return kT, z, norm, abund
|
|
302
|
-
|
|
303
|
-
def emission_lines(self, e_low, e_high):
|
|
304
|
-
# Get the parameters and extract the relevant data
|
|
305
|
-
energy = jnp.hstack([e_low, e_high[-1]])
|
|
306
|
-
kT, z, norm, abundances = self.get_parameters()
|
|
307
|
-
total_broadening = jnp.hypot(self.get_thermal_broadening(), self.get_turbulent_broadening())
|
|
308
|
-
energy = energy * (1 + z)
|
|
309
|
-
|
|
310
|
-
continuum = continuum_func(energy, kT, abundances) if self.continuum_to_compute else 0.0
|
|
311
|
-
pseudo_continuum = pseudo_func(energy, kT, abundances) if self.pseudo_to_compute else 0.0
|
|
312
|
-
lines = (
|
|
313
|
-
lines_func(energy, kT, abundances, total_broadening) if self.lines_to_compute else 0.0
|
|
314
|
-
)
|
|
315
|
-
|
|
316
|
-
return (continuum + pseudo_continuum + lines) * norm * 1e14 / (1 + z), (e_low + e_high) / 2
|