google-meridian 1.1.6__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.1.dist-info}/METADATA +8 -2
- google_meridian-1.2.1.dist-info/RECORD +52 -0
- meridian/__init__.py +1 -0
- meridian/analysis/analyzer.py +621 -393
- meridian/analysis/optimizer.py +403 -351
- meridian/analysis/summarizer.py +31 -16
- meridian/analysis/test_utils.py +96 -94
- meridian/analysis/visualizer.py +53 -54
- meridian/backend/__init__.py +975 -0
- meridian/backend/config.py +118 -0
- meridian/backend/test_utils.py +181 -0
- meridian/constants.py +71 -10
- meridian/data/input_data.py +99 -0
- meridian/data/test_utils.py +146 -12
- meridian/mlflow/autolog.py +2 -2
- meridian/model/adstock_hill.py +280 -33
- meridian/model/eda/__init__.py +17 -0
- meridian/model/eda/eda_engine.py +735 -0
- meridian/model/knots.py +525 -2
- meridian/model/media.py +62 -54
- meridian/model/model.py +224 -97
- meridian/model/model_test_data.py +331 -159
- meridian/model/posterior_sampler.py +388 -383
- meridian/model/prior_distribution.py +612 -177
- meridian/model/prior_sampler.py +65 -65
- meridian/model/spec.py +23 -3
- meridian/model/transformers.py +55 -49
- meridian/version.py +1 -1
- google_meridian-1.1.6.dist-info/RECORD +0 -47
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.1.dist-info}/WHEEL +0 -0
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.1.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.1.dist-info}/top_level.txt +0 -0
|
@@ -19,18 +19,22 @@ used by the Meridian model object.
|
|
|
19
19
|
"""
|
|
20
20
|
|
|
21
21
|
from __future__ import annotations
|
|
22
|
-
|
|
22
|
+
|
|
23
|
+
from collections.abc import MutableMapping, Sequence
|
|
23
24
|
import dataclasses
|
|
24
25
|
from typing import Any
|
|
25
26
|
import warnings
|
|
27
|
+
|
|
28
|
+
from meridian import backend
|
|
26
29
|
from meridian import constants
|
|
27
30
|
import numpy as np
|
|
28
|
-
import tensorflow as tf
|
|
29
|
-
import tensorflow_probability as tfp
|
|
30
31
|
|
|
31
32
|
|
|
32
33
|
__all__ = [
|
|
34
|
+
'IndependentMultivariateDistribution',
|
|
33
35
|
'PriorDistribution',
|
|
36
|
+
'distributions_are_equal',
|
|
37
|
+
'lognormal_dist_from_mean_std',
|
|
34
38
|
]
|
|
35
39
|
|
|
36
40
|
|
|
@@ -172,14 +176,14 @@ class PriorDistribution:
|
|
|
172
176
|
xi_n: Prior distribution on the hierarchical standard deviation of
|
|
173
177
|
`gamma_gn` which is the coefficient on non-media channel `n` for geo `g`.
|
|
174
178
|
Hierarchy is defined over geos. Default distribution is `HalfNormal(5.0)`.
|
|
175
|
-
alpha_m: Prior distribution on the
|
|
179
|
+
alpha_m: Prior distribution on the Adstock decay parameter for media input.
|
|
180
|
+
Default distribution is `Uniform(0.0, 1.0)`.
|
|
181
|
+
alpha_rf: Prior distribution on the Adstock decay parameter for RF input.
|
|
182
|
+
Default distribution is `Uniform(0.0, 1.0)`.
|
|
183
|
+
alpha_om: Prior distribution on the Adstock decay parameter for organic
|
|
176
184
|
media input. Default distribution is `Uniform(0.0, 1.0)`.
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
alpha_om: Prior distribution on the `geometric decay` Adstock parameter for
|
|
180
|
-
organic media input. Default distribution is `Uniform(0.0, 1.0)`.
|
|
181
|
-
alpha_orf: Prior distribution on the `geometric decay` Adstock parameter for
|
|
182
|
-
organic RF input. Default distribution is `Uniform(0.0, 1.0)`.
|
|
185
|
+
alpha_orf: Prior distribution on the Adstock decay parameter for organic RF
|
|
186
|
+
input. Default distribution is `Uniform(0.0, 1.0)`.
|
|
183
187
|
ec_m: Prior distribution on the `half-saturation` Hill parameter for media
|
|
184
188
|
input. Default distribution is `TruncatedNormal(0.8, 0.8, 0.1, 10)`.
|
|
185
189
|
ec_rf: Prior distribution on the `half-saturation` Hill parameter for RF
|
|
@@ -203,9 +207,9 @@ class PriorDistribution:
|
|
|
203
207
|
distribution is `HalfNormal(5.0)`.
|
|
204
208
|
roi_m: Prior distribution on the ROI of each media channel. This parameter
|
|
205
209
|
is only used when `paid_media_prior_type` is `'roi'`, in which case
|
|
206
|
-
`beta_m` is calculated as a deterministic function of `
|
|
207
|
-
`
|
|
208
|
-
|
|
210
|
+
`beta_m` is calculated as a deterministic function of `roi_m`, `alpha_m`,
|
|
211
|
+
`ec_m`, `slope_m`, and the spend associated with each media channel.
|
|
212
|
+
Default distribution is `LogNormal(0.2, 0.9)`. When `kpi_type` is
|
|
209
213
|
`'non_revenue'` and `revenue_per_kpi` is not provided, ROI is interpreted
|
|
210
214
|
as incremental KPI units per monetary unit spent. In this case, the
|
|
211
215
|
default value for `roi_m` and `roi_rf` will be ignored and a common ROI
|
|
@@ -270,196 +274,202 @@ class PriorDistribution:
|
|
|
270
274
|
`TruncatedNormal(0.0, 0.1, -1.0, 1.0)`.
|
|
271
275
|
"""
|
|
272
276
|
|
|
273
|
-
knot_values:
|
|
274
|
-
default_factory=lambda:
|
|
277
|
+
knot_values: backend.tfd.Distribution = dataclasses.field(
|
|
278
|
+
default_factory=lambda: backend.tfd.Normal(
|
|
275
279
|
0.0, 5.0, name=constants.KNOT_VALUES
|
|
276
280
|
),
|
|
277
281
|
)
|
|
278
|
-
tau_g_excl_baseline:
|
|
279
|
-
default_factory=lambda:
|
|
282
|
+
tau_g_excl_baseline: backend.tfd.Distribution = dataclasses.field(
|
|
283
|
+
default_factory=lambda: backend.tfd.Normal(
|
|
280
284
|
0.0, 5.0, name=constants.TAU_G_EXCL_BASELINE
|
|
281
285
|
),
|
|
282
286
|
)
|
|
283
|
-
beta_m:
|
|
284
|
-
default_factory=lambda:
|
|
287
|
+
beta_m: backend.tfd.Distribution = dataclasses.field(
|
|
288
|
+
default_factory=lambda: backend.tfd.HalfNormal(
|
|
285
289
|
5.0, name=constants.BETA_M
|
|
286
290
|
),
|
|
287
291
|
)
|
|
288
|
-
beta_rf:
|
|
289
|
-
default_factory=lambda:
|
|
292
|
+
beta_rf: backend.tfd.Distribution = dataclasses.field(
|
|
293
|
+
default_factory=lambda: backend.tfd.HalfNormal(
|
|
290
294
|
5.0, name=constants.BETA_RF
|
|
291
295
|
),
|
|
292
296
|
)
|
|
293
|
-
beta_om:
|
|
294
|
-
default_factory=lambda:
|
|
297
|
+
beta_om: backend.tfd.Distribution = dataclasses.field(
|
|
298
|
+
default_factory=lambda: backend.tfd.HalfNormal(
|
|
295
299
|
5.0, name=constants.BETA_OM
|
|
296
300
|
),
|
|
297
301
|
)
|
|
298
|
-
beta_orf:
|
|
299
|
-
default_factory=lambda:
|
|
302
|
+
beta_orf: backend.tfd.Distribution = dataclasses.field(
|
|
303
|
+
default_factory=lambda: backend.tfd.HalfNormal(
|
|
300
304
|
5.0, name=constants.BETA_ORF
|
|
301
305
|
),
|
|
302
306
|
)
|
|
303
|
-
eta_m:
|
|
304
|
-
default_factory=lambda:
|
|
305
|
-
1.0, name=constants.ETA_M
|
|
306
|
-
),
|
|
307
|
+
eta_m: backend.tfd.Distribution = dataclasses.field(
|
|
308
|
+
default_factory=lambda: backend.tfd.HalfNormal(1.0, name=constants.ETA_M),
|
|
307
309
|
)
|
|
308
|
-
eta_rf:
|
|
309
|
-
default_factory=lambda:
|
|
310
|
+
eta_rf: backend.tfd.Distribution = dataclasses.field(
|
|
311
|
+
default_factory=lambda: backend.tfd.HalfNormal(
|
|
310
312
|
1.0, name=constants.ETA_RF
|
|
311
313
|
),
|
|
312
314
|
)
|
|
313
|
-
eta_om:
|
|
314
|
-
default_factory=lambda:
|
|
315
|
+
eta_om: backend.tfd.Distribution = dataclasses.field(
|
|
316
|
+
default_factory=lambda: backend.tfd.HalfNormal(
|
|
315
317
|
1.0, name=constants.ETA_OM
|
|
316
318
|
),
|
|
317
319
|
)
|
|
318
|
-
eta_orf:
|
|
319
|
-
default_factory=lambda:
|
|
320
|
+
eta_orf: backend.tfd.Distribution = dataclasses.field(
|
|
321
|
+
default_factory=lambda: backend.tfd.HalfNormal(
|
|
320
322
|
1.0, name=constants.ETA_ORF
|
|
321
323
|
),
|
|
322
324
|
)
|
|
323
|
-
gamma_c:
|
|
324
|
-
default_factory=lambda:
|
|
325
|
+
gamma_c: backend.tfd.Distribution = dataclasses.field(
|
|
326
|
+
default_factory=lambda: backend.tfd.Normal(
|
|
325
327
|
0.0, 5.0, name=constants.GAMMA_C
|
|
326
328
|
),
|
|
327
329
|
)
|
|
328
|
-
gamma_n:
|
|
329
|
-
default_factory=lambda:
|
|
330
|
+
gamma_n: backend.tfd.Distribution = dataclasses.field(
|
|
331
|
+
default_factory=lambda: backend.tfd.Normal(
|
|
330
332
|
0.0, 5.0, name=constants.GAMMA_N
|
|
331
333
|
),
|
|
332
334
|
)
|
|
333
|
-
xi_c:
|
|
334
|
-
default_factory=lambda:
|
|
335
|
-
5.0, name=constants.XI_C
|
|
336
|
-
),
|
|
335
|
+
xi_c: backend.tfd.Distribution = dataclasses.field(
|
|
336
|
+
default_factory=lambda: backend.tfd.HalfNormal(5.0, name=constants.XI_C),
|
|
337
337
|
)
|
|
338
|
-
xi_n:
|
|
339
|
-
default_factory=lambda:
|
|
340
|
-
5.0, name=constants.XI_N
|
|
341
|
-
),
|
|
338
|
+
xi_n: backend.tfd.Distribution = dataclasses.field(
|
|
339
|
+
default_factory=lambda: backend.tfd.HalfNormal(5.0, name=constants.XI_N),
|
|
342
340
|
)
|
|
343
|
-
alpha_m:
|
|
344
|
-
default_factory=lambda:
|
|
341
|
+
alpha_m: backend.tfd.Distribution = dataclasses.field(
|
|
342
|
+
default_factory=lambda: backend.tfd.Uniform(
|
|
345
343
|
0.0, 1.0, name=constants.ALPHA_M
|
|
346
344
|
),
|
|
347
345
|
)
|
|
348
|
-
alpha_rf:
|
|
349
|
-
default_factory=lambda:
|
|
346
|
+
alpha_rf: backend.tfd.Distribution = dataclasses.field(
|
|
347
|
+
default_factory=lambda: backend.tfd.Uniform(
|
|
350
348
|
0.0, 1.0, name=constants.ALPHA_RF
|
|
351
349
|
),
|
|
352
350
|
)
|
|
353
|
-
alpha_om:
|
|
354
|
-
default_factory=lambda:
|
|
351
|
+
alpha_om: backend.tfd.Distribution = dataclasses.field(
|
|
352
|
+
default_factory=lambda: backend.tfd.Uniform(
|
|
355
353
|
0.0, 1.0, name=constants.ALPHA_OM
|
|
356
354
|
),
|
|
357
355
|
)
|
|
358
|
-
alpha_orf:
|
|
359
|
-
default_factory=lambda:
|
|
356
|
+
alpha_orf: backend.tfd.Distribution = dataclasses.field(
|
|
357
|
+
default_factory=lambda: backend.tfd.Uniform(
|
|
360
358
|
0.0, 1.0, name=constants.ALPHA_ORF
|
|
361
359
|
),
|
|
362
360
|
)
|
|
363
|
-
ec_m:
|
|
364
|
-
default_factory=lambda:
|
|
361
|
+
ec_m: backend.tfd.Distribution = dataclasses.field(
|
|
362
|
+
default_factory=lambda: backend.tfd.TruncatedNormal(
|
|
365
363
|
0.8, 0.8, 0.1, 10, name=constants.EC_M
|
|
366
364
|
),
|
|
367
365
|
)
|
|
368
|
-
ec_rf:
|
|
369
|
-
default_factory=lambda:
|
|
370
|
-
|
|
371
|
-
|
|
366
|
+
ec_rf: backend.tfd.Distribution = dataclasses.field(
|
|
367
|
+
default_factory=lambda: backend.tfd.TransformedDistribution(
|
|
368
|
+
backend.tfd.LogNormal(0.7, 0.4),
|
|
369
|
+
backend.bijectors.Shift(0.1),
|
|
372
370
|
name=constants.EC_RF,
|
|
373
371
|
),
|
|
374
372
|
)
|
|
375
|
-
ec_om:
|
|
376
|
-
default_factory=lambda:
|
|
373
|
+
ec_om: backend.tfd.Distribution = dataclasses.field(
|
|
374
|
+
default_factory=lambda: backend.tfd.TruncatedNormal(
|
|
377
375
|
0.8, 0.8, 0.1, 10, name=constants.EC_OM
|
|
378
376
|
),
|
|
379
377
|
)
|
|
380
|
-
ec_orf:
|
|
381
|
-
default_factory=lambda:
|
|
382
|
-
|
|
383
|
-
|
|
378
|
+
ec_orf: backend.tfd.Distribution = dataclasses.field(
|
|
379
|
+
default_factory=lambda: backend.tfd.TransformedDistribution(
|
|
380
|
+
backend.tfd.LogNormal(0.7, 0.4),
|
|
381
|
+
backend.bijectors.Shift(0.1),
|
|
384
382
|
name=constants.EC_ORF,
|
|
385
383
|
),
|
|
386
384
|
)
|
|
387
|
-
slope_m:
|
|
388
|
-
default_factory=lambda:
|
|
385
|
+
slope_m: backend.tfd.Distribution = dataclasses.field(
|
|
386
|
+
default_factory=lambda: backend.tfd.Deterministic(
|
|
389
387
|
1.0, name=constants.SLOPE_M
|
|
390
388
|
),
|
|
391
389
|
)
|
|
392
|
-
slope_rf:
|
|
393
|
-
default_factory=lambda:
|
|
390
|
+
slope_rf: backend.tfd.Distribution = dataclasses.field(
|
|
391
|
+
default_factory=lambda: backend.tfd.LogNormal(
|
|
394
392
|
0.7, 0.4, name=constants.SLOPE_RF
|
|
395
393
|
),
|
|
396
394
|
)
|
|
397
|
-
slope_om:
|
|
398
|
-
default_factory=lambda:
|
|
395
|
+
slope_om: backend.tfd.Distribution = dataclasses.field(
|
|
396
|
+
default_factory=lambda: backend.tfd.Deterministic(
|
|
399
397
|
1.0, name=constants.SLOPE_OM
|
|
400
398
|
),
|
|
401
399
|
)
|
|
402
|
-
slope_orf:
|
|
403
|
-
default_factory=lambda:
|
|
400
|
+
slope_orf: backend.tfd.Distribution = dataclasses.field(
|
|
401
|
+
default_factory=lambda: backend.tfd.LogNormal(
|
|
404
402
|
0.7, 0.4, name=constants.SLOPE_ORF
|
|
405
403
|
),
|
|
406
404
|
)
|
|
407
|
-
sigma:
|
|
408
|
-
default_factory=lambda:
|
|
409
|
-
5.0, name=constants.SIGMA
|
|
410
|
-
),
|
|
405
|
+
sigma: backend.tfd.Distribution = dataclasses.field(
|
|
406
|
+
default_factory=lambda: backend.tfd.HalfNormal(5.0, name=constants.SIGMA),
|
|
411
407
|
)
|
|
412
|
-
roi_m:
|
|
413
|
-
default_factory=lambda:
|
|
408
|
+
roi_m: backend.tfd.Distribution = dataclasses.field(
|
|
409
|
+
default_factory=lambda: backend.tfd.LogNormal(
|
|
414
410
|
0.2, 0.9, name=constants.ROI_M
|
|
415
411
|
),
|
|
416
412
|
)
|
|
417
|
-
roi_rf:
|
|
418
|
-
default_factory=lambda:
|
|
413
|
+
roi_rf: backend.tfd.Distribution = dataclasses.field(
|
|
414
|
+
default_factory=lambda: backend.tfd.LogNormal(
|
|
419
415
|
0.2, 0.9, name=constants.ROI_RF
|
|
420
416
|
),
|
|
421
417
|
)
|
|
422
|
-
mroi_m:
|
|
423
|
-
default_factory=lambda:
|
|
418
|
+
mroi_m: backend.tfd.Distribution = dataclasses.field(
|
|
419
|
+
default_factory=lambda: backend.tfd.LogNormal(
|
|
424
420
|
0.0, 0.5, name=constants.MROI_M
|
|
425
421
|
),
|
|
426
422
|
)
|
|
427
|
-
mroi_rf:
|
|
428
|
-
default_factory=lambda:
|
|
423
|
+
mroi_rf: backend.tfd.Distribution = dataclasses.field(
|
|
424
|
+
default_factory=lambda: backend.tfd.LogNormal(
|
|
429
425
|
0.0, 0.5, name=constants.MROI_RF
|
|
430
426
|
),
|
|
431
427
|
)
|
|
432
|
-
contribution_m:
|
|
433
|
-
default_factory=lambda:
|
|
428
|
+
contribution_m: backend.tfd.Distribution = dataclasses.field(
|
|
429
|
+
default_factory=lambda: backend.tfd.Beta(
|
|
434
430
|
1.0, 99.0, name=constants.CONTRIBUTION_M
|
|
435
431
|
),
|
|
436
432
|
)
|
|
437
|
-
contribution_rf:
|
|
438
|
-
default_factory=lambda:
|
|
433
|
+
contribution_rf: backend.tfd.Distribution = dataclasses.field(
|
|
434
|
+
default_factory=lambda: backend.tfd.Beta(
|
|
439
435
|
1.0, 99.0, name=constants.CONTRIBUTION_RF
|
|
440
436
|
),
|
|
441
437
|
)
|
|
442
|
-
contribution_om:
|
|
443
|
-
default_factory=lambda:
|
|
438
|
+
contribution_om: backend.tfd.Distribution = dataclasses.field(
|
|
439
|
+
default_factory=lambda: backend.tfd.Beta(
|
|
444
440
|
1.0, 99.0, name=constants.CONTRIBUTION_OM
|
|
445
441
|
),
|
|
446
442
|
)
|
|
447
|
-
contribution_orf:
|
|
448
|
-
default_factory=lambda:
|
|
443
|
+
contribution_orf: backend.tfd.Distribution = dataclasses.field(
|
|
444
|
+
default_factory=lambda: backend.tfd.Beta(
|
|
449
445
|
1.0, 99.0, name=constants.CONTRIBUTION_ORF
|
|
450
446
|
),
|
|
451
447
|
)
|
|
452
|
-
contribution_n:
|
|
453
|
-
default_factory=lambda:
|
|
448
|
+
contribution_n: backend.tfd.Distribution = dataclasses.field(
|
|
449
|
+
default_factory=lambda: backend.tfd.TruncatedNormal(
|
|
454
450
|
loc=0.0, scale=0.1, low=-1.0, high=1.0, name=constants.CONTRIBUTION_N
|
|
455
451
|
),
|
|
456
452
|
)
|
|
457
453
|
|
|
454
|
+
def __post_init__(self):
|
|
455
|
+
for param, bounds in _parameter_space_bounds.items():
|
|
456
|
+
prevent_deterministic_prior_at_bounds = (
|
|
457
|
+
_prevent_deterministic_prior_at_bounds[param]
|
|
458
|
+
if param in _prevent_deterministic_prior_at_bounds.keys()
|
|
459
|
+
else (False, False)
|
|
460
|
+
)
|
|
461
|
+
_validate_support(
|
|
462
|
+
param,
|
|
463
|
+
getattr(self, param),
|
|
464
|
+
bounds,
|
|
465
|
+
prevent_deterministic_prior_at_bounds,
|
|
466
|
+
)
|
|
467
|
+
|
|
458
468
|
def __setstate__(self, state):
|
|
459
469
|
# Override to support pickling.
|
|
460
470
|
def _unpack_distribution_params(
|
|
461
471
|
params: MutableMapping[str, Any],
|
|
462
|
-
) ->
|
|
472
|
+
) -> backend.tfd.Distribution:
|
|
463
473
|
if constants.DISTRIBUTION in params:
|
|
464
474
|
params[constants.DISTRIBUTION] = _unpack_distribution_params(
|
|
465
475
|
params[constants.DISTRIBUTION]
|
|
@@ -478,7 +488,7 @@ class PriorDistribution:
|
|
|
478
488
|
state = self.__dict__.copy()
|
|
479
489
|
|
|
480
490
|
def _pack_distribution_params(
|
|
481
|
-
dist:
|
|
491
|
+
dist: backend.tfd.Distribution,
|
|
482
492
|
) -> MutableMapping[str, Any]:
|
|
483
493
|
params = dist.parameters
|
|
484
494
|
params[constants.DISTRIBUTION_TYPE] = type(dist)
|
|
@@ -493,11 +503,9 @@ class PriorDistribution:
|
|
|
493
503
|
|
|
494
504
|
return state
|
|
495
505
|
|
|
496
|
-
def has_deterministic_param(
|
|
497
|
-
self, param: tfp.distributions.Distribution
|
|
498
|
-
) -> bool:
|
|
506
|
+
def has_deterministic_param(self, param: backend.tfd.Distribution) -> bool:
|
|
499
507
|
return hasattr(self, param) and isinstance(
|
|
500
|
-
getattr(self, param).distribution,
|
|
508
|
+
getattr(self, param).distribution, backend.tfd.Deterministic
|
|
501
509
|
)
|
|
502
510
|
|
|
503
511
|
def broadcast(
|
|
@@ -550,7 +558,7 @@ class PriorDistribution:
|
|
|
550
558
|
"""
|
|
551
559
|
|
|
552
560
|
def _validate_media_custom_priors(
|
|
553
|
-
param:
|
|
561
|
+
param: backend.tfd.Distribution,
|
|
554
562
|
) -> None:
|
|
555
563
|
if (
|
|
556
564
|
param.batch_shape.as_list()
|
|
@@ -573,7 +581,7 @@ class PriorDistribution:
|
|
|
573
581
|
_validate_media_custom_priors(self.beta_m)
|
|
574
582
|
|
|
575
583
|
def _validate_organic_media_custom_priors(
|
|
576
|
-
param:
|
|
584
|
+
param: backend.tfd.Distribution,
|
|
577
585
|
) -> None:
|
|
578
586
|
if (
|
|
579
587
|
param.batch_shape.as_list()
|
|
@@ -595,7 +603,7 @@ class PriorDistribution:
|
|
|
595
603
|
_validate_organic_media_custom_priors(self.beta_om)
|
|
596
604
|
|
|
597
605
|
def _validate_organic_rf_custom_priors(
|
|
598
|
-
param:
|
|
606
|
+
param: backend.tfd.Distribution,
|
|
599
607
|
) -> None:
|
|
600
608
|
if (
|
|
601
609
|
param.batch_shape.as_list()
|
|
@@ -617,7 +625,7 @@ class PriorDistribution:
|
|
|
617
625
|
_validate_organic_rf_custom_priors(self.beta_orf)
|
|
618
626
|
|
|
619
627
|
def _validate_rf_custom_priors(
|
|
620
|
-
param:
|
|
628
|
+
param: backend.tfd.Distribution,
|
|
621
629
|
) -> None:
|
|
622
630
|
if param.batch_shape.as_list() and n_rf_channels != param.batch_shape[0]:
|
|
623
631
|
raise ValueError(
|
|
@@ -637,7 +645,7 @@ class PriorDistribution:
|
|
|
637
645
|
_validate_rf_custom_priors(self.beta_rf)
|
|
638
646
|
|
|
639
647
|
def _validate_control_custom_priors(
|
|
640
|
-
param:
|
|
648
|
+
param: backend.tfd.Distribution,
|
|
641
649
|
) -> None:
|
|
642
650
|
if param.batch_shape.as_list() and n_controls != param.batch_shape[0]:
|
|
643
651
|
raise ValueError(
|
|
@@ -651,7 +659,7 @@ class PriorDistribution:
|
|
|
651
659
|
_validate_control_custom_priors(self.xi_c)
|
|
652
660
|
|
|
653
661
|
def _validate_non_media_custom_priors(
|
|
654
|
-
param:
|
|
662
|
+
param: backend.tfd.Distribution,
|
|
655
663
|
) -> None:
|
|
656
664
|
if (
|
|
657
665
|
param.batch_shape.as_list()
|
|
@@ -669,7 +677,7 @@ class PriorDistribution:
|
|
|
669
677
|
_validate_non_media_custom_priors(self.gamma_n)
|
|
670
678
|
_validate_non_media_custom_priors(self.xi_n)
|
|
671
679
|
|
|
672
|
-
knot_values =
|
|
680
|
+
knot_values = backend.tfd.BatchBroadcast(
|
|
673
681
|
self.knot_values,
|
|
674
682
|
n_knots,
|
|
675
683
|
name=constants.KNOT_VALUES,
|
|
@@ -680,19 +688,19 @@ class PriorDistribution:
|
|
|
680
688
|
)
|
|
681
689
|
else:
|
|
682
690
|
tau_g_converted = self.tau_g_excl_baseline
|
|
683
|
-
tau_g_excl_baseline =
|
|
691
|
+
tau_g_excl_baseline = backend.tfd.BatchBroadcast(
|
|
684
692
|
tau_g_converted, n_geos - 1, name=constants.TAU_G_EXCL_BASELINE
|
|
685
693
|
)
|
|
686
|
-
beta_m =
|
|
694
|
+
beta_m = backend.tfd.BatchBroadcast(
|
|
687
695
|
self.beta_m, n_media_channels, name=constants.BETA_M
|
|
688
696
|
)
|
|
689
|
-
beta_rf =
|
|
697
|
+
beta_rf = backend.tfd.BatchBroadcast(
|
|
690
698
|
self.beta_rf, n_rf_channels, name=constants.BETA_RF
|
|
691
699
|
)
|
|
692
|
-
beta_om =
|
|
700
|
+
beta_om = backend.tfd.BatchBroadcast(
|
|
693
701
|
self.beta_om, n_organic_media_channels, name=constants.BETA_OM
|
|
694
702
|
)
|
|
695
|
-
beta_orf =
|
|
703
|
+
beta_orf = backend.tfd.BatchBroadcast(
|
|
696
704
|
self.beta_orf, n_organic_rf_channels, name=constants.BETA_ORF
|
|
697
705
|
)
|
|
698
706
|
if is_national:
|
|
@@ -705,67 +713,67 @@ class PriorDistribution:
|
|
|
705
713
|
eta_rf_converted = self.eta_rf
|
|
706
714
|
eta_om_converted = self.eta_om
|
|
707
715
|
eta_orf_converted = self.eta_orf
|
|
708
|
-
eta_m =
|
|
716
|
+
eta_m = backend.tfd.BatchBroadcast(
|
|
709
717
|
eta_m_converted, n_media_channels, name=constants.ETA_M
|
|
710
718
|
)
|
|
711
|
-
eta_rf =
|
|
719
|
+
eta_rf = backend.tfd.BatchBroadcast(
|
|
712
720
|
eta_rf_converted, n_rf_channels, name=constants.ETA_RF
|
|
713
721
|
)
|
|
714
|
-
eta_om =
|
|
722
|
+
eta_om = backend.tfd.BatchBroadcast(
|
|
715
723
|
eta_om_converted,
|
|
716
724
|
n_organic_media_channels,
|
|
717
725
|
name=constants.ETA_OM,
|
|
718
726
|
)
|
|
719
|
-
eta_orf =
|
|
727
|
+
eta_orf = backend.tfd.BatchBroadcast(
|
|
720
728
|
eta_orf_converted, n_organic_rf_channels, name=constants.ETA_ORF
|
|
721
729
|
)
|
|
722
|
-
gamma_c =
|
|
730
|
+
gamma_c = backend.tfd.BatchBroadcast(
|
|
723
731
|
self.gamma_c, n_controls, name=constants.GAMMA_C
|
|
724
732
|
)
|
|
725
733
|
if is_national:
|
|
726
734
|
xi_c_converted = _convert_to_deterministic_0_distribution(self.xi_c)
|
|
727
735
|
else:
|
|
728
736
|
xi_c_converted = self.xi_c
|
|
729
|
-
xi_c =
|
|
737
|
+
xi_c = backend.tfd.BatchBroadcast(
|
|
730
738
|
xi_c_converted, n_controls, name=constants.XI_C
|
|
731
739
|
)
|
|
732
|
-
gamma_n =
|
|
740
|
+
gamma_n = backend.tfd.BatchBroadcast(
|
|
733
741
|
self.gamma_n, n_non_media_channels, name=constants.GAMMA_N
|
|
734
742
|
)
|
|
735
743
|
if is_national:
|
|
736
744
|
xi_n_converted = _convert_to_deterministic_0_distribution(self.xi_n)
|
|
737
745
|
else:
|
|
738
746
|
xi_n_converted = self.xi_n
|
|
739
|
-
xi_n =
|
|
747
|
+
xi_n = backend.tfd.BatchBroadcast(
|
|
740
748
|
xi_n_converted, n_non_media_channels, name=constants.XI_N
|
|
741
749
|
)
|
|
742
|
-
alpha_m =
|
|
750
|
+
alpha_m = backend.tfd.BatchBroadcast(
|
|
743
751
|
self.alpha_m, n_media_channels, name=constants.ALPHA_M
|
|
744
752
|
)
|
|
745
|
-
alpha_rf =
|
|
753
|
+
alpha_rf = backend.tfd.BatchBroadcast(
|
|
746
754
|
self.alpha_rf, n_rf_channels, name=constants.ALPHA_RF
|
|
747
755
|
)
|
|
748
|
-
alpha_om =
|
|
756
|
+
alpha_om = backend.tfd.BatchBroadcast(
|
|
749
757
|
self.alpha_om, n_organic_media_channels, name=constants.ALPHA_OM
|
|
750
758
|
)
|
|
751
|
-
alpha_orf =
|
|
759
|
+
alpha_orf = backend.tfd.BatchBroadcast(
|
|
752
760
|
self.alpha_orf, n_organic_rf_channels, name=constants.ALPHA_ORF
|
|
753
761
|
)
|
|
754
|
-
ec_m =
|
|
762
|
+
ec_m = backend.tfd.BatchBroadcast(
|
|
755
763
|
self.ec_m, n_media_channels, name=constants.EC_M
|
|
756
764
|
)
|
|
757
|
-
ec_rf =
|
|
765
|
+
ec_rf = backend.tfd.BatchBroadcast(
|
|
758
766
|
self.ec_rf, n_rf_channels, name=constants.EC_RF
|
|
759
767
|
)
|
|
760
|
-
ec_om =
|
|
768
|
+
ec_om = backend.tfd.BatchBroadcast(
|
|
761
769
|
self.ec_om, n_organic_media_channels, name=constants.EC_OM
|
|
762
770
|
)
|
|
763
|
-
ec_orf =
|
|
771
|
+
ec_orf = backend.tfd.BatchBroadcast(
|
|
764
772
|
self.ec_orf, n_organic_rf_channels, name=constants.EC_ORF
|
|
765
773
|
)
|
|
766
774
|
if (
|
|
767
|
-
not isinstance(self.slope_m,
|
|
768
|
-
or (
|
|
775
|
+
not isinstance(self.slope_m, backend.tfd.Deterministic)
|
|
776
|
+
or (backend.rank(self.slope_m.loc) == 0 and self.slope_m.loc != 1.0)
|
|
769
777
|
or (
|
|
770
778
|
self.slope_m.batch_shape.as_list()
|
|
771
779
|
and any(x != 1.0 for x in self.slope_m.loc)
|
|
@@ -776,15 +784,15 @@ class PriorDistribution:
|
|
|
776
784
|
' This may lead to poor MCMC convergence and budget optimization'
|
|
777
785
|
' may no longer produce a global optimum.'
|
|
778
786
|
)
|
|
779
|
-
slope_m =
|
|
787
|
+
slope_m = backend.tfd.BatchBroadcast(
|
|
780
788
|
self.slope_m, n_media_channels, name=constants.SLOPE_M
|
|
781
789
|
)
|
|
782
|
-
slope_rf =
|
|
790
|
+
slope_rf = backend.tfd.BatchBroadcast(
|
|
783
791
|
self.slope_rf, n_rf_channels, name=constants.SLOPE_RF
|
|
784
792
|
)
|
|
785
793
|
if (
|
|
786
|
-
not isinstance(self.slope_om,
|
|
787
|
-
or (
|
|
794
|
+
not isinstance(self.slope_om, backend.tfd.Deterministic)
|
|
795
|
+
or (backend.rank(self.slope_om.loc) == 0 and self.slope_om.loc != 1.0)
|
|
788
796
|
or (
|
|
789
797
|
self.slope_om.batch_shape.as_list()
|
|
790
798
|
and any(x != 1.0 for x in self.slope_om.loc)
|
|
@@ -795,16 +803,16 @@ class PriorDistribution:
|
|
|
795
803
|
' This may lead to poor MCMC convergence and budget optimization'
|
|
796
804
|
' may no longer produce a global optimum.'
|
|
797
805
|
)
|
|
798
|
-
slope_om =
|
|
806
|
+
slope_om = backend.tfd.BatchBroadcast(
|
|
799
807
|
self.slope_om, n_organic_media_channels, name=constants.SLOPE_OM
|
|
800
808
|
)
|
|
801
|
-
slope_orf =
|
|
809
|
+
slope_orf = backend.tfd.BatchBroadcast(
|
|
802
810
|
self.slope_orf, n_organic_rf_channels, name=constants.SLOPE_ORF
|
|
803
811
|
)
|
|
804
812
|
|
|
805
813
|
# If `unique_sigma_for_each_geo == False`, then make a scalar batch.
|
|
806
814
|
sigma_shape = n_geos if (n_geos > 1 and unique_sigma_for_each_geo) else []
|
|
807
|
-
sigma =
|
|
815
|
+
sigma = backend.tfd.BatchBroadcast(
|
|
808
816
|
self.sigma, sigma_shape, name=constants.SIGMA
|
|
809
817
|
)
|
|
810
818
|
|
|
@@ -818,37 +826,37 @@ class PriorDistribution:
|
|
|
818
826
|
else:
|
|
819
827
|
roi_m_converted = self.roi_m
|
|
820
828
|
roi_rf_converted = self.roi_rf
|
|
821
|
-
roi_m =
|
|
829
|
+
roi_m = backend.tfd.BatchBroadcast(
|
|
822
830
|
roi_m_converted, n_media_channels, name=constants.ROI_M
|
|
823
831
|
)
|
|
824
|
-
roi_rf =
|
|
832
|
+
roi_rf = backend.tfd.BatchBroadcast(
|
|
825
833
|
roi_rf_converted, n_rf_channels, name=constants.ROI_RF
|
|
826
834
|
)
|
|
827
835
|
|
|
828
|
-
mroi_m =
|
|
836
|
+
mroi_m = backend.tfd.BatchBroadcast(
|
|
829
837
|
self.mroi_m, n_media_channels, name=constants.MROI_M
|
|
830
838
|
)
|
|
831
|
-
mroi_rf =
|
|
839
|
+
mroi_rf = backend.tfd.BatchBroadcast(
|
|
832
840
|
self.mroi_rf, n_rf_channels, name=constants.MROI_RF
|
|
833
841
|
)
|
|
834
842
|
|
|
835
|
-
contribution_m =
|
|
843
|
+
contribution_m = backend.tfd.BatchBroadcast(
|
|
836
844
|
self.contribution_m, n_media_channels, name=constants.CONTRIBUTION_M
|
|
837
845
|
)
|
|
838
|
-
contribution_rf =
|
|
846
|
+
contribution_rf = backend.tfd.BatchBroadcast(
|
|
839
847
|
self.contribution_rf, n_rf_channels, name=constants.CONTRIBUTION_RF
|
|
840
848
|
)
|
|
841
|
-
contribution_om =
|
|
849
|
+
contribution_om = backend.tfd.BatchBroadcast(
|
|
842
850
|
self.contribution_om,
|
|
843
851
|
n_organic_media_channels,
|
|
844
852
|
name=constants.CONTRIBUTION_OM,
|
|
845
853
|
)
|
|
846
|
-
contribution_orf =
|
|
854
|
+
contribution_orf = backend.tfd.BatchBroadcast(
|
|
847
855
|
self.contribution_orf,
|
|
848
856
|
n_organic_rf_channels,
|
|
849
857
|
name=constants.CONTRIBUTION_ORF,
|
|
850
858
|
)
|
|
851
|
-
contribution_n =
|
|
859
|
+
contribution_n = backend.tfd.BatchBroadcast(
|
|
852
860
|
self.contribution_n, n_non_media_channels, name=constants.CONTRIBUTION_N
|
|
853
861
|
)
|
|
854
862
|
|
|
@@ -892,9 +900,350 @@ class PriorDistribution:
|
|
|
892
900
|
)
|
|
893
901
|
|
|
894
902
|
|
|
903
|
+
class IndependentMultivariateDistribution(backend.tfd.Distribution):
|
|
904
|
+
"""Container for a joint distribution created from independent distributions.
|
|
905
|
+
|
|
906
|
+
This class is useful when one wants to define a joint distribution for a
|
|
907
|
+
Meridian prior, where the elements are not necessarily from the same
|
|
908
|
+
distribution family. For example, to define a distribution where
|
|
909
|
+
one element is Uniform and the second is triangular:
|
|
910
|
+
|
|
911
|
+
```python
|
|
912
|
+
distributions = [
|
|
913
|
+
tfp.distributions.Uniform(0.0, 1.0),
|
|
914
|
+
tfp.distributions.Triangular(0.0, 1.0, 0.5)
|
|
915
|
+
]
|
|
916
|
+
distribution = IndependentMultivariateDistribution(distributions)
|
|
917
|
+
```
|
|
918
|
+
|
|
919
|
+
It is also possible to define a distribution where multiple elements come
|
|
920
|
+
from the same distribution family. For example, to define a distribution where
|
|
921
|
+
the three elements are LogNormal(0.2, 0.9), LogNormal(0, 0.5) and
|
|
922
|
+
Gamma(2, 2):
|
|
923
|
+
|
|
924
|
+
```python
|
|
925
|
+
distributions = [
|
|
926
|
+
tfp.distributions.LogNormal([0.2, 0.0], [0.9, 0.5]),
|
|
927
|
+
tfp.distributions.Gamma(2.0, 2.0)
|
|
928
|
+
]
|
|
929
|
+
distribution = IndependentMultivariateDistribution(distributions)
|
|
930
|
+
```
|
|
931
|
+
|
|
932
|
+
This class cannot contain instances of `tfd.Deterministic`.
|
|
933
|
+
"""
|
|
934
|
+
|
|
935
|
+
def __init__(
|
|
936
|
+
self,
|
|
937
|
+
distributions: Sequence[backend.tfd.Distribution],
|
|
938
|
+
validate_args: bool = False,
|
|
939
|
+
allow_nan_stats: bool = True,
|
|
940
|
+
name: str | None = None,
|
|
941
|
+
):
|
|
942
|
+
"""Initializes a batch of independent distributions from different families.
|
|
943
|
+
|
|
944
|
+
Args:
|
|
945
|
+
distributions: List of `tfd.Distribution` from which to construct a
|
|
946
|
+
multivariate distribution. The distributions must have scalar or one
|
|
947
|
+
dimensional batch shapes; the resulting batch shape will be the sum of
|
|
948
|
+
the underlying batch shapes.
|
|
949
|
+
validate_args: Python `bool`. When `True` distribution parameters are
|
|
950
|
+
checked for validity despite possibly degrading runtime performance.
|
|
951
|
+
When `False` invalid inputs may silently render incorrect outputs.
|
|
952
|
+
Default value is `False`.
|
|
953
|
+
allow_nan_stats: Python `bool`. When `True`, statistics (e.g., mean, mode,
|
|
954
|
+
variance) use the value "`NaN`" to indicate the result is undefined.
|
|
955
|
+
When `False`, an exception is raised if one or more of the statistic's
|
|
956
|
+
batch members are undefined. Default value is `True`.
|
|
957
|
+
name: Python `str` name prefixed to Ops created by this class. Default
|
|
958
|
+
value is 'IndependentMultivariate' followed by the names of the
|
|
959
|
+
underlying distributions.
|
|
960
|
+
|
|
961
|
+
Raises:
|
|
962
|
+
ValueError: If one or more distributions are instances of
|
|
963
|
+
`tfd.Deterministic` or dtypes differ between the
|
|
964
|
+
distributions.
|
|
965
|
+
"""
|
|
966
|
+
parameters = dict(locals())
|
|
967
|
+
|
|
968
|
+
self._verify_distributions(distributions)
|
|
969
|
+
|
|
970
|
+
self._distributions = [
|
|
971
|
+
dist
|
|
972
|
+
if not dist.is_scalar_batch()
|
|
973
|
+
else backend.tfd.BatchBroadcast(dist, (1,))
|
|
974
|
+
for dist in distributions
|
|
975
|
+
]
|
|
976
|
+
|
|
977
|
+
self._distribution_batch_shapes = self._get_distribution_batch_shapes()
|
|
978
|
+
self._distribution_batch_shape_tensors = backend.concatenate(
|
|
979
|
+
[dist.batch_shape_tensor() for dist in self._distributions],
|
|
980
|
+
axis=0,
|
|
981
|
+
)
|
|
982
|
+
|
|
983
|
+
dtype = self._verify_dtypes()
|
|
984
|
+
|
|
985
|
+
name = name or '-'.join(
|
|
986
|
+
[constants.INDEPENDENT_MULTIVARIATE] + [d.name for d in distributions]
|
|
987
|
+
)
|
|
988
|
+
|
|
989
|
+
super().__init__(
|
|
990
|
+
dtype=dtype,
|
|
991
|
+
reparameterization_type=backend.tfd.NOT_REPARAMETERIZED,
|
|
992
|
+
validate_args=validate_args,
|
|
993
|
+
allow_nan_stats=allow_nan_stats,
|
|
994
|
+
parameters=parameters,
|
|
995
|
+
name=name,
|
|
996
|
+
)
|
|
997
|
+
|
|
998
|
+
def _verify_distributions(
|
|
999
|
+
self, distributions: Sequence[backend.tfd.Distribution]
|
|
1000
|
+
):
|
|
1001
|
+
"""Check for deterministic distributions and raise an error if found."""
|
|
1002
|
+
|
|
1003
|
+
if any(
|
|
1004
|
+
isinstance(dist, backend.tfd.Deterministic) for dist in distributions
|
|
1005
|
+
):
|
|
1006
|
+
raise ValueError(
|
|
1007
|
+
f'{self.__class__.__name__} cannot contain `Deterministic` '
|
|
1008
|
+
'distributions. To implement a nearly deterministic element of this '
|
|
1009
|
+
'distribution, we recommend using `backend.tfd.Uniform` with a '
|
|
1010
|
+
'small range. For example to define a distribution that is nearly '
|
|
1011
|
+
'`Deterministic(1.0)`, use '
|
|
1012
|
+
'`tfp.distribution.Uniform(1.0 - 1e-9, 1.0 + 1e-9)`'
|
|
1013
|
+
)
|
|
1014
|
+
|
|
1015
|
+
def _verify_dtypes(self) -> str:
|
|
1016
|
+
dtypes = [dist.dtype for dist in self._distributions]
|
|
1017
|
+
if len(set(dtypes)) != 1:
|
|
1018
|
+
raise ValueError(
|
|
1019
|
+
f'All distributions must have the same dtype. Found: {dtypes}.'
|
|
1020
|
+
)
|
|
1021
|
+
|
|
1022
|
+
return backend.result_type(*dtypes)
|
|
1023
|
+
|
|
1024
|
+
def _event_shape(self):
|
|
1025
|
+
return backend.TensorShape([])
|
|
1026
|
+
|
|
1027
|
+
def _batch_shape_tensor(self):
|
|
1028
|
+
distribution_batch_shape_tensors = backend.concatenate(
|
|
1029
|
+
[dist.batch_shape_tensor() for dist in self._distributions],
|
|
1030
|
+
axis=0,
|
|
1031
|
+
)
|
|
1032
|
+
return backend.reduce_sum(distribution_batch_shape_tensors, keepdims=True)
|
|
1033
|
+
|
|
1034
|
+
def _batch_shape(self):
|
|
1035
|
+
return backend.TensorShape(sum(self._distribution_batch_shapes))
|
|
1036
|
+
|
|
1037
|
+
def _sample_n(self, n, seed=None):
|
|
1038
|
+
return backend.concatenate(
|
|
1039
|
+
[dist.sample(n, seed) for dist in self._distributions], axis=-1
|
|
1040
|
+
)
|
|
1041
|
+
|
|
1042
|
+
def _quantile(self, value):
|
|
1043
|
+
value = self._broadcast_value(value)
|
|
1044
|
+
split_value = backend.split(value, self._distribution_batch_shapes, axis=-1)
|
|
1045
|
+
quantiles = [
|
|
1046
|
+
dist.quantile(sv) for dist, sv in zip(self._distributions, split_value)
|
|
1047
|
+
]
|
|
1048
|
+
|
|
1049
|
+
return backend.concatenate(quantiles, axis=-1)
|
|
1050
|
+
|
|
1051
|
+
def _log_prob(self, value):
|
|
1052
|
+
value = self._broadcast_value(value)
|
|
1053
|
+
split_value = backend.split(value, self._distribution_batch_shapes, axis=-1)
|
|
1054
|
+
log_probs = [
|
|
1055
|
+
dist.log_prob(sv) for dist, sv in zip(self._distributions, split_value)
|
|
1056
|
+
]
|
|
1057
|
+
|
|
1058
|
+
return backend.concatenate(log_probs, axis=-1)
|
|
1059
|
+
|
|
1060
|
+
def _log_cdf(self, value):
|
|
1061
|
+
value = self._broadcast_value(value)
|
|
1062
|
+
split_value = backend.split(value, self._distribution_batch_shapes, axis=-1)
|
|
1063
|
+
|
|
1064
|
+
log_cdfs = [
|
|
1065
|
+
dist.log_cdf(sv) for dist, sv in zip(self._distributions, split_value)
|
|
1066
|
+
]
|
|
1067
|
+
|
|
1068
|
+
return backend.concatenate(log_cdfs, axis=-1)
|
|
1069
|
+
|
|
1070
|
+
def _mean(self):
|
|
1071
|
+
return backend.concatenate(
|
|
1072
|
+
[dist.mean() for dist in self._distributions], axis=0
|
|
1073
|
+
)
|
|
1074
|
+
|
|
1075
|
+
def _variance(self):
|
|
1076
|
+
return backend.concatenate(
|
|
1077
|
+
[dist.variance() for dist in self._distributions], axis=0
|
|
1078
|
+
)
|
|
1079
|
+
|
|
1080
|
+
def _default_event_space_bijector(self):
|
|
1081
|
+
"""Mapping from R^n to the event space of the wrapped distributions.
|
|
1082
|
+
|
|
1083
|
+
This is the blockwise concatenation of the underlying bijectors.
|
|
1084
|
+
|
|
1085
|
+
Returns:
|
|
1086
|
+
A `tfp.bijectors.Blockwise` object that concatenates the underlying
|
|
1087
|
+
bijectors.
|
|
1088
|
+
"""
|
|
1089
|
+
bijectors = [
|
|
1090
|
+
d.experimental_default_event_space_bijector()
|
|
1091
|
+
for d in self._distributions
|
|
1092
|
+
]
|
|
1093
|
+
|
|
1094
|
+
return backend.bijectors.Blockwise(
|
|
1095
|
+
bijectors,
|
|
1096
|
+
block_sizes=self._distribution_batch_shapes,
|
|
1097
|
+
)
|
|
1098
|
+
|
|
1099
|
+
def _broadcast_value(self, value: backend.Tensor) -> backend.Tensor:
|
|
1100
|
+
value = backend.to_tensor(value)
|
|
1101
|
+
broadcast_shape = backend.broadcast_dynamic_shape(
|
|
1102
|
+
value.shape, self.batch_shape_tensor()
|
|
1103
|
+
)
|
|
1104
|
+
return backend.broadcast_to(value, broadcast_shape)
|
|
1105
|
+
|
|
1106
|
+
def _get_distribution_batch_shapes(self) -> Sequence[int]:
|
|
1107
|
+
"""Sequence of batch shapes of underlying distributions."""
|
|
1108
|
+
|
|
1109
|
+
batch_shapes = []
|
|
1110
|
+
|
|
1111
|
+
for dist in self._distributions:
|
|
1112
|
+
try:
|
|
1113
|
+
(dist_batch_shape,) = dist.batch_shape
|
|
1114
|
+
except ValueError as exc:
|
|
1115
|
+
raise ValueError(
|
|
1116
|
+
'All distributions must be 0- or 1-dimensional.'
|
|
1117
|
+
f' Found {len(dist.batch_shape)}-dimensional distribution:'
|
|
1118
|
+
f' {dist.batch_shape}.'
|
|
1119
|
+
) from exc
|
|
1120
|
+
else:
|
|
1121
|
+
batch_shapes.append(dist_batch_shape)
|
|
1122
|
+
|
|
1123
|
+
return batch_shapes
|
|
1124
|
+
|
|
1125
|
+
|
|
1126
|
+
def distributions_are_equal(
|
|
1127
|
+
a: backend.tfd.Distribution, b: backend.tfd.Distribution
|
|
1128
|
+
) -> bool:
|
|
1129
|
+
"""Determine if two distributions are equal."""
|
|
1130
|
+
if type(a) != type(b): # pylint: disable=unidiomatic-typecheck
|
|
1131
|
+
return False
|
|
1132
|
+
|
|
1133
|
+
a_params = a.parameters.copy()
|
|
1134
|
+
b_params = b.parameters.copy()
|
|
1135
|
+
|
|
1136
|
+
if constants.DISTRIBUTION in a_params and constants.DISTRIBUTION in b_params:
|
|
1137
|
+
if not distributions_are_equal(
|
|
1138
|
+
a_params[constants.DISTRIBUTION], b_params[constants.DISTRIBUTION]
|
|
1139
|
+
):
|
|
1140
|
+
return False
|
|
1141
|
+
del a_params[constants.DISTRIBUTION]
|
|
1142
|
+
del b_params[constants.DISTRIBUTION]
|
|
1143
|
+
|
|
1144
|
+
if constants.DISTRIBUTION in a_params or constants.DISTRIBUTION in b_params:
|
|
1145
|
+
return False
|
|
1146
|
+
|
|
1147
|
+
if a_params.keys() != b_params.keys():
|
|
1148
|
+
return False
|
|
1149
|
+
|
|
1150
|
+
for key in a_params.keys():
|
|
1151
|
+
if isinstance(
|
|
1152
|
+
a_params[key], (backend.Tensor, np.ndarray, float, int)
|
|
1153
|
+
) and isinstance(b_params[key], (backend.Tensor, np.ndarray, float, int)):
|
|
1154
|
+
if not backend.allclose(a_params[key], b_params[key]):
|
|
1155
|
+
return False
|
|
1156
|
+
else:
|
|
1157
|
+
if a_params[key] != b_params[key]:
|
|
1158
|
+
return False
|
|
1159
|
+
|
|
1160
|
+
return True
|
|
1161
|
+
|
|
1162
|
+
|
|
1163
|
+
def lognormal_dist_from_mean_std(
|
|
1164
|
+
mean: float | Sequence[float], std: float | Sequence[float]
|
|
1165
|
+
) -> backend.tfd.LogNormal:
|
|
1166
|
+
"""Define a lognormal distribution from its mean and standard deviation.
|
|
1167
|
+
|
|
1168
|
+
This function parameterizes lognormal distributions by their mean and
|
|
1169
|
+
standard deviation.
|
|
1170
|
+
|
|
1171
|
+
Args:
|
|
1172
|
+
mean: A float or array-like object defining the distribution mean. Must be
|
|
1173
|
+
positive.
|
|
1174
|
+
std: A float or array-like object defining the distribution standard
|
|
1175
|
+
deviation. Must be non-negative.
|
|
1176
|
+
|
|
1177
|
+
Returns:
|
|
1178
|
+
A `backend.tfd.LogNormal` object with the input mean and standard deviation.
|
|
1179
|
+
"""
|
|
1180
|
+
|
|
1181
|
+
mean = np.asarray(mean)
|
|
1182
|
+
std = np.asarray(std)
|
|
1183
|
+
|
|
1184
|
+
mu = np.log(mean) - 0.5 * np.log((std / mean) ** 2 + 1)
|
|
1185
|
+
sigma = np.sqrt(np.log((std / mean) ** 2 + 1))
|
|
1186
|
+
|
|
1187
|
+
return backend.tfd.LogNormal(mu, sigma)
|
|
1188
|
+
|
|
1189
|
+
|
|
1190
|
+
def lognormal_dist_from_range(
|
|
1191
|
+
low: float | Sequence[float],
|
|
1192
|
+
high: float | Sequence[float],
|
|
1193
|
+
mass_percent: float | Sequence[float] = 0.95,
|
|
1194
|
+
) -> backend.tfd.LogNormal:
|
|
1195
|
+
"""Define a LogNormal distribution from a specified range.
|
|
1196
|
+
|
|
1197
|
+
This function parameterizes lognormal distributions by the bounds of a range,
|
|
1198
|
+
so that the specificed probability mass falls within the bounds defined by
|
|
1199
|
+
`low` and `high`. The probability mass is symmetric about the median. For
|
|
1200
|
+
example, to define a lognormal distribution with a 95% probability mass of
|
|
1201
|
+
(1, 10), use:
|
|
1202
|
+
|
|
1203
|
+
```python
|
|
1204
|
+
lognormal = lognormal_dist_from_range(1.0, 10.0, mass_percent=0.95)
|
|
1205
|
+
```
|
|
1206
|
+
|
|
1207
|
+
Args:
|
|
1208
|
+
low: Float or array-like denoting the lower bound of the range. Values must
|
|
1209
|
+
be non-negative.
|
|
1210
|
+
high: Float or array-like denoting the upper bound of range. Values must be
|
|
1211
|
+
non-negative.
|
|
1212
|
+
mass_percent: Float or array-like denoting the probability mass. Values must
|
|
1213
|
+
be between 0 and 1 (exlusive). Default: 0.95.
|
|
1214
|
+
|
|
1215
|
+
Returns:
|
|
1216
|
+
A `backend.tfd.LogNormal` object with the input percentage mass falling
|
|
1217
|
+
within the given range.
|
|
1218
|
+
"""
|
|
1219
|
+
low = np.asarray(low)
|
|
1220
|
+
high = np.asarray(high)
|
|
1221
|
+
mass_percent = np.asarray(mass_percent)
|
|
1222
|
+
|
|
1223
|
+
if not ((0.0 < low).all() and (low < high).all()): # pytype: disable=attribute-error
|
|
1224
|
+
raise ValueError("'low' and 'high' values must be non-negative and satisfy "
|
|
1225
|
+
"high > low.")
|
|
1226
|
+
|
|
1227
|
+
if not ((0.0 < mass_percent).all() and (mass_percent < 1.0).all()): # pytype: disable=attribute-error
|
|
1228
|
+
raise ValueError(
|
|
1229
|
+
"'mass_percent' values must be between 0 and 1, exclusive."
|
|
1230
|
+
)
|
|
1231
|
+
|
|
1232
|
+
normal = backend.tfd.Normal(0, 1)
|
|
1233
|
+
mass_lower = 0.5 - (mass_percent / 2)
|
|
1234
|
+
mass_upper = 0.5 + (mass_percent / 2)
|
|
1235
|
+
|
|
1236
|
+
sigma = np.log(high / low) / (
|
|
1237
|
+
normal.quantile(mass_upper) - normal.quantile(mass_lower)
|
|
1238
|
+
)
|
|
1239
|
+
mu = np.log(high) - normal.quantile(mass_upper) * sigma
|
|
1240
|
+
|
|
1241
|
+
return backend.tfd.LogNormal(mu, sigma)
|
|
1242
|
+
|
|
1243
|
+
|
|
895
1244
|
def _convert_to_deterministic_0_distribution(
|
|
896
|
-
distribution:
|
|
897
|
-
) ->
|
|
1245
|
+
distribution: backend.tfd.Distribution,
|
|
1246
|
+
) -> backend.tfd.Distribution:
|
|
898
1247
|
"""Converts the given distribution to a `Deterministic(0)` one.
|
|
899
1248
|
|
|
900
1249
|
Args:
|
|
@@ -909,7 +1258,7 @@ def _convert_to_deterministic_0_distribution(
|
|
|
909
1258
|
distribution.
|
|
910
1259
|
"""
|
|
911
1260
|
if (
|
|
912
|
-
not isinstance(distribution,
|
|
1261
|
+
not isinstance(distribution, backend.tfd.Deterministic)
|
|
913
1262
|
or distribution.loc != 0
|
|
914
1263
|
):
|
|
915
1264
|
warnings.warn(
|
|
@@ -917,7 +1266,7 @@ def _convert_to_deterministic_0_distribution(
|
|
|
917
1266
|
f' for national models. {distribution.name} has been automatically set'
|
|
918
1267
|
' to Deterministic(0).'
|
|
919
1268
|
)
|
|
920
|
-
return
|
|
1269
|
+
return backend.tfd.Deterministic(loc=0, name=distribution.name)
|
|
921
1270
|
else:
|
|
922
1271
|
return distribution
|
|
923
1272
|
|
|
@@ -928,7 +1277,7 @@ def _get_total_media_contribution_prior(
|
|
|
928
1277
|
name: str,
|
|
929
1278
|
p_mean: float = constants.P_MEAN,
|
|
930
1279
|
p_sd: float = constants.P_SD,
|
|
931
|
-
) ->
|
|
1280
|
+
) -> backend.tfd.Distribution:
|
|
932
1281
|
"""Determines ROI priors based on total media contribution.
|
|
933
1282
|
|
|
934
1283
|
Args:
|
|
@@ -945,34 +1294,120 @@ def _get_total_media_contribution_prior(
|
|
|
945
1294
|
"""
|
|
946
1295
|
roi_mean = p_mean * kpi / np.sum(total_spend)
|
|
947
1296
|
roi_sd = p_sd * kpi / np.sqrt(np.sum(np.power(total_spend, 2)))
|
|
948
|
-
lognormal_sigma =
|
|
949
|
-
np.sqrt(np.log(roi_sd**2 / roi_mean**2 + 1)), dtype=
|
|
1297
|
+
lognormal_sigma = backend.cast(
|
|
1298
|
+
np.sqrt(np.log(roi_sd**2 / roi_mean**2 + 1)), dtype=backend.float32
|
|
950
1299
|
)
|
|
951
|
-
lognormal_mu =
|
|
952
|
-
np.log(roi_mean * np.exp(-(lognormal_sigma**2) / 2)),
|
|
1300
|
+
lognormal_mu = backend.cast(
|
|
1301
|
+
np.log(roi_mean * np.exp(-(lognormal_sigma**2) / 2)),
|
|
1302
|
+
dtype=backend.float32,
|
|
953
1303
|
)
|
|
954
|
-
return
|
|
1304
|
+
return backend.tfd.LogNormal(lognormal_mu, lognormal_sigma, name=name)
|
|
955
1305
|
|
|
956
1306
|
|
|
957
|
-
def
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
|
|
1307
|
+
def _validate_support(
|
|
1308
|
+
parameter_name: str,
|
|
1309
|
+
tfp_dist: backend.tfp.distributions.Distribution,
|
|
1310
|
+
bounds: tuple[float, float],
|
|
1311
|
+
prevent_deterministic_prior_at_bounds: tuple[bool, bool],
|
|
1312
|
+
) -> None:
|
|
1313
|
+
"""Validates that distribution support is within the parameter bounds.
|
|
963
1314
|
|
|
964
|
-
|
|
965
|
-
|
|
1315
|
+
Args:
|
|
1316
|
+
parameter_name: Name of the parameter.
|
|
1317
|
+
tfp_dist: The TFP distribution to validate.
|
|
1318
|
+
bounds: Tuple containing the min and max values of the parameteter space.
|
|
1319
|
+
prevent_deterministic_prior_at_bounds: Tuple of two booleans indicating
|
|
1320
|
+
whether a deterministic prior is allowed at the lower and upper bounds,
|
|
1321
|
+
respectively.
|
|
966
1322
|
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
|
|
1323
|
+
Raises:
|
|
1324
|
+
ValueError: If the distribution support is not within the parameter bounds.
|
|
1325
|
+
"""
|
|
1326
|
+
# Note that `tfp.distributions.BatchBroadcast` objects have a `distribution`
|
|
1327
|
+
# attribute that points to a `tfp.distributions.Distribution` object.
|
|
1328
|
+
if isinstance(tfp_dist, backend.tfd.BatchBroadcast):
|
|
1329
|
+
tfp_dist = tfp_dist.distribution
|
|
1330
|
+
# Note that `tfp.distributions.Deterministic` does not have a `quantile`
|
|
1331
|
+
# method implemented, so the min and max values must be extracted from the
|
|
1332
|
+
# `loc` attribute instead.
|
|
1333
|
+
if isinstance(tfp_dist, backend.tfd.Deterministic):
|
|
1334
|
+
support_min_vals = tfp_dist.loc
|
|
1335
|
+
support_max_vals = tfp_dist.loc
|
|
1336
|
+
for i in (0, 1):
|
|
1337
|
+
if prevent_deterministic_prior_at_bounds[i] and np.any(
|
|
1338
|
+
tfp_dist.loc == bounds[i]
|
|
1339
|
+
):
|
|
1340
|
+
raise ValueError(
|
|
1341
|
+
f'{parameter_name} was assigned a point mass (deterministic) prior'
|
|
1342
|
+
f' at {bounds[i]}, which is not allowed.'
|
|
1343
|
+
)
|
|
1344
|
+
else:
|
|
1345
|
+
try:
|
|
1346
|
+
support_min_vals = tfp_dist.quantile(0)
|
|
1347
|
+
support_max_vals = tfp_dist.quantile(1)
|
|
1348
|
+
except (AttributeError, NotImplementedError):
|
|
1349
|
+
warnings.warn(
|
|
1350
|
+
f'The prior distribution for {parameter_name} does not have a'
|
|
1351
|
+
' `quantile` method implemented, so the support range validation'
|
|
1352
|
+
f' was skipped. Confirm that your prior for {parameter_name} is'
|
|
1353
|
+
' appropriate.'
|
|
1354
|
+
)
|
|
1355
|
+
return
|
|
1356
|
+
if np.any(support_min_vals < bounds[0]):
|
|
1357
|
+
raise ValueError(
|
|
1358
|
+
f'{parameter_name} was assigned a prior distribution that allows values'
|
|
1359
|
+
f' less than the parameter minimum {bounds[0]}.'
|
|
1360
|
+
)
|
|
1361
|
+
if np.any(support_max_vals > bounds[1]):
|
|
1362
|
+
raise ValueError(
|
|
1363
|
+
f'{parameter_name} was assigned a prior distribution that allows values'
|
|
1364
|
+
f' greater than the parameter maximum {bounds[1]}.'
|
|
1365
|
+
)
|
|
974
1366
|
|
|
975
|
-
if constants.DISTRIBUTION in a_params or constants.DISTRIBUTION in b_params:
|
|
976
|
-
return False
|
|
977
1367
|
|
|
978
|
-
|
|
1368
|
+
# Dictionary of parameters that have a limited parameters space. The tuple
|
|
1369
|
+
# contains the lower and upper bounds, respectively.
|
|
1370
|
+
_parameter_space_bounds = {
|
|
1371
|
+
'eta_m': (0, np.inf),
|
|
1372
|
+
'eta_rf': (0, np.inf),
|
|
1373
|
+
'eta_om': (0, np.inf),
|
|
1374
|
+
'eta_orf': (0, np.inf),
|
|
1375
|
+
'xi_c': (0, np.inf),
|
|
1376
|
+
'xi_n': (0, np.inf),
|
|
1377
|
+
'alpha_m': (0, 1),
|
|
1378
|
+
'alpha_rf': (0, 1),
|
|
1379
|
+
'alpha_om': (0, 1),
|
|
1380
|
+
'alpha_orf': (0, 1),
|
|
1381
|
+
'ec_m': (0, np.inf),
|
|
1382
|
+
'ec_rf': (0, np.inf),
|
|
1383
|
+
'ec_om': (0, np.inf),
|
|
1384
|
+
'ec_orf': (0, np.inf),
|
|
1385
|
+
'slope_m': (0, np.inf),
|
|
1386
|
+
'slope_rf': (0, np.inf),
|
|
1387
|
+
'slope_om': (0, np.inf),
|
|
1388
|
+
'slope_orf': (0, np.inf),
|
|
1389
|
+
'sigma': (0, np.inf),
|
|
1390
|
+
}
|
|
1391
|
+
|
|
1392
|
+
# Dictionary of parameters that do not allow a deterministic prior at one or
|
|
1393
|
+
# more of the parameter space bounds. The boolean tuple indicates whether a
|
|
1394
|
+
# deterministic prior is allowed at the lower bound or upper bound,
|
|
1395
|
+
# respectively, where `True` means "not allowed". This check is specifically for
|
|
1396
|
+
# point mass at finite paramteter space bounds, since point mass at infinity is
|
|
1397
|
+
# generally problematic for all parameters. Note that `sigma` should generally
|
|
1398
|
+
# not have point mass at zero, but this is not checked here because unit tests
|
|
1399
|
+
# require the ability to simulate data with `sigma` set to zero.
|
|
1400
|
+
_prevent_deterministic_prior_at_bounds = {
|
|
1401
|
+
'alpha_m': (False, True),
|
|
1402
|
+
'alpha_rf': (False, True),
|
|
1403
|
+
'alpha_om': (False, True),
|
|
1404
|
+
'alpha_orf': (False, True),
|
|
1405
|
+
'ec_m': (True, False),
|
|
1406
|
+
'ec_rf': (True, False),
|
|
1407
|
+
'ec_om': (True, False),
|
|
1408
|
+
'ec_orf': (True, False),
|
|
1409
|
+
'slope_m': (True, False),
|
|
1410
|
+
'slope_rf': (True, False),
|
|
1411
|
+
'slope_om': (True, False),
|
|
1412
|
+
'slope_orf': (True, False),
|
|
1413
|
+
}
|