jaxspec 0.2.0__tar.gz → 0.2.1.dev2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/PKG-INFO +3 -3
  2. {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/pyproject.toml +4 -4
  3. {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/analysis/results.py +14 -4
  4. {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/data/util.py +68 -8
  5. {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/fit.py +22 -2
  6. {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/model/additive.py +13 -13
  7. {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/model/multiplicative.py +3 -3
  8. {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/LICENSE.md +0 -0
  9. {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/README.md +0 -0
  10. {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/__init__.py +0 -0
  11. {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/_fit/__init__.py +0 -0
  12. {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/_fit/_build_model.py +0 -0
  13. {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/analysis/__init__.py +0 -0
  14. {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/analysis/_plot.py +0 -0
  15. {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/analysis/compare.py +0 -0
  16. {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/data/__init__.py +0 -0
  17. {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/data/instrument.py +0 -0
  18. {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/data/obsconf.py +0 -0
  19. {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/data/observation.py +0 -0
  20. {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/data/ogip.py +0 -0
  21. {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/model/__init__.py +0 -0
  22. {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/model/_graph_util.py +0 -0
  23. {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/model/abc.py +0 -0
  24. {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/model/background.py +0 -0
  25. {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/model/list.py +0 -0
  26. {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/scripts/__init__.py +0 -0
  27. {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/scripts/debug.py +0 -0
  28. {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/util/__init__.py +0 -0
  29. {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/util/abundance.py +0 -0
  30. {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/util/integrate.py +0 -0
  31. {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/util/misc.py +0 -0
  32. {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/util/online_storage.py +0 -0
  33. {jaxspec-0.2.0 → jaxspec-0.2.1.dev2}/src/jaxspec/util/typing.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: jaxspec
3
- Version: 0.2.0
3
+ Version: 0.2.1.dev2
4
4
  Summary: jaxspec is a bayesian spectral fitting library for X-ray astronomy.
5
5
  License: MIT
6
6
  Author: sdupourque
@@ -18,7 +18,7 @@ Requires-Dist: chainconsumer (>=1.1.2,<2.0.0)
18
18
  Requires-Dist: cmasher (>=1.6.3,<2.0.0)
19
19
  Requires-Dist: flax (>=0.10.1,<0.11.0)
20
20
  Requires-Dist: interpax (>=0.3.3,<0.4.0)
21
- Requires-Dist: jax (>=0.4.37,<0.5.0)
21
+ Requires-Dist: jax (>=0.5.0,<0.6.0)
22
22
  Requires-Dist: jaxns (>=2.6.7,<3.0.0)
23
23
  Requires-Dist: jaxopt (>=0.8.1,<0.9.0)
24
24
  Requires-Dist: matplotlib (>=3.8.0,<4.0.0)
@@ -26,7 +26,7 @@ Requires-Dist: mendeleev (>=0.15,<0.20)
26
26
  Requires-Dist: networkx (>=3.1,<4.0)
27
27
  Requires-Dist: numpy (<2.0.0)
28
28
  Requires-Dist: numpyro (>=0.16.1,<0.17.0)
29
- Requires-Dist: optimistix (>=0.0.7,<0.0.10)
29
+ Requires-Dist: optimistix (>=0.0.7,<0.0.11)
30
30
  Requires-Dist: pandas (>=2.2.0,<3.0.0)
31
31
  Requires-Dist: pooch (>=1.8.2,<2.0.0)
32
32
  Requires-Dist: scipy (<1.15)
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "jaxspec"
3
- version = "0.2.0"
3
+ version = "0.2.1dev-2"
4
4
  description = "jaxspec is a bayesian spectral fitting library for X-ray astronomy."
5
5
  authors = ["sdupourque <sdupourque@irap.omp.eu>"]
6
6
  license = "MIT"
@@ -11,7 +11,7 @@ documentation = "https://jaxspec.readthedocs.io/en/latest/"
11
11
 
12
12
  [tool.poetry.dependencies]
13
13
  python = ">=3.10,<3.13"
14
- jax = "^0.4.37"
14
+ jax = "^0.5.0"
15
15
  numpy = "<2.0.0"
16
16
  pandas = "^2.2.0"
17
17
  astropy = "^6.0.0"
@@ -26,7 +26,7 @@ jaxopt = "^0.8.1"
26
26
  tinygp = "^0.3.0"
27
27
  seaborn = "^0.13.1"
28
28
  sparse = "^0.15.4"
29
- optimistix = ">=0.0.7,<0.0.10"
29
+ optimistix = ">=0.0.7,<0.0.11"
30
30
  scipy = "<1.15"
31
31
  mendeleev = ">=0.15,<0.20"
32
32
  jaxns = "^2.6.7"
@@ -56,7 +56,7 @@ testbook = "^0.4.2"
56
56
 
57
57
  [tool.poetry.group.dev.dependencies]
58
58
  pre-commit = ">=3.5,<5.0"
59
- ruff = ">=0.2.1,<0.9.0"
59
+ ruff = ">=0.2.1,<0.10.0"
60
60
  jupyterlab = "^4.0.7"
61
61
  notebook = "^7.0.6"
62
62
  ipywidgets = "^8.1.1"
@@ -391,6 +391,8 @@ class FitResult:
391
391
  alpha_envelope: (float, float) = (0.15, 0.25),
392
392
  style: str | Any = "default",
393
393
  title: str | None = None,
394
+ figsize: tuple[float, float] = (6, 6),
395
+ x_lims: tuple[float, float] | None = None,
394
396
  ) -> list[plt.Figure]:
395
397
  r"""
396
398
  Plot the posterior predictive distribution of the model. It also features a residual plot, defined using the
@@ -400,7 +402,7 @@ class FitResult:
400
402
  {(\text{Posterior counts})_{84\%}-(\text{Posterior counts})_{16\%}} $$
401
403
 
402
404
  Parameters:
403
- percentile: The percentile of the posterior predictive distribution to plot.
405
+ n_sigmas: The number of sigmas to plot the envelops.
404
406
  x_unit: The units of the x-axis. It can be either a string (parsable by astropy.units) or an astropy unit. It must be homogeneous to either a length, a frequency or an energy.
405
407
  y_type: The type of the y-axis. It can be either "counts", "countrate", "photon_flux" or "photon_flux_density".
406
408
  plot_background: Whether to plot the background model if it is included in the fit.
@@ -408,6 +410,9 @@ class FitResult:
408
410
  scale: The axes scaling
409
411
  alpha_envelope: The transparency range for envelops
410
412
  style: The style of the plot. It can be either a string or a matplotlib style context.
413
+ title: The title of the plot.
414
+ figsize: The size of the figure.
415
+ x_lims: The limits of the x-axis.
411
416
 
412
417
  Returns:
413
418
  A list of matplotlib figures for each observation in the model.
@@ -436,7 +441,7 @@ class FitResult:
436
441
  fig, ax = plt.subplots(
437
442
  2,
438
443
  1,
439
- figsize=(6, 6),
444
+ figsize=figsize,
440
445
  sharex="col",
441
446
  height_ratios=[0.7, 0.3],
442
447
  )
@@ -525,8 +530,10 @@ class FitResult:
525
530
  alpha_envelope=alpha_envelope,
526
531
  )
527
532
 
533
+ name = component_name.split("*")[-1]
534
+
528
535
  legend_plots += component_plot
529
- legend_labels.append(component_name)
536
+ legend_labels.append(name)
530
537
 
531
538
  if self.background_model is not None and plot_background:
532
539
  # We plot the background only if it is included in the fit, i.e. by subtracting
@@ -617,6 +624,9 @@ class FitResult:
617
624
  ax[0].set_xscale("log")
618
625
  ax[0].set_yscale("log")
619
626
 
627
+ if x_lims is not None:
628
+ ax[0].set_xlim(*x_lims)
629
+
620
630
  fig.align_ylabels()
621
631
  plt.subplots_adjust(hspace=0.0)
622
632
  fig.tight_layout()
@@ -654,7 +664,7 @@ class FitResult:
654
664
  """
655
665
 
656
666
  consumer = ChainConsumer()
657
- consumer.add_chain(self.to_chain(self.model.to_string()))
667
+ consumer.add_chain(self.to_chain("Results"))
658
668
  consumer.set_plot_config(config)
659
669
 
660
670
  # Context for default mpl style
@@ -1,14 +1,16 @@
1
1
  from collections.abc import Mapping
2
2
  from pathlib import Path
3
- from typing import Literal, TypeVar
3
+ from typing import TYPE_CHECKING, Literal, TypeVar
4
4
 
5
5
  import jax
6
+ import jax.numpy as jnp
7
+ import numpy as np
6
8
  import numpyro
7
9
 
8
10
  from astropy.io import fits
11
+ from jax.experimental.sparse import BCOO
9
12
  from numpyro import handlers
10
13
 
11
- from .._fit._build_model import forward_model
12
14
  from ..model.abc import SpectralModel
13
15
  from ..util.online_storage import table_manager
14
16
  from . import Instrument, ObsConfiguration, Observation
@@ -16,6 +18,10 @@ from . import Instrument, ObsConfiguration, Observation
16
18
  K = TypeVar("K")
17
19
  V = TypeVar("V")
18
20
 
21
+ if TYPE_CHECKING:
22
+ from ..data import ObsConfiguration
23
+ from ..model.abc import SpectralModel
24
+
19
25
 
20
26
  def load_example_pha(
21
27
  source: Literal["NGC7793_ULX4_PN", "NGC7793_ULX4_ALL"],
@@ -124,8 +130,40 @@ def load_example_obsconf(source: Literal["NGC7793_ULX4_PN", "NGC7793_ULX4_ALL"])
124
130
  raise ValueError(f"{source} not recognized.")
125
131
 
126
132
 
133
+ def forward_model_with_multiple_inputs(
134
+ model: "SpectralModel",
135
+ parameters,
136
+ obs_configuration: "ObsConfiguration",
137
+ sparse=False,
138
+ ):
139
+ energies = np.asarray(obs_configuration.in_energies)
140
+ parameter_dims = next(iter(parameters.values())).shape
141
+
142
+ def flux_func(p):
143
+ return model.photon_flux(p, *energies)
144
+
145
+ for _ in parameter_dims:
146
+ flux_func = jax.vmap(flux_func)
147
+
148
+ flux_func = jax.jit(flux_func)
149
+
150
+ if sparse:
151
+ # folding.transfer_matrix.data.density > 0.015 is a good criterion to consider sparsify
152
+ transfer_matrix = BCOO.from_scipy_sparse(
153
+ obs_configuration.transfer_matrix.data.to_scipy_sparse().tocsr()
154
+ )
155
+
156
+ else:
157
+ transfer_matrix = np.asarray(obs_configuration.transfer_matrix.data.todense())
158
+
159
+ expected_counts = jnp.matvec(transfer_matrix, flux_func(parameters))
160
+
161
+ # The result is clipped at 1e-6 to avoid 0 round-off and diverging likelihoods
162
+ return jnp.clip(expected_counts, a_min=1e-6)
163
+
164
+
127
165
  def fakeit_for_multiple_parameters(
128
- instrument: ObsConfiguration | list[ObsConfiguration],
166
+ obsconfs: ObsConfiguration | list[ObsConfiguration],
129
167
  model: SpectralModel,
130
168
  parameters: Mapping[K, V],
131
169
  rng_key: int = 0,
@@ -134,10 +172,32 @@ def fakeit_for_multiple_parameters(
134
172
  ):
135
173
  """
136
174
  Convenience function to simulate multiple spectra from a given model and a set of parameters.
175
+ This is supposed to be somewhat optimized and can handle multiple parameters at once without blowing
176
+ up the memory. The parameters should be passed as a dictionary with the parameter name as the key and
177
+ the parameter values as the values, the value can be a scalar or a nd-array.
178
+
179
+ # Example:
180
+
181
+ ``` python
182
+ from jaxspec.data.util import fakeit_for_multiple_parameters
183
+ from numpy.random import default_rng
184
+
185
+ rng = default_rng(42)
186
+ size = (10, 30)
187
+
188
+ parameters = {
189
+ "tbabs_1_nh": rng.uniform(0.1, 0.4, size=size),
190
+ "powerlaw_1_alpha": rng.uniform(1, 3, size=size),
191
+ "powerlaw_1_norm": rng.exponential(10 ** (-0.5), size=size),
192
+ "blackbodyrad_1_kT": rng.uniform(0.1, 3.0, size=size),
193
+ "blackbodyrad_1_norm": rng.exponential(10 ** (-3), size=size)
194
+ }
137
195
 
196
+ spectra = fakeit_for_multiple_parameters(obsconf, model, parameters)
197
+ ```
138
198
 
139
199
  Parameters:
140
- instrument: The instrumental setup.
200
+ obsconfs: The observational setup(s).
141
201
  model: The model to use.
142
202
  parameters: The parameters of the model.
143
203
  rng_key: The random number generator seed.
@@ -145,12 +205,12 @@ def fakeit_for_multiple_parameters(
145
205
  sparsify_matrix: Whether to sparsify the matrix or not.
146
206
  """
147
207
 
148
- instruments = [instrument] if isinstance(instrument, ObsConfiguration) else instrument
208
+ obsconf_list = [obsconfs] if isinstance(obsconfs, ObsConfiguration) else obsconfs
149
209
  fakeits = []
150
210
 
151
- for i, obs in enumerate(instruments):
152
- countrate = jax.vmap(lambda p: forward_model(model, p, instrument, sparse=sparsify_matrix))(
153
- parameters
211
+ for i, obsconf in enumerate(obsconf_list):
212
+ countrate = forward_model_with_multiple_inputs(
213
+ model, parameters, obsconf, sparse=sparsify_matrix
154
214
  )
155
215
 
156
216
  if apply_stat:
@@ -13,7 +13,9 @@ import matplotlib.pyplot as plt
13
13
  import numpyro
14
14
 
15
15
  from jax import random
16
+ from jax.experimental import mesh_utils
16
17
  from jax.random import PRNGKey
18
+ from jax.sharding import PositionalSharding
17
19
  from numpyro.contrib.nested_sampling import NestedSampler
18
20
  from numpyro.distributions import Poisson, TransformedDistribution
19
21
  from numpyro.infer import AIES, ESS, MCMC, NUTS, Predictive
@@ -312,14 +314,27 @@ class BayesianModel:
312
314
  Check if the prior distribution include the observed data.
313
315
  """
314
316
  key_prior, key_posterior = jax.random.split(key, 2)
317
+ n_devices = len(jax.local_devices())
318
+ sharding = PositionalSharding(mesh_utils.create_device_mesh((n_devices,)))
319
+
320
+ # Sample from prior and correct if the number of samples is not a multiple of the number of devices
321
+ if num_samples % n_devices != 0:
322
+ num_samples = num_samples + n_devices - (num_samples % n_devices)
323
+
315
324
  prior_params = self.prior_samples(key=key_prior, num_samples=num_samples)
316
- posterior_observations = self.mock_observations(prior_params, key=key_posterior)
325
+
326
+ # Split the parameters on every device
327
+ sharded_parameters = jax.device_put(prior_params, sharding)
328
+ posterior_observations = self.mock_observations(sharded_parameters, key=key_posterior)
317
329
 
318
330
  for key, value in self.observation_container.items():
319
331
  fig, ax = plt.subplots(
320
332
  nrows=2, ncols=1, sharex=True, figsize=(5, 6), height_ratios=[3, 1]
321
333
  )
322
334
 
335
+ legend_plots = []
336
+ legend_labels = []
337
+
323
338
  y_observed, y_observed_low, y_observed_high = _error_bars_for_observed_data(
324
339
  value.folded_counts.values, 1.0, "ct"
325
340
  )
@@ -337,6 +352,11 @@ class BayesianModel:
337
352
  ax[0], value.out_energies, posterior_observations["obs_" + key], n_sigmas=3
338
353
  )
339
354
 
355
+ legend_plots.append((true_data_plot,))
356
+ legend_labels.append("Observed")
357
+ legend_plots += prior_plot
358
+ legend_labels.append("Prior Predictive")
359
+
340
360
  # rank = np.vstack((posterior_observations["obs_" + key], value.folded_counts.values)).argsort(axis=0)[-1] / (num_samples) * 100
341
361
  counts = posterior_observations["obs_" + key]
342
362
  observed = value.folded_counts.values
@@ -363,7 +383,7 @@ class BayesianModel:
363
383
  ax[1].set_ylim(0, 100)
364
384
  ax[0].set_xlim(value.out_energies.min(), value.out_energies.max())
365
385
  ax[0].loglog()
366
- ax[0].legend(loc="upper right")
386
+ ax[0].legend(legend_plots, legend_labels)
367
387
  plt.suptitle(f"Prior Predictive coverage for {key}")
368
388
  plt.tight_layout()
369
389
  plt.show()
@@ -156,13 +156,13 @@ class Gauss(AdditiveComponent):
156
156
  $$\mathcal{M}\left( E \right) = \frac{K}{\sigma \sqrt{2 \pi}}\exp\left(\frac{-(E-E_L)^2}{2\sigma^2}\right)$$
157
157
 
158
158
  !!! abstract "Parameters"
159
- * $E_L$ (`E_l`) $\left[\text{keV}\right]$ : Energy of the line
159
+ * $E_L$ (`El`) $\left[\text{keV}\right]$ : Energy of the line
160
160
  * $\sigma$ (`sigma`) $\left[\text{keV}\right]$ : Width of the line
161
161
  * $K$ (`norm`) $\left[\frac{\text{photons}}{\text{cm}^2\text{s}}\right]$ : Normalization
162
162
  """
163
163
 
164
164
  def __init__(self):
165
- self.E_l = nnx.Param(2.0)
165
+ self.El = nnx.Param(2.0)
166
166
  self.sigma = nnx.Param(1e-2)
167
167
  self.norm = nnx.Param(1.0)
168
168
 
@@ -170,12 +170,12 @@ class Gauss(AdditiveComponent):
170
170
  return self.norm * (
171
171
  jsp.stats.norm.cdf(
172
172
  e_high,
173
- loc=jnp.asarray(self.E_l, dtype=jnp.float64),
173
+ loc=jnp.asarray(self.El, dtype=jnp.float64),
174
174
  scale=jnp.asarray(self.sigma, dtype=jnp.float64),
175
175
  )
176
176
  - jsp.stats.norm.cdf(
177
177
  e_low,
178
- loc=jnp.asarray(self.E_l, dtype=jnp.float64),
178
+ loc=jnp.asarray(self.El, dtype=jnp.float64),
179
179
  scale=jnp.asarray(self.sigma, dtype=jnp.float64),
180
180
  )
181
181
  )
@@ -246,13 +246,13 @@ class Agauss(AdditiveComponent):
246
246
  \frac{K}{\sigma \sqrt{2 \pi}} \exp\left(\frac{-(\lambda - \lambda_L)^2}{2 \sigma^2}\right)$$
247
247
 
248
248
  !!! abstract "Parameters"
249
- * $\lambda_L$ (`lambda_l`) $\left[\unicode{x212B}\right]$ : Wavelength of the line in Angström
249
+ * $\lambda_L$ (`lambdal`) $\left[\unicode{x212B}\right]$ : Wavelength of the line in Angström
250
250
  * $\sigma$ (`sigma`) $\left[\unicode{x212B}\right]$ : Width of the line width in Angström
251
251
  * $K$ (`norm`) $\left[\frac{\unicode{x212B}~\text{photons}}{\text{keV}\text{cm}^2\text{s}}\right]$: Normalization
252
252
  """
253
253
 
254
254
  def __init__(self):
255
- self.lambda_l = nnx.Param(12.0)
255
+ self.lambdal = nnx.Param(12.0)
256
256
  self.sigma = nnx.Param(1e-2)
257
257
  self.norm = nnx.Param(1.0)
258
258
 
@@ -261,7 +261,7 @@ class Agauss(AdditiveComponent):
261
261
 
262
262
  return self.norm * jsp.stats.norm.pdf(
263
263
  hc / energy,
264
- loc=jnp.asarray(self.lambda_l, dtype=jnp.float64),
264
+ loc=jnp.asarray(self.lambdal, dtype=jnp.float64),
265
265
  scale=jnp.asarray(self.sigma, dtype=jnp.float64),
266
266
  )
267
267
 
@@ -275,14 +275,14 @@ class Zagauss(AdditiveComponent):
275
275
  \frac{K (1+z)}{\sigma \sqrt{2 \pi}} \exp\left(\frac{-(\lambda/(1+z) - \lambda_L)^2}{2 \sigma^2}\right)$$
276
276
 
277
277
  !!! abstract "Parameters"
278
- * $\lambda_L$ (`lambda_l`) $\left[\unicode{x212B}\right]$ : Wavelength of the line in Angström
278
+ * $\lambda_L$ (`lambdal`) $\left[\unicode{x212B}\right]$ : Wavelength of the line in Angström
279
279
  * $\sigma$ (`sigma`) $\left[\unicode{x212B}\right]$ : Width of the line width in Angström
280
280
  * $z$ (`redshift`) $\left[\text{dimensionless}\right]$ : Redshift
281
281
  * $K$ (`norm`) $\left[\frac{\unicode{x212B}~\text{photons}}{\text{keV}\text{cm}^2\text{s}}\right]$ : Normalization
282
282
  """
283
283
 
284
284
  def __init__(self):
285
- self.lambda_l = nnx.Param(12.0)
285
+ self.lambdal = nnx.Param(12.0)
286
286
  self.sigma = nnx.Param(1e-2)
287
287
  self.redshift = nnx.Param(0.0)
288
288
  self.norm = nnx.Param(1.0)
@@ -297,7 +297,7 @@ class Zagauss(AdditiveComponent):
297
297
  * (1 + redshift)
298
298
  * jsp.stats.norm.pdf(
299
299
  (hc / energy) / (1 + redshift),
300
- loc=jnp.asarray(self.lambda_l, dtype=jnp.float64),
300
+ loc=jnp.asarray(self.lambdal, dtype=jnp.float64),
301
301
  scale=jnp.asarray(self.sigma, dtype=jnp.float64),
302
302
  )
303
303
  )
@@ -311,14 +311,14 @@ class Zgauss(AdditiveComponent):
311
311
  \frac{K}{(1+z) \sigma \sqrt{2 \pi}}\exp\left(\frac{-(E(1+z)-E_L)^2}{2\sigma^2}\right)$$
312
312
 
313
313
  !!! abstract "Parameters"
314
- * $E_L$ (`E_l`) $\left[\text{keV}\right]$ : Energy of the line
314
+ * $E_L$ (`El`) $\left[\text{keV}\right]$ : Energy of the line
315
315
  * $\sigma$ (`sigma`) $\left[\text{keV}\right]$ : Width of the line
316
316
  * $K$ (`norm`) $\left[\frac{\text{photons}}{\text{cm}^2\text{s}}\right]$ : Normalization
317
317
  * $z$ (`redshift`) $\left[\text{dimensionless}\right]$ : Redshift
318
318
  """
319
319
 
320
320
  def __init__(self):
321
- self.E_l = nnx.Param(2.0)
321
+ self.El = nnx.Param(2.0)
322
322
  self.sigma = nnx.Param(1e-2)
323
323
  self.redshift = nnx.Param(0.0)
324
324
  self.norm = nnx.Param(1.0)
@@ -326,7 +326,7 @@ class Zgauss(AdditiveComponent):
326
326
  def continuum(self, energy) -> (jax.Array, jax.Array):
327
327
  return (self.norm / (1 + self.redshift)) * jsp.stats.norm.pdf(
328
328
  energy * (1 + self.redshift),
329
- loc=jnp.asarray(self.E_l, dtype=jnp.float64),
329
+ loc=jnp.asarray(self.El, dtype=jnp.float64),
330
330
  scale=jnp.asarray(self.sigma, dtype=jnp.float64),
331
331
  )
332
332
 
@@ -228,9 +228,9 @@ class Tbpcf(MultiplicativeComponent):
228
228
  self.nh = nnx.Param(1.0)
229
229
  self.f = nnx.Param(0.2)
230
230
 
231
- def continuum(self, energy):
231
+ def factor(self, energy):
232
232
  sigma = jnp.interp(energy, self.energy, self.sigma, left=1e9, right=0.0)
233
- return self.f * jnp.exp(-self.nh * sigma) + (1 - self.f)
233
+ return self.f * jnp.exp(-self.nh * sigma) + (1.0 - self.f)
234
234
 
235
235
 
236
236
  class FDcut(MultiplicativeComponent):
@@ -250,5 +250,5 @@ class FDcut(MultiplicativeComponent):
250
250
  self.Ec = nnx.Param(1.0)
251
251
  self.Ef = nnx.Param(3.0)
252
252
 
253
- def continuum(self, energy):
253
+ def factor(self, energy):
254
254
  return (1 + jnp.exp((energy - self.Ec) / self.Ef)) ** -1
File without changes
File without changes