google-meridian 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.4.0.dist-info → google_meridian-1.5.0.dist-info}/METADATA +14 -11
- {google_meridian-1.4.0.dist-info → google_meridian-1.5.0.dist-info}/RECORD +47 -43
- {google_meridian-1.4.0.dist-info → google_meridian-1.5.0.dist-info}/WHEEL +1 -1
- meridian/analysis/analyzer.py +558 -398
- meridian/analysis/optimizer.py +90 -68
- meridian/analysis/review/reviewer.py +4 -1
- meridian/analysis/summarizer.py +6 -1
- meridian/analysis/test_utils.py +2898 -2538
- meridian/analysis/visualizer.py +28 -9
- meridian/backend/__init__.py +106 -0
- meridian/constants.py +1 -0
- meridian/data/input_data.py +30 -52
- meridian/data/input_data_builder.py +2 -9
- meridian/data/test_utils.py +25 -41
- meridian/data/validator.py +48 -0
- meridian/mlflow/autolog.py +19 -9
- meridian/model/adstock_hill.py +3 -5
- meridian/model/context.py +134 -0
- meridian/model/eda/constants.py +334 -4
- meridian/model/eda/eda_engine.py +723 -312
- meridian/model/eda/eda_outcome.py +177 -33
- meridian/model/model.py +159 -110
- meridian/model/model_test_data.py +38 -0
- meridian/model/posterior_sampler.py +103 -62
- meridian/model/prior_sampler.py +114 -94
- meridian/model/spec.py +23 -14
- meridian/templates/card.html.jinja +9 -7
- meridian/templates/chart.html.jinja +1 -6
- meridian/templates/finding.html.jinja +19 -0
- meridian/templates/findings.html.jinja +33 -0
- meridian/templates/formatter.py +41 -5
- meridian/templates/formatter_test.py +127 -0
- meridian/templates/style.css +66 -9
- meridian/templates/style.scss +85 -4
- meridian/templates/table.html.jinja +1 -0
- meridian/version.py +1 -1
- scenarioplanner/linkingapi/constants.py +1 -1
- scenarioplanner/mmm_ui_proto_generator.py +1 -0
- schema/processors/marketing_processor.py +11 -10
- schema/processors/model_processor.py +4 -1
- schema/serde/distribution.py +12 -7
- schema/serde/hyperparameters.py +54 -107
- schema/serde/meridian_serde.py +6 -1
- schema/utils/__init__.py +1 -0
- schema/utils/proto_enum_converter.py +127 -0
- {google_meridian-1.4.0.dist-info → google_meridian-1.5.0.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.4.0.dist-info → google_meridian-1.5.0.dist-info}/top_level.txt +0 -0
meridian/model/model.py
CHANGED
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
import collections
|
|
18
18
|
from collections.abc import Mapping, Sequence
|
|
19
|
+
import dataclasses
|
|
19
20
|
import functools
|
|
20
21
|
import os
|
|
21
22
|
import warnings
|
|
@@ -84,7 +85,7 @@ class Meridian:
|
|
|
84
85
|
input_data: An `InputData` object containing the input data for the model.
|
|
85
86
|
model_spec: A `ModelSpec` object containing the model specification.
|
|
86
87
|
model_context: A `ModelContext` object containing the model context.
|
|
87
|
-
|
|
88
|
+
model_equations: A `ModelEquations` object containing stateless mathematical
|
|
88
89
|
functions and utilities for Meridian MMM.
|
|
89
90
|
inference_data: A _mutable_ `arviz.InferenceData` object containing the
|
|
90
91
|
resulting data from fitting the model.
|
|
@@ -164,7 +165,7 @@ class Meridian:
|
|
|
164
165
|
input_data=input_data,
|
|
165
166
|
model_spec=model_spec if model_spec else spec.ModelSpec(),
|
|
166
167
|
)
|
|
167
|
-
self.
|
|
168
|
+
self._model_equations = equations.ModelEquations(self._model_context)
|
|
168
169
|
|
|
169
170
|
self._eda_spec = eda_spec
|
|
170
171
|
|
|
@@ -190,8 +191,8 @@ class Meridian:
|
|
|
190
191
|
return self._model_context
|
|
191
192
|
|
|
192
193
|
@property
|
|
193
|
-
def
|
|
194
|
-
return self.
|
|
194
|
+
def model_equations(self) -> equations.ModelEquations:
|
|
195
|
+
return self._model_equations
|
|
195
196
|
|
|
196
197
|
@property
|
|
197
198
|
def inference_data(self) -> az.InferenceData:
|
|
@@ -199,57 +200,59 @@ class Meridian:
|
|
|
199
200
|
|
|
200
201
|
@functools.cached_property
|
|
201
202
|
def eda_engine(self) -> eda_engine.EDAEngine:
|
|
202
|
-
return eda_engine.EDAEngine(
|
|
203
|
+
return eda_engine.EDAEngine(
|
|
204
|
+
spec=self._eda_spec, model_context=self.model_context
|
|
205
|
+
)
|
|
203
206
|
|
|
204
207
|
@property
|
|
205
208
|
def eda_spec(self) -> eda_spec_module.EDASpec:
|
|
206
209
|
return self._eda_spec
|
|
207
210
|
|
|
208
211
|
@property
|
|
209
|
-
def eda_outcomes(self) ->
|
|
212
|
+
def eda_outcomes(self) -> eda_outcome.CriticalCheckEDAOutcomes:
|
|
210
213
|
return self.eda_engine.run_all_critical_checks()
|
|
211
214
|
|
|
212
|
-
@
|
|
215
|
+
@property
|
|
213
216
|
def media_tensors(self) -> media.MediaTensors:
|
|
214
217
|
return self._model_context.media_tensors
|
|
215
218
|
|
|
216
|
-
@
|
|
219
|
+
@property
|
|
217
220
|
def rf_tensors(self) -> media.RfTensors:
|
|
218
221
|
return self._model_context.rf_tensors
|
|
219
222
|
|
|
220
|
-
@
|
|
223
|
+
@property
|
|
221
224
|
def organic_media_tensors(self) -> media.OrganicMediaTensors:
|
|
222
225
|
return self._model_context.organic_media_tensors
|
|
223
226
|
|
|
224
|
-
@
|
|
227
|
+
@property
|
|
225
228
|
def organic_rf_tensors(self) -> media.OrganicRfTensors:
|
|
226
229
|
return self._model_context.organic_rf_tensors
|
|
227
230
|
|
|
228
|
-
@
|
|
231
|
+
@property
|
|
229
232
|
def kpi(self) -> backend.Tensor:
|
|
230
233
|
return self._model_context.kpi
|
|
231
234
|
|
|
232
|
-
@
|
|
235
|
+
@property
|
|
233
236
|
def revenue_per_kpi(self) -> backend.Tensor | None:
|
|
234
237
|
return self._model_context.revenue_per_kpi
|
|
235
238
|
|
|
236
|
-
@
|
|
239
|
+
@property
|
|
237
240
|
def controls(self) -> backend.Tensor | None:
|
|
238
241
|
return self._model_context.controls
|
|
239
242
|
|
|
240
|
-
@
|
|
243
|
+
@property
|
|
241
244
|
def non_media_treatments(self) -> backend.Tensor | None:
|
|
242
245
|
return self._model_context.non_media_treatments
|
|
243
246
|
|
|
244
|
-
@
|
|
247
|
+
@property
|
|
245
248
|
def population(self) -> backend.Tensor:
|
|
246
249
|
return self._model_context.population
|
|
247
250
|
|
|
248
|
-
@
|
|
251
|
+
@property
|
|
249
252
|
def total_spend(self) -> backend.Tensor:
|
|
250
253
|
return self._model_context.total_spend
|
|
251
254
|
|
|
252
|
-
@
|
|
255
|
+
@property
|
|
253
256
|
def total_outcome(self) -> backend.Tensor:
|
|
254
257
|
return self._model_context.total_outcome
|
|
255
258
|
|
|
@@ -293,31 +296,31 @@ class Meridian:
|
|
|
293
296
|
def is_national(self) -> bool:
|
|
294
297
|
return self._model_context.is_national
|
|
295
298
|
|
|
296
|
-
@
|
|
299
|
+
@property
|
|
297
300
|
def knot_info(self) -> knots.KnotInfo:
|
|
298
301
|
return self._model_context.knot_info
|
|
299
302
|
|
|
300
|
-
@
|
|
303
|
+
@property
|
|
301
304
|
def controls_transformer(
|
|
302
305
|
self,
|
|
303
306
|
) -> transformers.CenteringAndScalingTransformer | None:
|
|
304
307
|
return self._model_context.controls_transformer
|
|
305
308
|
|
|
306
|
-
@
|
|
309
|
+
@property
|
|
307
310
|
def non_media_transformer(
|
|
308
311
|
self,
|
|
309
312
|
) -> transformers.CenteringAndScalingTransformer | None:
|
|
310
313
|
return self._model_context.non_media_transformer
|
|
311
314
|
|
|
312
|
-
@
|
|
315
|
+
@property
|
|
313
316
|
def kpi_transformer(self) -> transformers.KpiTransformer:
|
|
314
317
|
return self._model_context.kpi_transformer
|
|
315
318
|
|
|
316
|
-
@
|
|
319
|
+
@property
|
|
317
320
|
def controls_scaled(self) -> backend.Tensor | None:
|
|
318
321
|
return self._model_context.controls_scaled
|
|
319
322
|
|
|
320
|
-
@
|
|
323
|
+
@property
|
|
321
324
|
def non_media_treatments_normalized(self) -> backend.Tensor | None:
|
|
322
325
|
"""Normalized non-media treatments.
|
|
323
326
|
|
|
@@ -327,33 +330,33 @@ class Meridian:
|
|
|
327
330
|
"""
|
|
328
331
|
return self._model_context.non_media_treatments_normalized
|
|
329
332
|
|
|
330
|
-
@
|
|
333
|
+
@property
|
|
331
334
|
def kpi_scaled(self) -> backend.Tensor:
|
|
332
335
|
return self._model_context.kpi_scaled
|
|
333
336
|
|
|
334
|
-
@
|
|
337
|
+
@property
|
|
335
338
|
def media_effects_dist(self) -> str:
|
|
336
339
|
return self._model_context.media_effects_dist
|
|
337
340
|
|
|
338
|
-
@
|
|
341
|
+
@property
|
|
339
342
|
def unique_sigma_for_each_geo(self) -> bool:
|
|
340
343
|
return self._model_context.unique_sigma_for_each_geo
|
|
341
344
|
|
|
342
|
-
@
|
|
345
|
+
@property
|
|
343
346
|
def baseline_geo_idx(self) -> int:
|
|
344
347
|
"""Returns the index of the baseline geo."""
|
|
345
348
|
return self._model_context.baseline_geo_idx
|
|
346
349
|
|
|
347
|
-
@
|
|
350
|
+
@property
|
|
348
351
|
def holdout_id(self) -> backend.Tensor | None:
|
|
349
352
|
return self._model_context.holdout_id
|
|
350
353
|
|
|
351
|
-
@
|
|
354
|
+
@property
|
|
352
355
|
def adstock_decay_spec(self) -> adstock_hill.AdstockDecaySpec:
|
|
353
356
|
"""Returns `AdstockDecaySpec` object with correctly mapped channels."""
|
|
354
357
|
return self._model_context.adstock_decay_spec
|
|
355
358
|
|
|
356
|
-
@
|
|
359
|
+
@property
|
|
357
360
|
def prior_broadcast(self) -> prior_distribution.PriorDistribution:
|
|
358
361
|
"""Returns broadcasted `PriorDistribution` object."""
|
|
359
362
|
return self._model_context.prior_broadcast
|
|
@@ -361,17 +364,20 @@ class Meridian:
|
|
|
361
364
|
@functools.cached_property
|
|
362
365
|
def prior_sampler_callable(self) -> prior_sampler.PriorDistributionSampler:
|
|
363
366
|
"""A `PriorDistributionSampler` callable bound to this model."""
|
|
364
|
-
return prior_sampler.PriorDistributionSampler(
|
|
367
|
+
return prior_sampler.PriorDistributionSampler(
|
|
368
|
+
model_context=self.model_context,
|
|
369
|
+
)
|
|
365
370
|
|
|
366
371
|
@functools.cached_property
|
|
367
372
|
def posterior_sampler_callable(
|
|
368
373
|
self,
|
|
369
374
|
) -> posterior_sampler.PosteriorMCMCSampler:
|
|
370
375
|
"""A `PosteriorMCMCSampler` callable bound to this model."""
|
|
371
|
-
return posterior_sampler.PosteriorMCMCSampler(
|
|
376
|
+
return posterior_sampler.PosteriorMCMCSampler(
|
|
377
|
+
model_context=self.model_context,
|
|
378
|
+
)
|
|
372
379
|
|
|
373
|
-
# TODO:
|
|
374
|
-
# `equations.py`.
|
|
380
|
+
# TODO: Remove this method.
|
|
375
381
|
def compute_non_media_treatments_baseline(
|
|
376
382
|
self,
|
|
377
383
|
non_media_baseline_values: Sequence[str | float] | None = None,
|
|
@@ -394,10 +400,18 @@ class Meridian:
|
|
|
394
400
|
A tensor of shape `(n_non_media_channels,)` containing the
|
|
395
401
|
baseline values for each non-media treatment channel.
|
|
396
402
|
"""
|
|
397
|
-
|
|
403
|
+
warnings.warn(
|
|
404
|
+
"Meridian.compute_non_media_treatments_baseline() is deprecated and"
|
|
405
|
+
" will be removed in a future version. Use"
|
|
406
|
+
" `ModelEquations.compute_non_media_treatments_baseline()` instead.",
|
|
407
|
+
DeprecationWarning,
|
|
408
|
+
stacklevel=2,
|
|
409
|
+
)
|
|
410
|
+
return self.model_equations.compute_non_media_treatments_baseline(
|
|
398
411
|
non_media_baseline_values=non_media_baseline_values
|
|
399
412
|
)
|
|
400
413
|
|
|
414
|
+
# TODO: Remove this method.
|
|
401
415
|
def expand_selected_time_dims(
|
|
402
416
|
self,
|
|
403
417
|
start_date: tc.Date = None,
|
|
@@ -425,12 +439,16 @@ class Meridian:
|
|
|
425
439
|
ValueError if `start_date` or `end_date` is not in the input data time
|
|
426
440
|
dimensions.
|
|
427
441
|
"""
|
|
428
|
-
|
|
442
|
+
warnings.warn(
|
|
443
|
+
"Meridian.expand_selected_time_dims() is deprecated and will be removed"
|
|
444
|
+
" in a future version. Use `ModelContext.expand_selected_time_dims()`"
|
|
445
|
+
" instead.",
|
|
446
|
+
DeprecationWarning,
|
|
447
|
+
stacklevel=2,
|
|
448
|
+
)
|
|
449
|
+
return self._model_context.expand_selected_time_dims(
|
|
429
450
|
start_date=start_date, end_date=end_date
|
|
430
451
|
)
|
|
431
|
-
if expanded is None:
|
|
432
|
-
return None
|
|
433
|
-
return [date.strftime(constants.DATE_FORMAT) for date in expanded]
|
|
434
452
|
|
|
435
453
|
def _validate_injected_inference_data(self):
|
|
436
454
|
"""Validates that the injected inference data has correct shapes.
|
|
@@ -598,7 +616,7 @@ class Meridian:
|
|
|
598
616
|
f' "{self.model_spec.non_media_treatments_prior_type}".'
|
|
599
617
|
)
|
|
600
618
|
|
|
601
|
-
# TODO:
|
|
619
|
+
# TODO: Remove this method.
|
|
602
620
|
def linear_predictor_counterfactual_difference_media(
|
|
603
621
|
self,
|
|
604
622
|
media_transformed: backend.Tensor,
|
|
@@ -627,15 +645,24 @@ class Meridian:
|
|
|
627
645
|
The linear predictor difference between the treatment variable and its
|
|
628
646
|
counterfactual.
|
|
629
647
|
"""
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
648
|
+
warnings.warn(
|
|
649
|
+
"Meridian.linear_predictor_counterfactual_difference_media() is"
|
|
650
|
+
" deprecated and will be removed in a future version. Use "
|
|
651
|
+
"`ModelEquations.linear_predictor_counterfactual_difference_media()`"
|
|
652
|
+
" instead.",
|
|
653
|
+
DeprecationWarning,
|
|
654
|
+
stacklevel=2,
|
|
655
|
+
)
|
|
656
|
+
return (
|
|
657
|
+
self.model_equations.linear_predictor_counterfactual_difference_media(
|
|
658
|
+
media_transformed=media_transformed,
|
|
659
|
+
alpha_m=alpha_m,
|
|
660
|
+
ec_m=ec_m,
|
|
661
|
+
slope_m=slope_m,
|
|
662
|
+
)
|
|
635
663
|
)
|
|
636
664
|
|
|
637
|
-
# TODO:
|
|
638
|
-
# ModelEquations.linear_predictor_counterfactual_difference_rf.
|
|
665
|
+
# TODO: Remove this method.
|
|
639
666
|
def linear_predictor_counterfactual_difference_rf(
|
|
640
667
|
self,
|
|
641
668
|
rf_transformed: backend.Tensor,
|
|
@@ -664,14 +691,21 @@ class Meridian:
|
|
|
664
691
|
The linear predictor difference between the treatment variable and its
|
|
665
692
|
counterfactual.
|
|
666
693
|
"""
|
|
667
|
-
|
|
694
|
+
warnings.warn(
|
|
695
|
+
"Meridian.linear_predictor_counterfactual_difference_rf() is deprecated"
|
|
696
|
+
" and will be removed in a future version. Use `ModelEquations."
|
|
697
|
+
"linear_predictor_counterfactual_difference_rf()` instead.",
|
|
698
|
+
DeprecationWarning,
|
|
699
|
+
stacklevel=2,
|
|
700
|
+
)
|
|
701
|
+
return self.model_equations.linear_predictor_counterfactual_difference_rf(
|
|
668
702
|
rf_transformed=rf_transformed,
|
|
669
703
|
alpha_rf=alpha_rf,
|
|
670
704
|
ec_rf=ec_rf,
|
|
671
705
|
slope_rf=slope_rf,
|
|
672
706
|
)
|
|
673
707
|
|
|
674
|
-
# TODO:
|
|
708
|
+
# TODO: Remove this method.
|
|
675
709
|
def calculate_beta_x(
|
|
676
710
|
self,
|
|
677
711
|
is_non_media: bool,
|
|
@@ -717,7 +751,14 @@ class Meridian:
|
|
|
717
751
|
The coefficient mean parameter of the treatment variable, which has
|
|
718
752
|
dimension equal to the number of treatment channels..
|
|
719
753
|
"""
|
|
720
|
-
|
|
754
|
+
warnings.warn(
|
|
755
|
+
"Meridian.calculate_beta_x() is deprecated and will be removed in a"
|
|
756
|
+
" future version. Use `ModelEquations.calculate_beta_x()`"
|
|
757
|
+
" instead.",
|
|
758
|
+
DeprecationWarning,
|
|
759
|
+
stacklevel=2,
|
|
760
|
+
)
|
|
761
|
+
return self.model_equations.calculate_beta_x(
|
|
721
762
|
is_non_media=is_non_media,
|
|
722
763
|
incremental_outcome_x=incremental_outcome_x,
|
|
723
764
|
linear_predictor_counterfactual_difference=linear_predictor_counterfactual_difference,
|
|
@@ -725,7 +766,7 @@ class Meridian:
|
|
|
725
766
|
beta_gx_dev=beta_gx_dev,
|
|
726
767
|
)
|
|
727
768
|
|
|
728
|
-
# TODO:
|
|
769
|
+
# TODO: Remove this method.
|
|
729
770
|
def adstock_hill_media(
|
|
730
771
|
self,
|
|
731
772
|
media: backend.Tensor, # pylint: disable=redefined-outer-name
|
|
@@ -756,7 +797,14 @@ class Meridian:
|
|
|
756
797
|
Tensor with dimensions `[..., n_geos, n_times, n_media_channels]`
|
|
757
798
|
representing Adstock and Hill-transformed media.
|
|
758
799
|
"""
|
|
759
|
-
|
|
800
|
+
warnings.warn(
|
|
801
|
+
"Meridian.adstock_hill_media() is deprecated and will be removed in a"
|
|
802
|
+
" future version. Use `ModelEquations.adstock_hill_media()`"
|
|
803
|
+
" instead.",
|
|
804
|
+
DeprecationWarning,
|
|
805
|
+
stacklevel=2,
|
|
806
|
+
)
|
|
807
|
+
return self.model_equations.adstock_hill_media(
|
|
760
808
|
media=media,
|
|
761
809
|
alpha=alpha,
|
|
762
810
|
ec=ec,
|
|
@@ -765,7 +813,7 @@ class Meridian:
|
|
|
765
813
|
n_times_output=n_times_output,
|
|
766
814
|
)
|
|
767
815
|
|
|
768
|
-
# TODO:
|
|
816
|
+
# TODO: Remove this method.
|
|
769
817
|
def adstock_hill_rf(
|
|
770
818
|
self,
|
|
771
819
|
reach: backend.Tensor,
|
|
@@ -797,7 +845,14 @@ class Meridian:
|
|
|
797
845
|
Tensor with dimensions `[..., n_geos, n_times, n_rf_channels]`
|
|
798
846
|
representing Hill and Adstock-transformed RF.
|
|
799
847
|
"""
|
|
800
|
-
|
|
848
|
+
warnings.warn(
|
|
849
|
+
"Meridian.adstock_hill_rf() is deprecated and will be removed in a"
|
|
850
|
+
" future version. Use `ModelEquations.adstock_hill_rf()`"
|
|
851
|
+
" instead.",
|
|
852
|
+
DeprecationWarning,
|
|
853
|
+
stacklevel=2,
|
|
854
|
+
)
|
|
855
|
+
return self.model_equations.adstock_hill_rf(
|
|
801
856
|
reach=reach,
|
|
802
857
|
frequency=frequency,
|
|
803
858
|
alpha=alpha,
|
|
@@ -827,66 +882,30 @@ class Meridian:
|
|
|
827
882
|
for attr in cached_properties:
|
|
828
883
|
_ = getattr(self, attr)
|
|
829
884
|
|
|
885
|
+
# TODO: Remove this method.
|
|
830
886
|
def create_inference_data_coords(
|
|
831
887
|
self, n_chains: int, n_draws: int
|
|
832
888
|
) -> Mapping[str, np.ndarray | Sequence[str]]:
|
|
833
889
|
"""Creates data coordinates for inference data."""
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
self.input_data.rf_channel
|
|
841
|
-
if self.input_data.rf_channel is not None
|
|
842
|
-
else np.array([])
|
|
843
|
-
)
|
|
844
|
-
organic_media_channel_names = (
|
|
845
|
-
self.input_data.organic_media_channel
|
|
846
|
-
if self.input_data.organic_media_channel is not None
|
|
847
|
-
else np.array([])
|
|
890
|
+
warnings.warn(
|
|
891
|
+
"Meridian.create_inference_data_coords() is deprecated and will be"
|
|
892
|
+
" removed in a future version. Use"
|
|
893
|
+
" `ModelContext.create_inference_data_coords()` instead.",
|
|
894
|
+
DeprecationWarning,
|
|
895
|
+
stacklevel=2,
|
|
848
896
|
)
|
|
849
|
-
|
|
850
|
-
self.input_data.organic_rf_channel
|
|
851
|
-
if self.input_data.organic_rf_channel is not None
|
|
852
|
-
else np.array([])
|
|
853
|
-
)
|
|
854
|
-
non_media_channel_names = (
|
|
855
|
-
self.input_data.non_media_channel
|
|
856
|
-
if self.input_data.non_media_channel is not None
|
|
857
|
-
else np.array([])
|
|
858
|
-
)
|
|
859
|
-
control_variable_names = (
|
|
860
|
-
self.input_data.control_variable
|
|
861
|
-
if self.input_data.control_variable is not None
|
|
862
|
-
else np.array([])
|
|
863
|
-
)
|
|
864
|
-
return {
|
|
865
|
-
constants.CHAIN: np.arange(n_chains),
|
|
866
|
-
constants.DRAW: np.arange(n_draws),
|
|
867
|
-
constants.GEO: self.input_data.geo,
|
|
868
|
-
constants.TIME: self.input_data.time,
|
|
869
|
-
constants.MEDIA_TIME: self.input_data.media_time,
|
|
870
|
-
constants.KNOTS: np.arange(self.knot_info.n_knots),
|
|
871
|
-
constants.CONTROL_VARIABLE: control_variable_names,
|
|
872
|
-
constants.NON_MEDIA_CHANNEL: non_media_channel_names,
|
|
873
|
-
constants.MEDIA_CHANNEL: media_channel_names,
|
|
874
|
-
constants.RF_CHANNEL: rf_channel_names,
|
|
875
|
-
constants.ORGANIC_MEDIA_CHANNEL: organic_media_channel_names,
|
|
876
|
-
constants.ORGANIC_RF_CHANNEL: organic_rf_channel_names,
|
|
877
|
-
}
|
|
897
|
+
return self._model_context.create_inference_data_coords(n_chains, n_draws)
|
|
878
898
|
|
|
899
|
+
# TODO: Remove this method.
|
|
879
900
|
def create_inference_data_dims(self) -> Mapping[str, Sequence[str]]:
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
for param, dims in inference_dims.items()
|
|
889
|
-
}
|
|
901
|
+
warnings.warn(
|
|
902
|
+
"Meridian.create_inference_data_dims() is deprecated and will be"
|
|
903
|
+
" removed in a future version. Use"
|
|
904
|
+
" `ModelContext.create_inference_data_dims()` instead.",
|
|
905
|
+
DeprecationWarning,
|
|
906
|
+
stacklevel=2,
|
|
907
|
+
)
|
|
908
|
+
return self._model_context.create_inference_data_dims()
|
|
890
909
|
|
|
891
910
|
def sample_prior(self, n_draws: int, seed: int | None = None):
|
|
892
911
|
"""Draws samples from the prior distributions.
|
|
@@ -899,14 +918,25 @@ class Meridian:
|
|
|
899
918
|
see [PRNGS and seeds]
|
|
900
919
|
(https://github.com/tensorflow/probability/blob/main/PRNGS.md).
|
|
901
920
|
"""
|
|
902
|
-
self.prior_sampler_callable(n_draws=n_draws, seed=seed)
|
|
921
|
+
prior_draws = self.prior_sampler_callable(n_draws=n_draws, seed=seed)
|
|
922
|
+
# Create Arviz InferenceData for prior draws.
|
|
923
|
+
prior_coords = self._model_context.create_inference_data_coords(1, n_draws)
|
|
924
|
+
prior_dims = self._model_context.create_inference_data_dims()
|
|
925
|
+
prior_inference_data = az.convert_to_inference_data(
|
|
926
|
+
prior_draws,
|
|
927
|
+
coords=prior_coords,
|
|
928
|
+
dims=prior_dims,
|
|
929
|
+
group=constants.PRIOR,
|
|
930
|
+
)
|
|
931
|
+
self.inference_data.extend(prior_inference_data, join="right")
|
|
903
932
|
|
|
904
933
|
def _run_model_fitting_guardrail(self):
|
|
905
934
|
"""Raises an error if the model has critical EDA issues."""
|
|
906
935
|
error_findings_by_type: dict[eda_outcome.EDACheckType, list[str]] = (
|
|
907
936
|
collections.defaultdict(list)
|
|
908
937
|
)
|
|
909
|
-
for
|
|
938
|
+
for field in dataclasses.fields(self.eda_outcomes):
|
|
939
|
+
outcome = getattr(self.eda_outcomes, field.name)
|
|
910
940
|
error_findings = [
|
|
911
941
|
finding
|
|
912
942
|
for finding in outcome.findings
|
|
@@ -1010,7 +1040,7 @@ class Meridian:
|
|
|
1010
1040
|
"""
|
|
1011
1041
|
self._run_model_fitting_guardrail()
|
|
1012
1042
|
|
|
1013
|
-
self.posterior_sampler_callable(
|
|
1043
|
+
posterior_inference_data = self.posterior_sampler_callable(
|
|
1014
1044
|
n_chains=n_chains,
|
|
1015
1045
|
n_adapt=n_adapt,
|
|
1016
1046
|
n_burnin=n_burnin,
|
|
@@ -1025,6 +1055,7 @@ class Meridian:
|
|
|
1025
1055
|
seed=seed,
|
|
1026
1056
|
**pins,
|
|
1027
1057
|
)
|
|
1058
|
+
self.inference_data.extend(posterior_inference_data, join="right")
|
|
1028
1059
|
|
|
1029
1060
|
|
|
1030
1061
|
def save_mmm(mmm: Meridian, file_path: str):
|
|
@@ -1038,6 +1069,15 @@ def save_mmm(mmm: Meridian, file_path: str):
|
|
|
1038
1069
|
mmm: Model object to save.
|
|
1039
1070
|
file_path: File path to save a pickled model object.
|
|
1040
1071
|
"""
|
|
1072
|
+
warnings.warn(
|
|
1073
|
+
"save_mmm is deprecated and will be removed in a future release. Please"
|
|
1074
|
+
" use `schema.serde.meridian_serde.save_meridian` instead. See"
|
|
1075
|
+
" https://developers.google.com/meridian/docs/user-guide/saving-model-object"
|
|
1076
|
+
" for details.",
|
|
1077
|
+
DeprecationWarning,
|
|
1078
|
+
stacklevel=2,
|
|
1079
|
+
)
|
|
1080
|
+
|
|
1041
1081
|
if not os.path.exists(os.path.dirname(file_path)):
|
|
1042
1082
|
os.makedirs(os.path.dirname(file_path))
|
|
1043
1083
|
|
|
@@ -1061,6 +1101,15 @@ def load_mmm(file_path: str) -> Meridian:
|
|
|
1061
1101
|
Raises:
|
|
1062
1102
|
FileNotFoundError: If `file_path` does not exist.
|
|
1063
1103
|
"""
|
|
1104
|
+
warnings.warn(
|
|
1105
|
+
"load_mmm is deprecated and will be removed in a future release. Please"
|
|
1106
|
+
" use `meridian.schema.serde.meridian_serde.load_meridian` instead. See"
|
|
1107
|
+
" https://developers.google.com/meridian/docs/user-guide/saving-model-object"
|
|
1108
|
+
" for details.",
|
|
1109
|
+
DeprecationWarning,
|
|
1110
|
+
stacklevel=2,
|
|
1111
|
+
)
|
|
1112
|
+
|
|
1064
1113
|
try:
|
|
1065
1114
|
with open(file_path, "rb") as f:
|
|
1066
1115
|
mmm = joblib.load(f)
|
|
@@ -159,6 +159,8 @@ class WithInputDataSamples:
|
|
|
159
159
|
_input_data_non_media_and_organic_same_time_dims: input_data.InputData
|
|
160
160
|
_input_data_organic_only: input_data.InputData
|
|
161
161
|
_national_input_data_organic_only: input_data.InputData
|
|
162
|
+
_input_data_non_media_only: input_data.InputData
|
|
163
|
+
_national_input_data_non_media_only: input_data.InputData
|
|
162
164
|
|
|
163
165
|
# The following NamedTuples and their attributes are immutable, so they can
|
|
164
166
|
# be accessed directly.
|
|
@@ -537,6 +539,34 @@ class WithInputDataSamples:
|
|
|
537
539
|
seed=0,
|
|
538
540
|
)
|
|
539
541
|
)
|
|
542
|
+
cls._input_data_non_media_only = (
|
|
543
|
+
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
544
|
+
n_geos=cls._N_GEOS,
|
|
545
|
+
n_times=cls._N_TIMES,
|
|
546
|
+
n_media_times=cls._N_MEDIA_TIMES,
|
|
547
|
+
n_controls=0,
|
|
548
|
+
n_non_media_channels=cls._N_NON_MEDIA_CHANNELS,
|
|
549
|
+
n_media_channels=cls._N_MEDIA_CHANNELS,
|
|
550
|
+
n_rf_channels=0,
|
|
551
|
+
n_organic_media_channels=0,
|
|
552
|
+
n_organic_rf_channels=0,
|
|
553
|
+
seed=0,
|
|
554
|
+
)
|
|
555
|
+
)
|
|
556
|
+
cls._national_input_data_non_media_only = (
|
|
557
|
+
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
558
|
+
n_geos=cls._N_GEOS_NATIONAL,
|
|
559
|
+
n_times=cls._N_TIMES,
|
|
560
|
+
n_media_times=cls._N_MEDIA_TIMES,
|
|
561
|
+
n_controls=0,
|
|
562
|
+
n_non_media_channels=cls._N_NON_MEDIA_CHANNELS,
|
|
563
|
+
n_media_channels=cls._N_MEDIA_CHANNELS,
|
|
564
|
+
n_rf_channels=0,
|
|
565
|
+
n_organic_media_channels=0,
|
|
566
|
+
n_organic_rf_channels=0,
|
|
567
|
+
seed=0,
|
|
568
|
+
)
|
|
569
|
+
)
|
|
540
570
|
|
|
541
571
|
@property
|
|
542
572
|
def input_data_non_revenue_no_revenue_per_kpi(self) -> input_data.InputData:
|
|
@@ -659,3 +689,11 @@ class WithInputDataSamples:
|
|
|
659
689
|
@property
|
|
660
690
|
def national_input_data_organic_only(self) -> input_data.InputData:
|
|
661
691
|
return self._national_input_data_organic_only.copy(deep=True)
|
|
692
|
+
|
|
693
|
+
@property
|
|
694
|
+
def input_data_non_media_only(self) -> input_data.InputData:
|
|
695
|
+
return self._input_data_non_media_only.copy(deep=True)
|
|
696
|
+
|
|
697
|
+
@property
|
|
698
|
+
def national_input_data_non_media_only(self) -> input_data.InputData:
|
|
699
|
+
return self._national_input_data_non_media_only.copy(deep=True)
|