jaxspec 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxspec/analysis/_plot.py +35 -0
- jaxspec/data/obsconf.py +34 -0
- jaxspec/fit.py +158 -310
- jaxspec/model/abc.py +26 -0
- jaxspec/model/additive.py +0 -2
- {jaxspec-0.1.1.dist-info → jaxspec-0.1.2.dist-info}/METADATA +13 -7
- {jaxspec-0.1.1.dist-info → jaxspec-0.1.2.dist-info}/RECORD +10 -12
- jaxspec/model/_additive/__init__.py +0 -0
- jaxspec/model/_additive/apec.py +0 -316
- jaxspec/model/_additive/apec_loaders.py +0 -73
- {jaxspec-0.1.1.dist-info → jaxspec-0.1.2.dist-info}/LICENSE.md +0 -0
- {jaxspec-0.1.1.dist-info → jaxspec-0.1.2.dist-info}/WHEEL +0 -0
- {jaxspec-0.1.1.dist-info → jaxspec-0.1.2.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
import matplotlib.pyplot as plt
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
from jax.typing import ArrayLike
|
|
5
|
+
from scipy.stats import nbinom
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _plot_poisson_data_with_error(
|
|
9
|
+
ax: plt.Axes,
|
|
10
|
+
x_bins: ArrayLike,
|
|
11
|
+
y: ArrayLike,
|
|
12
|
+
percentiles: tuple = (16, 84),
|
|
13
|
+
):
|
|
14
|
+
"""
|
|
15
|
+
Plot Poisson data with error bars. We extrapolate the intrinsic error of the observation assuming a prior rate
|
|
16
|
+
distributed according to a Gamma RV.
|
|
17
|
+
"""
|
|
18
|
+
y_low = nbinom.ppf(percentiles[0] / 100, y, 0.5)
|
|
19
|
+
y_high = nbinom.ppf(percentiles[1] / 100, y, 0.5)
|
|
20
|
+
|
|
21
|
+
ax_to_plot = ax.errorbar(
|
|
22
|
+
np.sqrt(x_bins[0] * x_bins[1]),
|
|
23
|
+
y,
|
|
24
|
+
xerr=np.abs(x_bins - np.sqrt(x_bins[0] * x_bins[1])),
|
|
25
|
+
yerr=[
|
|
26
|
+
y - y_low,
|
|
27
|
+
y_high - y,
|
|
28
|
+
],
|
|
29
|
+
color="black",
|
|
30
|
+
linestyle="none",
|
|
31
|
+
alpha=0.3,
|
|
32
|
+
capsize=2,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
return ax_to_plot
|
jaxspec/data/obsconf.py
CHANGED
|
@@ -229,5 +229,39 @@ class ObsConfiguration(xr.Dataset):
|
|
|
229
229
|
attrs=observation.attrs | instrument.attrs,
|
|
230
230
|
)
|
|
231
231
|
|
|
232
|
+
@classmethod
|
|
233
|
+
def mock_from_instrument(
|
|
234
|
+
cls,
|
|
235
|
+
instrument: Instrument,
|
|
236
|
+
exposure: float,
|
|
237
|
+
low_energy: float = 1e-20,
|
|
238
|
+
high_energy: float = 1e20,
|
|
239
|
+
):
|
|
240
|
+
"""
|
|
241
|
+
Create a mock observation configuration from an instrument object. The fake observation will have zero counts.
|
|
242
|
+
|
|
243
|
+
Parameters:
|
|
244
|
+
instrument: The instrument object.
|
|
245
|
+
exposure: The total exposure of the mock observation.
|
|
246
|
+
low_energy: The lower bound of the energy range to consider.
|
|
247
|
+
high_energy: The upper bound of the energy range to consider.
|
|
248
|
+
"""
|
|
249
|
+
|
|
250
|
+
n_channels = len(instrument.coords["instrument_channel"])
|
|
251
|
+
|
|
252
|
+
observation = Observation.from_matrix(
|
|
253
|
+
np.zeros(n_channels),
|
|
254
|
+
sparse.eye(n_channels),
|
|
255
|
+
np.arange(n_channels),
|
|
256
|
+
np.zeros(n_channels, dtype=bool),
|
|
257
|
+
exposure,
|
|
258
|
+
backratio=np.ones(n_channels),
|
|
259
|
+
attributes={"description": "Mock observation"} | instrument.attrs,
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
return cls.from_instrument(
|
|
263
|
+
instrument, observation, low_energy=low_energy, high_energy=high_energy
|
|
264
|
+
)
|
|
265
|
+
|
|
232
266
|
def plot_counts(self, **kwargs):
|
|
233
267
|
return self.folded_counts.plot.step(x="e_min_folded", where="post", **kwargs)
|
jaxspec/fit.py
CHANGED
|
@@ -10,29 +10,27 @@ import arviz as az
|
|
|
10
10
|
import haiku as hk
|
|
11
11
|
import jax
|
|
12
12
|
import jax.numpy as jnp
|
|
13
|
+
import matplotlib.pyplot as plt
|
|
14
|
+
import numpy as np
|
|
13
15
|
import numpyro
|
|
14
|
-
import optimistix as optx
|
|
15
16
|
|
|
16
17
|
from jax import random
|
|
17
18
|
from jax.experimental.sparse import BCOO
|
|
18
|
-
from jax.flatten_util import ravel_pytree
|
|
19
19
|
from jax.random import PRNGKey
|
|
20
20
|
from jax.tree_util import tree_map
|
|
21
21
|
from jax.typing import ArrayLike
|
|
22
22
|
from numpyro.contrib.nested_sampling import NestedSampler
|
|
23
23
|
from numpyro.distributions import Distribution, Poisson, TransformedDistribution
|
|
24
24
|
from numpyro.infer import AIES, ESS, MCMC, NUTS, Predictive
|
|
25
|
-
from numpyro.infer.initialization import init_to_value
|
|
26
25
|
from numpyro.infer.inspect import get_model_relations
|
|
27
26
|
from numpyro.infer.reparam import TransformReparam
|
|
28
|
-
from numpyro.infer.util import
|
|
29
|
-
from scipy.stats import Covariance, multivariate_normal
|
|
27
|
+
from numpyro.infer.util import log_density
|
|
30
28
|
|
|
29
|
+
from .analysis._plot import _plot_poisson_data_with_error
|
|
31
30
|
from .analysis.results import FitResult
|
|
32
31
|
from .data import ObsConfiguration
|
|
33
32
|
from .model.abc import SpectralModel
|
|
34
33
|
from .model.background import BackgroundModel
|
|
35
|
-
from .util import catchtime
|
|
36
34
|
from .util.typing import PriorDictModel, PriorDictType
|
|
37
35
|
|
|
38
36
|
|
|
@@ -101,27 +99,6 @@ def build_numpyro_model_for_single_obs(
|
|
|
101
99
|
return numpyro_model
|
|
102
100
|
|
|
103
101
|
|
|
104
|
-
def filter_inference_data(
|
|
105
|
-
inference_data, observation_container, background_model=None
|
|
106
|
-
) -> az.InferenceData:
|
|
107
|
-
predictive_parameters = []
|
|
108
|
-
|
|
109
|
-
for key, value in observation_container.items():
|
|
110
|
-
if background_model is not None:
|
|
111
|
-
predictive_parameters.append(f"obs_{key}")
|
|
112
|
-
predictive_parameters.append(f"bkg_{key}")
|
|
113
|
-
else:
|
|
114
|
-
predictive_parameters.append(f"obs_{key}")
|
|
115
|
-
|
|
116
|
-
inference_data.posterior_predictive = inference_data.posterior_predictive[predictive_parameters]
|
|
117
|
-
|
|
118
|
-
parameters = [x for x in inference_data.posterior.keys() if not x.endswith("_base")]
|
|
119
|
-
inference_data.posterior = inference_data.posterior[parameters]
|
|
120
|
-
inference_data.prior = inference_data.prior[parameters]
|
|
121
|
-
|
|
122
|
-
return inference_data
|
|
123
|
-
|
|
124
|
-
|
|
125
102
|
class CountForwardModel(hk.Module):
|
|
126
103
|
"""
|
|
127
104
|
A haiku module which allows to build the function that simulates the measured counts
|
|
@@ -154,7 +131,8 @@ class CountForwardModel(hk.Module):
|
|
|
154
131
|
|
|
155
132
|
class BayesianModel:
|
|
156
133
|
"""
|
|
157
|
-
|
|
134
|
+
Base class for a Bayesian model. This class contains the necessary methods to build a model, sample from the prior
|
|
135
|
+
and compute the log-likelihood and posterior probability.
|
|
158
136
|
"""
|
|
159
137
|
|
|
160
138
|
def __init__(
|
|
@@ -166,6 +144,8 @@ class BayesianModel:
|
|
|
166
144
|
sparsify_matrix: bool = False,
|
|
167
145
|
):
|
|
168
146
|
"""
|
|
147
|
+
Build a Bayesian model for a given spectral model and observations.
|
|
148
|
+
|
|
169
149
|
Parameters:
|
|
170
150
|
model: the spectral model to fit.
|
|
171
151
|
prior_distributions: a nested dictionary containing the prior distributions for the model parameters, or a
|
|
@@ -191,7 +171,7 @@ class BayesianModel:
|
|
|
191
171
|
prior_distributions_func = prior_distributions
|
|
192
172
|
|
|
193
173
|
self.prior_distributions_func = prior_distributions_func
|
|
194
|
-
self.init_params = self.
|
|
174
|
+
self.init_params = self.prior_samples()
|
|
195
175
|
|
|
196
176
|
@cached_property
|
|
197
177
|
def observation_container(self) -> dict[str, ObsConfiguration]:
|
|
@@ -215,9 +195,6 @@ class BayesianModel:
|
|
|
215
195
|
def numpyro_model(self) -> Callable:
|
|
216
196
|
"""
|
|
217
197
|
Build the numpyro model using the observed data, the prior distributions and the spectral model.
|
|
218
|
-
|
|
219
|
-
Returns:
|
|
220
|
-
A model function that can be used with numpyro.
|
|
221
198
|
"""
|
|
222
199
|
|
|
223
200
|
def model(observed=True):
|
|
@@ -257,9 +234,6 @@ class BayesianModel:
|
|
|
257
234
|
def log_likelihood_per_obs(self) -> Callable:
|
|
258
235
|
"""
|
|
259
236
|
Build the log likelihood function for each bins in each observation.
|
|
260
|
-
|
|
261
|
-
Returns:
|
|
262
|
-
Callable log-likelihood function.
|
|
263
237
|
"""
|
|
264
238
|
|
|
265
239
|
@jax.jit
|
|
@@ -316,6 +290,16 @@ class BayesianModel:
|
|
|
316
290
|
observed_sites = relations["observed"]
|
|
317
291
|
return [site for site in all_sites if site not in observed_sites]
|
|
318
292
|
|
|
293
|
+
@cached_property
|
|
294
|
+
def observation_names(self) -> list[str]:
|
|
295
|
+
"""
|
|
296
|
+
List of the observations.
|
|
297
|
+
"""
|
|
298
|
+
relations = get_model_relations(self.numpyro_model)
|
|
299
|
+
all_sites = relations["sample_sample"].keys()
|
|
300
|
+
observed_sites = relations["observed"]
|
|
301
|
+
return [site for site in all_sites if site in observed_sites]
|
|
302
|
+
|
|
319
303
|
def array_to_dict(self, theta):
|
|
320
304
|
"""
|
|
321
305
|
Convert an array of parameters to a dictionary of parameters.
|
|
@@ -339,7 +323,7 @@ class BayesianModel:
|
|
|
339
323
|
|
|
340
324
|
return theta
|
|
341
325
|
|
|
342
|
-
def
|
|
326
|
+
def prior_samples(self, key: PRNGKey = PRNGKey(0), num_samples: int = 100):
|
|
343
327
|
"""
|
|
344
328
|
Get initial parameters for the model by sampling from the prior distribution
|
|
345
329
|
|
|
@@ -348,9 +332,84 @@ class BayesianModel:
|
|
|
348
332
|
num_samples: the number of samples to draw from the prior.
|
|
349
333
|
"""
|
|
350
334
|
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
335
|
+
@jax.jit
|
|
336
|
+
def prior_sample(key):
|
|
337
|
+
return Predictive(
|
|
338
|
+
self.numpyro_model, return_sites=self.parameter_names, num_samples=num_samples
|
|
339
|
+
)(key, observed=False)
|
|
340
|
+
|
|
341
|
+
return prior_sample(key)
|
|
342
|
+
|
|
343
|
+
def mock_observations(self, parameters, key: PRNGKey = PRNGKey(0)):
|
|
344
|
+
@jax.jit
|
|
345
|
+
def fakeit(key, parameters):
|
|
346
|
+
return Predictive(
|
|
347
|
+
self.numpyro_model,
|
|
348
|
+
return_sites=self.observation_names,
|
|
349
|
+
posterior_samples=parameters,
|
|
350
|
+
)(key, observed=False)
|
|
351
|
+
|
|
352
|
+
return fakeit(key, parameters)
|
|
353
|
+
|
|
354
|
+
def prior_predictive_coverage(
|
|
355
|
+
self, key: PRNGKey = PRNGKey(0), num_samples: int = 1000, percentiles: tuple = (16, 84)
|
|
356
|
+
):
|
|
357
|
+
"""
|
|
358
|
+
Check if the prior distribution include the observed data.
|
|
359
|
+
"""
|
|
360
|
+
key_prior, key_posterior = jax.random.split(key, 2)
|
|
361
|
+
prior_params = self.prior_samples(key=key_prior, num_samples=num_samples)
|
|
362
|
+
posterior_observations = self.mock_observations(prior_params, key=key_posterior)
|
|
363
|
+
|
|
364
|
+
for key, value in self.observation_container.items():
|
|
365
|
+
fig, axs = plt.subplots(
|
|
366
|
+
nrows=2, ncols=1, sharex=True, figsize=(8, 8), height_ratios=[3, 1]
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
_plot_poisson_data_with_error(
|
|
370
|
+
axs[0],
|
|
371
|
+
value.out_energies,
|
|
372
|
+
value.folded_counts.values,
|
|
373
|
+
percentiles=percentiles,
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
axs[0].stairs(
|
|
377
|
+
np.max(posterior_observations["obs_" + key], axis=0),
|
|
378
|
+
edges=[*list(value.out_energies[0]), value.out_energies[1][-1]],
|
|
379
|
+
baseline=np.min(posterior_observations["obs_" + key], axis=0),
|
|
380
|
+
alpha=0.3,
|
|
381
|
+
fill=True,
|
|
382
|
+
color=(0.15, 0.25, 0.45),
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
# rank = np.vstack((posterior_observations["obs_" + key], value.folded_counts.values)).argsort(axis=0)[-1] / (num_samples) * 100
|
|
386
|
+
counts = posterior_observations["obs_" + key]
|
|
387
|
+
observed = value.folded_counts.values
|
|
388
|
+
|
|
389
|
+
num_samples = counts.shape[0]
|
|
390
|
+
|
|
391
|
+
less_than_obs = (counts < observed).sum(axis=0)
|
|
392
|
+
equal_to_obs = (counts == observed).sum(axis=0)
|
|
393
|
+
|
|
394
|
+
rank = (less_than_obs + 0.5 * equal_to_obs) / num_samples * 100
|
|
395
|
+
|
|
396
|
+
axs[1].stairs(rank, edges=[*list(value.out_energies[0]), value.out_energies[1][-1]])
|
|
397
|
+
|
|
398
|
+
axs[1].plot(
|
|
399
|
+
(value.out_energies.min(), value.out_energies.max()),
|
|
400
|
+
(50, 50),
|
|
401
|
+
color="black",
|
|
402
|
+
linestyle="--",
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
axs[1].set_xlabel("Energy (keV)")
|
|
406
|
+
axs[0].set_ylabel("Counts")
|
|
407
|
+
axs[1].set_ylabel("Rank (%)")
|
|
408
|
+
axs[1].set_ylim(0, 100)
|
|
409
|
+
axs[0].set_xlim(value.out_energies.min(), value.out_energies.max())
|
|
410
|
+
axs[0].loglog()
|
|
411
|
+
plt.suptitle(f"Prior Predictive coverage for {key}")
|
|
412
|
+
plt.show()
|
|
354
413
|
|
|
355
414
|
|
|
356
415
|
class BayesianModelFitter(BayesianModel, ABC):
|
|
@@ -359,11 +418,20 @@ class BayesianModelFitter(BayesianModel, ABC):
|
|
|
359
418
|
posterior_samples,
|
|
360
419
|
num_chains: int = 1,
|
|
361
420
|
num_predictive_samples: int = 1000,
|
|
362
|
-
key: PRNGKey = PRNGKey(
|
|
421
|
+
key: PRNGKey = PRNGKey(42),
|
|
363
422
|
use_transformed_model: bool = False,
|
|
423
|
+
filter_inference_data: bool = True,
|
|
364
424
|
) -> az.InferenceData:
|
|
365
425
|
"""
|
|
366
|
-
Build an InferenceData object from
|
|
426
|
+
Build an [InferenceData][arviz.InferenceData] object from posterior samples.
|
|
427
|
+
|
|
428
|
+
Parameters:
|
|
429
|
+
posterior_samples: the samples from the posterior distribution.
|
|
430
|
+
num_chains: the number of chains used to sample the posterior.
|
|
431
|
+
num_predictive_samples: the number of samples to draw from the prior.
|
|
432
|
+
key: the random key used to initialize the sampler.
|
|
433
|
+
use_transformed_model: whether to use the transformed model to build the InferenceData.
|
|
434
|
+
filter_inference_data: whether to filter the InferenceData to keep only the relevant parameters.
|
|
367
435
|
"""
|
|
368
436
|
|
|
369
437
|
numpyro_model = (
|
|
@@ -409,7 +477,7 @@ class BayesianModelFitter(BayesianModel, ABC):
|
|
|
409
477
|
key: reshape_first_dimension(value) for key, value in log_likelihood.items()
|
|
410
478
|
}
|
|
411
479
|
|
|
412
|
-
|
|
480
|
+
inference_data = az.from_dict(
|
|
413
481
|
posterior_samples,
|
|
414
482
|
prior=prior,
|
|
415
483
|
posterior_predictive=posterior_predictive,
|
|
@@ -417,81 +485,42 @@ class BayesianModelFitter(BayesianModel, ABC):
|
|
|
417
485
|
observed_data=observations,
|
|
418
486
|
)
|
|
419
487
|
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
class NUTSFitter(BayesianModelFitter):
|
|
425
|
-
"""
|
|
426
|
-
A class to fit a model to a given set of observation using a Bayesian approach. This class uses the NUTS sampler
|
|
427
|
-
from numpyro to perform the inference on the model parameters.
|
|
428
|
-
"""
|
|
488
|
+
return (
|
|
489
|
+
self.filter_inference_data(inference_data) if filter_inference_data else inference_data
|
|
490
|
+
)
|
|
429
491
|
|
|
430
|
-
def
|
|
492
|
+
def filter_inference_data(
|
|
431
493
|
self,
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
num_warmup: int = 1000,
|
|
435
|
-
num_samples: int = 1000,
|
|
436
|
-
max_tree_depth: int = 10,
|
|
437
|
-
target_accept_prob: float = 0.8,
|
|
438
|
-
dense_mass: bool = False,
|
|
439
|
-
kernel_kwargs: dict = {},
|
|
440
|
-
mcmc_kwargs: dict = {},
|
|
441
|
-
) -> FitResult:
|
|
494
|
+
inference_data: az.InferenceData,
|
|
495
|
+
) -> az.InferenceData:
|
|
442
496
|
"""
|
|
443
|
-
|
|
497
|
+
Filter the inference data to keep only the relevant parameters for the observations.
|
|
444
498
|
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
num_chains: the number of chains to run.
|
|
448
|
-
num_warmup: the number of warmup steps.
|
|
449
|
-
num_samples: the number of samples to draw.
|
|
450
|
-
max_tree_depth: the recursion depth of NUTS sampler.
|
|
451
|
-
target_accept_prob: the target acceptance probability for the NUTS sampler.
|
|
452
|
-
dense_mass: whether to use a dense mass for the NUTS sampler.
|
|
453
|
-
mcmc_kwargs: additional arguments to pass to the MCMC sampler. See [`MCMC`][numpyro.infer.mcmc.MCMC] for more details.
|
|
454
|
-
|
|
455
|
-
Returns:
|
|
456
|
-
A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
|
|
499
|
+
- Removes predictive parameters from deterministic random variables (e.g. kernel of background GP)
|
|
500
|
+
- Removes parameters build from reparametrised variables (e.g. ending with `"_base"`)
|
|
457
501
|
"""
|
|
458
502
|
|
|
459
|
-
|
|
460
|
-
# bayesian_model = self.numpyro_model(prior_distributions)
|
|
461
|
-
|
|
462
|
-
chain_kwargs = {
|
|
463
|
-
"num_warmup": num_warmup,
|
|
464
|
-
"num_samples": num_samples,
|
|
465
|
-
"num_chains": num_chains,
|
|
466
|
-
}
|
|
467
|
-
|
|
468
|
-
kernel = NUTS(
|
|
469
|
-
bayesian_model,
|
|
470
|
-
max_tree_depth=max_tree_depth,
|
|
471
|
-
target_accept_prob=target_accept_prob,
|
|
472
|
-
dense_mass=dense_mass,
|
|
473
|
-
**kernel_kwargs,
|
|
474
|
-
)
|
|
503
|
+
predictive_parameters = []
|
|
475
504
|
|
|
476
|
-
|
|
477
|
-
|
|
505
|
+
for key, value in self.observation_container.items():
|
|
506
|
+
if self.background_model is not None:
|
|
507
|
+
predictive_parameters.append(f"obs_{key}")
|
|
508
|
+
predictive_parameters.append(f"bkg_{key}")
|
|
509
|
+
else:
|
|
510
|
+
predictive_parameters.append(f"obs_{key}")
|
|
478
511
|
|
|
479
|
-
|
|
512
|
+
inference_data.posterior_predictive = inference_data.posterior_predictive[
|
|
513
|
+
predictive_parameters
|
|
514
|
+
]
|
|
480
515
|
|
|
481
|
-
|
|
516
|
+
parameters = [x for x in inference_data.posterior.keys() if not x.endswith("_base")]
|
|
517
|
+
inference_data.posterior = inference_data.posterior[parameters]
|
|
518
|
+
inference_data.prior = inference_data.prior[parameters]
|
|
482
519
|
|
|
483
|
-
inference_data
|
|
484
|
-
self.build_inference_data(posterior, num_chains=num_chains),
|
|
485
|
-
self.observation_container,
|
|
486
|
-
self.background_model,
|
|
487
|
-
)
|
|
520
|
+
return inference_data
|
|
488
521
|
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
inference_data,
|
|
492
|
-
self.model.params,
|
|
493
|
-
background_model=self.background_model,
|
|
494
|
-
)
|
|
522
|
+
@abstractmethod
|
|
523
|
+
def fit(self, **kwargs) -> FitResult: ...
|
|
495
524
|
|
|
496
525
|
|
|
497
526
|
class MCMCFitter(BayesianModelFitter):
|
|
@@ -513,6 +542,7 @@ class MCMCFitter(BayesianModelFitter):
|
|
|
513
542
|
num_warmup: int = 1000,
|
|
514
543
|
num_samples: int = 1000,
|
|
515
544
|
sampler: Literal["nuts", "aies", "ess"] = "nuts",
|
|
545
|
+
use_transformed_model: bool = True,
|
|
516
546
|
kernel_kwargs: dict = {},
|
|
517
547
|
mcmc_kwargs: dict = {},
|
|
518
548
|
) -> FitResult:
|
|
@@ -524,17 +554,18 @@ class MCMCFitter(BayesianModelFitter):
|
|
|
524
554
|
num_chains: the number of chains to run.
|
|
525
555
|
num_warmup: the number of warmup steps.
|
|
526
556
|
num_samples: the number of samples to draw.
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
557
|
+
sampler: the sampler to use. Can be one of "nuts", "aies" or "ess".
|
|
558
|
+
use_transformed_model: whether to use the transformed model to build the InferenceData.
|
|
559
|
+
kernel_kwargs: additional arguments to pass to the kernel. See [`NUTS`][numpyro.infer.mcmc.MCMCKernel] for more details.
|
|
530
560
|
mcmc_kwargs: additional arguments to pass to the MCMC sampler. See [`MCMC`][numpyro.infer.mcmc.MCMC] for more details.
|
|
531
561
|
|
|
532
562
|
Returns:
|
|
533
563
|
A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
|
|
534
564
|
"""
|
|
535
565
|
|
|
536
|
-
bayesian_model =
|
|
537
|
-
|
|
566
|
+
bayesian_model = (
|
|
567
|
+
self.transformed_numpyro_model if use_transformed_model else self.numpyro_model
|
|
568
|
+
)
|
|
538
569
|
|
|
539
570
|
chain_kwargs = {
|
|
540
571
|
"num_warmup": num_warmup,
|
|
@@ -557,10 +588,8 @@ class MCMCFitter(BayesianModelFitter):
|
|
|
557
588
|
|
|
558
589
|
posterior = mcmc.get_samples()
|
|
559
590
|
|
|
560
|
-
inference_data =
|
|
561
|
-
|
|
562
|
-
self.observation_container,
|
|
563
|
-
self.background_model,
|
|
591
|
+
inference_data = self.build_inference_data(
|
|
592
|
+
posterior, num_chains=num_chains, use_transformed_model=True
|
|
564
593
|
)
|
|
565
594
|
|
|
566
595
|
return FitResult(
|
|
@@ -571,175 +600,22 @@ class MCMCFitter(BayesianModelFitter):
|
|
|
571
600
|
)
|
|
572
601
|
|
|
573
602
|
|
|
574
|
-
class
|
|
575
|
-
"""
|
|
576
|
-
A class to fit a model to a given set of observation using a minimization algorithm. This class uses the L-BFGS
|
|
577
|
-
algorithm from jaxopt to perform the minimization on the model parameters. The uncertainties are computed using the
|
|
578
|
-
Hessian of the log-log_likelihood, assuming that it is a multivariate Gaussian in the unbounded space defined by
|
|
579
|
-
numpyro.
|
|
580
|
-
"""
|
|
581
|
-
|
|
582
|
-
def fit(
|
|
583
|
-
self,
|
|
584
|
-
rng_key: int = 0,
|
|
585
|
-
num_iter_max: int = 100_000,
|
|
586
|
-
num_samples: int = 1_000,
|
|
587
|
-
solver: Literal["bfgs", "levenberg_marquardt"] = "bfgs",
|
|
588
|
-
init_params=None,
|
|
589
|
-
refine_first_guess=True,
|
|
590
|
-
) -> FitResult:
|
|
591
|
-
"""
|
|
592
|
-
Fit the model to the data using L-BFGS algorithm.
|
|
593
|
-
|
|
594
|
-
Parameters:
|
|
595
|
-
rng_key: the random key used to initialize the sampler.
|
|
596
|
-
num_iter_max: the maximum number of iteration in the minimization algorithm.
|
|
597
|
-
num_samples: the number of sample to draw from the best-fit covariance.
|
|
598
|
-
|
|
599
|
-
Returns:
|
|
600
|
-
A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
|
|
601
|
-
"""
|
|
602
|
-
|
|
603
|
-
bayesian_model = self.numpyro_model
|
|
604
|
-
keys = jax.random.split(PRNGKey(rng_key), 4)
|
|
605
|
-
|
|
606
|
-
if init_params is not None:
|
|
607
|
-
# We initialize the parameters by randomly sampling from the prior
|
|
608
|
-
local_keys = jax.random.split(keys[0], 2)
|
|
609
|
-
|
|
610
|
-
with numpyro.handlers.seed(rng_seed=local_keys[0]):
|
|
611
|
-
starting_value = self.prior_distributions_func()
|
|
612
|
-
|
|
613
|
-
# We update the starting value with the provided init_params
|
|
614
|
-
for m, n, val in hk.data_structures.traverse(init_params):
|
|
615
|
-
if f"{m}_{n}" in starting_value.keys():
|
|
616
|
-
starting_value[f"{m}_{n}"] = val
|
|
617
|
-
|
|
618
|
-
init_params, _ = numpyro.infer.util.find_valid_initial_params(
|
|
619
|
-
local_keys[1], bayesian_model, init_strategy=init_to_value(values=starting_value)
|
|
620
|
-
)
|
|
621
|
-
|
|
622
|
-
else:
|
|
623
|
-
init_params, _ = numpyro.infer.util.find_valid_initial_params(keys[0], bayesian_model)
|
|
624
|
-
|
|
625
|
-
init_params = init_params[0]
|
|
626
|
-
|
|
627
|
-
@jax.jit
|
|
628
|
-
def nll(unconstrained_params, _):
|
|
629
|
-
constrained_params = constrain_fn(
|
|
630
|
-
bayesian_model, tuple(), dict(observed=True), unconstrained_params
|
|
631
|
-
)
|
|
632
|
-
|
|
633
|
-
log_likelihood = numpyro.infer.util.log_likelihood(
|
|
634
|
-
model=bayesian_model, posterior_samples=constrained_params
|
|
635
|
-
)
|
|
636
|
-
|
|
637
|
-
# We solve a least square problem, this function ensure that the total residual is indeed the nll
|
|
638
|
-
return jax.tree.map(lambda x: jnp.sqrt(-x), log_likelihood)
|
|
639
|
-
|
|
640
|
-
"""
|
|
641
|
-
if refine_first_guess:
|
|
642
|
-
with catchtime("Refine_first"):
|
|
643
|
-
solution = optx.least_squares(
|
|
644
|
-
nll,
|
|
645
|
-
optx.BestSoFarMinimiser(optx.OptaxMinimiser(optax.adam(1e-4), 1e-6, 1e-6)),
|
|
646
|
-
init_params,
|
|
647
|
-
max_steps=1000,
|
|
648
|
-
throw=False
|
|
649
|
-
)
|
|
650
|
-
init_params = solution.value
|
|
651
|
-
"""
|
|
652
|
-
|
|
653
|
-
if solver == "bfgs":
|
|
654
|
-
solver = optx.BestSoFarMinimiser(optx.BFGS(1e-6, 1e-6))
|
|
655
|
-
elif solver == "levenberg_marquardt":
|
|
656
|
-
solver = optx.BestSoFarLeastSquares(optx.LevenbergMarquardt(1e-6, 1e-6))
|
|
657
|
-
else:
|
|
658
|
-
raise NotImplementedError(f"{solver} is not implemented")
|
|
659
|
-
|
|
660
|
-
with catchtime("Minimization"):
|
|
661
|
-
solution = optx.least_squares(
|
|
662
|
-
nll,
|
|
663
|
-
solver,
|
|
664
|
-
init_params,
|
|
665
|
-
max_steps=num_iter_max,
|
|
666
|
-
)
|
|
667
|
-
|
|
668
|
-
params = solution.value
|
|
669
|
-
value_flat, unflatten_fun = ravel_pytree(params)
|
|
670
|
-
|
|
671
|
-
with catchtime("Compute error"):
|
|
672
|
-
precision = jax.hessian(
|
|
673
|
-
lambda p: jnp.sum(ravel_pytree(nll(unflatten_fun(p), None))[0] ** 2)
|
|
674
|
-
)(value_flat)
|
|
675
|
-
|
|
676
|
-
cov = Covariance.from_precision(precision)
|
|
677
|
-
|
|
678
|
-
samples_flat = multivariate_normal.rvs(mean=value_flat, cov=cov, size=num_samples)
|
|
679
|
-
|
|
680
|
-
samples = jax.vmap(unflatten_fun)(samples_flat)
|
|
681
|
-
posterior_samples = jax.jit(
|
|
682
|
-
jax.vmap(lambda p: constrain_fn(bayesian_model, tuple(), dict(observed=True), p))
|
|
683
|
-
)(samples)
|
|
684
|
-
|
|
685
|
-
with catchtime("Posterior"):
|
|
686
|
-
posterior_predictive = Predictive(bayesian_model, posterior_samples)(
|
|
687
|
-
keys[2], observed=False
|
|
688
|
-
)
|
|
689
|
-
prior = Predictive(bayesian_model, num_samples=num_samples)(keys[3], observed=False)
|
|
690
|
-
log_likelihood = numpyro.infer.log_likelihood(bayesian_model, posterior_samples)
|
|
691
|
-
|
|
692
|
-
def sanitize_chain(chain):
|
|
693
|
-
"""
|
|
694
|
-
reshape the samples so that it is arviz compliant with an extra starting dimension
|
|
695
|
-
"""
|
|
696
|
-
return tree_map(lambda x: x[None, ...], chain)
|
|
697
|
-
|
|
698
|
-
# We export the observed values to the inference_data
|
|
699
|
-
seeded_model = numpyro.handlers.substitute(
|
|
700
|
-
numpyro.handlers.seed(bayesian_model, jax.random.PRNGKey(0)),
|
|
701
|
-
substitute_fn=numpyro.infer.init_to_sample,
|
|
702
|
-
)
|
|
703
|
-
trace = numpyro.handlers.trace(seeded_model).get_trace()
|
|
704
|
-
observations = {
|
|
705
|
-
name: site["value"]
|
|
706
|
-
for name, site in trace.items()
|
|
707
|
-
if site["type"] == "sample" and site["is_observed"]
|
|
708
|
-
}
|
|
709
|
-
|
|
710
|
-
with catchtime("InferenceData wrapping"):
|
|
711
|
-
inference_data = az.from_dict(
|
|
712
|
-
sanitize_chain(posterior_samples),
|
|
713
|
-
prior=sanitize_chain(prior),
|
|
714
|
-
posterior_predictive=sanitize_chain(posterior_predictive),
|
|
715
|
-
log_likelihood=sanitize_chain(log_likelihood),
|
|
716
|
-
observed_data=observations,
|
|
717
|
-
)
|
|
718
|
-
|
|
719
|
-
inference_data = filter_inference_data(
|
|
720
|
-
inference_data, self.observation_container, self.background_model
|
|
721
|
-
)
|
|
722
|
-
|
|
723
|
-
return FitResult(
|
|
724
|
-
self,
|
|
725
|
-
inference_data,
|
|
726
|
-
self.model.params,
|
|
727
|
-
background_model=self.background_model,
|
|
728
|
-
)
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
class NestedSamplingFitter(BayesianModelFitter):
|
|
603
|
+
class NSFitter(BayesianModelFitter):
|
|
732
604
|
r"""
|
|
733
605
|
A class to fit a model to a given set of observation using the Nested Sampling algorithm. This class uses the
|
|
734
606
|
[`DefaultNestedSampler`][jaxns.DefaultNestedSampler] from [`jaxns`](https://jaxns.readthedocs.io/en/latest/) which
|
|
735
607
|
implements the [Phantom-Powered Nested Sampling](https://arxiv.org/abs/2312.11330) algorithm.
|
|
736
|
-
|
|
608
|
+
|
|
609
|
+
!!! info
|
|
610
|
+
Ensure large prior volume is covered by the prior distributions to ensure the algorithm yield proper results.
|
|
611
|
+
|
|
737
612
|
"""
|
|
738
613
|
|
|
739
614
|
def fit(
|
|
740
615
|
self,
|
|
741
616
|
rng_key: int = 0,
|
|
742
617
|
num_samples: int = 1000,
|
|
618
|
+
num_live_points: int = 1000,
|
|
743
619
|
plot_diagnostics=False,
|
|
744
620
|
termination_kwargs: dict | None = None,
|
|
745
621
|
verbose=True,
|
|
@@ -750,6 +626,10 @@ class NestedSamplingFitter(BayesianModelFitter):
|
|
|
750
626
|
Parameters:
|
|
751
627
|
rng_key: the random key used to initialize the sampler.
|
|
752
628
|
num_samples: the number of samples to draw.
|
|
629
|
+
num_live_points: the number of live points to use at the start of the NS algorithm.
|
|
630
|
+
plot_diagnostics: whether to plot the diagnostics of the NS algorithm.
|
|
631
|
+
termination_kwargs: additional arguments to pass to the termination criterion of the NS algorithm.
|
|
632
|
+
verbose: whether to print the progress of the NS algorithm.
|
|
753
633
|
|
|
754
634
|
Returns:
|
|
755
635
|
A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
|
|
@@ -766,7 +646,7 @@ class NestedSamplingFitter(BayesianModelFitter):
|
|
|
766
646
|
difficult_model=True,
|
|
767
647
|
max_samples=1e6,
|
|
768
648
|
parameter_estimation=True,
|
|
769
|
-
num_live_points=
|
|
649
|
+
num_live_points=num_live_points,
|
|
770
650
|
),
|
|
771
651
|
termination_kwargs=termination_kwargs if termination_kwargs else dict(),
|
|
772
652
|
)
|
|
@@ -776,41 +656,9 @@ class NestedSamplingFitter(BayesianModelFitter):
|
|
|
776
656
|
if plot_diagnostics:
|
|
777
657
|
ns.diagnostics()
|
|
778
658
|
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
keys[2], observed=False
|
|
783
|
-
)
|
|
784
|
-
|
|
785
|
-
prior = Predictive(bayesian_model, num_samples=num_samples)(keys[3], observed=False)
|
|
786
|
-
|
|
787
|
-
seeded_model = numpyro.handlers.substitute(
|
|
788
|
-
numpyro.handlers.seed(bayesian_model, jax.random.PRNGKey(0)),
|
|
789
|
-
substitute_fn=numpyro.infer.init_to_sample,
|
|
790
|
-
)
|
|
791
|
-
trace = numpyro.handlers.trace(seeded_model).get_trace()
|
|
792
|
-
observations = {
|
|
793
|
-
name: site["value"]
|
|
794
|
-
for name, site in trace.items()
|
|
795
|
-
if site["type"] == "sample" and site["is_observed"]
|
|
796
|
-
}
|
|
797
|
-
|
|
798
|
-
def sanitize_chain(chain):
|
|
799
|
-
"""
|
|
800
|
-
reshape the samples so that it is arviz compliant with an extra starting dimension
|
|
801
|
-
"""
|
|
802
|
-
return tree_map(lambda x: x[None, ...], chain)
|
|
803
|
-
|
|
804
|
-
inference_data = az.from_dict(
|
|
805
|
-
sanitize_chain(posterior_samples),
|
|
806
|
-
prior=sanitize_chain(prior),
|
|
807
|
-
posterior_predictive=sanitize_chain(posterior_predictive),
|
|
808
|
-
log_likelihood=sanitize_chain(log_likelihood),
|
|
809
|
-
observed_data=observations,
|
|
810
|
-
)
|
|
811
|
-
|
|
812
|
-
inference_data = filter_inference_data(
|
|
813
|
-
inference_data, self.observation_container, self.background_model
|
|
659
|
+
posterior = ns.get_samples(keys[1], num_samples=num_samples)
|
|
660
|
+
inference_data = self.build_inference_data(
|
|
661
|
+
posterior, num_chains=1, use_transformed_model=True
|
|
814
662
|
)
|
|
815
663
|
|
|
816
664
|
return FitResult(
|
jaxspec/model/abc.py
CHANGED
|
@@ -7,9 +7,11 @@ import haiku as hk
|
|
|
7
7
|
import jax
|
|
8
8
|
import jax.numpy as jnp
|
|
9
9
|
import networkx as nx
|
|
10
|
+
import rich
|
|
10
11
|
|
|
11
12
|
from haiku._src import base
|
|
12
13
|
from jax.scipy.integrate import trapezoid
|
|
14
|
+
from rich.table import Table
|
|
13
15
|
from simpleeval import simple_eval
|
|
14
16
|
|
|
15
17
|
|
|
@@ -110,6 +112,30 @@ class SpectralModel:
|
|
|
110
112
|
def params(self):
|
|
111
113
|
return self.transformed_func_photon.init(None, jnp.ones(10), jnp.ones(10))
|
|
112
114
|
|
|
115
|
+
def __rich_repr__(self):
|
|
116
|
+
table = Table(title=str(self))
|
|
117
|
+
|
|
118
|
+
table.add_column("Component", justify="right", style="bold", no_wrap=True)
|
|
119
|
+
table.add_column("Parameter")
|
|
120
|
+
|
|
121
|
+
params = self.params
|
|
122
|
+
|
|
123
|
+
for component in params.keys():
|
|
124
|
+
once = True
|
|
125
|
+
|
|
126
|
+
for parameters in params[component].keys():
|
|
127
|
+
table.add_row(component if once else "", parameters)
|
|
128
|
+
once = False
|
|
129
|
+
|
|
130
|
+
return table
|
|
131
|
+
|
|
132
|
+
def __repr_html_(self):
|
|
133
|
+
return self.__rich_repr__()
|
|
134
|
+
|
|
135
|
+
def __repr__(self):
|
|
136
|
+
rich.print(self.__rich_repr__())
|
|
137
|
+
return ""
|
|
138
|
+
|
|
113
139
|
def photon_flux(self, params, e_low, e_high, n_points=2):
|
|
114
140
|
r"""
|
|
115
141
|
Compute the expected counts between $E_\min$ and $E_\max$ by integrating the model.
|
jaxspec/model/additive.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: jaxspec
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.2
|
|
4
4
|
Summary: jaxspec is a bayesian spectral fitting library for X-ray astronomy.
|
|
5
5
|
Home-page: https://github.com/renecotyfanboy/jaxspec
|
|
6
6
|
License: MIT
|
|
@@ -15,19 +15,18 @@ Requires-Dist: arviz (>=0.17.1,<0.20.0)
|
|
|
15
15
|
Requires-Dist: astropy (>=6.0.0,<7.0.0)
|
|
16
16
|
Requires-Dist: chainconsumer (>=1.1.2,<2.0.0)
|
|
17
17
|
Requires-Dist: cmasher (>=1.6.3,<2.0.0)
|
|
18
|
-
Requires-Dist: dm-haiku (>=0.0.
|
|
18
|
+
Requires-Dist: dm-haiku (>=0.0.12,<0.0.13)
|
|
19
19
|
Requires-Dist: gpjax (>=0.8.0,<0.9.0)
|
|
20
20
|
Requires-Dist: interpax (>=0.3.3,<0.4.0)
|
|
21
|
-
Requires-Dist: jax (>=0.4.
|
|
21
|
+
Requires-Dist: jax (>=0.4.33,<0.5.0)
|
|
22
22
|
Requires-Dist: jaxlib (>=0.4.30,<0.5.0)
|
|
23
|
-
Requires-Dist: jaxns (
|
|
23
|
+
Requires-Dist: jaxns (<2.6)
|
|
24
24
|
Requires-Dist: jaxopt (>=0.8.1,<0.9.0)
|
|
25
25
|
Requires-Dist: matplotlib (>=3.8.0,<4.0.0)
|
|
26
26
|
Requires-Dist: mendeleev (>=0.15,<0.18)
|
|
27
|
-
Requires-Dist: mkdocstrings (>=0.24,<0.27)
|
|
28
27
|
Requires-Dist: networkx (>=3.1,<4.0)
|
|
29
28
|
Requires-Dist: numpy (<2.0.0)
|
|
30
|
-
Requires-Dist: numpyro (>=0.15.
|
|
29
|
+
Requires-Dist: numpyro (>=0.15.3,<0.16.0)
|
|
31
30
|
Requires-Dist: optimistix (>=0.0.7,<0.0.8)
|
|
32
31
|
Requires-Dist: pandas (>=2.2.0,<3.0.0)
|
|
33
32
|
Requires-Dist: pooch (>=1.8.2,<2.0.0)
|
|
@@ -41,7 +40,14 @@ Requires-Dist: watermark (>=2.4.3,<3.0.0)
|
|
|
41
40
|
Project-URL: Documentation, https://jaxspec.readthedocs.io/en/latest/
|
|
42
41
|
Description-Content-Type: text/markdown
|
|
43
42
|
|
|
44
|
-
|
|
43
|
+
<p align="center">
|
|
44
|
+
<img src="https://raw.githubusercontent.com/renecotyfanboy/jaxspec/main/docs/logo/logo_small.svg" alt="Logo" width="100" height="100">
|
|
45
|
+
</p>
|
|
46
|
+
|
|
47
|
+
<h1 align="center">
|
|
48
|
+
jaxspec
|
|
49
|
+
</h1>
|
|
50
|
+
|
|
45
51
|
|
|
46
52
|
[)](https://pypi.org/project/jaxspec/)
|
|
47
53
|
[](https://pypi.org/project/jaxspec/)
|
|
@@ -1,21 +1,19 @@
|
|
|
1
1
|
jaxspec/__init__.py,sha256=Sbn02lX6Y-zNXk17N8dec22c5jeypiS0LkHmGfz7lWA,126
|
|
2
2
|
jaxspec/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
+
jaxspec/analysis/_plot.py,sha256=C4XljmuzQz8xQur_jQddgInrBDmKgTn0eugSreLoD5k,862
|
|
3
4
|
jaxspec/analysis/compare.py,sha256=g2UFhmR9Zt-7cz5gQFOB6lXuklXB3yTyUvjTypOzoSY,725
|
|
4
5
|
jaxspec/analysis/results.py,sha256=Kz3eryxS3N_hiajcFLTWS1dtgTQo5hlh-rDCnJ3A-3c,27811
|
|
5
6
|
jaxspec/data/__init__.py,sha256=aantcYKC9kZFvaE-V2SIwSuLhIld17Kjrd9CIUu___Y,415
|
|
6
7
|
jaxspec/data/grouping.py,sha256=hhgBt-voiH0DDSyePacaIGsaMnrYbJM_-ZeU66keC7I,622
|
|
7
8
|
jaxspec/data/instrument.py,sha256=0pSf1p82g7syDMmKm13eVbYih-Veiq5DnwsyZe6_b4g,3890
|
|
8
|
-
jaxspec/data/obsconf.py,sha256=
|
|
9
|
+
jaxspec/data/obsconf.py,sha256=gv14sL6azK2avRiMCWuTbyLBPulzm4PwvoLY6iWPEVE,9833
|
|
9
10
|
jaxspec/data/observation.py,sha256=1UnFu5ihZp9z-vP_I7tsFY8jhhIJunv46JyuE-acrg0,6394
|
|
10
11
|
jaxspec/data/ogip.py,sha256=sv9p00qHS5pzw61pzWyyF0nV-E-RXySdSFK2tUavokA,9545
|
|
11
12
|
jaxspec/data/util.py,sha256=ycLPVE-cjn6VpUWYlBU1BGfw73ANXIBilyVAUOYOSj0,9540
|
|
12
|
-
jaxspec/fit.py,sha256=
|
|
13
|
+
jaxspec/fit.py,sha256=hI0koMO4KsNpe9mLlaFm_tNLgm4BVAYVyiMb1E1eyZE,24553
|
|
13
14
|
jaxspec/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
14
|
-
jaxspec/model/
|
|
15
|
-
jaxspec/model/
|
|
16
|
-
jaxspec/model/_additive/apec_loaders.py,sha256=jkUoH0ezeYdaNw3oV10V0L-jt848SKp2thanLWLWp9k,2412
|
|
17
|
-
jaxspec/model/abc.py,sha256=nQZUmtUzXjW94gv3BJg1lHXHZtgrHoOlAR4a6G2a9VQ,20234
|
|
18
|
-
jaxspec/model/additive.py,sha256=xD5E30nd5pqa-swQireA52ch1czxnqRosnh-dsp5xL0,22485
|
|
15
|
+
jaxspec/model/abc.py,sha256=MuxEyvn223QPwGoFIJiST8nRMgrZ08ZLkw33oep3tx4,20887
|
|
16
|
+
jaxspec/model/additive.py,sha256=wjY2wL3Io3F45GJpz-UB8xYVnA-W1OFBnZMbj5pWPbQ,22449
|
|
19
17
|
jaxspec/model/background.py,sha256=QSFFiuyUEvuzXBx3QfkvVneUR8KKEP-VaANEVXcavDE,7865
|
|
20
18
|
jaxspec/model/list.py,sha256=0RPAoscVz_zM1CWdx_Gd5wfrQWV5Nv4Kd4bSXu2ayUA,860
|
|
21
19
|
jaxspec/model/multiplicative.py,sha256=GCQ6JRz92QqbzDBFwWxGZ9SUqTJZQpD7B6ji9VEFXWo,8135
|
|
@@ -26,8 +24,8 @@ jaxspec/util/abundance.py,sha256=fsC313taIlGzQsZNwbYsJupDWm7ZbqzGhY66Ku394Mw,854
|
|
|
26
24
|
jaxspec/util/integrate.py,sha256=_Ax_knpC7d4et2-QFkOUzVtNeQLX1-cwLvm-FRBxYcw,4505
|
|
27
25
|
jaxspec/util/online_storage.py,sha256=vm56RfcbFKpkRVfr0bXO7J9aQxuBq-I_oEgA26YIhCo,2469
|
|
28
26
|
jaxspec/util/typing.py,sha256=8qK1aJlsqTcVKjYN-BxsDx20BTwtnS-wMw6Bdurpm-o,2459
|
|
29
|
-
jaxspec-0.1.
|
|
30
|
-
jaxspec-0.1.
|
|
31
|
-
jaxspec-0.1.
|
|
32
|
-
jaxspec-0.1.
|
|
33
|
-
jaxspec-0.1.
|
|
27
|
+
jaxspec-0.1.2.dist-info/LICENSE.md,sha256=2q5XoWzddts5IqzIcgYYMOL21puU3MfO8gvT3Ype1eQ,1073
|
|
28
|
+
jaxspec-0.1.2.dist-info/METADATA,sha256=FE2bTAk-3Xryi6fplV4Y-F2eibUdLZgC9ET9_4HvdOA,3708
|
|
29
|
+
jaxspec-0.1.2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
30
|
+
jaxspec-0.1.2.dist-info/entry_points.txt,sha256=kzLG2mGlCWITRn4Q6zKG_idx-_RKAncvA0DMNYTgHAg,71
|
|
31
|
+
jaxspec-0.1.2.dist-info/RECORD,,
|
|
File without changes
|
jaxspec/model/_additive/apec.py
DELETED
|
@@ -1,316 +0,0 @@
|
|
|
1
|
-
import warnings
|
|
2
|
-
|
|
3
|
-
from typing import Literal
|
|
4
|
-
|
|
5
|
-
import astropy.units as u
|
|
6
|
-
import haiku as hk
|
|
7
|
-
import jax
|
|
8
|
-
import jax.numpy as jnp
|
|
9
|
-
|
|
10
|
-
from astropy.constants import c, m_p
|
|
11
|
-
from haiku.initializers import Constant as HaikuConstant
|
|
12
|
-
from jax import lax
|
|
13
|
-
from jax.lax import fori_loop, scan
|
|
14
|
-
from jax.scipy.stats import norm as gaussian
|
|
15
|
-
|
|
16
|
-
from ...util.abundance import abundance_table, element_data
|
|
17
|
-
from ..abc import AdditiveComponent
|
|
18
|
-
from .apec_loaders import get_continuum, get_lines, get_pseudo, get_temperature
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
@jax.jit
|
|
22
|
-
def lerp(x, x0, x1, y0, y1):
|
|
23
|
-
"""
|
|
24
|
-
Linear interpolation routine
|
|
25
|
-
Return y(x) = (y0 * (x1 - x) + y1 * (x - x0)) / (x1 - x0)
|
|
26
|
-
"""
|
|
27
|
-
return (y0 * (x1 - x) + y1 * (x - x0)) / (x1 - x0)
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
@jax.jit
|
|
31
|
-
def interp_and_integrate(energy_low, energy_high, energy_ref, continuum_ref, end_index):
|
|
32
|
-
"""
|
|
33
|
-
This function interpolate & integrate the values of a tabulated reference continuum between two energy limits
|
|
34
|
-
Sorry for the boilerplate here, but be sure that it works !
|
|
35
|
-
|
|
36
|
-
Parameters:
|
|
37
|
-
energy_low: lower limit of the integral
|
|
38
|
-
energy_high: upper limit of the integral
|
|
39
|
-
energy_ref: energy grid of the reference continuum
|
|
40
|
-
continuum_ref: continuum values evaluated at energy_ref
|
|
41
|
-
|
|
42
|
-
"""
|
|
43
|
-
energy_ref = jnp.where(jnp.arange(energy_ref.shape[0]) < end_index, energy_ref, jnp.nan)
|
|
44
|
-
start_index = jnp.searchsorted(energy_ref, energy_low, side="left") - 1
|
|
45
|
-
end_index = jnp.searchsorted(energy_ref, energy_high, side="left") + 1
|
|
46
|
-
|
|
47
|
-
def body_func(index, value):
|
|
48
|
-
integrated_flux, previous_energy, previous_continuum = value
|
|
49
|
-
current_energy, current_continuum = energy_ref[index], continuum_ref[index]
|
|
50
|
-
|
|
51
|
-
# 5 cases
|
|
52
|
-
# Neither current and previous energies are within the integral limits > nothing is added to the integrated flux
|
|
53
|
-
# The left limit of the integral is between the current and previous energy > previous energy is set to the limit, previous continuum is interpolated, and then added to the integrated flux
|
|
54
|
-
# The right limit of the integral is between the current and previous energy > current energy is set to the limit, current continuum is interpolated, and then added to the integrated flux
|
|
55
|
-
# Both current and previous energies are within the integral limits -> add to the integrated flux
|
|
56
|
-
# Within
|
|
57
|
-
|
|
58
|
-
current_energy_is_between = (energy_low <= current_energy) * (current_energy < energy_high)
|
|
59
|
-
previous_energy_is_between = (energy_low <= previous_energy) * (
|
|
60
|
-
previous_energy < energy_high
|
|
61
|
-
)
|
|
62
|
-
energies_within_bins = (previous_energy <= energy_low) * (energy_high < current_energy)
|
|
63
|
-
|
|
64
|
-
case = (
|
|
65
|
-
(1 - previous_energy_is_between) * current_energy_is_between * 1
|
|
66
|
-
+ previous_energy_is_between * (1 - current_energy_is_between) * 2
|
|
67
|
-
+ (previous_energy_is_between * current_energy_is_between) * 3
|
|
68
|
-
+ energies_within_bins * 4
|
|
69
|
-
)
|
|
70
|
-
|
|
71
|
-
term_to_add = lax.switch(
|
|
72
|
-
case,
|
|
73
|
-
[
|
|
74
|
-
lambda pe, pc, ce, cc, el, er: 0.0, # 1
|
|
75
|
-
lambda pe, pc, ce, cc, el, er: (cc + lerp(el, pe, ce, pc, cc)) * (ce - el) / 2, # 2
|
|
76
|
-
lambda pe, pc, ce, cc, el, er: (pc + lerp(er, pe, ce, pc, cc)) * (er - pe) / 2, # 3
|
|
77
|
-
lambda pe, pc, ce, cc, el, er: (pc + cc) * (ce - pe) / 2, # 4
|
|
78
|
-
lambda pe, pc, ce, cc, el, er: (lerp(el, pe, ce, pc, cc) + lerp(er, pe, ce, pc, cc))
|
|
79
|
-
* (er - el)
|
|
80
|
-
/ 2,
|
|
81
|
-
# 5
|
|
82
|
-
],
|
|
83
|
-
previous_energy,
|
|
84
|
-
previous_continuum,
|
|
85
|
-
current_energy,
|
|
86
|
-
current_continuum,
|
|
87
|
-
energy_low,
|
|
88
|
-
energy_high,
|
|
89
|
-
)
|
|
90
|
-
|
|
91
|
-
return integrated_flux + term_to_add, current_energy, current_continuum
|
|
92
|
-
|
|
93
|
-
integrated_flux, _, _ = fori_loop(start_index, end_index, body_func, (0.0, 0.0, 0.0))
|
|
94
|
-
|
|
95
|
-
return integrated_flux
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
@jax.jit
|
|
99
|
-
def interp(e_low, e_high, energy_ref, continuum_ref, end_index):
|
|
100
|
-
energy_ref = jnp.where(jnp.arange(energy_ref.shape[0]) < end_index, energy_ref, jnp.nan)
|
|
101
|
-
|
|
102
|
-
return (
|
|
103
|
-
jnp.interp(e_high, energy_ref, continuum_ref) - jnp.interp(e_low, energy_ref, continuum_ref)
|
|
104
|
-
) / (e_high - e_low)
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
@jax.jit
|
|
108
|
-
def interp_flux(energy, energy_ref, continuum_ref, end_index):
|
|
109
|
-
"""
|
|
110
|
-
Iterate through an array of shape (energy_ref,) and compute the flux between the bins defined by energy
|
|
111
|
-
"""
|
|
112
|
-
|
|
113
|
-
def scanned_func(carry, unpack):
|
|
114
|
-
e_low, e_high = unpack
|
|
115
|
-
continuum = interp_and_integrate(e_low, e_high, energy_ref, continuum_ref, end_index)
|
|
116
|
-
|
|
117
|
-
return carry, continuum
|
|
118
|
-
|
|
119
|
-
_, continuum = scan(scanned_func, 0.0, (energy[:-1], energy[1:]))
|
|
120
|
-
|
|
121
|
-
return continuum
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
@jax.jit
|
|
125
|
-
def interp_flux_elements(energy_ref, continuum_ref, end_index, energy, abundances):
|
|
126
|
-
"""
|
|
127
|
-
Iterate through an array of shape (abundance, energy_ref) and compute the flux between the bins defined by energy
|
|
128
|
-
and weight the flux depending on the abundance of each element
|
|
129
|
-
"""
|
|
130
|
-
|
|
131
|
-
def scanned_func(_, unpack):
|
|
132
|
-
energy_ref, continuum_ref, end_idx = unpack
|
|
133
|
-
element_flux = interp_flux(energy, energy_ref, continuum_ref, end_idx)
|
|
134
|
-
|
|
135
|
-
return _, element_flux
|
|
136
|
-
|
|
137
|
-
_, flux = scan(scanned_func, 0.0, (energy_ref, continuum_ref, end_index))
|
|
138
|
-
|
|
139
|
-
return abundances @ flux
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
@jax.jit
|
|
143
|
-
def get_lines_contribution_broadening(
|
|
144
|
-
line_energy, line_element, line_emissivity, end_index, energy, abundances, total_broadening
|
|
145
|
-
):
|
|
146
|
-
def body_func(i, flux):
|
|
147
|
-
# Notice the -1 in line element to match the 0-based indexing
|
|
148
|
-
l_energy, l_emissivity, l_element = line_energy[i], line_emissivity[i], line_element[i] - 1
|
|
149
|
-
broadening = l_energy * total_broadening[l_element]
|
|
150
|
-
l_flux = gaussian.cdf(energy[1:], l_energy, broadening) - gaussian.cdf(
|
|
151
|
-
energy[:-1], l_energy, broadening
|
|
152
|
-
)
|
|
153
|
-
l_flux = l_flux * l_emissivity * abundances[l_element]
|
|
154
|
-
|
|
155
|
-
return flux + l_flux
|
|
156
|
-
|
|
157
|
-
return fori_loop(0, end_index, body_func, jnp.zeros_like(energy[:-1]))
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
@jax.jit
|
|
161
|
-
def continuum_func(energy, kT, abundances):
|
|
162
|
-
idx, kT_low, kT_high = get_temperature(kT)
|
|
163
|
-
continuum_low = interp_flux_elements(*get_continuum(idx), energy, abundances)
|
|
164
|
-
continuum_high = interp_flux_elements(*get_continuum(idx + 1), energy, abundances)
|
|
165
|
-
|
|
166
|
-
return lerp(kT, kT_low, kT_high, continuum_low, continuum_high)
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
@jax.jit
|
|
170
|
-
def pseudo_func(energy, kT, abundances):
|
|
171
|
-
idx, kT_low, kT_high = get_temperature(kT)
|
|
172
|
-
continuum_low = interp_flux_elements(*get_pseudo(idx), energy, abundances)
|
|
173
|
-
continuum_high = interp_flux_elements(*get_pseudo(idx + 1), energy, abundances)
|
|
174
|
-
|
|
175
|
-
return lerp(kT, kT_low, kT_high, continuum_low, continuum_high)
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
# @jax.custom_jvp
|
|
179
|
-
@jax.jit
|
|
180
|
-
def lines_func(energy, kT, abundances, broadening):
|
|
181
|
-
idx, kT_low, kT_high = get_temperature(kT)
|
|
182
|
-
line_low = get_lines_contribution_broadening(*get_lines(idx), energy, abundances, broadening)
|
|
183
|
-
line_high = get_lines_contribution_broadening(
|
|
184
|
-
*get_lines(idx + 1), energy, abundances, broadening
|
|
185
|
-
)
|
|
186
|
-
|
|
187
|
-
return lerp(kT, kT_low, kT_high, line_low, line_high)
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
class APEC(AdditiveComponent):
|
|
191
|
-
"""
|
|
192
|
-
APEC model implementation in pure JAX for X-ray spectral fitting.
|
|
193
|
-
|
|
194
|
-
!!! warning
|
|
195
|
-
This implementation is optimised for the CPU, it shows poor performance on the GPU.
|
|
196
|
-
"""
|
|
197
|
-
|
|
198
|
-
def __init__(
|
|
199
|
-
self,
|
|
200
|
-
continuum: bool = True,
|
|
201
|
-
pseudo: bool = True,
|
|
202
|
-
lines: bool = True,
|
|
203
|
-
thermal_broadening: bool = True,
|
|
204
|
-
turbulent_broadening: bool = True,
|
|
205
|
-
variant: Literal["none", "v", "vv"] = "none",
|
|
206
|
-
abundance_table: Literal[
|
|
207
|
-
"angr", "aspl", "feld", "aneb", "grsa", "wilm", "lodd", "lgpp", "lgps"
|
|
208
|
-
] = "angr",
|
|
209
|
-
trace_abundance: float = 1.0,
|
|
210
|
-
**kwargs,
|
|
211
|
-
):
|
|
212
|
-
super().__init__(**kwargs)
|
|
213
|
-
|
|
214
|
-
warnings.warn("Be aware that this APEC implementation is not meant to be used yet")
|
|
215
|
-
|
|
216
|
-
self.atomic_weights = jnp.asarray(element_data["atomic_weight"].to_numpy())
|
|
217
|
-
|
|
218
|
-
self.abundance_table = abundance_table
|
|
219
|
-
self.thermal_broadening = thermal_broadening
|
|
220
|
-
self.turbulent_broadening = turbulent_broadening
|
|
221
|
-
self.continuum_to_compute = continuum
|
|
222
|
-
self.pseudo_to_compute = pseudo
|
|
223
|
-
self.lines_to_compute = lines
|
|
224
|
-
self.trace_abundance = trace_abundance
|
|
225
|
-
self.variant = variant
|
|
226
|
-
|
|
227
|
-
def get_thermal_broadening(self):
|
|
228
|
-
r"""
|
|
229
|
-
Compute the thermal broadening $\sigma_T$ for each element using :
|
|
230
|
-
|
|
231
|
-
$$ \frac{\sigma_T}{E_{\text{line}}} = \frac{1}{c}\sqrt{\frac{k_{B} T}{A m_p}}$$
|
|
232
|
-
|
|
233
|
-
where $E_{\text{line}}$ is the energy of the line, $c$ is the speed of light, $k_{B}$ is the Boltzmann constant,
|
|
234
|
-
$T$ is the temperature, $A$ is the atomic weight of the element and $m_p$ is the proton mass.
|
|
235
|
-
"""
|
|
236
|
-
|
|
237
|
-
if self.thermal_broadening:
|
|
238
|
-
kT = hk.get_parameter("kT", [], init=HaikuConstant(6.5))
|
|
239
|
-
factor = 1 / c * (1 / m_p) ** (1 / 2)
|
|
240
|
-
factor = factor.to(u.keV ** (-1 / 2)).value
|
|
241
|
-
|
|
242
|
-
# Multiply this factor by Line_Energy * sqrt(kT/A) to get the broadening for a line
|
|
243
|
-
# This return value must be multiplied by the energy of the line to get actual broadening
|
|
244
|
-
return factor * jnp.sqrt(kT / self.atomic_weights)
|
|
245
|
-
|
|
246
|
-
else:
|
|
247
|
-
return jnp.zeros((30,))
|
|
248
|
-
|
|
249
|
-
def get_turbulent_broadening(self):
|
|
250
|
-
r"""
|
|
251
|
-
Return the turbulent broadening using :
|
|
252
|
-
|
|
253
|
-
$$\frac{\sigma_\text{turb}}{E_{\text{line}}} = \frac{\sigma_{v ~ ||}}{c}$$
|
|
254
|
-
|
|
255
|
-
where $\sigma_{v ~ ||}$ is the velocity dispersion along the line of sight in km/s.
|
|
256
|
-
"""
|
|
257
|
-
if self.turbulent_broadening:
|
|
258
|
-
# This return value must be multiplied by the energy of the line to get actual broadening
|
|
259
|
-
return (
|
|
260
|
-
hk.get_parameter("Velocity", [], init=HaikuConstant(100.0)) / c.to(u.km / u.s).value
|
|
261
|
-
)
|
|
262
|
-
else:
|
|
263
|
-
return 0.0
|
|
264
|
-
|
|
265
|
-
def get_parameters(self):
|
|
266
|
-
none_elements = ["C", "N", "O", "Ne", "Mg", "Al", "Si", "S", "Ar", "Ca", "Fe", "Ni"]
|
|
267
|
-
v_elements = ["He", "C", "N", "O", "Ne", "Mg", "Al", "Si", "S", "Ar", "Ca", "Fe", "Ni"]
|
|
268
|
-
trace_elements = (
|
|
269
|
-
jnp.asarray([3, 4, 5, 9, 11, 15, 17, 19, 21, 22, 23, 24, 25, 27, 29, 30], dtype=int) - 1
|
|
270
|
-
)
|
|
271
|
-
|
|
272
|
-
# Set abundances of trace element (will be overwritten in the vv case)
|
|
273
|
-
abund = jnp.ones((30,)).at[trace_elements].multiply(self.trace_abundance)
|
|
274
|
-
|
|
275
|
-
if self.variant == "vv":
|
|
276
|
-
for i, element in enumerate(abundance_table["Element"]):
|
|
277
|
-
if element != "H":
|
|
278
|
-
abund = abund.at[i].set(hk.get_parameter(element, [], init=HaikuConstant(1.0)))
|
|
279
|
-
|
|
280
|
-
elif self.variant == "v":
|
|
281
|
-
for i, element in enumerate(abundance_table["Element"]):
|
|
282
|
-
if element != "H" and element in v_elements:
|
|
283
|
-
abund = abund.at[i].set(hk.get_parameter(element, [], init=HaikuConstant(1.0)))
|
|
284
|
-
|
|
285
|
-
else:
|
|
286
|
-
Z = hk.get_parameter("Abundance", [], init=HaikuConstant(1.0))
|
|
287
|
-
for i, element in enumerate(abundance_table["Element"]):
|
|
288
|
-
if element != "H" and element in none_elements:
|
|
289
|
-
abund = abund.at[i].set(Z)
|
|
290
|
-
|
|
291
|
-
if abund != "angr":
|
|
292
|
-
abund = abund * jnp.asarray(
|
|
293
|
-
abundance_table[self.abundance_table] / abundance_table["angr"]
|
|
294
|
-
)
|
|
295
|
-
|
|
296
|
-
# Set the temperature, redshift, normalisation
|
|
297
|
-
kT = hk.get_parameter("kT", [], init=HaikuConstant(6.5))
|
|
298
|
-
z = hk.get_parameter("Redshift", [], init=HaikuConstant(0.0))
|
|
299
|
-
norm = hk.get_parameter("norm", [], init=HaikuConstant(1.0))
|
|
300
|
-
|
|
301
|
-
return kT, z, norm, abund
|
|
302
|
-
|
|
303
|
-
def emission_lines(self, e_low, e_high):
|
|
304
|
-
# Get the parameters and extract the relevant data
|
|
305
|
-
energy = jnp.hstack([e_low, e_high[-1]])
|
|
306
|
-
kT, z, norm, abundances = self.get_parameters()
|
|
307
|
-
total_broadening = jnp.hypot(self.get_thermal_broadening(), self.get_turbulent_broadening())
|
|
308
|
-
energy = energy * (1 + z)
|
|
309
|
-
|
|
310
|
-
continuum = continuum_func(energy, kT, abundances) if self.continuum_to_compute else 0.0
|
|
311
|
-
pseudo_continuum = pseudo_func(energy, kT, abundances) if self.pseudo_to_compute else 0.0
|
|
312
|
-
lines = (
|
|
313
|
-
lines_func(energy, kT, abundances, total_broadening) if self.lines_to_compute else 0.0
|
|
314
|
-
)
|
|
315
|
-
|
|
316
|
-
return (continuum + pseudo_continuum + lines) * norm * 1e14 / (1 + z), (e_low + e_high) / 2
|
|
@@ -1,73 +0,0 @@
|
|
|
1
|
-
"""This module contains the functions that load the APEC tables from the HDF5 file. They are implemented as JAX
|
|
2
|
-
pure callback to enable reading data from the files without saturating the memory."""
|
|
3
|
-
|
|
4
|
-
import h5netcdf
|
|
5
|
-
import jax
|
|
6
|
-
import jax.numpy as jnp
|
|
7
|
-
|
|
8
|
-
from ...util.online_storage import table_manager
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
@jax.jit
|
|
12
|
-
def temperature_table_getter():
|
|
13
|
-
with h5netcdf.File(table_manager.fetch("apec.nc"), "r") as f:
|
|
14
|
-
temperature = jnp.asarray(f["/temperature"])
|
|
15
|
-
|
|
16
|
-
return temperature
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
@jax.jit
|
|
20
|
-
def get_temperature(kT):
|
|
21
|
-
temperature = temperature_table_getter()
|
|
22
|
-
idx = jnp.searchsorted(temperature, kT) - 1
|
|
23
|
-
|
|
24
|
-
return idx, temperature[idx], temperature[idx + 1]
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
@jax.jit
|
|
28
|
-
def continuum_table_getter():
|
|
29
|
-
with h5netcdf.File(table_manager.fetch("apec.nc"), "r") as f:
|
|
30
|
-
continuum_energy = jnp.asarray(f["/continuum_energy"])
|
|
31
|
-
continuum_emissivity = jnp.asarray(f["/continuum_emissivity"])
|
|
32
|
-
continuum_end_index = jnp.asarray(f["/continuum_end_index"])
|
|
33
|
-
|
|
34
|
-
return continuum_energy, continuum_emissivity, continuum_end_index
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
@jax.jit
|
|
38
|
-
def pseudo_table_getter():
|
|
39
|
-
with h5netcdf.File(table_manager.fetch("apec.nc"), "r") as f:
|
|
40
|
-
pseudo_energy = jnp.asarray(f["/pseudo_energy"])
|
|
41
|
-
pseudo_emissivity = jnp.asarray(f["/pseudo_emissivity"])
|
|
42
|
-
pseudo_end_index = jnp.asarray(f["/pseudo_end_index"])
|
|
43
|
-
|
|
44
|
-
return pseudo_energy, pseudo_emissivity, pseudo_end_index
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
@jax.jit
|
|
48
|
-
def line_table_getter():
|
|
49
|
-
with h5netcdf.File(table_manager.fetch("apec.nc"), "r") as f:
|
|
50
|
-
line_energy = jnp.asarray(f["/line_energy"])
|
|
51
|
-
line_element = jnp.asarray(f["/line_element"])
|
|
52
|
-
line_emissivity = jnp.asarray(f["/line_emissivity"])
|
|
53
|
-
line_end_index = jnp.asarray(f["/line_end_index"])
|
|
54
|
-
|
|
55
|
-
return line_energy, line_element, line_emissivity, line_end_index
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
@jax.jit
|
|
59
|
-
def get_continuum(idx):
|
|
60
|
-
continuum_energy, continuum_emissivity, continuum_end_index = continuum_table_getter()
|
|
61
|
-
return continuum_energy[idx], continuum_emissivity[idx], continuum_end_index[idx]
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
@jax.jit
|
|
65
|
-
def get_pseudo(idx):
|
|
66
|
-
pseudo_energy, pseudo_emissivity, pseudo_end_index = pseudo_table_getter()
|
|
67
|
-
return pseudo_energy[idx], pseudo_emissivity[idx], pseudo_end_index[idx]
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
@jax.jit
|
|
71
|
-
def get_lines(idx):
|
|
72
|
-
line_energy, line_element, line_emissivity, line_end_index = line_table_getter()
|
|
73
|
-
return line_energy[idx], line_element[idx], line_emissivity[idx], line_end_index[idx]
|
|
File without changes
|
|
File without changes
|
|
File without changes
|