jaxspec 0.3.1__tar.gz → 0.3.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (121) hide show
  1. {jaxspec-0.3.1 → jaxspec-0.3.3}/PKG-INFO +2 -6
  2. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/examples/index.md +6 -0
  3. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/references/fitting.md +8 -0
  4. jaxspec-0.3.3/docs/references/instrument.md +4 -0
  5. {jaxspec-0.3.1 → jaxspec-0.3.3}/mkdocs.yml +3 -4
  6. {jaxspec-0.3.1 → jaxspec-0.3.3}/pyproject.toml +4 -8
  7. {jaxspec-0.3.1 → jaxspec-0.3.3}/src/jaxspec/analysis/results.py +61 -23
  8. {jaxspec-0.3.1 → jaxspec-0.3.3}/src/jaxspec/data/instrument.py +2 -1
  9. {jaxspec-0.3.1 → jaxspec-0.3.3}/src/jaxspec/data/obsconf.py +0 -10
  10. {jaxspec-0.3.1 → jaxspec-0.3.3}/src/jaxspec/data/observation.py +5 -11
  11. {jaxspec-0.3.1 → jaxspec-0.3.3}/src/jaxspec/data/util.py +2 -2
  12. {jaxspec-0.3.1 → jaxspec-0.3.3}/src/jaxspec/fit/_bayesian_model.py +1 -1
  13. {jaxspec-0.3.1 → jaxspec-0.3.3}/src/jaxspec/fit/_fitter.py +32 -8
  14. {jaxspec-0.3.1 → jaxspec-0.3.3}/src/jaxspec/model/abc.py +4 -3
  15. {jaxspec-0.3.1 → jaxspec-0.3.3}/src/jaxspec/model/instrument.py +34 -0
  16. {jaxspec-0.3.1 → jaxspec-0.3.3}/src/jaxspec/util/online_storage.py +1 -0
  17. {jaxspec-0.3.1 → jaxspec-0.3.3}/tests/conftest.py +10 -0
  18. {jaxspec-0.3.1 → jaxspec-0.3.3}/tests/test_background.py +3 -3
  19. {jaxspec-0.3.1 → jaxspec-0.3.3}/tests/test_fakeit.py +13 -0
  20. {jaxspec-0.3.1 → jaxspec-0.3.3}/tests/test_mcmc.py +4 -1
  21. {jaxspec-0.3.1 → jaxspec-0.3.3}/tests/test_results.py +9 -3
  22. {jaxspec-0.3.1 → jaxspec-0.3.3}/tests/test_xspec.py +39 -0
  23. {jaxspec-0.3.1 → jaxspec-0.3.3}/.dockerignore +0 -0
  24. {jaxspec-0.3.1 → jaxspec-0.3.3}/.github/dependabot.yml +0 -0
  25. {jaxspec-0.3.1 → jaxspec-0.3.3}/.github/workflows/documentation-links.yml +0 -0
  26. {jaxspec-0.3.1 → jaxspec-0.3.3}/.github/workflows/publish.yml +0 -0
  27. {jaxspec-0.3.1 → jaxspec-0.3.3}/.github/workflows/test-and-coverage.yml +0 -0
  28. {jaxspec-0.3.1 → jaxspec-0.3.3}/.gitignore +0 -0
  29. {jaxspec-0.3.1 → jaxspec-0.3.3}/.pre-commit-config.yaml +0 -0
  30. {jaxspec-0.3.1 → jaxspec-0.3.3}/.python-version +0 -0
  31. {jaxspec-0.3.1 → jaxspec-0.3.3}/.readthedocs.yaml +0 -0
  32. {jaxspec-0.3.1 → jaxspec-0.3.3}/CODE_OF_CONDUCT.md +0 -0
  33. {jaxspec-0.3.1 → jaxspec-0.3.3}/Dockerfile +0 -0
  34. {jaxspec-0.3.1 → jaxspec-0.3.3}/LICENSE.md +0 -0
  35. {jaxspec-0.3.1 → jaxspec-0.3.3}/README.md +0 -0
  36. {jaxspec-0.3.1 → jaxspec-0.3.3}/codecov.yml +0 -0
  37. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/contribute/index.md +0 -0
  38. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/contribute/internal.md +0 -0
  39. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/contribute/xspec.md +0 -0
  40. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/css/extra.css +0 -0
  41. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/css/material.css +0 -0
  42. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/css/mkdocstrings.css +0 -0
  43. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/css/xarray.css +0 -0
  44. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/dev/index.md +0 -0
  45. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/examples/background.md +0 -0
  46. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/examples/build_model.md +0 -0
  47. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/examples/fakeits.md +0 -0
  48. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/examples/fitting_example.md +0 -0
  49. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/examples/statics/background_comparison.png +0 -0
  50. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/examples/statics/background_gp.png +0 -0
  51. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/examples/statics/background_spectral.png +0 -0
  52. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/examples/statics/fakeits.png +0 -0
  53. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/examples/statics/fitting_example_corner.png +0 -0
  54. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/examples/statics/fitting_example_ppc.png +0 -0
  55. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/examples/statics/model.png +0 -0
  56. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/examples/statics/rmf.png +0 -0
  57. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/examples/statics/subtract_background.png +0 -0
  58. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/examples/statics/subtract_background_with_errors.png +0 -0
  59. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/faq/cookbook.md +0 -0
  60. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/faq/index.md +0 -0
  61. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/faq/statics/cstat_vs_chi2.png +0 -0
  62. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/frontpage/installation.md +0 -0
  63. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/index.md +0 -0
  64. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/javascripts/mathjax.js +0 -0
  65. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/logo/logo_small.svg +0 -0
  66. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/logo/xifu_mini.svg +0 -0
  67. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/references/abundance.md +0 -0
  68. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/references/additive.md +0 -0
  69. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/references/background.md +0 -0
  70. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/references/data.md +0 -0
  71. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/references/integrate.md +0 -0
  72. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/references/model.md +0 -0
  73. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/references/multiplicative.md +0 -0
  74. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/references/results.md +0 -0
  75. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/runtime/diagram.txt +0 -0
  76. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/runtime/result_table.txt +0 -0
  77. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/runtime/various_model_graphs/diagram.txt +0 -0
  78. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/runtime/various_model_graphs/model_complex.txt +0 -0
  79. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/runtime/various_model_graphs/model_simple.txt +0 -0
  80. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/theory/background.md +0 -0
  81. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/theory/bayesian_inference.md +0 -0
  82. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/theory/index.md +0 -0
  83. {jaxspec-0.3.1 → jaxspec-0.3.3}/docs/theory/instrument.md +0 -0
  84. {jaxspec-0.3.1 → jaxspec-0.3.3}/src/jaxspec/__init__.py +0 -0
  85. {jaxspec-0.3.1 → jaxspec-0.3.3}/src/jaxspec/analysis/__init__.py +0 -0
  86. {jaxspec-0.3.1 → jaxspec-0.3.3}/src/jaxspec/analysis/_plot.py +0 -0
  87. {jaxspec-0.3.1 → jaxspec-0.3.3}/src/jaxspec/analysis/compare.py +0 -0
  88. {jaxspec-0.3.1 → jaxspec-0.3.3}/src/jaxspec/data/__init__.py +0 -0
  89. {jaxspec-0.3.1 → jaxspec-0.3.3}/src/jaxspec/data/ogip.py +0 -0
  90. {jaxspec-0.3.1 → jaxspec-0.3.3}/src/jaxspec/experimental/__init__.py +0 -0
  91. {jaxspec-0.3.1 → jaxspec-0.3.3}/src/jaxspec/experimental/interpolator.py +0 -0
  92. {jaxspec-0.3.1 → jaxspec-0.3.3}/src/jaxspec/experimental/interpolator_jax.py +0 -0
  93. {jaxspec-0.3.1 → jaxspec-0.3.3}/src/jaxspec/experimental/intrument_models.py +0 -0
  94. {jaxspec-0.3.1 → jaxspec-0.3.3}/src/jaxspec/experimental/nested_sampler.py +0 -0
  95. {jaxspec-0.3.1 → jaxspec-0.3.3}/src/jaxspec/experimental/tabulated.py +0 -0
  96. {jaxspec-0.3.1 → jaxspec-0.3.3}/src/jaxspec/fit/__init__.py +0 -0
  97. {jaxspec-0.3.1 → jaxspec-0.3.3}/src/jaxspec/fit/_build_model.py +0 -0
  98. {jaxspec-0.3.1 → jaxspec-0.3.3}/src/jaxspec/model/__init__.py +0 -0
  99. {jaxspec-0.3.1 → jaxspec-0.3.3}/src/jaxspec/model/_graph_util.py +0 -0
  100. {jaxspec-0.3.1 → jaxspec-0.3.3}/src/jaxspec/model/additive.py +0 -0
  101. {jaxspec-0.3.1 → jaxspec-0.3.3}/src/jaxspec/model/background.py +0 -0
  102. {jaxspec-0.3.1 → jaxspec-0.3.3}/src/jaxspec/model/list.py +0 -0
  103. {jaxspec-0.3.1 → jaxspec-0.3.3}/src/jaxspec/model/multiplicative.py +0 -0
  104. {jaxspec-0.3.1 → jaxspec-0.3.3}/src/jaxspec/scripts/__init__.py +0 -0
  105. {jaxspec-0.3.1 → jaxspec-0.3.3}/src/jaxspec/scripts/debug.py +0 -0
  106. {jaxspec-0.3.1 → jaxspec-0.3.3}/src/jaxspec/util/__init__.py +0 -0
  107. {jaxspec-0.3.1 → jaxspec-0.3.3}/src/jaxspec/util/abundance.py +0 -0
  108. {jaxspec-0.3.1 → jaxspec-0.3.3}/src/jaxspec/util/integrate.py +0 -0
  109. {jaxspec-0.3.1 → jaxspec-0.3.3}/src/jaxspec/util/misc.py +0 -0
  110. {jaxspec-0.3.1 → jaxspec-0.3.3}/src/jaxspec/util/typing.py +0 -0
  111. {jaxspec-0.3.1 → jaxspec-0.3.3}/tests/data_files.yml +0 -0
  112. {jaxspec-0.3.1 → jaxspec-0.3.3}/tests/data_hash.yml +0 -0
  113. {jaxspec-0.3.1 → jaxspec-0.3.3}/tests/test_bayesian_model.py +0 -0
  114. {jaxspec-0.3.1 → jaxspec-0.3.3}/tests/test_bayesian_model_building.py +0 -0
  115. {jaxspec-0.3.1 → jaxspec-0.3.3}/tests/test_instruments.py +0 -0
  116. {jaxspec-0.3.1 → jaxspec-0.3.3}/tests/test_integrate.py +0 -0
  117. {jaxspec-0.3.1 → jaxspec-0.3.3}/tests/test_misc.py +0 -0
  118. {jaxspec-0.3.1 → jaxspec-0.3.3}/tests/test_models.py +0 -0
  119. {jaxspec-0.3.1 → jaxspec-0.3.3}/tests/test_observation.py +0 -0
  120. {jaxspec-0.3.1 → jaxspec-0.3.3}/tests/test_repr.py +0 -0
  121. {jaxspec-0.3.1 → jaxspec-0.3.3}/tests/test_xspec_models.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jaxspec
3
- Version: 0.3.1
3
+ Version: 0.3.3
4
4
  Summary: jaxspec is a bayesian spectral fitting library for X-ray astronomy.
5
5
  Project-URL: Homepage, https://github.com/renecotyfanboy/jaxspec
6
6
  Project-URL: Documentation, https://jaxspec.readthedocs.io/en/latest/
@@ -16,19 +16,15 @@ Requires-Dist: cmasher<2,>=1.6.3
16
16
  Requires-Dist: flax>0.10.5
17
17
  Requires-Dist: interpax<0.4,>=0.3.5
18
18
  Requires-Dist: jax<0.7,>=0.5.0
19
- Requires-Dist: jaxns<3,>=2.6.7
20
- Requires-Dist: jaxopt<0.9,>=0.8.3
21
19
  Requires-Dist: matplotlib<4,>=3.8.0
22
20
  Requires-Dist: mendeleev<1.2,>=0.15
23
21
  Requires-Dist: networkx~=3.1
24
22
  Requires-Dist: numpy<3.0.0
25
23
  Requires-Dist: numpyro<0.20,>=0.17.0
26
- Requires-Dist: optimistix<0.0.11,>=0.0.10
27
24
  Requires-Dist: pandas<3,>=2.2.0
28
25
  Requires-Dist: pooch<2,>=1.8.2
29
26
  Requires-Dist: scipy<1.16
30
- Requires-Dist: seaborn<0.14,>=0.13.1
31
- Requires-Dist: simpleeval<1.1.0,>=0.9.13
27
+ Requires-Dist: seaborn>=0.13.2
32
28
  Requires-Dist: sparse>0.15
33
29
  Requires-Dist: tinygp<0.4,>=0.3.0
34
30
  Requires-Dist: watermark<3,>=2.4.3
@@ -54,5 +54,11 @@
54
54
 
55
55
  Use `jaxspec` likelihood in other packages
56
56
 
57
+ - [__Cross-calibration between instruments__](calibration.ipynb)
58
+
59
+ ---
60
+
61
+ Add extra gains and shifts to account for the miscalibrations
62
+
57
63
 
58
64
  </div>
@@ -13,3 +13,11 @@
13
13
  show_root_full_path: false
14
14
  show_root_toc_entry: true
15
15
  heading_level: 3
16
+
17
+ ::: jaxspec.fit.VIFitter
18
+ options:
19
+ show_root_heading: true
20
+ show_root_full_path: false
21
+ show_root_toc_entry: true
22
+ heading_level: 3
23
+
@@ -0,0 +1,4 @@
1
+ ::: jaxspec.model.instrument
2
+ options:
3
+ show_root_heading: false
4
+ show_root_toc_entry: false
@@ -18,14 +18,13 @@ nav:
18
18
  - Add a background : examples/background.md
19
19
  - Good practices for MCMC: examples/work_with_arviz.ipynb
20
20
  - Interface with other frameworks : examples/external_samplers.ipynb
21
- #- Models:
22
- # - models/index.md
23
- # - APEC: models/apec.md
21
+ - Cross calibration : examples/calibration.ipynb
24
22
  - API Reference:
25
23
  - Spectral model base: references/model.md
26
24
  - Additive models: references/additive.md
27
25
  - Multiplicative models: references/multiplicative.md
28
26
  - Background models: references/background.md
27
+ - Instrument models: references/instrument.md
29
28
  - Data containers: references/data.md
30
29
  - Fitting: references/fitting.md
31
30
  - Result containers: references/results.md
@@ -87,7 +86,7 @@ theme:
87
86
  plugins:
88
87
  - search
89
88
  - autorefs
90
- # - typeset
89
+ - typeset
91
90
  - mkdocs-jupyter:
92
91
  include_source: True
93
92
  ignore_h1_titles: True
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "jaxspec"
3
- version = "0.3.1"
3
+ version = "0.3.3"
4
4
  description = "jaxspec is a bayesian spectral fitting library for X-ray astronomy."
5
5
  authors = [{ name = "sdupourque", email = "sdupourque@irap.omp.eu" }]
6
6
  requires-python = ">=3.10,<3.13"
@@ -16,21 +16,17 @@ dependencies = [
16
16
  "matplotlib>=3.8.0,<4",
17
17
  "arviz>=0.17.1,<0.23.0",
18
18
  "chainconsumer>=1.1.2,<2",
19
- "simpleeval>=0.9.13,<1.1.0",
20
19
  "cmasher>=1.6.3,<2",
21
- "jaxopt>=0.8.3,<0.9",
22
20
  "tinygp>=0.3.0,<0.4",
23
- "seaborn>=0.13.1,<0.14",
24
21
  "sparse>0.15",
25
- "optimistix>=0.0.10,<0.0.11",
26
22
  "scipy<1.16",
27
23
  "mendeleev>=0.15,<1.2",
28
- "jaxns>=2.6.7,<3",
29
24
  "pooch>=1.8.2,<2",
30
25
  "interpax>=0.3.5,<0.4",
31
26
  "watermark>=2.4.3,<3",
32
27
  "catppuccin>=2.3.4,<3",
33
28
  "flax>0.10.5",
29
+ "seaborn>=0.13.2",
34
30
  ]
35
31
 
36
32
  [project.urls]
@@ -44,14 +40,14 @@ jaxspec-debug-info = "jaxspec.scripts.debug:debug_info"
44
40
  docs = [
45
41
  "mkdocs>=1.6.1,<2",
46
42
  "mkdocs-material>=9.4.6,<10",
47
- "mkdocstrings[python]>=0.24,<0.28",
43
+ "mkdocstrings[python]>=0.24,<1.1",
48
44
  "mkdocs-jupyter>=0.25.0,<0.26",
49
45
  ]
50
46
  test = [
51
47
  "chex>=0.1.83,<0.2",
52
48
  "mktestdocs>=0.2.1,<0.3",
53
49
  "coverage>=7.3.2,<8",
54
- "pytest-cov>=4.1,<7.0",
50
+ "pytest-cov>=4.1,<8.0",
55
51
  "flake8>=7.0.0,<8",
56
52
  "pytest>=8.0.0,<9",
57
53
  "testbook>=0.4.2,<0.5",
@@ -77,8 +77,9 @@ class FitResult:
77
77
  r"""
78
78
  Convergence of the chain as computed by the $\hat{R}$ statistic.
79
79
  """
80
+ rhat = az.rhat(self.inference_data)
80
81
 
81
- return all(az.rhat(self.inference_data) < 1.01)
82
+ return bool((rhat.to_array() < 1.01).all())
82
83
 
83
84
  def _ppc_folded_branches(self, obs_id):
84
85
  obs = self.obsconfs[obs_id]
@@ -167,6 +168,7 @@ class FitResult:
167
168
  e_max: float,
168
169
  unit: Unit = u.photon / u.cm**2 / u.s,
169
170
  register: bool = False,
171
+ n_points: int = 100,
170
172
  ) -> ArrayLike:
171
173
  """
172
174
  Compute the unfolded photon flux in a given energy band. The flux is then added to
@@ -177,6 +179,7 @@ class FitResult:
177
179
  e_max: The upper bound of the energy band in observer frame.
178
180
  unit: The unit of the photon flux.
179
181
  register: Whether to register the flux with the other posterior parameters.
182
+ n_points: The number of points to use for computing the unfolded spectrum.
180
183
 
181
184
  !!! warning
182
185
  Computation of the folded flux is not implemented yet. Feel free to open an
@@ -188,18 +191,25 @@ class FitResult:
188
191
  def vectorized_flux(*pars):
189
192
  parameters_pytree = jax.tree.unflatten(pytree_def, pars)
190
193
  return self.model.photon_flux(
191
- parameters_pytree, jnp.asarray([e_min]), jnp.asarray([e_max]), n_points=100
194
+ parameters_pytree, jnp.asarray([e_min]), jnp.asarray([e_max]), n_points=n_points
192
195
  )[0]
193
196
 
194
197
  flat_tree, pytree_def = jax.tree.flatten(self.input_parameters)
195
198
  flux = vectorized_flux(*flat_tree)
196
- conversion_factor = (u.photon / u.cm**2 / u.s).to(unit)
197
- value = flux * conversion_factor
199
+ conversion_factor = float((u.photon / u.cm**2 / u.s).to(unit))
200
+ value = np.asarray(flux * conversion_factor)
198
201
 
199
202
  if register:
200
- self.inference_data.posterior[f"photon_flux_{e_min:.1f}_{e_max:.1f}"] = (
201
- list(self.inference_data.posterior.coords),
202
- value,
203
+ self.inference_data.posterior[f"mod/~/photon_flux_{e_min:.1f}_{e_max:.1f}"] = (
204
+ xr.DataArray(
205
+ value,
206
+ dims=self.inference_data.posterior.dims,
207
+ coords={
208
+ "chain": self.inference_data.posterior.chain,
209
+ "draw": self.inference_data.posterior.draw,
210
+ },
211
+ name=f"mod/~/photon_flux_{e_min:.1f}_{e_max:.1f}",
212
+ )
203
213
  )
204
214
 
205
215
  return value
@@ -210,6 +220,7 @@ class FitResult:
210
220
  e_max: float,
211
221
  unit: Unit = u.erg / u.cm**2 / u.s,
212
222
  register: bool = False,
223
+ n_points: int = 100,
213
224
  ) -> ArrayLike:
214
225
  """
215
226
  Compute the unfolded energy flux in a given energy band. The flux is then added to
@@ -220,6 +231,7 @@ class FitResult:
220
231
  e_max: The upper bound of the energy band in observer frame.
221
232
  unit: The unit of the energy flux.
222
233
  register: Whether to register the flux with the other posterior parameters.
234
+ n_points: The number of points to use for computing the unfolded spectrum.
223
235
 
224
236
  !!! warning
225
237
  Computation of the folded flux is not implemented yet. Feel free to open an
@@ -231,18 +243,25 @@ class FitResult:
231
243
  def vectorized_flux(*pars):
232
244
  parameters_pytree = jax.tree.unflatten(pytree_def, pars)
233
245
  return self.model.energy_flux(
234
- parameters_pytree, jnp.asarray([e_min]), jnp.asarray([e_max]), n_points=100
246
+ parameters_pytree, jnp.asarray([e_min]), jnp.asarray([e_max]), n_points=n_points
235
247
  )[0]
236
248
 
237
249
  flat_tree, pytree_def = jax.tree.flatten(self.input_parameters)
238
250
  flux = vectorized_flux(*flat_tree)
239
- conversion_factor = (u.keV / u.cm**2 / u.s).to(unit)
240
- value = flux * conversion_factor
251
+ conversion_factor = float((u.keV / u.cm**2 / u.s).to(unit))
252
+ value = np.asarray(flux * conversion_factor)
241
253
 
242
254
  if register:
243
- self.inference_data.posterior[f"energy_flux_{e_min:.1f}_{e_max:.1f}"] = (
244
- list(self.inference_data.posterior.coords),
245
- value,
255
+ self.inference_data.posterior[f"mod/~/energy_flux_{e_min:.1f}_{e_max:.1f}"] = (
256
+ xr.DataArray(
257
+ value,
258
+ dims=self.inference_data.posterior.dims,
259
+ coords={
260
+ "chain": self.inference_data.posterior.chain,
261
+ "draw": self.inference_data.posterior.draw,
262
+ },
263
+ name=f"mod/~/energy_flux_{e_min:.1f}_{e_max:.1f}",
264
+ )
246
265
  )
247
266
 
248
267
  return value
@@ -257,6 +276,7 @@ class FitResult:
257
276
  cosmology: Cosmology = Planck18,
258
277
  unit: Unit = u.erg / u.s,
259
278
  register: bool = False,
279
+ n_points=100,
260
280
  ) -> ArrayLike:
261
281
  """
262
282
  Compute the luminosity of the source specifying its redshift. The luminosity is then added to
@@ -294,17 +314,26 @@ class FitResult:
294
314
  parameters_pytree,
295
315
  jnp.asarray([e_min]) * (1 + redshift),
296
316
  jnp.asarray([e_max]) * (1 + redshift),
297
- n_points=100,
317
+ n_points=n_points,
298
318
  )[0]
299
319
 
300
320
  flat_tree, pytree_def = jax.tree.flatten(self.input_parameters)
301
321
  flux = vectorized_flux(*flat_tree) * (u.keV / u.cm**2 / u.s)
302
- value = (flux * (4 * np.pi * cosmology.luminosity_distance(redshift) ** 2)).to(unit)
322
+ value = np.asarray(
323
+ (flux * (4 * np.pi * cosmology.luminosity_distance(redshift) ** 2)).to(unit)
324
+ )
303
325
 
304
326
  if register:
305
- self.inference_data.posterior[f"luminosity_{e_min:.1f}_{e_max:.1f}"] = (
306
- list(self.inference_data.posterior.coords),
307
- value,
327
+ self.inference_data.posterior[f"mod/~/luminosity_{e_min:.1f}_{e_max:.1f}"] = (
328
+ xr.DataArray(
329
+ value,
330
+ dims=self.inference_data.posterior.dims,
331
+ coords={
332
+ "chain": self.inference_data.posterior.chain,
333
+ "draw": self.inference_data.posterior.draw,
334
+ },
335
+ name=f"mod/~/luminosity_{e_min:.1f}_{e_max:.1f}",
336
+ )
308
337
  )
309
338
 
310
339
  return value
@@ -315,10 +344,13 @@ class FitResult:
315
344
 
316
345
  Parameters:
317
346
  name: The name of the chain.
347
+ parameter_kind: The kind of parameters to keep.
318
348
  """
319
349
 
320
350
  keys_to_drop = [
321
- key for key in self.inference_data.posterior.keys() if not key.startswith("mod")
351
+ key
352
+ for key in self.inference_data.posterior.keys()
353
+ if not key.startswith(parameter_kind)
322
354
  ]
323
355
 
324
356
  reduced_id = az.extract(
@@ -403,6 +435,7 @@ class FitResult:
403
435
  title: str | None = None,
404
436
  figsize: tuple[float, float] = (6, 6),
405
437
  x_lims: tuple[float, float] | None = None,
438
+ rescale_background: bool = False,
406
439
  ) -> list[plt.Figure]:
407
440
  r"""
408
441
  Plot the posterior predictive distribution of the model. It also features a residual plot, defined using the
@@ -423,6 +456,7 @@ class FitResult:
423
456
  title: The title of the plot.
424
457
  figsize: The size of the figure.
425
458
  x_lims: The limits of the x-axis.
459
+ rescale_background: Whether to rescale the background model to the data with backscal ratio.
426
460
 
427
461
  Returns:
428
462
  A list of matplotlib figures for each observation in the model.
@@ -573,10 +607,14 @@ class FitResult:
573
607
  )
574
608
  )
575
609
 
610
+ rescale_background_factor = (
611
+ obsconf.folded_backratio.data if rescale_background else 1.0
612
+ )
613
+
576
614
  model_bkg_plot = _plot_binned_samples_with_error(
577
615
  ax[0],
578
616
  xbins.value,
579
- y_samples_bkg.value,
617
+ y_samples_bkg.value * rescale_background_factor,
580
618
  color=BACKGROUND_COLOR,
581
619
  alpha_envelope=alpha_envelope,
582
620
  n_sigmas=n_sigmas,
@@ -585,9 +623,9 @@ class FitResult:
585
623
  true_bkg_plot = _plot_poisson_data_with_error(
586
624
  ax[0],
587
625
  xbins.value,
588
- y_observed_bkg.value,
589
- y_observed_bkg_low.value,
590
- y_observed_bkg_high.value,
626
+ y_observed_bkg.value * rescale_background_factor,
627
+ y_observed_bkg_low.value * rescale_background_factor,
628
+ y_observed_bkg_high.value * rescale_background_factor,
591
629
  color=BACKGROUND_DATA_COLOR,
592
630
  alpha=0.7,
593
631
  )
@@ -1,4 +1,5 @@
1
1
  import os
2
+ import sparse
2
3
 
3
4
  import numpy as np
4
5
  import xarray as xr
@@ -92,7 +93,7 @@ class Instrument(xr.Dataset):
92
93
 
93
94
  else:
94
95
  specresp = rmf.matrix.sum(axis=0)
95
- rmf.sparse_matrix /= specresp
96
+ rmf.sparse_matrix = sparse.COO( rmf.matrix / specresp )
96
97
 
97
98
  return cls.from_matrix(
98
99
  rmf.sparse_matrix, specresp, rmf.energ_lo, rmf.energ_hi, rmf.e_min, rmf.e_max
@@ -163,10 +163,8 @@ class ObsConfiguration(xr.Dataset):
163
163
 
164
164
  if observation.folded_background is not None:
165
165
  folded_background = observation.folded_background.data[row_idx]
166
- folded_background_unscaled = observation.folded_background_unscaled.data[row_idx]
167
166
  else:
168
167
  folded_background = np.zeros_like(folded_counts)
169
- folded_background_unscaled = np.zeros_like(folded_counts)
170
168
 
171
169
  data_dict = {
172
170
  "transfer_matrix": (
@@ -208,14 +206,6 @@ class ObsConfiguration(xr.Dataset):
208
206
  "unit": "photons",
209
207
  },
210
208
  ),
211
- "folded_background_unscaled": (
212
- ["folded_channel"],
213
- folded_background_unscaled,
214
- {
215
- "description": "To be done",
216
- "unit": "photons",
217
- },
218
- ),
219
209
  }
220
210
 
221
211
  return cls(
@@ -46,16 +46,14 @@ class Observation(xr.Dataset):
46
46
  quality,
47
47
  exposure,
48
48
  background=None,
49
- background_unscaled=None,
50
49
  backratio=1.0,
51
50
  attributes: dict | None = None,
52
51
  ):
53
52
  if attributes is None:
54
53
  attributes = {}
55
54
 
56
- if background is None or background_unscaled is None:
55
+ if background is None:
57
56
  background = np.zeros_like(counts, dtype=np.int64)
58
- background_unscaled = np.zeros_like(counts, dtype=np.int64)
59
57
 
60
58
  data_dict = {
61
59
  "counts": (
@@ -86,7 +84,9 @@ class Observation(xr.Dataset):
86
84
  ),
87
85
  "folded_backratio": (
88
86
  ["folded_channel"],
89
- np.asarray(np.ma.filled(grouping @ backratio), dtype=float),
87
+ np.asarray(
88
+ np.ma.filled(grouping @ backratio) / grouping.sum(axis=1).todense(), dtype=float
89
+ ),
90
90
  {"description": "Background scaling after grouping"},
91
91
  ),
92
92
  "background": (
@@ -94,11 +94,6 @@ class Observation(xr.Dataset):
94
94
  np.asarray(background, dtype=np.int64),
95
95
  {"description": "Background counts", "unit": "photons"},
96
96
  ),
97
- "folded_background_unscaled": (
98
- ["folded_channel"],
99
- np.asarray(np.ma.filled(grouping @ background_unscaled), dtype=np.int64),
100
- {"description": "Background counts", "unit": "photons"},
101
- ),
102
97
  "folded_background": (
103
98
  ["folded_channel"],
104
99
  np.asarray(np.ma.filled(grouping @ background), dtype=np.float64),
@@ -147,8 +142,7 @@ class Observation(xr.Dataset):
147
142
  pha.quality,
148
143
  pha.exposure,
149
144
  backratio=backratio,
150
- background=bkg.counts * backratio if bkg is not None else None,
151
- background_unscaled=bkg.counts if bkg is not None else None,
145
+ background=bkg.counts if bkg is not None else None,
152
146
  attributes=metadata,
153
147
  )
154
148
 
@@ -152,11 +152,11 @@ def forward_model_with_multiple_inputs(
152
152
  transfer_matrix = BCOO.from_scipy_sparse(
153
153
  obs_configuration.transfer_matrix.data.to_scipy_sparse().tocsr()
154
154
  )
155
+ expected_counts = transfer_matrix @ flux_func(parameters).T
155
156
 
156
157
  else:
157
158
  transfer_matrix = np.asarray(obs_configuration.transfer_matrix.data.todense())
158
-
159
- expected_counts = jnp.matvec(transfer_matrix, flux_func(parameters))
159
+ expected_counts = jnp.matvec(transfer_matrix, flux_func(parameters))
160
160
 
161
161
  # The result is clipped at 1e-6 to avoid 0 round-off and diverging likelihoods
162
162
  return jnp.clip(expected_counts, a_min=1e-6)
@@ -135,7 +135,7 @@ class BayesianModel(nnx.Module):
135
135
  with numpyro.plate("obs_plate/~/" + name, len(observation.folded_counts)):
136
136
  numpyro.sample(
137
137
  "obs/~/" + name,
138
- Poisson(obs_countrate + bkg_countrate / observation.folded_backratio.data),
138
+ Poisson(obs_countrate + bkg_countrate * observation.folded_backratio.data),
139
139
  obs=observation.folded_counts.data if observed else None,
140
140
  )
141
141
 
@@ -9,6 +9,7 @@ import matplotlib.pyplot as plt
9
9
  import numpyro
10
10
 
11
11
  from jax import random
12
+ from jax.numpy import concatenate
12
13
  from jax.random import PRNGKey
13
14
  from numpyro.infer import AIES, ESS, MCMC, NUTS, SVI, Predictive, Trace_ELBO
14
15
  from numpyro.infer.autoguide import AutoMultivariateNormal
@@ -52,9 +53,18 @@ class BayesianModelFitter(BayesianModel, ABC):
52
53
  )
53
54
 
54
55
  log_likelihood = numpyro.infer.log_likelihood(numpyro_model, posterior_samples)
56
+ if len(log_likelihood.keys()) > 1:
57
+ log_likelihood["full"] = concatenate([ll for _, ll in log_likelihood.items()], axis=1)
58
+ log_likelihood["obs/~/all"] = concatenate(
59
+ [ll for k, ll in log_likelihood.items() if "obs" in k], axis=1
60
+ )
61
+ if self.background_model is not None:
62
+ log_likelihood["bkg/~/all"] = concatenate(
63
+ [ll for k, ll in log_likelihood.items() if "bkg" in k], axis=1
64
+ )
55
65
 
56
66
  seeded_model = numpyro.handlers.substitute(
57
- numpyro.handlers.seed(numpyro_model, keys[3]),
67
+ numpyro.handlers.seed(numpyro_model, keys[2]),
58
68
  substitute_fn=numpyro.infer.init_to_sample,
59
69
  )
60
70
 
@@ -108,12 +118,10 @@ class BayesianModelFitter(BayesianModel, ABC):
108
118
  predictive_parameters = []
109
119
 
110
120
  for key, value in self._observation_container.items():
121
+ predictive_parameters.append(f"obs/~/{key}")
111
122
  if self.background_model is not None:
112
- predictive_parameters.append(f"obs/~/{key}")
113
123
  predictive_parameters.append(f"bkg/~/{key}")
114
124
  # predictive_parameters.append(f"ins/~/{key}")
115
- else:
116
- predictive_parameters.append(f"obs/~/{key}")
117
125
  # predictive_parameters.append(f"ins/~/{key}")
118
126
 
119
127
  inference_data.posterior_predictive = inference_data.posterior_predictive[
@@ -215,13 +223,29 @@ class VIFitter(BayesianModelFitter):
215
223
  self,
216
224
  rng_key: int = 0,
217
225
  num_steps: int = 10_000,
218
- optimizer=numpyro.optim.Adam(step_size=0.0005),
219
- loss=Trace_ELBO(),
226
+ optimizer: numpyro.optim._NumPyroOptim = numpyro.optim.Adam(step_size=0.0005),
227
+ loss: numpyro.infer.elbo.ELBO = Trace_ELBO(),
220
228
  num_samples: int = 1000,
221
- guide=None,
229
+ guide: numpyro.infer.autoguide.AutoGuide | None = None,
222
230
  use_transformed_model: bool = True,
223
231
  plot_diagnostics: bool = False,
224
232
  ) -> FitResult:
233
+ """
234
+ Fit the model to the data using a variational inference approach from numpyro.
235
+
236
+ Parameters:
237
+ rng_key: the random key used to initialize the sampler.
238
+ num_steps: the number of steps for VI.
239
+ optimizer: the optimizer to use.
240
+ num_samples: the number of samples to draw.
241
+ loss: the loss function to use.
242
+ guide: the guide to use.
243
+ use_transformed_model: whether to use the transformed model to build the InferenceData.
244
+ plot_diagnostics: plot the loss during VI.
245
+
246
+ Returns:
247
+ A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
248
+ """
225
249
  bayesian_model = (
226
250
  self.transformed_numpyro_model if use_transformed_model else self.numpyro_model
227
251
  )
@@ -231,7 +255,7 @@ class VIFitter(BayesianModelFitter):
231
255
 
232
256
  svi = SVI(bayesian_model, guide, optimizer, loss=loss)
233
257
 
234
- keys = random.split(random.PRNGKey(rng_key), 3)
258
+ keys = random.split(random.PRNGKey(rng_key), 2)
235
259
  svi_result = svi.run(keys[0], num_steps)
236
260
  params = svi_result.params
237
261
 
@@ -372,9 +372,10 @@ class AdditiveComponent(ModelComponent):
372
372
  continuum = self.continuum(energy)
373
373
  integrated_continuum = self.integrated_continuum(e_low, e_high)
374
374
 
375
- return jsp.integrate.trapezoid(
376
- continuum * energy**2, jnp.log(energy), axis=-1
377
- ) + integrated_continuum * (e_high - e_low)
375
+ return (
376
+ jsp.integrate.trapezoid(continuum * energy**2, jnp.log(energy), axis=-1)
377
+ + integrated_continuum * (e_high + e_low) / 2.0
378
+ )
378
379
 
379
380
  @partial(jax.jit, static_argnums=0, static_argnames="n_points")
380
381
  def photon_flux(self, params, e_low, e_high, n_points=2):
@@ -8,13 +8,26 @@ from numpyro.distributions import Distribution
8
8
 
9
9
 
10
10
  class GainModel(ABC, nnx.Module):
11
+ """
12
+ Generic class for a gain model
13
+ """
14
+
11
15
  @abstractmethod
12
16
  def numpyro_model(self, observation_name: str):
13
17
  pass
14
18
 
15
19
 
16
20
  class ConstantGain(GainModel):
21
+ """
22
+ A constant gain model
23
+ """
24
+
17
25
  def __init__(self, prior_distribution: Distribution):
26
+ """
27
+ Parameters:
28
+ prior_distribution: the prior distribution for the gain value.
29
+ """
30
+
18
31
  self.prior_distribution = prior_distribution
19
32
 
20
33
  def numpyro_model(self, observation_name: str):
@@ -27,13 +40,25 @@ class ConstantGain(GainModel):
27
40
 
28
41
 
29
42
  class ShiftModel(ABC, nnx.Module):
43
+ """
44
+ Generic class for a shift model
45
+ """
46
+
30
47
  @abstractmethod
31
48
  def numpyro_model(self, observation_name: str):
32
49
  pass
33
50
 
34
51
 
35
52
  class ConstantShift(ShiftModel):
53
+ """
54
+ A constant shift model
55
+ """
56
+
36
57
  def __init__(self, prior_distribution: Distribution):
58
+ """
59
+ Parameters:
60
+ prior_distribution: the prior distribution for the shift value.
61
+ """
37
62
  self.prior_distribution = prior_distribution
38
63
 
39
64
  def numpyro_model(self, observation_name: str):
@@ -52,6 +77,15 @@ class InstrumentModel(nnx.Module):
52
77
  gain_model: GainModel | None = None,
53
78
  shift_model: ShiftModel | None = None,
54
79
  ):
80
+ """
81
+ Encapsulate an instrument model, build as a combination of a shift and gain model.
82
+
83
+ Parameters:
84
+ reference_observation_name : The observation to use as a reference
85
+ gain_model : The gain model
86
+ shift_model : The shift model
87
+ """
88
+
55
89
  self.reference = reference_observation_name
56
90
  self.gain_model = gain_model
57
91
  self.shift_model = shift_model
@@ -25,4 +25,5 @@ table_manager = pooch.create(
25
25
  "example_data/NGC7793_ULX4/MOS2.arf": "sha256:a126ff5a95a5f4bb93ed846944cf411d6e1c448626cb73d347e33324663d8b3f",
26
26
  "example_data/NGC7793_ULX4/PNbackground_spectrum.fits": "sha256:55e017e0c19b324245fef049dff2a7a2e49b9a391667ca9c4f667c4f683b1f49",
27
27
  },
28
+ retry_if_failed=10,
28
29
  )
@@ -70,6 +70,7 @@ pooch_dataset = pooch.create(
70
70
  base_url="https://github.com/HEACIT/curated-test-data/raw/main/",
71
71
  path=str(data_directory),
72
72
  registry=data_hash,
73
+ retry_if_failed=10,
73
74
  )
74
75
 
75
76
  for file in data_hash.keys():
@@ -112,6 +113,15 @@ def get_individual_mcmc_results(obs_model_prior):
112
113
  return [MCMCFitter(model, prior, obsconf).fit(num_samples=5000) for obsconf in obsconfs]
113
114
 
114
115
 
116
+ @pytest.fixture(scope="session")
117
+ def get_failed_mcmc_results(obs_model_prior):
118
+ obsconfs, model, prior = obs_model_prior
119
+
120
+ return [
121
+ MCMCFitter(model, prior, obsconf).fit(num_warmup=10, num_samples=10) for obsconf in obsconfs
122
+ ]
123
+
124
+
115
125
  @pytest.fixture(scope="session")
116
126
  def get_joint_mcmc_result(obs_model_prior):
117
127
  obsconfs, model, prior = obs_model_prior
@@ -14,9 +14,9 @@ spectral_model_background = Powerlaw() + Blackbodyrad()
14
14
 
15
15
  prior_background = {
16
16
  "powerlaw_1_alpha": dist.Uniform(0, 5),
17
- "powerlaw_1_norm": dist.LogUniform(1e-8, 1e-3),
17
+ "powerlaw_1_norm": dist.LogUniform(1e-7, 1e-3),
18
18
  "blackbodyrad_1_kT": dist.Uniform(0, 5),
19
- "blackbodyrad_1_norm": dist.LogUniform(1e-6, 1e-1),
19
+ "blackbodyrad_1_norm": dist.LogUniform(1e-5, 1e-1),
20
20
  }
21
21
 
22
22
 
@@ -36,7 +36,7 @@ def test_background_model(obs_model_prior, bkg_model):
36
36
  obs_list, model, prior = obs_model_prior
37
37
  forward = MCMCFitter(model, prior, obs_list[0], background_model=bkg_model)
38
38
  result = forward.fit(
39
- num_chains=4, num_warmup=100, num_samples=100, mcmc_kwargs={"progress_bar": False}
39
+ num_chains=4, num_warmup=1000, num_samples=1000, mcmc_kwargs={"progress_bar": False}
40
40
  )
41
41
  result.plot_ppc(title=f"Test {bkg_model.__class__.__name__}")
42
42
 
@@ -93,6 +93,19 @@ def test_fakeits_parallel(obsconfs, model, sharded_parameters):
93
93
  chex.assert_type(spectra, int)
94
94
 
95
95
 
96
+ def test_fakeits_sparsify(obsconfs, model, sharded_parameters):
97
+ obsconf = obsconfs[0]
98
+ spectra = fakeit_for_multiple_parameters(
99
+ obsconf, model, sharded_parameters, apply_stat=False, sparsify_matrix=False
100
+ )
101
+ chex.assert_type(spectra, float)
102
+
103
+ spectra = fakeit_for_multiple_parameters(
104
+ obsconf, model, sharded_parameters, apply_stat=False, sparsify_matrix=True
105
+ )
106
+ chex.assert_type(spectra, float)
107
+
108
+
96
109
  def test_fakeits_multiple_observation(obsconfs, model, multidimensional_parameters):
97
110
  obsconf = obsconfs[0]
98
111
  spectra = fakeit_for_multiple_parameters(
@@ -3,10 +3,13 @@ import pytest
3
3
  from jaxspec.fit import BayesianModel
4
4
 
5
5
 
6
- def test_convergence(get_individual_mcmc_results, get_joint_mcmc_result):
6
+ def test_convergence(get_individual_mcmc_results, get_joint_mcmc_result, get_failed_mcmc_results):
7
7
  for result in get_individual_mcmc_results + get_joint_mcmc_result:
8
8
  assert result.converged
9
9
 
10
+ for result in get_failed_mcmc_results:
11
+ assert not result.converged
12
+
10
13
 
11
14
  def test_ns(obs_model_prior):
12
15
  NSFitter = pytest.importorskip("jaxspec.fit.NSFitter")
@@ -54,14 +54,18 @@ def test_posterior_photon_flux(get_joint_mcmc_result):
54
54
  result = get_joint_mcmc_result[0]
55
55
  e_min, e_max = 0.7, 1.2
56
56
  result.photon_flux(e_min, e_max, register=True)
57
- assert f"photon_flux_{e_min:.1f}_{e_max:.1f}" in list(result.inference_data.posterior.keys())
57
+ assert f"mod/~/photon_flux_{e_min:.1f}_{e_max:.1f}" in list(
58
+ result.inference_data.posterior.keys()
59
+ )
58
60
 
59
61
 
60
62
  def test_posterior_energy_flux(get_joint_mcmc_result):
61
63
  result = get_joint_mcmc_result[0]
62
64
  e_min, e_max = 0.7, 1.2
63
65
  result.energy_flux(e_min, e_max, register=True)
64
- assert f"energy_flux_{e_min:.1f}_{e_max:.1f}" in list(result.inference_data.posterior.keys())
66
+ assert f"mod/~/energy_flux_{e_min:.1f}_{e_max:.1f}" in list(
67
+ result.inference_data.posterior.keys()
68
+ )
65
69
 
66
70
 
67
71
  def test_posterior_luminosity(get_joint_mcmc_result):
@@ -76,4 +80,6 @@ def test_posterior_luminosity(get_joint_mcmc_result):
76
80
 
77
81
  result.luminosity(e_min, e_max, redshift=0.1, register=True)
78
82
 
79
- assert f"luminosity_{e_min:.1f}_{e_max:.1f}" in list(result.inference_data.posterior.keys())
83
+ assert f"mod/~/luminosity_{e_min:.1f}_{e_max:.1f}" in list(
84
+ result.inference_data.posterior.keys()
85
+ )
@@ -1,8 +1,10 @@
1
1
  import os
2
2
 
3
+ import astropy.units as u
3
4
  import numpy as np
4
5
  import pytest
5
6
 
7
+ from jaxspec.model.additive import Powerlaw
6
8
  from jaxspec.util.online_storage import table_manager
7
9
 
8
10
  xspec = pytest.importorskip("xspec")
@@ -78,3 +80,40 @@ def test_bins(load_xspec_data, load_jaxspec_data):
78
80
  assert np.isclose(
79
81
  folding.in_energies, xspec_in_energies
80
82
  ).all(), f"The unfolded channel energy bins are not the same as XSPEC {file_pha}"
83
+
84
+
85
+ def test_flux_computation():
86
+ xspec.AllData.clear()
87
+ xspec.AllModels.clear()
88
+
89
+ xspec.AllData.dummyrsp(lowE=0.2, highE=1.7, nBins=10_000)
90
+ m = xspec.Model("powerlaw")
91
+ m.powerlaw.PhoIndex = 2.0
92
+ m.powerlaw.norm = 1.0
93
+
94
+ xspec.AllModels.calcFlux("0.5 1.5")
95
+
96
+ phflux_xspec = m.flux[3] # ph/cm^2/s
97
+ eflux_xspec = m.flux[0] # erg/cm^2/s
98
+
99
+ factor = (1 * u.keV).to(u.erg).value
100
+ phflux_jaxspec = Powerlaw().photon_flux(
101
+ {"powerlaw_1_norm": 1.0, "powerlaw_1_alpha": 2.0}, e_low=0.5, e_high=1.5, n_points=10_000
102
+ )
103
+
104
+ eflux_jaxspec = (
105
+ Powerlaw().energy_flux(
106
+ {"powerlaw_1_norm": 1.0, "powerlaw_1_alpha": 2.0},
107
+ e_low=0.5,
108
+ e_high=1.5,
109
+ n_points=10_000,
110
+ )
111
+ * factor
112
+ )
113
+
114
+ assert np.isclose(
115
+ phflux_xspec, phflux_jaxspec
116
+ ), f"Mismatch between XSPEC and jaxspec on photon flux, got {phflux_xspec} and {phflux_jaxspec}"
117
+ assert np.isclose(
118
+ eflux_xspec, eflux_jaxspec
119
+ ), f"Mismatch between XSPEC and jaxspec on energy flux, got {eflux_xspec} and {eflux_jaxspec}"
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes