jaxspec 0.3.0__tar.gz → 0.3.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jaxspec-0.3.0 → jaxspec-0.3.2}/.gitignore +2 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/PKG-INFO +5 -5
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/examples/index.md +6 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/references/fitting.md +8 -0
- jaxspec-0.3.2/docs/references/instrument.md +4 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/mkdocs.yml +2 -3
- {jaxspec-0.3.0 → jaxspec-0.3.2}/pyproject.toml +6 -6
- {jaxspec-0.3.0 → jaxspec-0.3.2}/src/jaxspec/analysis/results.py +3 -1
- {jaxspec-0.3.0 → jaxspec-0.3.2}/src/jaxspec/data/obsconf.py +0 -10
- {jaxspec-0.3.0 → jaxspec-0.3.2}/src/jaxspec/data/observation.py +5 -11
- {jaxspec-0.3.0 → jaxspec-0.3.2}/src/jaxspec/fit/_bayesian_model.py +9 -9
- {jaxspec-0.3.0 → jaxspec-0.3.2}/src/jaxspec/fit/_fitter.py +19 -3
- {jaxspec-0.3.0 → jaxspec-0.3.2}/src/jaxspec/model/instrument.py +34 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/tests/test_background.py +3 -3
- {jaxspec-0.3.0 → jaxspec-0.3.2}/tests/test_fakeit.py +3 -3
- {jaxspec-0.3.0 → jaxspec-0.3.2}/.dockerignore +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/.github/dependabot.yml +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/.github/workflows/documentation-links.yml +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/.github/workflows/publish.yml +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/.github/workflows/test-and-coverage.yml +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/.pre-commit-config.yaml +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/.python-version +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/.readthedocs.yaml +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/CODE_OF_CONDUCT.md +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/Dockerfile +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/LICENSE.md +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/README.md +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/codecov.yml +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/contribute/index.md +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/contribute/internal.md +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/contribute/xspec.md +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/css/extra.css +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/css/material.css +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/css/mkdocstrings.css +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/css/xarray.css +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/dev/index.md +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/examples/background.md +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/examples/build_model.md +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/examples/fakeits.md +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/examples/fitting_example.md +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/examples/statics/background_comparison.png +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/examples/statics/background_gp.png +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/examples/statics/background_spectral.png +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/examples/statics/fakeits.png +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/examples/statics/fitting_example_corner.png +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/examples/statics/fitting_example_ppc.png +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/examples/statics/model.png +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/examples/statics/rmf.png +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/examples/statics/subtract_background.png +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/examples/statics/subtract_background_with_errors.png +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/faq/cookbook.md +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/faq/index.md +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/faq/statics/cstat_vs_chi2.png +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/frontpage/installation.md +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/index.md +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/javascripts/mathjax.js +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/logo/logo_small.svg +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/logo/xifu_mini.svg +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/references/abundance.md +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/references/additive.md +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/references/background.md +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/references/data.md +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/references/integrate.md +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/references/model.md +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/references/multiplicative.md +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/references/results.md +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/runtime/diagram.txt +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/runtime/result_table.txt +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/runtime/various_model_graphs/diagram.txt +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/runtime/various_model_graphs/model_complex.txt +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/runtime/various_model_graphs/model_simple.txt +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/theory/background.md +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/theory/bayesian_inference.md +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/theory/index.md +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/docs/theory/instrument.md +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/src/jaxspec/__init__.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/src/jaxspec/analysis/__init__.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/src/jaxspec/analysis/_plot.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/src/jaxspec/analysis/compare.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/src/jaxspec/data/__init__.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/src/jaxspec/data/instrument.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/src/jaxspec/data/ogip.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/src/jaxspec/data/util.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/src/jaxspec/experimental/__init__.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/src/jaxspec/experimental/interpolator.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/src/jaxspec/experimental/interpolator_jax.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/src/jaxspec/experimental/intrument_models.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/src/jaxspec/experimental/nested_sampler.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/src/jaxspec/experimental/tabulated.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/src/jaxspec/fit/__init__.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/src/jaxspec/fit/_build_model.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/src/jaxspec/model/__init__.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/src/jaxspec/model/_graph_util.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/src/jaxspec/model/abc.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/src/jaxspec/model/additive.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/src/jaxspec/model/background.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/src/jaxspec/model/list.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/src/jaxspec/model/multiplicative.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/src/jaxspec/scripts/__init__.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/src/jaxspec/scripts/debug.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/src/jaxspec/util/__init__.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/src/jaxspec/util/abundance.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/src/jaxspec/util/integrate.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/src/jaxspec/util/misc.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/src/jaxspec/util/online_storage.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/src/jaxspec/util/typing.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/tests/conftest.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/tests/data_files.yml +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/tests/data_hash.yml +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/tests/test_bayesian_model.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/tests/test_bayesian_model_building.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/tests/test_instruments.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/tests/test_integrate.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/tests/test_mcmc.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/tests/test_misc.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/tests/test_models.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/tests/test_observation.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/tests/test_repr.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/tests/test_results.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/tests/test_xspec.py +0 -0
- {jaxspec-0.3.0 → jaxspec-0.3.2}/tests/test_xspec_models.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: jaxspec
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.2
|
|
4
4
|
Summary: jaxspec is a bayesian spectral fitting library for X-ray astronomy.
|
|
5
5
|
Project-URL: Homepage, https://github.com/renecotyfanboy/jaxspec
|
|
6
6
|
Project-URL: Documentation, https://jaxspec.readthedocs.io/en/latest/
|
|
@@ -15,18 +15,18 @@ Requires-Dist: chainconsumer<2,>=1.1.2
|
|
|
15
15
|
Requires-Dist: cmasher<2,>=1.6.3
|
|
16
16
|
Requires-Dist: flax>0.10.5
|
|
17
17
|
Requires-Dist: interpax<0.4,>=0.3.5
|
|
18
|
-
Requires-Dist: jax<0.
|
|
18
|
+
Requires-Dist: jax<0.7,>=0.5.0
|
|
19
19
|
Requires-Dist: jaxns<3,>=2.6.7
|
|
20
20
|
Requires-Dist: jaxopt<0.9,>=0.8.3
|
|
21
21
|
Requires-Dist: matplotlib<4,>=3.8.0
|
|
22
22
|
Requires-Dist: mendeleev<1.2,>=0.15
|
|
23
23
|
Requires-Dist: networkx~=3.1
|
|
24
24
|
Requires-Dist: numpy<3.0.0
|
|
25
|
-
Requires-Dist: numpyro<0.
|
|
26
|
-
Requires-Dist: optimistix<0.0.
|
|
25
|
+
Requires-Dist: numpyro<0.20,>=0.17.0
|
|
26
|
+
Requires-Dist: optimistix<0.0.12,>=0.0.10
|
|
27
27
|
Requires-Dist: pandas<3,>=2.2.0
|
|
28
28
|
Requires-Dist: pooch<2,>=1.8.2
|
|
29
|
-
Requires-Dist: scipy<1.
|
|
29
|
+
Requires-Dist: scipy<1.16
|
|
30
30
|
Requires-Dist: seaborn<0.14,>=0.13.1
|
|
31
31
|
Requires-Dist: simpleeval<1.1.0,>=0.9.13
|
|
32
32
|
Requires-Dist: sparse>0.15
|
|
@@ -18,14 +18,13 @@ nav:
|
|
|
18
18
|
- Add a background : examples/background.md
|
|
19
19
|
- Good practices for MCMC: examples/work_with_arviz.ipynb
|
|
20
20
|
- Interface with other frameworks : examples/external_samplers.ipynb
|
|
21
|
-
|
|
22
|
-
# - models/index.md
|
|
23
|
-
# - APEC: models/apec.md
|
|
21
|
+
- Cross calibration : examples/calibration.ipynb
|
|
24
22
|
- API Reference:
|
|
25
23
|
- Spectral model base: references/model.md
|
|
26
24
|
- Additive models: references/additive.md
|
|
27
25
|
- Multiplicative models: references/multiplicative.md
|
|
28
26
|
- Background models: references/background.md
|
|
27
|
+
- Instrument models: references/instrument.md
|
|
29
28
|
- Data containers: references/data.md
|
|
30
29
|
- Fitting: references/fitting.md
|
|
31
30
|
- Result containers: references/results.md
|
|
@@ -1,17 +1,17 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "jaxspec"
|
|
3
|
-
version = "0.3.
|
|
3
|
+
version = "0.3.2"
|
|
4
4
|
description = "jaxspec is a bayesian spectral fitting library for X-ray astronomy."
|
|
5
5
|
authors = [{ name = "sdupourque", email = "sdupourque@irap.omp.eu" }]
|
|
6
6
|
requires-python = ">=3.10,<3.13"
|
|
7
7
|
readme = "README.md"
|
|
8
8
|
license = "MIT"
|
|
9
9
|
dependencies = [
|
|
10
|
-
"jax>=0.5.0,<0.
|
|
10
|
+
"jax>=0.5.0,<0.7",
|
|
11
11
|
"numpy<3.0.0",
|
|
12
12
|
"pandas>=2.2.0,<3",
|
|
13
13
|
"astropy>=6.0.0,<8",
|
|
14
|
-
"numpyro>=0.17.0,<0.
|
|
14
|
+
"numpyro>=0.17.0,<0.20",
|
|
15
15
|
"networkx~=3.1",
|
|
16
16
|
"matplotlib>=3.8.0,<4",
|
|
17
17
|
"arviz>=0.17.1,<0.23.0",
|
|
@@ -22,8 +22,8 @@ dependencies = [
|
|
|
22
22
|
"tinygp>=0.3.0,<0.4",
|
|
23
23
|
"seaborn>=0.13.1,<0.14",
|
|
24
24
|
"sparse>0.15",
|
|
25
|
-
"optimistix>=0.0.10,<0.0.
|
|
26
|
-
"scipy<1.
|
|
25
|
+
"optimistix>=0.0.10,<0.0.12",
|
|
26
|
+
"scipy<1.16",
|
|
27
27
|
"mendeleev>=0.15,<1.2",
|
|
28
28
|
"jaxns>=2.6.7,<3",
|
|
29
29
|
"pooch>=1.8.2,<2",
|
|
@@ -58,7 +58,7 @@ test = [
|
|
|
58
58
|
]
|
|
59
59
|
dev = [
|
|
60
60
|
"pre-commit>=3.5,<5.0",
|
|
61
|
-
"ruff>=0.2.1,<0.
|
|
61
|
+
"ruff>=0.2.1,<0.15.0",
|
|
62
62
|
"jupyterlab>=4.0.7,<5",
|
|
63
63
|
"notebook>=7.0.6,<8",
|
|
64
64
|
"ipywidgets>=8.1.1,<9",
|
|
@@ -122,7 +122,9 @@ class FitResult:
|
|
|
122
122
|
|
|
123
123
|
samples_shape = (len(posterior.coords["chain"]), len(posterior.coords["draw"]))
|
|
124
124
|
|
|
125
|
-
total_shape = tuple(
|
|
125
|
+
total_shape = tuple(
|
|
126
|
+
posterior.sizes[d] for d in posterior.coords if not (("obs" in d) or ("bkg" in d))
|
|
127
|
+
)
|
|
126
128
|
|
|
127
129
|
posterior = {key: posterior[key].data for key in posterior.data_vars}
|
|
128
130
|
|
|
@@ -163,10 +163,8 @@ class ObsConfiguration(xr.Dataset):
|
|
|
163
163
|
|
|
164
164
|
if observation.folded_background is not None:
|
|
165
165
|
folded_background = observation.folded_background.data[row_idx]
|
|
166
|
-
folded_background_unscaled = observation.folded_background_unscaled.data[row_idx]
|
|
167
166
|
else:
|
|
168
167
|
folded_background = np.zeros_like(folded_counts)
|
|
169
|
-
folded_background_unscaled = np.zeros_like(folded_counts)
|
|
170
168
|
|
|
171
169
|
data_dict = {
|
|
172
170
|
"transfer_matrix": (
|
|
@@ -208,14 +206,6 @@ class ObsConfiguration(xr.Dataset):
|
|
|
208
206
|
"unit": "photons",
|
|
209
207
|
},
|
|
210
208
|
),
|
|
211
|
-
"folded_background_unscaled": (
|
|
212
|
-
["folded_channel"],
|
|
213
|
-
folded_background_unscaled,
|
|
214
|
-
{
|
|
215
|
-
"description": "To be done",
|
|
216
|
-
"unit": "photons",
|
|
217
|
-
},
|
|
218
|
-
),
|
|
219
209
|
}
|
|
220
210
|
|
|
221
211
|
return cls(
|
|
@@ -46,16 +46,14 @@ class Observation(xr.Dataset):
|
|
|
46
46
|
quality,
|
|
47
47
|
exposure,
|
|
48
48
|
background=None,
|
|
49
|
-
background_unscaled=None,
|
|
50
49
|
backratio=1.0,
|
|
51
50
|
attributes: dict | None = None,
|
|
52
51
|
):
|
|
53
52
|
if attributes is None:
|
|
54
53
|
attributes = {}
|
|
55
54
|
|
|
56
|
-
if background is None
|
|
55
|
+
if background is None:
|
|
57
56
|
background = np.zeros_like(counts, dtype=np.int64)
|
|
58
|
-
background_unscaled = np.zeros_like(counts, dtype=np.int64)
|
|
59
57
|
|
|
60
58
|
data_dict = {
|
|
61
59
|
"counts": (
|
|
@@ -86,7 +84,9 @@ class Observation(xr.Dataset):
|
|
|
86
84
|
),
|
|
87
85
|
"folded_backratio": (
|
|
88
86
|
["folded_channel"],
|
|
89
|
-
np.asarray(
|
|
87
|
+
np.asarray(
|
|
88
|
+
np.ma.filled(grouping @ backratio) / grouping.sum(axis=1).todense(), dtype=float
|
|
89
|
+
),
|
|
90
90
|
{"description": "Background scaling after grouping"},
|
|
91
91
|
),
|
|
92
92
|
"background": (
|
|
@@ -94,11 +94,6 @@ class Observation(xr.Dataset):
|
|
|
94
94
|
np.asarray(background, dtype=np.int64),
|
|
95
95
|
{"description": "Background counts", "unit": "photons"},
|
|
96
96
|
),
|
|
97
|
-
"folded_background_unscaled": (
|
|
98
|
-
["folded_channel"],
|
|
99
|
-
np.asarray(np.ma.filled(grouping @ background_unscaled), dtype=np.int64),
|
|
100
|
-
{"description": "Background counts", "unit": "photons"},
|
|
101
|
-
),
|
|
102
97
|
"folded_background": (
|
|
103
98
|
["folded_channel"],
|
|
104
99
|
np.asarray(np.ma.filled(grouping @ background), dtype=np.float64),
|
|
@@ -147,8 +142,7 @@ class Observation(xr.Dataset):
|
|
|
147
142
|
pha.quality,
|
|
148
143
|
pha.exposure,
|
|
149
144
|
backratio=backratio,
|
|
150
|
-
background=bkg.counts
|
|
151
|
-
background_unscaled=bkg.counts if bkg is not None else None,
|
|
145
|
+
background=bkg.counts if bkg is not None else None,
|
|
152
146
|
attributes=metadata,
|
|
153
147
|
)
|
|
154
148
|
|
|
@@ -10,9 +10,8 @@ import matplotlib.pyplot as plt
|
|
|
10
10
|
import numpyro
|
|
11
11
|
|
|
12
12
|
from flax import nnx
|
|
13
|
-
from jax.experimental import mesh_utils
|
|
14
13
|
from jax.random import PRNGKey
|
|
15
|
-
from jax.sharding import
|
|
14
|
+
from jax.sharding import NamedSharding, PartitionSpec
|
|
16
15
|
from numpyro.distributions import Poisson, TransformedDistribution
|
|
17
16
|
from numpyro.infer import Predictive
|
|
18
17
|
from numpyro.infer.inspect import get_model_relations
|
|
@@ -136,7 +135,7 @@ class BayesianModel(nnx.Module):
|
|
|
136
135
|
with numpyro.plate("obs_plate/~/" + name, len(observation.folded_counts)):
|
|
137
136
|
numpyro.sample(
|
|
138
137
|
"obs/~/" + name,
|
|
139
|
-
Poisson(obs_countrate + bkg_countrate
|
|
138
|
+
Poisson(obs_countrate + bkg_countrate * observation.folded_backratio.data),
|
|
140
139
|
obs=observation.folded_counts.data if observed else None,
|
|
141
140
|
)
|
|
142
141
|
|
|
@@ -244,7 +243,7 @@ class BayesianModel(nnx.Module):
|
|
|
244
243
|
return log_posterior_prob
|
|
245
244
|
|
|
246
245
|
@cached_property
|
|
247
|
-
def
|
|
246
|
+
def parameter_names(self) -> list[str]:
|
|
248
247
|
"""
|
|
249
248
|
A list of parameter names for the model.
|
|
250
249
|
"""
|
|
@@ -269,7 +268,7 @@ class BayesianModel(nnx.Module):
|
|
|
269
268
|
"""
|
|
270
269
|
input_params = {}
|
|
271
270
|
|
|
272
|
-
for index, key in enumerate(self.
|
|
271
|
+
for index, key in enumerate(self.parameter_names):
|
|
273
272
|
input_params[key] = theta[index]
|
|
274
273
|
|
|
275
274
|
return input_params
|
|
@@ -279,9 +278,9 @@ class BayesianModel(nnx.Module):
|
|
|
279
278
|
Convert a dictionary of parameters to an array of parameters.
|
|
280
279
|
"""
|
|
281
280
|
|
|
282
|
-
theta = jnp.zeros(len(self.
|
|
281
|
+
theta = jnp.zeros(len(self.parameter_names))
|
|
283
282
|
|
|
284
|
-
for index, key in enumerate(self.
|
|
283
|
+
for index, key in enumerate(self.parameter_names):
|
|
285
284
|
theta = theta.at[index].set(dict_of_params[key])
|
|
286
285
|
|
|
287
286
|
return theta
|
|
@@ -298,7 +297,7 @@ class BayesianModel(nnx.Module):
|
|
|
298
297
|
@jax.jit
|
|
299
298
|
def prior_sample(key):
|
|
300
299
|
return Predictive(
|
|
301
|
-
self.numpyro_model, return_sites=self.
|
|
300
|
+
self.numpyro_model, return_sites=self.parameter_names, num_samples=num_samples
|
|
302
301
|
)(key, observed=False)
|
|
303
302
|
|
|
304
303
|
return prior_sample(key)
|
|
@@ -324,7 +323,8 @@ class BayesianModel(nnx.Module):
|
|
|
324
323
|
"""
|
|
325
324
|
key_prior, key_posterior = jax.random.split(key, 2)
|
|
326
325
|
n_devices = len(jax.local_devices())
|
|
327
|
-
|
|
326
|
+
mesh = jax.make_mesh((n_devices,), ("batch",))
|
|
327
|
+
sharding = NamedSharding(mesh, PartitionSpec("batch"))
|
|
328
328
|
|
|
329
329
|
# Sample from prior and correct if the number of samples is not a multiple of the number of devices
|
|
330
330
|
if num_samples % n_devices != 0:
|
|
@@ -215,13 +215,29 @@ class VIFitter(BayesianModelFitter):
|
|
|
215
215
|
self,
|
|
216
216
|
rng_key: int = 0,
|
|
217
217
|
num_steps: int = 10_000,
|
|
218
|
-
optimizer=numpyro.optim.Adam(step_size=0.0005),
|
|
219
|
-
loss=Trace_ELBO(),
|
|
218
|
+
optimizer: numpyro.optim._NumPyroOptim = numpyro.optim.Adam(step_size=0.0005),
|
|
219
|
+
loss: numpyro.infer.elbo.ELBO = Trace_ELBO(),
|
|
220
220
|
num_samples: int = 1000,
|
|
221
|
-
guide=None,
|
|
221
|
+
guide: numpyro.infer.autoguide.AutoGuide | None = None,
|
|
222
222
|
use_transformed_model: bool = True,
|
|
223
223
|
plot_diagnostics: bool = False,
|
|
224
224
|
) -> FitResult:
|
|
225
|
+
"""
|
|
226
|
+
Fit the model to the data using a variational inference approach from numpyro.
|
|
227
|
+
|
|
228
|
+
Parameters:
|
|
229
|
+
rng_key: the random key used to initialize the sampler.
|
|
230
|
+
num_steps: the number of steps for VI.
|
|
231
|
+
optimizer: the optimizer to use.
|
|
232
|
+
num_samples: the number of samples to draw.
|
|
233
|
+
loss: the loss function to use.
|
|
234
|
+
guide: the guide to use.
|
|
235
|
+
use_transformed_model: whether to use the transformed model to build the InferenceData.
|
|
236
|
+
plot_diagnostics: plot the loss during VI.
|
|
237
|
+
|
|
238
|
+
Returns:
|
|
239
|
+
A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
|
|
240
|
+
"""
|
|
225
241
|
bayesian_model = (
|
|
226
242
|
self.transformed_numpyro_model if use_transformed_model else self.numpyro_model
|
|
227
243
|
)
|
|
@@ -8,13 +8,26 @@ from numpyro.distributions import Distribution
|
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
class GainModel(ABC, nnx.Module):
|
|
11
|
+
"""
|
|
12
|
+
Generic class for a gain model
|
|
13
|
+
"""
|
|
14
|
+
|
|
11
15
|
@abstractmethod
|
|
12
16
|
def numpyro_model(self, observation_name: str):
|
|
13
17
|
pass
|
|
14
18
|
|
|
15
19
|
|
|
16
20
|
class ConstantGain(GainModel):
|
|
21
|
+
"""
|
|
22
|
+
A constant gain model
|
|
23
|
+
"""
|
|
24
|
+
|
|
17
25
|
def __init__(self, prior_distribution: Distribution):
|
|
26
|
+
"""
|
|
27
|
+
Parameters:
|
|
28
|
+
prior_distribution: the prior distribution for the gain value.
|
|
29
|
+
"""
|
|
30
|
+
|
|
18
31
|
self.prior_distribution = prior_distribution
|
|
19
32
|
|
|
20
33
|
def numpyro_model(self, observation_name: str):
|
|
@@ -27,13 +40,25 @@ class ConstantGain(GainModel):
|
|
|
27
40
|
|
|
28
41
|
|
|
29
42
|
class ShiftModel(ABC, nnx.Module):
|
|
43
|
+
"""
|
|
44
|
+
Generic class for a shift model
|
|
45
|
+
"""
|
|
46
|
+
|
|
30
47
|
@abstractmethod
|
|
31
48
|
def numpyro_model(self, observation_name: str):
|
|
32
49
|
pass
|
|
33
50
|
|
|
34
51
|
|
|
35
52
|
class ConstantShift(ShiftModel):
|
|
53
|
+
"""
|
|
54
|
+
A constant shift model
|
|
55
|
+
"""
|
|
56
|
+
|
|
36
57
|
def __init__(self, prior_distribution: Distribution):
|
|
58
|
+
"""
|
|
59
|
+
Parameters:
|
|
60
|
+
prior_distribution: the prior distribution for the shift value.
|
|
61
|
+
"""
|
|
37
62
|
self.prior_distribution = prior_distribution
|
|
38
63
|
|
|
39
64
|
def numpyro_model(self, observation_name: str):
|
|
@@ -52,6 +77,15 @@ class InstrumentModel(nnx.Module):
|
|
|
52
77
|
gain_model: GainModel | None = None,
|
|
53
78
|
shift_model: ShiftModel | None = None,
|
|
54
79
|
):
|
|
80
|
+
"""
|
|
81
|
+
Encapsulate an instrument model, build as a combination of a shift and gain model.
|
|
82
|
+
|
|
83
|
+
Parameters:
|
|
84
|
+
reference_observation_name : The observation to use as a reference
|
|
85
|
+
gain_model : The gain model
|
|
86
|
+
shift_model : The shift model
|
|
87
|
+
"""
|
|
88
|
+
|
|
55
89
|
self.reference = reference_observation_name
|
|
56
90
|
self.gain_model = gain_model
|
|
57
91
|
self.shift_model = shift_model
|
|
@@ -14,9 +14,9 @@ spectral_model_background = Powerlaw() + Blackbodyrad()
|
|
|
14
14
|
|
|
15
15
|
prior_background = {
|
|
16
16
|
"powerlaw_1_alpha": dist.Uniform(0, 5),
|
|
17
|
-
"powerlaw_1_norm": dist.LogUniform(1e-
|
|
17
|
+
"powerlaw_1_norm": dist.LogUniform(1e-7, 1e-3),
|
|
18
18
|
"blackbodyrad_1_kT": dist.Uniform(0, 5),
|
|
19
|
-
"blackbodyrad_1_norm": dist.LogUniform(1e-
|
|
19
|
+
"blackbodyrad_1_norm": dist.LogUniform(1e-5, 1e-1),
|
|
20
20
|
}
|
|
21
21
|
|
|
22
22
|
|
|
@@ -36,7 +36,7 @@ def test_background_model(obs_model_prior, bkg_model):
|
|
|
36
36
|
obs_list, model, prior = obs_model_prior
|
|
37
37
|
forward = MCMCFitter(model, prior, obs_list[0], background_model=bkg_model)
|
|
38
38
|
result = forward.fit(
|
|
39
|
-
num_chains=4, num_warmup=
|
|
39
|
+
num_chains=4, num_warmup=1000, num_samples=1000, mcmc_kwargs={"progress_bar": False}
|
|
40
40
|
)
|
|
41
41
|
result.plot_ppc(title=f"Test {bkg_model.__class__.__name__}")
|
|
42
42
|
|
|
@@ -63,10 +63,10 @@ def model():
|
|
|
63
63
|
|
|
64
64
|
@pytest.fixture
|
|
65
65
|
def sharded_parameters(unidimensional_parameters):
|
|
66
|
-
from jax.
|
|
67
|
-
from jax.sharding import PositionalSharding
|
|
66
|
+
from jax.sharding import NamedSharding, PartitionSpec
|
|
68
67
|
|
|
69
|
-
|
|
68
|
+
mesh = jax.make_mesh((4,), ("batch",))
|
|
69
|
+
sharding = NamedSharding(mesh, PartitionSpec("batch"))
|
|
70
70
|
|
|
71
71
|
return jax.device_put(unidimensional_parameters, sharding)
|
|
72
72
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|