jaxspec 0.3.2__tar.gz → 0.3.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (121) hide show
  1. {jaxspec-0.3.2 → jaxspec-0.3.4}/PKG-INFO +4 -8
  2. {jaxspec-0.3.2 → jaxspec-0.3.4}/mkdocs.yml +1 -1
  3. {jaxspec-0.3.2 → jaxspec-0.3.4}/pyproject.toml +6 -10
  4. {jaxspec-0.3.2 → jaxspec-0.3.4}/src/jaxspec/analysis/results.py +83 -32
  5. {jaxspec-0.3.2 → jaxspec-0.3.4}/src/jaxspec/data/__init__.py +2 -0
  6. {jaxspec-0.3.2 → jaxspec-0.3.4}/src/jaxspec/data/instrument.py +2 -1
  7. {jaxspec-0.3.2 → jaxspec-0.3.4}/src/jaxspec/data/util.py +2 -2
  8. {jaxspec-0.3.2 → jaxspec-0.3.4}/src/jaxspec/fit/_fitter.py +18 -5
  9. {jaxspec-0.3.2 → jaxspec-0.3.4}/src/jaxspec/model/abc.py +4 -3
  10. {jaxspec-0.3.2 → jaxspec-0.3.4}/src/jaxspec/model/additive.py +17 -21
  11. {jaxspec-0.3.2 → jaxspec-0.3.4}/src/jaxspec/model/multiplicative.py +54 -3
  12. {jaxspec-0.3.2 → jaxspec-0.3.4}/src/jaxspec/util/online_storage.py +1 -0
  13. {jaxspec-0.3.2 → jaxspec-0.3.4}/tests/conftest.py +10 -0
  14. {jaxspec-0.3.2 → jaxspec-0.3.4}/tests/test_fakeit.py +13 -0
  15. {jaxspec-0.3.2 → jaxspec-0.3.4}/tests/test_mcmc.py +4 -1
  16. {jaxspec-0.3.2 → jaxspec-0.3.4}/tests/test_results.py +9 -3
  17. {jaxspec-0.3.2 → jaxspec-0.3.4}/tests/test_xspec.py +39 -0
  18. {jaxspec-0.3.2 → jaxspec-0.3.4}/.dockerignore +0 -0
  19. {jaxspec-0.3.2 → jaxspec-0.3.4}/.github/dependabot.yml +0 -0
  20. {jaxspec-0.3.2 → jaxspec-0.3.4}/.github/workflows/documentation-links.yml +0 -0
  21. {jaxspec-0.3.2 → jaxspec-0.3.4}/.github/workflows/publish.yml +0 -0
  22. {jaxspec-0.3.2 → jaxspec-0.3.4}/.github/workflows/test-and-coverage.yml +0 -0
  23. {jaxspec-0.3.2 → jaxspec-0.3.4}/.gitignore +0 -0
  24. {jaxspec-0.3.2 → jaxspec-0.3.4}/.pre-commit-config.yaml +0 -0
  25. {jaxspec-0.3.2 → jaxspec-0.3.4}/.python-version +0 -0
  26. {jaxspec-0.3.2 → jaxspec-0.3.4}/.readthedocs.yaml +0 -0
  27. {jaxspec-0.3.2 → jaxspec-0.3.4}/CODE_OF_CONDUCT.md +0 -0
  28. {jaxspec-0.3.2 → jaxspec-0.3.4}/Dockerfile +0 -0
  29. {jaxspec-0.3.2 → jaxspec-0.3.4}/LICENSE.md +0 -0
  30. {jaxspec-0.3.2 → jaxspec-0.3.4}/README.md +0 -0
  31. {jaxspec-0.3.2 → jaxspec-0.3.4}/codecov.yml +0 -0
  32. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/contribute/index.md +0 -0
  33. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/contribute/internal.md +0 -0
  34. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/contribute/xspec.md +0 -0
  35. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/css/extra.css +0 -0
  36. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/css/material.css +0 -0
  37. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/css/mkdocstrings.css +0 -0
  38. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/css/xarray.css +0 -0
  39. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/dev/index.md +0 -0
  40. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/examples/background.md +0 -0
  41. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/examples/build_model.md +0 -0
  42. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/examples/fakeits.md +0 -0
  43. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/examples/fitting_example.md +0 -0
  44. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/examples/index.md +0 -0
  45. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/examples/statics/background_comparison.png +0 -0
  46. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/examples/statics/background_gp.png +0 -0
  47. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/examples/statics/background_spectral.png +0 -0
  48. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/examples/statics/fakeits.png +0 -0
  49. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/examples/statics/fitting_example_corner.png +0 -0
  50. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/examples/statics/fitting_example_ppc.png +0 -0
  51. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/examples/statics/model.png +0 -0
  52. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/examples/statics/rmf.png +0 -0
  53. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/examples/statics/subtract_background.png +0 -0
  54. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/examples/statics/subtract_background_with_errors.png +0 -0
  55. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/faq/cookbook.md +0 -0
  56. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/faq/index.md +0 -0
  57. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/faq/statics/cstat_vs_chi2.png +0 -0
  58. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/frontpage/installation.md +0 -0
  59. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/index.md +0 -0
  60. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/javascripts/mathjax.js +0 -0
  61. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/logo/logo_small.svg +0 -0
  62. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/logo/xifu_mini.svg +0 -0
  63. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/references/abundance.md +0 -0
  64. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/references/additive.md +0 -0
  65. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/references/background.md +0 -0
  66. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/references/data.md +0 -0
  67. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/references/fitting.md +0 -0
  68. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/references/instrument.md +0 -0
  69. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/references/integrate.md +0 -0
  70. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/references/model.md +0 -0
  71. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/references/multiplicative.md +0 -0
  72. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/references/results.md +0 -0
  73. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/runtime/diagram.txt +0 -0
  74. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/runtime/result_table.txt +0 -0
  75. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/runtime/various_model_graphs/diagram.txt +0 -0
  76. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/runtime/various_model_graphs/model_complex.txt +0 -0
  77. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/runtime/various_model_graphs/model_simple.txt +0 -0
  78. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/theory/background.md +0 -0
  79. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/theory/bayesian_inference.md +0 -0
  80. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/theory/index.md +0 -0
  81. {jaxspec-0.3.2 → jaxspec-0.3.4}/docs/theory/instrument.md +0 -0
  82. {jaxspec-0.3.2 → jaxspec-0.3.4}/src/jaxspec/__init__.py +0 -0
  83. {jaxspec-0.3.2 → jaxspec-0.3.4}/src/jaxspec/analysis/__init__.py +0 -0
  84. {jaxspec-0.3.2 → jaxspec-0.3.4}/src/jaxspec/analysis/_plot.py +0 -0
  85. {jaxspec-0.3.2 → jaxspec-0.3.4}/src/jaxspec/analysis/compare.py +0 -0
  86. {jaxspec-0.3.2 → jaxspec-0.3.4}/src/jaxspec/data/obsconf.py +0 -0
  87. {jaxspec-0.3.2 → jaxspec-0.3.4}/src/jaxspec/data/observation.py +0 -0
  88. {jaxspec-0.3.2 → jaxspec-0.3.4}/src/jaxspec/data/ogip.py +0 -0
  89. {jaxspec-0.3.2 → jaxspec-0.3.4}/src/jaxspec/experimental/__init__.py +0 -0
  90. {jaxspec-0.3.2 → jaxspec-0.3.4}/src/jaxspec/experimental/interpolator.py +0 -0
  91. {jaxspec-0.3.2 → jaxspec-0.3.4}/src/jaxspec/experimental/interpolator_jax.py +0 -0
  92. {jaxspec-0.3.2 → jaxspec-0.3.4}/src/jaxspec/experimental/intrument_models.py +0 -0
  93. {jaxspec-0.3.2 → jaxspec-0.3.4}/src/jaxspec/experimental/nested_sampler.py +0 -0
  94. {jaxspec-0.3.2 → jaxspec-0.3.4}/src/jaxspec/experimental/tabulated.py +0 -0
  95. {jaxspec-0.3.2 → jaxspec-0.3.4}/src/jaxspec/fit/__init__.py +0 -0
  96. {jaxspec-0.3.2 → jaxspec-0.3.4}/src/jaxspec/fit/_bayesian_model.py +0 -0
  97. {jaxspec-0.3.2 → jaxspec-0.3.4}/src/jaxspec/fit/_build_model.py +0 -0
  98. {jaxspec-0.3.2 → jaxspec-0.3.4}/src/jaxspec/model/__init__.py +0 -0
  99. {jaxspec-0.3.2 → jaxspec-0.3.4}/src/jaxspec/model/_graph_util.py +0 -0
  100. {jaxspec-0.3.2 → jaxspec-0.3.4}/src/jaxspec/model/background.py +0 -0
  101. {jaxspec-0.3.2 → jaxspec-0.3.4}/src/jaxspec/model/instrument.py +0 -0
  102. {jaxspec-0.3.2 → jaxspec-0.3.4}/src/jaxspec/model/list.py +0 -0
  103. {jaxspec-0.3.2 → jaxspec-0.3.4}/src/jaxspec/scripts/__init__.py +0 -0
  104. {jaxspec-0.3.2 → jaxspec-0.3.4}/src/jaxspec/scripts/debug.py +0 -0
  105. {jaxspec-0.3.2 → jaxspec-0.3.4}/src/jaxspec/util/__init__.py +0 -0
  106. {jaxspec-0.3.2 → jaxspec-0.3.4}/src/jaxspec/util/abundance.py +0 -0
  107. {jaxspec-0.3.2 → jaxspec-0.3.4}/src/jaxspec/util/integrate.py +0 -0
  108. {jaxspec-0.3.2 → jaxspec-0.3.4}/src/jaxspec/util/misc.py +0 -0
  109. {jaxspec-0.3.2 → jaxspec-0.3.4}/src/jaxspec/util/typing.py +0 -0
  110. {jaxspec-0.3.2 → jaxspec-0.3.4}/tests/data_files.yml +0 -0
  111. {jaxspec-0.3.2 → jaxspec-0.3.4}/tests/data_hash.yml +0 -0
  112. {jaxspec-0.3.2 → jaxspec-0.3.4}/tests/test_background.py +0 -0
  113. {jaxspec-0.3.2 → jaxspec-0.3.4}/tests/test_bayesian_model.py +0 -0
  114. {jaxspec-0.3.2 → jaxspec-0.3.4}/tests/test_bayesian_model_building.py +0 -0
  115. {jaxspec-0.3.2 → jaxspec-0.3.4}/tests/test_instruments.py +0 -0
  116. {jaxspec-0.3.2 → jaxspec-0.3.4}/tests/test_integrate.py +0 -0
  117. {jaxspec-0.3.2 → jaxspec-0.3.4}/tests/test_misc.py +0 -0
  118. {jaxspec-0.3.2 → jaxspec-0.3.4}/tests/test_models.py +0 -0
  119. {jaxspec-0.3.2 → jaxspec-0.3.4}/tests/test_observation.py +0 -0
  120. {jaxspec-0.3.2 → jaxspec-0.3.4}/tests/test_repr.py +0 -0
  121. {jaxspec-0.3.2 → jaxspec-0.3.4}/tests/test_xspec_models.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jaxspec
3
- Version: 0.3.2
3
+ Version: 0.3.4
4
4
  Summary: jaxspec is a bayesian spectral fitting library for X-ray astronomy.
5
5
  Project-URL: Homepage, https://github.com/renecotyfanboy/jaxspec
6
6
  Project-URL: Documentation, https://jaxspec.readthedocs.io/en/latest/
@@ -8,7 +8,7 @@ Author-email: sdupourque <sdupourque@irap.omp.eu>
8
8
  License-Expression: MIT
9
9
  License-File: LICENSE.md
10
10
  Requires-Python: <3.13,>=3.10
11
- Requires-Dist: arviz<0.23.0,>=0.17.1
11
+ Requires-Dist: arviz<0.24.0,>=0.17.1
12
12
  Requires-Dist: astropy<8,>=6.0.0
13
13
  Requires-Dist: catppuccin<3,>=2.3.4
14
14
  Requires-Dist: chainconsumer<2,>=1.1.2
@@ -16,19 +16,15 @@ Requires-Dist: cmasher<2,>=1.6.3
16
16
  Requires-Dist: flax>0.10.5
17
17
  Requires-Dist: interpax<0.4,>=0.3.5
18
18
  Requires-Dist: jax<0.7,>=0.5.0
19
- Requires-Dist: jaxns<3,>=2.6.7
20
- Requires-Dist: jaxopt<0.9,>=0.8.3
21
19
  Requires-Dist: matplotlib<4,>=3.8.0
22
20
  Requires-Dist: mendeleev<1.2,>=0.15
23
21
  Requires-Dist: networkx~=3.1
24
22
  Requires-Dist: numpy<3.0.0
25
- Requires-Dist: numpyro<0.20,>=0.17.0
26
- Requires-Dist: optimistix<0.0.12,>=0.0.10
23
+ Requires-Dist: numpyro<0.21,>=0.17.0
27
24
  Requires-Dist: pandas<3,>=2.2.0
28
25
  Requires-Dist: pooch<2,>=1.8.2
29
26
  Requires-Dist: scipy<1.16
30
- Requires-Dist: seaborn<0.14,>=0.13.1
31
- Requires-Dist: simpleeval<1.1.0,>=0.9.13
27
+ Requires-Dist: seaborn>=0.13.2
32
28
  Requires-Dist: sparse>0.15
33
29
  Requires-Dist: tinygp<0.4,>=0.3.0
34
30
  Requires-Dist: watermark<3,>=2.4.3
@@ -86,7 +86,7 @@ theme:
86
86
  plugins:
87
87
  - search
88
88
  - autorefs
89
- # - typeset
89
+ - typeset
90
90
  - mkdocs-jupyter:
91
91
  include_source: True
92
92
  ignore_h1_titles: True
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "jaxspec"
3
- version = "0.3.2"
3
+ version = "0.3.4"
4
4
  description = "jaxspec is a bayesian spectral fitting library for X-ray astronomy."
5
5
  authors = [{ name = "sdupourque", email = "sdupourque@irap.omp.eu" }]
6
6
  requires-python = ">=3.10,<3.13"
@@ -11,26 +11,22 @@ dependencies = [
11
11
  "numpy<3.0.0",
12
12
  "pandas>=2.2.0,<3",
13
13
  "astropy>=6.0.0,<8",
14
- "numpyro>=0.17.0,<0.20",
14
+ "numpyro>=0.17.0,<0.21",
15
15
  "networkx~=3.1",
16
16
  "matplotlib>=3.8.0,<4",
17
- "arviz>=0.17.1,<0.23.0",
17
+ "arviz>=0.17.1,<0.24.0",
18
18
  "chainconsumer>=1.1.2,<2",
19
- "simpleeval>=0.9.13,<1.1.0",
20
19
  "cmasher>=1.6.3,<2",
21
- "jaxopt>=0.8.3,<0.9",
22
20
  "tinygp>=0.3.0,<0.4",
23
- "seaborn>=0.13.1,<0.14",
24
21
  "sparse>0.15",
25
- "optimistix>=0.0.10,<0.0.12",
26
22
  "scipy<1.16",
27
23
  "mendeleev>=0.15,<1.2",
28
- "jaxns>=2.6.7,<3",
29
24
  "pooch>=1.8.2,<2",
30
25
  "interpax>=0.3.5,<0.4",
31
26
  "watermark>=2.4.3,<3",
32
27
  "catppuccin>=2.3.4,<3",
33
28
  "flax>0.10.5",
29
+ "seaborn>=0.13.2",
34
30
  ]
35
31
 
36
32
  [project.urls]
@@ -44,14 +40,14 @@ jaxspec-debug-info = "jaxspec.scripts.debug:debug_info"
44
40
  docs = [
45
41
  "mkdocs>=1.6.1,<2",
46
42
  "mkdocs-material>=9.4.6,<10",
47
- "mkdocstrings[python]>=0.24,<0.28",
43
+ "mkdocstrings[python]>=0.24,<1.1",
48
44
  "mkdocs-jupyter>=0.25.0,<0.26",
49
45
  ]
50
46
  test = [
51
47
  "chex>=0.1.83,<0.2",
52
48
  "mktestdocs>=0.2.1,<0.3",
53
49
  "coverage>=7.3.2,<8",
54
- "pytest-cov>=4.1,<7.0",
50
+ "pytest-cov>=4.1,<8.0",
55
51
  "flake8>=7.0.0,<8",
56
52
  "pytest>=8.0.0,<9",
57
53
  "testbook>=0.4.2,<0.5",
@@ -77,8 +77,9 @@ class FitResult:
77
77
  r"""
78
78
  Convergence of the chain as computed by the $\hat{R}$ statistic.
79
79
  """
80
+ rhat = az.rhat(self.inference_data)
80
81
 
81
- return all(az.rhat(self.inference_data) < 1.01)
82
+ return bool((rhat.to_array() < 1.01).all())
82
83
 
83
84
  def _ppc_folded_branches(self, obs_id):
84
85
  obs = self.obsconfs[obs_id]
@@ -167,6 +168,8 @@ class FitResult:
167
168
  e_max: float,
168
169
  unit: Unit = u.photon / u.cm**2 / u.s,
169
170
  register: bool = False,
171
+ n_points: int = 5,
172
+ n_grid: int = 1_000,
170
173
  ) -> ArrayLike:
171
174
  """
172
175
  Compute the unfolded photon flux in a given energy band. The flux is then added to
@@ -177,29 +180,40 @@ class FitResult:
177
180
  e_max: The upper bound of the energy band in observer frame.
178
181
  unit: The unit of the photon flux.
179
182
  register: Whether to register the flux with the other posterior parameters.
183
+ n_points: The number of points per bin to use for computing the unfolded spectrum.
184
+ n_grid: The number of grid points to use for computing the unfolded spectrum.
180
185
 
181
186
  !!! warning
182
187
  Computation of the folded flux is not implemented yet. Feel free to open an
183
188
  [issue](https://github.com/renecotyfanboy/jaxspec/issues) in the GitHub repository.
184
189
  """
185
190
 
191
+ energy_grid = np.linspace(e_min, e_max, n_grid)
192
+
186
193
  @jax.jit
187
194
  @jnp.vectorize
188
195
  def vectorized_flux(*pars):
189
196
  parameters_pytree = jax.tree.unflatten(pytree_def, pars)
190
197
  return self.model.photon_flux(
191
- parameters_pytree, jnp.asarray([e_min]), jnp.asarray([e_max]), n_points=100
192
- )[0]
198
+ parameters_pytree, energy_grid[:-1], energy_grid[1:], n_points=n_points
199
+ )
193
200
 
194
201
  flat_tree, pytree_def = jax.tree.flatten(self.input_parameters)
195
- flux = vectorized_flux(*flat_tree)
196
- conversion_factor = (u.photon / u.cm**2 / u.s).to(unit)
197
- value = flux * conversion_factor
202
+ flux = vectorized_flux(*flat_tree).sum(axis=-1) # Sum over all bins
203
+ conversion_factor = float((u.photon / u.cm**2 / u.s).to(unit))
204
+ value = np.asarray(flux * conversion_factor)
198
205
 
199
206
  if register:
200
- self.inference_data.posterior[f"photon_flux_{e_min:.1f}_{e_max:.1f}"] = (
201
- list(self.inference_data.posterior.coords),
202
- value,
207
+ self.inference_data.posterior[f"mod/~/photon_flux_{e_min:.1f}_{e_max:.1f}"] = (
208
+ xr.DataArray(
209
+ value,
210
+ dims=self.inference_data.posterior.dims,
211
+ coords={
212
+ "chain": self.inference_data.posterior.chain,
213
+ "draw": self.inference_data.posterior.draw,
214
+ },
215
+ name=f"mod/~/photon_flux_{e_min:.1f}_{e_max:.1f}",
216
+ )
203
217
  )
204
218
 
205
219
  return value
@@ -210,6 +224,8 @@ class FitResult:
210
224
  e_max: float,
211
225
  unit: Unit = u.erg / u.cm**2 / u.s,
212
226
  register: bool = False,
227
+ n_points: int = 5,
228
+ n_grid: int = 1_000,
213
229
  ) -> ArrayLike:
214
230
  """
215
231
  Compute the unfolded energy flux in a given energy band. The flux is then added to
@@ -220,29 +236,40 @@ class FitResult:
220
236
  e_max: The upper bound of the energy band in observer frame.
221
237
  unit: The unit of the energy flux.
222
238
  register: Whether to register the flux with the other posterior parameters.
239
+ n_points: The number of points per bin to use for computing the unfolded spectrum.
240
+ n_grid: The number of grid points to use for computing the unfolded spectrum.
223
241
 
224
242
  !!! warning
225
243
  Computation of the folded flux is not implemented yet. Feel free to open an
226
244
  [issue](https://github.com/renecotyfanboy/jaxspec/issues) in the GitHub repository.
227
245
  """
228
246
 
247
+ energy_grid = np.linspace(e_min, e_max, n_grid)
248
+
229
249
  @jax.jit
230
250
  @jnp.vectorize
231
251
  def vectorized_flux(*pars):
232
252
  parameters_pytree = jax.tree.unflatten(pytree_def, pars)
233
253
  return self.model.energy_flux(
234
- parameters_pytree, jnp.asarray([e_min]), jnp.asarray([e_max]), n_points=100
235
- )[0]
254
+ parameters_pytree, energy_grid[:-1], energy_grid[1:], n_points=n_points
255
+ )
236
256
 
237
257
  flat_tree, pytree_def = jax.tree.flatten(self.input_parameters)
238
- flux = vectorized_flux(*flat_tree)
239
- conversion_factor = (u.keV / u.cm**2 / u.s).to(unit)
240
- value = flux * conversion_factor
241
-
258
+ flux = vectorized_flux(*flat_tree).sum(axis=-1) # Sum over all bins
259
+ conversion_factor = float((u.keV / u.cm**2 / u.s).to(unit))
260
+ value = np.asarray(flux * conversion_factor)
261
+ # TODO : ADD TESTS WITH BACKGROUND
242
262
  if register:
243
- self.inference_data.posterior[f"energy_flux_{e_min:.1f}_{e_max:.1f}"] = (
244
- list(self.inference_data.posterior.coords),
245
- value,
263
+ self.inference_data.posterior[f"mod/~/energy_flux_{e_min:.1f}_{e_max:.1f}"] = (
264
+ xr.DataArray(
265
+ value,
266
+ dims=self.inference_data.posterior.dims,
267
+ coords={
268
+ "chain": self.inference_data.posterior.chain,
269
+ "draw": self.inference_data.posterior.draw,
270
+ },
271
+ name=f"mod/~/energy_flux_{e_min:.1f}_{e_max:.1f}",
272
+ )
246
273
  )
247
274
 
248
275
  return value
@@ -257,6 +284,8 @@ class FitResult:
257
284
  cosmology: Cosmology = Planck18,
258
285
  unit: Unit = u.erg / u.s,
259
286
  register: bool = False,
287
+ n_points: int = 5,
288
+ n_grid: int = 1_000,
260
289
  ) -> ArrayLike:
261
290
  """
262
291
  Compute the luminosity of the source specifying its redshift. The luminosity is then added to
@@ -270,8 +299,12 @@ class FitResult:
270
299
  cosmology: Chosen cosmology.
271
300
  unit: The unit of the luminosity.
272
301
  register: Whether to register the flux with the other posterior parameters.
302
+ n_points: The number of points per bin to use for computing the unfolded spectrum.
303
+ n_grid: The number of grid points to use for computing the unfolded spectrum.
273
304
  """
274
305
 
306
+ energy_grid = np.linspace(e_min, e_max, n_grid)
307
+
275
308
  if not observer_frame:
276
309
  raise NotImplementedError()
277
310
 
@@ -292,19 +325,28 @@ class FitResult:
292
325
  parameters_pytree = jax.tree.unflatten(pytree_def, pars)
293
326
  return self.model.energy_flux(
294
327
  parameters_pytree,
295
- jnp.asarray([e_min]) * (1 + redshift),
296
- jnp.asarray([e_max]) * (1 + redshift),
297
- n_points=100,
298
- )[0]
328
+ energy_grid[:-1] * (1 + redshift),
329
+ energy_grid[1:] * (1 + redshift),
330
+ n_points=n_points,
331
+ )
299
332
 
300
333
  flat_tree, pytree_def = jax.tree.flatten(self.input_parameters)
301
- flux = vectorized_flux(*flat_tree) * (u.keV / u.cm**2 / u.s)
302
- value = (flux * (4 * np.pi * cosmology.luminosity_distance(redshift) ** 2)).to(unit)
334
+ flux = vectorized_flux(*flat_tree).sum(axis=-1) * (u.keV / u.cm**2 / u.s)
335
+ value = np.asarray(
336
+ (flux * (4 * np.pi * cosmology.luminosity_distance(redshift) ** 2)).to(unit)
337
+ )
303
338
 
304
339
  if register:
305
- self.inference_data.posterior[f"luminosity_{e_min:.1f}_{e_max:.1f}"] = (
306
- list(self.inference_data.posterior.coords),
307
- value,
340
+ self.inference_data.posterior[f"mod/~/luminosity_{e_min:.1f}_{e_max:.1f}"] = (
341
+ xr.DataArray(
342
+ value,
343
+ dims=self.inference_data.posterior.dims,
344
+ coords={
345
+ "chain": self.inference_data.posterior.chain,
346
+ "draw": self.inference_data.posterior.draw,
347
+ },
348
+ name=f"mod/~/luminosity_{e_min:.1f}_{e_max:.1f}",
349
+ )
308
350
  )
309
351
 
310
352
  return value
@@ -315,10 +357,13 @@ class FitResult:
315
357
 
316
358
  Parameters:
317
359
  name: The name of the chain.
360
+ parameter_kind: The kind of parameters to keep.
318
361
  """
319
362
 
320
363
  keys_to_drop = [
321
- key for key in self.inference_data.posterior.keys() if not key.startswith("mod")
364
+ key
365
+ for key in self.inference_data.posterior.keys()
366
+ if not key.startswith(parameter_kind)
322
367
  ]
323
368
 
324
369
  reduced_id = az.extract(
@@ -403,6 +448,7 @@ class FitResult:
403
448
  title: str | None = None,
404
449
  figsize: tuple[float, float] = (6, 6),
405
450
  x_lims: tuple[float, float] | None = None,
451
+ rescale_background: bool = False,
406
452
  ) -> list[plt.Figure]:
407
453
  r"""
408
454
  Plot the posterior predictive distribution of the model. It also features a residual plot, defined using the
@@ -423,6 +469,7 @@ class FitResult:
423
469
  title: The title of the plot.
424
470
  figsize: The size of the figure.
425
471
  x_lims: The limits of the x-axis.
472
+ rescale_background: Whether to rescale the background model to the data with backscal ratio.
426
473
 
427
474
  Returns:
428
475
  A list of matplotlib figures for each observation in the model.
@@ -573,10 +620,14 @@ class FitResult:
573
620
  )
574
621
  )
575
622
 
623
+ rescale_background_factor = (
624
+ obsconf.folded_backratio.data if rescale_background else 1.0
625
+ )
626
+
576
627
  model_bkg_plot = _plot_binned_samples_with_error(
577
628
  ax[0],
578
629
  xbins.value,
579
- y_samples_bkg.value,
630
+ y_samples_bkg.value * rescale_background_factor,
580
631
  color=BACKGROUND_COLOR,
581
632
  alpha_envelope=alpha_envelope,
582
633
  n_sigmas=n_sigmas,
@@ -585,9 +636,9 @@ class FitResult:
585
636
  true_bkg_plot = _plot_poisson_data_with_error(
586
637
  ax[0],
587
638
  xbins.value,
588
- y_observed_bkg.value,
589
- y_observed_bkg_low.value,
590
- y_observed_bkg_high.value,
639
+ y_observed_bkg.value * rescale_background_factor,
640
+ y_observed_bkg_low.value * rescale_background_factor,
641
+ y_observed_bkg_high.value * rescale_background_factor,
591
642
  color=BACKGROUND_DATA_COLOR,
592
643
  alpha=0.7,
593
644
  )
@@ -6,5 +6,7 @@ from .observation import Observation
6
6
 
7
7
  u.add_enabled_aliases({"counts": u.count})
8
8
  u.add_enabled_aliases({"channel": u.dimensionless_unscaled})
9
+ u.add_enabled_aliases({"ADU": u.dimensionless_unscaled}) # Appears in SIXTE outputs
10
+
9
11
  # Arbitrary units are found in .rsp files , let's hope it is compatible with what we would expect as the rmf x arf
10
12
  # u.add_enabled_aliases({"au": u.dimensionless_unscaled})
@@ -1,4 +1,5 @@
1
1
  import os
2
+ import sparse
2
3
 
3
4
  import numpy as np
4
5
  import xarray as xr
@@ -92,7 +93,7 @@ class Instrument(xr.Dataset):
92
93
 
93
94
  else:
94
95
  specresp = rmf.matrix.sum(axis=0)
95
- rmf.sparse_matrix /= specresp
96
+ rmf.sparse_matrix = sparse.COO( rmf.matrix / specresp )
96
97
 
97
98
  return cls.from_matrix(
98
99
  rmf.sparse_matrix, specresp, rmf.energ_lo, rmf.energ_hi, rmf.e_min, rmf.e_max
@@ -152,11 +152,11 @@ def forward_model_with_multiple_inputs(
152
152
  transfer_matrix = BCOO.from_scipy_sparse(
153
153
  obs_configuration.transfer_matrix.data.to_scipy_sparse().tocsr()
154
154
  )
155
+ expected_counts = transfer_matrix @ flux_func(parameters).T
155
156
 
156
157
  else:
157
158
  transfer_matrix = np.asarray(obs_configuration.transfer_matrix.data.todense())
158
-
159
- expected_counts = jnp.matvec(transfer_matrix, flux_func(parameters))
159
+ expected_counts = jnp.matvec(transfer_matrix, flux_func(parameters))
160
160
 
161
161
  # The result is clipped at 1e-6 to avoid 0 round-off and diverging likelihoods
162
162
  return jnp.clip(expected_counts, a_min=1e-6)
@@ -9,11 +9,13 @@ import matplotlib.pyplot as plt
9
9
  import numpyro
10
10
 
11
11
  from jax import random
12
+ from jax.numpy import concatenate
12
13
  from jax.random import PRNGKey
13
14
  from numpyro.infer import AIES, ESS, MCMC, NUTS, SVI, Predictive, Trace_ELBO
14
15
  from numpyro.infer.autoguide import AutoMultivariateNormal
15
16
 
16
17
  from ..analysis.results import FitResult
18
+ from ..model.background import SubtractedBackground
17
19
  from ._bayesian_model import BayesianModel
18
20
 
19
21
 
@@ -52,9 +54,22 @@ class BayesianModelFitter(BayesianModel, ABC):
52
54
  )
53
55
 
54
56
  log_likelihood = numpyro.infer.log_likelihood(numpyro_model, posterior_samples)
57
+ if len(log_likelihood.keys()) > 1:
58
+ log_likelihood["full"] = concatenate([ll for _, ll in log_likelihood.items()], axis=1)
59
+ log_likelihood["obs/~/all"] = concatenate(
60
+ [ll for k, ll in log_likelihood.items() if "obs" in k], axis=1
61
+ )
62
+
63
+ # Subtracted background is not fitted so there is no likelihood
64
+ if self.background_model is not None and not isinstance(
65
+ self.background_model, SubtractedBackground
66
+ ):
67
+ log_likelihood["bkg/~/all"] = concatenate(
68
+ [ll for k, ll in log_likelihood.items() if "bkg" in k], axis=1
69
+ )
55
70
 
56
71
  seeded_model = numpyro.handlers.substitute(
57
- numpyro.handlers.seed(numpyro_model, keys[3]),
72
+ numpyro.handlers.seed(numpyro_model, keys[2]),
58
73
  substitute_fn=numpyro.infer.init_to_sample,
59
74
  )
60
75
 
@@ -108,12 +123,10 @@ class BayesianModelFitter(BayesianModel, ABC):
108
123
  predictive_parameters = []
109
124
 
110
125
  for key, value in self._observation_container.items():
126
+ predictive_parameters.append(f"obs/~/{key}")
111
127
  if self.background_model is not None:
112
- predictive_parameters.append(f"obs/~/{key}")
113
128
  predictive_parameters.append(f"bkg/~/{key}")
114
129
  # predictive_parameters.append(f"ins/~/{key}")
115
- else:
116
- predictive_parameters.append(f"obs/~/{key}")
117
130
  # predictive_parameters.append(f"ins/~/{key}")
118
131
 
119
132
  inference_data.posterior_predictive = inference_data.posterior_predictive[
@@ -247,7 +260,7 @@ class VIFitter(BayesianModelFitter):
247
260
 
248
261
  svi = SVI(bayesian_model, guide, optimizer, loss=loss)
249
262
 
250
- keys = random.split(random.PRNGKey(rng_key), 3)
263
+ keys = random.split(random.PRNGKey(rng_key), 2)
251
264
  svi_result = svi.run(keys[0], num_steps)
252
265
  params = svi_result.params
253
266
 
@@ -372,9 +372,10 @@ class AdditiveComponent(ModelComponent):
372
372
  continuum = self.continuum(energy)
373
373
  integrated_continuum = self.integrated_continuum(e_low, e_high)
374
374
 
375
- return jsp.integrate.trapezoid(
376
- continuum * energy**2, jnp.log(energy), axis=-1
377
- ) + integrated_continuum * (e_high - e_low)
375
+ return (
376
+ jsp.integrate.trapezoid(continuum * energy**2, jnp.log(energy), axis=-1)
377
+ + integrated_continuum * (e_high + e_low) / 2.0
378
+ )
378
379
 
379
380
  @partial(jax.jit, static_argnums=0, static_argnames="n_points")
380
381
  def photon_flux(self, params, e_low, e_high, n_points=2):
@@ -34,8 +34,10 @@ class Powerlaw(AdditiveComponent):
34
34
  self.alpha = nnx.Param(1.7)
35
35
  self.norm = nnx.Param(1e-4)
36
36
 
37
- def continuum(self, energy):
38
- return self.norm * energy ** (-self.alpha)
37
+ def integrated_continuum(self, e_low, e_high):
38
+ return (
39
+ self.norm / (1 - self.alpha) * (e_high ** (1 - self.alpha) - e_low ** (1 - self.alpha))
40
+ )
39
41
 
40
42
 
41
43
  class Additiveconstant(AdditiveComponent):
@@ -166,28 +168,22 @@ class Gauss(AdditiveComponent):
166
168
  self.sigma = nnx.Param(1e-2)
167
169
  self.norm = nnx.Param(1.0)
168
170
 
169
- def continuum(self, energy):
170
- return (
171
- self.norm
172
- * jsp.stats.norm.pdf(energy, loc=jnp.asarray(self.El), scale=jnp.asarray(self.sigma))
173
- / (1 - jsp.special.erf(-self.El / (self.sigma * jnp.sqrt(2))))
171
+ def integrated_continuum(self, e_low, e_high):
172
+ upper = jsp.stats.norm.cdf(
173
+ e_high,
174
+ loc=jnp.asarray(self.El),
175
+ scale=jnp.asarray(self.sigma),
174
176
  )
175
177
 
176
- """
177
- def integrated_continuum(self, e_low, e_high):
178
- return self.norm * (
179
- jsp.stats.norm.cdf(
180
- e_high,
181
- loc=jnp.asarray(self.El),
182
- scale=jnp.asarray(self.sigma),
183
- )
184
- - jsp.stats.norm.cdf(
185
- e_low,
186
- loc=jnp.asarray(self.El),
187
- scale=jnp.asarray(self.sigma),
188
- ) #/ (1 - jsp.special.erf(- self.El / (self.sigma * jnp.sqrt(2))))
178
+ lower = jsp.stats.norm.cdf(
179
+ e_low,
180
+ loc=jnp.asarray(self.El),
181
+ scale=jnp.asarray(self.sigma),
189
182
  )
190
- """
183
+
184
+ factor = 2 / (1 - jsp.special.erf(-self.El / (self.sigma * jnp.sqrt(2))))
185
+
186
+ return self.norm * (upper - lower) * factor
191
187
 
192
188
 
193
189
  class Cutoffpl(AdditiveComponent):
@@ -49,7 +49,9 @@ class Expfac(MultiplicativeComponent):
49
49
  self.E_c = nnx.Param(1.0)
50
50
 
51
51
  def factor(self, energy):
52
- return jnp.where(energy >= self.E_c, 1.0 + self.A * jnp.exp(-self.f * energy), 1.0)
52
+ return jnp.where(
53
+ energy >= self.E_c, 1.0 + self.A * jnp.exp(-self.f * energy), 1.0
54
+ )
53
55
 
54
56
 
55
57
  class Tbabs(MultiplicativeComponent):
@@ -91,6 +93,49 @@ class Tbabs(MultiplicativeComponent):
91
93
  return jnp.exp(-self.nh * sigma)
92
94
 
93
95
 
96
+ class zTbabs(MultiplicativeComponent):
97
+ r"""
98
+ The redshifted Tuebingen-Boulder ISM absorption model. See `Tbabs` for more details.
99
+ From Xspec manual:
100
+ This model assumes that 20% of the hydrogen is molecular
101
+ and that there is NO MATERIAL IN GRAINS.
102
+
103
+ $$
104
+ \mathcal{M}(E) = \exp^{-N_{\text{H}}\sigma(E)}
105
+ $$
106
+
107
+ !!! abstract "Parameters"
108
+ * $N_{\text{H}}$ (`nh`) $\left[\frac{\text{atoms}~10^{22}}{\text{cm}^2}\right]$ : Equivalent hydrogen column density
109
+ * $z$ (`z`) $\left[\text{dimensionless}\right]$ : Redshift
110
+
111
+
112
+ !!! note
113
+ Abundances and cross-sections $\sigma$ can be found in Wilms et al. (2000).
114
+
115
+ """
116
+
117
+ def __init__(self):
118
+ table = Table.read(table_manager.fetch("xsect_tbabs_wilm.fits"))
119
+ self._energy = np.asarray(table["ENERGY"], dtype=np.float64)
120
+ self._sigma = np.asarray(table["SIGMA"], dtype=np.float64)
121
+ self.nh = nnx.Param(1.0)
122
+ self.z = nnx.Param(1.0)
123
+
124
+ def factor(self, energy):
125
+ z = jnp.asarray(self.z)
126
+ sigma = jnp.exp(
127
+ jnp.interp(
128
+ jnp.log(energy) + jnp.log1p(z),
129
+ jnp.log(self._energy),
130
+ jnp.log(self._sigma),
131
+ left="extrapolate",
132
+ right="extrapolate",
133
+ )
134
+ )
135
+
136
+ return jnp.exp(-self.nh * sigma)
137
+
138
+
94
139
  class Phabs(MultiplicativeComponent):
95
140
  r"""
96
141
  A photoelectric absorption model.
@@ -215,7 +260,9 @@ class Zedge(MultiplicativeComponent):
215
260
 
216
261
  def factor(self, energy):
217
262
  return jnp.where(
218
- energy <= self.Ec, 1.0, jnp.exp(-self.D * (energy * (1 + self.z) / self.Ec) ** 3)
263
+ energy <= self.Ec,
264
+ 1.0,
265
+ jnp.exp(-self.D * (energy * (1 + self.z) / self.Ec) ** 3),
219
266
  )
220
267
 
221
268
 
@@ -246,7 +293,11 @@ class Tbpcf(MultiplicativeComponent):
246
293
  def factor(self, energy):
247
294
  sigma = jnp.exp(
248
295
  jnp.interp(
249
- energy, self._energy, jnp.log(self._sigma), left="extrapolate", right="extrapolate"
296
+ energy,
297
+ self._energy,
298
+ jnp.log(self._sigma),
299
+ left="extrapolate",
300
+ right="extrapolate",
250
301
  )
251
302
  )
252
303
 
@@ -25,4 +25,5 @@ table_manager = pooch.create(
25
25
  "example_data/NGC7793_ULX4/MOS2.arf": "sha256:a126ff5a95a5f4bb93ed846944cf411d6e1c448626cb73d347e33324663d8b3f",
26
26
  "example_data/NGC7793_ULX4/PNbackground_spectrum.fits": "sha256:55e017e0c19b324245fef049dff2a7a2e49b9a391667ca9c4f667c4f683b1f49",
27
27
  },
28
+ retry_if_failed=10,
28
29
  )
@@ -70,6 +70,7 @@ pooch_dataset = pooch.create(
70
70
  base_url="https://github.com/HEACIT/curated-test-data/raw/main/",
71
71
  path=str(data_directory),
72
72
  registry=data_hash,
73
+ retry_if_failed=10,
73
74
  )
74
75
 
75
76
  for file in data_hash.keys():
@@ -112,6 +113,15 @@ def get_individual_mcmc_results(obs_model_prior):
112
113
  return [MCMCFitter(model, prior, obsconf).fit(num_samples=5000) for obsconf in obsconfs]
113
114
 
114
115
 
116
+ @pytest.fixture(scope="session")
117
+ def get_failed_mcmc_results(obs_model_prior):
118
+ obsconfs, model, prior = obs_model_prior
119
+
120
+ return [
121
+ MCMCFitter(model, prior, obsconf).fit(num_warmup=10, num_samples=10) for obsconf in obsconfs
122
+ ]
123
+
124
+
115
125
  @pytest.fixture(scope="session")
116
126
  def get_joint_mcmc_result(obs_model_prior):
117
127
  obsconfs, model, prior = obs_model_prior
@@ -93,6 +93,19 @@ def test_fakeits_parallel(obsconfs, model, sharded_parameters):
93
93
  chex.assert_type(spectra, int)
94
94
 
95
95
 
96
+ def test_fakeits_sparsify(obsconfs, model, sharded_parameters):
97
+ obsconf = obsconfs[0]
98
+ spectra = fakeit_for_multiple_parameters(
99
+ obsconf, model, sharded_parameters, apply_stat=False, sparsify_matrix=False
100
+ )
101
+ chex.assert_type(spectra, float)
102
+
103
+ spectra = fakeit_for_multiple_parameters(
104
+ obsconf, model, sharded_parameters, apply_stat=False, sparsify_matrix=True
105
+ )
106
+ chex.assert_type(spectra, float)
107
+
108
+
96
109
  def test_fakeits_multiple_observation(obsconfs, model, multidimensional_parameters):
97
110
  obsconf = obsconfs[0]
98
111
  spectra = fakeit_for_multiple_parameters(
@@ -3,10 +3,13 @@ import pytest
3
3
  from jaxspec.fit import BayesianModel
4
4
 
5
5
 
6
- def test_convergence(get_individual_mcmc_results, get_joint_mcmc_result):
6
+ def test_convergence(get_individual_mcmc_results, get_joint_mcmc_result, get_failed_mcmc_results):
7
7
  for result in get_individual_mcmc_results + get_joint_mcmc_result:
8
8
  assert result.converged
9
9
 
10
+ for result in get_failed_mcmc_results:
11
+ assert not result.converged
12
+
10
13
 
11
14
  def test_ns(obs_model_prior):
12
15
  NSFitter = pytest.importorskip("jaxspec.fit.NSFitter")
@@ -54,14 +54,18 @@ def test_posterior_photon_flux(get_joint_mcmc_result):
54
54
  result = get_joint_mcmc_result[0]
55
55
  e_min, e_max = 0.7, 1.2
56
56
  result.photon_flux(e_min, e_max, register=True)
57
- assert f"photon_flux_{e_min:.1f}_{e_max:.1f}" in list(result.inference_data.posterior.keys())
57
+ assert f"mod/~/photon_flux_{e_min:.1f}_{e_max:.1f}" in list(
58
+ result.inference_data.posterior.keys()
59
+ )
58
60
 
59
61
 
60
62
  def test_posterior_energy_flux(get_joint_mcmc_result):
61
63
  result = get_joint_mcmc_result[0]
62
64
  e_min, e_max = 0.7, 1.2
63
65
  result.energy_flux(e_min, e_max, register=True)
64
- assert f"energy_flux_{e_min:.1f}_{e_max:.1f}" in list(result.inference_data.posterior.keys())
66
+ assert f"mod/~/energy_flux_{e_min:.1f}_{e_max:.1f}" in list(
67
+ result.inference_data.posterior.keys()
68
+ )
65
69
 
66
70
 
67
71
  def test_posterior_luminosity(get_joint_mcmc_result):
@@ -76,4 +80,6 @@ def test_posterior_luminosity(get_joint_mcmc_result):
76
80
 
77
81
  result.luminosity(e_min, e_max, redshift=0.1, register=True)
78
82
 
79
- assert f"luminosity_{e_min:.1f}_{e_max:.1f}" in list(result.inference_data.posterior.keys())
83
+ assert f"mod/~/luminosity_{e_min:.1f}_{e_max:.1f}" in list(
84
+ result.inference_data.posterior.keys()
85
+ )
@@ -1,8 +1,10 @@
1
1
  import os
2
2
 
3
+ import astropy.units as u
3
4
  import numpy as np
4
5
  import pytest
5
6
 
7
+ from jaxspec.model.additive import Powerlaw
6
8
  from jaxspec.util.online_storage import table_manager
7
9
 
8
10
  xspec = pytest.importorskip("xspec")
@@ -78,3 +80,40 @@ def test_bins(load_xspec_data, load_jaxspec_data):
78
80
  assert np.isclose(
79
81
  folding.in_energies, xspec_in_energies
80
82
  ).all(), f"The unfolded channel energy bins are not the same as XSPEC {file_pha}"
83
+
84
+
85
+ def test_flux_computation():
86
+ xspec.AllData.clear()
87
+ xspec.AllModels.clear()
88
+
89
+ xspec.AllData.dummyrsp(lowE=0.2, highE=1.7, nBins=10_000)
90
+ m = xspec.Model("powerlaw")
91
+ m.powerlaw.PhoIndex = 2.0
92
+ m.powerlaw.norm = 1.0
93
+
94
+ xspec.AllModels.calcFlux("0.5 1.5")
95
+
96
+ phflux_xspec = m.flux[3] # ph/cm^2/s
97
+ eflux_xspec = m.flux[0] # erg/cm^2/s
98
+
99
+ factor = (1 * u.keV).to(u.erg).value
100
+ phflux_jaxspec = Powerlaw().photon_flux(
101
+ {"powerlaw_1_norm": 1.0, "powerlaw_1_alpha": 2.0}, e_low=0.5, e_high=1.5, n_points=10_000
102
+ )
103
+
104
+ eflux_jaxspec = (
105
+ Powerlaw().energy_flux(
106
+ {"powerlaw_1_norm": 1.0, "powerlaw_1_alpha": 2.0},
107
+ e_low=0.5,
108
+ e_high=1.5,
109
+ n_points=10_000,
110
+ )
111
+ * factor
112
+ )
113
+
114
+ assert np.isclose(
115
+ phflux_xspec, phflux_jaxspec
116
+ ), f"Mismatch between XSPEC and jaxspec on photon flux, got {phflux_xspec} and {phflux_jaxspec}"
117
+ assert np.isclose(
118
+ eflux_xspec, eflux_jaxspec
119
+ ), f"Mismatch between XSPEC and jaxspec on energy flux, got {eflux_xspec} and {eflux_jaxspec}"
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes