jaxspec 0.2.2__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,54 +1,54 @@
1
1
  import operator
2
- import warnings
3
2
 
4
- from abc import ABC, abstractmethod
5
3
  from collections.abc import Callable
6
4
  from functools import cached_property
7
- from typing import Literal
5
+ from typing import Any
8
6
 
9
- import arviz as az
10
7
  import jax
11
8
  import jax.numpy as jnp
12
9
  import matplotlib.pyplot as plt
13
10
  import numpyro
14
11
 
15
- from jax import random
12
+ from flax import nnx
16
13
  from jax.experimental import mesh_utils
17
14
  from jax.random import PRNGKey
18
15
  from jax.sharding import PositionalSharding
19
- from numpyro.contrib.nested_sampling import NestedSampler
20
16
  from numpyro.distributions import Poisson, TransformedDistribution
21
- from numpyro.infer import AIES, ESS, MCMC, NUTS, Predictive
17
+ from numpyro.infer import Predictive
22
18
  from numpyro.infer.inspect import get_model_relations
23
19
  from numpyro.infer.reparam import TransformReparam
24
20
  from numpyro.infer.util import log_density
25
21
 
26
- from ._fit._build_model import build_prior, forward_model
27
- from .analysis._plot import (
22
+ from ..analysis._plot import (
28
23
  _error_bars_for_observed_data,
29
24
  _plot_binned_samples_with_error,
30
25
  _plot_poisson_data_with_error,
31
26
  )
32
- from .analysis.results import FitResult
33
- from .data import ObsConfiguration
34
- from .model.abc import SpectralModel
35
- from .model.background import BackgroundModel
36
- from .util.typing import PriorDictType
27
+ from ..data import ObsConfiguration
28
+ from ..model.abc import SpectralModel
29
+ from ..model.background import BackgroundModel
30
+ from ..model.instrument import InstrumentModel
31
+ from ..util.typing import PriorDictType
32
+ from ._build_model import build_prior, forward_model
37
33
 
38
34
 
39
- class BayesianModel:
35
+ class BayesianModel(nnx.Module):
40
36
  """
41
37
  Base class for a Bayesian model. This class contains the necessary methods to build a model, sample from the prior
42
38
  and compute the log-likelihood and posterior probability.
43
39
  """
44
40
 
41
+ settings: dict[str, Any]
42
+
45
43
  def __init__(
46
44
  self,
47
45
  model: SpectralModel,
48
46
  prior_distributions: PriorDictType | Callable,
49
47
  observations: ObsConfiguration | list[ObsConfiguration] | dict[str, ObsConfiguration],
50
48
  background_model: BackgroundModel = None,
49
+ instrument_model: InstrumentModel = None,
51
50
  sparsify_matrix: bool = False,
51
+ n_points: int = 2,
52
52
  ):
53
53
  """
54
54
  Build a Bayesian model for a given spectral model and observations.
@@ -59,74 +59,36 @@ class BayesianModel:
59
59
  callable function that returns parameter samples.
60
60
  observations: the observations to fit the model to.
61
61
  background_model: the background model to fit.
62
+ instrument_model: the instrument model to fit.
62
63
  sparsify_matrix: whether to sparsify the transfer matrix.
63
64
  """
64
- self.model = model
65
+
66
+ self.spectral_model = model
65
67
  self._observations = observations
66
68
  self.background_model = background_model
67
- self.sparse = sparsify_matrix
69
+ self.instrument_model = instrument_model
70
+ self.settings = {"sparse": sparsify_matrix}
68
71
 
69
72
  if not callable(prior_distributions):
70
- # Validate the entry with pydantic
71
- # prior = PriorDictModel.from_dict(prior_distributions).
72
73
 
73
74
  def prior_distributions_func():
74
75
  return build_prior(
75
- prior_distributions, expand_shape=(len(self.observation_container),)
76
+ prior_distributions,
77
+ expand_shape=(len(self._observation_container),),
78
+ prefix="mod/~/",
76
79
  )
77
80
 
78
81
  else:
79
82
  prior_distributions_func = prior_distributions
80
83
 
81
84
  self.prior_distributions_func = prior_distributions_func
82
- self.init_params = self.prior_samples()
83
-
84
- # Check the priors are suited for the observations
85
- split_parameters = [
86
- (param, shape[-1])
87
- for param, shape in jax.tree.map(lambda x: x.shape, self.init_params).items()
88
- if (len(shape) > 1)
89
- and not param.startswith("_")
90
- and not param.startswith("bkg") # hardcoded for subtracted background
91
- ]
92
-
93
- for parameter, proposed_number_of_obs in split_parameters:
94
- if proposed_number_of_obs != len(self.observation_container):
95
- raise ValueError(
96
- f"Invalid splitting in the prior distribution. "
97
- f"Expected {len(self.observation_container)} but got {proposed_number_of_obs} for {parameter}"
98
- )
99
-
100
- @cached_property
101
- def observation_container(self) -> dict[str, ObsConfiguration]:
102
- """
103
- The observations used in the fit as a dictionary of observations.
104
- """
105
-
106
- if isinstance(self._observations, dict):
107
- return self._observations
108
-
109
- elif isinstance(self._observations, list):
110
- return {f"data_{i}": obs for i, obs in enumerate(self._observations)}
111
-
112
- elif isinstance(self._observations, ObsConfiguration):
113
- return {"data": self._observations}
114
-
115
- else:
116
- raise ValueError(f"Invalid type for observations : {type(self._observations)}")
117
-
118
- @cached_property
119
- def numpyro_model(self) -> Callable:
120
- """
121
- Build the numpyro model using the observed data, the prior distributions and the spectral model.
122
- """
123
85
 
124
86
  def numpyro_model(observed=True):
125
87
  # Instantiate and register the parameters of the spectral model and the background
126
88
  prior_params = self.prior_distributions_func()
127
89
 
128
90
  # Iterate over all the observations in our container and build a single numpyro model for each observation
129
- for i, (name, observation) in enumerate(self.observation_container.items()):
91
+ for i, (name, observation) in enumerate(self._observation_container.items()):
130
92
  # Check that we can indeed fit a background
131
93
  if (getattr(observation, "folded_background", None) is not None) and (
132
94
  self.background_model is not None
@@ -150,23 +112,70 @@ class BayesianModel:
150
112
  # They can be identical or different for each observation
151
113
  params = jax.tree.map(lambda x: x[i], prior_params)
152
114
 
115
+ if self.instrument_model is not None:
116
+ gain, shift = self.instrument_model.get_gain_and_shift_model(name)
117
+ else:
118
+ gain, shift = None, None
119
+
153
120
  # Forward model the observation and get the associated countrate
154
121
  obs_model = jax.jit(
155
- lambda par: forward_model(self.model, par, observation, sparse=self.sparse)
122
+ lambda par: forward_model(
123
+ self.spectral_model,
124
+ par,
125
+ observation,
126
+ sparse=self.settings.get("sparse", False),
127
+ gain=gain,
128
+ shift=shift,
129
+ n_points=n_points,
130
+ )
156
131
  )
132
+
157
133
  obs_countrate = obs_model(params)
158
134
 
159
135
  # Register the observation as an observed site
160
- with numpyro.plate("obs_plate_" + name, len(observation.folded_counts)):
136
+ with numpyro.plate("obs_plate/~/" + name, len(observation.folded_counts)):
161
137
  numpyro.sample(
162
- "obs_" + name,
163
- Poisson(
164
- obs_countrate + bkg_countrate
165
- ), # / observation.folded_backratio.data
138
+ "obs/~/" + name,
139
+ Poisson(obs_countrate + bkg_countrate / observation.folded_backratio.data),
166
140
  obs=observation.folded_counts.data if observed else None,
167
141
  )
168
142
 
169
- return numpyro_model
143
+ self.numpyro_model = numpyro_model
144
+ self._init_params = self.prior_samples()
145
+ # Check the priors are suited for the observations
146
+ split_parameters = [
147
+ (param, shape[-1])
148
+ for param, shape in jax.tree.map(lambda x: x.shape, self._init_params).items()
149
+ if (len(shape) > 1)
150
+ and not param.startswith("_")
151
+ and not param.startswith("bkg") # hardcoded for subtracted background
152
+ and not param.startswith("ins")
153
+ ]
154
+
155
+ for parameter, proposed_number_of_obs in split_parameters:
156
+ if proposed_number_of_obs != len(self._observation_container):
157
+ raise ValueError(
158
+ f"Invalid splitting in the prior distribution. "
159
+ f"Expected {len(self._observation_container)} but got {proposed_number_of_obs} for {parameter}"
160
+ )
161
+
162
+ @cached_property
163
+ def _observation_container(self) -> dict[str, ObsConfiguration]:
164
+ """
165
+ The observations used in the fit as a dictionary of observations.
166
+ """
167
+
168
+ if isinstance(self._observations, dict):
169
+ return self._observations
170
+
171
+ elif isinstance(self._observations, list):
172
+ return {f"data_{i}": obs for i, obs in enumerate(self._observations)}
173
+
174
+ elif isinstance(self._observations, ObsConfiguration):
175
+ return {"data": self._observations}
176
+
177
+ else:
178
+ raise ValueError(f"Invalid type for observations : {type(self._observations)}")
170
179
 
171
180
  @cached_property
172
181
  def transformed_numpyro_model(self) -> Callable:
@@ -235,7 +244,7 @@ class BayesianModel:
235
244
  return log_posterior_prob
236
245
 
237
246
  @cached_property
238
- def parameter_names(self) -> list[str]:
247
+ def _parameter_names(self) -> list[str]:
239
248
  """
240
249
  A list of parameter names for the model.
241
250
  """
@@ -260,7 +269,7 @@ class BayesianModel:
260
269
  """
261
270
  input_params = {}
262
271
 
263
- for index, key in enumerate(self.parameter_names):
272
+ for index, key in enumerate(self._parameter_names):
264
273
  input_params[key] = theta[index]
265
274
 
266
275
  return input_params
@@ -270,9 +279,9 @@ class BayesianModel:
270
279
  Convert a dictionary of parameters to an array of parameters.
271
280
  """
272
281
 
273
- theta = jnp.zeros(len(self.parameter_names))
282
+ theta = jnp.zeros(len(self._parameter_names))
274
283
 
275
- for index, key in enumerate(self.parameter_names):
284
+ for index, key in enumerate(self._parameter_names):
276
285
  theta = theta.at[index].set(dict_of_params[key])
277
286
 
278
287
  return theta
@@ -289,7 +298,7 @@ class BayesianModel:
289
298
  @jax.jit
290
299
  def prior_sample(key):
291
300
  return Predictive(
292
- self.numpyro_model, return_sites=self.parameter_names, num_samples=num_samples
301
+ self.numpyro_model, return_sites=self._parameter_names, num_samples=num_samples
293
302
  )(key, observed=False)
294
303
 
295
304
  return prior_sample(key)
@@ -327,7 +336,7 @@ class BayesianModel:
327
336
  sharded_parameters = jax.device_put(prior_params, sharding)
328
337
  posterior_observations = self.mock_observations(sharded_parameters, key=key_posterior)
329
338
 
330
- for key, value in self.observation_container.items():
339
+ for key, value in self._observation_container.items():
331
340
  fig, ax = plt.subplots(
332
341
  nrows=2, ncols=1, sharex=True, figsize=(5, 6), height_ratios=[3, 1]
333
342
  )
@@ -349,7 +358,7 @@ class BayesianModel:
349
358
  )
350
359
 
351
360
  prior_plot = _plot_binned_samples_with_error(
352
- ax[0], value.out_energies, posterior_observations["obs_" + key], n_sigmas=3
361
+ ax[0], value.out_energies, posterior_observations["obs/~/" + key], n_sigmas=3
353
362
  )
354
363
 
355
364
  legend_plots.append((true_data_plot,))
@@ -358,7 +367,7 @@ class BayesianModel:
358
367
  legend_labels.append("Prior Predictive")
359
368
 
360
369
  # rank = np.vstack((posterior_observations["obs_" + key], value.folded_counts.values)).argsort(axis=0)[-1] / (num_samples) * 100
361
- counts = posterior_observations["obs_" + key]
370
+ counts = posterior_observations["obs/~/" + key]
362
371
  observed = value.folded_counts.values
363
372
 
364
373
  num_samples = counts.shape[0]
@@ -387,264 +396,3 @@ class BayesianModel:
387
396
  plt.suptitle(f"Prior Predictive coverage for {key}")
388
397
  plt.tight_layout()
389
398
  plt.show()
390
-
391
-
392
- class BayesianModelFitter(BayesianModel, ABC):
393
- def build_inference_data(
394
- self,
395
- posterior_samples,
396
- num_chains: int = 1,
397
- num_predictive_samples: int = 1000,
398
- key: PRNGKey = PRNGKey(42),
399
- use_transformed_model: bool = False,
400
- filter_inference_data: bool = True,
401
- ) -> az.InferenceData:
402
- """
403
- Build an [InferenceData][arviz.InferenceData] object from posterior samples.
404
-
405
- Parameters:
406
- posterior_samples: the samples from the posterior distribution.
407
- num_chains: the number of chains used to sample the posterior.
408
- num_predictive_samples: the number of samples to draw from the prior.
409
- key: the random key used to initialize the sampler.
410
- use_transformed_model: whether to use the transformed model to build the InferenceData.
411
- filter_inference_data: whether to filter the InferenceData to keep only the relevant parameters.
412
- """
413
-
414
- numpyro_model = (
415
- self.transformed_numpyro_model if use_transformed_model else self.numpyro_model
416
- )
417
-
418
- keys = random.split(key, 3)
419
-
420
- posterior_predictive = Predictive(numpyro_model, posterior_samples)(keys[0], observed=False)
421
-
422
- prior = Predictive(numpyro_model, num_samples=num_predictive_samples * num_chains)(
423
- keys[1], observed=False
424
- )
425
-
426
- log_likelihood = numpyro.infer.log_likelihood(numpyro_model, posterior_samples)
427
-
428
- seeded_model = numpyro.handlers.substitute(
429
- numpyro.handlers.seed(numpyro_model, keys[3]),
430
- substitute_fn=numpyro.infer.init_to_sample,
431
- )
432
-
433
- observations = {
434
- name: site["value"]
435
- for name, site in numpyro.handlers.trace(seeded_model).get_trace().items()
436
- if site["type"] == "sample" and site["is_observed"]
437
- }
438
-
439
- def reshape_first_dimension(arr):
440
- new_dim = arr.shape[0] // num_chains
441
- new_shape = (num_chains, new_dim) + arr.shape[1:]
442
- reshaped_array = arr.reshape(new_shape)
443
-
444
- return reshaped_array
445
-
446
- posterior_samples = {
447
- key: reshape_first_dimension(value) for key, value in posterior_samples.items()
448
- }
449
- prior = {key: value[None, :] for key, value in prior.items()}
450
- posterior_predictive = {
451
- key: reshape_first_dimension(value) for key, value in posterior_predictive.items()
452
- }
453
- log_likelihood = {
454
- key: reshape_first_dimension(value) for key, value in log_likelihood.items()
455
- }
456
-
457
- inference_data = az.from_dict(
458
- posterior_samples,
459
- prior=prior,
460
- posterior_predictive=posterior_predictive,
461
- log_likelihood=log_likelihood,
462
- observed_data=observations,
463
- )
464
-
465
- return (
466
- self.filter_inference_data(inference_data) if filter_inference_data else inference_data
467
- )
468
-
469
- def filter_inference_data(
470
- self,
471
- inference_data: az.InferenceData,
472
- ) -> az.InferenceData:
473
- """
474
- Filter the inference data to keep only the relevant parameters for the observations.
475
-
476
- - Removes predictive parameters from deterministic random variables (e.g. kernel of background GP)
477
- - Removes parameters build from reparametrised variables (e.g. ending with `"_base"`)
478
- """
479
-
480
- predictive_parameters = []
481
-
482
- for key, value in self.observation_container.items():
483
- if self.background_model is not None:
484
- predictive_parameters.append(f"obs_{key}")
485
- predictive_parameters.append(f"bkg_{key}")
486
- else:
487
- predictive_parameters.append(f"obs_{key}")
488
-
489
- inference_data.posterior_predictive = inference_data.posterior_predictive[
490
- predictive_parameters
491
- ]
492
-
493
- parameters = [
494
- x
495
- for x in inference_data.posterior.keys()
496
- if not x.endswith("_base") or x.startswith("_")
497
- ]
498
- inference_data.posterior = inference_data.posterior[parameters]
499
- inference_data.prior = inference_data.prior[parameters]
500
-
501
- return inference_data
502
-
503
- @abstractmethod
504
- def fit(self, **kwargs) -> FitResult: ...
505
-
506
-
507
- class MCMCFitter(BayesianModelFitter):
508
- """
509
- A class to fit a model to a given set of observation using a Bayesian approach. This class uses samplers
510
- from numpyro to perform the inference on the model parameters.
511
- """
512
-
513
- kernel_dict = {
514
- "nuts": NUTS,
515
- "aies": AIES,
516
- "ess": ESS,
517
- }
518
-
519
- def fit(
520
- self,
521
- rng_key: int = 0,
522
- num_chains: int = len(jax.devices()),
523
- num_warmup: int = 1000,
524
- num_samples: int = 1000,
525
- sampler: Literal["nuts", "aies", "ess"] = "nuts",
526
- use_transformed_model: bool = True,
527
- kernel_kwargs: dict = {},
528
- mcmc_kwargs: dict = {},
529
- ) -> FitResult:
530
- """
531
- Fit the model to the data using a MCMC sampler from numpyro.
532
-
533
- Parameters:
534
- rng_key: the random key used to initialize the sampler.
535
- num_chains: the number of chains to run.
536
- num_warmup: the number of warmup steps.
537
- num_samples: the number of samples to draw.
538
- sampler: the sampler to use. Can be one of "nuts", "aies" or "ess".
539
- use_transformed_model: whether to use the transformed model to build the InferenceData.
540
- kernel_kwargs: additional arguments to pass to the kernel. See [`NUTS`][numpyro.infer.mcmc.MCMCKernel] for more details.
541
- mcmc_kwargs: additional arguments to pass to the MCMC sampler. See [`MCMC`][numpyro.infer.mcmc.MCMC] for more details.
542
-
543
- Returns:
544
- A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
545
- """
546
-
547
- bayesian_model = (
548
- self.transformed_numpyro_model if use_transformed_model else self.numpyro_model
549
- )
550
-
551
- chain_kwargs = {
552
- "num_warmup": num_warmup,
553
- "num_samples": num_samples,
554
- "num_chains": num_chains,
555
- }
556
-
557
- kernel = self.kernel_dict[sampler](bayesian_model, **kernel_kwargs)
558
-
559
- mcmc_kwargs = chain_kwargs | mcmc_kwargs
560
-
561
- if sampler in ["aies", "ess"] and mcmc_kwargs.get("chain_method", None) != "vectorized":
562
- mcmc_kwargs["chain_method"] = "vectorized"
563
- warnings.warn("The chain_method is set to 'vectorized' for AIES and ESS samplers")
564
-
565
- mcmc = MCMC(kernel, **mcmc_kwargs)
566
- keys = random.split(random.PRNGKey(rng_key), 3)
567
-
568
- mcmc.run(keys[0])
569
-
570
- posterior = mcmc.get_samples()
571
-
572
- inference_data = self.build_inference_data(
573
- posterior, num_chains=num_chains, use_transformed_model=True
574
- )
575
-
576
- return FitResult(
577
- self,
578
- inference_data,
579
- background_model=self.background_model,
580
- )
581
-
582
-
583
- class NSFitter(BayesianModelFitter):
584
- r"""
585
- A class to fit a model to a given set of observation using the Nested Sampling algorithm. This class uses the
586
- [`DefaultNestedSampler`][jaxns.DefaultNestedSampler] from [`jaxns`](https://jaxns.readthedocs.io/en/latest/) which
587
- implements the [Phantom-Powered Nested Sampling](https://arxiv.org/abs/2312.11330) algorithm.
588
-
589
- !!! info
590
- Ensure large prior volume is covered by the prior distributions to ensure the algorithm yield proper results.
591
-
592
- """
593
-
594
- def fit(
595
- self,
596
- rng_key: int = 0,
597
- num_samples: int = 1000,
598
- num_live_points: int = 1000,
599
- plot_diagnostics=False,
600
- termination_kwargs: dict | None = None,
601
- verbose=True,
602
- ) -> FitResult:
603
- """
604
- Fit the model to the data using the Phantom-Powered nested sampling algorithm.
605
-
606
- Parameters:
607
- rng_key: the random key used to initialize the sampler.
608
- num_samples: the number of samples to draw.
609
- num_live_points: the number of live points to use at the start of the NS algorithm.
610
- plot_diagnostics: whether to plot the diagnostics of the NS algorithm.
611
- termination_kwargs: additional arguments to pass to the termination criterion of the NS algorithm.
612
- verbose: whether to print the progress of the NS algorithm.
613
-
614
- Returns:
615
- A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
616
- """
617
-
618
- bayesian_model = self.transformed_numpyro_model
619
- keys = random.split(random.PRNGKey(rng_key), 4)
620
-
621
- ns = NestedSampler(
622
- bayesian_model,
623
- constructor_kwargs=dict(
624
- verbose=verbose,
625
- difficult_model=True,
626
- max_samples=1e5,
627
- parameter_estimation=True,
628
- gradient_guided=True,
629
- devices=jax.devices(),
630
- # init_efficiency_threshold=0.01,
631
- num_live_points=num_live_points,
632
- ),
633
- termination_kwargs=termination_kwargs if termination_kwargs else dict(),
634
- )
635
-
636
- ns.run(keys[0])
637
-
638
- if plot_diagnostics:
639
- ns.diagnostics()
640
-
641
- posterior = ns.get_samples(keys[1], num_samples=num_samples)
642
- inference_data = self.build_inference_data(
643
- posterior, num_chains=1, use_transformed_model=True
644
- )
645
-
646
- return FitResult(
647
- self,
648
- inference_data,
649
- background_model=self.background_model,
650
- )
@@ -1,5 +1,7 @@
1
+ from collections.abc import Callable
1
2
  from typing import TYPE_CHECKING
2
3
 
4
+ import jax
3
5
  import jax.numpy as jnp
4
6
  import numpy as np
5
7
  import numpyro
@@ -9,9 +11,15 @@ from jax.typing import ArrayLike
9
11
  from numpyro.distributions import Distribution
10
12
 
11
13
  if TYPE_CHECKING:
12
- from ..data import ObsConfiguration
13
- from ..model.abc import SpectralModel
14
- from ..util.typing import PriorDictType
14
+ from jaxspec.data import ObsConfiguration
15
+ from jaxspec.model.abc import SpectralModel
16
+ from jaxspec.util.typing import PriorDictType
17
+
18
+
19
+ class TiedParameter:
20
+ def __init__(self, tied_to, func):
21
+ self.tied_to = tied_to
22
+ self.func = func
15
23
 
16
24
 
17
25
  def forward_model(
@@ -19,6 +27,10 @@ def forward_model(
19
27
  parameters,
20
28
  obs_configuration: "ObsConfiguration",
21
29
  sparse=False,
30
+ gain: Callable | None = None,
31
+ shift: Callable | None = None,
32
+ split_branches: bool = False,
33
+ n_points: int | None = 2,
22
34
  ):
23
35
  energies = np.asarray(obs_configuration.in_energies)
24
36
 
@@ -31,10 +43,24 @@ def forward_model(
31
43
  else:
32
44
  transfer_matrix = np.asarray(obs_configuration.transfer_matrix.data.todense())
33
45
 
34
- expected_counts = transfer_matrix @ model.photon_flux(parameters, *energies)
46
+ energies = shift(energies) if shift is not None else energies
47
+ energies = jnp.clip(energies, min=1e-6) # Ensure shifted energies remain positive
48
+ factor = gain(energies) if gain is not None else 1.0
49
+ factor = jnp.clip(factor, min=0.0) # Ensure the gain is positive to avoid NaNs
35
50
 
36
- # The result is clipped at 1e-6 to avoid 0 round-off and diverging likelihoods
37
- return jnp.clip(expected_counts, a_min=1e-6)
51
+ if not split_branches:
52
+ expected_counts = transfer_matrix @ (
53
+ model.photon_flux(parameters, *energies, n_points=n_points) * factor
54
+ )
55
+ return jnp.clip(expected_counts, min=1e-6) # Ensure the expected counts are positive
56
+
57
+ else:
58
+ model_flux = model.photon_flux(
59
+ parameters, *energies, split_branches=True, n_points=n_points
60
+ )
61
+ return jax.tree.map(
62
+ lambda f: jnp.clip(transfer_matrix @ (f * factor), min=1e-6), model_flux
63
+ )
38
64
 
39
65
 
40
66
  def build_prior(prior: "PriorDictType", expand_shape: tuple = (), prefix=""):
@@ -43,15 +69,20 @@ def build_prior(prior: "PriorDictType", expand_shape: tuple = (), prefix=""):
43
69
  Must be used within a numpyro model.
44
70
  """
45
71
  parameters = {}
72
+ params_to_tie = {}
46
73
 
47
74
  for key, value in prior.items():
48
75
  # Split the key to extract the module name and parameter name
49
76
  module_name, param_name = key.rsplit("_", 1)
77
+
50
78
  if isinstance(value, Distribution):
51
79
  parameters[key] = jnp.ones(expand_shape) * numpyro.sample(
52
80
  f"{prefix}{module_name}_{param_name}", value
53
81
  )
54
82
 
83
+ elif isinstance(value, TiedParameter):
84
+ params_to_tie[key] = value
85
+
55
86
  elif isinstance(value, ArrayLike):
56
87
  parameters[key] = jnp.ones(expand_shape) * value
57
88
 
@@ -60,4 +91,9 @@ def build_prior(prior: "PriorDictType", expand_shape: tuple = (), prefix=""):
60
91
  f"Invalid prior type {type(value)} for parameter {prefix}{module_name}_{param_name} : {value}"
61
92
  )
62
93
 
94
+ for key, value in params_to_tie.items():
95
+ func_to_apply = value.func
96
+ tied_to = value.tied_to
97
+ parameters[key] = func_to_apply(parameters[tied_to])
98
+
63
99
  return parameters