jaxspec 0.1.0__tar.gz → 0.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. {jaxspec-0.1.0 → jaxspec-0.1.1}/PKG-INFO +2 -2
  2. {jaxspec-0.1.0 → jaxspec-0.1.1}/pyproject.toml +6 -6
  3. {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/fit.py +147 -9
  4. {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/model/abc.py +34 -11
  5. {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/model/additive.py +1 -1
  6. {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/model/multiplicative.py +3 -11
  7. {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/util/typing.py +27 -2
  8. {jaxspec-0.1.0 → jaxspec-0.1.1}/LICENSE.md +0 -0
  9. {jaxspec-0.1.0 → jaxspec-0.1.1}/README.md +0 -0
  10. {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/__init__.py +0 -0
  11. {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/analysis/__init__.py +0 -0
  12. {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/analysis/compare.py +0 -0
  13. {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/analysis/results.py +0 -0
  14. {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/data/__init__.py +0 -0
  15. {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/data/grouping.py +0 -0
  16. {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/data/instrument.py +0 -0
  17. {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/data/obsconf.py +0 -0
  18. {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/data/observation.py +0 -0
  19. {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/data/ogip.py +0 -0
  20. {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/data/util.py +0 -0
  21. {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/model/__init__.py +0 -0
  22. {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/model/_additive/__init__.py +0 -0
  23. {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/model/_additive/apec.py +0 -0
  24. {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/model/_additive/apec_loaders.py +0 -0
  25. {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/model/background.py +0 -0
  26. {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/model/list.py +0 -0
  27. {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/scripts/__init__.py +0 -0
  28. {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/scripts/debug.py +0 -0
  29. {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/util/__init__.py +0 -0
  30. {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/util/abundance.py +0 -0
  31. {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/util/integrate.py +0 -0
  32. {jaxspec-0.1.0 → jaxspec-0.1.1}/src/jaxspec/util/online_storage.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxspec
3
- Version: 0.1.0
3
+ Version: 0.1.1
4
4
  Summary: jaxspec is a bayesian spectral fitting library for X-ray astronomy.
5
5
  Home-page: https://github.com/renecotyfanboy/jaxspec
6
6
  License: MIT
@@ -24,7 +24,7 @@ Requires-Dist: jaxns (>=2.5.1,<3.0.0)
24
24
  Requires-Dist: jaxopt (>=0.8.1,<0.9.0)
25
25
  Requires-Dist: matplotlib (>=3.8.0,<4.0.0)
26
26
  Requires-Dist: mendeleev (>=0.15,<0.18)
27
- Requires-Dist: mkdocstrings (>=0.24,<0.26)
27
+ Requires-Dist: mkdocstrings (>=0.24,<0.27)
28
28
  Requires-Dist: networkx (>=3.1,<4.0)
29
29
  Requires-Dist: numpy (<2.0.0)
30
30
  Requires-Dist: numpyro (>=0.15.2,<0.16.0)
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "jaxspec"
3
- version = "0.1.0"
3
+ version = "0.1.1"
4
4
  description = "jaxspec is a bayesian spectral fitting library for X-ray astronomy."
5
5
  authors = ["sdupourque <sdupourque@irap.omp.eu>"]
6
6
  license = "MIT"
@@ -28,7 +28,7 @@ gpjax = "^0.8.0"
28
28
  jaxopt = "^0.8.1"
29
29
  tinygp = "^0.3.0"
30
30
  seaborn = "^0.13.1"
31
- mkdocstrings = ">=0.24,<0.26"
31
+ mkdocstrings = ">=0.24,<0.27"
32
32
  sparse = "^0.15.1"
33
33
  optimistix = "^0.0.7"
34
34
  scipy = "<1.15"
@@ -43,8 +43,8 @@ watermark = "^2.4.3"
43
43
  [tool.poetry.group.docs.dependencies]
44
44
  mkdocs = "^1.5.3"
45
45
  mkdocs-material = "^9.4.6"
46
- mkdocstrings = {extras = ["python"], version = ">=0.24,<0.26"}
47
- mkdocs-jupyter = "^0.24.6"
46
+ mkdocstrings = {extras = ["python"], version = ">=0.24,<0.27"}
47
+ mkdocs-jupyter = ">=0.24.6,<0.26.0"
48
48
 
49
49
 
50
50
  [tool.poetry.group.test.dependencies]
@@ -59,7 +59,7 @@ testbook = "^0.4.2"
59
59
 
60
60
  [tool.poetry.group.dev.dependencies]
61
61
  pre-commit = "^3.5.0"
62
- ruff = ">=0.2.1,<0.6.0"
62
+ ruff = ">=0.2.1,<0.7.0"
63
63
  jupyterlab = "^4.0.7"
64
64
  notebook = "^7.0.6"
65
65
  ipywidgets = "^8.1.1"
@@ -118,4 +118,4 @@ requires = ["poetry-core"]
118
118
  build-backend = "poetry.core.masonry.api"
119
119
 
120
120
  [tool.poetry.scripts]
121
- jaxspec-debug-info = "jaxspec.scripts.debug:debug_info"
121
+ jaxspec-debug-info = "jaxspec.scripts.debug:debug_info"
@@ -1,4 +1,5 @@
1
1
  import operator
2
+ import warnings
2
3
 
3
4
  from abc import ABC, abstractmethod
4
5
  from collections.abc import Callable
@@ -20,7 +21,7 @@ from jax.tree_util import tree_map
20
21
  from jax.typing import ArrayLike
21
22
  from numpyro.contrib.nested_sampling import NestedSampler
22
23
  from numpyro.distributions import Distribution, Poisson, TransformedDistribution
23
- from numpyro.infer import MCMC, NUTS, Predictive
24
+ from numpyro.infer import AIES, ESS, MCMC, NUTS, Predictive
24
25
  from numpyro.infer.initialization import init_to_value
25
26
  from numpyro.infer.inspect import get_model_relations
26
27
  from numpyro.infer.reparam import TransformReparam
@@ -181,7 +182,7 @@ class BayesianModel:
181
182
 
182
183
  if not callable(prior_distributions):
183
184
  # Validate the entry with pydantic
184
- prior = PriorDictModel(nested_dict=prior_distributions).nested_dict
185
+ prior = PriorDictModel.from_dict(prior_distributions).nested_dict
185
186
 
186
187
  def prior_distributions_func():
187
188
  return build_prior(prior, expand_shape=(len(self.observation_container),))
@@ -293,6 +294,9 @@ class BayesianModel:
293
294
  that can be fetched with the [`parameter_names`][jaxspec.fit.BayesianModel.parameter_names].
294
295
  """
295
296
 
297
+ # This is required as numpyro.infer.util.log_densities does not check parameter validity by itself
298
+ numpyro.enable_validation()
299
+
296
300
  @jax.jit
297
301
  def log_posterior_prob(constrained_params):
298
302
  log_posterior_prob, _ = log_density(
@@ -350,6 +354,69 @@ class BayesianModel:
350
354
 
351
355
 
352
356
  class BayesianModelFitter(BayesianModel, ABC):
357
+ def build_inference_data(
358
+ self,
359
+ posterior_samples,
360
+ num_chains: int = 1,
361
+ num_predictive_samples: int = 1000,
362
+ key: PRNGKey = PRNGKey(0),
363
+ use_transformed_model: bool = False,
364
+ ) -> az.InferenceData:
365
+ """
366
+ Build an InferenceData object from the posterior samples.
367
+ """
368
+
369
+ numpyro_model = (
370
+ self.transformed_numpyro_model if use_transformed_model else self.numpyro_model
371
+ )
372
+
373
+ keys = random.split(key, 3)
374
+
375
+ posterior_predictive = Predictive(numpyro_model, posterior_samples)(keys[0], observed=False)
376
+
377
+ prior = Predictive(numpyro_model, num_samples=num_predictive_samples * num_chains)(
378
+ keys[1], observed=False
379
+ )
380
+
381
+ log_likelihood = numpyro.infer.log_likelihood(numpyro_model, posterior_samples)
382
+
383
+ seeded_model = numpyro.handlers.substitute(
384
+ numpyro.handlers.seed(numpyro_model, keys[3]),
385
+ substitute_fn=numpyro.infer.init_to_sample,
386
+ )
387
+
388
+ observations = {
389
+ name: site["value"]
390
+ for name, site in numpyro.handlers.trace(seeded_model).get_trace().items()
391
+ if site["type"] == "sample" and site["is_observed"]
392
+ }
393
+
394
+ def reshape_first_dimension(arr):
395
+ new_dim = arr.shape[0] // num_chains
396
+ new_shape = (num_chains, new_dim) + arr.shape[1:]
397
+ reshaped_array = arr.reshape(new_shape)
398
+
399
+ return reshaped_array
400
+
401
+ posterior_samples = {
402
+ key: reshape_first_dimension(value) for key, value in posterior_samples.items()
403
+ }
404
+ prior = {key: value[None, :] for key, value in prior.items()}
405
+ posterior_predictive = {
406
+ key: reshape_first_dimension(value) for key, value in posterior_predictive.items()
407
+ }
408
+ log_likelihood = {
409
+ key: reshape_first_dimension(value) for key, value in log_likelihood.items()
410
+ }
411
+
412
+ return az.from_dict(
413
+ posterior_samples,
414
+ prior=prior,
415
+ posterior_predictive=posterior_predictive,
416
+ log_likelihood=log_likelihood,
417
+ observed_data=observations,
418
+ )
419
+
353
420
  @abstractmethod
354
421
  def fit(self, **kwargs) -> FitResult: ...
355
422
 
@@ -411,18 +478,89 @@ class NUTSFitter(BayesianModelFitter):
411
478
 
412
479
  mcmc.run(keys[0])
413
480
 
414
- posterior_predictive = Predictive(bayesian_model, mcmc.get_samples())(
415
- keys[1], observed=False
416
- )
481
+ posterior = mcmc.get_samples()
417
482
 
418
- prior = Predictive(bayesian_model, num_samples=num_samples)(keys[2], observed=False)
483
+ inference_data = filter_inference_data(
484
+ self.build_inference_data(posterior, num_chains=num_chains),
485
+ self.observation_container,
486
+ self.background_model,
487
+ )
419
488
 
420
- inference_data = az.from_numpyro(
421
- mcmc, prior=prior, posterior_predictive=posterior_predictive
489
+ return FitResult(
490
+ self,
491
+ inference_data,
492
+ self.model.params,
493
+ background_model=self.background_model,
422
494
  )
423
495
 
496
+
497
+ class MCMCFitter(BayesianModelFitter):
498
+ """
499
+ A class to fit a model to a given set of observation using a Bayesian approach. This class uses samplers
500
+ from numpyro to perform the inference on the model parameters.
501
+ """
502
+
503
+ kernel_dict = {
504
+ "nuts": NUTS,
505
+ "aies": AIES,
506
+ "ess": ESS,
507
+ }
508
+
509
+ def fit(
510
+ self,
511
+ rng_key: int = 0,
512
+ num_chains: int = len(jax.devices()),
513
+ num_warmup: int = 1000,
514
+ num_samples: int = 1000,
515
+ sampler: Literal["nuts", "aies", "ess"] = "nuts",
516
+ kernel_kwargs: dict = {},
517
+ mcmc_kwargs: dict = {},
518
+ ) -> FitResult:
519
+ """
520
+ Fit the model to the data using a MCMC sampler from numpyro.
521
+
522
+ Parameters:
523
+ rng_key: the random key used to initialize the sampler.
524
+ num_chains: the number of chains to run.
525
+ num_warmup: the number of warmup steps.
526
+ num_samples: the number of samples to draw.
527
+ max_tree_depth: the recursion depth of NUTS sampler.
528
+ target_accept_prob: the target acceptance probability for the NUTS sampler.
529
+ dense_mass: whether to use a dense mass for the NUTS sampler.
530
+ mcmc_kwargs: additional arguments to pass to the MCMC sampler. See [`MCMC`][numpyro.infer.mcmc.MCMC] for more details.
531
+
532
+ Returns:
533
+ A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
534
+ """
535
+
536
+ bayesian_model = self.transformed_numpyro_model
537
+ # bayesian_model = self.numpyro_model(prior_distributions)
538
+
539
+ chain_kwargs = {
540
+ "num_warmup": num_warmup,
541
+ "num_samples": num_samples,
542
+ "num_chains": num_chains,
543
+ }
544
+
545
+ kernel = self.kernel_dict[sampler](bayesian_model, **kernel_kwargs)
546
+
547
+ mcmc_kwargs = chain_kwargs | mcmc_kwargs
548
+
549
+ if sampler in ["aies", "ess"] and mcmc_kwargs.get("chain_method", None) != "vectorized":
550
+ mcmc_kwargs["chain_method"] = "vectorized"
551
+ warnings.warn("The chain_method is set to 'vectorized' for AIES and ESS samplers")
552
+
553
+ mcmc = MCMC(kernel, **mcmc_kwargs)
554
+ keys = random.split(random.PRNGKey(rng_key), 3)
555
+
556
+ mcmc.run(keys[0])
557
+
558
+ posterior = mcmc.get_samples()
559
+
424
560
  inference_data = filter_inference_data(
425
- inference_data, self.observation_container, self.background_model
561
+ self.build_inference_data(posterior, num_chains=num_chains),
562
+ self.observation_container,
563
+ self.background_model,
426
564
  )
427
565
 
428
566
  return FitResult(
@@ -1,12 +1,15 @@
1
1
  from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from uuid import uuid4
5
+
2
6
  import haiku as hk
3
7
  import jax
4
8
  import jax.numpy as jnp
5
9
  import networkx as nx
10
+
6
11
  from haiku._src import base
7
- from uuid import uuid4
8
12
  from jax.scipy.integrate import trapezoid
9
- from abc import ABC
10
13
  from simpleeval import simple_eval
11
14
 
12
15
 
@@ -215,16 +218,18 @@ class SpectralModel:
215
218
  continuum[node_id] = runtime_modules[node_id].continuum(energies)
216
219
 
217
220
  elif node and node["type"] == "operation":
218
- component_1 = list(self.graph.in_edges(node_id))[0][0]
221
+ component_1 = list(self.graph.in_edges(node_id))[0][0] # noqa: RUF015
219
222
  component_2 = list(self.graph.in_edges(node_id))[1][0]
220
- continuum[node_id] = node["function"](continuum[component_1], continuum[component_2])
223
+ continuum[node_id] = node["function"](
224
+ continuum[component_1], continuum[component_2]
225
+ )
221
226
 
222
227
  if n_points == 2:
223
- flux_1D = continuum[list(self.graph.in_edges("out"))[0][0]]
228
+ flux_1D = continuum[list(self.graph.in_edges("out"))[0][0]] # noqa: RUF015
224
229
  flux = jnp.stack((flux_1D[:-1], flux_1D[1:]))
225
230
 
226
231
  else:
227
- flux = continuum[list(self.graph.in_edges("out"))[0][0]]
232
+ flux = continuum[list(self.graph.in_edges("out"))[0][0]] # noqa: RUF015
228
233
 
229
234
  if energy_flux:
230
235
  continuum_flux = trapezoid(
@@ -234,7 +239,9 @@ class SpectralModel:
234
239
  )
235
240
 
236
241
  else:
237
- continuum_flux = trapezoid(flux * energies_to_integrate, x=jnp.log(energies_to_integrate), axis=0)
242
+ continuum_flux = trapezoid(
243
+ flux * energies_to_integrate, x=jnp.log(energies_to_integrate), axis=0
244
+ )
238
245
 
239
246
  # Iterate from the root nodes to the output node and
240
247
  # compute the fine structure contribution for each component
@@ -249,14 +256,18 @@ class SpectralModel:
249
256
  path = nx.shortest_path(self.graph, source=root_node_id, target="out")
250
257
  nodes_id_in_path = [node_id for node_id in path]
251
258
 
252
- flux_from_component, mean_energy = runtime_modules[root_node_id].emission_lines(e_low, e_high)
259
+ flux_from_component, mean_energy = runtime_modules[root_node_id].emission_lines(
260
+ e_low, e_high
261
+ )
253
262
 
254
263
  multiplicative_nodes = []
255
264
 
256
265
  # Search all multiplicative components connected to this node
257
266
  # and apply them at mean energy
258
267
  for node_id in nodes_id_in_path[::-1]:
259
- multiplicative_nodes.extend([node_id for node_id in self.find_multiplicative_components(node_id)])
268
+ multiplicative_nodes.extend(
269
+ [node_id for node_id in self.find_multiplicative_components(node_id)]
270
+ )
260
271
 
261
272
  for mul_node in multiplicative_nodes:
262
273
  flux_from_component *= runtime_modules[mul_node].continuum(mean_energy)
@@ -309,7 +320,10 @@ class SpectralModel:
309
320
  if component.type == "additive":
310
321
 
311
322
  def lam_func(e):
312
- return component(**kwargs).continuum(e) + component(**kwargs).emission_lines(e, e + 1)[0]
323
+ return (
324
+ component(**kwargs).continuum(e)
325
+ + component(**kwargs).emission_lines(e, e + 1)[0]
326
+ )
313
327
 
314
328
  elif component.type == "multiplicative":
315
329
 
@@ -342,7 +356,9 @@ class SpectralModel:
342
356
 
343
357
  return cls(graph, labels)
344
358
 
345
- def compose(self, other: SpectralModel, operation=None, function=None, name=None) -> SpectralModel:
359
+ def compose(
360
+ self, other: SpectralModel, operation=None, function=None, name=None
361
+ ) -> SpectralModel:
346
362
  """
347
363
  This function operate a composition between the operation graph of two models
348
364
  1) It fuses the two graphs using which joins at the 'out' nodes
@@ -524,3 +540,10 @@ class AdditiveComponent(ModelComponent, ABC):
524
540
 
525
541
  return jnp.trapz(self(x) * dx, x=t)
526
542
  '''
543
+
544
+
545
+ class MultiplicativeComponent(ModelComponent, ABC):
546
+ type = "multiplicative"
547
+
548
+ @abstractmethod
549
+ def continuum(self, energy): ...
@@ -38,7 +38,7 @@ class Powerlaw(AdditiveComponent):
38
38
  return norm * energy ** (-alpha)
39
39
 
40
40
 
41
- class AdditiveConstant(AdditiveComponent):
41
+ class Additiveconstant(AdditiveComponent):
42
42
  r"""
43
43
  A constant model
44
44
 
@@ -1,7 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- from abc import ABC, abstractmethod
4
-
5
3
  import haiku as hk
6
4
  import jax.numpy as jnp
7
5
  import numpy as np
@@ -10,14 +8,7 @@ from astropy.table import Table
10
8
  from haiku.initializers import Constant as HaikuConstant
11
9
 
12
10
  from ..util.online_storage import table_manager
13
- from .abc import ModelComponent
14
-
15
-
16
- class MultiplicativeComponent(ModelComponent, ABC):
17
- type = "multiplicative"
18
-
19
- @abstractmethod
20
- def continuum(self, energy): ...
11
+ from .abc import MultiplicativeComponent
21
12
 
22
13
 
23
14
  class Expfac(MultiplicativeComponent):
@@ -226,6 +217,7 @@ class Tbpcf(MultiplicativeComponent):
226
217
 
227
218
  return f * jnp.exp(-nh * sigma) + (1 - f)
228
219
 
220
+
229
221
  class FDcut(MultiplicativeComponent):
230
222
  r"""
231
223
  A Fermi-Dirac cutoff model.
@@ -243,4 +235,4 @@ class FDcut(MultiplicativeComponent):
243
235
  cutoff = hk.get_parameter("E_c", [], init=HaikuConstant(1))
244
236
  folding = hk.get_parameter("E_f", [], init=HaikuConstant(1))
245
237
 
246
- return (1 + jnp.exp((energy - cutoff)/folding)) ** -1
238
+ return (1 + jnp.exp((energy - cutoff) / folding)) ** -1
@@ -9,6 +9,13 @@ from pydantic import BaseModel, field_validator
9
9
  PriorDictType = dict[str, dict[str, dist.Distribution | ArrayLike]]
10
10
 
11
11
 
12
+ def is_flat_dict(input_data: dict[str, Any]) -> bool:
13
+ """
14
+ Check if the input data is a flat dictionary with string keys and non-dictionary values.
15
+ """
16
+ return all(isinstance(k, str) and not isinstance(v, dict) for k, v in input_data.items())
17
+
18
+
12
19
  class PriorDictModel(BaseModel):
13
20
  """
14
21
  Pydantic model for a nested dictionary of NumPyro distributions or JAX arrays.
@@ -21,6 +28,23 @@ class PriorDictModel(BaseModel):
21
28
  class Config: # noqa D106
22
29
  arbitrary_types_allowed = True
23
30
 
31
+ @classmethod
32
+ def from_dict(cls, input_prior: dict[str, Any]):
33
+ if is_flat_dict(input_prior):
34
+ nested_dict = {}
35
+
36
+ for key, obj in input_prior.items():
37
+ component, component_number, *parameter = key.split("_")
38
+
39
+ sub_dict = nested_dict.get(f"{component}_{component_number}", {})
40
+ sub_dict["_".join(parameter)] = obj
41
+
42
+ nested_dict[f"{component}_{component_number}"] = sub_dict
43
+
44
+ return cls(nested_dict=nested_dict)
45
+
46
+ return cls(nested_dict=input_prior)
47
+
24
48
  @field_validator("nested_dict", mode="before")
25
49
  def check_and_cast_nested_dict(cls, value: dict[str, Any]):
26
50
  if not isinstance(value, dict):
@@ -35,9 +59,10 @@ class PriorDictModel(BaseModel):
35
59
  try:
36
60
  # Attempt to cast to JAX array
37
61
  value[key][inner_key] = jnp.array(obj, dtype=float)
62
+
38
63
  except Exception as e:
39
64
  raise ValueError(
40
- f'The value for key "{inner_key}" in inner dictionary must '
41
- f"be a NumPyro distribution or castable to JAX array. Error: {e}"
65
+ f'The value for key "{inner_key}" in {key} be a NumPyro '
66
+ f"distribution or castable to JAX array. Error: {e}"
42
67
  )
43
68
  return value
File without changes
File without changes
File without changes