jaxspec 0.1.3__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxspec/fit.py CHANGED
@@ -7,126 +7,31 @@ from functools import cached_property
7
7
  from typing import Literal
8
8
 
9
9
  import arviz as az
10
- import haiku as hk
11
10
  import jax
12
11
  import jax.numpy as jnp
13
12
  import matplotlib.pyplot as plt
14
- import numpy as np
15
13
  import numpyro
16
14
 
17
15
  from jax import random
18
- from jax.experimental.sparse import BCOO
19
16
  from jax.random import PRNGKey
20
- from jax.tree_util import tree_map
21
- from jax.typing import ArrayLike
22
17
  from numpyro.contrib.nested_sampling import NestedSampler
23
- from numpyro.distributions import Distribution, Poisson, TransformedDistribution
18
+ from numpyro.distributions import Poisson, TransformedDistribution
24
19
  from numpyro.infer import AIES, ESS, MCMC, NUTS, Predictive
25
20
  from numpyro.infer.inspect import get_model_relations
26
21
  from numpyro.infer.reparam import TransformReparam
27
22
  from numpyro.infer.util import log_density
28
23
 
29
- from .analysis._plot import _plot_poisson_data_with_error
24
+ from ._fit._build_model import build_prior, forward_model
25
+ from .analysis._plot import (
26
+ _error_bars_for_observed_data,
27
+ _plot_binned_samples_with_error,
28
+ _plot_poisson_data_with_error,
29
+ )
30
30
  from .analysis.results import FitResult
31
31
  from .data import ObsConfiguration
32
32
  from .model.abc import SpectralModel
33
33
  from .model.background import BackgroundModel
34
- from .util.typing import PriorDictModel, PriorDictType
35
-
36
-
37
- def build_prior(prior: PriorDictType, expand_shape: tuple = (), prefix=""):
38
- """
39
- Transform a dictionary of prior distributions into a dictionary of parameters sampled from the prior.
40
- Must be used within a numpyro model.
41
- """
42
- parameters = dict(hk.data_structures.to_haiku_dict(prior))
43
-
44
- for i, (m, n, sample) in enumerate(hk.data_structures.traverse(prior)):
45
- if isinstance(sample, Distribution):
46
- parameters[m][n] = jnp.ones(expand_shape) * numpyro.sample(f"{prefix}{m}_{n}", sample)
47
-
48
- elif isinstance(sample, ArrayLike):
49
- parameters[m][n] = jnp.ones(expand_shape) * sample
50
-
51
- else:
52
- raise ValueError(
53
- f"Invalid prior type {type(sample)} for parameter {prefix}{m}_{n} : {sample}"
54
- )
55
-
56
- return parameters
57
-
58
-
59
- def build_numpyro_model_for_single_obs(
60
- obs: ObsConfiguration,
61
- model: SpectralModel,
62
- background_model: BackgroundModel,
63
- name: str = "",
64
- sparse: bool = False,
65
- ) -> Callable:
66
- """
67
- Build a numpyro model for a given observation and spectral model.
68
- """
69
-
70
- def numpyro_model(prior_params, observed=True):
71
- # prior_params = build_prior(prior_distributions, name=name)
72
- transformed_model = hk.without_apply_rng(
73
- hk.transform(lambda par: CountForwardModel(model, obs, sparse=sparse)(par))
74
- )
75
-
76
- if (getattr(obs, "folded_background", None) is not None) and (background_model is not None):
77
- bkg_countrate = background_model.numpyro_model(
78
- obs, model, name="bkg_" + name, observed=observed
79
- )
80
- elif (getattr(obs, "folded_background", None) is None) and (background_model is not None):
81
- raise ValueError(
82
- "Trying to fit a background model but no background is linked to this observation"
83
- )
84
-
85
- else:
86
- bkg_countrate = 0.0
87
-
88
- obs_model = jax.jit(lambda p: transformed_model.apply(None, p))
89
- countrate = obs_model(prior_params)
90
-
91
- # This is the case where we fit a model to a TOTAL spectrum as defined in OGIP standard
92
- with numpyro.plate("obs_plate_" + name, len(obs.folded_counts)):
93
- numpyro.sample(
94
- "obs_" + name,
95
- Poisson(countrate + bkg_countrate / obs.folded_backratio.data),
96
- obs=obs.folded_counts.data if observed else None,
97
- )
98
-
99
- return numpyro_model
100
-
101
-
102
- class CountForwardModel(hk.Module):
103
- """
104
- A haiku module which allows to build the function that simulates the measured counts
105
- """
106
-
107
- def __init__(self, model: SpectralModel, folding: ObsConfiguration, sparse=False):
108
- super().__init__()
109
- self.model = model
110
- self.energies = jnp.asarray(folding.in_energies)
111
-
112
- if (
113
- sparse
114
- ): # folding.transfer_matrix.data.density > 0.015 is a good criterion to consider sparsify
115
- self.transfer_matrix = BCOO.from_scipy_sparse(
116
- folding.transfer_matrix.data.to_scipy_sparse().tocsr()
117
- )
118
-
119
- else:
120
- self.transfer_matrix = jnp.asarray(folding.transfer_matrix.data.todense())
121
-
122
- def __call__(self, parameters):
123
- """
124
- Compute the count functions for a given observation.
125
- """
126
-
127
- expected_counts = self.transfer_matrix @ self.model.photon_flux(parameters, *self.energies)
128
-
129
- return jnp.clip(expected_counts, a_min=1e-6)
34
+ from .util.typing import PriorDictType
130
35
 
131
36
 
132
37
  class BayesianModel:
@@ -157,15 +62,16 @@ class BayesianModel:
157
62
  self.model = model
158
63
  self._observations = observations
159
64
  self.background_model = background_model
160
- self.pars = tree_map(lambda x: jnp.float64(x), self.model.params)
161
65
  self.sparse = sparsify_matrix
162
66
 
163
67
  if not callable(prior_distributions):
164
68
  # Validate the entry with pydantic
165
- prior = PriorDictModel.from_dict(prior_distributions).nested_dict
69
+ # prior = PriorDictModel.from_dict(prior_distributions).
166
70
 
167
71
  def prior_distributions_func():
168
- return build_prior(prior, expand_shape=(len(self.observation_container),))
72
+ return build_prior(
73
+ prior_distributions, expand_shape=(len(self.observation_container),)
74
+ )
169
75
 
170
76
  else:
171
77
  prior_distributions_func = prior_distributions
@@ -173,6 +79,22 @@ class BayesianModel:
173
79
  self.prior_distributions_func = prior_distributions_func
174
80
  self.init_params = self.prior_samples()
175
81
 
82
+ # Check the priors are suited for the observations
83
+ split_parameters = [
84
+ (param, shape[-1])
85
+ for param, shape in jax.tree.map(lambda x: x.shape, self.init_params).items()
86
+ if (len(shape) > 1)
87
+ and not param.startswith("_")
88
+ and not param.startswith("bkg") # hardcoded for subtracted background
89
+ ]
90
+
91
+ for parameter, proposed_number_of_obs in split_parameters:
92
+ if proposed_number_of_obs != len(self.observation_container):
93
+ raise ValueError(
94
+ f"Invalid splitting in the prior distribution. "
95
+ f"Expected {len(self.observation_container)} but got {proposed_number_of_obs} for {parameter}"
96
+ )
97
+
176
98
  @cached_property
177
99
  def observation_container(self) -> dict[str, ObsConfiguration]:
178
100
  """
@@ -197,22 +119,52 @@ class BayesianModel:
197
119
  Build the numpyro model using the observed data, the prior distributions and the spectral model.
198
120
  """
199
121
 
200
- def model(observed=True):
122
+ def numpyro_model(observed=True):
123
+ # Instantiate and register the parameters of the spectral model and the background
201
124
  prior_params = self.prior_distributions_func()
202
125
 
203
126
  # Iterate over all the observations in our container and build a single numpyro model for each observation
204
- for i, (key, observation) in enumerate(self.observation_container.items()):
127
+ for i, (name, observation) in enumerate(self.observation_container.items()):
128
+ # Check that we can indeed fit a background
129
+ if (getattr(observation, "folded_background", None) is not None) and (
130
+ self.background_model is not None
131
+ ):
132
+ # This call should register the parameter and observation of our background model
133
+ bkg_countrate = self.background_model.numpyro_model(
134
+ observation, name=name, observed=observed
135
+ )
136
+
137
+ elif (getattr(observation, "folded_background", None) is None) and (
138
+ self.background_model is not None
139
+ ):
140
+ raise ValueError(
141
+ "Trying to fit a background model but no background is linked to this observation"
142
+ )
143
+
144
+ else:
145
+ bkg_countrate = 0.0
146
+
205
147
  # We expect that prior_params contains an array of parameters for each observation
206
148
  # They can be identical or different for each observation
207
- params = tree_map(lambda x: x[i], prior_params)
149
+ params = jax.tree.map(lambda x: x[i], prior_params)
208
150
 
209
- obs_model = build_numpyro_model_for_single_obs(
210
- observation, self.model, self.background_model, name=key, sparse=self.sparse
151
+ # Forward model the observation and get the associated countrate
152
+ obs_model = jax.jit(
153
+ lambda par: forward_model(self.model, par, observation, sparse=self.sparse)
211
154
  )
155
+ obs_countrate = obs_model(params)
212
156
 
213
- obs_model(params, observed=observed)
157
+ # Register the observation as an observed site
158
+ with numpyro.plate("obs_plate_" + name, len(observation.folded_counts)):
159
+ numpyro.sample(
160
+ "obs_" + name,
161
+ Poisson(
162
+ obs_countrate + bkg_countrate
163
+ ), # / observation.folded_backratio.data
164
+ obs=observation.folded_counts.data if observed else None,
165
+ )
214
166
 
215
- return model
167
+ return numpyro_model
216
168
 
217
169
  @cached_property
218
170
  def transformed_numpyro_model(self) -> Callable:
@@ -352,7 +304,9 @@ class BayesianModel:
352
304
  return fakeit(key, parameters)
353
305
 
354
306
  def prior_predictive_coverage(
355
- self, key: PRNGKey = PRNGKey(0), num_samples: int = 1000, percentiles: tuple = (16, 84)
307
+ self,
308
+ key: PRNGKey = PRNGKey(0),
309
+ num_samples: int = 1000,
356
310
  ):
357
311
  """
358
312
  Check if the prior distribution include the observed data.
@@ -362,24 +316,25 @@ class BayesianModel:
362
316
  posterior_observations = self.mock_observations(prior_params, key=key_posterior)
363
317
 
364
318
  for key, value in self.observation_container.items():
365
- fig, axs = plt.subplots(
366
- nrows=2, ncols=1, sharex=True, figsize=(8, 8), height_ratios=[3, 1]
319
+ fig, ax = plt.subplots(
320
+ nrows=2, ncols=1, sharex=True, figsize=(5, 6), height_ratios=[3, 1]
367
321
  )
368
322
 
369
- _plot_poisson_data_with_error(
370
- axs[0],
323
+ y_observed, y_observed_low, y_observed_high = _error_bars_for_observed_data(
324
+ value.folded_counts.values, 1.0, "ct"
325
+ )
326
+
327
+ true_data_plot = _plot_poisson_data_with_error(
328
+ ax[0],
371
329
  value.out_energies,
372
- value.folded_counts.values,
373
- percentiles=percentiles,
330
+ y_observed.value,
331
+ y_observed_low.value,
332
+ y_observed_high.value,
333
+ alpha=0.7,
374
334
  )
375
335
 
376
- axs[0].stairs(
377
- np.max(posterior_observations["obs_" + key], axis=0),
378
- edges=[*list(value.out_energies[0]), value.out_energies[1][-1]],
379
- baseline=np.min(posterior_observations["obs_" + key], axis=0),
380
- alpha=0.3,
381
- fill=True,
382
- color=(0.15, 0.25, 0.45),
336
+ prior_plot = _plot_binned_samples_with_error(
337
+ ax[0], value.out_energies, posterior_observations["obs_" + key], n_sigmas=3
383
338
  )
384
339
 
385
340
  # rank = np.vstack((posterior_observations["obs_" + key], value.folded_counts.values)).argsort(axis=0)[-1] / (num_samples) * 100
@@ -393,22 +348,24 @@ class BayesianModel:
393
348
 
394
349
  rank = (less_than_obs + 0.5 * equal_to_obs) / num_samples * 100
395
350
 
396
- axs[1].stairs(rank, edges=[*list(value.out_energies[0]), value.out_energies[1][-1]])
351
+ ax[1].stairs(rank, edges=[*list(value.out_energies[0]), value.out_energies[1][-1]])
397
352
 
398
- axs[1].plot(
353
+ ax[1].plot(
399
354
  (value.out_energies.min(), value.out_energies.max()),
400
355
  (50, 50),
401
356
  color="black",
402
357
  linestyle="--",
403
358
  )
404
359
 
405
- axs[1].set_xlabel("Energy (keV)")
406
- axs[0].set_ylabel("Counts")
407
- axs[1].set_ylabel("Rank (%)")
408
- axs[1].set_ylim(0, 100)
409
- axs[0].set_xlim(value.out_energies.min(), value.out_energies.max())
410
- axs[0].loglog()
360
+ ax[1].set_xlabel("Energy (keV)")
361
+ ax[0].set_ylabel("Counts")
362
+ ax[1].set_ylabel("Rank (%)")
363
+ ax[1].set_ylim(0, 100)
364
+ ax[0].set_xlim(value.out_energies.min(), value.out_energies.max())
365
+ ax[0].loglog()
366
+ ax[0].legend(loc="upper right")
411
367
  plt.suptitle(f"Prior Predictive coverage for {key}")
368
+ plt.tight_layout()
412
369
  plt.show()
413
370
 
414
371
 
@@ -513,7 +470,11 @@ class BayesianModelFitter(BayesianModel, ABC):
513
470
  predictive_parameters
514
471
  ]
515
472
 
516
- parameters = [x for x in inference_data.posterior.keys() if not x.endswith("_base")]
473
+ parameters = [
474
+ x
475
+ for x in inference_data.posterior.keys()
476
+ if not x.endswith("_base") or x.startswith("_")
477
+ ]
517
478
  inference_data.posterior = inference_data.posterior[parameters]
518
479
  inference_data.prior = inference_data.prior[parameters]
519
480
 
@@ -595,7 +556,6 @@ class MCMCFitter(BayesianModelFitter):
595
556
  return FitResult(
596
557
  self,
597
558
  inference_data,
598
- self.model.params,
599
559
  background_model=self.background_model,
600
560
  )
601
561
 
@@ -641,11 +601,13 @@ class NSFitter(BayesianModelFitter):
641
601
  ns = NestedSampler(
642
602
  bayesian_model,
643
603
  constructor_kwargs=dict(
644
- num_parallel_workers=1,
645
604
  verbose=verbose,
646
605
  difficult_model=True,
647
- max_samples=1e6,
606
+ max_samples=1e5,
648
607
  parameter_estimation=True,
608
+ gradient_guided=True,
609
+ devices=jax.devices(),
610
+ # init_efficiency_threshold=0.01,
649
611
  num_live_points=num_live_points,
650
612
  ),
651
613
  termination_kwargs=termination_kwargs if termination_kwargs else dict(),
@@ -664,6 +626,5 @@ class NSFitter(BayesianModelFitter):
664
626
  return FitResult(
665
627
  self,
666
628
  inference_data,
667
- self.model.params,
668
629
  background_model=self.background_model,
669
630
  )
@@ -0,0 +1,151 @@
1
+ """Helper functions to deal with the graph logic within model building"""
2
+
3
+ import re
4
+
5
+ from collections.abc import Callable
6
+ from uuid import uuid4
7
+
8
+ import networkx as nx
9
+
10
+
11
+ def get_component_names(graph: nx.DiGraph):
12
+ """
13
+ Get the set of component names from the nodes of a graph.
14
+
15
+ Parameters:
16
+ graph: The graph to get the component names from.
17
+ """
18
+ return set(
19
+ data["name"] for _, data in graph.nodes(data=True) if "component" in data.get("type")
20
+ )
21
+
22
+
23
+ def increment_name(name: str, used_names: set):
24
+ """
25
+ Increment the suffix number in a name if it is formated as 'name_1'.
26
+
27
+ Parameters:
28
+ name: The name to increment.
29
+ used_names: The set of names that are already used.
30
+ """
31
+ # Use regex to extract base name and suffix number
32
+ match = re.match(r"(.*?)(?:_(\d+))?$", name)
33
+ base_name = match.group(1)
34
+ suffix = match.group(2)
35
+ if suffix:
36
+ number = int(suffix)
37
+ else:
38
+ number = 1 # Start from 1 if there is no suffix
39
+
40
+ new_name = name
41
+ while new_name in used_names:
42
+ number += 1
43
+ new_name = f"{base_name}_{number}"
44
+
45
+ return new_name
46
+
47
+
48
+ def compose_with_rename(graph_1: nx.DiGraph, graph_2: nx.DiGraph):
49
+ """
50
+ Compose two graphs by updating the 'name' attributes of nodes in graph_2,
51
+ and return the graph joined on the 'out' node.
52
+
53
+ Parameters:
54
+ graph_1: The first graph to compose.
55
+ graph_2: The second graph to compose.
56
+ """
57
+
58
+ # Initialize the set of used names with names from graph_1
59
+ used_names = get_component_names(graph_1)
60
+
61
+ # Update the 'name' attributes in graph_2 to make them unique
62
+ for node, data in graph_2.nodes(data=True):
63
+ if "component" in data.get("type"):
64
+ original_name = data["name"]
65
+ new_name = original_name
66
+
67
+ if new_name in used_names:
68
+ new_name = increment_name(original_name, used_names)
69
+ data["name"] = new_name
70
+ used_names.add(new_name)
71
+
72
+ else:
73
+ used_names.add(new_name)
74
+
75
+ # Now you can safely compose the graphs
76
+ composed_graph = nx.compose(graph_1, graph_2)
77
+
78
+ return composed_graph
79
+
80
+
81
+ def compose(
82
+ graph_1: nx.DiGraph,
83
+ graph_2: nx.DiGraph,
84
+ operation: str = "",
85
+ operation_func: Callable = lambda x, y: None,
86
+ ):
87
+ """
88
+ Compose two graphs by joining the 'out' node of graph_1 and graph_2, and turning
89
+ it to an 'operation' node with the relevant operator and add a new 'out' node.
90
+
91
+ Parameters:
92
+ graph_1: The first graph to compose.
93
+ graph_2: The second graph to compose.
94
+ operation: The string describing the operation to perform.
95
+ operation_func: The callable that performs the operation.
96
+ """
97
+
98
+ combined_graph = compose_with_rename(graph_1, graph_2)
99
+ node_id = str(uuid4())
100
+ graph = nx.relabel_nodes(combined_graph, {"out": node_id})
101
+ nx.set_node_attributes(graph, {node_id: f"{operation}_operation"}, "type")
102
+ nx.set_node_attributes(graph, {node_id: operation_func}, "operator")
103
+
104
+ # Now add the output node and link it to the operation node
105
+ graph.add_node("out", type="out")
106
+ graph.add_edge(node_id, "out")
107
+
108
+ # Compute the new depth of each node
109
+ longest_path = nx.dag_longest_path_length(graph)
110
+
111
+ for node in graph.nodes:
112
+ nx.set_node_attributes(
113
+ graph,
114
+ {node: longest_path - nx.shortest_path_length(graph, node, "out")},
115
+ "depth",
116
+ )
117
+
118
+ return graph
119
+
120
+
121
+ def export_to_mermaid(graph, file=None):
122
+ mermaid_code = "graph LR\n" # LR = left to right
123
+
124
+ # Add nodes
125
+ for node, attributes in graph.nodes(data=True):
126
+ if attributes["type"] == "out":
127
+ mermaid_code += f' {node}("Output")\n'
128
+
129
+ else:
130
+ operation_type, node_type = attributes["type"].split("_")
131
+
132
+ if node_type == "component":
133
+ name, number = attributes["name"].split("_")
134
+ mermaid_code += f' {node}("{name.capitalize()} ({number})")\n'
135
+
136
+ elif node_type == "operation":
137
+ if operation_type == "add":
138
+ mermaid_code += f" {node}{{**+**}}\n"
139
+
140
+ elif operation_type == "mul":
141
+ mermaid_code += f" {node}{{**x**}}\n"
142
+
143
+ # Draw connexion between nodes
144
+ for source, target in graph.edges():
145
+ mermaid_code += f" {source} --> {target}\n"
146
+
147
+ if file is None:
148
+ return mermaid_code
149
+ else:
150
+ with open(file, "w") as f:
151
+ f.write(mermaid_code)