jaxspec 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxspec/model/abc.py CHANGED
@@ -1,12 +1,17 @@
1
1
  from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from uuid import uuid4
5
+
2
6
  import haiku as hk
3
7
  import jax
4
8
  import jax.numpy as jnp
5
9
  import networkx as nx
10
+ import rich
11
+
6
12
  from haiku._src import base
7
- from uuid import uuid4
8
13
  from jax.scipy.integrate import trapezoid
9
- from abc import ABC
14
+ from rich.table import Table
10
15
  from simpleeval import simple_eval
11
16
 
12
17
 
@@ -107,6 +112,30 @@ class SpectralModel:
107
112
  def params(self):
108
113
  return self.transformed_func_photon.init(None, jnp.ones(10), jnp.ones(10))
109
114
 
115
+ def __rich_repr__(self):
116
+ table = Table(title=str(self))
117
+
118
+ table.add_column("Component", justify="right", style="bold", no_wrap=True)
119
+ table.add_column("Parameter")
120
+
121
+ params = self.params
122
+
123
+ for component in params.keys():
124
+ once = True
125
+
126
+ for parameters in params[component].keys():
127
+ table.add_row(component if once else "", parameters)
128
+ once = False
129
+
130
+ return table
131
+
132
+ def __repr_html_(self):
133
+ return self.__rich_repr__()
134
+
135
+ def __repr__(self):
136
+ rich.print(self.__rich_repr__())
137
+ return ""
138
+
110
139
  def photon_flux(self, params, e_low, e_high, n_points=2):
111
140
  r"""
112
141
  Compute the expected counts between $E_\min$ and $E_\max$ by integrating the model.
@@ -215,16 +244,18 @@ class SpectralModel:
215
244
  continuum[node_id] = runtime_modules[node_id].continuum(energies)
216
245
 
217
246
  elif node and node["type"] == "operation":
218
- component_1 = list(self.graph.in_edges(node_id))[0][0]
247
+ component_1 = list(self.graph.in_edges(node_id))[0][0] # noqa: RUF015
219
248
  component_2 = list(self.graph.in_edges(node_id))[1][0]
220
- continuum[node_id] = node["function"](continuum[component_1], continuum[component_2])
249
+ continuum[node_id] = node["function"](
250
+ continuum[component_1], continuum[component_2]
251
+ )
221
252
 
222
253
  if n_points == 2:
223
- flux_1D = continuum[list(self.graph.in_edges("out"))[0][0]]
254
+ flux_1D = continuum[list(self.graph.in_edges("out"))[0][0]] # noqa: RUF015
224
255
  flux = jnp.stack((flux_1D[:-1], flux_1D[1:]))
225
256
 
226
257
  else:
227
- flux = continuum[list(self.graph.in_edges("out"))[0][0]]
258
+ flux = continuum[list(self.graph.in_edges("out"))[0][0]] # noqa: RUF015
228
259
 
229
260
  if energy_flux:
230
261
  continuum_flux = trapezoid(
@@ -234,7 +265,9 @@ class SpectralModel:
234
265
  )
235
266
 
236
267
  else:
237
- continuum_flux = trapezoid(flux * energies_to_integrate, x=jnp.log(energies_to_integrate), axis=0)
268
+ continuum_flux = trapezoid(
269
+ flux * energies_to_integrate, x=jnp.log(energies_to_integrate), axis=0
270
+ )
238
271
 
239
272
  # Iterate from the root nodes to the output node and
240
273
  # compute the fine structure contribution for each component
@@ -249,14 +282,18 @@ class SpectralModel:
249
282
  path = nx.shortest_path(self.graph, source=root_node_id, target="out")
250
283
  nodes_id_in_path = [node_id for node_id in path]
251
284
 
252
- flux_from_component, mean_energy = runtime_modules[root_node_id].emission_lines(e_low, e_high)
285
+ flux_from_component, mean_energy = runtime_modules[root_node_id].emission_lines(
286
+ e_low, e_high
287
+ )
253
288
 
254
289
  multiplicative_nodes = []
255
290
 
256
291
  # Search all multiplicative components connected to this node
257
292
  # and apply them at mean energy
258
293
  for node_id in nodes_id_in_path[::-1]:
259
- multiplicative_nodes.extend([node_id for node_id in self.find_multiplicative_components(node_id)])
294
+ multiplicative_nodes.extend(
295
+ [node_id for node_id in self.find_multiplicative_components(node_id)]
296
+ )
260
297
 
261
298
  for mul_node in multiplicative_nodes:
262
299
  flux_from_component *= runtime_modules[mul_node].continuum(mean_energy)
@@ -309,7 +346,10 @@ class SpectralModel:
309
346
  if component.type == "additive":
310
347
 
311
348
  def lam_func(e):
312
- return component(**kwargs).continuum(e) + component(**kwargs).emission_lines(e, e + 1)[0]
349
+ return (
350
+ component(**kwargs).continuum(e)
351
+ + component(**kwargs).emission_lines(e, e + 1)[0]
352
+ )
313
353
 
314
354
  elif component.type == "multiplicative":
315
355
 
@@ -342,7 +382,9 @@ class SpectralModel:
342
382
 
343
383
  return cls(graph, labels)
344
384
 
345
- def compose(self, other: SpectralModel, operation=None, function=None, name=None) -> SpectralModel:
385
+ def compose(
386
+ self, other: SpectralModel, operation=None, function=None, name=None
387
+ ) -> SpectralModel:
346
388
  """
347
389
  This function operate a composition between the operation graph of two models
348
390
  1) It fuses the two graphs using which joins at the 'out' nodes
@@ -524,3 +566,10 @@ class AdditiveComponent(ModelComponent, ABC):
524
566
 
525
567
  return jnp.trapz(self(x) * dx, x=t)
526
568
  '''
569
+
570
+
571
+ class MultiplicativeComponent(ModelComponent, ABC):
572
+ type = "multiplicative"
573
+
574
+ @abstractmethod
575
+ def continuum(self, energy): ...
jaxspec/model/additive.py CHANGED
@@ -14,8 +14,6 @@ from haiku.initializers import Constant as HaikuConstant
14
14
 
15
15
  from ..util.integrate import integrate_interval
16
16
  from ..util.online_storage import table_manager
17
-
18
- # from ._additive.apec import APEC
19
17
  from .abc import AdditiveComponent
20
18
 
21
19
 
@@ -38,7 +36,7 @@ class Powerlaw(AdditiveComponent):
38
36
  return norm * energy ** (-alpha)
39
37
 
40
38
 
41
- class AdditiveConstant(AdditiveComponent):
39
+ class Additiveconstant(AdditiveComponent):
42
40
  r"""
43
41
  A constant model
44
42
 
@@ -1,7 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- from abc import ABC, abstractmethod
4
-
5
3
  import haiku as hk
6
4
  import jax.numpy as jnp
7
5
  import numpy as np
@@ -10,14 +8,7 @@ from astropy.table import Table
10
8
  from haiku.initializers import Constant as HaikuConstant
11
9
 
12
10
  from ..util.online_storage import table_manager
13
- from .abc import ModelComponent
14
-
15
-
16
- class MultiplicativeComponent(ModelComponent, ABC):
17
- type = "multiplicative"
18
-
19
- @abstractmethod
20
- def continuum(self, energy): ...
11
+ from .abc import MultiplicativeComponent
21
12
 
22
13
 
23
14
  class Expfac(MultiplicativeComponent):
@@ -226,6 +217,7 @@ class Tbpcf(MultiplicativeComponent):
226
217
 
227
218
  return f * jnp.exp(-nh * sigma) + (1 - f)
228
219
 
220
+
229
221
  class FDcut(MultiplicativeComponent):
230
222
  r"""
231
223
  A Fermi-Dirac cutoff model.
@@ -243,4 +235,4 @@ class FDcut(MultiplicativeComponent):
243
235
  cutoff = hk.get_parameter("E_c", [], init=HaikuConstant(1))
244
236
  folding = hk.get_parameter("E_f", [], init=HaikuConstant(1))
245
237
 
246
- return (1 + jnp.exp((energy - cutoff)/folding)) ** -1
238
+ return (1 + jnp.exp((energy - cutoff) / folding)) ** -1
jaxspec/util/typing.py CHANGED
@@ -9,6 +9,13 @@ from pydantic import BaseModel, field_validator
9
9
  PriorDictType = dict[str, dict[str, dist.Distribution | ArrayLike]]
10
10
 
11
11
 
12
+ def is_flat_dict(input_data: dict[str, Any]) -> bool:
13
+ """
14
+ Check if the input data is a flat dictionary with string keys and non-dictionary values.
15
+ """
16
+ return all(isinstance(k, str) and not isinstance(v, dict) for k, v in input_data.items())
17
+
18
+
12
19
  class PriorDictModel(BaseModel):
13
20
  """
14
21
  Pydantic model for a nested dictionary of NumPyro distributions or JAX arrays.
@@ -21,6 +28,23 @@ class PriorDictModel(BaseModel):
21
28
  class Config: # noqa D106
22
29
  arbitrary_types_allowed = True
23
30
 
31
+ @classmethod
32
+ def from_dict(cls, input_prior: dict[str, Any]):
33
+ if is_flat_dict(input_prior):
34
+ nested_dict = {}
35
+
36
+ for key, obj in input_prior.items():
37
+ component, component_number, *parameter = key.split("_")
38
+
39
+ sub_dict = nested_dict.get(f"{component}_{component_number}", {})
40
+ sub_dict["_".join(parameter)] = obj
41
+
42
+ nested_dict[f"{component}_{component_number}"] = sub_dict
43
+
44
+ return cls(nested_dict=nested_dict)
45
+
46
+ return cls(nested_dict=input_prior)
47
+
24
48
  @field_validator("nested_dict", mode="before")
25
49
  def check_and_cast_nested_dict(cls, value: dict[str, Any]):
26
50
  if not isinstance(value, dict):
@@ -35,9 +59,10 @@ class PriorDictModel(BaseModel):
35
59
  try:
36
60
  # Attempt to cast to JAX array
37
61
  value[key][inner_key] = jnp.array(obj, dtype=float)
62
+
38
63
  except Exception as e:
39
64
  raise ValueError(
40
- f'The value for key "{inner_key}" in inner dictionary must '
41
- f"be a NumPyro distribution or castable to JAX array. Error: {e}"
65
+ f'The value for key "{inner_key}" in {key} be a NumPyro '
66
+ f"distribution or castable to JAX array. Error: {e}"
42
67
  )
43
68
  return value
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxspec
3
- Version: 0.1.0
3
+ Version: 0.1.2
4
4
  Summary: jaxspec is a bayesian spectral fitting library for X-ray astronomy.
5
5
  Home-page: https://github.com/renecotyfanboy/jaxspec
6
6
  License: MIT
@@ -15,19 +15,18 @@ Requires-Dist: arviz (>=0.17.1,<0.20.0)
15
15
  Requires-Dist: astropy (>=6.0.0,<7.0.0)
16
16
  Requires-Dist: chainconsumer (>=1.1.2,<2.0.0)
17
17
  Requires-Dist: cmasher (>=1.6.3,<2.0.0)
18
- Requires-Dist: dm-haiku (>=0.0.11,<0.0.13)
18
+ Requires-Dist: dm-haiku (>=0.0.12,<0.0.13)
19
19
  Requires-Dist: gpjax (>=0.8.0,<0.9.0)
20
20
  Requires-Dist: interpax (>=0.3.3,<0.4.0)
21
- Requires-Dist: jax (>=0.4.30,<0.5.0)
21
+ Requires-Dist: jax (>=0.4.33,<0.5.0)
22
22
  Requires-Dist: jaxlib (>=0.4.30,<0.5.0)
23
- Requires-Dist: jaxns (>=2.5.1,<3.0.0)
23
+ Requires-Dist: jaxns (<2.6)
24
24
  Requires-Dist: jaxopt (>=0.8.1,<0.9.0)
25
25
  Requires-Dist: matplotlib (>=3.8.0,<4.0.0)
26
26
  Requires-Dist: mendeleev (>=0.15,<0.18)
27
- Requires-Dist: mkdocstrings (>=0.24,<0.26)
28
27
  Requires-Dist: networkx (>=3.1,<4.0)
29
28
  Requires-Dist: numpy (<2.0.0)
30
- Requires-Dist: numpyro (>=0.15.2,<0.16.0)
29
+ Requires-Dist: numpyro (>=0.15.3,<0.16.0)
31
30
  Requires-Dist: optimistix (>=0.0.7,<0.0.8)
32
31
  Requires-Dist: pandas (>=2.2.0,<3.0.0)
33
32
  Requires-Dist: pooch (>=1.8.2,<2.0.0)
@@ -41,7 +40,14 @@ Requires-Dist: watermark (>=2.4.3,<3.0.0)
41
40
  Project-URL: Documentation, https://jaxspec.readthedocs.io/en/latest/
42
41
  Description-Content-Type: text/markdown
43
42
 
44
- # jaxspec
43
+ <p align="center">
44
+ <img src="https://raw.githubusercontent.com/renecotyfanboy/jaxspec/main/docs/logo/logo_small.svg" alt="Logo" width="100" height="100">
45
+ </p>
46
+
47
+ <h1 align="center">
48
+ jaxspec
49
+ </h1>
50
+
45
51
 
46
52
  [![PyPI - Version](https://img.shields.io/pypi/v/jaxspec?style=for-the-badge&logo=pypi&color=rgb(37%2C%20150%2C%20190))](https://pypi.org/project/jaxspec/)
47
53
  [![Python package](https://img.shields.io/pypi/pyversions/jaxspec?style=for-the-badge)](https://pypi.org/project/jaxspec/)
@@ -1,33 +1,31 @@
1
1
  jaxspec/__init__.py,sha256=Sbn02lX6Y-zNXk17N8dec22c5jeypiS0LkHmGfz7lWA,126
2
2
  jaxspec/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ jaxspec/analysis/_plot.py,sha256=C4XljmuzQz8xQur_jQddgInrBDmKgTn0eugSreLoD5k,862
3
4
  jaxspec/analysis/compare.py,sha256=g2UFhmR9Zt-7cz5gQFOB6lXuklXB3yTyUvjTypOzoSY,725
4
5
  jaxspec/analysis/results.py,sha256=Kz3eryxS3N_hiajcFLTWS1dtgTQo5hlh-rDCnJ3A-3c,27811
5
6
  jaxspec/data/__init__.py,sha256=aantcYKC9kZFvaE-V2SIwSuLhIld17Kjrd9CIUu___Y,415
6
7
  jaxspec/data/grouping.py,sha256=hhgBt-voiH0DDSyePacaIGsaMnrYbJM_-ZeU66keC7I,622
7
8
  jaxspec/data/instrument.py,sha256=0pSf1p82g7syDMmKm13eVbYih-Veiq5DnwsyZe6_b4g,3890
8
- jaxspec/data/obsconf.py,sha256=0X9jR-pV-Pk4-EVuUdlVWgl_gBx8ZurVkRNrfKQWdC4,8663
9
+ jaxspec/data/obsconf.py,sha256=gv14sL6azK2avRiMCWuTbyLBPulzm4PwvoLY6iWPEVE,9833
9
10
  jaxspec/data/observation.py,sha256=1UnFu5ihZp9z-vP_I7tsFY8jhhIJunv46JyuE-acrg0,6394
10
11
  jaxspec/data/ogip.py,sha256=sv9p00qHS5pzw61pzWyyF0nV-E-RXySdSFK2tUavokA,9545
11
12
  jaxspec/data/util.py,sha256=ycLPVE-cjn6VpUWYlBU1BGfw73ANXIBilyVAUOYOSj0,9540
12
- jaxspec/fit.py,sha256=C9hQxMJz1nLu47rHWkiKx7J7oS1bXow_kMKwswsJy8U,24791
13
+ jaxspec/fit.py,sha256=hI0koMO4KsNpe9mLlaFm_tNLgm4BVAYVyiMb1E1eyZE,24553
13
14
  jaxspec/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
- jaxspec/model/_additive/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
- jaxspec/model/_additive/apec.py,sha256=r7CQqscAgR0BXC_AJqF6B7CPq3Byoo65Z-h9XgACZeU,12460
16
- jaxspec/model/_additive/apec_loaders.py,sha256=jkUoH0ezeYdaNw3oV10V0L-jt848SKp2thanLWLWp9k,2412
17
- jaxspec/model/abc.py,sha256=SWjKOOsqU5UJsVy63Tt9dDq8H2eTIbvK2C9iqgiR0cY,19817
18
- jaxspec/model/additive.py,sha256=CT2K2DVVeHKN1tee9-J3MYdEPqEOolLB2E7HU-RJKZw,22485
15
+ jaxspec/model/abc.py,sha256=MuxEyvn223QPwGoFIJiST8nRMgrZ08ZLkw33oep3tx4,20887
16
+ jaxspec/model/additive.py,sha256=wjY2wL3Io3F45GJpz-UB8xYVnA-W1OFBnZMbj5pWPbQ,22449
19
17
  jaxspec/model/background.py,sha256=QSFFiuyUEvuzXBx3QfkvVneUR8KKEP-VaANEVXcavDE,7865
20
18
  jaxspec/model/list.py,sha256=0RPAoscVz_zM1CWdx_Gd5wfrQWV5Nv4Kd4bSXu2ayUA,860
21
- jaxspec/model/multiplicative.py,sha256=TG3PCgS7oCuHwJ4TM4whw6pz318oo9MVvjSs4sQZVPc,8300
19
+ jaxspec/model/multiplicative.py,sha256=GCQ6JRz92QqbzDBFwWxGZ9SUqTJZQpD7B6ji9VEFXWo,8135
22
20
  jaxspec/scripts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
21
  jaxspec/scripts/debug.py,sha256=RIRykBW_kzt8__PhOohQn2xtDW6oAz49E8rmuR5ewAU,293
24
22
  jaxspec/util/__init__.py,sha256=vKurfp7p2hxHptJjXhXqFAXAikAGXAqISMJUqPeiGTw,1259
25
23
  jaxspec/util/abundance.py,sha256=fsC313taIlGzQsZNwbYsJupDWm7ZbqzGhY66Ku394Mw,8546
26
24
  jaxspec/util/integrate.py,sha256=_Ax_knpC7d4et2-QFkOUzVtNeQLX1-cwLvm-FRBxYcw,4505
27
25
  jaxspec/util/online_storage.py,sha256=vm56RfcbFKpkRVfr0bXO7J9aQxuBq-I_oEgA26YIhCo,2469
28
- jaxspec/util/typing.py,sha256=qwZMKHivZlozoo0ESsiaQNkG99Dh3PE2Z-5aOQD9zc0,1650
29
- jaxspec-0.1.0.dist-info/LICENSE.md,sha256=2q5XoWzddts5IqzIcgYYMOL21puU3MfO8gvT3Ype1eQ,1073
30
- jaxspec-0.1.0.dist-info/METADATA,sha256=ONyl92xPTT1LNXKNkonzdH4IlWFlkGdDo-09EEGur9c,3572
31
- jaxspec-0.1.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
32
- jaxspec-0.1.0.dist-info/entry_points.txt,sha256=kzLG2mGlCWITRn4Q6zKG_idx-_RKAncvA0DMNYTgHAg,71
33
- jaxspec-0.1.0.dist-info/RECORD,,
26
+ jaxspec/util/typing.py,sha256=8qK1aJlsqTcVKjYN-BxsDx20BTwtnS-wMw6Bdurpm-o,2459
27
+ jaxspec-0.1.2.dist-info/LICENSE.md,sha256=2q5XoWzddts5IqzIcgYYMOL21puU3MfO8gvT3Ype1eQ,1073
28
+ jaxspec-0.1.2.dist-info/METADATA,sha256=FE2bTAk-3Xryi6fplV4Y-F2eibUdLZgC9ET9_4HvdOA,3708
29
+ jaxspec-0.1.2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
30
+ jaxspec-0.1.2.dist-info/entry_points.txt,sha256=kzLG2mGlCWITRn4Q6zKG_idx-_RKAncvA0DMNYTgHAg,71
31
+ jaxspec-0.1.2.dist-info/RECORD,,
File without changes
@@ -1,316 +0,0 @@
1
- import warnings
2
-
3
- from typing import Literal
4
-
5
- import astropy.units as u
6
- import haiku as hk
7
- import jax
8
- import jax.numpy as jnp
9
-
10
- from astropy.constants import c, m_p
11
- from haiku.initializers import Constant as HaikuConstant
12
- from jax import lax
13
- from jax.lax import fori_loop, scan
14
- from jax.scipy.stats import norm as gaussian
15
-
16
- from ...util.abundance import abundance_table, element_data
17
- from ..abc import AdditiveComponent
18
- from .apec_loaders import get_continuum, get_lines, get_pseudo, get_temperature
19
-
20
-
21
- @jax.jit
22
- def lerp(x, x0, x1, y0, y1):
23
- """
24
- Linear interpolation routine
25
- Return y(x) = (y0 * (x1 - x) + y1 * (x - x0)) / (x1 - x0)
26
- """
27
- return (y0 * (x1 - x) + y1 * (x - x0)) / (x1 - x0)
28
-
29
-
30
- @jax.jit
31
- def interp_and_integrate(energy_low, energy_high, energy_ref, continuum_ref, end_index):
32
- """
33
- This function interpolate & integrate the values of a tabulated reference continuum between two energy limits
34
- Sorry for the boilerplate here, but be sure that it works !
35
-
36
- Parameters:
37
- energy_low: lower limit of the integral
38
- energy_high: upper limit of the integral
39
- energy_ref: energy grid of the reference continuum
40
- continuum_ref: continuum values evaluated at energy_ref
41
-
42
- """
43
- energy_ref = jnp.where(jnp.arange(energy_ref.shape[0]) < end_index, energy_ref, jnp.nan)
44
- start_index = jnp.searchsorted(energy_ref, energy_low, side="left") - 1
45
- end_index = jnp.searchsorted(energy_ref, energy_high, side="left") + 1
46
-
47
- def body_func(index, value):
48
- integrated_flux, previous_energy, previous_continuum = value
49
- current_energy, current_continuum = energy_ref[index], continuum_ref[index]
50
-
51
- # 5 cases
52
- # Neither current and previous energies are within the integral limits > nothing is added to the integrated flux
53
- # The left limit of the integral is between the current and previous energy > previous energy is set to the limit, previous continuum is interpolated, and then added to the integrated flux
54
- # The right limit of the integral is between the current and previous energy > current energy is set to the limit, current continuum is interpolated, and then added to the integrated flux
55
- # Both current and previous energies are within the integral limits -> add to the integrated flux
56
- # Within
57
-
58
- current_energy_is_between = (energy_low <= current_energy) * (current_energy < energy_high)
59
- previous_energy_is_between = (energy_low <= previous_energy) * (
60
- previous_energy < energy_high
61
- )
62
- energies_within_bins = (previous_energy <= energy_low) * (energy_high < current_energy)
63
-
64
- case = (
65
- (1 - previous_energy_is_between) * current_energy_is_between * 1
66
- + previous_energy_is_between * (1 - current_energy_is_between) * 2
67
- + (previous_energy_is_between * current_energy_is_between) * 3
68
- + energies_within_bins * 4
69
- )
70
-
71
- term_to_add = lax.switch(
72
- case,
73
- [
74
- lambda pe, pc, ce, cc, el, er: 0.0, # 1
75
- lambda pe, pc, ce, cc, el, er: (cc + lerp(el, pe, ce, pc, cc)) * (ce - el) / 2, # 2
76
- lambda pe, pc, ce, cc, el, er: (pc + lerp(er, pe, ce, pc, cc)) * (er - pe) / 2, # 3
77
- lambda pe, pc, ce, cc, el, er: (pc + cc) * (ce - pe) / 2, # 4
78
- lambda pe, pc, ce, cc, el, er: (lerp(el, pe, ce, pc, cc) + lerp(er, pe, ce, pc, cc))
79
- * (er - el)
80
- / 2,
81
- # 5
82
- ],
83
- previous_energy,
84
- previous_continuum,
85
- current_energy,
86
- current_continuum,
87
- energy_low,
88
- energy_high,
89
- )
90
-
91
- return integrated_flux + term_to_add, current_energy, current_continuum
92
-
93
- integrated_flux, _, _ = fori_loop(start_index, end_index, body_func, (0.0, 0.0, 0.0))
94
-
95
- return integrated_flux
96
-
97
-
98
- @jax.jit
99
- def interp(e_low, e_high, energy_ref, continuum_ref, end_index):
100
- energy_ref = jnp.where(jnp.arange(energy_ref.shape[0]) < end_index, energy_ref, jnp.nan)
101
-
102
- return (
103
- jnp.interp(e_high, energy_ref, continuum_ref) - jnp.interp(e_low, energy_ref, continuum_ref)
104
- ) / (e_high - e_low)
105
-
106
-
107
- @jax.jit
108
- def interp_flux(energy, energy_ref, continuum_ref, end_index):
109
- """
110
- Iterate through an array of shape (energy_ref,) and compute the flux between the bins defined by energy
111
- """
112
-
113
- def scanned_func(carry, unpack):
114
- e_low, e_high = unpack
115
- continuum = interp_and_integrate(e_low, e_high, energy_ref, continuum_ref, end_index)
116
-
117
- return carry, continuum
118
-
119
- _, continuum = scan(scanned_func, 0.0, (energy[:-1], energy[1:]))
120
-
121
- return continuum
122
-
123
-
124
- @jax.jit
125
- def interp_flux_elements(energy_ref, continuum_ref, end_index, energy, abundances):
126
- """
127
- Iterate through an array of shape (abundance, energy_ref) and compute the flux between the bins defined by energy
128
- and weight the flux depending on the abundance of each element
129
- """
130
-
131
- def scanned_func(_, unpack):
132
- energy_ref, continuum_ref, end_idx = unpack
133
- element_flux = interp_flux(energy, energy_ref, continuum_ref, end_idx)
134
-
135
- return _, element_flux
136
-
137
- _, flux = scan(scanned_func, 0.0, (energy_ref, continuum_ref, end_index))
138
-
139
- return abundances @ flux
140
-
141
-
142
- @jax.jit
143
- def get_lines_contribution_broadening(
144
- line_energy, line_element, line_emissivity, end_index, energy, abundances, total_broadening
145
- ):
146
- def body_func(i, flux):
147
- # Notice the -1 in line element to match the 0-based indexing
148
- l_energy, l_emissivity, l_element = line_energy[i], line_emissivity[i], line_element[i] - 1
149
- broadening = l_energy * total_broadening[l_element]
150
- l_flux = gaussian.cdf(energy[1:], l_energy, broadening) - gaussian.cdf(
151
- energy[:-1], l_energy, broadening
152
- )
153
- l_flux = l_flux * l_emissivity * abundances[l_element]
154
-
155
- return flux + l_flux
156
-
157
- return fori_loop(0, end_index, body_func, jnp.zeros_like(energy[:-1]))
158
-
159
-
160
- @jax.jit
161
- def continuum_func(energy, kT, abundances):
162
- idx, kT_low, kT_high = get_temperature(kT)
163
- continuum_low = interp_flux_elements(*get_continuum(idx), energy, abundances)
164
- continuum_high = interp_flux_elements(*get_continuum(idx + 1), energy, abundances)
165
-
166
- return lerp(kT, kT_low, kT_high, continuum_low, continuum_high)
167
-
168
-
169
- @jax.jit
170
- def pseudo_func(energy, kT, abundances):
171
- idx, kT_low, kT_high = get_temperature(kT)
172
- continuum_low = interp_flux_elements(*get_pseudo(idx), energy, abundances)
173
- continuum_high = interp_flux_elements(*get_pseudo(idx + 1), energy, abundances)
174
-
175
- return lerp(kT, kT_low, kT_high, continuum_low, continuum_high)
176
-
177
-
178
- # @jax.custom_jvp
179
- @jax.jit
180
- def lines_func(energy, kT, abundances, broadening):
181
- idx, kT_low, kT_high = get_temperature(kT)
182
- line_low = get_lines_contribution_broadening(*get_lines(idx), energy, abundances, broadening)
183
- line_high = get_lines_contribution_broadening(
184
- *get_lines(idx + 1), energy, abundances, broadening
185
- )
186
-
187
- return lerp(kT, kT_low, kT_high, line_low, line_high)
188
-
189
-
190
- class APEC(AdditiveComponent):
191
- """
192
- APEC model implementation in pure JAX for X-ray spectral fitting.
193
-
194
- !!! warning
195
- This implementation is optimised for the CPU, it shows poor performance on the GPU.
196
- """
197
-
198
- def __init__(
199
- self,
200
- continuum: bool = True,
201
- pseudo: bool = True,
202
- lines: bool = True,
203
- thermal_broadening: bool = True,
204
- turbulent_broadening: bool = True,
205
- variant: Literal["none", "v", "vv"] = "none",
206
- abundance_table: Literal[
207
- "angr", "aspl", "feld", "aneb", "grsa", "wilm", "lodd", "lgpp", "lgps"
208
- ] = "angr",
209
- trace_abundance: float = 1.0,
210
- **kwargs,
211
- ):
212
- super().__init__(**kwargs)
213
-
214
- warnings.warn("Be aware that this APEC implementation is not meant to be used yet")
215
-
216
- self.atomic_weights = jnp.asarray(element_data["atomic_weight"].to_numpy())
217
-
218
- self.abundance_table = abundance_table
219
- self.thermal_broadening = thermal_broadening
220
- self.turbulent_broadening = turbulent_broadening
221
- self.continuum_to_compute = continuum
222
- self.pseudo_to_compute = pseudo
223
- self.lines_to_compute = lines
224
- self.trace_abundance = trace_abundance
225
- self.variant = variant
226
-
227
- def get_thermal_broadening(self):
228
- r"""
229
- Compute the thermal broadening $\sigma_T$ for each element using :
230
-
231
- $$ \frac{\sigma_T}{E_{\text{line}}} = \frac{1}{c}\sqrt{\frac{k_{B} T}{A m_p}}$$
232
-
233
- where $E_{\text{line}}$ is the energy of the line, $c$ is the speed of light, $k_{B}$ is the Boltzmann constant,
234
- $T$ is the temperature, $A$ is the atomic weight of the element and $m_p$ is the proton mass.
235
- """
236
-
237
- if self.thermal_broadening:
238
- kT = hk.get_parameter("kT", [], init=HaikuConstant(6.5))
239
- factor = 1 / c * (1 / m_p) ** (1 / 2)
240
- factor = factor.to(u.keV ** (-1 / 2)).value
241
-
242
- # Multiply this factor by Line_Energy * sqrt(kT/A) to get the broadening for a line
243
- # This return value must be multiplied by the energy of the line to get actual broadening
244
- return factor * jnp.sqrt(kT / self.atomic_weights)
245
-
246
- else:
247
- return jnp.zeros((30,))
248
-
249
- def get_turbulent_broadening(self):
250
- r"""
251
- Return the turbulent broadening using :
252
-
253
- $$\frac{\sigma_\text{turb}}{E_{\text{line}}} = \frac{\sigma_{v ~ ||}}{c}$$
254
-
255
- where $\sigma_{v ~ ||}$ is the velocity dispersion along the line of sight in km/s.
256
- """
257
- if self.turbulent_broadening:
258
- # This return value must be multiplied by the energy of the line to get actual broadening
259
- return (
260
- hk.get_parameter("Velocity", [], init=HaikuConstant(100.0)) / c.to(u.km / u.s).value
261
- )
262
- else:
263
- return 0.0
264
-
265
- def get_parameters(self):
266
- none_elements = ["C", "N", "O", "Ne", "Mg", "Al", "Si", "S", "Ar", "Ca", "Fe", "Ni"]
267
- v_elements = ["He", "C", "N", "O", "Ne", "Mg", "Al", "Si", "S", "Ar", "Ca", "Fe", "Ni"]
268
- trace_elements = (
269
- jnp.asarray([3, 4, 5, 9, 11, 15, 17, 19, 21, 22, 23, 24, 25, 27, 29, 30], dtype=int) - 1
270
- )
271
-
272
- # Set abundances of trace element (will be overwritten in the vv case)
273
- abund = jnp.ones((30,)).at[trace_elements].multiply(self.trace_abundance)
274
-
275
- if self.variant == "vv":
276
- for i, element in enumerate(abundance_table["Element"]):
277
- if element != "H":
278
- abund = abund.at[i].set(hk.get_parameter(element, [], init=HaikuConstant(1.0)))
279
-
280
- elif self.variant == "v":
281
- for i, element in enumerate(abundance_table["Element"]):
282
- if element != "H" and element in v_elements:
283
- abund = abund.at[i].set(hk.get_parameter(element, [], init=HaikuConstant(1.0)))
284
-
285
- else:
286
- Z = hk.get_parameter("Abundance", [], init=HaikuConstant(1.0))
287
- for i, element in enumerate(abundance_table["Element"]):
288
- if element != "H" and element in none_elements:
289
- abund = abund.at[i].set(Z)
290
-
291
- if abund != "angr":
292
- abund = abund * jnp.asarray(
293
- abundance_table[self.abundance_table] / abundance_table["angr"]
294
- )
295
-
296
- # Set the temperature, redshift, normalisation
297
- kT = hk.get_parameter("kT", [], init=HaikuConstant(6.5))
298
- z = hk.get_parameter("Redshift", [], init=HaikuConstant(0.0))
299
- norm = hk.get_parameter("norm", [], init=HaikuConstant(1.0))
300
-
301
- return kT, z, norm, abund
302
-
303
- def emission_lines(self, e_low, e_high):
304
- # Get the parameters and extract the relevant data
305
- energy = jnp.hstack([e_low, e_high[-1]])
306
- kT, z, norm, abundances = self.get_parameters()
307
- total_broadening = jnp.hypot(self.get_thermal_broadening(), self.get_turbulent_broadening())
308
- energy = energy * (1 + z)
309
-
310
- continuum = continuum_func(energy, kT, abundances) if self.continuum_to_compute else 0.0
311
- pseudo_continuum = pseudo_func(energy, kT, abundances) if self.pseudo_to_compute else 0.0
312
- lines = (
313
- lines_func(energy, kT, abundances, total_broadening) if self.lines_to_compute else 0.0
314
- )
315
-
316
- return (continuum + pseudo_continuum + lines) * norm * 1e14 / (1 + z), (e_low + e_high) / 2