eryn 1.2.3__tar.gz → 1.2.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {eryn-1.2.3 → eryn-1.2.5}/PKG-INFO +1 -1
  2. {eryn-1.2.3 → eryn-1.2.5}/pyproject.toml +1 -1
  3. {eryn-1.2.3 → eryn-1.2.5}/src/eryn/backends/backend.py +11 -2
  4. {eryn-1.2.3 → eryn-1.2.5}/src/eryn/backends/hdfbackend.py +15 -0
  5. {eryn-1.2.3 → eryn-1.2.5}/src/eryn/ensemble.py +18 -5
  6. {eryn-1.2.3 → eryn-1.2.5}/src/eryn/moves/gaussian.py +6 -1
  7. {eryn-1.2.3 → eryn-1.2.5}/src/eryn/prior.py +47 -2
  8. {eryn-1.2.3 → eryn-1.2.5}/src/eryn/utils/periodic.py +19 -3
  9. {eryn-1.2.3 → eryn-1.2.5}/src/eryn/utils/transform.py +52 -39
  10. {eryn-1.2.3 → eryn-1.2.5}/src/eryn/utils/utility.py +4 -3
  11. eryn-1.2.3/src/eryn/tests/__init__.py +0 -0
  12. eryn-1.2.3/src/eryn/tests/test_eryn.py +0 -1246
  13. {eryn-1.2.3 → eryn-1.2.5}/README.md +0 -0
  14. {eryn-1.2.3 → eryn-1.2.5}/src/eryn/CMakeLists.txt +0 -0
  15. {eryn-1.2.3 → eryn-1.2.5}/src/eryn/__init__.py +0 -0
  16. {eryn-1.2.3 → eryn-1.2.5}/src/eryn/backends/__init__.py +0 -0
  17. {eryn-1.2.3 → eryn-1.2.5}/src/eryn/git_version.py.in +0 -0
  18. {eryn-1.2.3 → eryn-1.2.5}/src/eryn/model.py +0 -0
  19. {eryn-1.2.3 → eryn-1.2.5}/src/eryn/moves/__init__.py +0 -0
  20. {eryn-1.2.3 → eryn-1.2.5}/src/eryn/moves/combine.py +0 -0
  21. {eryn-1.2.3 → eryn-1.2.5}/src/eryn/moves/delayedrejection.py +0 -0
  22. {eryn-1.2.3 → eryn-1.2.5}/src/eryn/moves/distgen.py +0 -0
  23. {eryn-1.2.3 → eryn-1.2.5}/src/eryn/moves/distgenrj.py +0 -0
  24. {eryn-1.2.3 → eryn-1.2.5}/src/eryn/moves/group.py +0 -0
  25. {eryn-1.2.3 → eryn-1.2.5}/src/eryn/moves/groupstretch.py +0 -0
  26. {eryn-1.2.3 → eryn-1.2.5}/src/eryn/moves/mh.py +0 -0
  27. {eryn-1.2.3 → eryn-1.2.5}/src/eryn/moves/move.py +0 -0
  28. {eryn-1.2.3 → eryn-1.2.5}/src/eryn/moves/mtdistgen.py +0 -0
  29. {eryn-1.2.3 → eryn-1.2.5}/src/eryn/moves/mtdistgenrj.py +0 -0
  30. {eryn-1.2.3 → eryn-1.2.5}/src/eryn/moves/multipletry.py +0 -0
  31. {eryn-1.2.3 → eryn-1.2.5}/src/eryn/moves/red_blue.py +0 -0
  32. {eryn-1.2.3 → eryn-1.2.5}/src/eryn/moves/rj.py +0 -0
  33. {eryn-1.2.3 → eryn-1.2.5}/src/eryn/moves/stretch.py +0 -0
  34. {eryn-1.2.3 → eryn-1.2.5}/src/eryn/moves/tempering.py +0 -0
  35. {eryn-1.2.3 → eryn-1.2.5}/src/eryn/pbar.py +0 -0
  36. {eryn-1.2.3 → eryn-1.2.5}/src/eryn/state.py +0 -0
  37. {eryn-1.2.3 → eryn-1.2.5}/src/eryn/utils/__init__.py +0 -0
  38. {eryn-1.2.3 → eryn-1.2.5}/src/eryn/utils/stopping.py +0 -0
  39. {eryn-1.2.3 → eryn-1.2.5}/src/eryn/utils/updates.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: eryn
3
- Version: 1.2.3
3
+ Version: 1.2.5
4
4
  Summary: Eryn: an omni-MCMC sampling package.
5
5
  Author: Michael Katz
6
6
  Author-email: Michael Katz <mikekatz04@gmail.com>
@@ -8,7 +8,7 @@ requires = [
8
8
  [project]
9
9
  name = "eryn" #@NAMESUFFIX@
10
10
 
11
- version = "1.2.3"
11
+ version = "1.2.5"
12
12
 
13
13
  description = "Eryn: an omni-MCMC sampling package."
14
14
 
@@ -83,6 +83,7 @@ class Backend(object):
83
83
  nbranches=1,
84
84
  rj=False,
85
85
  moves=None,
86
+ key_order=None,
86
87
  **info,
87
88
  ):
88
89
  """Clear the state of the chain and empty the backend
@@ -106,6 +107,9 @@ class Backend(object):
106
107
  (default: ``False``)
107
108
  moves (list, optional): List of all of the move classes input into the sampler.
108
109
  (default: ``None``)
110
+ key_order (dict, optional): Keys are ``branch_names`` and values are lists of key ordering for each
111
+ branch. For example, ``{"model_0": ["x1", "x2", "x3"]}``.
112
+ (default: ``None``)
109
113
  **info (dict, optional): Any other key-value pairs to be added
110
114
  as attributes to the backend.
111
115
 
@@ -118,6 +122,7 @@ class Backend(object):
118
122
  branch_names=branch_names,
119
123
  rj=rj,
120
124
  moves=moves,
125
+ key_order=key_order,
121
126
  info=info,
122
127
  )
123
128
 
@@ -184,6 +189,7 @@ class Backend(object):
184
189
  self.branch_names = branch_names
185
190
  self.ndims = ndims
186
191
  self.nleaves_max = nleaves_max
192
+ self.key_order = key_order
187
193
 
188
194
  self.iteration = 0
189
195
 
@@ -703,8 +709,11 @@ class Backend(object):
703
709
  "thermo",
704
710
  "ti",
705
711
  ]:
706
- logls = np.mean(logls_all, axis=(0, -1))
707
- logZ, dlogZ = thermodynamic_integration_log_evidence(betas, logls)
712
+ logls = logls_all.copy()
713
+ logls[~np.isfinite(logls)] = np.nan
714
+ meanlogls = np.nanmean(logls, axis=(0, -1))
715
+ logZ, dlogZ = thermodynamic_integration_log_evidence(betas, meanlogls)
716
+
708
717
  elif method.lower() in [
709
718
  "stepping stone",
710
719
  "ss",
@@ -176,6 +176,7 @@ class HDFBackend(Backend):
176
176
  nbranches=1,
177
177
  rj=False,
178
178
  moves=None,
179
+ key_order=None,
179
180
  **info,
180
181
  ):
181
182
  """Clear the state of the chain and empty the backend
@@ -199,6 +200,9 @@ class HDFBackend(Backend):
199
200
  (default: ``False``)
200
201
  moves (list, optional): List of all of the move classes input into the sampler.
201
202
  (default: ``None``)
203
+ key_order (dict, optional): Keys are ``branch_names`` and values are lists of key ordering for each
204
+ branch. For example, ``{"model_0": ["x1", "x2", "x3"]}``.
205
+ (default: ``None``)
202
206
  **info (dict, optional): Any other key-value pairs to be added
203
207
  as attributes to the backend. These are also added to the HDF5 file.
204
208
 
@@ -343,6 +347,7 @@ class HDFBackend(Backend):
343
347
 
344
348
  chain = g.create_group("chain")
345
349
  inds = g.create_group("inds")
350
+ k_o_g = g.create_group("key_order")
346
351
 
347
352
  for name in branch_names:
348
353
  nleaves = self.nleaves_max[name]
@@ -365,6 +370,9 @@ class HDFBackend(Backend):
365
370
  compression_opts=self.compression_opts,
366
371
  )
367
372
 
373
+ if key_order is not None:
374
+ k_o_g.attrs[name] = key_order[name]
375
+
368
376
  # store move specific information
369
377
  if moves is not None:
370
378
  move_group = g.create_group("moves")
@@ -388,6 +396,12 @@ class HDFBackend(Backend):
388
396
 
389
397
  self.blobs = None
390
398
 
399
+ @property
400
+ def key_order(self):
401
+ """Key order of parameters for each model."""
402
+ with self.open() as f:
403
+ return {key: value for key, value in f[self.name]["key_order"].attrs.items()}
404
+
391
405
  @property
392
406
  def nwalkers(self):
393
407
  """Get nwalkers from h5 file."""
@@ -456,6 +470,7 @@ class HDFBackend(Backend):
456
470
  branch_names=self.branch_names,
457
471
  rj=self.rj,
458
472
  moves=self.moves,
473
+ key_order=self.key_order,
459
474
  )
460
475
 
461
476
  @property
@@ -198,6 +198,7 @@ class EnsembleSampler(object):
198
198
  if it is changed. In this case, the user should declare a new backend and use the last
199
199
  state from the previous backend. **Warning**: If the order of moves of the same move class
200
200
  is changed, the check may not catch it, so the tracking may mix move acceptance fractions together.
201
+ (default: ``True'')
201
202
  info (dict, optional): Key and value pairs reprenting any information
202
203
  the user wants to add to the backend if the user is not inputing
203
204
  their own backend.
@@ -597,6 +598,7 @@ class EnsembleSampler(object):
597
598
  nleaves_max=nleaves_max,
598
599
  rj=self.has_reversible_jump,
599
600
  moves=move_keys,
601
+ key_order=self.key_order,
600
602
  **info,
601
603
  )
602
604
  state = np.random.get_state()
@@ -615,6 +617,9 @@ class EnsembleSampler(object):
615
617
  "Configuration of moves has changed. Cannot use the same backend. Declare a new backend and start from the previous state. If you would prefer not to track move acceptance fraction, set track_moves to False in the EnsembleSampler."
616
618
  )
617
619
 
620
+ if self.key_order != self.backend.key_order:
621
+ raise ValueError("Input key order from priors does not match backend.")
622
+
618
623
  # Check the backend shape
619
624
  for i, (name, shape) in enumerate(self.backend.shape.items()):
620
625
  test_shape = (
@@ -689,7 +694,7 @@ class EnsembleSampler(object):
689
694
 
690
695
  """
691
696
  return self._random.get_state()
692
-
697
+
693
698
  @random_state.setter # NOQA
694
699
  def random_state(self, state):
695
700
  """
@@ -751,13 +756,14 @@ class EnsembleSampler(object):
751
756
  else:
752
757
  raise ValueError("Priors must be a dictionary.")
753
758
 
759
+ self.key_order = {key: value.key_order for key, value in self._priors.items()}
754
760
  return
755
761
 
756
762
  @property
757
763
  def iteration(self):
758
764
  return self.backend.iteration
759
765
 
760
- def reset(self, **info):
766
+ def reset(self, **kwargs):
761
767
  """
762
768
  Reset the backend.
763
769
 
@@ -765,7 +771,7 @@ class EnsembleSampler(object):
765
771
  **info (dict, optional): information to pass to backend reset method.
766
772
 
767
773
  """
768
- self.backend.reset(self.nwalkers, self.ndims, **info)
774
+ self.backend.reset(self.nwalkers, self.ndims, **kwargs)
769
775
 
770
776
  def __getstate__(self):
771
777
  # In order to be generally picklable, we need to discard the pool
@@ -1208,6 +1214,9 @@ class EnsembleSampler(object):
1208
1214
  # vectorized because everything is rectangular (no groups to indicate model difference)
1209
1215
  prior_out += prior_out_temp.sum(axis=-1)
1210
1216
 
1217
+ if np.any(np.isnan(prior_out)):
1218
+ raise ValueError("The prior function is returning Nan.")
1219
+
1211
1220
  return prior_out
1212
1221
 
1213
1222
  def compute_log_like(
@@ -1493,8 +1502,9 @@ class EnsembleSampler(object):
1493
1502
  ll[inds_fix_zeros] = self.fill_zero_leaves_val
1494
1503
 
1495
1504
  # deal with blobs
1496
- blobs_out = np.zeros((nwalkers_all, results.shape[1] - 1))
1497
- blobs_out[unique_groups] = results[:, 1:]
1505
+ _blobs_out = np.zeros((nwalkers_all, results.shape[1] - 1))
1506
+ _blobs_out[unique_groups] = results[:, 1:]
1507
+ blobs_out = _blobs_out.reshape(ntemps, nwalkers)
1498
1508
 
1499
1509
  elif results.dtype == "object":
1500
1510
  # TODO: check blobs and add this capability
@@ -1531,6 +1541,9 @@ class EnsembleSampler(object):
1531
1541
  for key in branch_supps_in_2[name_i]
1532
1542
  }
1533
1543
 
1544
+ if np.any(np.isnan(ll)):
1545
+ raise ValueError("The likelihood function is returning Nan.")
1546
+
1534
1547
  # return Likelihood and blobs
1535
1548
  return ll.reshape(ntemps, nwalkers), blobs_out
1536
1549
 
@@ -137,7 +137,12 @@ class _isotropic_proposal(object):
137
137
  def __init__(self, scale, factor, mode):
138
138
  self.index = 0
139
139
  self.scale = scale
140
- self.invscale = np.linalg.inv(np.linalg.cholesky(scale))
140
+
141
+ if isinstance(scale, float):
142
+ self.invscale = 1. / scale
143
+ else:
144
+ self.invscale = np.linalg.inv(np.linalg.cholesky(scale))
145
+
141
146
  if factor is None:
142
147
  self._log_factor = None
143
148
  else:
@@ -248,24 +248,69 @@ class ProbDistContainer:
248
248
  # to separate out in list form
249
249
  self.priors = []
250
250
 
251
+ self.has_strings = False
252
+ self.has_ints = False
253
+
254
+ # this is for the strings (for the ints it just counts them)
255
+ current_ind = 0
256
+ key_order = []
257
+
251
258
  # setup lists
252
259
  temp_inds = []
253
260
  for inds, dist in priors_in.items():
254
261
  # multiple index
255
262
  if isinstance(inds, tuple):
256
- inds_in = np.asarray(inds)
263
+ inds_tmp = []
264
+ for i in range(len(inds)):
265
+ if isinstance(inds[i], str):
266
+ assert not self.has_ints
267
+ self.has_strings = True
268
+ inds_tmp.append(current_ind)
269
+ key_order.append(inds[i])
270
+
271
+ elif isinstance(inds[i], int):
272
+ assert not self.has_strings
273
+ self.has_ints = True
274
+ inds_tmp.append(i)
275
+
276
+ else:
277
+ raise ValueError("Index in tuple must be int or str and all be the same type.")
278
+
279
+ current_ind += 1
280
+
281
+ inds_in = np.asarray(inds_tmp)
257
282
  self.priors.append([inds_in, dist])
258
283
 
259
284
  # single index
260
285
  elif isinstance(inds, int):
286
+ self.has_ints = True
287
+ assert not self.has_strings
261
288
  inds_in = np.array([inds])
262
289
  self.priors.append([inds_in, dist])
290
+ current_ind += 1
291
+
292
+ elif isinstance(inds, str):
293
+ assert not self.has_ints
294
+ self.has_strings = True
295
+ key_order.append(inds)
296
+ inds_in = np.array([current_ind])
297
+ current_ind += 1
298
+ self.priors.append([inds_in, dist])
263
299
 
264
300
  else:
265
301
  raise ValueError(
266
- "Keys for prior dictionary must be an integer or tuple."
302
+ "Keys for prior dictionary must be an integer, string, or tuple."
267
303
  )
268
304
 
305
+ if self.has_strings:
306
+ assert not self.has_ints
307
+ # key order is already set
308
+ self.key_order = key_order
309
+
310
+ if self.has_ints:
311
+ self.key_order = [i for i in range(current_ind)] # here current_ind is the total count
312
+ assert not self.has_strings
313
+
269
314
  temp_inds.append(np.asarray([inds_in]))
270
315
 
271
316
  uni_inds = np.unique(np.concatenate(temp_inds, axis=1).flatten())
@@ -18,15 +18,31 @@ class PeriodicContainer:
18
18
 
19
19
  """
20
20
 
21
- def __init__(self, periodic):
21
+ def __init__(self, periodic, key_order=None):
22
22
 
23
23
  # store all the information
24
24
  self.periodic = periodic
25
+ inds_periodic = {}
26
+ periods = {}
27
+ for key in periodic:
28
+ if periodic[key] is None:
29
+ continue
30
+ inds_periodic[key] = []
31
+ periods[key] = []
32
+ for var, period in periodic[key].items():
33
+ if isinstance(var, str):
34
+ if key_order is None:
35
+ raise ValueError(f"If providing str values for the variable names, must provide key_order argument.")
36
+
37
+ index = key_order[key].index(var)
38
+ inds_periodic[key].append(index)
39
+ periods[key].append(period)
40
+
25
41
  self.inds_periodic = {
26
- key: np.asarray([i for i in periodic[key].keys()]) for key in periodic
42
+ key: np.asarray(tmp) for key, tmp in inds_periodic.items()
27
43
  }
28
44
  self.periods = {
29
- key: np.asarray([i for i in periodic[key].values()]) for key in periodic
45
+ key: np.asarray(tmp) for key, tmp in periods.items()
30
46
  }
31
47
 
32
48
  def distance(self, p1, p2, xp=None):
@@ -11,6 +11,10 @@ class TransformContainer:
11
11
  """Container for helpful transformations
12
12
 
13
13
  Args:
14
+ input_basis (list): List of integers or strings representing each
15
+ basis element from the input basis.
16
+ output_basis (list): List of integers or strings representing each
17
+ basis element for the output basis.
14
18
  parameter_transforms (dict, optional): Keys are ``int`` or ``tuple``
15
19
  of ``int`` that contain the indexes into the parameters
16
20
  that correspond to the transformation added as the Values to the
@@ -31,53 +35,70 @@ class TransformContainer:
31
35
 
32
36
  """
33
37
 
34
- def __init__(self, parameter_transforms=None, fill_dict=None):
38
+ def __init__(self, input_basis=None, output_basis=None, parameter_transforms=None, fill_dict=None, key_map={}):
35
39
 
40
+
36
41
  # store originals
37
42
  self.original_parameter_transforms = parameter_transforms
43
+ self.ndim_full = len(output_basis)
44
+ self.ndim = len(input_basis)
45
+
46
+ self.input_basis, self.output_basis = input_basis, output_basis
47
+
48
+ test_inds = []
49
+ for key in input_basis:
50
+ if key not in output_basis and key not in key_map:
51
+ raise ValueError("All keys in input_basis must be present in output basis, or you must provide a key_map")
52
+ key_in = key if key not in key_map else key_map[key]
53
+ test_inds.append(output_basis.index(key_in))
54
+
55
+ self.test_inds = test_inds = np.asarray(test_inds)
38
56
  if parameter_transforms is not None:
39
57
  # differentiate between single and multi parameter transformations
40
58
  self.base_transforms = {"single_param": {}, "mult_param": {}}
41
59
 
42
60
  # iterate through transforms and setup single and multiparameter transforms
43
61
  for key, item in parameter_transforms.items():
44
- if isinstance(key, int):
45
- self.base_transforms["single_param"][key] = item
62
+ if isinstance(key, str) or isinstance(key, int):
63
+ if key not in output_basis:
64
+ assert key in key_map
65
+ key = key_map[key]
66
+ key_in = output_basis.index(key)
67
+ self.base_transforms["single_param"][key_in] = item
46
68
  elif isinstance(key, tuple):
47
- self.base_transforms["mult_param"][key] = item
69
+ _tmp = []
70
+ for i in range(len(key)):
71
+ key_tmp = key[i]
72
+ if key_tmp not in output_basis:
73
+ assert key_tmp in key_map
74
+ key_tmp = key_map[key_tmp]
75
+ _tmp.append(output_basis.index(key_tmp))
76
+ self.base_transforms["mult_param"][tuple(_tmp)] = item
48
77
  else:
49
78
  raise ValueError(
50
- "Parameter transform keys must be int or tuple of ints. {} is neither.".format(
79
+ "Parameter transform keys must be str (or int) or tuple of strs (or ints). {} is neither.".format(
51
80
  key
52
81
  )
53
82
  )
54
83
  else:
55
84
  self.base_transforms = None
56
85
 
86
+ self.original_fill_dict = fill_dict
57
87
  if fill_dict is not None:
58
88
  if not isinstance(fill_dict, dict):
59
89
  raise ValueError("fill_dict must be a dictionary.")
60
90
 
61
- self.fill_dict = fill_dict
62
- fill_dict_keys = list(self.fill_dict.keys())
63
- for key in ["ndim_full", "fill_inds", "fill_values"]:
64
- # check to make sure it has all necessary pieces
65
- if key not in fill_dict_keys:
66
- raise ValueError(
67
- f"If providing fill_inds, dictionary must have {key} as a key."
68
- )
69
- # check all the inputs
70
- if not isinstance(fill_dict["ndim_full"], int):
71
- raise ValueError("fill_dict['ndim_full'] must be an int.")
72
- if not isinstance(fill_dict["fill_inds"], np.ndarray):
73
- raise ValueError("fill_dict['fill_inds'] must be an np.ndarray.")
74
- if not isinstance(fill_dict["fill_values"], np.ndarray):
75
- raise ValueError("fill_dict['fill_values'] must be an np.ndarray.")
91
+ self.fill_dict = {}
92
+ self.fill_dict["fill_inds"] = []
93
+ self.fill_dict["fill_values"] = []
94
+ for key in fill_dict.keys():
95
+ self.fill_dict["fill_inds"].append(output_basis.index(key))
96
+ self.fill_dict["fill_values"].append(fill_dict[key])
76
97
 
77
98
  # set up test_inds accordingly
78
- self.fill_dict["test_inds"] = np.delete(
79
- np.arange(self.fill_dict["ndim_full"]), self.fill_dict["fill_inds"]
80
- )
99
+ self.fill_dict["test_inds"] = test_inds
100
+ self.fill_dict["fill_inds"] = np.asarray(self.fill_dict["fill_inds"])
101
+ self.fill_dict["fill_values"] = np.asarray(self.fill_dict["fill_values"])
81
102
 
82
103
  else:
83
104
  self.fill_dict = None
@@ -134,6 +155,8 @@ class TransformContainer:
134
155
  def fill_values(self, params, xp=None):
135
156
  """fill fixed parameters
136
157
 
158
+ This also adjusts parameter order as needed between the two bases.
159
+
137
160
  Args:
138
161
  params (np.ndarray[..., ndim]): Array with coordinates. This array is
139
162
  filled with values according to the ``self.fill_dict`` dictionary.
@@ -152,7 +175,7 @@ class TransformContainer:
152
175
  shape = params.shape
153
176
 
154
177
  # setup new array to fill
155
- params_filled = xp.zeros(shape[:-1] + (self.fill_dict["ndim_full"],))
178
+ params_filled = xp.zeros(shape[:-1] + (self.ndim_full,))
156
179
  test_inds = xp.asarray(self.fill_dict["test_inds"])
157
180
  # special indexing to properly fill array with params
158
181
  indexing_test_inds = tuple([slice(0, temp) for temp in shape[:-1]]) + (
@@ -179,7 +202,7 @@ class TransformContainer:
179
202
  return params
180
203
 
181
204
  def both_transforms(
182
- self, params, copy=True, return_transpose=False, reverse=False, xp=None
205
+ self, params, copy=True, return_transpose=False, xp=None
183
206
  ):
184
207
  """Transform the parameters and fill fixed parameters
185
208
 
@@ -197,9 +220,6 @@ class TransformContainer:
197
220
  (default: ``True``)
198
221
  return_transpose (bool, optional): If ``True``, return the transpose of the
199
222
  array. (default: ``False``)
200
- reverse (bool, optional): If ``True`` perform the filling after the transforms. This makes
201
- indexing easier, but removes the ability of fixed parameters to affect transforms.
202
- (default: ``False``)
203
223
  xp (object, optional): ``numpy`` or ``cupy``. If ``None``, use ``numpy``.
204
224
  (default: ``None``)
205
225
 
@@ -212,15 +232,8 @@ class TransformContainer:
212
232
  xp = np
213
233
 
214
234
  # run transforms first
215
- if reverse:
216
- temp = self.transform_base_parameters(
217
- params, copy=copy, return_transpose=return_transpose, xp=xp
218
- )
219
- temp = self.fill_values(temp, xp=xp)
220
-
221
- else:
222
- temp = self.fill_values(params, xp=xp)
223
- temp = self.transform_base_parameters(
224
- temp, copy=copy, return_transpose=return_transpose, xp=xp
225
- )
235
+ temp = self.fill_values(params, xp=xp)
236
+ temp = self.transform_base_parameters(
237
+ temp, copy=copy, return_transpose=return_transpose, xp=xp
238
+ )
226
239
  return temp
@@ -237,11 +237,12 @@ def stepping_stone_log_evidence(betas, logls, block_len=50, repeats=100):
237
237
 
238
238
  def calculate_stepping_stone(betas, logls):
239
239
  n = logls.shape[0]
240
- delta_betas = betas[1:] - betas[:-1]
241
240
  n_T = betas.shape[0]
242
- log_ratio = logsumexp(delta_betas * logls[:, :-1], axis=0) - np.log(n)
241
+ delta_betas = betas[1:] - betas[:-1]
242
+ throwaways = np.any(~np.isfinite(logls), axis=1) # a safeguard against non-finite entries
243
+ log_ratio = logsumexp(delta_betas * logls[~throwaways, :-1], axis=0) - (n_T - 1.0)*np.log(n - np.sum(throwaways))
243
244
  return np.sum(log_ratio), log_ratio
244
-
245
+
245
246
  # make sure they are the same length
246
247
  if len(betas) != logls.shape[1]:
247
248
  raise ValueError(
File without changes