jaxspec 0.1.4__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxspec/_fit/_build_model.py +26 -103
- jaxspec/analysis/_plot.py +166 -7
- jaxspec/analysis/results.py +231 -332
- jaxspec/data/instrument.py +47 -12
- jaxspec/data/obsconf.py +12 -2
- jaxspec/data/observation.py +17 -4
- jaxspec/data/ogip.py +32 -13
- jaxspec/data/util.py +60 -70
- jaxspec/fit.py +76 -44
- jaxspec/model/_graph_util.py +151 -0
- jaxspec/model/abc.py +275 -414
- jaxspec/model/additive.py +276 -289
- jaxspec/model/background.py +3 -4
- jaxspec/model/multiplicative.py +102 -86
- jaxspec/scripts/debug.py +1 -1
- jaxspec/util/__init__.py +0 -45
- jaxspec/util/misc.py +25 -0
- jaxspec/util/typing.py +0 -63
- {jaxspec-0.1.4.dist-info → jaxspec-0.2.1.dist-info}/METADATA +13 -14
- jaxspec-0.2.1.dist-info/RECORD +34 -0
- {jaxspec-0.1.4.dist-info → jaxspec-0.2.1.dist-info}/WHEEL +1 -1
- jaxspec/data/grouping.py +0 -23
- jaxspec-0.1.4.dist-info/RECORD +0 -33
- {jaxspec-0.1.4.dist-info → jaxspec-0.2.1.dist-info}/LICENSE.md +0 -0
- {jaxspec-0.1.4.dist-info → jaxspec-0.2.1.dist-info}/entry_points.txt +0 -0
jaxspec/fit.py
CHANGED
|
@@ -10,11 +10,12 @@ import arviz as az
|
|
|
10
10
|
import jax
|
|
11
11
|
import jax.numpy as jnp
|
|
12
12
|
import matplotlib.pyplot as plt
|
|
13
|
-
import numpy as np
|
|
14
13
|
import numpyro
|
|
15
14
|
|
|
16
15
|
from jax import random
|
|
16
|
+
from jax.experimental import mesh_utils
|
|
17
17
|
from jax.random import PRNGKey
|
|
18
|
+
from jax.sharding import PositionalSharding
|
|
18
19
|
from numpyro.contrib.nested_sampling import NestedSampler
|
|
19
20
|
from numpyro.distributions import Poisson, TransformedDistribution
|
|
20
21
|
from numpyro.infer import AIES, ESS, MCMC, NUTS, Predictive
|
|
@@ -23,12 +24,16 @@ from numpyro.infer.reparam import TransformReparam
|
|
|
23
24
|
from numpyro.infer.util import log_density
|
|
24
25
|
|
|
25
26
|
from ._fit._build_model import build_prior, forward_model
|
|
26
|
-
from .analysis._plot import
|
|
27
|
+
from .analysis._plot import (
|
|
28
|
+
_error_bars_for_observed_data,
|
|
29
|
+
_plot_binned_samples_with_error,
|
|
30
|
+
_plot_poisson_data_with_error,
|
|
31
|
+
)
|
|
27
32
|
from .analysis.results import FitResult
|
|
28
33
|
from .data import ObsConfiguration
|
|
29
34
|
from .model.abc import SpectralModel
|
|
30
35
|
from .model.background import BackgroundModel
|
|
31
|
-
from .util.typing import
|
|
36
|
+
from .util.typing import PriorDictType
|
|
32
37
|
|
|
33
38
|
|
|
34
39
|
class BayesianModel:
|
|
@@ -63,10 +68,12 @@ class BayesianModel:
|
|
|
63
68
|
|
|
64
69
|
if not callable(prior_distributions):
|
|
65
70
|
# Validate the entry with pydantic
|
|
66
|
-
prior = PriorDictModel.from_dict(prior_distributions).
|
|
71
|
+
# prior = PriorDictModel.from_dict(prior_distributions).
|
|
67
72
|
|
|
68
73
|
def prior_distributions_func():
|
|
69
|
-
return build_prior(
|
|
74
|
+
return build_prior(
|
|
75
|
+
prior_distributions, expand_shape=(len(self.observation_container),)
|
|
76
|
+
)
|
|
70
77
|
|
|
71
78
|
else:
|
|
72
79
|
prior_distributions_func = prior_distributions
|
|
@@ -74,6 +81,22 @@ class BayesianModel:
|
|
|
74
81
|
self.prior_distributions_func = prior_distributions_func
|
|
75
82
|
self.init_params = self.prior_samples()
|
|
76
83
|
|
|
84
|
+
# Check the priors are suited for the observations
|
|
85
|
+
split_parameters = [
|
|
86
|
+
(param, shape[-1])
|
|
87
|
+
for param, shape in jax.tree.map(lambda x: x.shape, self.init_params).items()
|
|
88
|
+
if (len(shape) > 1)
|
|
89
|
+
and not param.startswith("_")
|
|
90
|
+
and not param.startswith("bkg") # hardcoded for subtracted background
|
|
91
|
+
]
|
|
92
|
+
|
|
93
|
+
for parameter, proposed_number_of_obs in split_parameters:
|
|
94
|
+
if proposed_number_of_obs != len(self.observation_container):
|
|
95
|
+
raise ValueError(
|
|
96
|
+
f"Invalid splitting in the prior distribution. "
|
|
97
|
+
f"Expected {len(self.observation_container)} but got {proposed_number_of_obs} for {parameter}"
|
|
98
|
+
)
|
|
99
|
+
|
|
77
100
|
@cached_property
|
|
78
101
|
def observation_container(self) -> dict[str, ObsConfiguration]:
|
|
79
102
|
"""
|
|
@@ -137,7 +160,9 @@ class BayesianModel:
|
|
|
137
160
|
with numpyro.plate("obs_plate_" + name, len(observation.folded_counts)):
|
|
138
161
|
numpyro.sample(
|
|
139
162
|
"obs_" + name,
|
|
140
|
-
Poisson(
|
|
163
|
+
Poisson(
|
|
164
|
+
obs_countrate + bkg_countrate
|
|
165
|
+
), # / observation.folded_backratio.data
|
|
141
166
|
obs=observation.folded_counts.data if observed else None,
|
|
142
167
|
)
|
|
143
168
|
|
|
@@ -289,41 +314,48 @@ class BayesianModel:
|
|
|
289
314
|
Check if the prior distribution include the observed data.
|
|
290
315
|
"""
|
|
291
316
|
key_prior, key_posterior = jax.random.split(key, 2)
|
|
317
|
+
n_devices = len(jax.local_devices())
|
|
318
|
+
sharding = PositionalSharding(mesh_utils.create_device_mesh((n_devices,)))
|
|
319
|
+
|
|
320
|
+
# Sample from prior and correct if the number of samples is not a multiple of the number of devices
|
|
321
|
+
if num_samples % n_devices != 0:
|
|
322
|
+
num_samples = num_samples + n_devices - (num_samples % n_devices)
|
|
323
|
+
|
|
292
324
|
prior_params = self.prior_samples(key=key_prior, num_samples=num_samples)
|
|
293
|
-
|
|
325
|
+
|
|
326
|
+
# Split the parameters on every device
|
|
327
|
+
sharded_parameters = jax.device_put(prior_params, sharding)
|
|
328
|
+
posterior_observations = self.mock_observations(sharded_parameters, key=key_posterior)
|
|
294
329
|
|
|
295
330
|
for key, value in self.observation_container.items():
|
|
296
|
-
fig,
|
|
331
|
+
fig, ax = plt.subplots(
|
|
297
332
|
nrows=2, ncols=1, sharex=True, figsize=(5, 6), height_ratios=[3, 1]
|
|
298
333
|
)
|
|
299
334
|
|
|
300
|
-
|
|
301
|
-
|
|
335
|
+
legend_plots = []
|
|
336
|
+
legend_labels = []
|
|
337
|
+
|
|
338
|
+
y_observed, y_observed_low, y_observed_high = _error_bars_for_observed_data(
|
|
339
|
+
value.folded_counts.values, 1.0, "ct"
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
true_data_plot = _plot_poisson_data_with_error(
|
|
343
|
+
ax[0],
|
|
302
344
|
value.out_energies,
|
|
303
|
-
value
|
|
304
|
-
|
|
345
|
+
y_observed.value,
|
|
346
|
+
y_observed_low.value,
|
|
347
|
+
y_observed_high.value,
|
|
348
|
+
alpha=0.7,
|
|
305
349
|
)
|
|
306
350
|
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
["#03045e", "#0077b6", "#00b4d8"],
|
|
311
|
-
[0.5, 0.4, 0.3],
|
|
312
|
-
)
|
|
313
|
-
):
|
|
314
|
-
lower, upper = np.percentile(
|
|
315
|
-
posterior_observations["obs_" + key], envelop_percentiles, axis=0
|
|
316
|
-
)
|
|
351
|
+
prior_plot = _plot_binned_samples_with_error(
|
|
352
|
+
ax[0], value.out_energies, posterior_observations["obs_" + key], n_sigmas=3
|
|
353
|
+
)
|
|
317
354
|
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
alpha=alpha,
|
|
323
|
-
fill=True,
|
|
324
|
-
color=color,
|
|
325
|
-
label=rf"${1+i}\sigma$",
|
|
326
|
-
)
|
|
355
|
+
legend_plots.append((true_data_plot,))
|
|
356
|
+
legend_labels.append("Observed")
|
|
357
|
+
legend_plots += prior_plot
|
|
358
|
+
legend_labels.append("Prior Predictive")
|
|
327
359
|
|
|
328
360
|
# rank = np.vstack((posterior_observations["obs_" + key], value.folded_counts.values)).argsort(axis=0)[-1] / (num_samples) * 100
|
|
329
361
|
counts = posterior_observations["obs_" + key]
|
|
@@ -336,22 +368,22 @@ class BayesianModel:
|
|
|
336
368
|
|
|
337
369
|
rank = (less_than_obs + 0.5 * equal_to_obs) / num_samples * 100
|
|
338
370
|
|
|
339
|
-
|
|
371
|
+
ax[1].stairs(rank, edges=[*list(value.out_energies[0]), value.out_energies[1][-1]])
|
|
340
372
|
|
|
341
|
-
|
|
373
|
+
ax[1].plot(
|
|
342
374
|
(value.out_energies.min(), value.out_energies.max()),
|
|
343
375
|
(50, 50),
|
|
344
376
|
color="black",
|
|
345
377
|
linestyle="--",
|
|
346
378
|
)
|
|
347
379
|
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
380
|
+
ax[1].set_xlabel("Energy (keV)")
|
|
381
|
+
ax[0].set_ylabel("Counts")
|
|
382
|
+
ax[1].set_ylabel("Rank (%)")
|
|
383
|
+
ax[1].set_ylim(0, 100)
|
|
384
|
+
ax[0].set_xlim(value.out_energies.min(), value.out_energies.max())
|
|
385
|
+
ax[0].loglog()
|
|
386
|
+
ax[0].legend(legend_plots, legend_labels)
|
|
355
387
|
plt.suptitle(f"Prior Predictive coverage for {key}")
|
|
356
388
|
plt.tight_layout()
|
|
357
389
|
plt.show()
|
|
@@ -544,7 +576,6 @@ class MCMCFitter(BayesianModelFitter):
|
|
|
544
576
|
return FitResult(
|
|
545
577
|
self,
|
|
546
578
|
inference_data,
|
|
547
|
-
self.model.params,
|
|
548
579
|
background_model=self.background_model,
|
|
549
580
|
)
|
|
550
581
|
|
|
@@ -590,11 +621,13 @@ class NSFitter(BayesianModelFitter):
|
|
|
590
621
|
ns = NestedSampler(
|
|
591
622
|
bayesian_model,
|
|
592
623
|
constructor_kwargs=dict(
|
|
593
|
-
num_parallel_workers=1,
|
|
594
624
|
verbose=verbose,
|
|
595
625
|
difficult_model=True,
|
|
596
|
-
max_samples=
|
|
626
|
+
max_samples=1e5,
|
|
597
627
|
parameter_estimation=True,
|
|
628
|
+
gradient_guided=True,
|
|
629
|
+
devices=jax.devices(),
|
|
630
|
+
# init_efficiency_threshold=0.01,
|
|
598
631
|
num_live_points=num_live_points,
|
|
599
632
|
),
|
|
600
633
|
termination_kwargs=termination_kwargs if termination_kwargs else dict(),
|
|
@@ -613,6 +646,5 @@ class NSFitter(BayesianModelFitter):
|
|
|
613
646
|
return FitResult(
|
|
614
647
|
self,
|
|
615
648
|
inference_data,
|
|
616
|
-
self.model.params,
|
|
617
649
|
background_model=self.background_model,
|
|
618
650
|
)
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
"""Helper functions to deal with the graph logic within model building"""
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from uuid import uuid4
|
|
7
|
+
|
|
8
|
+
import networkx as nx
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def get_component_names(graph: nx.DiGraph):
|
|
12
|
+
"""
|
|
13
|
+
Get the set of component names from the nodes of a graph.
|
|
14
|
+
|
|
15
|
+
Parameters:
|
|
16
|
+
graph: The graph to get the component names from.
|
|
17
|
+
"""
|
|
18
|
+
return set(
|
|
19
|
+
data["name"] for _, data in graph.nodes(data=True) if "component" in data.get("type")
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def increment_name(name: str, used_names: set):
|
|
24
|
+
"""
|
|
25
|
+
Increment the suffix number in a name if it is formated as 'name_1'.
|
|
26
|
+
|
|
27
|
+
Parameters:
|
|
28
|
+
name: The name to increment.
|
|
29
|
+
used_names: The set of names that are already used.
|
|
30
|
+
"""
|
|
31
|
+
# Use regex to extract base name and suffix number
|
|
32
|
+
match = re.match(r"(.*?)(?:_(\d+))?$", name)
|
|
33
|
+
base_name = match.group(1)
|
|
34
|
+
suffix = match.group(2)
|
|
35
|
+
if suffix:
|
|
36
|
+
number = int(suffix)
|
|
37
|
+
else:
|
|
38
|
+
number = 1 # Start from 1 if there is no suffix
|
|
39
|
+
|
|
40
|
+
new_name = name
|
|
41
|
+
while new_name in used_names:
|
|
42
|
+
number += 1
|
|
43
|
+
new_name = f"{base_name}_{number}"
|
|
44
|
+
|
|
45
|
+
return new_name
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def compose_with_rename(graph_1: nx.DiGraph, graph_2: nx.DiGraph):
|
|
49
|
+
"""
|
|
50
|
+
Compose two graphs by updating the 'name' attributes of nodes in graph_2,
|
|
51
|
+
and return the graph joined on the 'out' node.
|
|
52
|
+
|
|
53
|
+
Parameters:
|
|
54
|
+
graph_1: The first graph to compose.
|
|
55
|
+
graph_2: The second graph to compose.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
# Initialize the set of used names with names from graph_1
|
|
59
|
+
used_names = get_component_names(graph_1)
|
|
60
|
+
|
|
61
|
+
# Update the 'name' attributes in graph_2 to make them unique
|
|
62
|
+
for node, data in graph_2.nodes(data=True):
|
|
63
|
+
if "component" in data.get("type"):
|
|
64
|
+
original_name = data["name"]
|
|
65
|
+
new_name = original_name
|
|
66
|
+
|
|
67
|
+
if new_name in used_names:
|
|
68
|
+
new_name = increment_name(original_name, used_names)
|
|
69
|
+
data["name"] = new_name
|
|
70
|
+
used_names.add(new_name)
|
|
71
|
+
|
|
72
|
+
else:
|
|
73
|
+
used_names.add(new_name)
|
|
74
|
+
|
|
75
|
+
# Now you can safely compose the graphs
|
|
76
|
+
composed_graph = nx.compose(graph_1, graph_2)
|
|
77
|
+
|
|
78
|
+
return composed_graph
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def compose(
|
|
82
|
+
graph_1: nx.DiGraph,
|
|
83
|
+
graph_2: nx.DiGraph,
|
|
84
|
+
operation: str = "",
|
|
85
|
+
operation_func: Callable = lambda x, y: None,
|
|
86
|
+
):
|
|
87
|
+
"""
|
|
88
|
+
Compose two graphs by joining the 'out' node of graph_1 and graph_2, and turning
|
|
89
|
+
it to an 'operation' node with the relevant operator and add a new 'out' node.
|
|
90
|
+
|
|
91
|
+
Parameters:
|
|
92
|
+
graph_1: The first graph to compose.
|
|
93
|
+
graph_2: The second graph to compose.
|
|
94
|
+
operation: The string describing the operation to perform.
|
|
95
|
+
operation_func: The callable that performs the operation.
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
combined_graph = compose_with_rename(graph_1, graph_2)
|
|
99
|
+
node_id = str(uuid4())
|
|
100
|
+
graph = nx.relabel_nodes(combined_graph, {"out": node_id})
|
|
101
|
+
nx.set_node_attributes(graph, {node_id: f"{operation}_operation"}, "type")
|
|
102
|
+
nx.set_node_attributes(graph, {node_id: operation_func}, "operator")
|
|
103
|
+
|
|
104
|
+
# Now add the output node and link it to the operation node
|
|
105
|
+
graph.add_node("out", type="out")
|
|
106
|
+
graph.add_edge(node_id, "out")
|
|
107
|
+
|
|
108
|
+
# Compute the new depth of each node
|
|
109
|
+
longest_path = nx.dag_longest_path_length(graph)
|
|
110
|
+
|
|
111
|
+
for node in graph.nodes:
|
|
112
|
+
nx.set_node_attributes(
|
|
113
|
+
graph,
|
|
114
|
+
{node: longest_path - nx.shortest_path_length(graph, node, "out")},
|
|
115
|
+
"depth",
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
return graph
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def export_to_mermaid(graph, file=None):
|
|
122
|
+
mermaid_code = "graph LR\n" # LR = left to right
|
|
123
|
+
|
|
124
|
+
# Add nodes
|
|
125
|
+
for node, attributes in graph.nodes(data=True):
|
|
126
|
+
if attributes["type"] == "out":
|
|
127
|
+
mermaid_code += f' {node}("Output")\n'
|
|
128
|
+
|
|
129
|
+
else:
|
|
130
|
+
operation_type, node_type = attributes["type"].split("_")
|
|
131
|
+
|
|
132
|
+
if node_type == "component":
|
|
133
|
+
name, number = attributes["name"].split("_")
|
|
134
|
+
mermaid_code += f' {node}("{name.capitalize()} ({number})")\n'
|
|
135
|
+
|
|
136
|
+
elif node_type == "operation":
|
|
137
|
+
if operation_type == "add":
|
|
138
|
+
mermaid_code += f" {node}{{**+**}}\n"
|
|
139
|
+
|
|
140
|
+
elif operation_type == "mul":
|
|
141
|
+
mermaid_code += f" {node}{{**x**}}\n"
|
|
142
|
+
|
|
143
|
+
# Draw connexion between nodes
|
|
144
|
+
for source, target in graph.edges():
|
|
145
|
+
mermaid_code += f" {source} --> {target}\n"
|
|
146
|
+
|
|
147
|
+
if file is None:
|
|
148
|
+
return mermaid_code
|
|
149
|
+
else:
|
|
150
|
+
with open(file, "w") as f:
|
|
151
|
+
f.write(mermaid_code)
|