jaxspec 0.2.2.dev0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxspec/fit/_fitter.py ADDED
@@ -0,0 +1,255 @@
1
+ import warnings
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Literal
5
+
6
+ import arviz as az
7
+ import jax
8
+ import matplotlib.pyplot as plt
9
+ import numpyro
10
+
11
+ from jax import random
12
+ from jax.random import PRNGKey
13
+ from numpyro.infer import AIES, ESS, MCMC, NUTS, SVI, Predictive, Trace_ELBO
14
+ from numpyro.infer.autoguide import AutoMultivariateNormal
15
+
16
+ from ..analysis.results import FitResult
17
+ from ._bayesian_model import BayesianModel
18
+
19
+
20
+ class BayesianModelFitter(BayesianModel, ABC):
21
+ def build_inference_data(
22
+ self,
23
+ posterior_samples,
24
+ num_chains: int = 1,
25
+ num_predictive_samples: int = 1000,
26
+ key: PRNGKey = PRNGKey(42),
27
+ use_transformed_model: bool = False,
28
+ filter_inference_data: bool = True,
29
+ ) -> az.InferenceData:
30
+ """
31
+ Build an [InferenceData][arviz.InferenceData] object from posterior samples.
32
+
33
+ Parameters:
34
+ posterior_samples: the samples from the posterior distribution.
35
+ num_chains: the number of chains used to sample the posterior.
36
+ num_predictive_samples: the number of samples to draw from the prior.
37
+ key: the random key used to initialize the sampler.
38
+ use_transformed_model: whether to use the transformed model to build the InferenceData.
39
+ filter_inference_data: whether to filter the InferenceData to keep only the relevant parameters.
40
+ """
41
+
42
+ numpyro_model = (
43
+ self.transformed_numpyro_model if use_transformed_model else self.numpyro_model
44
+ )
45
+
46
+ keys = random.split(key, 3)
47
+
48
+ posterior_predictive = Predictive(numpyro_model, posterior_samples)(keys[0], observed=False)
49
+
50
+ prior = Predictive(numpyro_model, num_samples=num_predictive_samples * num_chains)(
51
+ keys[1], observed=False
52
+ )
53
+
54
+ log_likelihood = numpyro.infer.log_likelihood(numpyro_model, posterior_samples)
55
+
56
+ seeded_model = numpyro.handlers.substitute(
57
+ numpyro.handlers.seed(numpyro_model, keys[3]),
58
+ substitute_fn=numpyro.infer.init_to_sample,
59
+ )
60
+
61
+ observations = {
62
+ name: site["value"]
63
+ for name, site in numpyro.handlers.trace(seeded_model).get_trace().items()
64
+ if site["type"] == "sample" and site["is_observed"]
65
+ }
66
+
67
+ def reshape_first_dimension(arr):
68
+ new_dim = arr.shape[0] // num_chains
69
+ new_shape = (num_chains, new_dim) + arr.shape[1:]
70
+ reshaped_array = arr.reshape(new_shape)
71
+
72
+ return reshaped_array
73
+
74
+ posterior_samples = {
75
+ key: reshape_first_dimension(value) for key, value in posterior_samples.items()
76
+ }
77
+ prior = {key: value[None, :] for key, value in prior.items()}
78
+ posterior_predictive = {
79
+ key: reshape_first_dimension(value) for key, value in posterior_predictive.items()
80
+ }
81
+ log_likelihood = {
82
+ key: reshape_first_dimension(value) for key, value in log_likelihood.items()
83
+ }
84
+
85
+ inference_data = az.from_dict(
86
+ posterior_samples,
87
+ prior=prior,
88
+ posterior_predictive=posterior_predictive,
89
+ log_likelihood=log_likelihood,
90
+ observed_data=observations,
91
+ )
92
+
93
+ return (
94
+ self.filter_inference_data(inference_data) if filter_inference_data else inference_data
95
+ )
96
+
97
+ def filter_inference_data(
98
+ self,
99
+ inference_data: az.InferenceData,
100
+ ) -> az.InferenceData:
101
+ """
102
+ Filter the inference data to keep only the relevant parameters for the observations.
103
+
104
+ - Removes predictive parameters from deterministic random variables (e.g. kernel of background GP)
105
+ - Removes parameters build from reparametrised variables (e.g. ending with `"_base"`)
106
+ """
107
+
108
+ predictive_parameters = []
109
+
110
+ for key, value in self._observation_container.items():
111
+ if self.background_model is not None:
112
+ predictive_parameters.append(f"obs/~/{key}")
113
+ predictive_parameters.append(f"bkg/~/{key}")
114
+ # predictive_parameters.append(f"ins/~/{key}")
115
+ else:
116
+ predictive_parameters.append(f"obs/~/{key}")
117
+ # predictive_parameters.append(f"ins/~/{key}")
118
+
119
+ inference_data.posterior_predictive = inference_data.posterior_predictive[
120
+ predictive_parameters
121
+ ]
122
+
123
+ parameters = [
124
+ x
125
+ for x in inference_data.posterior.keys()
126
+ if not (x.endswith("_base") or x.startswith("_"))
127
+ ]
128
+ inference_data.posterior = inference_data.posterior[parameters]
129
+ inference_data.prior = inference_data.prior[parameters]
130
+
131
+ return inference_data
132
+
133
+ @abstractmethod
134
+ def fit(self, **kwargs) -> FitResult: ...
135
+
136
+
137
+ class MCMCFitter(BayesianModelFitter):
138
+ """
139
+ A class to fit a model to a given set of observation using a Bayesian approach. This class uses samplers
140
+ from numpyro to perform the inference on the model parameters.
141
+ """
142
+
143
+ kernel_dict = {
144
+ "nuts": NUTS,
145
+ "aies": AIES,
146
+ "ess": ESS,
147
+ }
148
+
149
+ def fit(
150
+ self,
151
+ rng_key: int = 0,
152
+ num_chains: int = len(jax.devices()),
153
+ num_warmup: int = 1000,
154
+ num_samples: int = 1000,
155
+ sampler: Literal["nuts", "aies", "ess"] = "nuts",
156
+ use_transformed_model: bool = True,
157
+ kernel_kwargs: dict = {},
158
+ mcmc_kwargs: dict = {},
159
+ ) -> FitResult:
160
+ """
161
+ Fit the model to the data using a MCMC sampler from numpyro.
162
+
163
+ Parameters:
164
+ rng_key: the random key used to initialize the sampler.
165
+ num_chains: the number of chains to run.
166
+ num_warmup: the number of warmup steps.
167
+ num_samples: the number of samples to draw.
168
+ sampler: the sampler to use. Can be one of "nuts", "aies" or "ess".
169
+ use_transformed_model: whether to use the transformed model to build the InferenceData.
170
+ kernel_kwargs: additional arguments to pass to the kernel. See [`NUTS`][numpyro.infer.mcmc.MCMCKernel] for more details.
171
+ mcmc_kwargs: additional arguments to pass to the MCMC sampler. See [`MCMC`][numpyro.infer.mcmc.MCMC] for more details.
172
+
173
+ Returns:
174
+ A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
175
+ """
176
+
177
+ bayesian_model = (
178
+ self.transformed_numpyro_model if use_transformed_model else self.numpyro_model
179
+ )
180
+
181
+ chain_kwargs = {
182
+ "num_warmup": num_warmup,
183
+ "num_samples": num_samples,
184
+ "num_chains": num_chains,
185
+ }
186
+
187
+ kernel = self.kernel_dict[sampler](bayesian_model, **kernel_kwargs)
188
+
189
+ mcmc_kwargs = chain_kwargs | mcmc_kwargs
190
+
191
+ if sampler in ["aies", "ess"] and mcmc_kwargs.get("chain_method", None) != "vectorized":
192
+ mcmc_kwargs["chain_method"] = "vectorized"
193
+ warnings.warn("The chain_method is set to 'vectorized' for AIES and ESS samplers")
194
+
195
+ mcmc = MCMC(kernel, **mcmc_kwargs)
196
+ keys = random.split(random.PRNGKey(rng_key), 3)
197
+
198
+ mcmc.run(keys[0])
199
+
200
+ posterior = mcmc.get_samples()
201
+
202
+ inference_data = self.build_inference_data(
203
+ posterior, num_chains=num_chains, use_transformed_model=use_transformed_model
204
+ )
205
+
206
+ return FitResult(
207
+ self,
208
+ inference_data,
209
+ background_model=self.background_model,
210
+ )
211
+
212
+
213
+ class VIFitter(BayesianModelFitter):
214
+ def fit(
215
+ self,
216
+ rng_key: int = 0,
217
+ num_steps: int = 10_000,
218
+ optimizer=numpyro.optim.Adam(step_size=0.0005),
219
+ loss=Trace_ELBO(),
220
+ num_samples: int = 1000,
221
+ guide=None,
222
+ use_transformed_model: bool = True,
223
+ plot_diagnostics: bool = False,
224
+ ) -> FitResult:
225
+ bayesian_model = (
226
+ self.transformed_numpyro_model if use_transformed_model else self.numpyro_model
227
+ )
228
+
229
+ if guide is None:
230
+ guide = AutoMultivariateNormal(bayesian_model)
231
+
232
+ svi = SVI(bayesian_model, guide, optimizer, loss=loss)
233
+
234
+ keys = random.split(random.PRNGKey(rng_key), 3)
235
+ svi_result = svi.run(keys[0], num_steps)
236
+ params = svi_result.params
237
+
238
+ if plot_diagnostics:
239
+ plt.plot(svi_result.losses)
240
+ plt.xlabel("Steps")
241
+ plt.ylabel("ELBO loss")
242
+ plt.semilogy()
243
+
244
+ predictive = Predictive(guide, params=params, num_samples=num_samples)
245
+ posterior = predictive(keys[1])
246
+
247
+ inference_data = self.build_inference_data(
248
+ posterior, num_chains=1, use_transformed_model=use_transformed_model
249
+ )
250
+
251
+ return FitResult(
252
+ self,
253
+ inference_data,
254
+ background_model=self.background_model,
255
+ )
jaxspec/model/abc.py CHANGED
@@ -11,8 +11,7 @@ import jax
11
11
  import jax.numpy as jnp
12
12
  import jax.scipy as jsp
13
13
  import networkx as nx
14
-
15
- from simpleeval import simple_eval
14
+ import numpy as np
16
15
 
17
16
  from jaxspec.util.typing import PriorDictType
18
17
 
@@ -31,14 +30,14 @@ def set_parameters(params: PriorDictType, state: nnx.State) -> nnx.State:
31
30
  A spectral model with the newly set parameters.
32
31
  """
33
32
 
34
- state_dict = state.to_pure_dict() # haiku-like 2 level dictionary
33
+ state_dict = nnx.to_pure_dict(state) # haiku-like 2 level dictionary
35
34
 
36
35
  for key, value in params.items():
37
36
  # Split the key to extract the module name and parameter name
38
37
  module_name, param_name = key.rsplit("_", 1)
39
- state_dict["modules"][module_name][param_name] = value
38
+ state_dict["components"][module_name][param_name] = value
40
39
 
41
- state.replace_by_pure_dict(state_dict)
40
+ nnx.replace_by_pure_dict(state, state_dict)
42
41
 
43
42
  return state
44
43
 
@@ -71,52 +70,16 @@ class Composable(ABC):
71
70
 
72
71
 
73
72
  class SpectralModel(nnx.Module, Composable):
74
- # graph: nx.DiGraph = eqx.field(static=True)
75
- # modules: dict
73
+ _graph: nx.DiGraph
76
74
 
77
75
  def __init__(self, graph: nx.DiGraph):
78
- self.graph = graph
79
- self.modules = {}
76
+ self._graph = graph
77
+ self.components = {}
78
+ self._energy_grid = np.geomspace(0.1, 50, 1000, dtype=np.float64)
80
79
 
81
- for node, data in self.graph.nodes(data=True):
80
+ for node, data in self._graph.nodes(data=True):
82
81
  if "component" in data["type"]:
83
- self.modules[data["name"]] = data["component"] # (**data['kwargs'])
84
-
85
- @classmethod
86
- def from_string(cls, string: str) -> SpectralModel:
87
- """
88
- This constructor enable to build a model from a string. The string should be a valid python expression, with
89
- the following constraints :
90
-
91
- * The model components should be defined in the jaxspec.model.list module
92
- * The model components should be separated by a * or a + (no convolution yet)
93
- * The model components should be written with their parameters in parentheses
94
-
95
- Parameters:
96
- string : The string to parse
97
-
98
- Examples:
99
- An absorbed model with a powerlaw and a blackbody:
100
-
101
- >>> model = SpectralModel.from_string("Tbabs()*(Powerlaw() + Blackbody())")
102
- """
103
-
104
- from .list import model_components
105
-
106
- return simple_eval(string, functions=model_components)
107
-
108
- def to_string(self) -> str:
109
- """
110
- This method return the string representation of the model.
111
-
112
- Examples:
113
- Build a model from a string and convert it back to a string:
114
-
115
- >>> model = SpectralModel.from_string("Tbabs()*(Powerlaw() + Blackbody())")
116
- >>> model.to_string()
117
- "Tbabs()*(Powerlaw() + Blackbody())"
118
- """
119
- return str(self)
82
+ self.components[data["name"]] = data["component"] # (**data['kwargs'])
120
83
 
121
84
  """
122
85
  def __str__(self) -> str:
@@ -153,7 +116,7 @@ class SpectralModel(nnx.Module, Composable):
153
116
  """
154
117
 
155
118
  composed_graph = compose(
156
- self.graph, other.graph, operation=operation, operation_func=operation_func
119
+ self._graph, other._graph, operation=operation, operation_func=operation_func
157
120
  )
158
121
 
159
122
  return SpectralModel(composed_graph)
@@ -161,7 +124,7 @@ class SpectralModel(nnx.Module, Composable):
161
124
  @classmethod
162
125
  def from_component(cls, component):
163
126
  node_id = str(uuid4())
164
- graph = nx.DiGraph()
127
+ _graph = nx.DiGraph()
165
128
 
166
129
  node_properties = {
167
130
  "type": f"{component.type}_component",
@@ -170,26 +133,26 @@ class SpectralModel(nnx.Module, Composable):
170
133
  "depth": 0,
171
134
  }
172
135
 
173
- graph.add_node(node_id, **node_properties)
174
- graph.add_node("out", type="out", depth=1)
175
- graph.add_edge(node_id, "out")
136
+ _graph.add_node(node_id, **node_properties)
137
+ _graph.add_node("out", type="out", depth=1)
138
+ _graph.add_edge(node_id, "out")
176
139
 
177
- return cls(graph)
140
+ return cls(_graph)
178
141
 
179
142
  def _find_multiplicative_components(self, node_id):
180
143
  """
181
144
  Recursively finds all the multiplicative components connected to the node with the given ID.
182
145
  """
183
- node = self.graph.nodes[node_id]
146
+ node = self._graph.nodes[node_id]
184
147
  multiplicative_nodes = []
185
148
 
186
149
  if node.get("type") == "mul_operation":
187
150
  # Recursively find all the multiplicative components using the predecessors
188
- predecessors = self.graph.pred[node_id]
151
+ predecessors = self._graph.pred[node_id]
189
152
  for node_id in predecessors:
190
- if "multiplicative_component" == self.graph.nodes[node_id].get("type"):
153
+ if "multiplicative_component" == self._graph.nodes[node_id].get("type"):
191
154
  multiplicative_nodes.append(node_id)
192
- elif "mul_operation" == self.graph.nodes[node_id].get("type"):
155
+ elif "mul_operation" == self._graph.nodes[node_id].get("type"):
193
156
  multiplicative_nodes.extend(self._find_multiplicative_components(node_id))
194
157
 
195
158
  return multiplicative_nodes
@@ -198,8 +161,8 @@ class SpectralModel(nnx.Module, Composable):
198
161
  def root_nodes(self) -> list[str]:
199
162
  return [
200
163
  node_id
201
- for node_id, in_degree in self.graph.in_degree(self.graph.nodes)
202
- if in_degree == 0 and ("additive" in self.graph.nodes[node_id].get("type"))
164
+ for node_id, in_degree in self._graph.in_degree(self._graph.nodes)
165
+ if in_degree == 0 and ("additive" in self._graph.nodes[node_id].get("type"))
203
166
  ]
204
167
 
205
168
  @property
@@ -207,8 +170,8 @@ class SpectralModel(nnx.Module, Composable):
207
170
  branches = []
208
171
 
209
172
  for root_node_id in self.root_nodes:
210
- root_node_name = self.graph.nodes[root_node_id].get("name")
211
- path = nx.shortest_path(self.graph, source=root_node_id, target="out")
173
+ root_node_name = self._graph.nodes[root_node_id].get("name")
174
+ path = nx.shortest_path(self._graph, source=root_node_id, target="out")
212
175
  multiplicative_components = []
213
176
 
214
177
  # Search all multiplicative components connected to this node
@@ -218,10 +181,12 @@ class SpectralModel(nnx.Module, Composable):
218
181
  [node_id for node_id in self._find_multiplicative_components(node_id)]
219
182
  )
220
183
 
184
+ multiplicative_components = set(multiplicative_components)
185
+
221
186
  branch = ""
222
187
 
223
188
  for multiplicative_node_id in multiplicative_components:
224
- multiplicative_node_name = self.graph.nodes[multiplicative_node_id].get("name")
189
+ multiplicative_node_name = self._graph.nodes[multiplicative_node_id].get("name")
225
190
  branch += f"{multiplicative_node_name}*"
226
191
 
227
192
  branch += f"{root_node_name}"
@@ -233,12 +198,12 @@ class SpectralModel(nnx.Module, Composable):
233
198
  continuum = {}
234
199
 
235
200
  ## Evaluate the expected contribution for each component
236
- for node_id in nx.dag.topological_sort(self.graph):
237
- node = self.graph.nodes[node_id]
201
+ for node_id in nx.dag.topological_sort(self._graph):
202
+ node = self._graph.nodes[node_id]
238
203
 
239
204
  if node["type"] == "additive_component":
240
205
  node_name = node["name"]
241
- runtime_modules = self.modules[node_name]
206
+ runtime_modules = self.components[node_name]
242
207
 
243
208
  if not energy_flux:
244
209
  continuum[node_name] = runtime_modules._photon_flux(
@@ -252,8 +217,8 @@ class SpectralModel(nnx.Module, Composable):
252
217
 
253
218
  elif node["type"] == "multiplicative_component":
254
219
  node_name = node["name"]
255
- runtime_modules = self.modules[node_name]
256
- continuum[node_name] = runtime_modules.factor((e_low + e_high) / 2)
220
+ runtime_modules = self.components[node_name]
221
+ continuum[node_name] = runtime_modules._factor(e_low, e_high, n_points=n_points)
257
222
 
258
223
  else:
259
224
  pass
@@ -264,10 +229,10 @@ class SpectralModel(nnx.Module, Composable):
264
229
  branches = {}
265
230
 
266
231
  for root_node_id in root_nodes:
267
- root_node_name = self.graph.nodes[root_node_id].get("name")
232
+ root_node_name = self._graph.nodes[root_node_id].get("name")
268
233
  root_continuum = continuum[root_node_name]
269
234
 
270
- path = nx.shortest_path(self.graph, source=root_node_id, target="out")
235
+ path = nx.shortest_path(self._graph, source=root_node_id, target="out")
271
236
  multiplicative_components = []
272
237
 
273
238
  # Search all multiplicative components connected to this node
@@ -279,8 +244,8 @@ class SpectralModel(nnx.Module, Composable):
279
244
 
280
245
  branch = ""
281
246
 
282
- for multiplicative_node_id in multiplicative_components:
283
- multiplicative_node_name = self.graph.nodes[multiplicative_node_id].get("name")
247
+ for multiplicative_node_id in set(multiplicative_components):
248
+ multiplicative_node_name = self._graph.nodes[multiplicative_node_id].get("name")
284
249
  root_continuum *= continuum[multiplicative_node_name]
285
250
  branch += f"{multiplicative_node_name}*"
286
251
 
@@ -302,7 +267,7 @@ class SpectralModel(nnx.Module, Composable):
302
267
  Returns:
303
268
  A string containing the mermaid representation of the model.
304
269
  """
305
- return export_to_mermaid(self.graph, file)
270
+ return export_to_mermaid(self._graph, file)
306
271
 
307
272
  @partial(jax.jit, static_argnums=0, static_argnames=("n_points", "split_branches"))
308
273
  def photon_flux(self, params, e_low, e_high, n_points=2, split_branches=False):
@@ -327,12 +292,12 @@ class SpectralModel(nnx.Module, Composable):
327
292
  instead.
328
293
  """
329
294
 
330
- graphdef, state = nnx.split(self)
331
- state = set_parameters(params, state)
295
+ graphdef, parameters, tables = nnx.split(self, nnx.Param, ...)
296
+ parameters = set_parameters(params, parameters)
332
297
 
333
- return nnx.call((graphdef, state)).turbo_flux(
298
+ return nnx.merge(graphdef, parameters, tables).turbo_flux(
334
299
  e_low, e_high, n_points=n_points, return_branches=split_branches
335
- )[0]
300
+ )
336
301
 
337
302
  @partial(jax.jit, static_argnums=0, static_argnames="n_points")
338
303
  def energy_flux(self, params, e_low, e_high, n_points=2):
@@ -357,12 +322,12 @@ class SpectralModel(nnx.Module, Composable):
357
322
  instead.
358
323
  """
359
324
 
360
- graphdef, state = nnx.split(self)
361
- state = set_parameters(params, state)
325
+ graphdef, parameters, tables = nnx.split(self, nnx.Param, ...)
326
+ parameters = set_parameters(params, parameters)
362
327
 
363
- return nnx.call((graphdef, state)).turbo_flux(
328
+ return nnx.merge(graphdef, parameters, tables).turbo_flux(
364
329
  e_low, e_high, n_points=n_points, energy_flux=True
365
- )[0]
330
+ )
366
331
 
367
332
 
368
333
  class ModelComponent(nnx.Module, Composable, ABC):
@@ -427,6 +392,13 @@ class AdditiveComponent(ModelComponent):
427
392
  class MultiplicativeComponent(ModelComponent):
428
393
  type = "multiplicative"
429
394
 
395
+ def _factor(self, e_low, e_high, n_points=2):
396
+ energy = jnp.linspace(e_low, e_high, n_points, axis=-1)
397
+ factor = self.factor(energy)
398
+
399
+ return jsp.integrate.trapezoid(factor * energy, jnp.log(energy), axis=-1) / (e_high - e_low)
400
+ # return jnp.mean(factor, axis = -1)
401
+
430
402
  def factor(self, energy):
431
403
  """
432
404
  Absorption factor applied for a given energy
jaxspec/model/additive.py CHANGED
@@ -166,19 +166,28 @@ class Gauss(AdditiveComponent):
166
166
  self.sigma = nnx.Param(1e-2)
167
167
  self.norm = nnx.Param(1.0)
168
168
 
169
+ def continuum(self, energy):
170
+ return (
171
+ self.norm
172
+ * jsp.stats.norm.pdf(energy, loc=jnp.asarray(self.El), scale=jnp.asarray(self.sigma))
173
+ / (1 - jsp.special.erf(-self.El / (self.sigma * jnp.sqrt(2))))
174
+ )
175
+
176
+ """
169
177
  def integrated_continuum(self, e_low, e_high):
170
178
  return self.norm * (
171
179
  jsp.stats.norm.cdf(
172
180
  e_high,
173
- loc=jnp.asarray(self.El, dtype=jnp.float64),
174
- scale=jnp.asarray(self.sigma, dtype=jnp.float64),
181
+ loc=jnp.asarray(self.El),
182
+ scale=jnp.asarray(self.sigma),
175
183
  )
176
184
  - jsp.stats.norm.cdf(
177
185
  e_low,
178
- loc=jnp.asarray(self.El, dtype=jnp.float64),
179
- scale=jnp.asarray(self.sigma, dtype=jnp.float64),
180
- )
186
+ loc=jnp.asarray(self.El),
187
+ scale=jnp.asarray(self.sigma),
188
+ ) #/ (1 - jsp.special.erf(- self.El / (self.sigma * jnp.sqrt(2))))
181
189
  )
190
+ """
182
191
 
183
192
 
184
193
  class Cutoffpl(AdditiveComponent):
@@ -5,15 +5,15 @@ import jax.numpy as jnp
5
5
  import numpyro
6
6
  import numpyro.distributions as dist
7
7
 
8
+ from flax import nnx
8
9
  from jax.scipy.integrate import trapezoid
9
10
  from numpyro.distributions import Poisson
10
11
  from tinygp import GaussianProcess, kernels
11
12
 
12
- from .._fit._build_model import build_prior, forward_model
13
13
  from .abc import SpectralModel
14
14
 
15
15
 
16
- class BackgroundModel(ABC):
16
+ class BackgroundModel(ABC, nnx.Module):
17
17
  """
18
18
  Handles the background modelling in our spectra. This is handled in a separate class for now
19
19
  since backgrounds can be phenomenological models fitted directly on the folded spectrum. This is not the case for
@@ -42,7 +42,7 @@ class SubtractedBackground(BackgroundModel):
42
42
 
43
43
  def numpyro_model(self, observation, name: str = "", observed=True):
44
44
  _, observed_counts = observation.out_energies, observation.folded_background.data
45
- numpyro.deterministic(f"bkg_{name}", observed_counts)
45
+ numpyro.deterministic(f"bkg/~/{name}", observed_counts)
46
46
 
47
47
  return observed_counts
48
48
 
@@ -61,11 +61,11 @@ class BackgroundWithError(BackgroundModel):
61
61
  _, observed_counts = obs.out_energies, obs.folded_background.data
62
62
  alpha = observed_counts + 1
63
63
  beta = 1
64
- countrate = numpyro.sample(f"_bkg_{name}_countrate", dist.Gamma(alpha, rate=beta))
64
+ countrate = numpyro.sample(f"bkg/~/_{name}_countrate", dist.Gamma(alpha, rate=beta))
65
65
 
66
- with numpyro.plate(f"bkg_{name}_plate", len(observed_counts)):
66
+ with numpyro.plate(f"bkg/~/{name}_plate", len(observed_counts)):
67
67
  numpyro.sample(
68
- f"bkg_{name}", dist.Poisson(countrate), obs=observed_counts if observed else None
68
+ f"bkg/~/{name}", dist.Poisson(countrate), obs=observed_counts if observed else None
69
69
  )
70
70
 
71
71
  return countrate
@@ -111,27 +111,28 @@ class GaussianProcessBackground(BackgroundModel):
111
111
 
112
112
  # The parameters of the GP model
113
113
  mean = numpyro.sample(
114
- f"_bkg_{name}_mean", dist.Normal(jnp.log(jnp.mean(observed_counts)), 2.0)
114
+ f"bkg/~/_{name}_mean", dist.Normal(jnp.log(jnp.mean(observed_counts)), 2.0)
115
115
  )
116
- sigma = numpyro.sample(f"_bkg_{name}_sigma", dist.HalfNormal(3.0))
117
- rho = numpyro.sample(f"_bkg_{name}_rho", dist.HalfNormal(10.0))
116
+ sigma = numpyro.sample(f"bkg/~/_{name}_sigma", dist.HalfNormal(3.0))
117
+ rho = numpyro.sample(f"bkg/~/_{name}_rho", dist.HalfNormal(10.0))
118
118
 
119
119
  # Set up the kernel and GP objects
120
120
  kernel = sigma**2 * self.kernel(rho)
121
121
  nodes = jnp.linspace(0, 1, self.n_nodes)
122
122
  gp = GaussianProcess(kernel, nodes, diag=1e-5 * jnp.ones_like(nodes), mean=mean)
123
123
 
124
- log_rate = numpyro.sample(f"_bkg_{name}_log_rate_nodes", gp.numpyro_dist())
124
+ log_rate = numpyro.sample(f"bkg/~/_{name}_log_rate_nodes", gp.numpyro_dist())
125
125
 
126
126
  interp_count_rate = jnp.exp(
127
127
  jnp.interp(energy, nodes * (self.e_max - self.e_min) + self.e_min, log_rate)
128
128
  )
129
+
129
130
  count_rate = trapezoid(interp_count_rate, energy, axis=0)
130
131
 
131
132
  # Finally, our observation model is Poisson
132
- with numpyro.plate("bkg_plate_" + name, len(observed_counts)):
133
+ with numpyro.plate("bkg/~/plate_" + name, len(observed_counts)):
133
134
  numpyro.sample(
134
- f"bkg_{name}", dist.Poisson(count_rate), obs=observed_counts if observed else None
135
+ f"bkg/~/{name}", dist.Poisson(count_rate), obs=observed_counts if observed else None
135
136
  )
136
137
 
137
138
  return count_rate
@@ -144,15 +145,17 @@ class SpectralModelBackground(BackgroundModel):
144
145
  self.sparse = sparse
145
146
 
146
147
  def numpyro_model(self, observation, name: str = "", observed=True):
148
+ from jaxspec.fit._build_model import build_prior, forward_model
149
+
147
150
  params = build_prior(self.prior, prefix=f"_bkg_{name}_")
148
151
  bkg_model = jax.jit(
149
152
  lambda par: forward_model(self.spectral_model, par, observation, sparse=self.sparse)
150
153
  )
151
154
  bkg_countrate = bkg_model(params)
152
155
 
153
- with numpyro.plate("bkg_plate_" + name, len(observation.folded_background)):
156
+ with numpyro.plate("bkg/~/plate_" + name, len(observation.folded_background)):
154
157
  numpyro.sample(
155
- "bkg_" + name,
158
+ "bkg/~/" + name,
156
159
  Poisson(bkg_countrate),
157
160
  obs=observation.folded_background.data if observed else None,
158
161
  )