google-meridian 1.1.5__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.1.5.dist-info → google_meridian-1.2.0.dist-info}/METADATA +8 -2
- google_meridian-1.2.0.dist-info/RECORD +52 -0
- meridian/__init__.py +1 -0
- meridian/analysis/analyzer.py +526 -362
- meridian/analysis/optimizer.py +275 -267
- meridian/analysis/test_utils.py +96 -94
- meridian/analysis/visualizer.py +37 -49
- meridian/backend/__init__.py +514 -0
- meridian/backend/config.py +59 -0
- meridian/backend/test_utils.py +95 -0
- meridian/constants.py +59 -3
- meridian/data/input_data.py +94 -0
- meridian/data/test_utils.py +144 -12
- meridian/model/adstock_hill.py +279 -33
- meridian/model/eda/__init__.py +17 -0
- meridian/model/eda/eda_engine.py +306 -0
- meridian/model/knots.py +525 -2
- meridian/model/media.py +62 -54
- meridian/model/model.py +224 -97
- meridian/model/model_test_data.py +323 -157
- meridian/model/posterior_sampler.py +84 -77
- meridian/model/prior_distribution.py +538 -168
- meridian/model/prior_sampler.py +65 -65
- meridian/model/spec.py +23 -3
- meridian/model/transformers.py +53 -47
- meridian/version.py +1 -1
- google_meridian-1.1.5.dist-info/RECORD +0 -47
- {google_meridian-1.1.5.dist-info → google_meridian-1.2.0.dist-info}/WHEEL +0 -0
- {google_meridian-1.1.5.dist-info → google_meridian-1.2.0.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.1.5.dist-info → google_meridian-1.2.0.dist-info}/top_level.txt +0 -0
|
@@ -19,18 +19,21 @@ used by the Meridian model object.
|
|
|
19
19
|
"""
|
|
20
20
|
|
|
21
21
|
from __future__ import annotations
|
|
22
|
-
from collections.abc import MutableMapping
|
|
22
|
+
from collections.abc import MutableMapping, Sequence
|
|
23
23
|
import dataclasses
|
|
24
24
|
from typing import Any
|
|
25
25
|
import warnings
|
|
26
|
+
|
|
27
|
+
from meridian import backend
|
|
26
28
|
from meridian import constants
|
|
29
|
+
|
|
27
30
|
import numpy as np
|
|
28
|
-
import tensorflow as tf
|
|
29
|
-
import tensorflow_probability as tfp
|
|
30
31
|
|
|
31
32
|
|
|
32
33
|
__all__ = [
|
|
34
|
+
'IndependentMultivariateDistribution',
|
|
33
35
|
'PriorDistribution',
|
|
36
|
+
'distributions_are_equal',
|
|
34
37
|
]
|
|
35
38
|
|
|
36
39
|
|
|
@@ -203,9 +206,9 @@ class PriorDistribution:
|
|
|
203
206
|
distribution is `HalfNormal(5.0)`.
|
|
204
207
|
roi_m: Prior distribution on the ROI of each media channel. This parameter
|
|
205
208
|
is only used when `paid_media_prior_type` is `'roi'`, in which case
|
|
206
|
-
`beta_m` is calculated as a deterministic function of `
|
|
207
|
-
`
|
|
208
|
-
|
|
209
|
+
`beta_m` is calculated as a deterministic function of `roi_m`, `alpha_m`,
|
|
210
|
+
`ec_m`, `slope_m`, and the spend associated with each media channel.
|
|
211
|
+
Default distribution is `LogNormal(0.2, 0.9)`. When `kpi_type` is
|
|
209
212
|
`'non_revenue'` and `revenue_per_kpi` is not provided, ROI is interpreted
|
|
210
213
|
as incremental KPI units per monetary unit spent. In this case, the
|
|
211
214
|
default value for `roi_m` and `roi_rf` will be ignored and a common ROI
|
|
@@ -270,196 +273,202 @@ class PriorDistribution:
|
|
|
270
273
|
`TruncatedNormal(0.0, 0.1, -1.0, 1.0)`.
|
|
271
274
|
"""
|
|
272
275
|
|
|
273
|
-
knot_values:
|
|
274
|
-
default_factory=lambda:
|
|
276
|
+
knot_values: backend.tfd.Distribution = dataclasses.field(
|
|
277
|
+
default_factory=lambda: backend.tfd.Normal(
|
|
275
278
|
0.0, 5.0, name=constants.KNOT_VALUES
|
|
276
279
|
),
|
|
277
280
|
)
|
|
278
|
-
tau_g_excl_baseline:
|
|
279
|
-
default_factory=lambda:
|
|
281
|
+
tau_g_excl_baseline: backend.tfd.Distribution = dataclasses.field(
|
|
282
|
+
default_factory=lambda: backend.tfd.Normal(
|
|
280
283
|
0.0, 5.0, name=constants.TAU_G_EXCL_BASELINE
|
|
281
284
|
),
|
|
282
285
|
)
|
|
283
|
-
beta_m:
|
|
284
|
-
default_factory=lambda:
|
|
286
|
+
beta_m: backend.tfd.Distribution = dataclasses.field(
|
|
287
|
+
default_factory=lambda: backend.tfd.HalfNormal(
|
|
285
288
|
5.0, name=constants.BETA_M
|
|
286
289
|
),
|
|
287
290
|
)
|
|
288
|
-
beta_rf:
|
|
289
|
-
default_factory=lambda:
|
|
291
|
+
beta_rf: backend.tfd.Distribution = dataclasses.field(
|
|
292
|
+
default_factory=lambda: backend.tfd.HalfNormal(
|
|
290
293
|
5.0, name=constants.BETA_RF
|
|
291
294
|
),
|
|
292
295
|
)
|
|
293
|
-
beta_om:
|
|
294
|
-
default_factory=lambda:
|
|
296
|
+
beta_om: backend.tfd.Distribution = dataclasses.field(
|
|
297
|
+
default_factory=lambda: backend.tfd.HalfNormal(
|
|
295
298
|
5.0, name=constants.BETA_OM
|
|
296
299
|
),
|
|
297
300
|
)
|
|
298
|
-
beta_orf:
|
|
299
|
-
default_factory=lambda:
|
|
301
|
+
beta_orf: backend.tfd.Distribution = dataclasses.field(
|
|
302
|
+
default_factory=lambda: backend.tfd.HalfNormal(
|
|
300
303
|
5.0, name=constants.BETA_ORF
|
|
301
304
|
),
|
|
302
305
|
)
|
|
303
|
-
eta_m:
|
|
304
|
-
default_factory=lambda:
|
|
305
|
-
1.0, name=constants.ETA_M
|
|
306
|
-
),
|
|
306
|
+
eta_m: backend.tfd.Distribution = dataclasses.field(
|
|
307
|
+
default_factory=lambda: backend.tfd.HalfNormal(1.0, name=constants.ETA_M),
|
|
307
308
|
)
|
|
308
|
-
eta_rf:
|
|
309
|
-
default_factory=lambda:
|
|
309
|
+
eta_rf: backend.tfd.Distribution = dataclasses.field(
|
|
310
|
+
default_factory=lambda: backend.tfd.HalfNormal(
|
|
310
311
|
1.0, name=constants.ETA_RF
|
|
311
312
|
),
|
|
312
313
|
)
|
|
313
|
-
eta_om:
|
|
314
|
-
default_factory=lambda:
|
|
314
|
+
eta_om: backend.tfd.Distribution = dataclasses.field(
|
|
315
|
+
default_factory=lambda: backend.tfd.HalfNormal(
|
|
315
316
|
1.0, name=constants.ETA_OM
|
|
316
317
|
),
|
|
317
318
|
)
|
|
318
|
-
eta_orf:
|
|
319
|
-
default_factory=lambda:
|
|
319
|
+
eta_orf: backend.tfd.Distribution = dataclasses.field(
|
|
320
|
+
default_factory=lambda: backend.tfd.HalfNormal(
|
|
320
321
|
1.0, name=constants.ETA_ORF
|
|
321
322
|
),
|
|
322
323
|
)
|
|
323
|
-
gamma_c:
|
|
324
|
-
default_factory=lambda:
|
|
324
|
+
gamma_c: backend.tfd.Distribution = dataclasses.field(
|
|
325
|
+
default_factory=lambda: backend.tfd.Normal(
|
|
325
326
|
0.0, 5.0, name=constants.GAMMA_C
|
|
326
327
|
),
|
|
327
328
|
)
|
|
328
|
-
gamma_n:
|
|
329
|
-
default_factory=lambda:
|
|
329
|
+
gamma_n: backend.tfd.Distribution = dataclasses.field(
|
|
330
|
+
default_factory=lambda: backend.tfd.Normal(
|
|
330
331
|
0.0, 5.0, name=constants.GAMMA_N
|
|
331
332
|
),
|
|
332
333
|
)
|
|
333
|
-
xi_c:
|
|
334
|
-
default_factory=lambda:
|
|
335
|
-
5.0, name=constants.XI_C
|
|
336
|
-
),
|
|
334
|
+
xi_c: backend.tfd.Distribution = dataclasses.field(
|
|
335
|
+
default_factory=lambda: backend.tfd.HalfNormal(5.0, name=constants.XI_C),
|
|
337
336
|
)
|
|
338
|
-
xi_n:
|
|
339
|
-
default_factory=lambda:
|
|
340
|
-
5.0, name=constants.XI_N
|
|
341
|
-
),
|
|
337
|
+
xi_n: backend.tfd.Distribution = dataclasses.field(
|
|
338
|
+
default_factory=lambda: backend.tfd.HalfNormal(5.0, name=constants.XI_N),
|
|
342
339
|
)
|
|
343
|
-
alpha_m:
|
|
344
|
-
default_factory=lambda:
|
|
340
|
+
alpha_m: backend.tfd.Distribution = dataclasses.field(
|
|
341
|
+
default_factory=lambda: backend.tfd.Uniform(
|
|
345
342
|
0.0, 1.0, name=constants.ALPHA_M
|
|
346
343
|
),
|
|
347
344
|
)
|
|
348
|
-
alpha_rf:
|
|
349
|
-
default_factory=lambda:
|
|
345
|
+
alpha_rf: backend.tfd.Distribution = dataclasses.field(
|
|
346
|
+
default_factory=lambda: backend.tfd.Uniform(
|
|
350
347
|
0.0, 1.0, name=constants.ALPHA_RF
|
|
351
348
|
),
|
|
352
349
|
)
|
|
353
|
-
alpha_om:
|
|
354
|
-
default_factory=lambda:
|
|
350
|
+
alpha_om: backend.tfd.Distribution = dataclasses.field(
|
|
351
|
+
default_factory=lambda: backend.tfd.Uniform(
|
|
355
352
|
0.0, 1.0, name=constants.ALPHA_OM
|
|
356
353
|
),
|
|
357
354
|
)
|
|
358
|
-
alpha_orf:
|
|
359
|
-
default_factory=lambda:
|
|
355
|
+
alpha_orf: backend.tfd.Distribution = dataclasses.field(
|
|
356
|
+
default_factory=lambda: backend.tfd.Uniform(
|
|
360
357
|
0.0, 1.0, name=constants.ALPHA_ORF
|
|
361
358
|
),
|
|
362
359
|
)
|
|
363
|
-
ec_m:
|
|
364
|
-
default_factory=lambda:
|
|
360
|
+
ec_m: backend.tfd.Distribution = dataclasses.field(
|
|
361
|
+
default_factory=lambda: backend.tfd.TruncatedNormal(
|
|
365
362
|
0.8, 0.8, 0.1, 10, name=constants.EC_M
|
|
366
363
|
),
|
|
367
364
|
)
|
|
368
|
-
ec_rf:
|
|
369
|
-
default_factory=lambda:
|
|
370
|
-
|
|
371
|
-
|
|
365
|
+
ec_rf: backend.tfd.Distribution = dataclasses.field(
|
|
366
|
+
default_factory=lambda: backend.tfd.TransformedDistribution(
|
|
367
|
+
backend.tfd.LogNormal(0.7, 0.4),
|
|
368
|
+
backend.bijectors.Shift(0.1),
|
|
372
369
|
name=constants.EC_RF,
|
|
373
370
|
),
|
|
374
371
|
)
|
|
375
|
-
ec_om:
|
|
376
|
-
default_factory=lambda:
|
|
372
|
+
ec_om: backend.tfd.Distribution = dataclasses.field(
|
|
373
|
+
default_factory=lambda: backend.tfd.TruncatedNormal(
|
|
377
374
|
0.8, 0.8, 0.1, 10, name=constants.EC_OM
|
|
378
375
|
),
|
|
379
376
|
)
|
|
380
|
-
ec_orf:
|
|
381
|
-
default_factory=lambda:
|
|
382
|
-
|
|
383
|
-
|
|
377
|
+
ec_orf: backend.tfd.Distribution = dataclasses.field(
|
|
378
|
+
default_factory=lambda: backend.tfd.TransformedDistribution(
|
|
379
|
+
backend.tfd.LogNormal(0.7, 0.4),
|
|
380
|
+
backend.bijectors.Shift(0.1),
|
|
384
381
|
name=constants.EC_ORF,
|
|
385
382
|
),
|
|
386
383
|
)
|
|
387
|
-
slope_m:
|
|
388
|
-
default_factory=lambda:
|
|
384
|
+
slope_m: backend.tfd.Distribution = dataclasses.field(
|
|
385
|
+
default_factory=lambda: backend.tfd.Deterministic(
|
|
389
386
|
1.0, name=constants.SLOPE_M
|
|
390
387
|
),
|
|
391
388
|
)
|
|
392
|
-
slope_rf:
|
|
393
|
-
default_factory=lambda:
|
|
389
|
+
slope_rf: backend.tfd.Distribution = dataclasses.field(
|
|
390
|
+
default_factory=lambda: backend.tfd.LogNormal(
|
|
394
391
|
0.7, 0.4, name=constants.SLOPE_RF
|
|
395
392
|
),
|
|
396
393
|
)
|
|
397
|
-
slope_om:
|
|
398
|
-
default_factory=lambda:
|
|
394
|
+
slope_om: backend.tfd.Distribution = dataclasses.field(
|
|
395
|
+
default_factory=lambda: backend.tfd.Deterministic(
|
|
399
396
|
1.0, name=constants.SLOPE_OM
|
|
400
397
|
),
|
|
401
398
|
)
|
|
402
|
-
slope_orf:
|
|
403
|
-
default_factory=lambda:
|
|
399
|
+
slope_orf: backend.tfd.Distribution = dataclasses.field(
|
|
400
|
+
default_factory=lambda: backend.tfd.LogNormal(
|
|
404
401
|
0.7, 0.4, name=constants.SLOPE_ORF
|
|
405
402
|
),
|
|
406
403
|
)
|
|
407
|
-
sigma:
|
|
408
|
-
default_factory=lambda:
|
|
409
|
-
5.0, name=constants.SIGMA
|
|
410
|
-
),
|
|
404
|
+
sigma: backend.tfd.Distribution = dataclasses.field(
|
|
405
|
+
default_factory=lambda: backend.tfd.HalfNormal(5.0, name=constants.SIGMA),
|
|
411
406
|
)
|
|
412
|
-
roi_m:
|
|
413
|
-
default_factory=lambda:
|
|
407
|
+
roi_m: backend.tfd.Distribution = dataclasses.field(
|
|
408
|
+
default_factory=lambda: backend.tfd.LogNormal(
|
|
414
409
|
0.2, 0.9, name=constants.ROI_M
|
|
415
410
|
),
|
|
416
411
|
)
|
|
417
|
-
roi_rf:
|
|
418
|
-
default_factory=lambda:
|
|
412
|
+
roi_rf: backend.tfd.Distribution = dataclasses.field(
|
|
413
|
+
default_factory=lambda: backend.tfd.LogNormal(
|
|
419
414
|
0.2, 0.9, name=constants.ROI_RF
|
|
420
415
|
),
|
|
421
416
|
)
|
|
422
|
-
mroi_m:
|
|
423
|
-
default_factory=lambda:
|
|
417
|
+
mroi_m: backend.tfd.Distribution = dataclasses.field(
|
|
418
|
+
default_factory=lambda: backend.tfd.LogNormal(
|
|
424
419
|
0.0, 0.5, name=constants.MROI_M
|
|
425
420
|
),
|
|
426
421
|
)
|
|
427
|
-
mroi_rf:
|
|
428
|
-
default_factory=lambda:
|
|
422
|
+
mroi_rf: backend.tfd.Distribution = dataclasses.field(
|
|
423
|
+
default_factory=lambda: backend.tfd.LogNormal(
|
|
429
424
|
0.0, 0.5, name=constants.MROI_RF
|
|
430
425
|
),
|
|
431
426
|
)
|
|
432
|
-
contribution_m:
|
|
433
|
-
default_factory=lambda:
|
|
427
|
+
contribution_m: backend.tfd.Distribution = dataclasses.field(
|
|
428
|
+
default_factory=lambda: backend.tfd.Beta(
|
|
434
429
|
1.0, 99.0, name=constants.CONTRIBUTION_M
|
|
435
430
|
),
|
|
436
431
|
)
|
|
437
|
-
contribution_rf:
|
|
438
|
-
default_factory=lambda:
|
|
432
|
+
contribution_rf: backend.tfd.Distribution = dataclasses.field(
|
|
433
|
+
default_factory=lambda: backend.tfd.Beta(
|
|
439
434
|
1.0, 99.0, name=constants.CONTRIBUTION_RF
|
|
440
435
|
),
|
|
441
436
|
)
|
|
442
|
-
contribution_om:
|
|
443
|
-
default_factory=lambda:
|
|
437
|
+
contribution_om: backend.tfd.Distribution = dataclasses.field(
|
|
438
|
+
default_factory=lambda: backend.tfd.Beta(
|
|
444
439
|
1.0, 99.0, name=constants.CONTRIBUTION_OM
|
|
445
440
|
),
|
|
446
441
|
)
|
|
447
|
-
contribution_orf:
|
|
448
|
-
default_factory=lambda:
|
|
442
|
+
contribution_orf: backend.tfd.Distribution = dataclasses.field(
|
|
443
|
+
default_factory=lambda: backend.tfd.Beta(
|
|
449
444
|
1.0, 99.0, name=constants.CONTRIBUTION_ORF
|
|
450
445
|
),
|
|
451
446
|
)
|
|
452
|
-
contribution_n:
|
|
453
|
-
default_factory=lambda:
|
|
447
|
+
contribution_n: backend.tfd.Distribution = dataclasses.field(
|
|
448
|
+
default_factory=lambda: backend.tfd.TruncatedNormal(
|
|
454
449
|
loc=0.0, scale=0.1, low=-1.0, high=1.0, name=constants.CONTRIBUTION_N
|
|
455
450
|
),
|
|
456
451
|
)
|
|
457
452
|
|
|
453
|
+
def __post_init__(self):
|
|
454
|
+
for param, bounds in _parameter_space_bounds.items():
|
|
455
|
+
prevent_deterministic_prior_at_bounds = (
|
|
456
|
+
_prevent_deterministic_prior_at_bounds[param]
|
|
457
|
+
if param in _prevent_deterministic_prior_at_bounds.keys()
|
|
458
|
+
else (False, False)
|
|
459
|
+
)
|
|
460
|
+
_validate_support(
|
|
461
|
+
param,
|
|
462
|
+
getattr(self, param),
|
|
463
|
+
bounds,
|
|
464
|
+
prevent_deterministic_prior_at_bounds,
|
|
465
|
+
)
|
|
466
|
+
|
|
458
467
|
def __setstate__(self, state):
|
|
459
468
|
# Override to support pickling.
|
|
460
469
|
def _unpack_distribution_params(
|
|
461
470
|
params: MutableMapping[str, Any],
|
|
462
|
-
) ->
|
|
471
|
+
) -> backend.tfd.Distribution:
|
|
463
472
|
if constants.DISTRIBUTION in params:
|
|
464
473
|
params[constants.DISTRIBUTION] = _unpack_distribution_params(
|
|
465
474
|
params[constants.DISTRIBUTION]
|
|
@@ -478,7 +487,7 @@ class PriorDistribution:
|
|
|
478
487
|
state = self.__dict__.copy()
|
|
479
488
|
|
|
480
489
|
def _pack_distribution_params(
|
|
481
|
-
dist:
|
|
490
|
+
dist: backend.tfd.Distribution,
|
|
482
491
|
) -> MutableMapping[str, Any]:
|
|
483
492
|
params = dist.parameters
|
|
484
493
|
params[constants.DISTRIBUTION_TYPE] = type(dist)
|
|
@@ -493,11 +502,9 @@ class PriorDistribution:
|
|
|
493
502
|
|
|
494
503
|
return state
|
|
495
504
|
|
|
496
|
-
def has_deterministic_param(
|
|
497
|
-
self, param: tfp.distributions.Distribution
|
|
498
|
-
) -> bool:
|
|
505
|
+
def has_deterministic_param(self, param: backend.tfd.Distribution) -> bool:
|
|
499
506
|
return hasattr(self, param) and isinstance(
|
|
500
|
-
getattr(self, param).distribution,
|
|
507
|
+
getattr(self, param).distribution, backend.tfd.Deterministic
|
|
501
508
|
)
|
|
502
509
|
|
|
503
510
|
def broadcast(
|
|
@@ -550,7 +557,7 @@ class PriorDistribution:
|
|
|
550
557
|
"""
|
|
551
558
|
|
|
552
559
|
def _validate_media_custom_priors(
|
|
553
|
-
param:
|
|
560
|
+
param: backend.tfd.Distribution,
|
|
554
561
|
) -> None:
|
|
555
562
|
if (
|
|
556
563
|
param.batch_shape.as_list()
|
|
@@ -573,7 +580,7 @@ class PriorDistribution:
|
|
|
573
580
|
_validate_media_custom_priors(self.beta_m)
|
|
574
581
|
|
|
575
582
|
def _validate_organic_media_custom_priors(
|
|
576
|
-
param:
|
|
583
|
+
param: backend.tfd.Distribution,
|
|
577
584
|
) -> None:
|
|
578
585
|
if (
|
|
579
586
|
param.batch_shape.as_list()
|
|
@@ -595,7 +602,7 @@ class PriorDistribution:
|
|
|
595
602
|
_validate_organic_media_custom_priors(self.beta_om)
|
|
596
603
|
|
|
597
604
|
def _validate_organic_rf_custom_priors(
|
|
598
|
-
param:
|
|
605
|
+
param: backend.tfd.Distribution,
|
|
599
606
|
) -> None:
|
|
600
607
|
if (
|
|
601
608
|
param.batch_shape.as_list()
|
|
@@ -617,7 +624,7 @@ class PriorDistribution:
|
|
|
617
624
|
_validate_organic_rf_custom_priors(self.beta_orf)
|
|
618
625
|
|
|
619
626
|
def _validate_rf_custom_priors(
|
|
620
|
-
param:
|
|
627
|
+
param: backend.tfd.Distribution,
|
|
621
628
|
) -> None:
|
|
622
629
|
if param.batch_shape.as_list() and n_rf_channels != param.batch_shape[0]:
|
|
623
630
|
raise ValueError(
|
|
@@ -637,7 +644,7 @@ class PriorDistribution:
|
|
|
637
644
|
_validate_rf_custom_priors(self.beta_rf)
|
|
638
645
|
|
|
639
646
|
def _validate_control_custom_priors(
|
|
640
|
-
param:
|
|
647
|
+
param: backend.tfd.Distribution,
|
|
641
648
|
) -> None:
|
|
642
649
|
if param.batch_shape.as_list() and n_controls != param.batch_shape[0]:
|
|
643
650
|
raise ValueError(
|
|
@@ -651,7 +658,7 @@ class PriorDistribution:
|
|
|
651
658
|
_validate_control_custom_priors(self.xi_c)
|
|
652
659
|
|
|
653
660
|
def _validate_non_media_custom_priors(
|
|
654
|
-
param:
|
|
661
|
+
param: backend.tfd.Distribution,
|
|
655
662
|
) -> None:
|
|
656
663
|
if (
|
|
657
664
|
param.batch_shape.as_list()
|
|
@@ -669,7 +676,7 @@ class PriorDistribution:
|
|
|
669
676
|
_validate_non_media_custom_priors(self.gamma_n)
|
|
670
677
|
_validate_non_media_custom_priors(self.xi_n)
|
|
671
678
|
|
|
672
|
-
knot_values =
|
|
679
|
+
knot_values = backend.tfd.BatchBroadcast(
|
|
673
680
|
self.knot_values,
|
|
674
681
|
n_knots,
|
|
675
682
|
name=constants.KNOT_VALUES,
|
|
@@ -680,19 +687,19 @@ class PriorDistribution:
|
|
|
680
687
|
)
|
|
681
688
|
else:
|
|
682
689
|
tau_g_converted = self.tau_g_excl_baseline
|
|
683
|
-
tau_g_excl_baseline =
|
|
690
|
+
tau_g_excl_baseline = backend.tfd.BatchBroadcast(
|
|
684
691
|
tau_g_converted, n_geos - 1, name=constants.TAU_G_EXCL_BASELINE
|
|
685
692
|
)
|
|
686
|
-
beta_m =
|
|
693
|
+
beta_m = backend.tfd.BatchBroadcast(
|
|
687
694
|
self.beta_m, n_media_channels, name=constants.BETA_M
|
|
688
695
|
)
|
|
689
|
-
beta_rf =
|
|
696
|
+
beta_rf = backend.tfd.BatchBroadcast(
|
|
690
697
|
self.beta_rf, n_rf_channels, name=constants.BETA_RF
|
|
691
698
|
)
|
|
692
|
-
beta_om =
|
|
699
|
+
beta_om = backend.tfd.BatchBroadcast(
|
|
693
700
|
self.beta_om, n_organic_media_channels, name=constants.BETA_OM
|
|
694
701
|
)
|
|
695
|
-
beta_orf =
|
|
702
|
+
beta_orf = backend.tfd.BatchBroadcast(
|
|
696
703
|
self.beta_orf, n_organic_rf_channels, name=constants.BETA_ORF
|
|
697
704
|
)
|
|
698
705
|
if is_national:
|
|
@@ -705,66 +712,66 @@ class PriorDistribution:
|
|
|
705
712
|
eta_rf_converted = self.eta_rf
|
|
706
713
|
eta_om_converted = self.eta_om
|
|
707
714
|
eta_orf_converted = self.eta_orf
|
|
708
|
-
eta_m =
|
|
715
|
+
eta_m = backend.tfd.BatchBroadcast(
|
|
709
716
|
eta_m_converted, n_media_channels, name=constants.ETA_M
|
|
710
717
|
)
|
|
711
|
-
eta_rf =
|
|
718
|
+
eta_rf = backend.tfd.BatchBroadcast(
|
|
712
719
|
eta_rf_converted, n_rf_channels, name=constants.ETA_RF
|
|
713
720
|
)
|
|
714
|
-
eta_om =
|
|
721
|
+
eta_om = backend.tfd.BatchBroadcast(
|
|
715
722
|
eta_om_converted,
|
|
716
723
|
n_organic_media_channels,
|
|
717
724
|
name=constants.ETA_OM,
|
|
718
725
|
)
|
|
719
|
-
eta_orf =
|
|
726
|
+
eta_orf = backend.tfd.BatchBroadcast(
|
|
720
727
|
eta_orf_converted, n_organic_rf_channels, name=constants.ETA_ORF
|
|
721
728
|
)
|
|
722
|
-
gamma_c =
|
|
729
|
+
gamma_c = backend.tfd.BatchBroadcast(
|
|
723
730
|
self.gamma_c, n_controls, name=constants.GAMMA_C
|
|
724
731
|
)
|
|
725
732
|
if is_national:
|
|
726
733
|
xi_c_converted = _convert_to_deterministic_0_distribution(self.xi_c)
|
|
727
734
|
else:
|
|
728
735
|
xi_c_converted = self.xi_c
|
|
729
|
-
xi_c =
|
|
736
|
+
xi_c = backend.tfd.BatchBroadcast(
|
|
730
737
|
xi_c_converted, n_controls, name=constants.XI_C
|
|
731
738
|
)
|
|
732
|
-
gamma_n =
|
|
739
|
+
gamma_n = backend.tfd.BatchBroadcast(
|
|
733
740
|
self.gamma_n, n_non_media_channels, name=constants.GAMMA_N
|
|
734
741
|
)
|
|
735
742
|
if is_national:
|
|
736
743
|
xi_n_converted = _convert_to_deterministic_0_distribution(self.xi_n)
|
|
737
744
|
else:
|
|
738
745
|
xi_n_converted = self.xi_n
|
|
739
|
-
xi_n =
|
|
746
|
+
xi_n = backend.tfd.BatchBroadcast(
|
|
740
747
|
xi_n_converted, n_non_media_channels, name=constants.XI_N
|
|
741
748
|
)
|
|
742
|
-
alpha_m =
|
|
749
|
+
alpha_m = backend.tfd.BatchBroadcast(
|
|
743
750
|
self.alpha_m, n_media_channels, name=constants.ALPHA_M
|
|
744
751
|
)
|
|
745
|
-
alpha_rf =
|
|
752
|
+
alpha_rf = backend.tfd.BatchBroadcast(
|
|
746
753
|
self.alpha_rf, n_rf_channels, name=constants.ALPHA_RF
|
|
747
754
|
)
|
|
748
|
-
alpha_om =
|
|
755
|
+
alpha_om = backend.tfd.BatchBroadcast(
|
|
749
756
|
self.alpha_om, n_organic_media_channels, name=constants.ALPHA_OM
|
|
750
757
|
)
|
|
751
|
-
alpha_orf =
|
|
758
|
+
alpha_orf = backend.tfd.BatchBroadcast(
|
|
752
759
|
self.alpha_orf, n_organic_rf_channels, name=constants.ALPHA_ORF
|
|
753
760
|
)
|
|
754
|
-
ec_m =
|
|
761
|
+
ec_m = backend.tfd.BatchBroadcast(
|
|
755
762
|
self.ec_m, n_media_channels, name=constants.EC_M
|
|
756
763
|
)
|
|
757
|
-
ec_rf =
|
|
764
|
+
ec_rf = backend.tfd.BatchBroadcast(
|
|
758
765
|
self.ec_rf, n_rf_channels, name=constants.EC_RF
|
|
759
766
|
)
|
|
760
|
-
ec_om =
|
|
767
|
+
ec_om = backend.tfd.BatchBroadcast(
|
|
761
768
|
self.ec_om, n_organic_media_channels, name=constants.EC_OM
|
|
762
769
|
)
|
|
763
|
-
ec_orf =
|
|
770
|
+
ec_orf = backend.tfd.BatchBroadcast(
|
|
764
771
|
self.ec_orf, n_organic_rf_channels, name=constants.EC_ORF
|
|
765
772
|
)
|
|
766
773
|
if (
|
|
767
|
-
not isinstance(self.slope_m,
|
|
774
|
+
not isinstance(self.slope_m, backend.tfd.Deterministic)
|
|
768
775
|
or (np.isscalar(self.slope_m.loc.numpy()) and self.slope_m.loc != 1.0)
|
|
769
776
|
or (
|
|
770
777
|
self.slope_m.batch_shape.as_list()
|
|
@@ -776,14 +783,14 @@ class PriorDistribution:
|
|
|
776
783
|
' This may lead to poor MCMC convergence and budget optimization'
|
|
777
784
|
' may no longer produce a global optimum.'
|
|
778
785
|
)
|
|
779
|
-
slope_m =
|
|
786
|
+
slope_m = backend.tfd.BatchBroadcast(
|
|
780
787
|
self.slope_m, n_media_channels, name=constants.SLOPE_M
|
|
781
788
|
)
|
|
782
|
-
slope_rf =
|
|
789
|
+
slope_rf = backend.tfd.BatchBroadcast(
|
|
783
790
|
self.slope_rf, n_rf_channels, name=constants.SLOPE_RF
|
|
784
791
|
)
|
|
785
792
|
if (
|
|
786
|
-
not isinstance(self.slope_om,
|
|
793
|
+
not isinstance(self.slope_om, backend.tfd.Deterministic)
|
|
787
794
|
or (np.isscalar(self.slope_om.loc.numpy()) and self.slope_om.loc != 1.0)
|
|
788
795
|
or (
|
|
789
796
|
self.slope_om.batch_shape.as_list()
|
|
@@ -795,16 +802,16 @@ class PriorDistribution:
|
|
|
795
802
|
' This may lead to poor MCMC convergence and budget optimization'
|
|
796
803
|
' may no longer produce a global optimum.'
|
|
797
804
|
)
|
|
798
|
-
slope_om =
|
|
805
|
+
slope_om = backend.tfd.BatchBroadcast(
|
|
799
806
|
self.slope_om, n_organic_media_channels, name=constants.SLOPE_OM
|
|
800
807
|
)
|
|
801
|
-
slope_orf =
|
|
808
|
+
slope_orf = backend.tfd.BatchBroadcast(
|
|
802
809
|
self.slope_orf, n_organic_rf_channels, name=constants.SLOPE_ORF
|
|
803
810
|
)
|
|
804
811
|
|
|
805
812
|
# If `unique_sigma_for_each_geo == False`, then make a scalar batch.
|
|
806
813
|
sigma_shape = n_geos if (n_geos > 1 and unique_sigma_for_each_geo) else []
|
|
807
|
-
sigma =
|
|
814
|
+
sigma = backend.tfd.BatchBroadcast(
|
|
808
815
|
self.sigma, sigma_shape, name=constants.SIGMA
|
|
809
816
|
)
|
|
810
817
|
|
|
@@ -818,37 +825,37 @@ class PriorDistribution:
|
|
|
818
825
|
else:
|
|
819
826
|
roi_m_converted = self.roi_m
|
|
820
827
|
roi_rf_converted = self.roi_rf
|
|
821
|
-
roi_m =
|
|
828
|
+
roi_m = backend.tfd.BatchBroadcast(
|
|
822
829
|
roi_m_converted, n_media_channels, name=constants.ROI_M
|
|
823
830
|
)
|
|
824
|
-
roi_rf =
|
|
831
|
+
roi_rf = backend.tfd.BatchBroadcast(
|
|
825
832
|
roi_rf_converted, n_rf_channels, name=constants.ROI_RF
|
|
826
833
|
)
|
|
827
834
|
|
|
828
|
-
mroi_m =
|
|
835
|
+
mroi_m = backend.tfd.BatchBroadcast(
|
|
829
836
|
self.mroi_m, n_media_channels, name=constants.MROI_M
|
|
830
837
|
)
|
|
831
|
-
mroi_rf =
|
|
838
|
+
mroi_rf = backend.tfd.BatchBroadcast(
|
|
832
839
|
self.mroi_rf, n_rf_channels, name=constants.MROI_RF
|
|
833
840
|
)
|
|
834
841
|
|
|
835
|
-
contribution_m =
|
|
842
|
+
contribution_m = backend.tfd.BatchBroadcast(
|
|
836
843
|
self.contribution_m, n_media_channels, name=constants.CONTRIBUTION_M
|
|
837
844
|
)
|
|
838
|
-
contribution_rf =
|
|
845
|
+
contribution_rf = backend.tfd.BatchBroadcast(
|
|
839
846
|
self.contribution_rf, n_rf_channels, name=constants.CONTRIBUTION_RF
|
|
840
847
|
)
|
|
841
|
-
contribution_om =
|
|
848
|
+
contribution_om = backend.tfd.BatchBroadcast(
|
|
842
849
|
self.contribution_om,
|
|
843
850
|
n_organic_media_channels,
|
|
844
851
|
name=constants.CONTRIBUTION_OM,
|
|
845
852
|
)
|
|
846
|
-
contribution_orf =
|
|
853
|
+
contribution_orf = backend.tfd.BatchBroadcast(
|
|
847
854
|
self.contribution_orf,
|
|
848
855
|
n_organic_rf_channels,
|
|
849
856
|
name=constants.CONTRIBUTION_ORF,
|
|
850
857
|
)
|
|
851
|
-
contribution_n =
|
|
858
|
+
contribution_n = backend.tfd.BatchBroadcast(
|
|
852
859
|
self.contribution_n, n_non_media_channels, name=constants.CONTRIBUTION_N
|
|
853
860
|
)
|
|
854
861
|
|
|
@@ -892,9 +899,283 @@ class PriorDistribution:
|
|
|
892
899
|
)
|
|
893
900
|
|
|
894
901
|
|
|
902
|
+
class IndependentMultivariateDistribution(backend.tfd.Distribution):
|
|
903
|
+
"""Container for a joint distribution created from independent distributions.
|
|
904
|
+
|
|
905
|
+
This class is useful when one wants to define a joint distribution for a
|
|
906
|
+
Meridian prior, where the elements are not necessarily from the same
|
|
907
|
+
distribution family. For example, to define a distribution where
|
|
908
|
+
one element is Uniform and the second is triangular:
|
|
909
|
+
|
|
910
|
+
```python
|
|
911
|
+
distributions = [
|
|
912
|
+
tfp.distributions.Uniform(0.0, 1.0),
|
|
913
|
+
tfp.distributions.Triangular(0.0, 1.0, 0.5)
|
|
914
|
+
]
|
|
915
|
+
distribution = IndependentMultivariateDistribution(distributions)
|
|
916
|
+
```
|
|
917
|
+
|
|
918
|
+
It is also possible to define a distribution where multiple elements come
|
|
919
|
+
from the same distribution family. For example, to define a distribution where
|
|
920
|
+
the three elements are LogNormal(0.2, 0.9), LogNormal(0, 0.5) and
|
|
921
|
+
Gamma(2, 2):
|
|
922
|
+
|
|
923
|
+
```python
|
|
924
|
+
distributions = [
|
|
925
|
+
tfp.distributions.LogNormal([0.2, 0.0], [0.9, 0.5]),
|
|
926
|
+
tfp.distributions.Gamma(2.0, 2.0)
|
|
927
|
+
]
|
|
928
|
+
distribution = IndependentMultivariateDistribution(distributions)
|
|
929
|
+
```
|
|
930
|
+
|
|
931
|
+
This class cannot contain instances of `tfd.Deterministic`.
|
|
932
|
+
"""
|
|
933
|
+
|
|
934
|
+
def __init__(
|
|
935
|
+
self,
|
|
936
|
+
distributions: Sequence[backend.tfd.Distribution],
|
|
937
|
+
validate_args: bool = False,
|
|
938
|
+
allow_nan_stats: bool = True,
|
|
939
|
+
name: str | None = None,
|
|
940
|
+
):
|
|
941
|
+
"""Initializes a batch of independent distributions from different families.
|
|
942
|
+
|
|
943
|
+
Args:
|
|
944
|
+
distributions: List of `tfd.Distribution` from which to construct a
|
|
945
|
+
multivariate distribution. The distributions must have scalar or one
|
|
946
|
+
dimensional batch shapes; the resulting batch shape will be the sum of
|
|
947
|
+
the underlying batch shapes.
|
|
948
|
+
validate_args: Python `bool`. When `True` distribution parameters are
|
|
949
|
+
checked for validity despite possibly degrading runtime performance.
|
|
950
|
+
When `False` invalid inputs may silently render incorrect outputs.
|
|
951
|
+
Default value is `False`.
|
|
952
|
+
allow_nan_stats: Python `bool`. When `True`, statistics (e.g., mean, mode,
|
|
953
|
+
variance) use the value "`NaN`" to indicate the result is undefined.
|
|
954
|
+
When `False`, an exception is raised if one or more of the statistic's
|
|
955
|
+
batch members are undefined. Default value is `True`.
|
|
956
|
+
name: Python `str` name prefixed to Ops created by this class. Default
|
|
957
|
+
value is 'IndependentMultivariate' followed by the names of the
|
|
958
|
+
underlying distributions.
|
|
959
|
+
|
|
960
|
+
Raises:
|
|
961
|
+
ValueError: If one or more distributions are instances of
|
|
962
|
+
`tfd.Deterministic` or dtypes differ between the
|
|
963
|
+
distributions.
|
|
964
|
+
"""
|
|
965
|
+
parameters = dict(locals())
|
|
966
|
+
|
|
967
|
+
self._verify_distributions(distributions)
|
|
968
|
+
|
|
969
|
+
self._distributions = [
|
|
970
|
+
dist
|
|
971
|
+
if not dist.is_scalar_batch()
|
|
972
|
+
else backend.tfd.BatchBroadcast(dist, (1,))
|
|
973
|
+
for dist in distributions
|
|
974
|
+
]
|
|
975
|
+
|
|
976
|
+
self._distribution_batch_shapes = self._get_distribution_batch_shapes()
|
|
977
|
+
self._distribution_batch_shape_tensors = backend.concatenate(
|
|
978
|
+
[dist.batch_shape_tensor() for dist in self._distributions],
|
|
979
|
+
axis=0,
|
|
980
|
+
)
|
|
981
|
+
|
|
982
|
+
dtype = self._verify_dtypes()
|
|
983
|
+
|
|
984
|
+
name = name or '-'.join(
|
|
985
|
+
[constants.INDEPENDENT_MULTIVARIATE] + [d.name for d in distributions]
|
|
986
|
+
)
|
|
987
|
+
|
|
988
|
+
super().__init__(
|
|
989
|
+
dtype=dtype,
|
|
990
|
+
reparameterization_type=backend.tfd.NOT_REPARAMETERIZED,
|
|
991
|
+
validate_args=validate_args,
|
|
992
|
+
allow_nan_stats=allow_nan_stats,
|
|
993
|
+
parameters=parameters,
|
|
994
|
+
name=name,
|
|
995
|
+
)
|
|
996
|
+
|
|
997
|
+
def _verify_distributions(
|
|
998
|
+
self, distributions: Sequence[backend.tfd.Distribution]
|
|
999
|
+
):
|
|
1000
|
+
"""Check for deterministic distributions and raise an error if found."""
|
|
1001
|
+
|
|
1002
|
+
if any(
|
|
1003
|
+
isinstance(dist, backend.tfd.Deterministic)
|
|
1004
|
+
for dist in distributions
|
|
1005
|
+
):
|
|
1006
|
+
raise ValueError(
|
|
1007
|
+
f'{self.__class__.__name__} cannot contain `Deterministic` '
|
|
1008
|
+
'distributions. To implement a nearly deterministic element of this '
|
|
1009
|
+
'distribution, we recommend using `backend.tfd.Uniform` with a '
|
|
1010
|
+
'small range. For example to define a distribution that is nearly '
|
|
1011
|
+
'`Deterministic(1.0)`, use '
|
|
1012
|
+
'`tfp.distribution.Uniform(1.0 - 1e-9, 1.0 + 1e-9)`'
|
|
1013
|
+
)
|
|
1014
|
+
|
|
1015
|
+
def _verify_dtypes(self) -> str:
|
|
1016
|
+
dtypes = [dist.dtype for dist in self._distributions]
|
|
1017
|
+
if len(set(dtypes)) != 1:
|
|
1018
|
+
raise ValueError(
|
|
1019
|
+
f'All distributions must have the same dtype. Found: {dtypes}.'
|
|
1020
|
+
)
|
|
1021
|
+
|
|
1022
|
+
return backend.result_type(*dtypes)
|
|
1023
|
+
|
|
1024
|
+
def _event_shape(self):
|
|
1025
|
+
return backend.TensorShape([])
|
|
1026
|
+
|
|
1027
|
+
def _batch_shape_tensor(self):
|
|
1028
|
+
distribution_batch_shape_tensors = backend.concatenate(
|
|
1029
|
+
[dist.batch_shape_tensor() for dist in self._distributions],
|
|
1030
|
+
axis=0,
|
|
1031
|
+
)
|
|
1032
|
+
return backend.reduce_sum(
|
|
1033
|
+
distribution_batch_shape_tensors, keepdims=True
|
|
1034
|
+
)
|
|
1035
|
+
|
|
1036
|
+
def _batch_shape(self):
|
|
1037
|
+
return backend.TensorShape(sum(self._distribution_batch_shapes))
|
|
1038
|
+
|
|
1039
|
+
def _sample_n(self, n, seed=None):
|
|
1040
|
+
return backend.concatenate(
|
|
1041
|
+
[dist.sample(n, seed) for dist in self._distributions], axis=-1
|
|
1042
|
+
)
|
|
1043
|
+
|
|
1044
|
+
def _quantile(self, value):
|
|
1045
|
+
value = self._broadcast_value(value)
|
|
1046
|
+
split_value = backend.split(
|
|
1047
|
+
value,
|
|
1048
|
+
self._distribution_batch_shapes, axis=-1
|
|
1049
|
+
)
|
|
1050
|
+
quantiles = [
|
|
1051
|
+
dist.quantile(sv) for dist, sv in zip(self._distributions, split_value)
|
|
1052
|
+
]
|
|
1053
|
+
|
|
1054
|
+
return backend.concatenate(quantiles, axis=-1)
|
|
1055
|
+
|
|
1056
|
+
def _log_prob(self, value):
|
|
1057
|
+
value = self._broadcast_value(value)
|
|
1058
|
+
split_value = backend.split(
|
|
1059
|
+
value,
|
|
1060
|
+
self._distribution_batch_shapes,
|
|
1061
|
+
axis=-1
|
|
1062
|
+
)
|
|
1063
|
+
log_probs = [
|
|
1064
|
+
dist.log_prob(sv) for dist, sv in zip(self._distributions, split_value)
|
|
1065
|
+
]
|
|
1066
|
+
|
|
1067
|
+
return backend.concatenate(log_probs, axis=-1)
|
|
1068
|
+
|
|
1069
|
+
def _log_cdf(self, value):
|
|
1070
|
+
value = self._broadcast_value(value)
|
|
1071
|
+
split_value = backend.split(
|
|
1072
|
+
value,
|
|
1073
|
+
self._distribution_batch_shapes,
|
|
1074
|
+
axis=-1
|
|
1075
|
+
)
|
|
1076
|
+
|
|
1077
|
+
log_cdfs = [
|
|
1078
|
+
dist.log_cdf(sv) for dist, sv in zip(self._distributions, split_value)
|
|
1079
|
+
]
|
|
1080
|
+
|
|
1081
|
+
return backend.concatenate(log_cdfs, axis=-1)
|
|
1082
|
+
|
|
1083
|
+
def _mean(self):
|
|
1084
|
+
return backend.concatenate(
|
|
1085
|
+
[dist.mean() for dist in self._distributions], axis=0
|
|
1086
|
+
)
|
|
1087
|
+
|
|
1088
|
+
def _variance(self):
|
|
1089
|
+
return backend.concatenate(
|
|
1090
|
+
[dist.variance() for dist in self._distributions], axis=0
|
|
1091
|
+
)
|
|
1092
|
+
|
|
1093
|
+
def _default_event_space_bijector(self):
|
|
1094
|
+
"""Mapping from R^n to the event space of the wrapped distributions.
|
|
1095
|
+
|
|
1096
|
+
This is the blockwise concatenation of the underlying bijectors.
|
|
1097
|
+
|
|
1098
|
+
Returns:
|
|
1099
|
+
A `tfp.bijectors.Blockwise` object that concatenates the underlying
|
|
1100
|
+
bijectors.
|
|
1101
|
+
"""
|
|
1102
|
+
bijectors = [
|
|
1103
|
+
d.experimental_default_event_space_bijector()
|
|
1104
|
+
for d in self._distributions
|
|
1105
|
+
]
|
|
1106
|
+
|
|
1107
|
+
return backend.bijectors.Blockwise(
|
|
1108
|
+
bijectors,
|
|
1109
|
+
block_sizes=self._distribution_batch_shapes,
|
|
1110
|
+
)
|
|
1111
|
+
|
|
1112
|
+
def _broadcast_value(self, value: backend.Tensor) -> backend.Tensor:
|
|
1113
|
+
value = backend.to_tensor(value)
|
|
1114
|
+
broadcast_shape = backend.broadcast_dynamic_shape(
|
|
1115
|
+
value.shape, self.batch_shape_tensor()
|
|
1116
|
+
)
|
|
1117
|
+
return backend.broadcast_to(value, broadcast_shape)
|
|
1118
|
+
|
|
1119
|
+
def _get_distribution_batch_shapes(self) -> Sequence[int]:
|
|
1120
|
+
"""Sequence of batch shapes of underlying distributions."""
|
|
1121
|
+
|
|
1122
|
+
batch_shapes = []
|
|
1123
|
+
|
|
1124
|
+
for dist in self._distributions:
|
|
1125
|
+
try:
|
|
1126
|
+
(dist_batch_shape,) = dist.batch_shape
|
|
1127
|
+
except ValueError as exc:
|
|
1128
|
+
raise ValueError(
|
|
1129
|
+
'All distributions must be 0- or 1-dimensional.'
|
|
1130
|
+
f' Found {len(dist.batch_shape)}-dimensional distribution:'
|
|
1131
|
+
f' {dist.batch_shape}.'
|
|
1132
|
+
) from exc
|
|
1133
|
+
else:
|
|
1134
|
+
batch_shapes.append(dist_batch_shape)
|
|
1135
|
+
|
|
1136
|
+
return batch_shapes
|
|
1137
|
+
|
|
1138
|
+
|
|
1139
|
+
def distributions_are_equal(
|
|
1140
|
+
a: backend.tfd.Distribution, b: backend.tfd.Distribution
|
|
1141
|
+
) -> bool:
|
|
1142
|
+
"""Determine if two distributions are equal."""
|
|
1143
|
+
if type(a) != type(b): # pylint: disable=unidiomatic-typecheck
|
|
1144
|
+
return False
|
|
1145
|
+
|
|
1146
|
+
a_params = a.parameters.copy()
|
|
1147
|
+
b_params = b.parameters.copy()
|
|
1148
|
+
|
|
1149
|
+
if constants.DISTRIBUTION in a_params and constants.DISTRIBUTION in b_params:
|
|
1150
|
+
if not distributions_are_equal(
|
|
1151
|
+
a_params[constants.DISTRIBUTION], b_params[constants.DISTRIBUTION]
|
|
1152
|
+
):
|
|
1153
|
+
return False
|
|
1154
|
+
del a_params[constants.DISTRIBUTION]
|
|
1155
|
+
del b_params[constants.DISTRIBUTION]
|
|
1156
|
+
|
|
1157
|
+
if constants.DISTRIBUTION in a_params or constants.DISTRIBUTION in b_params:
|
|
1158
|
+
return False
|
|
1159
|
+
|
|
1160
|
+
if a_params.keys() != b_params.keys():
|
|
1161
|
+
return False
|
|
1162
|
+
|
|
1163
|
+
for key in a_params.keys():
|
|
1164
|
+
if isinstance(
|
|
1165
|
+
a_params[key], (backend.Tensor, np.ndarray, float, int)
|
|
1166
|
+
) and isinstance(b_params[key], (backend.Tensor, np.ndarray, float, int)):
|
|
1167
|
+
if not backend.allclose(a_params[key], b_params[key]):
|
|
1168
|
+
return False
|
|
1169
|
+
else:
|
|
1170
|
+
if a_params[key] != b_params[key]:
|
|
1171
|
+
return False
|
|
1172
|
+
|
|
1173
|
+
return True
|
|
1174
|
+
|
|
1175
|
+
|
|
895
1176
|
def _convert_to_deterministic_0_distribution(
|
|
896
|
-
distribution:
|
|
897
|
-
) ->
|
|
1177
|
+
distribution: backend.tfd.Distribution,
|
|
1178
|
+
) -> backend.tfd.Distribution:
|
|
898
1179
|
"""Converts the given distribution to a `Deterministic(0)` one.
|
|
899
1180
|
|
|
900
1181
|
Args:
|
|
@@ -909,7 +1190,7 @@ def _convert_to_deterministic_0_distribution(
|
|
|
909
1190
|
distribution.
|
|
910
1191
|
"""
|
|
911
1192
|
if (
|
|
912
|
-
not isinstance(distribution,
|
|
1193
|
+
not isinstance(distribution, backend.tfd.Deterministic)
|
|
913
1194
|
or distribution.loc != 0
|
|
914
1195
|
):
|
|
915
1196
|
warnings.warn(
|
|
@@ -917,7 +1198,7 @@ def _convert_to_deterministic_0_distribution(
|
|
|
917
1198
|
f' for national models. {distribution.name} has been automatically set'
|
|
918
1199
|
' to Deterministic(0).'
|
|
919
1200
|
)
|
|
920
|
-
return
|
|
1201
|
+
return backend.tfd.Deterministic(loc=0, name=distribution.name)
|
|
921
1202
|
else:
|
|
922
1203
|
return distribution
|
|
923
1204
|
|
|
@@ -928,7 +1209,7 @@ def _get_total_media_contribution_prior(
|
|
|
928
1209
|
name: str,
|
|
929
1210
|
p_mean: float = constants.P_MEAN,
|
|
930
1211
|
p_sd: float = constants.P_SD,
|
|
931
|
-
) ->
|
|
1212
|
+
) -> backend.tfd.Distribution:
|
|
932
1213
|
"""Determines ROI priors based on total media contribution.
|
|
933
1214
|
|
|
934
1215
|
Args:
|
|
@@ -945,34 +1226,123 @@ def _get_total_media_contribution_prior(
|
|
|
945
1226
|
"""
|
|
946
1227
|
roi_mean = p_mean * kpi / np.sum(total_spend)
|
|
947
1228
|
roi_sd = p_sd * kpi / np.sqrt(np.sum(np.power(total_spend, 2)))
|
|
948
|
-
lognormal_sigma =
|
|
949
|
-
np.sqrt(np.log(roi_sd**2 / roi_mean**2 + 1)), dtype=
|
|
1229
|
+
lognormal_sigma = backend.cast(
|
|
1230
|
+
np.sqrt(np.log(roi_sd**2 / roi_mean**2 + 1)), dtype=backend.float32
|
|
950
1231
|
)
|
|
951
|
-
lognormal_mu =
|
|
952
|
-
np.log(roi_mean * np.exp(-(lognormal_sigma**2) / 2)),
|
|
1232
|
+
lognormal_mu = backend.cast(
|
|
1233
|
+
np.log(roi_mean * np.exp(-(lognormal_sigma**2) / 2)),
|
|
1234
|
+
dtype=backend.float32,
|
|
953
1235
|
)
|
|
954
|
-
return
|
|
1236
|
+
return backend.tfd.LogNormal(lognormal_mu, lognormal_sigma, name=name)
|
|
955
1237
|
|
|
956
1238
|
|
|
957
|
-
def
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
|
|
1239
|
+
def _validate_support(
|
|
1240
|
+
parameter_name: str,
|
|
1241
|
+
tfp_dist: backend.tfp.distributions.Distribution,
|
|
1242
|
+
bounds: tuple[float, float],
|
|
1243
|
+
prevent_deterministic_prior_at_bounds: tuple[bool, bool],
|
|
1244
|
+
) -> None:
|
|
1245
|
+
"""Validates that distribution support is within the parameter bounds.
|
|
963
1246
|
|
|
964
|
-
|
|
965
|
-
|
|
1247
|
+
Args:
|
|
1248
|
+
parameter_name: Name of the parameter.
|
|
1249
|
+
tfp_dist: The TFP distribution to validate.
|
|
1250
|
+
bounds: Tuple containing the min and max values of the parameteter space.
|
|
1251
|
+
prevent_deterministic_prior_at_bounds: Tuple of two booleans indicating
|
|
1252
|
+
whether a deterministic prior is allowed at the lower and upper bounds,
|
|
1253
|
+
respectively.
|
|
966
1254
|
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
|
|
1255
|
+
Raises:
|
|
1256
|
+
ValueError: If the distribution support is not within the parameter bounds.
|
|
1257
|
+
"""
|
|
1258
|
+
# Note that `tfp.distributions.BatchBroadcast` objects have a `distribution`
|
|
1259
|
+
# attribute that points to a `tfp.distributions.Distribution` object.
|
|
1260
|
+
if isinstance(tfp_dist, backend.tfp.distributions.BatchBroadcast):
|
|
1261
|
+
tfp_dist = tfp_dist.distribution
|
|
1262
|
+
# Note that `tfp.distributions.Deterministic` does not have a `quantile`
|
|
1263
|
+
# method implemented, so the min and max values must be extracted from the
|
|
1264
|
+
# `loc` attribute instead.
|
|
1265
|
+
if isinstance(
|
|
1266
|
+
tfp_dist,
|
|
1267
|
+
backend.tfp.python.distributions.deterministic.Deterministic
|
|
1268
|
+
):
|
|
1269
|
+
support_min_vals = tfp_dist.loc
|
|
1270
|
+
support_max_vals = tfp_dist.loc
|
|
1271
|
+
for i in (0, 1):
|
|
1272
|
+
if (
|
|
1273
|
+
prevent_deterministic_prior_at_bounds[i]
|
|
1274
|
+
and np.any(tfp_dist.loc == bounds[i])
|
|
1275
|
+
):
|
|
1276
|
+
raise ValueError(
|
|
1277
|
+
f'{parameter_name} was assigned a point mass (deterministic) prior'
|
|
1278
|
+
f' at {bounds[i]}, which is not allowed.'
|
|
1279
|
+
)
|
|
1280
|
+
else:
|
|
1281
|
+
try:
|
|
1282
|
+
support_min_vals = tfp_dist.quantile(0)
|
|
1283
|
+
support_max_vals = tfp_dist.quantile(1)
|
|
1284
|
+
except (AttributeError, NotImplementedError):
|
|
1285
|
+
warnings.warn(
|
|
1286
|
+
f'The prior distribution for {parameter_name} does not have a'
|
|
1287
|
+
f' `quantile` method implemented, so the support range validation'
|
|
1288
|
+
f' was skipped. Confirm that your prior for {parameter_name} is'
|
|
1289
|
+
f' appropriate.'
|
|
1290
|
+
)
|
|
1291
|
+
return
|
|
1292
|
+
if np.any(support_min_vals < bounds[0]):
|
|
1293
|
+
raise ValueError(
|
|
1294
|
+
f'{parameter_name} was assigned a prior distribution that allows values'
|
|
1295
|
+
f' less than the parameter minimum {bounds[0]}.'
|
|
1296
|
+
)
|
|
1297
|
+
if np.any(support_max_vals > bounds[1]):
|
|
1298
|
+
raise ValueError(
|
|
1299
|
+
f'{parameter_name} was assigned a prior distribution that allows values'
|
|
1300
|
+
f' greater than the parameter maximum {bounds[1]}.'
|
|
1301
|
+
)
|
|
974
1302
|
|
|
975
|
-
|
|
976
|
-
|
|
1303
|
+
# Dictionary of parameters that have a limited parameters space. The tuple
|
|
1304
|
+
# contains the lower and upper bounds, respectively.
|
|
1305
|
+
_parameter_space_bounds = {
|
|
1306
|
+
'eta_m': (0, np.inf),
|
|
1307
|
+
'eta_rf': (0, np.inf),
|
|
1308
|
+
'eta_om': (0, np.inf),
|
|
1309
|
+
'eta_orf': (0, np.inf),
|
|
1310
|
+
'xi_c': (0, np.inf),
|
|
1311
|
+
'xi_n': (0, np.inf),
|
|
1312
|
+
'alpha_m': (0, 1),
|
|
1313
|
+
'alpha_rf': (0, 1),
|
|
1314
|
+
'alpha_om': (0, 1),
|
|
1315
|
+
'alpha_orf': (0, 1),
|
|
1316
|
+
'ec_m': (0, np.inf),
|
|
1317
|
+
'ec_rf': (0, np.inf),
|
|
1318
|
+
'ec_om': (0, np.inf),
|
|
1319
|
+
'ec_orf': (0, np.inf),
|
|
1320
|
+
'slope_m': (0, np.inf),
|
|
1321
|
+
'slope_rf': (0, np.inf),
|
|
1322
|
+
'slope_om': (0, np.inf),
|
|
1323
|
+
'slope_orf': (0, np.inf),
|
|
1324
|
+
'sigma': (0, np.inf),
|
|
1325
|
+
}
|
|
977
1326
|
|
|
978
|
-
|
|
1327
|
+
# Dictionary of parameters that do not allow a deterministic prior at one or
|
|
1328
|
+
# more of the parameter space bounds. The boolean tuple indicates whether a
|
|
1329
|
+
# deterministic prior is allowed at the lower bound or upper bound,
|
|
1330
|
+
# respectively, where `True` means "not allowed". This check is specifically for
|
|
1331
|
+
# point mass at finite paramteter space bounds, since point mass at infinity is
|
|
1332
|
+
# generally problematic for all parameters. Note that `sigma` should generally
|
|
1333
|
+
# not have point mass at zero, but this is not checked here because unit tests
|
|
1334
|
+
# require the ability to simulate data with `sigma` set to zero.
|
|
1335
|
+
_prevent_deterministic_prior_at_bounds = {
|
|
1336
|
+
'alpha_m': (False, True),
|
|
1337
|
+
'alpha_rf': (False, True),
|
|
1338
|
+
'alpha_om': (False, True),
|
|
1339
|
+
'alpha_orf': (False, True),
|
|
1340
|
+
'ec_m': (True, False),
|
|
1341
|
+
'ec_rf': (True, False),
|
|
1342
|
+
'ec_om': (True, False),
|
|
1343
|
+
'ec_orf': (True, False),
|
|
1344
|
+
'slope_m': (True, False),
|
|
1345
|
+
'slope_rf': (True, False),
|
|
1346
|
+
'slope_om': (True, False),
|
|
1347
|
+
'slope_orf': (True, False),
|
|
1348
|
+
}
|