jaxspec 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxspec/data/__init__.py CHANGED
@@ -1,9 +1,9 @@
1
- # precommit is suppressing these imports
2
- from .obsconf import ObsConfiguration # noqa: F401
3
- from .instrument import Instrument # noqa: F401
4
- from .observation import Observation # noqa: F401
5
1
  import astropy.units as u
6
2
 
3
+ from .instrument import Instrument
4
+ from .obsconf import ObsConfiguration
5
+ from .observation import Observation
6
+
7
7
  u.add_enabled_aliases({"counts": u.count})
8
8
  u.add_enabled_aliases({"channel": u.dimensionless_unscaled})
9
9
  # Arbitrary units are found in .rsp files , let's hope it is compatible with what we would expect as the rmf x arf
jaxspec/data/obsconf.py CHANGED
@@ -1,7 +1,8 @@
1
1
  import numpy as np
2
- import xarray as xr
3
- import sparse
4
2
  import scipy
3
+ import sparse
4
+ import xarray as xr
5
+
5
6
  from .instrument import Instrument
6
7
  from .observation import Observation
7
8
 
@@ -62,8 +63,26 @@ class ObsConfiguration(xr.Dataset):
62
63
 
63
64
  @classmethod
64
65
  def from_pha_file(
65
- cls, pha_path, rmf_path=None, arf_path=None, bkg_path=None, low_energy: float = 1e-20, high_energy: float = 1e20
66
+ cls,
67
+ pha_path,
68
+ rmf_path: str | None = None,
69
+ arf_path: str | None = None,
70
+ bkg_path: str | None = None,
71
+ low_energy: float = 1e-20,
72
+ high_energy: float = 1e20,
66
73
  ):
74
+ r"""
75
+ Build the observation configuration from a PHA file.
76
+
77
+ Parameters:
78
+ pha_path: The path to the PHA file.
79
+ rmf_path: The path to the RMF file.
80
+ arf_path: The path to the ARF file.
81
+ bkg_path: The path to the background file.
82
+ low_energy: The lower bound of the energy range to consider.
83
+ high_energy: The upper bound of the energy range to consider.
84
+ """
85
+
67
86
  from .util import data_path_finder
68
87
 
69
88
  arf_path_default, rmf_path_default, bkg_path_default = data_path_finder(pha_path)
@@ -75,16 +94,34 @@ class ObsConfiguration(xr.Dataset):
75
94
  instrument = Instrument.from_ogip_file(rmf_path, arf_path=arf_path)
76
95
  observation = Observation.from_pha_file(pha_path, bkg_path=bkg_path)
77
96
 
78
- return cls.from_instrument(instrument, observation, low_energy=low_energy, high_energy=high_energy)
97
+ return cls.from_instrument(
98
+ instrument, observation, low_energy=low_energy, high_energy=high_energy
99
+ )
79
100
 
80
101
  @classmethod
81
102
  def from_instrument(
82
- cls, instrument: Instrument, observation: Observation, low_energy: float = 1e-20, high_energy: float = 1e20
103
+ cls,
104
+ instrument: Instrument,
105
+ observation: Observation,
106
+ low_energy: float = 1e-20,
107
+ high_energy: float = 1e20,
83
108
  ):
109
+ r"""
110
+ Build the observation configuration from an [`Instrument`][jaxspec.data.Instrument] and an [`Observation`][jaxspec.data.Observation] object.
111
+
112
+ Parameters:
113
+ instrument: The instrument object.
114
+ observation: The observation object.
115
+ low_energy: The lower bound of the energy range to consider.
116
+ high_energy: The upper bound of the energy range to consider.
117
+
118
+ """
84
119
  # First we unpack all the xarray data to classical np array for efficiency
85
120
  # We also exclude the bins that are flagged with bad quality on the instrument
86
121
  quality_filter = observation.quality.data == 0
87
- grouping = scipy.sparse.csr_array(observation.grouping.data.to_scipy_sparse()) * quality_filter
122
+ grouping = (
123
+ scipy.sparse.csr_array(observation.grouping.data.to_scipy_sparse()) * quality_filter
124
+ )
88
125
  e_min_channel = instrument.coords["e_min_channel"].data
89
126
  e_max_channel = instrument.coords["e_max_channel"].data
90
127
  e_min_unfolded = instrument.coords["e_min_unfolded"].data
@@ -134,7 +171,10 @@ class ObsConfiguration(xr.Dataset):
134
171
  "area": (
135
172
  ["unfolded_channel"],
136
173
  area,
137
- {"description": "Effective area with the same restrictions as the transfer matrix.", "units": "cm^2"},
174
+ {
175
+ "description": "Effective area with the same restrictions as the transfer matrix.",
176
+ "units": "cm^2",
177
+ },
138
178
  ),
139
179
  "exposure": ([], exposure, {"description": "Total exposure", "unit": "s"}),
140
180
  "folded_counts": (
@@ -148,7 +188,9 @@ class ObsConfiguration(xr.Dataset):
148
188
  "folded_backratio": (
149
189
  ["folded_channel"],
150
190
  folded_backratio,
151
- {"description": "Background scaling after grouping, with the same restrictions as the transfer matrix."},
191
+ {
192
+ "description": "Background scaling after grouping, with the same restrictions as the transfer matrix."
193
+ },
152
194
  ),
153
195
  "folded_background": (
154
196
  ["folded_channel"],
@@ -186,3 +228,6 @@ class ObsConfiguration(xr.Dataset):
186
228
  },
187
229
  attrs=observation.attrs | instrument.attrs,
188
230
  )
231
+
232
+ def plot_counts(self, **kwargs):
233
+ return self.folded_counts.plot.step(x="e_min_folded", where="post", **kwargs)
jaxspec/data/util.py CHANGED
@@ -1,102 +1,130 @@
1
- import importlib.resources
2
- import numpyro
1
+ from collections.abc import Mapping
2
+ from pathlib import Path
3
+ from typing import Literal, TypeVar
4
+
5
+ import haiku as hk
3
6
  import jax
4
7
  import numpy as np
5
- import haiku as hk
6
- from pathlib import Path
7
- from numpy.typing import ArrayLike
8
- from collections.abc import Mapping
9
- from typing import TypeVar, Tuple
8
+ import numpyro
9
+
10
10
  from astropy.io import fits
11
+ from numpy.typing import ArrayLike
12
+ from numpyro import handlers
11
13
 
12
- from . import Observation, Instrument, ObsConfiguration
13
- from ..model.abc import SpectralModel
14
14
  from ..fit import CountForwardModel
15
- from numpyro import handlers
15
+ from ..model.abc import SpectralModel
16
+ from ..util.online_storage import table_manager
17
+ from . import Instrument, ObsConfiguration, Observation
16
18
 
17
19
  K = TypeVar("K")
18
20
  V = TypeVar("V")
19
21
 
20
22
 
21
- def load_example_observations():
23
+ def load_example_pha(
24
+ source: Literal["NGC7793_ULX4_PN", "NGC7793_ULX4_ALL"],
25
+ ) -> (Observation, list[Observation] | dict[str, Observation]):
22
26
  """
23
27
  Load some example observations from the package data.
28
+
29
+ Parameters:
30
+ source: The source to be loaded. Can be either "NGC7793_ULX4_PN" or "NGC7793_ULX4_ALL".
24
31
  """
25
32
 
26
- example_observations = {
27
- "PN": Observation.from_pha_file(
28
- str(importlib.resources.files("jaxspec") / "data/example_data/PN_spectrum_grp20.fits"),
29
- low_energy=0.3,
30
- high_energy=7.5,
31
- ),
32
- "MOS1": Observation.from_pha_file(
33
- str(importlib.resources.files("jaxspec") / "data/example_data/MOS1_spectrum_grp.fits"),
34
- low_energy=0.3,
35
- high_energy=7,
36
- ),
37
- "MOS2": Observation.from_pha_file(
38
- str(importlib.resources.files("jaxspec") / "data/example_data/MOS2_spectrum_grp.fits"),
39
- low_energy=0.3,
40
- high_energy=7,
41
- ),
42
- }
43
-
44
- return example_observations
45
-
46
-
47
- def load_example_instruments():
33
+ if source == "NGC7793_ULX4_PN":
34
+ return Observation.from_pha_file(
35
+ table_manager.fetch("example_data/NGC7793_ULX4/PN_spectrum_grp20.fits"),
36
+ bkg_path=table_manager.fetch("example_data/NGC7793_ULX4/PNbackground_spectrum.fits"),
37
+ )
38
+
39
+ elif source == "NGC7793_ULX4_ALL":
40
+ return {
41
+ "PN": Observation.from_pha_file(
42
+ table_manager.fetch("example_data/NGC7793_ULX4/PN_spectrum_grp20.fits"),
43
+ bkg_path=table_manager.fetch(
44
+ "example_data/NGC7793_ULX4/PNbackground_spectrum.fits"
45
+ ),
46
+ ),
47
+ "MOS1": Observation.from_pha_file(
48
+ table_manager.fetch("example_data/NGC7793_ULX4/MOS1_spectrum_grp.fits"),
49
+ bkg_path=table_manager.fetch(
50
+ "example_data/NGC7793_ULX4/MOS1background_spectrum.fits"
51
+ ),
52
+ ),
53
+ "MOS2": Observation.from_pha_file(
54
+ table_manager.fetch("example_data/NGC7793_ULX4/MOS2_spectrum_grp.fits"),
55
+ bkg_path=table_manager.fetch(
56
+ "example_data/NGC7793_ULX4/MOS2background_spectrum.fits"
57
+ ),
58
+ ),
59
+ }
60
+
61
+ else:
62
+ raise ValueError(f"{source} not recognized.")
63
+
64
+
65
+ def load_example_instruments(source: Literal["NGC7793_ULX4_PN", "NGC7793_ULX4_ALL"]):
48
66
  """
49
67
  Load some example instruments from the package data.
68
+
69
+ Parameters:
70
+ source: The source to be loaded. Can be either "NGC7793_ULX4_PN" or "NGC7793_ULX4_ALL".
71
+
50
72
  """
73
+ if source == "NGC7793_ULX4_PN":
74
+ return Instrument.from_ogip_file(
75
+ table_manager.fetch("example_data/NGC7793_ULX4/PN.rmf"),
76
+ table_manager.fetch("example_data/NGC7793_ULX4/PN.arf"),
77
+ )
51
78
 
52
- example_instruments = {
53
- "PN": Instrument.from_ogip_file(
54
- str(importlib.resources.files("jaxspec") / "data/example_data/PN.rmf"),
55
- str(importlib.resources.files("jaxspec") / "data/example_data/PN.arf"),
56
- ),
57
- "MOS1": Instrument.from_ogip_file(
58
- str(importlib.resources.files("jaxspec") / "data/example_data/MOS1.rmf"),
59
- str(importlib.resources.files("jaxspec") / "data/example_data/MOS1.arf"),
60
- ),
61
- "MOS2": Instrument.from_ogip_file(
62
- str(importlib.resources.files("jaxspec") / "data/example_data/MOS2.rmf"),
63
- str(importlib.resources.files("jaxspec") / "data/example_data/MOS2.arf"),
64
- ),
65
- }
66
-
67
- return example_instruments
68
-
69
-
70
- def load_example_foldings():
79
+ elif source == "NGC7793_ULX4_ALL":
80
+ return {
81
+ "PN": Instrument.from_ogip_file(
82
+ table_manager.fetch("example_data/NGC7793_ULX4/PN.rmf"),
83
+ table_manager.fetch("example_data/NGC7793_ULX4/PN.arf"),
84
+ ),
85
+ "MOS1": Instrument.from_ogip_file(
86
+ table_manager.fetch("example_data/NGC7793_ULX4/MOS1.rmf"),
87
+ table_manager.fetch("example_data/NGC7793_ULX4/MOS1.arf"),
88
+ ),
89
+ "MOS2": Instrument.from_ogip_file(
90
+ table_manager.fetch("example_data/NGC7793_ULX4/MOS2.rmf"),
91
+ table_manager.fetch("example_data/NGC7793_ULX4/MOS2.arf"),
92
+ ),
93
+ }
94
+
95
+ else:
96
+ raise ValueError(f"{source} not recognized.")
97
+
98
+
99
+ def load_example_obsconf(source: Literal["NGC7793_ULX4_PN", "NGC7793_ULX4_ALL"]):
71
100
  """
72
- Load some example instruments from the package data.
101
+ Load some example ObsConfigurations.
102
+
103
+ Parameters:
104
+ source: The source to be loaded. Can be either "NGC7793_ULX4_PN" or "NGC7793_ULX4_ALL".
73
105
  """
74
106
 
75
- example_instruments = load_example_instruments()
76
- example_observations = load_example_observations()
77
-
78
- example_foldings = {
79
- "PN": ObsConfiguration.from_instrument(
80
- example_instruments["PN"],
81
- example_observations["PN"],
82
- low_energy=0.3,
83
- high_energy=7.5,
84
- ),
85
- "MOS1": ObsConfiguration.from_instrument(
86
- example_instruments["MOS1"],
87
- example_observations["MOS1"],
88
- low_energy=0.3,
89
- high_energy=7,
90
- ),
91
- "MOS2": ObsConfiguration.from_instrument(
92
- example_instruments["MOS2"],
93
- example_observations["MOS2"],
94
- low_energy=0.3,
95
- high_energy=7,
96
- ),
97
- }
98
-
99
- return example_foldings
107
+ if source in "NGC7793_ULX4_PN":
108
+ instrument = load_example_instruments(source)
109
+ observation = load_example_pha(source)
110
+
111
+ return ObsConfiguration.from_instrument(
112
+ instrument, observation, low_energy=0.5, high_energy=8.0
113
+ )
114
+
115
+ elif source == "NGC7793_ULX4_ALL":
116
+ instruments_dict = load_example_instruments(source)
117
+ observations_dict = load_example_pha(source)
118
+
119
+ return {
120
+ key: ObsConfiguration.from_instrument(
121
+ instruments_dict[key], observations_dict[key], low_energy=0.5, high_energy=8.0
122
+ )
123
+ for key in instruments_dict.keys()
124
+ }
125
+
126
+ else:
127
+ raise ValueError(f"{source} not recognized.")
100
128
 
101
129
 
102
130
  def fakeit(
@@ -107,7 +135,7 @@ def fakeit(
107
135
  sparsify_matrix: bool = False,
108
136
  ) -> ArrayLike | list[ArrayLike]:
109
137
  """
110
- This function is a convenience function that allows to simulate spectra from a given model and a set of parameters.
138
+ Convenience function to simulate a spectrum from a given model and a set of parameters.
111
139
  It requires an instrumental setup, and unlike in
112
140
  [XSPEC's fakeit](https://heasarc.gsfc.nasa.gov/xanadu/xspec/manual/node72.html), the error on the counts is given
113
141
  exclusively by Poisson statistics.
@@ -125,7 +153,9 @@ def fakeit(
125
153
 
126
154
  for i, instrument in enumerate(instruments):
127
155
  transformed_model = hk.without_apply_rng(
128
- hk.transform(lambda par: CountForwardModel(model, instrument, sparse=sparsify_matrix)(par))
156
+ hk.transform(
157
+ lambda par: CountForwardModel(model, instrument, sparse=sparsify_matrix)(par)
158
+ )
129
159
  )
130
160
 
131
161
  def obs_model(p):
@@ -167,8 +197,7 @@ def fakeit_for_multiple_parameters(
167
197
  sparsify_matrix: bool = False,
168
198
  ):
169
199
  """
170
- This function is a convenience function that allows to simulate spectra multiple spectra from a given model and a
171
- set of parameters.
200
+ Convenience function to simulate multiple spectra from a given model and a set of parameters.
172
201
 
173
202
  TODO : avoid redundancy, better doc and type hints
174
203
 
@@ -209,9 +238,10 @@ def fakeit_for_multiple_parameters(
209
238
  return fakeits[0] if len(fakeits) == 1 else fakeits
210
239
 
211
240
 
212
- def data_path_finder(pha_path: str) -> Tuple[str | None, str | None, str | None]:
241
+ def data_path_finder(pha_path: str) -> tuple[str | None, str | None, str | None]:
213
242
  """
214
- This function tries its best to find the ARF, RMF and BKG files associated with a given PHA file.
243
+ Function which tries its best to find the ARF, RMF and BKG files associated with a given PHA file.
244
+
215
245
  Parameters:
216
246
  pha_path: The PHA file path.
217
247