google-meridian 1.2.1__py3-none-any.whl → 1.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google_meridian-1.3.1.dist-info/METADATA +209 -0
- google_meridian-1.3.1.dist-info/RECORD +76 -0
- {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/top_level.txt +1 -0
- meridian/analysis/__init__.py +2 -0
- meridian/analysis/analyzer.py +179 -105
- meridian/analysis/formatter.py +2 -2
- meridian/analysis/optimizer.py +227 -87
- meridian/analysis/review/__init__.py +20 -0
- meridian/analysis/review/checks.py +721 -0
- meridian/analysis/review/configs.py +110 -0
- meridian/analysis/review/constants.py +40 -0
- meridian/analysis/review/results.py +544 -0
- meridian/analysis/review/reviewer.py +186 -0
- meridian/analysis/summarizer.py +21 -34
- meridian/analysis/templates/chips.html.jinja +12 -0
- meridian/analysis/test_utils.py +27 -5
- meridian/analysis/visualizer.py +41 -57
- meridian/backend/__init__.py +457 -118
- meridian/backend/test_utils.py +162 -0
- meridian/constants.py +39 -3
- meridian/model/__init__.py +1 -0
- meridian/model/eda/__init__.py +3 -0
- meridian/model/eda/constants.py +21 -0
- meridian/model/eda/eda_engine.py +1309 -196
- meridian/model/eda/eda_outcome.py +200 -0
- meridian/model/eda/eda_spec.py +84 -0
- meridian/model/eda/meridian_eda.py +220 -0
- meridian/model/knots.py +55 -49
- meridian/model/media.py +10 -8
- meridian/model/model.py +79 -16
- meridian/model/model_test_data.py +53 -0
- meridian/model/posterior_sampler.py +39 -32
- meridian/model/prior_distribution.py +12 -2
- meridian/model/prior_sampler.py +146 -90
- meridian/model/spec.py +7 -8
- meridian/model/transformers.py +11 -3
- meridian/version.py +1 -1
- schema/__init__.py +18 -0
- schema/serde/__init__.py +26 -0
- schema/serde/constants.py +48 -0
- schema/serde/distribution.py +515 -0
- schema/serde/eda_spec.py +192 -0
- schema/serde/function_registry.py +143 -0
- schema/serde/hyperparameters.py +363 -0
- schema/serde/inference_data.py +105 -0
- schema/serde/marketing_data.py +1321 -0
- schema/serde/meridian_serde.py +413 -0
- schema/serde/serde.py +47 -0
- schema/serde/test_data.py +4608 -0
- schema/utils/__init__.py +17 -0
- schema/utils/time_record.py +156 -0
- google_meridian-1.2.1.dist-info/METADATA +0 -409
- google_meridian-1.2.1.dist-info/RECORD +0 -52
- {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/WHEEL +0 -0
- {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/licenses/LICENSE +0 -0
meridian/model/model.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
|
|
15
15
|
"""Meridian module for the geo-level Bayesian hierarchical media mix model."""
|
|
16
16
|
|
|
17
|
+
import collections
|
|
17
18
|
from collections.abc import Mapping, Sequence
|
|
18
19
|
import functools
|
|
19
20
|
import numbers
|
|
@@ -34,19 +35,26 @@ from meridian.model import prior_distribution
|
|
|
34
35
|
from meridian.model import prior_sampler
|
|
35
36
|
from meridian.model import spec
|
|
36
37
|
from meridian.model import transformers
|
|
38
|
+
from meridian.model.eda import eda_engine
|
|
39
|
+
from meridian.model.eda import eda_outcome
|
|
40
|
+
from meridian.model.eda import eda_spec as eda_spec_module
|
|
37
41
|
import numpy as np
|
|
38
42
|
|
|
39
|
-
|
|
40
43
|
__all__ = [
|
|
41
44
|
"MCMCSamplingError",
|
|
42
45
|
"MCMCOOMError",
|
|
43
46
|
"Meridian",
|
|
47
|
+
"ModelFittingError",
|
|
44
48
|
"NotFittedModelError",
|
|
45
49
|
"save_mmm",
|
|
46
50
|
"load_mmm",
|
|
47
51
|
]
|
|
48
52
|
|
|
49
53
|
|
|
54
|
+
class ModelFittingError(Exception):
|
|
55
|
+
"""Model has critical issues preventing fitting."""
|
|
56
|
+
|
|
57
|
+
|
|
50
58
|
class NotFittedModelError(Exception):
|
|
51
59
|
"""Model has not been fitted."""
|
|
52
60
|
|
|
@@ -91,6 +99,10 @@ class Meridian:
|
|
|
91
99
|
model_spec: A `ModelSpec` object containing the model specification.
|
|
92
100
|
inference_data: A _mutable_ `arviz.InferenceData` object containing the
|
|
93
101
|
resulting data from fitting the model.
|
|
102
|
+
eda_engine: An `EDAEngine` object containing the EDA engine.
|
|
103
|
+
eda_spec: An `EDASpec` object containing the EDA specification.
|
|
104
|
+
eda_outcomes: A list of `EDAOutcome` objects containing the outcomes from
|
|
105
|
+
running critical EDA checks.
|
|
94
106
|
n_geos: Number of geos in the data.
|
|
95
107
|
n_media_channels: Number of media channels in the data.
|
|
96
108
|
n_rf_channels: Number of reach and frequency (RF) channels in the data.
|
|
@@ -126,11 +138,17 @@ class Meridian:
|
|
|
126
138
|
treatmenttensors using the model's non-media treatment data.
|
|
127
139
|
kpi_transformer: A `KpiTransformer` to scale KPI tensors using the model's
|
|
128
140
|
KPI data.
|
|
129
|
-
controls_scaled: The controls tensor
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
141
|
+
controls_scaled: The controls tensor after pre-modeling transformations
|
|
142
|
+
including population scaling (for variables with
|
|
143
|
+
`ModelSpec.control_population_scaling_id` set to `True`), centering by the
|
|
144
|
+
mean, and scaling by the standard deviation.
|
|
145
|
+
non_media_treatments_scaled: The non-media treatment tensor after
|
|
146
|
+
pre-modeling transformations including population scaling (for variables
|
|
147
|
+
with `ModelSpec.non_media_population_scaling_id` set to `True`), centering
|
|
148
|
+
by the mean, and scaling by the standard deviation.
|
|
149
|
+
kpi_scaled: The KPI tensor after pre-modeling transformations including
|
|
150
|
+
population scaling, centering by the mean, and scaling by the standard
|
|
151
|
+
deviation.
|
|
134
152
|
media_effects_dist: A string to specify the distribution of media random
|
|
135
153
|
effects across geos.
|
|
136
154
|
unique_sigma_for_each_geo: A boolean indicating whether to use a unique
|
|
@@ -148,6 +166,7 @@ class Meridian:
|
|
|
148
166
|
inference_data: (
|
|
149
167
|
az.InferenceData | None
|
|
150
168
|
) = None, # for deserializer use only
|
|
169
|
+
eda_spec: eda_spec_module.EDASpec = eda_spec_module.EDASpec(),
|
|
151
170
|
):
|
|
152
171
|
self._input_data = input_data
|
|
153
172
|
self._model_spec = model_spec if model_spec else spec.ModelSpec()
|
|
@@ -155,6 +174,8 @@ class Meridian:
|
|
|
155
174
|
inference_data if inference_data else az.InferenceData()
|
|
156
175
|
)
|
|
157
176
|
|
|
177
|
+
self._eda_spec = eda_spec
|
|
178
|
+
|
|
158
179
|
self._validate_data_dependent_model_spec()
|
|
159
180
|
self._validate_injected_inference_data()
|
|
160
181
|
|
|
@@ -184,6 +205,18 @@ class Meridian:
|
|
|
184
205
|
def inference_data(self) -> az.InferenceData:
|
|
185
206
|
return self._inference_data
|
|
186
207
|
|
|
208
|
+
@functools.cached_property
|
|
209
|
+
def eda_engine(self) -> eda_engine.EDAEngine:
|
|
210
|
+
return eda_engine.EDAEngine(self, spec=self._eda_spec)
|
|
211
|
+
|
|
212
|
+
@property
|
|
213
|
+
def eda_spec(self) -> eda_spec_module.EDASpec:
|
|
214
|
+
return self._eda_spec
|
|
215
|
+
|
|
216
|
+
@property
|
|
217
|
+
def eda_outcomes(self) -> Sequence[eda_outcome.EDAOutcome]:
|
|
218
|
+
return self.eda_engine.run_all_critical_checks()
|
|
219
|
+
|
|
187
220
|
@functools.cached_property
|
|
188
221
|
def media_tensors(self) -> media.MediaTensors:
|
|
189
222
|
return media.build_media_tensors(self.input_data, self.model_spec)
|
|
@@ -444,7 +477,8 @@ class Meridian:
|
|
|
444
477
|
f" {tuple(self.model_spec.adstock_decay_spec.keys())}. Keys should"
|
|
445
478
|
" either contain only channel_names"
|
|
446
479
|
f" {tuple(self.input_data.get_all_adstock_hill_channels().tolist())} or"
|
|
447
|
-
" be one or more of {'media', 'rf', 'organic_media',
|
|
480
|
+
" be one or more of {'media', 'rf', 'organic_media',"
|
|
481
|
+
" 'organic_rf'}."
|
|
448
482
|
) from e
|
|
449
483
|
|
|
450
484
|
@functools.cached_property
|
|
@@ -561,7 +595,9 @@ class Meridian:
|
|
|
561
595
|
non_media_treatments_population_scaled[..., channel], axis=[0, 1]
|
|
562
596
|
)
|
|
563
597
|
elif isinstance(baseline_value, numbers.Number):
|
|
564
|
-
baseline_for_channel = backend.
|
|
598
|
+
baseline_for_channel = backend.to_tensor(
|
|
599
|
+
baseline_value, dtype=backend.float32
|
|
600
|
+
)
|
|
565
601
|
else:
|
|
566
602
|
raise ValueError(
|
|
567
603
|
f"Invalid non_media_baseline_values value: '{baseline_value}'. Only"
|
|
@@ -1135,16 +1171,11 @@ class Meridian:
|
|
|
1135
1171
|
" time."
|
|
1136
1172
|
)
|
|
1137
1173
|
|
|
1138
|
-
def _kpi_has_variability(self):
|
|
1139
|
-
"""Returns True if the KPI has variability across geos and times."""
|
|
1140
|
-
return self.kpi_transformer.population_scaled_stdev != 0
|
|
1141
|
-
|
|
1142
1174
|
def _validate_kpi_transformer(self):
|
|
1143
1175
|
"""Validates the KPI transformer."""
|
|
1144
|
-
if self.
|
|
1176
|
+
if self.eda_engine.kpi_has_variability:
|
|
1145
1177
|
return
|
|
1146
|
-
|
|
1147
|
-
kpi = "kpi" if self.is_national else "population_scaled_kpi"
|
|
1178
|
+
kpi = self.eda_engine.kpi_scaled_da.name
|
|
1148
1179
|
|
|
1149
1180
|
if (
|
|
1150
1181
|
self.n_media_channels > 0
|
|
@@ -1569,6 +1600,36 @@ class Meridian:
|
|
|
1569
1600
|
"""
|
|
1570
1601
|
self.prior_sampler_callable(n_draws=n_draws, seed=seed)
|
|
1571
1602
|
|
|
1603
|
+
def _run_model_fitting_guardrail(self):
|
|
1604
|
+
"""Raises an error if the model has critical EDA issues."""
|
|
1605
|
+
error_findings_by_type: dict[eda_outcome.EDACheckType, list[str]] = (
|
|
1606
|
+
collections.defaultdict(list)
|
|
1607
|
+
)
|
|
1608
|
+
for outcome in self.eda_outcomes:
|
|
1609
|
+
error_findings = [
|
|
1610
|
+
finding
|
|
1611
|
+
for finding in outcome.findings
|
|
1612
|
+
if finding.severity == eda_outcome.EDASeverity.ERROR
|
|
1613
|
+
]
|
|
1614
|
+
if error_findings:
|
|
1615
|
+
error_findings_by_type[outcome.check_type].extend(
|
|
1616
|
+
[finding.explanation for finding in error_findings]
|
|
1617
|
+
)
|
|
1618
|
+
|
|
1619
|
+
if error_findings_by_type:
|
|
1620
|
+
error_message_lines = [
|
|
1621
|
+
"Model has critical EDA issues. Please fix before running"
|
|
1622
|
+
" `sample_posterior`.\n"
|
|
1623
|
+
]
|
|
1624
|
+
for check_type, explanations in error_findings_by_type.items():
|
|
1625
|
+
error_message_lines.append(f"Check type: {check_type.name}")
|
|
1626
|
+
for explanation in explanations:
|
|
1627
|
+
error_message_lines.append(f"- {explanation}")
|
|
1628
|
+
error_message_lines.append(
|
|
1629
|
+
"For further details, please refer to `Meridian.eda_outcomes`."
|
|
1630
|
+
)
|
|
1631
|
+
raise ModelFittingError("\n".join(error_message_lines))
|
|
1632
|
+
|
|
1572
1633
|
def sample_posterior(
|
|
1573
1634
|
self,
|
|
1574
1635
|
n_chains: Sequence[int] | int,
|
|
@@ -1644,8 +1705,10 @@ class Meridian:
|
|
|
1644
1705
|
a list of integers as `n_chains` to sample chains serially. For more
|
|
1645
1706
|
information, see
|
|
1646
1707
|
[ResourceExhaustedError when running Meridian.sample_posterior]
|
|
1647
|
-
(https://developers.google.com/meridian/docs/
|
|
1708
|
+
(https://developers.google.com/meridian/docs/post-modeling/model-debugging#gpu-oom-error).
|
|
1648
1709
|
"""
|
|
1710
|
+
self._run_model_fitting_guardrail()
|
|
1711
|
+
|
|
1649
1712
|
self.posterior_sampler_callable(
|
|
1650
1713
|
n_chains=n_chains,
|
|
1651
1714
|
n_adapt=n_adapt,
|
|
@@ -143,6 +143,7 @@ class WithInputDataSamples:
|
|
|
143
143
|
_short_input_data_with_rf_only: input_data.InputData
|
|
144
144
|
_short_input_data_with_media_and_rf: input_data.InputData
|
|
145
145
|
_national_input_data_media_only: input_data.InputData
|
|
146
|
+
_national_input_data_rf_only: input_data.InputData
|
|
146
147
|
_national_input_data_media_and_rf: input_data.InputData
|
|
147
148
|
_test_dist_media_and_rf: collections.OrderedDict[str, backend.Tensor]
|
|
148
149
|
_test_dist_media_only: collections.OrderedDict[str, backend.Tensor]
|
|
@@ -156,6 +157,8 @@ class WithInputDataSamples:
|
|
|
156
157
|
_short_input_data_non_media_and_organic: input_data.InputData
|
|
157
158
|
_short_input_data_non_media: input_data.InputData
|
|
158
159
|
_input_data_non_media_and_organic_same_time_dims: input_data.InputData
|
|
160
|
+
_input_data_organic_only: input_data.InputData
|
|
161
|
+
_national_input_data_organic_only: input_data.InputData
|
|
159
162
|
|
|
160
163
|
# The following NamedTuples and their attributes are immutable, so they can
|
|
161
164
|
# be accessed directly.
|
|
@@ -280,6 +283,16 @@ class WithInputDataSamples:
|
|
|
280
283
|
seed=0,
|
|
281
284
|
)
|
|
282
285
|
)
|
|
286
|
+
cls._national_input_data_rf_only = (
|
|
287
|
+
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
288
|
+
n_geos=cls._N_GEOS_NATIONAL,
|
|
289
|
+
n_times=cls._N_TIMES,
|
|
290
|
+
n_media_times=cls._N_MEDIA_TIMES,
|
|
291
|
+
n_controls=cls._N_CONTROLS,
|
|
292
|
+
n_rf_channels=cls._N_RF_CHANNELS,
|
|
293
|
+
seed=0,
|
|
294
|
+
)
|
|
295
|
+
)
|
|
283
296
|
cls._national_input_data_media_only = (
|
|
284
297
|
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
285
298
|
n_geos=cls._N_GEOS_NATIONAL,
|
|
@@ -496,6 +509,34 @@ class WithInputDataSamples:
|
|
|
496
509
|
seed=0,
|
|
497
510
|
)
|
|
498
511
|
)
|
|
512
|
+
cls._input_data_organic_only = (
|
|
513
|
+
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
514
|
+
n_geos=cls._N_GEOS,
|
|
515
|
+
n_times=cls._N_TIMES,
|
|
516
|
+
n_media_times=cls._N_MEDIA_TIMES,
|
|
517
|
+
n_controls=cls._N_CONTROLS,
|
|
518
|
+
n_non_media_channels=0,
|
|
519
|
+
n_media_channels=cls._N_MEDIA_CHANNELS,
|
|
520
|
+
n_rf_channels=0,
|
|
521
|
+
n_organic_media_channels=cls._N_ORGANIC_MEDIA_CHANNELS,
|
|
522
|
+
n_organic_rf_channels=cls._N_ORGANIC_RF_CHANNELS,
|
|
523
|
+
seed=0,
|
|
524
|
+
)
|
|
525
|
+
)
|
|
526
|
+
cls._national_input_data_organic_only = (
|
|
527
|
+
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
528
|
+
n_geos=cls._N_GEOS_NATIONAL,
|
|
529
|
+
n_times=cls._N_TIMES,
|
|
530
|
+
n_media_times=cls._N_MEDIA_TIMES,
|
|
531
|
+
n_controls=cls._N_CONTROLS,
|
|
532
|
+
n_non_media_channels=0,
|
|
533
|
+
n_media_channels=cls._N_MEDIA_CHANNELS,
|
|
534
|
+
n_rf_channels=0,
|
|
535
|
+
n_organic_media_channels=cls._N_ORGANIC_MEDIA_CHANNELS,
|
|
536
|
+
n_organic_rf_channels=cls._N_ORGANIC_RF_CHANNELS,
|
|
537
|
+
seed=0,
|
|
538
|
+
)
|
|
539
|
+
)
|
|
499
540
|
|
|
500
541
|
@property
|
|
501
542
|
def input_data_non_revenue_no_revenue_per_kpi(self) -> input_data.InputData:
|
|
@@ -551,6 +592,10 @@ class WithInputDataSamples:
|
|
|
551
592
|
def national_input_data_media_only(self) -> input_data.InputData:
|
|
552
593
|
return self._national_input_data_media_only.copy(deep=True)
|
|
553
594
|
|
|
595
|
+
@property
|
|
596
|
+
def national_input_data_rf_only(self) -> input_data.InputData:
|
|
597
|
+
return self._national_input_data_rf_only.copy(deep=True)
|
|
598
|
+
|
|
554
599
|
@property
|
|
555
600
|
def national_input_data_media_and_rf(self) -> input_data.InputData:
|
|
556
601
|
return self._national_input_data_media_and_rf.copy(deep=True)
|
|
@@ -606,3 +651,11 @@ class WithInputDataSamples:
|
|
|
606
651
|
self,
|
|
607
652
|
) -> input_data.InputData:
|
|
608
653
|
return self._input_data_non_media_and_organic_same_time_dims.copy(deep=True)
|
|
654
|
+
|
|
655
|
+
@property
|
|
656
|
+
def input_data_organic_only(self) -> input_data.InputData:
|
|
657
|
+
return self._input_data_organic_only.copy(deep=True)
|
|
658
|
+
|
|
659
|
+
@property
|
|
660
|
+
def national_input_data_organic_only(self) -> input_data.InputData:
|
|
661
|
+
return self._national_input_data_organic_only.copy(deep=True)
|
|
@@ -72,12 +72,6 @@ def _get_tau_g(
|
|
|
72
72
|
return backend.tfd.Deterministic(tau_g, name="tau_g")
|
|
73
73
|
|
|
74
74
|
|
|
75
|
-
@backend.function(autograph=False, jit_compile=True)
|
|
76
|
-
def _xla_windowed_adaptive_nuts(**kwargs):
|
|
77
|
-
"""XLA wrapper for windowed_adaptive_nuts."""
|
|
78
|
-
return backend.experimental.mcmc.windowed_adaptive_nuts(**kwargs)
|
|
79
|
-
|
|
80
|
-
|
|
81
75
|
def _joint_dist_unpinned(mmm: "model.Meridian"):
|
|
82
76
|
"""Returns unpinned joint distribution."""
|
|
83
77
|
|
|
@@ -447,26 +441,44 @@ class PosteriorMCMCSampler:
|
|
|
447
441
|
|
|
448
442
|
def __init__(self, meridian: "model.Meridian"):
|
|
449
443
|
self._meridian = meridian
|
|
444
|
+
self._joint_dist = None
|
|
445
|
+
|
|
446
|
+
def __getstate__(self):
|
|
447
|
+
state = self.__dict__.copy()
|
|
448
|
+
# Exclude unpickleable objects.
|
|
449
|
+
if "_joint_dist" in state:
|
|
450
|
+
del state["_joint_dist"]
|
|
451
|
+
return state
|
|
452
|
+
|
|
453
|
+
def __setstate__(self, state):
|
|
454
|
+
self.__dict__.update(state)
|
|
455
|
+
self._joint_dist = None
|
|
450
456
|
|
|
451
457
|
@property
|
|
452
458
|
def model(self) -> "model.Meridian":
|
|
453
459
|
return self._meridian
|
|
454
460
|
|
|
461
|
+
def _joint_dist_unpinned_fn(self):
|
|
462
|
+
return _joint_dist_unpinned(self.model)
|
|
463
|
+
|
|
455
464
|
def _get_joint_dist_unpinned(self) -> backend.tfd.Distribution:
|
|
456
|
-
"""
|
|
465
|
+
"""Builds a `JointDistributionCoroutineAutoBatched` function for MCMC."""
|
|
457
466
|
mmm = self.model
|
|
458
467
|
mmm.populate_cached_properties()
|
|
459
|
-
|
|
460
|
-
|
|
468
|
+
return backend.tfd.JointDistributionCoroutineAutoBatched(
|
|
469
|
+
self._joint_dist_unpinned_fn
|
|
470
|
+
)
|
|
461
471
|
|
|
462
472
|
def _get_joint_dist(self) -> backend.tfd.Distribution:
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
473
|
+
if self._joint_dist is None:
|
|
474
|
+
mmm = self.model
|
|
475
|
+
y = (
|
|
476
|
+
backend.where(mmm.holdout_id, 0.0, mmm.kpi_scaled)
|
|
477
|
+
if mmm.holdout_id is not None
|
|
478
|
+
else mmm.kpi_scaled
|
|
479
|
+
)
|
|
480
|
+
self._joint_dist = self._get_joint_dist_unpinned().experimental_pin(y=y)
|
|
481
|
+
return self._joint_dist
|
|
470
482
|
|
|
471
483
|
def __call__(
|
|
472
484
|
self,
|
|
@@ -541,26 +553,22 @@ class PosteriorMCMCSampler:
|
|
|
541
553
|
a list of integers as `n_chains` to sample chains serially. For more
|
|
542
554
|
information, see
|
|
543
555
|
[ResourceExhaustedError when running Meridian.sample_posterior]
|
|
544
|
-
(https://developers.google.com/meridian/docs/
|
|
556
|
+
(https://developers.google.com/meridian/docs/post-modeling/model-debugging#gpu-oom-error).
|
|
545
557
|
"""
|
|
546
|
-
|
|
547
|
-
raise ValueError(
|
|
548
|
-
"Invalid seed: Must be either a single integer (stateful seed) or a"
|
|
549
|
-
" pair of two integers (stateless seed). See"
|
|
550
|
-
" [tfp.random.sanitize_seed](https://www.tensorflow.org/probability/api_docs/python/tfp/random/sanitize_seed)"
|
|
551
|
-
" for details."
|
|
552
|
-
)
|
|
553
|
-
if seed is not None and isinstance(seed, int):
|
|
554
|
-
seed = (seed, seed)
|
|
555
|
-
seed = backend.random.sanitize_seed(seed) if seed is not None else None
|
|
558
|
+
rng_handler = backend.RNGHandler(seed)
|
|
556
559
|
n_chains_list = [n_chains] if isinstance(n_chains, int) else n_chains
|
|
557
560
|
total_chains = np.sum(n_chains_list)
|
|
558
561
|
|
|
562
|
+
# Clear joint distribution cache prior to sampling.
|
|
563
|
+
self._joint_dist = None
|
|
564
|
+
|
|
559
565
|
states = []
|
|
560
566
|
traces = []
|
|
561
567
|
for n_chains_batch in n_chains_list:
|
|
568
|
+
kernel_seed = rng_handler.get_kernel_seed()
|
|
569
|
+
|
|
562
570
|
try:
|
|
563
|
-
mcmc =
|
|
571
|
+
mcmc = backend.xla_windowed_adaptive_nuts(
|
|
564
572
|
n_draws=n_burnin + n_keep,
|
|
565
573
|
joint_dist=self._get_joint_dist(),
|
|
566
574
|
n_chains=n_chains_batch,
|
|
@@ -572,17 +580,16 @@ class PosteriorMCMCSampler:
|
|
|
572
580
|
max_energy_diff=max_energy_diff,
|
|
573
581
|
unrolled_leapfrog_steps=unrolled_leapfrog_steps,
|
|
574
582
|
parallel_iterations=parallel_iterations,
|
|
575
|
-
seed=
|
|
583
|
+
seed=kernel_seed,
|
|
576
584
|
**pins,
|
|
577
585
|
)
|
|
578
586
|
except backend.errors.ResourceExhaustedError as error:
|
|
579
587
|
raise MCMCOOMError(
|
|
580
588
|
"ERROR: Out of memory. Try reducing `n_keep` or pass a list of"
|
|
581
589
|
" integers as `n_chains` to sample chains serially (see"
|
|
582
|
-
" https://developers.google.com/meridian/docs/
|
|
590
|
+
" https://developers.google.com/meridian/docs/post-modeling/model-debugging#gpu-oom-error)"
|
|
583
591
|
) from error
|
|
584
|
-
|
|
585
|
-
seed += 1
|
|
592
|
+
rng_handler = rng_handler.advance_handler()
|
|
586
593
|
states.append(mcmc.all_states._asdict())
|
|
587
594
|
traces.append(mcmc.trace)
|
|
588
595
|
|
|
@@ -35,6 +35,7 @@ __all__ = [
|
|
|
35
35
|
'PriorDistribution',
|
|
36
36
|
'distributions_are_equal',
|
|
37
37
|
'lognormal_dist_from_mean_std',
|
|
38
|
+
'lognormal_dist_from_range',
|
|
38
39
|
]
|
|
39
40
|
|
|
40
41
|
|
|
@@ -1195,7 +1196,7 @@ def lognormal_dist_from_range(
|
|
|
1195
1196
|
"""Define a LogNormal distribution from a specified range.
|
|
1196
1197
|
|
|
1197
1198
|
This function parameterizes lognormal distributions by the bounds of a range,
|
|
1198
|
-
so that the
|
|
1199
|
+
so that the specified probability mass falls within the bounds defined by
|
|
1199
1200
|
`low` and `high`. The probability mass is symmetric about the median. For
|
|
1200
1201
|
example, to define a lognormal distribution with a 95% probability mass of
|
|
1201
1202
|
(1, 10), use:
|
|
@@ -1210,7 +1211,7 @@ def lognormal_dist_from_range(
|
|
|
1210
1211
|
high: Float or array-like denoting the upper bound of range. Values must be
|
|
1211
1212
|
non-negative.
|
|
1212
1213
|
mass_percent: Float or array-like denoting the probability mass. Values must
|
|
1213
|
-
be between 0 and 1 (
|
|
1214
|
+
be between 0 and 1 (exclusive). Default: 0.95.
|
|
1214
1215
|
|
|
1215
1216
|
Returns:
|
|
1216
1217
|
A `backend.tfd.LogNormal` object with the input percentage mass falling
|
|
@@ -1341,6 +1342,15 @@ def _validate_support(
|
|
|
1341
1342
|
f'{parameter_name} was assigned a point mass (deterministic) prior'
|
|
1342
1343
|
f' at {bounds[i]}, which is not allowed.'
|
|
1343
1344
|
)
|
|
1345
|
+
elif isinstance(tfp_dist, backend.tfd.TruncatedNormal):
|
|
1346
|
+
# TruncatedNormal quantile method is not reliable, particularly when the
|
|
1347
|
+
# `low` or `high` value falls into extreme percentile of the untruncated
|
|
1348
|
+
# distribution. Note that
|
|
1349
|
+
# `TruncatedNormal.experimental_default_event_space_bijector()([-inf, inf])`
|
|
1350
|
+
# returns the correct support range, so this method could be used if the
|
|
1351
|
+
# `quantile` method is found to be unreliable for other distributions.
|
|
1352
|
+
support_min_vals = tfp_dist.low
|
|
1353
|
+
support_max_vals = tfp_dist.high
|
|
1344
1354
|
else:
|
|
1345
1355
|
try:
|
|
1346
1356
|
support_min_vals = tfp_dist.quantile(0)
|