jaxspec 0.1.3__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,22 +1,27 @@
1
1
  from abc import ABC, abstractmethod
2
2
 
3
+ import jax
3
4
  import jax.numpy as jnp
4
5
  import numpyro
5
6
  import numpyro.distributions as dist
6
7
 
7
8
  from jax.scipy.integrate import trapezoid
9
+ from numpyro.distributions import Poisson
8
10
  from tinygp import GaussianProcess, kernels
9
11
 
12
+ from .._fit._build_model import build_prior, forward_model
13
+ from .abc import SpectralModel
14
+
10
15
 
11
16
  class BackgroundModel(ABC):
12
17
  """
13
- This class handles the modelization of backgrounds in our spectra. This is handled in a separate class for now
18
+ Handles the background modelling in our spectra. This is handled in a separate class for now
14
19
  since backgrounds can be phenomenological models fitted directly on the folded spectrum. This is not the case for
15
20
  the source model, which is fitted on the unfolded spectrum. This might be changed later.
16
21
  """
17
22
 
18
23
  @abstractmethod
19
- def numpyro_model(self, energy, observed_counts, name: str = "bkg", observed=True):
24
+ def numpyro_model(self, observation, name: str = "", observed=True):
20
25
  """
21
26
  Build the model for the background.
22
27
  """
@@ -25,7 +30,7 @@ class BackgroundModel(ABC):
25
30
 
26
31
  class SubtractedBackground(BackgroundModel):
27
32
  """
28
- This class is to use when implying that the observed background should be simply subtracted from the observed.
33
+ Define a model where the observed background is simply subtracted from the observed.
29
34
 
30
35
  !!! danger
31
36
 
@@ -35,93 +40,40 @@ class SubtractedBackground(BackgroundModel):
35
40
 
36
41
  """
37
42
 
38
- def numpyro_model(self, obs, spectral_model, name: str = "bkg", observed=True):
39
- _, observed_counts = obs.out_energies, obs.folded_background.data
40
- numpyro.deterministic(f"{name}", observed_counts)
43
+ def numpyro_model(self, observation, name: str = "", observed=True):
44
+ _, observed_counts = observation.out_energies, observation.folded_background.data
45
+ numpyro.deterministic(f"bkg_{name}", observed_counts)
41
46
 
42
- return jnp.zeros_like(observed_counts)
47
+ return observed_counts
43
48
 
44
49
 
45
50
  class BackgroundWithError(BackgroundModel):
46
51
  """
47
- This class is to use when implying that the observed background should be simply subtracted from the observed. It
52
+ Define a model where the observed background is subtracted from the observed accounting for its intrinsic spread. It
48
53
  fits a countrate for each background bin assuming a Poisson distribution.
49
-
50
- !!! warning
51
- This is the same as [`ConjugateBackground`][jaxspec.model.background.ConjugateBackground]
52
- but slower since it performs the fit using MCMC instead of analytical solution.
53
54
  """
54
55
 
55
- def numpyro_model(self, obs, spectral_model, name: str = "bkg", observed=True):
56
+ def numpyro_model(self, obs, name: str = "", observed=True):
57
+ # We can't use the build_prior_function method here because the parameter size varies
58
+ # with the current observation. It must be instantiated in place.
56
59
  # Gamma in numpyro is parameterized by concentration and rate (alpha/beta)
60
+
57
61
  _, observed_counts = obs.out_energies, obs.folded_background.data
58
62
  alpha = observed_counts + 1
59
63
  beta = 1
60
- countrate = numpyro.sample(f"{name}_params", dist.Gamma(alpha, rate=beta))
64
+ countrate = numpyro.sample(f"_bkg_{name}_countrate", dist.Gamma(alpha, rate=beta))
61
65
 
62
- with numpyro.plate(f"{name}_plate", len(observed_counts)):
66
+ with numpyro.plate(f"bkg_{name}_plate", len(observed_counts)):
63
67
  numpyro.sample(
64
- f"{name}", dist.Poisson(countrate), obs=observed_counts if observed else None
68
+ f"bkg_{name}", dist.Poisson(countrate), obs=observed_counts if observed else None
65
69
  )
66
70
 
67
71
  return countrate
68
72
 
69
73
 
70
- '''
71
- # TODO: Implement this class and sample it with Gibbs Sampling
72
-
73
- class ConjugateBackground(BackgroundModel):
74
- r"""
75
- This class fit an expected rate $\\lambda$ in each bin of the background spectrum. Assuming a Gamma prior
76
- distribution, we can analytically derive the posterior as a Negative binomial distribution.
77
-
78
- $$ p(\\lambda_{\text{Bkg}}) \\sim \\Gamma \\left( \alpha, \beta \right) \\implies
79
- p\\left(\\lambda_{\text{Bkg}} | \text{Counts}_{\text{Bkg}}\right) \\sim \text{NB}\\left(\alpha, \frac{\beta}{\beta +1}
80
- \right) $$
81
-
82
- !!! info
83
- Here, $\alpha$ and $\beta$ are set to $\alpha = \text{Counts}_{\text{Bkg}} + 1$ and $\beta = 1$. Doing so,
84
- the prior distribution is such that $\\mathbb{E}[\\lambda_{\text{Bkg}}] = \text{Counts}_{\text{Bkg}} +1$ and
85
- $\text{Var}[\\lambda_{\text{Bkg}}] = \text{Counts}_{\text{Bkg}}+1$. The +1 is to avoid numerical issues when the
86
- counts are 0, and add a small scatter even if the measured background is effectively null.
87
-
88
- ??? abstract "References"
89
-
90
- - https://en.wikipedia.org/wiki/Conjugate_prior
91
- - https://www.acsu.buffalo.edu/~adamcunn/probability/gamma.html
92
- - https://bayesiancomputationbook.com/markdown/chp_01.html?highlight=conjugate#conjugate-priors
93
- - https://vioshyvo.github.io/Bayesian_inference/conjugate-distributions.html
94
-
95
- """
96
-
97
- def numpyro_model(self, energy, observed_counts, name: str = "bkg", observed=True):
98
- # Gamma in numpyro is parameterized by concentration and rate (alpha/beta)
99
- # alpha = observed_counts + 1
100
- # beta = 1
101
-
102
- with numpyro.plate(f"{name}_plate", len(observed_counts)):
103
- countrate = numpyro.sample(f"{name}", dist.Gamma(2 * observed_counts + 1, 2), obs=None)
104
-
105
- return countrate
106
- '''
107
-
108
- """
109
- class SpectralBackgroundModel(BackgroundModel):
110
- # I should pass the current spectral model as an argument to the background model
111
- # In the numpyro model function
112
- def __init__(self, model, prior):
113
- self.model = model
114
- self.prior = prior
115
-
116
- def numpyro_model(self, energy, observed_counts, name: str = "bkg", observed=True):
117
- #TODO : keep the sparsification from top model
118
- transformed_model = hk.without_apply_rng(hk.transform(lambda par: CountForwardModel(model, obs, sparse=False)(par)))
119
- """
120
-
121
-
122
74
  class GaussianProcessBackground(BackgroundModel):
123
75
  """
124
- This class use a Gaussian Process to model the background. The GP is built using the
76
+ Define a Gaussian Process to model the background. The GP is built using the
125
77
  [`tinygp`](https://tinygp.readthedocs.io/en/stable/guide.html) library.
126
78
  """
127
79
 
@@ -146,16 +98,7 @@ class GaussianProcessBackground(BackgroundModel):
146
98
  self.n_nodes = n_nodes
147
99
  self.kernel = kernel
148
100
 
149
- def numpyro_model(self, obs, spectral_model, name: str = "bkg", observed=True):
150
- """
151
- Build the model for the background.
152
-
153
- Parameters:
154
- energy: The energy bins lower and upper values (e_low, e_high).
155
- observed_counts: The observed counts in each energy bin.
156
- name: The name of the background model for parameters disambiguation.
157
- observed: Whether the model is observed or not. Useful for `numpyro.infer.Predictive` calls.
158
- """
101
+ def numpyro_model(self, obs, name: str = "", observed=True):
159
102
  energy, observed_counts = obs.out_energies, obs.folded_background.data
160
103
 
161
104
  if (observed_counts is not None) and (self.n_nodes >= len(observed_counts)):
@@ -163,28 +106,92 @@ class GaussianProcessBackground(BackgroundModel):
163
106
  "More nodes than channels in the observation associated with GaussianProcessBackground."
164
107
  )
165
108
 
109
+ else:
110
+ observed_counts = jnp.asarray(observed_counts)
111
+
166
112
  # The parameters of the GP model
167
- mean = numpyro.sample(f"{name}_mean", dist.Normal(jnp.log(jnp.mean(observed_counts)), 2.0))
168
- sigma = numpyro.sample(f"{name}_sigma", dist.HalfNormal(3.0))
169
- rho = numpyro.sample(f"{name}_rho", dist.HalfNormal(10.0))
113
+ mean = numpyro.sample(
114
+ f"_bkg_{name}_mean", dist.Normal(jnp.log(jnp.mean(observed_counts)), 2.0)
115
+ )
116
+ sigma = numpyro.sample(f"_bkg_{name}_sigma", dist.HalfNormal(3.0))
117
+ rho = numpyro.sample(f"_bkg_{name}_rho", dist.HalfNormal(10.0))
170
118
 
171
119
  # Set up the kernel and GP objects
172
120
  kernel = sigma**2 * self.kernel(rho)
173
121
  nodes = jnp.linspace(0, 1, self.n_nodes)
174
122
  gp = GaussianProcess(kernel, nodes, diag=1e-5 * jnp.ones_like(nodes), mean=mean)
175
123
 
176
- log_rate = numpyro.sample(f"_{name}_log_rate_nodes", gp.numpyro_dist())
124
+ log_rate = numpyro.sample(f"_bkg_{name}_log_rate_nodes", gp.numpyro_dist())
125
+
177
126
  interp_count_rate = jnp.exp(
178
127
  jnp.interp(energy, nodes * (self.e_max - self.e_min) + self.e_min, log_rate)
179
128
  )
180
129
  count_rate = trapezoid(interp_count_rate, energy, axis=0)
181
130
 
182
131
  # Finally, our observation model is Poisson
183
- with numpyro.plate(f"{name}_plate", len(observed_counts)):
184
- # TODO : change to Poisson Likelihood when there is no background model
185
- # TODO : Otherwise clip the background model to 1e-6 to avoid numerical issues
132
+ with numpyro.plate("bkg_plate_" + name, len(observed_counts)):
186
133
  numpyro.sample(
187
- f"{name}", dist.Poisson(count_rate), obs=observed_counts if observed else None
134
+ f"bkg_{name}", dist.Poisson(count_rate), obs=observed_counts if observed else None
188
135
  )
189
136
 
190
137
  return count_rate
138
+
139
+
140
+ class SpectralModelBackground(BackgroundModel):
141
+ def __init__(self, spectral_model: "SpectralModel", prior_distributions, sparse=False):
142
+ self.spectral_model = spectral_model
143
+ self.prior = prior_distributions
144
+ self.sparse = sparse
145
+
146
+ def numpyro_model(self, observation, name: str = "", observed=True):
147
+ params = build_prior(self.prior, prefix=f"_bkg_{name}_")
148
+ bkg_model = jax.jit(
149
+ lambda par: forward_model(self.spectral_model, par, observation, sparse=self.sparse)
150
+ )
151
+ bkg_countrate = bkg_model(params)
152
+
153
+ with numpyro.plate("bkg_plate_" + name, len(observation.folded_background)):
154
+ numpyro.sample(
155
+ "bkg_" + name,
156
+ Poisson(bkg_countrate),
157
+ obs=observation.folded_background.data if observed else None,
158
+ )
159
+
160
+ return bkg_countrate
161
+
162
+
163
+ '''
164
+ class ConjugateBackground(BackgroundModel):
165
+ r"""
166
+ This class fit an expected rate $\\lambda$ in each bin of the background spectrum. Assuming a Gamma prior
167
+ distribution, we can analytically derive the posterior as a Negative binomial distribution.
168
+
169
+ $$ p(\\lambda_{\text{Bkg}}) \\sim \\Gamma \\left( \alpha, \beta \right) \\implies
170
+ p\\left(\\lambda_{\text{Bkg}} | \text{Counts}_{\text{Bkg}}\right) \\sim \text{NB}\\left(\alpha, \frac{\beta}{\beta +1}
171
+ \right) $$
172
+
173
+ !!! info
174
+ Here, $\alpha$ and $\beta$ are set to $\alpha = \text{Counts}_{\text{Bkg}} + 1$ and $\beta = 1$. Doing so,
175
+ the prior distribution is such that $\\mathbb{E}[\\lambda_{\text{Bkg}}] = \text{Counts}_{\text{Bkg}} +1$ and
176
+ $\text{Var}[\\lambda_{\text{Bkg}}] = \text{Counts}_{\text{Bkg}}+1$. The +1 is to avoid numerical issues when the
177
+ counts are 0, and add a small scatter even if the measured background is effectively null.
178
+
179
+ ??? abstract "References"
180
+
181
+ - https://en.wikipedia.org/wiki/Conjugate_prior
182
+ - https://www.acsu.buffalo.edu/~adamcunn/probability/gamma.html
183
+ - https://bayesiancomputationbook.com/markdown/chp_01.html?highlight=conjugate#conjugate-priors
184
+ - https://vioshyvo.github.io/Bayesian_inference/conjugate-distributions.html
185
+
186
+ """
187
+
188
+ def numpyro_model(self, energy, observed_counts, name: str = "bkg", observed=True):
189
+ # Gamma in numpyro is parameterized by concentration and rate (alpha/beta)
190
+ # alpha = observed_counts + 1
191
+ # beta = 1
192
+
193
+ with numpyro.plate(f"{name}_plate", len(observed_counts)):
194
+ countrate = numpyro.sample(f"{name}", dist.Gamma(2 * observed_counts + 1, 2), obs=None)
195
+
196
+ return countrate
197
+ '''
@@ -1,16 +1,31 @@
1
1
  from __future__ import annotations
2
2
 
3
- import haiku as hk
3
+ import flax.nnx as nnx
4
4
  import jax.numpy as jnp
5
5
  import numpy as np
6
6
 
7
7
  from astropy.table import Table
8
- from haiku.initializers import Constant as HaikuConstant
9
8
 
10
9
  from ..util.online_storage import table_manager
11
10
  from .abc import MultiplicativeComponent
12
11
 
13
12
 
13
+ class MultiplicativeConstant(MultiplicativeComponent):
14
+ r"""
15
+ A multiplicative constant
16
+
17
+ !!! abstract "Parameters"
18
+ * $K$ (`norm`) $\left[\text{dimensionless}\right]$: The multiplicative constant.
19
+
20
+ """
21
+
22
+ def __init__(self):
23
+ self.norm = nnx.Param(1.0)
24
+
25
+ def factor(self, energy):
26
+ return self.norm
27
+
28
+
14
29
  class Expfac(MultiplicativeComponent):
15
30
  r"""
16
31
  An exponential modification of a spectrum.
@@ -20,19 +35,20 @@ class Expfac(MultiplicativeComponent):
20
35
  \text{if $E>E_c$}\\1 & \text{if $E<E_c$}\end{cases}
21
36
  $$
22
37
 
23
- ??? abstract "Parameters"
24
- * $A$ : Amplitude of the modification $\left[\text{dimensionless}\right]$
25
- * $f$ : Exponential factor $\left[\text{keV}^{-1}\right]$
26
- * $E_c$ : Start energy of modification $\left[\text{keV}\right]$
38
+ !!! abstract "Parameters"
39
+ * $A$ (`A`) $\left[\text{dimensionless}\right]$ : Amplitude of the modification
40
+ * $f$ (`f`) $\left[\text{keV}^{-1}\right]$ : Exponential factor
41
+ * $E_c$ (`E_c`) $\left[\text{keV}\right]$: Start energy of modification
27
42
 
28
43
  """
29
44
 
30
- def continuum(self, energy):
31
- amplitude = hk.get_parameter("A", [], float, init=HaikuConstant(1))
32
- factor = hk.get_parameter("f", [], float, init=HaikuConstant(1))
33
- pivot = hk.get_parameter("E_c", [], float, init=HaikuConstant(1))
45
+ def __init__(self):
46
+ self.A = nnx.Param(1.0)
47
+ self.f = nnx.Param(1.0)
48
+ self.E_c = nnx.Param(1.0)
34
49
 
35
- return jnp.where(energy >= pivot, 1.0 + amplitude * jnp.exp(-factor * energy), 1.0)
50
+ def factor(self, energy):
51
+ return jnp.where(energy >= self.E_c, 1.0 + self.A * jnp.exp(-self.f * energy), 1.0)
36
52
 
37
53
 
38
54
  class Tbabs(MultiplicativeComponent):
@@ -45,49 +61,47 @@ class Tbabs(MultiplicativeComponent):
45
61
  \mathcal{M}(E) = \exp^{-N_{\text{H}}\sigma(E)}
46
62
  $$
47
63
 
48
- ??? abstract "Parameters"
49
- * $N_{\text{H}}$ : Equivalent hydrogen column density
50
- $\left[\frac{\text{atoms}~10^{22}}{\text{cm}^2}\right]$
64
+ !!! abstract "Parameters"
65
+ * $N_{\text{H}}$ (`nh`) $\left[\frac{\text{atoms}~10^{22}}{\text{cm}^2}\right]$ : Equivalent hydrogen column density
66
+
51
67
 
52
68
  !!! note
53
69
  Abundances and cross-sections $\sigma$ can be found in Wilms et al. (2000).
54
70
 
55
71
  """
56
72
 
57
- def __init__(self, *args, **kwargs):
58
- super().__init__(*args, **kwargs)
73
+ def __init__(self):
59
74
  table = Table.read(table_manager.fetch("xsect_tbabs_wilm.fits"))
60
- self.energy = jnp.asarray(np.array(table["ENERGY"]), dtype=np.float64)
61
- self.sigma = jnp.asarray(np.array(table["SIGMA"]), dtype=np.float64)
75
+ self.energy = nnx.Variable(np.asarray(table["ENERGY"], dtype=np.float64))
76
+ self.sigma = nnx.Variable(np.asarray(table["SIGMA"], dtype=np.float64))
77
+ self.nh = nnx.Param(1.0)
62
78
 
63
- def continuum(self, energy):
64
- nh = hk.get_parameter("N_H", [], float, init=HaikuConstant(1))
79
+ def factor(self, energy):
65
80
  sigma = jnp.interp(energy, self.energy, self.sigma, left=1e9, right=0.0)
66
81
 
67
- return jnp.exp(-nh * sigma)
82
+ return jnp.exp(-self.nh * sigma)
68
83
 
69
84
 
70
85
  class Phabs(MultiplicativeComponent):
71
86
  r"""
72
87
  A photoelectric absorption model.
73
88
 
74
- ??? abstract "Parameters"
75
- * $N_{\text{H}}$ : Equivalent hydrogen column density
76
- $\left[\frac{\text{atoms}~10^{22}}{\text{cm}^2}\right]$
89
+ !!! abstract "Parameters"
90
+ * $N_{\text{H}}$ (`nh`) $\left[\frac{\text{atoms}~10^{22}}{\text{cm}^2}\right]$ : Equivalent hydrogen column density
91
+
77
92
 
78
93
  """
79
94
 
80
- def __init__(self, *args, **kwargs):
81
- super().__init__(*args, **kwargs)
95
+ def __init__(self):
82
96
  table = Table.read(table_manager.fetch("xsect_phabs_aspl.fits"))
83
- self.energy = jnp.asarray(np.array(table["ENERGY"]), dtype=np.float64)
84
- self.sigma = jnp.asarray(np.array(table["SIGMA"]), dtype=np.float64)
97
+ self.energy = nnx.Variable(np.asarray(table["ENERGY"], dtype=np.float64))
98
+ self.sigma = nnx.Variable(np.asarray(table["SIGMA"], dtype=np.float64))
99
+ self.nh = nnx.Param(1.0)
85
100
 
86
- def continuum(self, energy):
87
- nh = hk.get_parameter("N_H", [], float, init=HaikuConstant(1))
88
- sigma = jnp.interp(energy, self.energy, self.sigma, left=jnp.inf, right=0.0)
101
+ def factor(self, energy):
102
+ sigma = jnp.interp(energy, self.energy, self.sigma, left=1e9, right=0.0)
89
103
 
90
- return jnp.exp(-nh * sigma)
104
+ return jnp.exp(-self.nh * sigma)
91
105
 
92
106
 
93
107
  class Wabs(MultiplicativeComponent):
@@ -95,22 +109,19 @@ class Wabs(MultiplicativeComponent):
95
109
  A photo-electric absorption using Wisconsin (Morrison & McCammon 1983) cross-sections.
96
110
 
97
111
  ??? abstract "Parameters"
98
- * $N_{\text{H}}$ : Equivalent hydrogen column density
99
- $\left[\frac{\text{atoms}~10^{22}}{\text{cm}^2}\right]$
100
-
112
+ * $N_{\text{H}}$ (`nh`) $\left[\frac{\text{atoms}~10^{22}}{\text{cm}^2}\right]$ : Equivalent hydrogen column density
101
113
  """
102
114
 
103
- def __init__(self, *args, **kwargs):
104
- super().__init__(*args, **kwargs)
115
+ def __init__(self):
105
116
  table = Table.read(table_manager.fetch("xsect_wabs_angr.fits"))
106
- self.energy = jnp.asarray(np.array(table["ENERGY"]), dtype=np.float64)
107
- self.sigma = jnp.asarray(np.array(table["SIGMA"]), dtype=np.float64)
117
+ self.energy = nnx.Variable(np.asarray(table["ENERGY"], dtype=np.float64))
118
+ self.sigma = nnx.Variable(np.asarray(table["SIGMA"], dtype=np.float64))
119
+ self.nh = nnx.Param(1.0)
108
120
 
109
- def continuum(self, energy):
110
- nh = hk.get_parameter("N_H", [], float, init=HaikuConstant(1))
121
+ def factor(self, energy):
111
122
  sigma = jnp.interp(energy, self.energy, self.sigma, left=1e9, right=0.0)
112
123
 
113
- return jnp.exp(-nh * sigma)
124
+ return jnp.exp(-self.nh * sigma)
114
125
 
115
126
 
116
127
  class Gabs(MultiplicativeComponent):
@@ -122,23 +133,26 @@ class Gabs(MultiplicativeComponent):
122
133
  \left( -\frac{\left(E-E_0\right)^2}{2 \sigma^2} \right) \right)
123
134
  $$
124
135
 
125
- ??? abstract "Parameters"
126
- * $\tau$ : Absorption strength $\left[\text{dimensionless}\right]$
127
- * $\sigma$ : Absorption width $\left[\text{keV}\right]$
128
- * $E_0$ : Absorption center $\left[\text{keV}\right]$
136
+ !!! abstract "Parameters"
137
+ * $\tau$ (`tau`) $\left[\text{dimensionless}\right]$ : Absorption strength
138
+ * $\sigma$ (`sigma`) $\left[\text{keV}\right]$ : Absorption width
139
+ * $E_0$ (`E0`) $\left[\text{keV}\right]$ : Absorption center
129
140
 
130
141
  !!! note
131
142
  The optical depth at line center is $\tau/(\sqrt{2 \pi} \sigma)$.
132
143
 
133
144
  """
134
145
 
135
- def continuum(self, energy):
136
- tau = hk.get_parameter("tau", [], float, init=HaikuConstant(1))
137
- sigma = hk.get_parameter("sigma", [], float, init=HaikuConstant(1))
138
- center = hk.get_parameter("E_0", [], float, init=HaikuConstant(1))
146
+ def __init__(self):
147
+ self.tau = nnx.Param(1.0)
148
+ self.sigma = nnx.Param(1.0)
149
+ self.E0 = nnx.Param(1.0)
139
150
 
151
+ def factor(self, energy):
140
152
  return jnp.exp(
141
- -tau / (jnp.sqrt(2 * jnp.pi) * sigma) * jnp.exp(-0.5 * ((energy - center) / sigma) ** 2)
153
+ -self.tau
154
+ / (jnp.sqrt(2 * jnp.pi) * self.sigma)
155
+ * jnp.exp(-0.5 * ((energy - self.E0) / self.sigma) ** 2)
142
156
  )
143
157
 
144
158
 
@@ -151,16 +165,17 @@ class Highecut(MultiplicativeComponent):
151
165
  \left( \frac{E_c - E}{E_f} \right)& \text{if $E > E_c$}\\ 1 & \text{if $E < E_c$}\end{cases}
152
166
  $$
153
167
 
154
- ??? abstract "Parameters"
155
- * $E_c$ : Cutoff energy $\left[\text{keV}\right]$
156
- * $E_f$ : Folding energy $\left[\text{keV}\right]$
168
+ !!! abstract "Parameters"
169
+ * $E_c$ (`Ec`) $\left[\text{keV}\right]$ : Cutoff energy
170
+ * $E_f$ (`Ef`) $\left[\text{keV}\right]$ : Folding energy
157
171
  """
158
172
 
159
- def continuum(self, energy):
160
- cutoff = hk.get_parameter("E_c", [], float, init=HaikuConstant(1))
161
- folding = hk.get_parameter("E_f", [], float, init=HaikuConstant(1))
173
+ def __init__(self):
174
+ self.Ec = nnx.Param(1.0)
175
+ self.Ef = nnx.Param(1.0)
162
176
 
163
- return jnp.where(energy <= cutoff, 1.0, jnp.exp((cutoff - energy) / folding))
177
+ def factor(self, energy):
178
+ return jnp.where(energy <= self.Ec, 1.0, jnp.exp((self.Ec - energy) / self.Ef))
164
179
 
165
180
 
166
181
  class Zedge(MultiplicativeComponent):
@@ -172,18 +187,21 @@ class Zedge(MultiplicativeComponent):
172
187
  & \text{if $E > E_c$}\\ 1 & \text{if $E < E_c$}\end{cases}
173
188
  $$
174
189
 
175
- ??? abstract "Parameters"
176
- * $E_c$ : Threshold energy
177
- * $E_f$ : Absorption depth at the threshold
178
- * $z$ : Redshift [dimensionless]
190
+ !!! abstract "Parameters"
191
+ * $E_c$ (`Ec`) $\left[\text{keV}\right]$ : Threshold energy
192
+ * $D$ (`D`) $\left[\text{dimensionless}\right]$ : Absorption depth at the threshold
193
+ * $z$ (`z`) $\left[\text{dimensionless}\right]$ : Redshift
179
194
  """
180
195
 
181
- def continuum(self, energy):
182
- E_c = hk.get_parameter("E_c", [], float, init=HaikuConstant(1))
183
- D = hk.get_parameter("D", [], float, init=HaikuConstant(1))
184
- z = hk.get_parameter("z", [], float, init=HaikuConstant(0))
196
+ def __init__(self):
197
+ self.Ec = nnx.Param(1.0)
198
+ self.D = nnx.Param(1.0)
199
+ self.z = nnx.Param(0.0)
185
200
 
186
- return jnp.where(energy <= E_c, 1.0, jnp.exp(-D * (energy * (1 + z) / E_c) ** 3))
201
+ def factor(self, energy):
202
+ return jnp.where(
203
+ energy <= self.Ec, 1.0, jnp.exp(-self.D * (energy * (1 + self.z) / self.Ec) ** 3)
204
+ )
187
205
 
188
206
 
189
207
  class Tbpcf(MultiplicativeComponent):
@@ -194,28 +212,25 @@ class Tbpcf(MultiplicativeComponent):
194
212
  \mathcal{M}(E) = f \exp^{-N_{\text{H}}\sigma(E)} + (1-f)
195
213
  $$
196
214
 
197
- ??? abstract "Parameters"
198
- * $N_{\text{H}}$ : Equivalent hydrogen column density
199
- $\left[\frac{\text{atoms}~10^{22}}{\text{cm}^2}\right]$
200
- * $f$ : Partial covering fraction, ranges between 0 and 1 [dimensionless]
215
+ !!! abstract "Parameters"
216
+ * $N_{\text{H}}$ (`nh`) $\left[\frac{\text{atoms}~10^{22}}{\text{cm}^2}\right]$ : Equivalent hydrogen column density
217
+ * $f$ (`f`) $\left[\text{dimensionless}\right]$ : Partial covering fraction, ranges between 0 and 1
201
218
 
202
219
  !!! note
203
220
  Abundances and cross-sections $\sigma$ can be found in Wilms et al. (2000).
204
221
 
205
222
  """
206
223
 
207
- def __init__(self, *args, **kwargs):
208
- super().__init__(*args, **kwargs)
224
+ def __init__(self):
209
225
  table = Table.read(table_manager.fetch("xsect_tbabs_wilm.fits"))
210
- self.energy = jnp.asarray(np.array(table["ENERGY"]), dtype=np.float64)
211
- self.sigma = jnp.asarray(np.array(table["SIGMA"]), dtype=np.float64)
226
+ self.energy = nnx.Variable(np.asarray(table["ENERGY"], dtype=np.float64))
227
+ self.sigma = nnx.Variable(np.asarray(table["SIGMA"], dtype=np.float64))
228
+ self.nh = nnx.Param(1.0)
229
+ self.f = nnx.Param(0.2)
212
230
 
213
231
  def continuum(self, energy):
214
- nh = hk.get_parameter("N_H", [], float, init=HaikuConstant(1))
215
- f = hk.get_parameter("f", [], float, init=HaikuConstant(0.2))
216
232
  sigma = jnp.interp(energy, self.energy, self.sigma, left=1e9, right=0.0)
217
-
218
- return f * jnp.exp(-nh * sigma) + (1 - f)
233
+ return self.f * jnp.exp(-self.nh * sigma) + (1 - self.f)
219
234
 
220
235
 
221
236
  class FDcut(MultiplicativeComponent):
@@ -227,12 +242,13 @@ class FDcut(MultiplicativeComponent):
227
242
  $$
228
243
 
229
244
  ??? abstract "Parameters"
230
- * $E_c$ : Cutoff energy $\left[\text{keV}\right]$
231
- * $E_f$ : Folding energy $\left[\text{keV}\right]$
245
+ * $E_c$ (`Ec`) $\left[\text{keV}\right]$ : Cutoff energy
246
+ * $E_f$ (`Ef`) $\left[\text{keV}\right]$ : Folding energy
232
247
  """
233
248
 
234
- def continuum(self, energy):
235
- cutoff = hk.get_parameter("E_c", [], init=HaikuConstant(1))
236
- folding = hk.get_parameter("E_f", [], init=HaikuConstant(1))
249
+ def __init__(self):
250
+ self.Ec = nnx.Param(1.0)
251
+ self.Ef = nnx.Param(3.0)
237
252
 
238
- return (1 + jnp.exp((energy - cutoff) / folding)) ** -1
253
+ def continuum(self, energy):
254
+ return (1 + jnp.exp((energy - self.Ec) / self.Ef)) ** -1
jaxspec/scripts/debug.py CHANGED
@@ -8,4 +8,4 @@ def debug_info():
8
8
 
9
9
  # Cimer CamilleTheBest pour l'idée
10
10
  print(watermark())
11
- print(watermark(packages="jaxspec,jax,jaxlib,numpyro,haiku,numpy,scipy"))
11
+ print(watermark(packages="jaxspec,jax,jaxlib,numpyro,flax,numpy,scipy"))
jaxspec/util/__init__.py CHANGED
@@ -1,45 +0,0 @@
1
- from collections.abc import Callable
2
- from contextlib import contextmanager
3
- from time import perf_counter
4
-
5
- import haiku as hk
6
-
7
- from jax.random import PRNGKey, split
8
-
9
-
10
- @contextmanager
11
- def catchtime(desc="Task", print_time=True) -> Callable[[], float]:
12
- """
13
- Context manager to measure time taken by a task.
14
-
15
- Parameters
16
- ----------
17
- desc (str): Description of the task.
18
- print_time (bool): Whether to print the time taken by the task.
19
-
20
- Returns
21
- -------
22
- Callable[[], float]: Function to get the time taken by the task.
23
- """
24
-
25
- t1 = t2 = perf_counter()
26
- yield lambda: t2 - t1
27
- t2 = perf_counter()
28
- if print_time:
29
- print(f"{desc}: {t2 - t1:.3f} seconds")
30
-
31
-
32
- def sample_prior(dict_of_prior, key=PRNGKey(0), flat_parameters=False):
33
- """
34
- Sample the prior distribution from a dict of distributions
35
- """
36
-
37
- parameters = dict(hk.data_structures.to_haiku_dict(dict_of_prior))
38
- parameters_flat = {}
39
-
40
- for m, n, distribution in hk.data_structures.traverse(dict_of_prior):
41
- key, subkey = split(key)
42
- parameters[m][n] = distribution.sample(subkey)
43
- parameters_flat[m + "_" + n] = distribution.sample(subkey)
44
-
45
- return parameters if not flat_parameters else parameters_flat
jaxspec/util/misc.py ADDED
@@ -0,0 +1,25 @@
1
+ from collections.abc import Callable
2
+ from contextlib import contextmanager
3
+ from time import perf_counter
4
+
5
+
6
+ @contextmanager
7
+ def catchtime(desc="Task", print_time=True) -> Callable[[], float]:
8
+ """
9
+ Context manager to measure time taken by a task.
10
+
11
+ Parameters
12
+ ----------
13
+ desc (str): Description of the task.
14
+ print_time (bool): Whether to print the time taken by the task.
15
+
16
+ Returns:
17
+ -------
18
+ Callable[[], float]: Function to get the time taken by the task.
19
+ """
20
+
21
+ t1 = t2 = perf_counter()
22
+ yield lambda: t2 - t1
23
+ t2 = perf_counter()
24
+ if print_time:
25
+ print(f"{desc}: {t2 - t1:.3f} seconds")