google-meridian 1.3.2__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/METADATA +18 -11
- google_meridian-1.5.0.dist-info/RECORD +112 -0
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/WHEEL +1 -1
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/top_level.txt +1 -0
- meridian/analysis/analyzer.py +558 -398
- meridian/analysis/optimizer.py +90 -68
- meridian/analysis/review/reviewer.py +4 -1
- meridian/analysis/summarizer.py +13 -3
- meridian/analysis/test_utils.py +2911 -2102
- meridian/analysis/visualizer.py +37 -14
- meridian/backend/__init__.py +106 -0
- meridian/constants.py +2 -0
- meridian/data/input_data.py +30 -52
- meridian/data/input_data_builder.py +2 -9
- meridian/data/test_utils.py +107 -51
- meridian/data/validator.py +48 -0
- meridian/mlflow/autolog.py +19 -9
- meridian/model/__init__.py +2 -0
- meridian/model/adstock_hill.py +3 -5
- meridian/model/context.py +1059 -0
- meridian/model/eda/constants.py +335 -4
- meridian/model/eda/eda_engine.py +723 -312
- meridian/model/eda/eda_outcome.py +177 -33
- meridian/model/equations.py +418 -0
- meridian/model/knots.py +58 -47
- meridian/model/model.py +228 -878
- meridian/model/model_test_data.py +38 -0
- meridian/model/posterior_sampler.py +103 -62
- meridian/model/prior_sampler.py +114 -94
- meridian/model/spec.py +23 -14
- meridian/templates/card.html.jinja +9 -7
- meridian/templates/chart.html.jinja +1 -6
- meridian/templates/finding.html.jinja +19 -0
- meridian/templates/findings.html.jinja +33 -0
- meridian/templates/formatter.py +41 -5
- meridian/templates/formatter_test.py +127 -0
- meridian/templates/style.css +66 -9
- meridian/templates/style.scss +85 -4
- meridian/templates/table.html.jinja +1 -0
- meridian/version.py +1 -1
- scenarioplanner/__init__.py +42 -0
- scenarioplanner/converters/__init__.py +25 -0
- scenarioplanner/converters/dataframe/__init__.py +28 -0
- scenarioplanner/converters/dataframe/budget_opt_converters.py +383 -0
- scenarioplanner/converters/dataframe/common.py +71 -0
- scenarioplanner/converters/dataframe/constants.py +137 -0
- scenarioplanner/converters/dataframe/converter.py +42 -0
- scenarioplanner/converters/dataframe/dataframe_model_converter.py +70 -0
- scenarioplanner/converters/dataframe/marketing_analyses_converters.py +543 -0
- scenarioplanner/converters/dataframe/rf_opt_converters.py +314 -0
- scenarioplanner/converters/mmm.py +743 -0
- scenarioplanner/converters/mmm_converter.py +58 -0
- scenarioplanner/converters/sheets.py +156 -0
- scenarioplanner/converters/test_data.py +714 -0
- scenarioplanner/linkingapi/__init__.py +47 -0
- scenarioplanner/linkingapi/constants.py +27 -0
- scenarioplanner/linkingapi/url_generator.py +131 -0
- scenarioplanner/mmm_ui_proto_generator.py +355 -0
- schema/__init__.py +5 -2
- schema/mmm_proto_generator.py +71 -0
- schema/model_consumer.py +133 -0
- schema/processors/__init__.py +77 -0
- schema/processors/budget_optimization_processor.py +832 -0
- schema/processors/common.py +64 -0
- schema/processors/marketing_processor.py +1137 -0
- schema/processors/model_fit_processor.py +367 -0
- schema/processors/model_kernel_processor.py +117 -0
- schema/processors/model_processor.py +415 -0
- schema/processors/reach_frequency_optimization_processor.py +584 -0
- schema/serde/distribution.py +12 -7
- schema/serde/hyperparameters.py +54 -107
- schema/serde/meridian_serde.py +6 -1
- schema/test_data.py +380 -0
- schema/utils/__init__.py +2 -0
- schema/utils/date_range_bucketing.py +117 -0
- schema/utils/proto_enum_converter.py +127 -0
- google_meridian-1.3.2.dist-info/RECORD +0 -76
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -159,6 +159,8 @@ class WithInputDataSamples:
|
|
|
159
159
|
_input_data_non_media_and_organic_same_time_dims: input_data.InputData
|
|
160
160
|
_input_data_organic_only: input_data.InputData
|
|
161
161
|
_national_input_data_organic_only: input_data.InputData
|
|
162
|
+
_input_data_non_media_only: input_data.InputData
|
|
163
|
+
_national_input_data_non_media_only: input_data.InputData
|
|
162
164
|
|
|
163
165
|
# The following NamedTuples and their attributes are immutable, so they can
|
|
164
166
|
# be accessed directly.
|
|
@@ -537,6 +539,34 @@ class WithInputDataSamples:
|
|
|
537
539
|
seed=0,
|
|
538
540
|
)
|
|
539
541
|
)
|
|
542
|
+
cls._input_data_non_media_only = (
|
|
543
|
+
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
544
|
+
n_geos=cls._N_GEOS,
|
|
545
|
+
n_times=cls._N_TIMES,
|
|
546
|
+
n_media_times=cls._N_MEDIA_TIMES,
|
|
547
|
+
n_controls=0,
|
|
548
|
+
n_non_media_channels=cls._N_NON_MEDIA_CHANNELS,
|
|
549
|
+
n_media_channels=cls._N_MEDIA_CHANNELS,
|
|
550
|
+
n_rf_channels=0,
|
|
551
|
+
n_organic_media_channels=0,
|
|
552
|
+
n_organic_rf_channels=0,
|
|
553
|
+
seed=0,
|
|
554
|
+
)
|
|
555
|
+
)
|
|
556
|
+
cls._national_input_data_non_media_only = (
|
|
557
|
+
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
558
|
+
n_geos=cls._N_GEOS_NATIONAL,
|
|
559
|
+
n_times=cls._N_TIMES,
|
|
560
|
+
n_media_times=cls._N_MEDIA_TIMES,
|
|
561
|
+
n_controls=0,
|
|
562
|
+
n_non_media_channels=cls._N_NON_MEDIA_CHANNELS,
|
|
563
|
+
n_media_channels=cls._N_MEDIA_CHANNELS,
|
|
564
|
+
n_rf_channels=0,
|
|
565
|
+
n_organic_media_channels=0,
|
|
566
|
+
n_organic_rf_channels=0,
|
|
567
|
+
seed=0,
|
|
568
|
+
)
|
|
569
|
+
)
|
|
540
570
|
|
|
541
571
|
@property
|
|
542
572
|
def input_data_non_revenue_no_revenue_per_kpi(self) -> input_data.InputData:
|
|
@@ -659,3 +689,11 @@ class WithInputDataSamples:
|
|
|
659
689
|
@property
|
|
660
690
|
def national_input_data_organic_only(self) -> input_data.InputData:
|
|
661
691
|
return self._national_input_data_organic_only.copy(deep=True)
|
|
692
|
+
|
|
693
|
+
@property
|
|
694
|
+
def input_data_non_media_only(self) -> input_data.InputData:
|
|
695
|
+
return self._input_data_non_media_only.copy(deep=True)
|
|
696
|
+
|
|
697
|
+
@property
|
|
698
|
+
def national_input_data_non_media_only(self) -> input_data.InputData:
|
|
699
|
+
return self._national_input_data_non_media_only.copy(deep=True)
|
|
@@ -15,11 +15,15 @@
|
|
|
15
15
|
"""Module for MCMC sampling of posterior distributions in a Meridian model."""
|
|
16
16
|
|
|
17
17
|
from collections.abc import Mapping, Sequence
|
|
18
|
-
|
|
18
|
+
import functools
|
|
19
|
+
from typing import Optional, TYPE_CHECKING
|
|
20
|
+
import warnings
|
|
19
21
|
|
|
20
22
|
import arviz as az
|
|
21
23
|
from meridian import backend
|
|
22
24
|
from meridian import constants
|
|
25
|
+
from meridian.model import context
|
|
26
|
+
from meridian.model import equations
|
|
23
27
|
import numpy as np
|
|
24
28
|
|
|
25
29
|
if TYPE_CHECKING:
|
|
@@ -72,34 +76,39 @@ def _get_tau_g(
|
|
|
72
76
|
return backend.tfd.Deterministic(tau_g, name="tau_g")
|
|
73
77
|
|
|
74
78
|
|
|
75
|
-
def _joint_dist_unpinned(
|
|
79
|
+
def _joint_dist_unpinned(
|
|
80
|
+
model_context: context.ModelContext,
|
|
81
|
+
model_equations: equations.ModelEquations,
|
|
82
|
+
):
|
|
76
83
|
"""Returns unpinned joint distribution."""
|
|
77
84
|
|
|
78
85
|
# This lists all the derived properties and states of this Meridian object
|
|
79
86
|
# that are referenced by the joint distribution coroutine.
|
|
80
87
|
# That is, these are the list of captured parameters.
|
|
81
|
-
prior_broadcast =
|
|
82
|
-
baseline_geo_idx =
|
|
83
|
-
knot_info =
|
|
84
|
-
n_geos =
|
|
85
|
-
n_times =
|
|
86
|
-
n_media_channels =
|
|
87
|
-
n_rf_channels =
|
|
88
|
-
n_organic_media_channels =
|
|
89
|
-
n_organic_rf_channels =
|
|
90
|
-
n_controls =
|
|
91
|
-
n_non_media_channels =
|
|
92
|
-
holdout_id =
|
|
93
|
-
media_tensors =
|
|
94
|
-
rf_tensors =
|
|
95
|
-
organic_media_tensors =
|
|
96
|
-
organic_rf_tensors =
|
|
97
|
-
controls_scaled =
|
|
98
|
-
non_media_treatments_normalized =
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
88
|
+
prior_broadcast = model_context.prior_broadcast
|
|
89
|
+
baseline_geo_idx = model_context.baseline_geo_idx
|
|
90
|
+
knot_info = model_context.knot_info
|
|
91
|
+
n_geos = model_context.n_geos
|
|
92
|
+
n_times = model_context.n_times
|
|
93
|
+
n_media_channels = model_context.n_media_channels
|
|
94
|
+
n_rf_channels = model_context.n_rf_channels
|
|
95
|
+
n_organic_media_channels = model_context.n_organic_media_channels
|
|
96
|
+
n_organic_rf_channels = model_context.n_organic_rf_channels
|
|
97
|
+
n_controls = model_context.n_controls
|
|
98
|
+
n_non_media_channels = model_context.n_non_media_channels
|
|
99
|
+
holdout_id = model_context.holdout_id
|
|
100
|
+
media_tensors = model_context.media_tensors
|
|
101
|
+
rf_tensors = model_context.rf_tensors
|
|
102
|
+
organic_media_tensors = model_context.organic_media_tensors
|
|
103
|
+
organic_rf_tensors = model_context.organic_rf_tensors
|
|
104
|
+
controls_scaled = model_context.controls_scaled
|
|
105
|
+
non_media_treatments_normalized = (
|
|
106
|
+
model_context.non_media_treatments_normalized
|
|
107
|
+
)
|
|
108
|
+
media_effects_dist = model_context.media_effects_dist
|
|
109
|
+
adstock_hill_media_fn = model_equations.adstock_hill_media
|
|
110
|
+
adstock_hill_rf_fn = model_equations.adstock_hill_rf
|
|
111
|
+
total_outcome = model_context.total_outcome
|
|
103
112
|
|
|
104
113
|
# Sample directly from prior.
|
|
105
114
|
knot_values = yield prior_broadcast.knot_values
|
|
@@ -142,9 +151,9 @@ def _joint_dist_unpinned(mmm: "model.Meridian"):
|
|
|
142
151
|
alpha=alpha_m,
|
|
143
152
|
ec=ec_m,
|
|
144
153
|
slope=slope_m,
|
|
145
|
-
decay_functions=
|
|
154
|
+
decay_functions=model_context.adstock_decay_spec.media,
|
|
146
155
|
)
|
|
147
|
-
prior_type =
|
|
156
|
+
prior_type = model_context.model_spec.effective_media_prior_type
|
|
148
157
|
if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
|
|
149
158
|
beta_m = yield prior_broadcast.beta_m
|
|
150
159
|
else:
|
|
@@ -160,14 +169,14 @@ def _joint_dist_unpinned(mmm: "model.Meridian"):
|
|
|
160
169
|
treatment_parameter_m * media_tensors.prior_denominator
|
|
161
170
|
)
|
|
162
171
|
linear_predictor_counterfactual_difference = (
|
|
163
|
-
|
|
172
|
+
model_equations.linear_predictor_counterfactual_difference_media(
|
|
164
173
|
media_transformed=media_transformed,
|
|
165
174
|
alpha_m=alpha_m,
|
|
166
175
|
ec_m=ec_m,
|
|
167
176
|
slope_m=slope_m,
|
|
168
177
|
)
|
|
169
178
|
)
|
|
170
|
-
beta_m_value =
|
|
179
|
+
beta_m_value = model_equations.calculate_beta_x(
|
|
171
180
|
is_non_media=False,
|
|
172
181
|
incremental_outcome_x=incremental_outcome_m,
|
|
173
182
|
linear_predictor_counterfactual_difference=linear_predictor_counterfactual_difference,
|
|
@@ -208,10 +217,10 @@ def _joint_dist_unpinned(mmm: "model.Meridian"):
|
|
|
208
217
|
alpha=alpha_rf,
|
|
209
218
|
ec=ec_rf,
|
|
210
219
|
slope=slope_rf,
|
|
211
|
-
decay_functions=
|
|
220
|
+
decay_functions=model_context.adstock_decay_spec.rf,
|
|
212
221
|
)
|
|
213
222
|
|
|
214
|
-
prior_type =
|
|
223
|
+
prior_type = model_context.model_spec.effective_rf_prior_type
|
|
215
224
|
if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
|
|
216
225
|
beta_rf = yield prior_broadcast.beta_rf
|
|
217
226
|
else:
|
|
@@ -227,14 +236,14 @@ def _joint_dist_unpinned(mmm: "model.Meridian"):
|
|
|
227
236
|
treatment_parameter_rf * rf_tensors.prior_denominator
|
|
228
237
|
)
|
|
229
238
|
linear_predictor_counterfactual_difference = (
|
|
230
|
-
|
|
239
|
+
model_equations.linear_predictor_counterfactual_difference_rf(
|
|
231
240
|
rf_transformed=rf_transformed,
|
|
232
241
|
alpha_rf=alpha_rf,
|
|
233
242
|
ec_rf=ec_rf,
|
|
234
243
|
slope_rf=slope_rf,
|
|
235
244
|
)
|
|
236
245
|
)
|
|
237
|
-
beta_rf_value =
|
|
246
|
+
beta_rf_value = model_equations.calculate_beta_x(
|
|
238
247
|
is_non_media=False,
|
|
239
248
|
incremental_outcome_x=incremental_outcome_rf,
|
|
240
249
|
linear_predictor_counterfactual_difference=linear_predictor_counterfactual_difference,
|
|
@@ -274,15 +283,15 @@ def _joint_dist_unpinned(mmm: "model.Meridian"):
|
|
|
274
283
|
alpha=alpha_om,
|
|
275
284
|
ec=ec_om,
|
|
276
285
|
slope=slope_om,
|
|
277
|
-
decay_functions=
|
|
286
|
+
decay_functions=model_context.adstock_decay_spec.organic_media,
|
|
278
287
|
)
|
|
279
|
-
prior_type =
|
|
288
|
+
prior_type = model_context.model_spec.organic_media_prior_type
|
|
280
289
|
if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
|
|
281
290
|
beta_om = yield prior_broadcast.beta_om
|
|
282
291
|
elif prior_type == constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION:
|
|
283
292
|
contribution_om = yield prior_broadcast.contribution_om
|
|
284
293
|
incremental_outcome_om = contribution_om * total_outcome
|
|
285
|
-
beta_om_value =
|
|
294
|
+
beta_om_value = model_equations.calculate_beta_x(
|
|
286
295
|
is_non_media=False,
|
|
287
296
|
incremental_outcome_x=incremental_outcome_om,
|
|
288
297
|
linear_predictor_counterfactual_difference=organic_media_transformed,
|
|
@@ -325,16 +334,16 @@ def _joint_dist_unpinned(mmm: "model.Meridian"):
|
|
|
325
334
|
alpha=alpha_orf,
|
|
326
335
|
ec=ec_orf,
|
|
327
336
|
slope=slope_orf,
|
|
328
|
-
decay_functions=
|
|
337
|
+
decay_functions=model_context.adstock_decay_spec.organic_rf,
|
|
329
338
|
)
|
|
330
339
|
|
|
331
|
-
prior_type =
|
|
340
|
+
prior_type = model_context.model_spec.organic_rf_prior_type
|
|
332
341
|
if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
|
|
333
342
|
beta_orf = yield prior_broadcast.beta_orf
|
|
334
343
|
elif prior_type == constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION:
|
|
335
344
|
contribution_orf = yield prior_broadcast.contribution_orf
|
|
336
345
|
incremental_outcome_orf = contribution_orf * total_outcome
|
|
337
|
-
beta_orf_value =
|
|
346
|
+
beta_orf_value = model_equations.calculate_beta_x(
|
|
338
347
|
is_non_media=False,
|
|
339
348
|
incremental_outcome_x=incremental_outcome_orf,
|
|
340
349
|
linear_predictor_counterfactual_difference=organic_rf_transformed,
|
|
@@ -382,26 +391,26 @@ def _joint_dist_unpinned(mmm: "model.Meridian"):
|
|
|
382
391
|
"gtc,gc->gt", controls_scaled, gamma_gc
|
|
383
392
|
)
|
|
384
393
|
|
|
385
|
-
if
|
|
394
|
+
if model_context.non_media_treatments is not None:
|
|
386
395
|
xi_n = yield prior_broadcast.xi_n
|
|
387
396
|
gamma_gn_dev = yield backend.tfd.Sample(
|
|
388
397
|
backend.tfd.Normal(0, 1),
|
|
389
398
|
[n_geos, n_non_media_channels],
|
|
390
399
|
name=constants.GAMMA_GN_DEV,
|
|
391
400
|
)
|
|
392
|
-
prior_type =
|
|
401
|
+
prior_type = model_context.model_spec.non_media_treatments_prior_type
|
|
393
402
|
if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
|
|
394
403
|
gamma_n = yield prior_broadcast.gamma_n
|
|
395
404
|
elif prior_type == constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION:
|
|
396
405
|
contribution_n = yield prior_broadcast.contribution_n
|
|
397
406
|
incremental_outcome_n = contribution_n * total_outcome
|
|
398
|
-
baseline_scaled =
|
|
399
|
-
|
|
407
|
+
baseline_scaled = model_context.non_media_transformer.forward( # pytype: disable=attribute-error
|
|
408
|
+
model_equations.compute_non_media_treatments_baseline()
|
|
400
409
|
)
|
|
401
410
|
linear_predictor_counterfactual_difference = (
|
|
402
411
|
non_media_treatments_normalized - baseline_scaled
|
|
403
412
|
)
|
|
404
|
-
gamma_n_value =
|
|
413
|
+
gamma_n_value = model_equations.calculate_beta_x(
|
|
405
414
|
is_non_media=True,
|
|
406
415
|
incremental_outcome_x=incremental_outcome_n,
|
|
407
416
|
linear_predictor_counterfactual_difference=linear_predictor_counterfactual_difference,
|
|
@@ -439,10 +448,36 @@ def _joint_dist_unpinned(mmm: "model.Meridian"):
|
|
|
439
448
|
class PosteriorMCMCSampler:
|
|
440
449
|
"""A callable that samples from posterior distributions using MCMC."""
|
|
441
450
|
|
|
442
|
-
|
|
443
|
-
|
|
451
|
+
# TODO: Deprecate direct injection of `model.Meridian`.
|
|
452
|
+
def __init__(
|
|
453
|
+
self,
|
|
454
|
+
meridian: Optional["model.Meridian"] = None,
|
|
455
|
+
*,
|
|
456
|
+
model_context: context.ModelContext | None = None,
|
|
457
|
+
):
|
|
458
|
+
if meridian is not None:
|
|
459
|
+
warnings.warn(
|
|
460
|
+
"Initializing PosteriorMCMCSampler with a Meridian object is"
|
|
461
|
+
" deprecated and will be removed in a future version. Please use"
|
|
462
|
+
" `model_context` instead.",
|
|
463
|
+
DeprecationWarning,
|
|
464
|
+
stacklevel=2,
|
|
465
|
+
)
|
|
466
|
+
self._meridian = meridian
|
|
467
|
+
self._model_context = meridian.model_context
|
|
468
|
+
elif model_context is not None:
|
|
469
|
+
self._meridian = None
|
|
470
|
+
self._model_context = model_context
|
|
471
|
+
else:
|
|
472
|
+
raise ValueError(
|
|
473
|
+
"Either `meridian` or `model_context` must be provided."
|
|
474
|
+
)
|
|
444
475
|
self._joint_dist = None
|
|
445
476
|
|
|
477
|
+
@functools.cached_property
|
|
478
|
+
def _model_equations(self) -> equations.ModelEquations:
|
|
479
|
+
return equations.ModelEquations(self._model_context)
|
|
480
|
+
|
|
446
481
|
def __getstate__(self):
|
|
447
482
|
state = self.__dict__.copy()
|
|
448
483
|
# Exclude unpickleable objects.
|
|
@@ -454,29 +489,33 @@ class PosteriorMCMCSampler:
|
|
|
454
489
|
self.__dict__.update(state)
|
|
455
490
|
self._joint_dist = None
|
|
456
491
|
|
|
492
|
+
# TODO: Remove this property in favor of using `model_context`
|
|
493
|
+
# and `model_equations` directly.
|
|
457
494
|
@property
|
|
458
|
-
def model(self) -> "model.Meridian":
|
|
495
|
+
def model(self) -> Optional["model.Meridian"]:
|
|
459
496
|
return self._meridian
|
|
460
497
|
|
|
461
498
|
def _joint_dist_unpinned_fn(self):
|
|
462
|
-
return _joint_dist_unpinned(self.
|
|
499
|
+
return _joint_dist_unpinned(self._model_context, self._model_equations)
|
|
463
500
|
|
|
464
501
|
def _get_joint_dist_unpinned(self) -> backend.tfd.Distribution:
|
|
465
502
|
"""Builds a `JointDistributionCoroutineAutoBatched` function for MCMC."""
|
|
466
|
-
|
|
467
|
-
mmm.populate_cached_properties()
|
|
503
|
+
self._model_context.populate_cached_properties()
|
|
468
504
|
return backend.tfd.JointDistributionCoroutineAutoBatched(
|
|
469
505
|
self._joint_dist_unpinned_fn
|
|
470
506
|
)
|
|
471
507
|
|
|
472
508
|
def _get_joint_dist(self) -> backend.tfd.Distribution:
|
|
509
|
+
"""Returns a joint distribution for MCMC sampling."""
|
|
473
510
|
if self._joint_dist is None:
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
511
|
+
if self._model_context.holdout_id is not None:
|
|
512
|
+
y = backend.where(
|
|
513
|
+
self._model_context.holdout_id,
|
|
514
|
+
0.0,
|
|
515
|
+
self._model_context.kpi_scaled,
|
|
516
|
+
)
|
|
517
|
+
else:
|
|
518
|
+
y = self._model_context.kpi_scaled
|
|
480
519
|
self._joint_dist = self._get_joint_dist_unpinned().experimental_pin(y=y)
|
|
481
520
|
return self._joint_dist
|
|
482
521
|
|
|
@@ -495,7 +534,7 @@ class PosteriorMCMCSampler:
|
|
|
495
534
|
parallel_iterations: int = 10,
|
|
496
535
|
seed: Sequence[int] | int | None = None,
|
|
497
536
|
**pins,
|
|
498
|
-
) ->
|
|
537
|
+
) -> az.InferenceData:
|
|
499
538
|
"""Runs Markov Chain Monte Carlo (MCMC) sampling of posterior distributions.
|
|
500
539
|
|
|
501
540
|
For more information about the arguments, see [`windowed_adaptive_nuts`]
|
|
@@ -548,6 +587,10 @@ class PosteriorMCMCSampler:
|
|
|
548
587
|
**pins: These are used to condition the provided joint distribution, and
|
|
549
588
|
are passed directly to `joint_dist.experimental_pin(**pins)`.
|
|
550
589
|
|
|
590
|
+
Returns:
|
|
591
|
+
An `arviz.InferenceData` object containing the posterior samples, trace
|
|
592
|
+
metrics, and sampling statistics.
|
|
593
|
+
|
|
551
594
|
Throws:
|
|
552
595
|
MCMCOOMError: If the model is out of memory. Try reducing `n_keep` or pass
|
|
553
596
|
a list of integers as `n_chains` to sample chains serially. For more
|
|
@@ -604,10 +647,10 @@ class PosteriorMCMCSampler:
|
|
|
604
647
|
if k not in constants.UNSAVED_PARAMETERS
|
|
605
648
|
}
|
|
606
649
|
# Create Arviz InferenceData for posterior draws.
|
|
607
|
-
posterior_coords = self.
|
|
650
|
+
posterior_coords = self._model_context.create_inference_data_coords(
|
|
608
651
|
total_chains, n_keep
|
|
609
652
|
)
|
|
610
|
-
posterior_dims = self.
|
|
653
|
+
posterior_dims = self._model_context.create_inference_data_dims()
|
|
611
654
|
infdata_posterior = az.convert_to_inference_data(
|
|
612
655
|
mcmc_states, coords=posterior_coords, dims=posterior_dims
|
|
613
656
|
)
|
|
@@ -669,7 +712,5 @@ class PosteriorMCMCSampler:
|
|
|
669
712
|
dims=sample_stats_dims,
|
|
670
713
|
group="sample_stats",
|
|
671
714
|
)
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
)
|
|
675
|
-
self.model.inference_data.extend(posterior_inference_data, join="right")
|
|
715
|
+
|
|
716
|
+
return az.concat(infdata_posterior, infdata_trace, infdata_sample_stats)
|