flipcosmo 1.0.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. docs/conf.py +154 -0
  2. flip/__init__.py +4 -11
  3. flip/covariance/__init__.py +7 -8
  4. flip/covariance/analytical/__init__.py +11 -0
  5. flip/covariance/{adamsblake17plane → analytical/adamsblake17}/coefficients.py +1 -1
  6. flip/covariance/{adamsblake17plane → analytical/adamsblake17}/fisher_terms.py +1 -1
  7. flip/covariance/{adamsblake17 → analytical/adamsblake17}/flip_terms.py +0 -1
  8. flip/covariance/{adamsblake17 → analytical/adamsblake17plane}/coefficients.py +1 -1
  9. flip/covariance/{adamsblake17 → analytical/adamsblake17plane}/fisher_terms.py +1 -1
  10. flip/covariance/{adamsblake17plane → analytical/adamsblake17plane}/flip_terms.py +0 -1
  11. flip/covariance/{adamsblake17plane → analytical/adamsblake17plane}/generator.py +103 -19
  12. flip/covariance/{adamsblake20 → analytical/adamsblake20}/coefficients.py +1 -1
  13. flip/covariance/{adamsblake20 → analytical/adamsblake20}/fisher_terms.py +1 -1
  14. flip/covariance/{adamsblake20 → analytical/adamsblake20}/flip_terms.py +0 -1
  15. flip/covariance/{carreres23 → analytical/carreres23}/coefficients.py +1 -4
  16. flip/covariance/{ravouxnoanchor25 → analytical/carreres23}/fisher_terms.py +1 -1
  17. flip/covariance/{carreres23 → analytical/carreres23}/flip_terms.py +0 -1
  18. flip/covariance/analytical/carreres23/generator.py +198 -0
  19. flip/covariance/analytical/genericzdep/__init__.py +5 -0
  20. flip/covariance/analytical/genericzdep/coefficients.py +53 -0
  21. flip/covariance/analytical/genericzdep/flip_terms.py +99 -0
  22. flip/covariance/{lai22 → analytical/lai22}/coefficients.py +2 -3
  23. flip/covariance/{lai22 → analytical/lai22}/fisher_terms.py +1 -1
  24. flip/covariance/{lai22 → analytical/lai22}/flip_terms.py +0 -1
  25. flip/covariance/{lai22 → analytical/lai22}/generator.py +263 -58
  26. flip/covariance/{lai22 → analytical/lai22}/symbolic.py +55 -19
  27. flip/covariance/{ravouxcarreres → analytical/ravouxcarreres}/coefficients.py +1 -1
  28. flip/covariance/{ravouxcarreres → analytical/ravouxcarreres}/fisher_terms.py +1 -1
  29. flip/covariance/{ravouxcarreres → analytical/ravouxcarreres}/flip_terms.py +0 -1
  30. flip/covariance/{ravouxnoanchor25 → analytical/ravouxnoanchor25}/coefficients.py +3 -2
  31. flip/covariance/{carreres23 → analytical/ravouxnoanchor25}/fisher_terms.py +1 -1
  32. flip/covariance/{ravouxnoanchor25 → analytical/ravouxnoanchor25}/flip_terms.py +0 -9
  33. flip/covariance/{rcrk24 → analytical/rcrk24}/coefficients.py +6 -6
  34. flip/covariance/{rcrk24 → analytical/rcrk24}/fisher_terms.py +7 -9
  35. flip/covariance/{rcrk24 → analytical/rcrk24}/flip_terms.py +0 -8
  36. flip/covariance/contraction.py +82 -40
  37. flip/covariance/cov_utils.py +89 -81
  38. flip/covariance/covariance.py +172 -141
  39. flip/covariance/emulators/__init__.py +1 -1
  40. flip/covariance/emulators/generator.py +73 -3
  41. flip/covariance/emulators/gpmatrix.py +40 -1
  42. flip/covariance/emulators/nnmatrix.py +57 -1
  43. flip/covariance/emulators/skgpmatrix.py +125 -0
  44. flip/covariance/fisher.py +307 -0
  45. flip/{fit_utils.py → covariance/fit_utils.py} +185 -10
  46. flip/{fitter.py → covariance/fitter.py} +151 -125
  47. flip/covariance/generator.py +82 -106
  48. flip/{likelihood.py → covariance/likelihood.py} +286 -64
  49. flip/{plot_utils.py → covariance/plot_utils.py} +79 -4
  50. flip/covariance/symbolic.py +89 -44
  51. flip/data/__init__.py +1 -1
  52. flip/data/data_density.parquet +0 -0
  53. flip/data/data_velocity.parquet +0 -0
  54. flip/data/{grid_window_m.parquet → data_window_density.parquet} +0 -0
  55. flip/{gridding.py → data/gridding.py} +125 -130
  56. flip/data/load_data_test.py +102 -0
  57. flip/data/power_spectrum_mm.txt +2 -2
  58. flip/data/power_spectrum_mt.txt +2 -2
  59. flip/data/power_spectrum_tt.txt +2 -2
  60. flip/data/test_covariance_reference_values.json +145 -0
  61. flip/data/test_e2e_reference_values.json +14 -0
  62. flip/data_vector/basic.py +118 -101
  63. flip/data_vector/cosmo_utils.py +18 -0
  64. flip/data_vector/galaxypv_vectors.py +58 -94
  65. flip/data_vector/snia_vectors.py +60 -3
  66. flip/data_vector/vector_utils.py +47 -1
  67. flip/power_spectra/class_engine.py +36 -1
  68. flip/power_spectra/cosmoprimo_engine.py +37 -2
  69. flip/power_spectra/generator.py +47 -25
  70. flip/power_spectra/models.py +30 -31
  71. flip/power_spectra/pyccl_engine.py +36 -1
  72. flip/simulation/__init__.py +0 -0
  73. flip/utils.py +62 -91
  74. flipcosmo-1.2.1.dist-info/METADATA +78 -0
  75. flipcosmo-1.2.1.dist-info/RECORD +109 -0
  76. {flipcosmo-1.0.0.dist-info → flipcosmo-1.2.1.dist-info}/WHEEL +1 -1
  77. flipcosmo-1.2.1.dist-info/top_level.txt +7 -0
  78. scripts/flip_compute_correlation_model.py +70 -0
  79. scripts/flip_compute_power_spectra.py +50 -0
  80. scripts/flip_fisher_forecast_velocity.py +70 -0
  81. scripts/flip_fisher_rcrk24.py +164 -0
  82. scripts/flip_launch_minuit_density_fit.py +91 -0
  83. scripts/flip_launch_minuit_full_fit.py +117 -0
  84. scripts/flip_launch_minuit_velocity_fit.py +78 -0
  85. scripts/flip_launch_minuit_velocity_fit_full.py +107 -0
  86. scripts/flip_launch_minuit_velocity_fit_interpolation.py +93 -0
  87. test/refresh_reference_values.py +43 -0
  88. test/test_covariance_assembly.py +102 -0
  89. test/test_covariance_reference_values.py +125 -0
  90. test/test_covariance_utils.py +34 -0
  91. test/test_e2e_density.py +50 -0
  92. test/test_e2e_joint.py +65 -0
  93. test/test_e2e_velocity.py +53 -0
  94. test/test_likelihood_inversions.py +31 -0
  95. flip/covariance/carreres23/generator.py +0 -132
  96. flip/data/density_data.parquet +0 -0
  97. flip/data/velocity_data.parquet +0 -0
  98. flip/fisher.py +0 -190
  99. flipcosmo-1.0.0.dist-info/METADATA +0 -32
  100. flipcosmo-1.0.0.dist-info/RECORD +0 -82
  101. flipcosmo-1.0.0.dist-info/top_level.txt +0 -1
  102. /flip/{config.py → _config.py} +0 -0
  103. /flip/covariance/{adamsblake17 → analytical/adamsblake17}/__init__.py +0 -0
  104. /flip/covariance/{adamsblake17plane → analytical/adamsblake17plane}/__init__.py +0 -0
  105. /flip/covariance/{adamsblake20 → analytical/adamsblake20}/__init__.py +0 -0
  106. /flip/covariance/{carreres23 → analytical/carreres23}/__init__.py +0 -0
  107. /flip/covariance/{lai22 → analytical/lai22}/__init__.py +0 -0
  108. /flip/covariance/{lai22 → analytical/lai22}/h_terms.py +0 -0
  109. /flip/covariance/{ravouxcarreres → analytical/ravouxcarreres}/__init__.py +0 -0
  110. /flip/covariance/{ravouxcarreres → analytical/ravouxcarreres}/flip_terms_lmax.py +0 -0
  111. /flip/covariance/{ravouxnoanchor25 → analytical/ravouxnoanchor25}/__init__.py +0 -0
  112. /flip/covariance/{rcrk24 → analytical/rcrk24}/__init__.py +0 -0
  113. {flipcosmo-1.0.0.dist-info → flipcosmo-1.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -7,15 +7,24 @@ import emcee
7
7
  import iminuit
8
8
  import numpy as np
9
9
 
10
+ import flip.covariance.likelihood as flik
11
+ from flip.covariance.covariance import CovMatrix
10
12
  from flip.utils import create_log
11
13
 
12
14
  log = create_log()
13
15
 
14
- import flip.likelihood as flik
15
- from flip.covariance.covariance import CovMatrix
16
-
17
16
 
18
17
  class BaseFitter(abc.ABC):
18
+ """Abstract interface for fitters.
19
+
20
+ Provides common wiring between covariance, data, and likelihood construction,
21
+ and defines the contract for initialization from covariance or files.
22
+
23
+ Attributes:
24
+ covariance (CovMatrix): Covariance model to use for fits.
25
+ data (object): Data provider passed to likelihoods.
26
+ """
27
+
19
28
  def __init__(
20
29
  self,
21
30
  covariance=None,
@@ -42,13 +51,10 @@ class BaseFitter(abc.ABC):
42
51
  def init_from_covariance(
43
52
  cls,
44
53
  ):
45
- """
46
- The init_from_covariance function is a class method that initializes the
47
- fitter from the covariance matrix. It is here an abstract method that
48
- needs to be override
54
+ """Initialize fitter from covariance.
49
55
 
50
- Args:
51
- cls: Pass a class object into a method
56
+ Returns:
57
+ BaseFitter: Implementations must return an initialized fitter.
52
58
  """
53
59
  return
54
60
 
@@ -56,13 +62,10 @@ class BaseFitter(abc.ABC):
56
62
  def init_from_file(
57
63
  cls,
58
64
  ):
59
- """
60
- The init_from_covariance function is a class method that initializes the
61
- fitter from the a file containing covariance matrix. It is here an
62
- abstract method that needs to be override
65
+ """Initialize fitter from a covariance file.
63
66
 
64
- Args:
65
- cls: Pass a class object into a method
67
+ Returns:
68
+ BaseFitter: Implementations must return an initialized fitter.
66
69
  """
67
70
  return
68
71
 
@@ -73,20 +76,16 @@ class BaseFitter(abc.ABC):
73
76
  likelihood_properties=None,
74
77
  **kwargs,
75
78
  ):
76
- """
77
- The get_likelihood function is used to create a likelihood object from the covariance matrix.
78
- The function takes in a dictionary of parameters, and returns an instance of the likelihood class.
79
-
79
+ """Construct a likelihood from the fitter's covariance and data.
80
80
 
81
81
  Args:
82
- self: Bind the method to a class
83
- parameter_dict: Pass the parameters to be used in the likelihood function
84
- likelihood_type: Select the likelihood class
85
- : Select the likelihood function
82
+ parameter_dict (dict): Parameters with keys as names and values/priors.
83
+ likelihood_type (str): Likelihood class key; see `select_likelihood`.
84
+ likelihood_properties (dict, optional): Options overriding defaults.
85
+ **kwargs: Extra args forwarded to likelihood constructors.
86
86
 
87
87
  Returns:
88
- A likelihood object
89
-
88
+ BaseLikelihood: Initialized likelihood instance.
90
89
  """
91
90
 
92
91
  parameter_names = [parameters for parameters in parameter_dict]
@@ -105,15 +104,14 @@ class BaseFitter(abc.ABC):
105
104
 
106
105
  @staticmethod
107
106
  def select_likelihood(likelihood_type):
108
- """
109
- The select_likelihood function takes in a string, likelihood_type, and returns the corresponding class.
107
+ """Map a likelihood type key to its class.
110
108
 
111
109
  Args:
112
- likelihood_type: Determine which likelihood function to use
110
+ likelihood_type (str): One of `multivariate_gaussian`,
111
+ `multivariate_gaussian_interp1d`, `multivariate_gaussian_interp2d`.
113
112
 
114
113
  Returns:
115
- The likelihood class
116
-
114
+ type: Likelihood class.
117
115
  """
118
116
  if likelihood_type == "multivariate_gaussian":
119
117
  likelihood_class = flik.MultivariateGaussianLikelihood
@@ -134,22 +132,13 @@ class FitMinuit(BaseFitter):
134
132
  likelihood=None,
135
133
  minuit=None,
136
134
  ):
137
- """
138
- The __init__ function is called when the class is instantiated.
139
- It sets up the instance of the class, and defines all of its attributes.
140
- The self argument refers to the instance of this object that has been created.
135
+ """Initialize Minuit fitter.
141
136
 
142
137
  Args:
143
- self: Represent the instance of the class
144
- covariance: Pass the covariance matrix to the fit
145
- data: Pass the data to be fitted
146
- likelihood: Pass the likelihood function to the fit
147
- minuit: Pass a minuit object to the fitminuit class
148
- : Set the minuit object
149
-
150
- Returns:
151
- The object that is being created
152
-
138
+ covariance (CovMatrix, optional): Covariance model for the fit.
139
+ data (object, optional): Data provider passed to likelihoods.
140
+ likelihood (BaseLikelihood, optional): Prebuilt likelihood.
141
+ minuit (iminuit.Minuit, optional): Preconfigured Minuit instance.
153
142
  """
154
143
  super(FitMinuit, self).__init__(
155
144
  covariance=covariance,
@@ -168,26 +157,18 @@ class FitMinuit(BaseFitter):
168
157
  likelihood_properties={},
169
158
  **kwargs,
170
159
  ):
171
- """
172
- The init_from_covariance function is a class method that initializes the MinuitFitter object.
173
- It takes in the covariance matrix, data, parameter dictionary and likelihood type as arguments.
174
- The minuit_fitter object is initialized with the covariance matrix and data. The likelihood function
175
- is then calculated using get_likelihood() which returns an instance of LikelihoodFunction(). This
176
- instance is assigned to minuit_fitter's attribute 'likelihood'. The parameter values are extracted from
177
- the parameter dictionary and stored in a list called 'parameter_values'. A Minuit object called 'minuit'
178
- is
160
+ """Build a Minuit fitter from covariance and data.
179
161
 
180
162
  Args:
181
- cls: Create a new instance of the class
182
- covariance: Initialize the covariance matrix of the likelihood
183
- data: Pass the data to the likelihood function
184
- parameter_dict: Pass the parameters to be fitted
185
- likelihood_type: Specify the type of likelihood function to be used
186
- : Set the covariance matrix of the data
163
+ covariance (CovMatrix): Covariance model.
164
+ data (object): Data provider.
165
+ parameter_dict (dict): Parameter specs including values, errors, limits.
166
+ likelihood_type (str): Likelihood variant key.
167
+ likelihood_properties (dict): Options (e.g., use_jit, use_gradient).
168
+ **kwargs: Extra args forwarded to likelihood construction.
187
169
 
188
170
  Returns:
189
- A minuit_fitter object
190
-
171
+ FitMinuit: Configured fitter with `iminuit.Minuit` ready.
191
172
  """
192
173
  minuit_fitter = cls(
193
174
  covariance=covariance,
@@ -237,22 +218,22 @@ class FitMinuit(BaseFitter):
237
218
  likelihood_type="multivariate_gaussian",
238
219
  likelihood_properties=None,
239
220
  ):
240
- """
241
- The init_from_file function is a class method that initializes the fitter object from a covariance matrix.
221
+ """Initialize a Minuit fitter from a covariance file.
222
+
223
+ Detects supported formats by extension (`.pickle`, `.npz`), loads a
224
+ `CovMatrix`, and delegates to `init_from_covariance`.
242
225
 
243
226
  Args:
244
- cls: Pass the class object to the function
245
- model_name: Specify the name of the model
246
- model_kind: Specify the type of model
247
- filename: Load the covariance matrix from a file
248
- data: Initialize the fitter's data attribute
249
- parameter_dict: Pass in the parameters that are used to
250
- likelihood_type: Specify the type of likelihood function to use
251
- : Specify the type of likelihood
227
+ model_name (str): Model name (unused here; kept for API parity).
228
+ model_kind (str): Model kind (unused here).
229
+ filename (str): Path with or without extension.
230
+ data (object): Data provider.
231
+ parameter_dict (dict): Parameter specs.
232
+ likelihood_type (str): Likelihood variant key.
233
+ likelihood_properties (dict): Likelihood options.
252
234
 
253
235
  Returns:
254
- A fitter object
255
-
236
+ FitMinuit: Configured fitter with `iminuit.Minuit` ready.
256
237
  """
257
238
  # Detect supported formats by extension
258
239
  if filename.endswith(".pickle"):
@@ -280,18 +261,11 @@ class FitMinuit(BaseFitter):
280
261
  )
281
262
 
282
263
  def setup_minuit(self, parameter_dict):
283
- """
284
- The setup_minuit function is used to set up the minuit object.
285
- It takes a dictionary of parameters as input and sets the errors, fixed values, and limits for each parameter.
286
- The error is set to be equal to the value if no error is specified in the dictionary. If a parameter has been fixed then its error will be zero.
264
+ """Configure Minuit parameter errors, limits, and fixed flags.
287
265
 
288
266
  Args:
289
- self: Refer to the object itself
290
- parameter_dict: Set the initial values of the parameters
291
-
292
- Returns:
293
- A minuit object
294
-
267
+ parameter_dict (dict): Parameter config with keys `value`, optional
268
+ `error`, `fixed`, `limit_low`, `limit_up`.
295
269
  """
296
270
  self.minuit.errordef = iminuit.Minuit.LIKELIHOOD
297
271
  for parameters in parameter_dict:
@@ -306,20 +280,16 @@ class FitMinuit(BaseFitter):
306
280
  self.minuit.limits[parameters] = (limit_low, limit_up)
307
281
 
308
282
  def run(self, migrad=True, hesse=False, minos=False, n_iter=1):
309
- """
310
- The run function is the main function of the class. It takes in a number of
311
- arguments, and then runs them through Minuit. The arguments are:
283
+ """Run Minuit optimization and return fitted parameter values.
312
284
 
313
285
  Args:
314
- self: Bind the method to the object
315
- migrad: Run the migrad algorithm
316
- hesse: Run the hesse function
317
- minos: Run the minos function, which is a
318
- : Set the number of iterations for migrad
286
+ migrad (bool): Run MIGRAD algorithm.
287
+ hesse (bool): Compute HESSE errors.
288
+ minos (bool): Compute MINOS intervals (may be slow).
289
+ n_iter (int): Number of MIGRAD iterations to perform.
319
290
 
320
291
  Returns:
321
- A dictionary with the results of the minimization
322
-
292
+ dict: Fitted parameter values.
323
293
  """
324
294
  if migrad:
325
295
  for i in range(n_iter):
@@ -331,7 +301,7 @@ class FitMinuit(BaseFitter):
331
301
  if minos:
332
302
  try:
333
303
  log.add(self.minuit.minos())
334
- except:
304
+ except RuntimeError:
335
305
  pass
336
306
 
337
307
  return self.minuit.values.to_dict()
@@ -346,23 +316,12 @@ class FitMCMC(BaseFitter):
346
316
  data=None,
347
317
  sampler_name="emcee",
348
318
  ):
349
- """
350
- The __init__ function is called when the class is instantiated.
351
- It sets up the instance of the class, and defines all of its attributes.
352
- The __init__ function should always accept at least one argument, self,
353
- which refers to the instance of the object being created.
319
+ """Initialize MCMC fitter.
354
320
 
355
321
  Args:
356
- self: Represent the instance of the object itself
357
- covariance: Set the covariance matrix of the fit
358
- data: Pass the data to the likelihood function
359
- likelihood: Define the likelihood function
360
- sampler: Pass the sampler object to the fit
361
- : Define the sampler that will be used in the fit
362
-
363
- Returns:
364
- The object itself, so the return value is self
365
-
322
+ covariance (CovMatrix, optional): Covariance model.
323
+ data (object, optional): Data provider.
324
+ sampler_name (str): Sampler backend name (only `emcee` supported).
366
325
  """
367
326
  super().__init__(
368
327
  covariance=covariance,
@@ -384,20 +343,21 @@ class FitMCMC(BaseFitter):
384
343
  backend_file=None,
385
344
  **kwargs,
386
345
  ):
387
- """
388
- The init_from_covariance function is a class method that initializes the MCMC fitter from a covariance matrix.
346
+ """Build an MCMC fitter from covariance and data.
389
347
 
390
348
  Args:
391
- cls: Create a new instance of the class
392
- covariance: Set the covariance matrix of the multivariate gaussian
393
- data: Calculate the likelihood
394
- parameter_dict: Pass in the parameters of the model
395
- likelihood_type: Specify the type of likelihood function to use
396
- : Set the covariance matrix
349
+ covariance (CovMatrix): Covariance model.
350
+ data (object): Data provider.
351
+ parameter_dict (dict): Parameter specs including random initialization.
352
+ likelihood_type (str): Likelihood variant key.
353
+ likelihood_properties (dict): Options; sets negative_log_likelihood=False.
354
+ sampler_name (str): Sampler backend (`emcee`).
355
+ nwalkers (int): Number of walkers.
356
+ backend_file (str, optional): HDF backend path for resume/checkpoint.
357
+ **kwargs: Extra args forwarded to likelihood.
397
358
 
398
359
  Returns:
399
- A mcmc_fitter object
400
-
360
+ FitMCMC: Configured fitter with sampler set.
401
361
  """
402
362
 
403
363
  mcmc_fitter = cls(
@@ -432,18 +392,22 @@ class FitMCMC(BaseFitter):
432
392
  def init_from_file(
433
393
  cls,
434
394
  ):
435
- """
436
- The init_from_covariance function is a class method that initializes the
437
- fitter from the a file containing covariance matrix. It is here an
438
- abstract method that needs to be override
395
+ """Not implemented for MCMC from file.
439
396
 
440
- Args:
441
- cls: Pass a class object into a method
397
+ Raises:
398
+ NotImplementedError: Always.
442
399
  """
443
400
 
444
401
  raise NotImplementedError
445
402
 
446
403
  def set_sampler(self, likelihood, p0=None, **kwargs):
404
+ """Create sampler backend from likelihood and initial positions.
405
+
406
+ Args:
407
+ likelihood (Callable): Log-probability callable.
408
+ p0 (numpy.ndarray, optional): Initial walker positions `(nwalkers, ndim)`.
409
+ **kwargs: Backend-specific options (e.g., `backend_file`).
410
+ """
447
411
  if self.sampler_name == "emcee":
448
412
  self.sampler = EMCEESampler(likelihood, p0=p0, **kwargs)
449
413
  else:
@@ -451,8 +415,15 @@ class FitMCMC(BaseFitter):
451
415
 
452
416
 
453
417
  class Sampler(abc.ABC):
418
+ """Abstract sampler interface wrapping different MCMC engines."""
454
419
 
455
420
  def __init__(self, likelihood, p0=None):
421
+ """Initialize sampler.
422
+
423
+ Args:
424
+ likelihood (Callable): Log-probability function.
425
+ p0 (numpy.ndarray, optional): Initial positions `(nwalkers, ndim)`.
426
+ """
456
427
  self.likelihood = likelihood
457
428
  self._p0 = None
458
429
  if p0 is not None:
@@ -460,18 +431,44 @@ class Sampler(abc.ABC):
460
431
 
461
432
  @abc.abstractmethod
462
433
  def run_chains(self, nsteps):
434
+ """Run sampler chains for a fixed number of steps.
435
+
436
+ Args:
437
+ nsteps (int): Number of steps per walker.
438
+
439
+ Returns:
440
+ Any: Backend-specific sampler object.
441
+ """
463
442
  return
464
443
 
465
444
  @property
466
445
  def ndim(self):
446
+ """Return dimensionality of parameter space.
447
+
448
+ Returns:
449
+ int: Number of parameters.
450
+ """
467
451
  return len(self.likelihood.parameter_names)
468
452
 
469
453
  @property
470
454
  def p0(self):
455
+ """Initial positions of walkers.
456
+
457
+ Returns:
458
+ numpy.ndarray: Array of shape `(nwalkers, ndim)`.
459
+ """
471
460
  return self._p0
472
461
 
473
462
  @p0.setter
474
463
  def p0(self, value):
464
+ """Set initial positions ensuring shape consistency.
465
+
466
+ Args:
467
+ value (numpy.ndarray): Initial positions `(nwalkers, ndim)`.
468
+
469
+ Raises:
470
+ ValueError: If `ndim` mismatch.
471
+ """
475
472
  if value.shape[1] != self.ndim:
476
473
  raise ValueError(
477
474
  f"p0.shape[1] is equal to ndim={self.ndim}, currently {value.shape[1]}"
@@ -483,7 +480,13 @@ class Sampler(abc.ABC):
483
480
  class EMCEESampler(Sampler):
484
481
  def __init__(self, likelihood, p0=None, backend_file=None):
485
482
  super().__init__(likelihood, p0=p0)
483
+ """Create an emcee sampler with optional HDF backend.
486
484
 
485
+ Args:
486
+ likelihood (Callable): Log-probability function.
487
+ p0 (numpy.ndarray, optional): Initial positions.
488
+ backend_file (str, optional): HDF backend filename to resume/checkpoint.
489
+ """
487
490
  self.backend = None
488
491
  if backend_file is not None:
489
492
  backend_file_exists = os.path.exists(backend_file)
@@ -504,6 +507,16 @@ class EMCEESampler(Sampler):
504
507
  self.backend_file_exists = False
505
508
 
506
509
  def run_chains(self, nsteps, number_worker=1, progress=False):
510
+ """Run emcee chains for a fixed number of steps.
511
+
512
+ Args:
513
+ nsteps (int): Number of steps to run.
514
+ number_worker (int): Parallel workers via multiprocessing.
515
+ progress (bool): Show progress bar.
516
+
517
+ Returns:
518
+ emcee.EnsembleSampler: The sampler instance.
519
+ """
507
520
  with mp.Pool(number_worker) if number_worker != 1 else nullcontext() as pool:
508
521
  sampler = emcee.EnsembleSampler(
509
522
  self.nwalkers,
@@ -526,7 +539,20 @@ class EMCEESampler(Sampler):
526
539
  tau_conv=0.01,
527
540
  progress=False,
528
541
  ):
529
- """Run chains until reaching auto correlation convergence criteria."""
542
+ """Run chains until reaching autocorrelation convergence criteria.
543
+
544
+ Uses emcee's `get_autocorr_time` to check stabilization of autocorrelation
545
+ time and sufficient chain length.
546
+
547
+ Args:
548
+ number_worker (int): Parallel workers.
549
+ maxstep (int): Maximum steps if not converged earlier.
550
+ tau_conv (float): Relative change threshold for convergence.
551
+ progress (bool): Show progress bar.
552
+
553
+ Returns:
554
+ emcee.EnsembleSampler: The sampler instance.
555
+ """
530
556
  old_tau = np.inf
531
557
  with mp.Pool(number_worker) if number_worker != 1 else nullcontext() as pool:
532
558
  sampler = emcee.EnsembleSampler(