jaxspec 0.1.4__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxspec/fit.py CHANGED
@@ -10,11 +10,12 @@ import arviz as az
10
10
  import jax
11
11
  import jax.numpy as jnp
12
12
  import matplotlib.pyplot as plt
13
- import numpy as np
14
13
  import numpyro
15
14
 
16
15
  from jax import random
16
+ from jax.experimental import mesh_utils
17
17
  from jax.random import PRNGKey
18
+ from jax.sharding import PositionalSharding
18
19
  from numpyro.contrib.nested_sampling import NestedSampler
19
20
  from numpyro.distributions import Poisson, TransformedDistribution
20
21
  from numpyro.infer import AIES, ESS, MCMC, NUTS, Predictive
@@ -23,12 +24,16 @@ from numpyro.infer.reparam import TransformReparam
23
24
  from numpyro.infer.util import log_density
24
25
 
25
26
  from ._fit._build_model import build_prior, forward_model
26
- from .analysis._plot import _plot_poisson_data_with_error
27
+ from .analysis._plot import (
28
+ _error_bars_for_observed_data,
29
+ _plot_binned_samples_with_error,
30
+ _plot_poisson_data_with_error,
31
+ )
27
32
  from .analysis.results import FitResult
28
33
  from .data import ObsConfiguration
29
34
  from .model.abc import SpectralModel
30
35
  from .model.background import BackgroundModel
31
- from .util.typing import PriorDictModel, PriorDictType
36
+ from .util.typing import PriorDictType
32
37
 
33
38
 
34
39
  class BayesianModel:
@@ -63,10 +68,12 @@ class BayesianModel:
63
68
 
64
69
  if not callable(prior_distributions):
65
70
  # Validate the entry with pydantic
66
- prior = PriorDictModel.from_dict(prior_distributions).nested_dict
71
+ # prior = PriorDictModel.from_dict(prior_distributions).
67
72
 
68
73
  def prior_distributions_func():
69
- return build_prior(prior, expand_shape=(len(self.observation_container),))
74
+ return build_prior(
75
+ prior_distributions, expand_shape=(len(self.observation_container),)
76
+ )
70
77
 
71
78
  else:
72
79
  prior_distributions_func = prior_distributions
@@ -74,6 +81,22 @@ class BayesianModel:
74
81
  self.prior_distributions_func = prior_distributions_func
75
82
  self.init_params = self.prior_samples()
76
83
 
84
+ # Check the priors are suited for the observations
85
+ split_parameters = [
86
+ (param, shape[-1])
87
+ for param, shape in jax.tree.map(lambda x: x.shape, self.init_params).items()
88
+ if (len(shape) > 1)
89
+ and not param.startswith("_")
90
+ and not param.startswith("bkg") # hardcoded for subtracted background
91
+ ]
92
+
93
+ for parameter, proposed_number_of_obs in split_parameters:
94
+ if proposed_number_of_obs != len(self.observation_container):
95
+ raise ValueError(
96
+ f"Invalid splitting in the prior distribution. "
97
+ f"Expected {len(self.observation_container)} but got {proposed_number_of_obs} for {parameter}"
98
+ )
99
+
77
100
  @cached_property
78
101
  def observation_container(self) -> dict[str, ObsConfiguration]:
79
102
  """
@@ -137,7 +160,9 @@ class BayesianModel:
137
160
  with numpyro.plate("obs_plate_" + name, len(observation.folded_counts)):
138
161
  numpyro.sample(
139
162
  "obs_" + name,
140
- Poisson(obs_countrate + bkg_countrate / observation.folded_backratio.data),
163
+ Poisson(
164
+ obs_countrate + bkg_countrate
165
+ ), # / observation.folded_backratio.data
141
166
  obs=observation.folded_counts.data if observed else None,
142
167
  )
143
168
 
@@ -289,41 +314,48 @@ class BayesianModel:
289
314
  Check if the prior distribution include the observed data.
290
315
  """
291
316
  key_prior, key_posterior = jax.random.split(key, 2)
317
+ n_devices = len(jax.local_devices())
318
+ sharding = PositionalSharding(mesh_utils.create_device_mesh((n_devices,)))
319
+
320
+ # Sample from prior and correct if the number of samples is not a multiple of the number of devices
321
+ if num_samples % n_devices != 0:
322
+ num_samples = num_samples + n_devices - (num_samples % n_devices)
323
+
292
324
  prior_params = self.prior_samples(key=key_prior, num_samples=num_samples)
293
- posterior_observations = self.mock_observations(prior_params, key=key_posterior)
325
+
326
+ # Split the parameters on every device
327
+ sharded_parameters = jax.device_put(prior_params, sharding)
328
+ posterior_observations = self.mock_observations(sharded_parameters, key=key_posterior)
294
329
 
295
330
  for key, value in self.observation_container.items():
296
- fig, axs = plt.subplots(
331
+ fig, ax = plt.subplots(
297
332
  nrows=2, ncols=1, sharex=True, figsize=(5, 6), height_ratios=[3, 1]
298
333
  )
299
334
 
300
- _plot_poisson_data_with_error(
301
- axs[0],
335
+ legend_plots = []
336
+ legend_labels = []
337
+
338
+ y_observed, y_observed_low, y_observed_high = _error_bars_for_observed_data(
339
+ value.folded_counts.values, 1.0, "ct"
340
+ )
341
+
342
+ true_data_plot = _plot_poisson_data_with_error(
343
+ ax[0],
302
344
  value.out_energies,
303
- value.folded_counts.values,
304
- percentiles=(16, 84),
345
+ y_observed.value,
346
+ y_observed_low.value,
347
+ y_observed_high.value,
348
+ alpha=0.7,
305
349
  )
306
350
 
307
- for i, (envelop_percentiles, color, alpha) in enumerate(
308
- zip(
309
- [(16, 86), (2.5, 97.5), (0.15, 99.85)],
310
- ["#03045e", "#0077b6", "#00b4d8"],
311
- [0.5, 0.4, 0.3],
312
- )
313
- ):
314
- lower, upper = np.percentile(
315
- posterior_observations["obs_" + key], envelop_percentiles, axis=0
316
- )
351
+ prior_plot = _plot_binned_samples_with_error(
352
+ ax[0], value.out_energies, posterior_observations["obs_" + key], n_sigmas=3
353
+ )
317
354
 
318
- axs[0].stairs(
319
- upper,
320
- edges=[*list(value.out_energies[0]), value.out_energies[1][-1]],
321
- baseline=lower,
322
- alpha=alpha,
323
- fill=True,
324
- color=color,
325
- label=rf"${1+i}\sigma$",
326
- )
355
+ legend_plots.append((true_data_plot,))
356
+ legend_labels.append("Observed")
357
+ legend_plots += prior_plot
358
+ legend_labels.append("Prior Predictive")
327
359
 
328
360
  # rank = np.vstack((posterior_observations["obs_" + key], value.folded_counts.values)).argsort(axis=0)[-1] / (num_samples) * 100
329
361
  counts = posterior_observations["obs_" + key]
@@ -336,22 +368,22 @@ class BayesianModel:
336
368
 
337
369
  rank = (less_than_obs + 0.5 * equal_to_obs) / num_samples * 100
338
370
 
339
- axs[1].stairs(rank, edges=[*list(value.out_energies[0]), value.out_energies[1][-1]])
371
+ ax[1].stairs(rank, edges=[*list(value.out_energies[0]), value.out_energies[1][-1]])
340
372
 
341
- axs[1].plot(
373
+ ax[1].plot(
342
374
  (value.out_energies.min(), value.out_energies.max()),
343
375
  (50, 50),
344
376
  color="black",
345
377
  linestyle="--",
346
378
  )
347
379
 
348
- axs[1].set_xlabel("Energy (keV)")
349
- axs[0].set_ylabel("Counts")
350
- axs[1].set_ylabel("Rank (%)")
351
- axs[1].set_ylim(0, 100)
352
- axs[0].set_xlim(value.out_energies.min(), value.out_energies.max())
353
- axs[0].loglog()
354
- axs[0].legend(loc="upper right")
380
+ ax[1].set_xlabel("Energy (keV)")
381
+ ax[0].set_ylabel("Counts")
382
+ ax[1].set_ylabel("Rank (%)")
383
+ ax[1].set_ylim(0, 100)
384
+ ax[0].set_xlim(value.out_energies.min(), value.out_energies.max())
385
+ ax[0].loglog()
386
+ ax[0].legend(legend_plots, legend_labels)
355
387
  plt.suptitle(f"Prior Predictive coverage for {key}")
356
388
  plt.tight_layout()
357
389
  plt.show()
@@ -544,7 +576,6 @@ class MCMCFitter(BayesianModelFitter):
544
576
  return FitResult(
545
577
  self,
546
578
  inference_data,
547
- self.model.params,
548
579
  background_model=self.background_model,
549
580
  )
550
581
 
@@ -590,11 +621,13 @@ class NSFitter(BayesianModelFitter):
590
621
  ns = NestedSampler(
591
622
  bayesian_model,
592
623
  constructor_kwargs=dict(
593
- num_parallel_workers=1,
594
624
  verbose=verbose,
595
625
  difficult_model=True,
596
- max_samples=1e6,
626
+ max_samples=1e5,
597
627
  parameter_estimation=True,
628
+ gradient_guided=True,
629
+ devices=jax.devices(),
630
+ # init_efficiency_threshold=0.01,
598
631
  num_live_points=num_live_points,
599
632
  ),
600
633
  termination_kwargs=termination_kwargs if termination_kwargs else dict(),
@@ -613,6 +646,5 @@ class NSFitter(BayesianModelFitter):
613
646
  return FitResult(
614
647
  self,
615
648
  inference_data,
616
- self.model.params,
617
649
  background_model=self.background_model,
618
650
  )
@@ -0,0 +1,151 @@
1
+ """Helper functions to deal with the graph logic within model building"""
2
+
3
+ import re
4
+
5
+ from collections.abc import Callable
6
+ from uuid import uuid4
7
+
8
+ import networkx as nx
9
+
10
+
11
+ def get_component_names(graph: nx.DiGraph):
12
+ """
13
+ Get the set of component names from the nodes of a graph.
14
+
15
+ Parameters:
16
+ graph: The graph to get the component names from.
17
+ """
18
+ return set(
19
+ data["name"] for _, data in graph.nodes(data=True) if "component" in data.get("type")
20
+ )
21
+
22
+
23
+ def increment_name(name: str, used_names: set):
24
+ """
25
+ Increment the suffix number in a name if it is formated as 'name_1'.
26
+
27
+ Parameters:
28
+ name: The name to increment.
29
+ used_names: The set of names that are already used.
30
+ """
31
+ # Use regex to extract base name and suffix number
32
+ match = re.match(r"(.*?)(?:_(\d+))?$", name)
33
+ base_name = match.group(1)
34
+ suffix = match.group(2)
35
+ if suffix:
36
+ number = int(suffix)
37
+ else:
38
+ number = 1 # Start from 1 if there is no suffix
39
+
40
+ new_name = name
41
+ while new_name in used_names:
42
+ number += 1
43
+ new_name = f"{base_name}_{number}"
44
+
45
+ return new_name
46
+
47
+
48
+ def compose_with_rename(graph_1: nx.DiGraph, graph_2: nx.DiGraph):
49
+ """
50
+ Compose two graphs by updating the 'name' attributes of nodes in graph_2,
51
+ and return the graph joined on the 'out' node.
52
+
53
+ Parameters:
54
+ graph_1: The first graph to compose.
55
+ graph_2: The second graph to compose.
56
+ """
57
+
58
+ # Initialize the set of used names with names from graph_1
59
+ used_names = get_component_names(graph_1)
60
+
61
+ # Update the 'name' attributes in graph_2 to make them unique
62
+ for node, data in graph_2.nodes(data=True):
63
+ if "component" in data.get("type"):
64
+ original_name = data["name"]
65
+ new_name = original_name
66
+
67
+ if new_name in used_names:
68
+ new_name = increment_name(original_name, used_names)
69
+ data["name"] = new_name
70
+ used_names.add(new_name)
71
+
72
+ else:
73
+ used_names.add(new_name)
74
+
75
+ # Now you can safely compose the graphs
76
+ composed_graph = nx.compose(graph_1, graph_2)
77
+
78
+ return composed_graph
79
+
80
+
81
+ def compose(
82
+ graph_1: nx.DiGraph,
83
+ graph_2: nx.DiGraph,
84
+ operation: str = "",
85
+ operation_func: Callable = lambda x, y: None,
86
+ ):
87
+ """
88
+ Compose two graphs by joining the 'out' node of graph_1 and graph_2, and turning
89
+ it to an 'operation' node with the relevant operator and add a new 'out' node.
90
+
91
+ Parameters:
92
+ graph_1: The first graph to compose.
93
+ graph_2: The second graph to compose.
94
+ operation: The string describing the operation to perform.
95
+ operation_func: The callable that performs the operation.
96
+ """
97
+
98
+ combined_graph = compose_with_rename(graph_1, graph_2)
99
+ node_id = str(uuid4())
100
+ graph = nx.relabel_nodes(combined_graph, {"out": node_id})
101
+ nx.set_node_attributes(graph, {node_id: f"{operation}_operation"}, "type")
102
+ nx.set_node_attributes(graph, {node_id: operation_func}, "operator")
103
+
104
+ # Now add the output node and link it to the operation node
105
+ graph.add_node("out", type="out")
106
+ graph.add_edge(node_id, "out")
107
+
108
+ # Compute the new depth of each node
109
+ longest_path = nx.dag_longest_path_length(graph)
110
+
111
+ for node in graph.nodes:
112
+ nx.set_node_attributes(
113
+ graph,
114
+ {node: longest_path - nx.shortest_path_length(graph, node, "out")},
115
+ "depth",
116
+ )
117
+
118
+ return graph
119
+
120
+
121
+ def export_to_mermaid(graph, file=None):
122
+ mermaid_code = "graph LR\n" # LR = left to right
123
+
124
+ # Add nodes
125
+ for node, attributes in graph.nodes(data=True):
126
+ if attributes["type"] == "out":
127
+ mermaid_code += f' {node}("Output")\n'
128
+
129
+ else:
130
+ operation_type, node_type = attributes["type"].split("_")
131
+
132
+ if node_type == "component":
133
+ name, number = attributes["name"].split("_")
134
+ mermaid_code += f' {node}("{name.capitalize()} ({number})")\n'
135
+
136
+ elif node_type == "operation":
137
+ if operation_type == "add":
138
+ mermaid_code += f" {node}{{**+**}}\n"
139
+
140
+ elif operation_type == "mul":
141
+ mermaid_code += f" {node}{{**x**}}\n"
142
+
143
+ # Draw connexion between nodes
144
+ for source, target in graph.edges():
145
+ mermaid_code += f" {source} --> {target}\n"
146
+
147
+ if file is None:
148
+ return mermaid_code
149
+ else:
150
+ with open(file, "w") as f:
151
+ f.write(mermaid_code)