flipcosmo 1.0.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. docs/conf.py +154 -0
  2. flip/__init__.py +4 -11
  3. flip/covariance/__init__.py +7 -8
  4. flip/covariance/analytical/__init__.py +11 -0
  5. flip/covariance/{adamsblake17plane → analytical/adamsblake17}/coefficients.py +1 -1
  6. flip/covariance/{adamsblake17plane → analytical/adamsblake17}/fisher_terms.py +1 -1
  7. flip/covariance/{adamsblake17 → analytical/adamsblake17}/flip_terms.py +0 -1
  8. flip/covariance/{adamsblake17 → analytical/adamsblake17plane}/coefficients.py +1 -1
  9. flip/covariance/{adamsblake17 → analytical/adamsblake17plane}/fisher_terms.py +1 -1
  10. flip/covariance/{adamsblake17plane → analytical/adamsblake17plane}/flip_terms.py +0 -1
  11. flip/covariance/{adamsblake17plane → analytical/adamsblake17plane}/generator.py +103 -19
  12. flip/covariance/{adamsblake20 → analytical/adamsblake20}/coefficients.py +1 -1
  13. flip/covariance/{adamsblake20 → analytical/adamsblake20}/fisher_terms.py +1 -1
  14. flip/covariance/{adamsblake20 → analytical/adamsblake20}/flip_terms.py +0 -1
  15. flip/covariance/{carreres23 → analytical/carreres23}/coefficients.py +1 -4
  16. flip/covariance/{ravouxnoanchor25 → analytical/carreres23}/fisher_terms.py +1 -1
  17. flip/covariance/{carreres23 → analytical/carreres23}/flip_terms.py +0 -1
  18. flip/covariance/analytical/carreres23/generator.py +198 -0
  19. flip/covariance/analytical/genericzdep/__init__.py +5 -0
  20. flip/covariance/analytical/genericzdep/coefficients.py +53 -0
  21. flip/covariance/analytical/genericzdep/flip_terms.py +99 -0
  22. flip/covariance/{lai22 → analytical/lai22}/coefficients.py +2 -3
  23. flip/covariance/{lai22 → analytical/lai22}/fisher_terms.py +1 -1
  24. flip/covariance/{lai22 → analytical/lai22}/flip_terms.py +0 -1
  25. flip/covariance/{lai22 → analytical/lai22}/generator.py +263 -58
  26. flip/covariance/{lai22 → analytical/lai22}/symbolic.py +55 -19
  27. flip/covariance/{ravouxcarreres → analytical/ravouxcarreres}/coefficients.py +1 -1
  28. flip/covariance/{ravouxcarreres → analytical/ravouxcarreres}/fisher_terms.py +1 -1
  29. flip/covariance/{ravouxcarreres → analytical/ravouxcarreres}/flip_terms.py +0 -1
  30. flip/covariance/{ravouxnoanchor25 → analytical/ravouxnoanchor25}/coefficients.py +3 -2
  31. flip/covariance/{carreres23 → analytical/ravouxnoanchor25}/fisher_terms.py +1 -1
  32. flip/covariance/{ravouxnoanchor25 → analytical/ravouxnoanchor25}/flip_terms.py +0 -9
  33. flip/covariance/{rcrk24 → analytical/rcrk24}/coefficients.py +6 -6
  34. flip/covariance/{rcrk24 → analytical/rcrk24}/fisher_terms.py +7 -9
  35. flip/covariance/{rcrk24 → analytical/rcrk24}/flip_terms.py +0 -8
  36. flip/covariance/contraction.py +82 -40
  37. flip/covariance/cov_utils.py +89 -81
  38. flip/covariance/covariance.py +172 -141
  39. flip/covariance/emulators/__init__.py +1 -1
  40. flip/covariance/emulators/generator.py +73 -3
  41. flip/covariance/emulators/gpmatrix.py +40 -1
  42. flip/covariance/emulators/nnmatrix.py +57 -1
  43. flip/covariance/emulators/skgpmatrix.py +125 -0
  44. flip/covariance/fisher.py +307 -0
  45. flip/{fit_utils.py → covariance/fit_utils.py} +185 -10
  46. flip/{fitter.py → covariance/fitter.py} +151 -125
  47. flip/covariance/generator.py +82 -106
  48. flip/{likelihood.py → covariance/likelihood.py} +286 -64
  49. flip/{plot_utils.py → covariance/plot_utils.py} +79 -4
  50. flip/covariance/symbolic.py +89 -44
  51. flip/data/__init__.py +1 -1
  52. flip/data/data_density.parquet +0 -0
  53. flip/data/data_velocity.parquet +0 -0
  54. flip/data/{grid_window_m.parquet → data_window_density.parquet} +0 -0
  55. flip/{gridding.py → data/gridding.py} +125 -130
  56. flip/data/load_data_test.py +102 -0
  57. flip/data/power_spectrum_mm.txt +2 -2
  58. flip/data/power_spectrum_mt.txt +2 -2
  59. flip/data/power_spectrum_tt.txt +2 -2
  60. flip/data/test_covariance_reference_values.json +145 -0
  61. flip/data/test_e2e_reference_values.json +14 -0
  62. flip/data_vector/basic.py +118 -101
  63. flip/data_vector/cosmo_utils.py +18 -0
  64. flip/data_vector/galaxypv_vectors.py +58 -94
  65. flip/data_vector/snia_vectors.py +60 -3
  66. flip/data_vector/vector_utils.py +47 -1
  67. flip/power_spectra/class_engine.py +36 -1
  68. flip/power_spectra/cosmoprimo_engine.py +37 -2
  69. flip/power_spectra/generator.py +47 -25
  70. flip/power_spectra/models.py +30 -31
  71. flip/power_spectra/pyccl_engine.py +36 -1
  72. flip/simulation/__init__.py +0 -0
  73. flip/utils.py +62 -91
  74. flipcosmo-1.2.1.dist-info/METADATA +78 -0
  75. flipcosmo-1.2.1.dist-info/RECORD +109 -0
  76. {flipcosmo-1.0.0.dist-info → flipcosmo-1.2.1.dist-info}/WHEEL +1 -1
  77. flipcosmo-1.2.1.dist-info/top_level.txt +7 -0
  78. scripts/flip_compute_correlation_model.py +70 -0
  79. scripts/flip_compute_power_spectra.py +50 -0
  80. scripts/flip_fisher_forecast_velocity.py +70 -0
  81. scripts/flip_fisher_rcrk24.py +164 -0
  82. scripts/flip_launch_minuit_density_fit.py +91 -0
  83. scripts/flip_launch_minuit_full_fit.py +117 -0
  84. scripts/flip_launch_minuit_velocity_fit.py +78 -0
  85. scripts/flip_launch_minuit_velocity_fit_full.py +107 -0
  86. scripts/flip_launch_minuit_velocity_fit_interpolation.py +93 -0
  87. test/refresh_reference_values.py +43 -0
  88. test/test_covariance_assembly.py +102 -0
  89. test/test_covariance_reference_values.py +125 -0
  90. test/test_covariance_utils.py +34 -0
  91. test/test_e2e_density.py +50 -0
  92. test/test_e2e_joint.py +65 -0
  93. test/test_e2e_velocity.py +53 -0
  94. test/test_likelihood_inversions.py +31 -0
  95. flip/covariance/carreres23/generator.py +0 -132
  96. flip/data/density_data.parquet +0 -0
  97. flip/data/velocity_data.parquet +0 -0
  98. flip/fisher.py +0 -190
  99. flipcosmo-1.0.0.dist-info/METADATA +0 -32
  100. flipcosmo-1.0.0.dist-info/RECORD +0 -82
  101. flipcosmo-1.0.0.dist-info/top_level.txt +0 -1
  102. /flip/{config.py → _config.py} +0 -0
  103. /flip/covariance/{adamsblake17 → analytical/adamsblake17}/__init__.py +0 -0
  104. /flip/covariance/{adamsblake17plane → analytical/adamsblake17plane}/__init__.py +0 -0
  105. /flip/covariance/{adamsblake20 → analytical/adamsblake20}/__init__.py +0 -0
  106. /flip/covariance/{carreres23 → analytical/carreres23}/__init__.py +0 -0
  107. /flip/covariance/{lai22 → analytical/lai22}/__init__.py +0 -0
  108. /flip/covariance/{lai22 → analytical/lai22}/h_terms.py +0 -0
  109. /flip/covariance/{ravouxcarreres → analytical/ravouxcarreres}/__init__.py +0 -0
  110. /flip/covariance/{ravouxcarreres → analytical/ravouxcarreres}/flip_terms_lmax.py +0 -0
  111. /flip/covariance/{ravouxnoanchor25 → analytical/ravouxnoanchor25}/__init__.py +0 -0
  112. /flip/covariance/{rcrk24 → analytical/rcrk24}/__init__.py +0 -0
  113. {flipcosmo-1.0.0.dist-info → flipcosmo-1.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -6,7 +6,7 @@ import scipy as sc
6
6
 
7
7
  from flip.utils import create_log
8
8
 
9
- from .config import __use_jax__
9
+ from .._config import __use_jax__
10
10
 
11
11
  if __use_jax__:
12
12
  try:
@@ -44,6 +44,21 @@ _available_inversion_methods = [
44
44
 
45
45
 
46
46
  def log_likelihood_gaussian_inverse(vector, covariance_sum):
47
+ """Compute multivariate Gaussian log-likelihood using explicit inverse.
48
+
49
+ This evaluates $\mathcal{L} = -\tfrac{1}{2}[N\log(2\pi) + \log|C| + \chi^2]$
50
+ with $\chi^2 = v^T C^{-1} v$ by explicitly inverting the covariance matrix.
51
+
52
+ Args:
53
+ vector (array-like): Residual data vector `v` of shape `(N,)`.
54
+ covariance_sum (array-like): Total covariance matrix `C` of shape `(N, N)`.
55
+
56
+ Returns:
57
+ float: Log-likelihood value of the Gaussian model.
58
+
59
+ Notes:
60
+ Prefer solve or Cholesky variants for better numerical stability on ill-conditioned matrices.
61
+ """
47
62
  _, logdet = jnp.linalg.slogdet(covariance_sum)
48
63
  inverse_covariance_sum = jnp.linalg.inv(covariance_sum)
49
64
  chi2 = jnp.dot(vector, jnp.dot(inverse_covariance_sum, vector))
@@ -51,12 +66,36 @@ def log_likelihood_gaussian_inverse(vector, covariance_sum):
51
66
 
52
67
 
53
68
  def log_likelihood_gaussian_solve(vector, covariance_sum):
69
+ """Compute multivariate Gaussian log-likelihood via linear solver.
70
+
71
+ Uses `solve(C, v)` to avoid explicit inversion when computing $\chi^2 = v^T C^{-1} v$.
72
+
73
+ Args:
74
+ vector (array-like): Residual data vector `v` of shape `(N,)`.
75
+ covariance_sum (array-like): Total covariance matrix `C` of shape `(N, N)`.
76
+
77
+ Returns:
78
+ float: Log-likelihood value of the Gaussian model.
79
+
80
+ """
54
81
  _, logdet = jnp.linalg.slogdet(covariance_sum)
55
82
  chi2 = jnp.dot(vector.T, jnp.linalg.solve(covariance_sum, vector))
56
83
  return -0.5 * (vector.size * jnp.log(2 * jnp.pi) + logdet + chi2)
57
84
 
58
85
 
59
86
  def log_likelihood_gaussian_cholesky(vector, covariance_sum):
87
+ """Compute Gaussian log-likelihood using Cholesky factorization.
88
+
89
+ Factorizes `C = L L^T` to compute both `log|C|` and $\chi^2$ stably.
90
+
91
+ Args:
92
+ vector (array-like): Residual data vector `v` of shape `(N,)`.
93
+ covariance_sum (array-like): Positive-definite covariance matrix `C`.
94
+
95
+ Returns:
96
+ float: Log-likelihood value of the Gaussian model.
97
+
98
+ """
60
99
  cholesky = jsc.linalg.cho_factor(covariance_sum)
61
100
  logdet = 2 * jnp.sum(jnp.log(jnp.diag(cholesky[0])))
62
101
  chi2 = jnp.dot(vector, jsc.linalg.cho_solve(cholesky, vector))
@@ -64,13 +103,37 @@ def log_likelihood_gaussian_cholesky(vector, covariance_sum):
64
103
 
65
104
 
66
105
  def log_likelihood_gaussian_cholesky_inverse(vector, covariance_sum):
106
+ """Compute Gaussian log-likelihood using Cholesky, fallback to inverse.
107
+
108
+ Attempts a Cholesky factorization; if it fails (non-PD matrix), falls back to
109
+ explicit inversion-based computation.
110
+
111
+ Args:
112
+ vector (array-like): Residual data vector `v` of shape `(N,)`.
113
+ covariance_sum (array-like): Covariance matrix `C` (ideally PD).
114
+
115
+ Returns:
116
+ float: Log-likelihood value of the Gaussian model.
117
+ """
67
118
  try:
68
119
  return log_likelihood_gaussian_cholesky(vector, covariance_sum)
69
- except:
120
+ except jnp.linalg.LinAlgError:
70
121
  return log_likelihood_gaussian_inverse(vector, covariance_sum)
71
122
 
72
123
 
73
124
  def log_likelihood_gaussian_cholesky_regularized(vector, covariance_sum):
125
+ """Compute Gaussian log-likelihood with eigenvalue regularization.
126
+
127
+ Ensures positive-definiteness by replacing negative eigenvalues with their absolute
128
+ values before Cholesky factorization.
129
+
130
+ Args:
131
+ vector (array-like): Residual data vector `v` of shape `(N,)`.
132
+ covariance_sum (array-like): Covariance matrix `C` that may be indefinite.
133
+
134
+ Returns:
135
+ float: Log-likelihood value of the Gaussian model.
136
+ """
74
137
  eigval, eigvec = jnp.linalg.eig(covariance_sum)
75
138
  cov_sum_regularized = eigvec @ jnp.abs(jnp.diag(eigval)) @ jnp.linalg.inv(eigvec)
76
139
  cholesky = jsc.linalg.cho_factor(cov_sum_regularized)
@@ -89,14 +152,47 @@ if jax_installed:
89
152
 
90
153
 
91
154
  def no_prior(x):
155
+ """Return zero prior contribution.
156
+
157
+ Args:
158
+ x (Any): Ignored parameter values input.
159
+
160
+ Returns:
161
+ int: Zero.
162
+ """
92
163
  return 0
93
164
 
94
165
 
95
166
  def prior_sum(priors, x):
167
+ """Sum multiple prior contributions.
168
+
169
+ Args:
170
+ priors (list[Callable]): List of prior callables accepting a parameter dict.
171
+ x (dict): Parameter values dictionary.
172
+
173
+ Returns:
174
+ float: Sum of all prior log-probabilities.
175
+ """
96
176
  return sum(prior(x) for prior in priors)
97
177
 
98
178
 
99
179
  class BaseLikelihood(abc.ABC):
180
+ """Abstract base class for likelihood evaluation.
181
+
182
+ Provides common setup for covariance verification, properties validation, prior
183
+ initialization, and building the callable likelihood and optional gradient.
184
+
185
+ Attributes:
186
+ covariance (CovMatrix or list[CovMatrix]): Covariance model(s).
187
+ data (object): Data provider with `free_par` and `give_data_and_variance`.
188
+ parameter_names (list[str]): Ordered parameter names.
189
+ free_par (list[str]): Combined free parameters from data and covariance.
190
+ likelihood_properties (dict): Controls inversion method, sign, JIT, gradients.
191
+ prior (Callable): Prior function returning log-prior given parameter dict.
192
+ likelihood_call (Callable): Callable evaluating the likelihood.
193
+ likelihood_grad (Callable|None): Gradient of likelihood if JAX is available.
194
+
195
+ """
100
196
 
101
197
  _default_likelihood_properties = {
102
198
  "inversion_method": "inverse",
@@ -129,15 +225,29 @@ class BaseLikelihood(abc.ABC):
129
225
  }
130
226
 
131
227
  self.verify_covariance()
228
+ self.verify_properties()
132
229
  self.prior = self.initialize_prior()
133
230
 
134
231
  self.likelihood_call, self.likelihood_grad = self._init_likelihood()
135
232
 
136
233
  def __call__(self, parameter_values):
234
+ """Evaluate likelihood at parameter values.
235
+
236
+ Args:
237
+ parameter_values (array-like): Parameter vector aligned with `parameter_names`.
238
+
239
+ Returns:
240
+ float: Likelihood value, sign controlled by `negative_log_likelihood`.
241
+ """
137
242
  return self.likelihood_call(parameter_values)
138
243
 
139
244
  @abc.abstractmethod
140
245
  def _init_likelihood(self, *args):
246
+ """Initialize likelihood and optional gradient.
247
+
248
+ Returns:
249
+ tuple[Callable, Callable|None]: `(likelihood_call, likelihood_grad)`.
250
+ """
141
251
  likelihood_fun = None
142
252
  likelihood_grad = None
143
253
  return likelihood_fun, likelihood_grad
@@ -151,21 +261,17 @@ class BaseLikelihood(abc.ABC):
151
261
  likelihood_properties={},
152
262
  **kwargs,
153
263
  ):
154
- """
155
- The init_from_covariance function is a class method that initializes the likelihood object from a covariance matrix.
264
+ """Construct a likelihood instance from a covariance.
156
265
 
157
266
  Args:
158
- cls: Create a new instance of the class
159
- covariance: Compute the full matrix of the covariance
160
- parameter_names: Set the names of the parameters
161
- density: Compute the vector and its error
162
- density_err: Compute the vector_err
163
- velocity: Compute the vector and vector_err
164
- velocity_err: Compute the error in the vector
165
- : Compute the vector
267
+ covariance (CovMatrix or list[CovMatrix]): Covariance model(s) to use.
268
+ data (object): Data provider used by likelihood to build residuals/errors.
269
+ parameter_names (list[str]): Parameter names ordering the input vector.
270
+ likelihood_properties (dict, optional): Likelihood options overriding defaults.
271
+ **kwargs: Extra arguments forwarded to subclass constructor.
166
272
 
167
273
  Returns:
168
- A likelihood object
274
+ BaseLikelihood: Initialized likelihood instance.
169
275
 
170
276
  """
171
277
 
@@ -182,6 +288,14 @@ class BaseLikelihood(abc.ABC):
182
288
  def initialize_prior(
183
289
  self,
184
290
  ):
291
+ """Build prior function from likelihood properties.
292
+
293
+ Returns:
294
+ Callable: Prior function mapping parameter dict to log-prior.
295
+
296
+ Raises:
297
+ ValueError: If an unsupported prior type is requested.
298
+ """
185
299
  if "prior" not in self.likelihood_properties.keys():
186
300
  return no_prior
187
301
  else:
@@ -214,6 +328,11 @@ class BaseLikelihood(abc.ABC):
214
328
  return prior_function
215
329
 
216
330
  def verify_covariance(self):
331
+ """Ensure covariance matrices are ready for likelihood evaluation.
332
+
333
+ Converts flat covariances to matrix form if required and initializes the
334
+ cached `compute_covariance_sum` (and JIT variant) functions.
335
+ """
217
336
  if isinstance(self.covariance, list):
218
337
  for i in range(len(self.covariance)):
219
338
  if self.covariance[i].matrix_form is False:
@@ -234,8 +353,28 @@ class BaseLikelihood(abc.ABC):
234
353
  ):
235
354
  self.covariance.init_compute_covariance_sum()
236
355
 
356
+ def verify_properties(self):
357
+ """Validate likelihood properties such as inversion method.
358
+
359
+ Raises:
360
+ ValueError: If the inversion method is not supported.
361
+ """
362
+ if (
363
+ self.likelihood_properties["inversion_method"]
364
+ not in _available_inversion_methods
365
+ ):
366
+ raise ValueError(
367
+ f"""The inversion method {self.likelihood_properties['inversion_method']} is not available. """
368
+ f"""Please choose between {_available_inversion_methods}"""
369
+ )
370
+
237
371
 
238
372
  class MultivariateGaussianLikelihood(BaseLikelihood):
373
+ """Gaussian likelihood for a single covariance model.
374
+
375
+ Supports multiple inversion strategies and optional JAX JIT/grad for speed.
376
+ """
377
+
239
378
  def __init__(
240
379
  self,
241
380
  covariance=None,
@@ -251,6 +390,11 @@ class MultivariateGaussianLikelihood(BaseLikelihood):
251
390
  )
252
391
 
253
392
  def _init_likelihood(self):
393
+ """Build callable likelihood and optional gradient for Gaussian model.
394
+
395
+ Returns:
396
+ tuple[Callable, Callable|None]: `(likelihood_call, likelihood_grad)`.
397
+ """
254
398
 
255
399
  use_jit = self.likelihood_properties["use_jit"]
256
400
 
@@ -261,6 +405,7 @@ class MultivariateGaussianLikelihood(BaseLikelihood):
261
405
 
262
406
  give_data_and_variance = eval(f"self.data.give_data_and_variance{suffix}")
263
407
  compute_covariance_sum = eval(f"self.covariance.compute_covariance_sum{suffix}")
408
+
264
409
  likelihood_function = eval(
265
410
  f"log_likelihood_gaussian_{self.likelihood_properties['inversion_method']}{suffix}"
266
411
  )
@@ -270,37 +415,51 @@ class MultivariateGaussianLikelihood(BaseLikelihood):
270
415
  else:
271
416
  prior = self.prior
272
417
 
273
- def likelihood_evaluation(parameter_values, neg_like=False):
418
+ def likelihood_evaluation(
419
+ parameter_values,
420
+ covariance_prefactor_dict=None,
421
+ ):
422
+ """Evaluate likelihood for given parameters.
423
+
424
+ Args:
425
+ parameter_values (array-like): Parameter vector aligned to names.
426
+ covariance_prefactor_dict (dict, optional): Prefactors per block (gg/gv/vv).
427
+
428
+ Returns:
429
+ float: Likelihood value (sign depends on `negative_log_likelihood`).
430
+ """
274
431
  parameter_values_dict = dict(zip(self.parameter_names, parameter_values))
275
432
  vector, vector_variance = give_data_and_variance(parameter_values_dict)
276
433
  covariance_sum = compute_covariance_sum(
277
- parameter_values_dict, vector_variance
434
+ parameter_values_dict,
435
+ vector_variance,
436
+ covariance_prefactor_dict=covariance_prefactor_dict,
278
437
  )
279
438
  likelihood_value = likelihood_function(vector, covariance_sum) + prior(
280
439
  parameter_values_dict
281
440
  )
282
441
 
283
- if neg_like:
442
+ if self.likelihood_properties["negative_log_likelihood"]:
284
443
  likelihood_value *= -1
285
444
  return likelihood_value
286
445
 
287
- if self.likelihood_properties["negative_log_likelihood"]:
288
- neg_like = True
289
- else:
290
- neg_like = False
291
-
292
- likelihood_fun = partial(likelihood_evaluation, neg_like=neg_like)
293
446
  if jax_installed:
294
- likelihood_grad = grad(likelihood_fun)
447
+ likelihood_grad = grad(likelihood_evaluation)
295
448
  if use_jit:
296
- likelihood_fun = jit(likelihood_fun)
449
+ likelihood_evaluation = jit(likelihood_evaluation)
297
450
  likelihood_grad = jit(likelihood_grad)
298
451
  else:
299
452
  likelihood_grad = None
300
- return likelihood_fun, likelihood_grad
453
+ return likelihood_evaluation, likelihood_grad
301
454
 
302
455
 
303
456
  class MultivariateGaussianLikelihoodInterpolate1D(BaseLikelihood):
457
+ """Gaussian likelihood with 1D interpolation over precomputed covariances.
458
+
459
+ Interpolates the covariance matrix across a scalar parameter grid to avoid
460
+ regeneration during fits.
461
+ """
462
+
304
463
  def __init__(
305
464
  self,
306
465
  covariance=None,
@@ -310,23 +469,18 @@ class MultivariateGaussianLikelihoodInterpolate1D(BaseLikelihood):
310
469
  interpolation_value_name=None,
311
470
  interpolation_value_range=None,
312
471
  ):
313
- """
314
- The __init__ function is called when the class is instantiated.
315
- It sets up the instance of the class, and defines all its attributes.
316
- The __init__ function takes arguments, which are then assigned to object attributes:
472
+ """Initialize 1D interpolation likelihood.
317
473
 
318
474
  Args:
319
- self: Represent the instance of the class
320
- covariance: Set the covariance matrix of the likelihood
321
- data: Store the data
322
- parameter_names: Specify the names of the parameters that are used in this likelihood
323
- likelihood_properties: Pass in the interpolation_value_name and interpolation_value_range
324
- interpolation_value_name: Specify the name of the parameter that is being interpolated
325
- interpolation_value_range: Specify the range of values that will be used to interpolate
326
- : Define the interpolation value name
475
+ covariance (list[CovMatrix]): Covariance models sampled along interpolation axis.
476
+ data (object): Data provider with residuals and variance.
477
+ parameter_names (list[str]): Parameter names ordering the input vector.
478
+ likelihood_properties (dict, optional): Likelihood options overriding defaults.
479
+ interpolation_value_name (str): Name of the interpolation parameter.
480
+ interpolation_value_range (array-like): Sorted grid of interpolation values.
327
481
 
328
482
  Returns:
329
- The object itself
483
+ None: Initializes attributes and base class.
330
484
  """
331
485
  self.interpolation_value_name = interpolation_value_name
332
486
  self.interpolation_value_range = interpolation_value_range
@@ -341,6 +495,11 @@ class MultivariateGaussianLikelihoodInterpolate1D(BaseLikelihood):
341
495
  self.free_par = [interpolation_value_name] + self.free_par
342
496
 
343
497
  def _init_likelihood(self):
498
+ """Build callable likelihood and optional gradient for 1D interpolation.
499
+
500
+ Returns:
501
+ tuple[Callable, Callable|None]: `(likelihood_call, likelihood_grad)`.
502
+ """
344
503
  use_jit = self.likelihood_properties["use_jit"]
345
504
 
346
505
  if jax_installed & use_jit:
@@ -367,7 +526,22 @@ class MultivariateGaussianLikelihoodInterpolate1D(BaseLikelihood):
367
526
  else:
368
527
  prior = self.prior
369
528
 
370
- def likelihood_evaluation(parameter_values, neg_like=False):
529
+ def likelihood_evaluation(
530
+ parameter_values,
531
+ covariance_prefactor_dict=None,
532
+ ):
533
+ """Evaluate likelihood with interpolated covariance.
534
+
535
+ Performs linear interpolation between nearest covariances along the
536
+ interpolation axis.
537
+
538
+ Args:
539
+ parameter_values (array-like): Parameter vector.
540
+ covariance_prefactor_dict (dict, optional): Prefactors per block.
541
+
542
+ Returns:
543
+ float: Likelihood value including prior on interpolation range.
544
+ """
371
545
  parameter_values_dict = dict(zip(self.parameter_names, parameter_values))
372
546
  interpolation_value = parameter_values_dict[self.interpolation_value_name]
373
547
 
@@ -383,10 +557,17 @@ class MultivariateGaussianLikelihoodInterpolate1D(BaseLikelihood):
383
557
  upper_index = jnp.searchsorted(
384
558
  interpolation_value_range, interpolation_value
385
559
  )
560
+ upper_index = jnp.min(
561
+ jnp.array([upper_index, len(interpolation_value_range) - 1])
562
+ )
386
563
 
387
564
  covariance_sum_list = jnp.array(
388
565
  [
389
- compute_covariance_sum(parameter_values_dict, vector_variance)
566
+ compute_covariance_sum(
567
+ parameter_values_dict,
568
+ vector_variance,
569
+ covariance_prefactor_dict=covariance_prefactor_dict,
570
+ )
390
571
  for compute_covariance_sum in compute_covariance_sum_list
391
572
  ]
392
573
  )
@@ -414,29 +595,32 @@ class MultivariateGaussianLikelihoodInterpolate1D(BaseLikelihood):
414
595
  + prior_interpolation_range
415
596
  )
416
597
 
417
- if neg_like:
598
+ if self.likelihood_properties["negative_log_likelihood"]:
418
599
  likelihood_value *= -1
419
600
  return likelihood_value
420
601
 
421
- if self.likelihood_properties["negative_log_likelihood"]:
422
- neg_like = True
423
- else:
424
- neg_like = False
425
-
426
- likelihood_fun = partial(likelihood_evaluation, neg_like=neg_like)
427
-
428
602
  if jax_installed:
429
- likelihood_grad = grad(likelihood_fun)
603
+ likelihood_grad = grad(likelihood_evaluation)
430
604
  if use_jit:
431
- likelihood_fun = jit(likelihood_fun)
605
+ likelihood_evaluation = jit(likelihood_evaluation)
432
606
  likelihood_grad = jit(likelihood_grad)
433
607
  else:
434
608
  likelihood_grad = None
435
609
 
436
- return likelihood_fun, likelihood_grad
610
+ return likelihood_evaluation, likelihood_grad
611
+
612
+
613
+ # CR - This class is no deprecated, do not use it for now.
437
614
 
438
615
 
439
616
  class MultivariateGaussianLikelihoodInterpolate2D(BaseLikelihood):
617
+ """Deprecated 2D interpolation Gaussian likelihood.
618
+
619
+ Note:
620
+ Uses `scipy.interpolate.interp2d`, which is deprecated upstream.
621
+ Prefer emulator-based or grid-based approaches.
622
+ """
623
+
440
624
  def __init__(
441
625
  self,
442
626
  covariance=None,
@@ -482,22 +666,16 @@ class MultivariateGaussianLikelihoodInterpolate2D(BaseLikelihood):
482
666
  def __call__(
483
667
  self,
484
668
  parameter_values,
669
+ covariance_prefactor_dict=None,
485
670
  ):
486
- """
487
- The __call__ function is the function that will be called when the likelihood
488
- object is called. It takes in a list of parameter values, and returns a float
489
- value representing the log-likelihood value for those parameters. The __call__
490
- method should not be overwritten by subclasses unless you know what you are doing!
671
+ """Evaluate 2D interpolated likelihood.
491
672
 
492
673
  Args:
493
- self: Refer to the object itself
494
- parameter_values: Compute the covariance matrix
495
- interpolation_value_0: Interpolate the covariance matrix along the first dimension
496
- interpolation_value_1: Interpolate the covariance matrix
497
- : Compute the covariance sum
674
+ parameter_values (array-like): Parameter vector aligned to names.
675
+ covariance_prefactor_dict (dict, optional): Prefactors per covariance block.
498
676
 
499
677
  Returns:
500
- The log-likelihood function
678
+ float: Log-likelihood value (sign depends on `negative_log_likelihood`).
501
679
  """
502
680
  parameter_values_dict = dict(zip(self.parameter_names, parameter_values))
503
681
 
@@ -515,7 +693,7 @@ class MultivariateGaussianLikelihoodInterpolate2D(BaseLikelihood):
515
693
  else:
516
694
  return -np.inf
517
695
 
518
- vector, vector_variance = self.data(parameter_values)
696
+ vector, vector_variance = self.data.give_data_and_variance(parameter_values)
519
697
 
520
698
  covariance_sum_matrix = []
521
699
 
@@ -524,7 +702,9 @@ class MultivariateGaussianLikelihoodInterpolate2D(BaseLikelihood):
524
702
  for j in range(len(self.covariance[i])):
525
703
  covariance_sum_matrix_i.append(
526
704
  self.covariance[i][j].compute_covariance_sum(
527
- parameter_values_dict, vector_variance
705
+ parameter_values_dict,
706
+ vector_variance,
707
+ covariance_prefactor_dict=covariance_prefactor_dict,
528
708
  )
529
709
  )
530
710
  covariance_sum_matrix.append(covariance_sum_matrix_i)
@@ -549,6 +729,12 @@ class MultivariateGaussianLikelihoodInterpolate2D(BaseLikelihood):
549
729
 
550
730
 
551
731
  class Prior:
732
+ """Base prior class encapsulating parameter-specific priors.
733
+
734
+ Attributes:
735
+ parameter_name (str): Name of the parameter this prior applies to.
736
+ """
737
+
552
738
  def __init__(
553
739
  self,
554
740
  parameter_name=None,
@@ -557,6 +743,10 @@ class Prior:
557
743
 
558
744
 
559
745
  class GaussianPrior(Prior):
746
+ """Univariate Gaussian prior on a parameter.
747
+
748
+ Models $p(\theta) \propto \exp\{-\tfrac{1}{2}[(\theta-\mu)^2/\sigma^2]\}$.
749
+ """
560
750
 
561
751
  def __init__(
562
752
  self,
@@ -572,6 +762,14 @@ class GaussianPrior(Prior):
572
762
  self,
573
763
  parameter_values_dict,
574
764
  ):
765
+ """Return Gaussian log-prior for the parameter value.
766
+
767
+ Args:
768
+ parameter_values_dict (dict): Map of parameter names to values.
769
+
770
+ Returns:
771
+ float: Log-prior value.
772
+ """
575
773
  return -0.5 * (
576
774
  np.log(2 * jnp.pi * self.prior_standard_deviation**2)
577
775
  + (parameter_values_dict[self.parameter_name] - self.prior_mean) ** 2
@@ -580,6 +778,10 @@ class GaussianPrior(Prior):
580
778
 
581
779
 
582
780
  class PositivePrior(Prior):
781
+ """Log-prior enforcing parameter positivity via Heaviside function.
782
+
783
+ Returns `log(Heaviside(value))`, which is `0` for positive values and `-inf` otherwise.
784
+ """
583
785
 
584
786
  def __init__(
585
787
  self,
@@ -591,10 +793,22 @@ class PositivePrior(Prior):
591
793
  self,
592
794
  parameter_values_dict,
593
795
  ):
796
+ """Return log-prior that is zero for positive values, -inf otherwise.
797
+
798
+ Args:
799
+ parameter_values_dict (dict): Map of parameter names to values.
800
+
801
+ Returns:
802
+ float: Log-prior (0 or -inf).
803
+ """
594
804
  return jnp.log(jnp.heaviside(parameter_values_dict[self.parameter_name], 0))
595
805
 
596
806
 
597
807
  class UniformPrior(Prior):
808
+ """Uniform prior over a finite interval for a parameter.
809
+
810
+ Uses `scipy.stats.uniform.logpdf` for the specified range.
811
+ """
598
812
 
599
813
  def __init__(self, parameter_name=None, range=None):
600
814
  super().__init__(parameter_name=parameter_name)
@@ -604,6 +818,14 @@ class UniformPrior(Prior):
604
818
  self,
605
819
  parameter_values_dict,
606
820
  ):
821
+ """Return uniform log-prior over `[range[0], range[1]]`.
822
+
823
+ Args:
824
+ parameter_values_dict (dict): Map of parameter names to values.
825
+
826
+ Returns:
827
+ float: Log-prior value (constant inside range, -inf outside).
828
+ """
607
829
  value = parameter_values_dict[self.parameter_name]
608
830
  return jsc.stats.uniform.logpdf(
609
831
  value, loc=self.range[0], scale=self.range[1] - self.range[0]