google-meridian 1.2.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.2.0.dist-info → google_meridian-1.2.1.dist-info}/METADATA +2 -2
- {google_meridian-1.2.0.dist-info → google_meridian-1.2.1.dist-info}/RECORD +24 -24
- meridian/analysis/analyzer.py +101 -37
- meridian/analysis/optimizer.py +132 -88
- meridian/analysis/summarizer.py +31 -16
- meridian/analysis/visualizer.py +16 -5
- meridian/backend/__init__.py +475 -14
- meridian/backend/config.py +75 -16
- meridian/backend/test_utils.py +87 -1
- meridian/constants.py +14 -9
- meridian/data/input_data.py +7 -2
- meridian/data/test_utils.py +5 -3
- meridian/mlflow/autolog.py +2 -2
- meridian/model/adstock_hill.py +10 -9
- meridian/model/eda/eda_engine.py +440 -11
- meridian/model/knots.py +1 -1
- meridian/model/model_test_data.py +15 -9
- meridian/model/posterior_sampler.py +365 -365
- meridian/model/prior_distribution.py +104 -39
- meridian/model/transformers.py +5 -5
- meridian/version.py +1 -1
- {google_meridian-1.2.0.dist-info → google_meridian-1.2.1.dist-info}/WHEEL +0 -0
- {google_meridian-1.2.0.dist-info → google_meridian-1.2.1.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.2.0.dist-info → google_meridian-1.2.1.dist-info}/top_level.txt +0 -0
meridian/model/eda/eda_engine.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
|
|
15
15
|
"""Meridian EDA Engine."""
|
|
16
16
|
|
|
17
|
+
import dataclasses
|
|
17
18
|
import functools
|
|
18
19
|
from typing import Callable, Dict, Optional, TypeAlias
|
|
19
20
|
from meridian import constants
|
|
@@ -28,11 +29,64 @@ _DEFAULT_DA_VAR_AGG_FUNCTION = np.sum
|
|
|
28
29
|
AggregationMap: TypeAlias = Dict[str, Callable[[xr.DataArray], np.ndarray]]
|
|
29
30
|
|
|
30
31
|
|
|
32
|
+
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
33
|
+
class ReachFrequencyData:
|
|
34
|
+
"""Holds reach and frequency data arrays.
|
|
35
|
+
|
|
36
|
+
Attributes:
|
|
37
|
+
reach_raw_da: Raw reach data.
|
|
38
|
+
reach_scaled_da: Scaled reach data.
|
|
39
|
+
reach_raw_da_national: National raw reach data.
|
|
40
|
+
reach_scaled_da_national: National scaled reach data.
|
|
41
|
+
frequency_da: Frequency data.
|
|
42
|
+
frequency_da_national: National frequency data.
|
|
43
|
+
rf_impressions_scaled_da: Scaled reach * frequency impressions data.
|
|
44
|
+
rf_impressions_scaled_da_national: National scaled reach * frequency
|
|
45
|
+
impressions data.
|
|
46
|
+
rf_impressions_raw_da: Raw reach * frequency impressions data.
|
|
47
|
+
rf_impressions_raw_da_national: National raw reach * frequency impressions
|
|
48
|
+
data.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
reach_raw_da: xr.DataArray
|
|
52
|
+
reach_scaled_da: xr.DataArray
|
|
53
|
+
reach_raw_da_national: xr.DataArray
|
|
54
|
+
reach_scaled_da_national: xr.DataArray
|
|
55
|
+
frequency_da: xr.DataArray
|
|
56
|
+
frequency_da_national: xr.DataArray
|
|
57
|
+
rf_impressions_scaled_da: xr.DataArray
|
|
58
|
+
rf_impressions_scaled_da_national: xr.DataArray
|
|
59
|
+
rf_impressions_raw_da: xr.DataArray
|
|
60
|
+
rf_impressions_raw_da_national: xr.DataArray
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
64
|
+
class AggregationConfig:
|
|
65
|
+
"""Configuration for custom aggregation functions.
|
|
66
|
+
|
|
67
|
+
Attributes:
|
|
68
|
+
control_variables: A dictionary mapping control variable names to
|
|
69
|
+
aggregation functions. Defaults to `np.sum` if a variable is not
|
|
70
|
+
specified.
|
|
71
|
+
non_media_treatments: A dictionary mapping non-media variable names to
|
|
72
|
+
aggregation functions. Defaults to `np.sum` if a variable is not
|
|
73
|
+
specified.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
control_variables: AggregationMap = dataclasses.field(default_factory=dict)
|
|
77
|
+
non_media_treatments: AggregationMap = dataclasses.field(default_factory=dict)
|
|
78
|
+
|
|
79
|
+
|
|
31
80
|
class EDAEngine:
|
|
32
81
|
"""Meridian EDA Engine."""
|
|
33
82
|
|
|
34
|
-
def __init__(
|
|
83
|
+
def __init__(
|
|
84
|
+
self,
|
|
85
|
+
meridian: model.Meridian,
|
|
86
|
+
agg_config: AggregationConfig = AggregationConfig(),
|
|
87
|
+
):
|
|
35
88
|
self._meridian = meridian
|
|
89
|
+
self._agg_config = agg_config
|
|
36
90
|
|
|
37
91
|
@functools.cached_property
|
|
38
92
|
def controls_scaled_da(self) -> xr.DataArray | None:
|
|
@@ -44,6 +98,26 @@ class EDAEngine:
|
|
|
44
98
|
)
|
|
45
99
|
return controls_scaled_da
|
|
46
100
|
|
|
101
|
+
@functools.cached_property
|
|
102
|
+
def controls_scaled_da_national(self) -> xr.DataArray | None:
|
|
103
|
+
"""Returns the national controls data array."""
|
|
104
|
+
if self._meridian.input_data.controls is None:
|
|
105
|
+
return None
|
|
106
|
+
if self._meridian.is_national:
|
|
107
|
+
if self.controls_scaled_da is None:
|
|
108
|
+
# This case should be impossible given the check above.
|
|
109
|
+
raise RuntimeError(
|
|
110
|
+
'controls_scaled_da is None when controls is not None.'
|
|
111
|
+
)
|
|
112
|
+
return self.controls_scaled_da.squeeze(constants.GEO)
|
|
113
|
+
else:
|
|
114
|
+
return self._aggregate_and_scale_geo_da(
|
|
115
|
+
self._meridian.input_data.controls,
|
|
116
|
+
transformers.CenteringAndScalingTransformer,
|
|
117
|
+
constants.CONTROL_VARIABLE,
|
|
118
|
+
self._agg_config.control_variables,
|
|
119
|
+
)
|
|
120
|
+
|
|
47
121
|
@functools.cached_property
|
|
48
122
|
def media_raw_da(self) -> xr.DataArray | None:
|
|
49
123
|
if self._meridian.input_data.media is None:
|
|
@@ -71,13 +145,32 @@ class EDAEngine:
|
|
|
71
145
|
# No need to truncate the media time for media spend.
|
|
72
146
|
return media_spend_da
|
|
73
147
|
|
|
148
|
+
@functools.cached_property
|
|
149
|
+
def media_spend_da_national(self) -> xr.DataArray | None:
|
|
150
|
+
"""Returns the national media spend data array."""
|
|
151
|
+
if self._meridian.input_data.media_spend is None:
|
|
152
|
+
return None
|
|
153
|
+
if self._meridian.is_national:
|
|
154
|
+
if self.media_spend_da is None:
|
|
155
|
+
# This case should be impossible given the check above.
|
|
156
|
+
raise RuntimeError(
|
|
157
|
+
'media_spend_da is None when media_spend is not None.'
|
|
158
|
+
)
|
|
159
|
+
return self.media_spend_da.squeeze(constants.GEO)
|
|
160
|
+
else:
|
|
161
|
+
return self._aggregate_and_scale_geo_da(
|
|
162
|
+
self._meridian.input_data.media_spend,
|
|
163
|
+
None,
|
|
164
|
+
)
|
|
165
|
+
|
|
74
166
|
@functools.cached_property
|
|
75
167
|
def media_raw_da_national(self) -> xr.DataArray | None:
|
|
76
168
|
if self.media_raw_da is None:
|
|
77
169
|
return None
|
|
78
170
|
if self._meridian.is_national:
|
|
79
|
-
return self.media_raw_da
|
|
171
|
+
return self.media_raw_da.squeeze(constants.GEO)
|
|
80
172
|
else:
|
|
173
|
+
# Note that media is summable by assumption.
|
|
81
174
|
return self._aggregate_and_scale_geo_da(
|
|
82
175
|
self.media_raw_da,
|
|
83
176
|
None,
|
|
@@ -88,8 +181,9 @@ class EDAEngine:
|
|
|
88
181
|
if self.media_scaled_da is None:
|
|
89
182
|
return None
|
|
90
183
|
if self._meridian.is_national:
|
|
91
|
-
return self.media_scaled_da
|
|
184
|
+
return self.media_scaled_da.squeeze(constants.GEO)
|
|
92
185
|
else:
|
|
186
|
+
# Note that media is summable by assumption.
|
|
93
187
|
return self._aggregate_and_scale_geo_da(
|
|
94
188
|
self.media_raw_da,
|
|
95
189
|
transformers.MediaTransformer,
|
|
@@ -116,8 +210,9 @@ class EDAEngine:
|
|
|
116
210
|
if self.organic_media_raw_da is None:
|
|
117
211
|
return None
|
|
118
212
|
if self._meridian.is_national:
|
|
119
|
-
return self.organic_media_raw_da
|
|
213
|
+
return self.organic_media_raw_da.squeeze(constants.GEO)
|
|
120
214
|
else:
|
|
215
|
+
# Note that organic media is summable by assumption.
|
|
121
216
|
return self._aggregate_and_scale_geo_da(self.organic_media_raw_da, None)
|
|
122
217
|
|
|
123
218
|
@functools.cached_property
|
|
@@ -125,8 +220,9 @@ class EDAEngine:
|
|
|
125
220
|
if self.organic_media_scaled_da is None:
|
|
126
221
|
return None
|
|
127
222
|
if self._meridian.is_national:
|
|
128
|
-
return self.organic_media_scaled_da
|
|
223
|
+
return self.organic_media_scaled_da.squeeze(constants.GEO)
|
|
129
224
|
else:
|
|
225
|
+
# Note that organic media is summable by assumption.
|
|
130
226
|
return self._aggregate_and_scale_geo_da(
|
|
131
227
|
self.organic_media_raw_da,
|
|
132
228
|
transformers.MediaTransformer,
|
|
@@ -142,6 +238,26 @@ class EDAEngine:
|
|
|
142
238
|
)
|
|
143
239
|
return non_media_scaled_da
|
|
144
240
|
|
|
241
|
+
@functools.cached_property
|
|
242
|
+
def non_media_scaled_da_national(self) -> xr.DataArray | None:
|
|
243
|
+
"""Returns the national non-media treatment data array."""
|
|
244
|
+
if self._meridian.input_data.non_media_treatments is None:
|
|
245
|
+
return None
|
|
246
|
+
if self._meridian.is_national:
|
|
247
|
+
if self.non_media_scaled_da is None:
|
|
248
|
+
# This case should be impossible given the check above.
|
|
249
|
+
raise RuntimeError(
|
|
250
|
+
'non_media_scaled_da is None when non_media_treatments is not None.'
|
|
251
|
+
)
|
|
252
|
+
return self.non_media_scaled_da.squeeze(constants.GEO)
|
|
253
|
+
else:
|
|
254
|
+
return self._aggregate_and_scale_geo_da(
|
|
255
|
+
self._meridian.input_data.non_media_treatments,
|
|
256
|
+
transformers.CenteringAndScalingTransformer,
|
|
257
|
+
constants.NON_MEDIA_CHANNEL,
|
|
258
|
+
self._agg_config.non_media_treatments,
|
|
259
|
+
)
|
|
260
|
+
|
|
145
261
|
@functools.cached_property
|
|
146
262
|
def rf_spend_da(self) -> xr.DataArray | None:
|
|
147
263
|
if self._meridian.input_data.rf_spend is None:
|
|
@@ -157,12 +273,226 @@ class EDAEngine:
|
|
|
157
273
|
if self._meridian.input_data.rf_spend is None:
|
|
158
274
|
return None
|
|
159
275
|
if self._meridian.is_national:
|
|
160
|
-
|
|
276
|
+
if self.rf_spend_da is None:
|
|
277
|
+
# This case should be impossible given the check above.
|
|
278
|
+
raise RuntimeError('rf_spend_da is None when rf_spend is not None.')
|
|
279
|
+
return self.rf_spend_da.squeeze(constants.GEO)
|
|
161
280
|
else:
|
|
162
281
|
return self._aggregate_and_scale_geo_da(
|
|
163
282
|
self._meridian.input_data.rf_spend, None
|
|
164
283
|
)
|
|
165
284
|
|
|
285
|
+
@functools.cached_property
|
|
286
|
+
def _rf_data(self) -> ReachFrequencyData | None:
|
|
287
|
+
if self._meridian.input_data.reach is None:
|
|
288
|
+
return None
|
|
289
|
+
return self._get_rf_data(
|
|
290
|
+
self._meridian.input_data.reach,
|
|
291
|
+
self._meridian.input_data.frequency,
|
|
292
|
+
is_organic=False,
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
@property
|
|
296
|
+
def reach_raw_da(self) -> xr.DataArray | None:
|
|
297
|
+
if self._rf_data is None:
|
|
298
|
+
return None
|
|
299
|
+
return self._rf_data.reach_raw_da
|
|
300
|
+
|
|
301
|
+
@property
|
|
302
|
+
def reach_scaled_da(self) -> xr.DataArray | None:
|
|
303
|
+
if self._rf_data is None:
|
|
304
|
+
return None
|
|
305
|
+
return self._rf_data.reach_scaled_da
|
|
306
|
+
|
|
307
|
+
@property
|
|
308
|
+
def reach_raw_da_national(self) -> xr.DataArray | None:
|
|
309
|
+
if self._rf_data is None:
|
|
310
|
+
return None
|
|
311
|
+
return self._rf_data.reach_raw_da_national
|
|
312
|
+
|
|
313
|
+
@property
|
|
314
|
+
def reach_scaled_da_national(self) -> xr.DataArray | None:
|
|
315
|
+
if self._rf_data is None:
|
|
316
|
+
return None
|
|
317
|
+
return self._rf_data.reach_scaled_da_national
|
|
318
|
+
|
|
319
|
+
@property
|
|
320
|
+
def frequency_da(self) -> xr.DataArray | None:
|
|
321
|
+
if self._rf_data is None:
|
|
322
|
+
return None
|
|
323
|
+
return self._rf_data.frequency_da
|
|
324
|
+
|
|
325
|
+
@property
|
|
326
|
+
def frequency_da_national(self) -> xr.DataArray | None:
|
|
327
|
+
if self._rf_data is None:
|
|
328
|
+
return None
|
|
329
|
+
return self._rf_data.frequency_da_national
|
|
330
|
+
|
|
331
|
+
@property
|
|
332
|
+
def rf_impressions_raw_da(self) -> xr.DataArray | None:
|
|
333
|
+
if self._rf_data is None:
|
|
334
|
+
return None
|
|
335
|
+
return self._rf_data.rf_impressions_raw_da
|
|
336
|
+
|
|
337
|
+
@property
|
|
338
|
+
def rf_impressions_raw_da_national(self) -> xr.DataArray | None:
|
|
339
|
+
if self._rf_data is None:
|
|
340
|
+
return None
|
|
341
|
+
return self._rf_data.rf_impressions_raw_da_national
|
|
342
|
+
|
|
343
|
+
@property
|
|
344
|
+
def rf_impressions_scaled_da(self) -> xr.DataArray | None:
|
|
345
|
+
if self._rf_data is None:
|
|
346
|
+
return None
|
|
347
|
+
return self._rf_data.rf_impressions_scaled_da
|
|
348
|
+
|
|
349
|
+
@property
|
|
350
|
+
def rf_impressions_scaled_da_national(self) -> xr.DataArray | None:
|
|
351
|
+
if self._rf_data is None:
|
|
352
|
+
return None
|
|
353
|
+
return self._rf_data.rf_impressions_scaled_da_national
|
|
354
|
+
|
|
355
|
+
@functools.cached_property
|
|
356
|
+
def _organic_rf_data(self) -> ReachFrequencyData | None:
|
|
357
|
+
if self._meridian.input_data.organic_reach is None:
|
|
358
|
+
return None
|
|
359
|
+
return self._get_rf_data(
|
|
360
|
+
self._meridian.input_data.organic_reach,
|
|
361
|
+
self._meridian.input_data.organic_frequency,
|
|
362
|
+
is_organic=True,
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
@property
|
|
366
|
+
def organic_reach_raw_da(self) -> xr.DataArray | None:
|
|
367
|
+
if self._organic_rf_data is None:
|
|
368
|
+
return None
|
|
369
|
+
return self._organic_rf_data.reach_raw_da
|
|
370
|
+
|
|
371
|
+
@property
|
|
372
|
+
def organic_reach_scaled_da(self) -> xr.DataArray | None:
|
|
373
|
+
if self._organic_rf_data is None:
|
|
374
|
+
return None
|
|
375
|
+
return self._organic_rf_data.reach_scaled_da
|
|
376
|
+
|
|
377
|
+
@property
|
|
378
|
+
def organic_reach_raw_da_national(self) -> xr.DataArray | None:
|
|
379
|
+
if self._organic_rf_data is None:
|
|
380
|
+
return None
|
|
381
|
+
return self._organic_rf_data.reach_raw_da_national
|
|
382
|
+
|
|
383
|
+
@property
|
|
384
|
+
def organic_reach_scaled_da_national(self) -> xr.DataArray | None:
|
|
385
|
+
if self._organic_rf_data is None:
|
|
386
|
+
return None
|
|
387
|
+
return self._organic_rf_data.reach_scaled_da_national
|
|
388
|
+
|
|
389
|
+
@property
|
|
390
|
+
def organic_rf_impressions_scaled_da(self) -> xr.DataArray | None:
|
|
391
|
+
if self._organic_rf_data is None:
|
|
392
|
+
return None
|
|
393
|
+
return self._organic_rf_data.rf_impressions_scaled_da
|
|
394
|
+
|
|
395
|
+
@property
|
|
396
|
+
def organic_rf_impressions_scaled_da_national(self) -> xr.DataArray | None:
|
|
397
|
+
if self._organic_rf_data is None:
|
|
398
|
+
return None
|
|
399
|
+
return self._organic_rf_data.rf_impressions_scaled_da_national
|
|
400
|
+
|
|
401
|
+
@property
|
|
402
|
+
def organic_frequency_da(self) -> xr.DataArray | None:
|
|
403
|
+
if self._organic_rf_data is None:
|
|
404
|
+
return None
|
|
405
|
+
return self._organic_rf_data.frequency_da
|
|
406
|
+
|
|
407
|
+
@property
|
|
408
|
+
def organic_frequency_da_national(self) -> xr.DataArray | None:
|
|
409
|
+
if self._organic_rf_data is None:
|
|
410
|
+
return None
|
|
411
|
+
return self._organic_rf_data.frequency_da_national
|
|
412
|
+
|
|
413
|
+
@property
|
|
414
|
+
def organic_rf_impressions_raw_da(self) -> xr.DataArray | None:
|
|
415
|
+
if self._organic_rf_data is None:
|
|
416
|
+
return None
|
|
417
|
+
return self._organic_rf_data.rf_impressions_raw_da
|
|
418
|
+
|
|
419
|
+
@property
|
|
420
|
+
def organic_rf_impressions_raw_da_national(self) -> xr.DataArray | None:
|
|
421
|
+
if self._organic_rf_data is None:
|
|
422
|
+
return None
|
|
423
|
+
return self._organic_rf_data.rf_impressions_raw_da_national
|
|
424
|
+
|
|
425
|
+
@functools.cached_property
|
|
426
|
+
def geo_population_da(self) -> xr.DataArray | None:
|
|
427
|
+
if self._meridian.is_national:
|
|
428
|
+
return None
|
|
429
|
+
return xr.DataArray(
|
|
430
|
+
self._meridian.population,
|
|
431
|
+
coords={constants.GEO: self._meridian.input_data.geo.values},
|
|
432
|
+
dims=[constants.GEO],
|
|
433
|
+
name=constants.POPULATION,
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
@functools.cached_property
|
|
437
|
+
def kpi_scaled_da(self) -> xr.DataArray:
|
|
438
|
+
return _data_array_like(
|
|
439
|
+
da=self._meridian.input_data.kpi,
|
|
440
|
+
values=self._meridian.kpi_scaled,
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
@functools.cached_property
|
|
444
|
+
def kpi_scaled_da_national(self) -> xr.DataArray:
|
|
445
|
+
if self._meridian.is_national:
|
|
446
|
+
return self.kpi_scaled_da.squeeze(constants.GEO)
|
|
447
|
+
else:
|
|
448
|
+
# Note that kpi is summable by assumption.
|
|
449
|
+
return self._aggregate_and_scale_geo_da(
|
|
450
|
+
self._meridian.input_data.kpi,
|
|
451
|
+
transformers.CenteringAndScalingTransformer,
|
|
452
|
+
)
|
|
453
|
+
|
|
454
|
+
@functools.cached_property
|
|
455
|
+
def treatment_control_scaled_ds(self) -> xr.Dataset:
|
|
456
|
+
"""Returns a Dataset containing all scaled treatments and controls.
|
|
457
|
+
|
|
458
|
+
This includes media, RF impressions, organic media, organic RF impressions,
|
|
459
|
+
non-media treatments, and control variables, all at the geo level.
|
|
460
|
+
"""
|
|
461
|
+
to_merge = [
|
|
462
|
+
da
|
|
463
|
+
for da in [
|
|
464
|
+
self.media_scaled_da,
|
|
465
|
+
self.rf_impressions_scaled_da,
|
|
466
|
+
self.organic_media_scaled_da,
|
|
467
|
+
self.organic_rf_impressions_scaled_da,
|
|
468
|
+
self.controls_scaled_da,
|
|
469
|
+
self.non_media_scaled_da,
|
|
470
|
+
]
|
|
471
|
+
if da is not None
|
|
472
|
+
]
|
|
473
|
+
return xr.merge(to_merge, join='inner')
|
|
474
|
+
|
|
475
|
+
@functools.cached_property
|
|
476
|
+
def treatment_control_scaled_ds_national(self) -> xr.Dataset:
|
|
477
|
+
"""Returns a Dataset containing all scaled treatments and controls.
|
|
478
|
+
|
|
479
|
+
This includes media, RF impressions, organic media, organic RF impressions,
|
|
480
|
+
non-media treatments, and control variables, all at the national level.
|
|
481
|
+
"""
|
|
482
|
+
to_merge_national = [
|
|
483
|
+
da
|
|
484
|
+
for da in [
|
|
485
|
+
self.media_scaled_da_national,
|
|
486
|
+
self.rf_impressions_scaled_da_national,
|
|
487
|
+
self.organic_media_scaled_da_national,
|
|
488
|
+
self.organic_rf_impressions_scaled_da_national,
|
|
489
|
+
self.controls_scaled_da_national,
|
|
490
|
+
self.non_media_scaled_da_national,
|
|
491
|
+
]
|
|
492
|
+
if da is not None
|
|
493
|
+
]
|
|
494
|
+
return xr.merge(to_merge_national, join='inner')
|
|
495
|
+
|
|
166
496
|
def _truncate_media_time(self, da: xr.DataArray) -> xr.DataArray:
|
|
167
497
|
"""Truncates the first `start` elements of the media time of a variable."""
|
|
168
498
|
# This should not happen. If it does, it means this function is mis-used.
|
|
@@ -183,15 +513,17 @@ class EDAEngine:
|
|
|
183
513
|
population: tf.Tensor = tf.constant([1.0], dtype=tf.float32),
|
|
184
514
|
):
|
|
185
515
|
"""Scales xarray values with a TensorTransformer."""
|
|
516
|
+
da = xarray.copy()
|
|
517
|
+
|
|
186
518
|
if transformer_class is None:
|
|
187
|
-
return
|
|
519
|
+
return da
|
|
188
520
|
elif transformer_class is transformers.CenteringAndScalingTransformer:
|
|
189
521
|
xarray_transformer = transformers.CenteringAndScalingTransformer(
|
|
190
|
-
tensor=
|
|
522
|
+
tensor=da.values, population=population
|
|
191
523
|
)
|
|
192
524
|
elif transformer_class is transformers.MediaTransformer:
|
|
193
525
|
xarray_transformer = transformers.MediaTransformer(
|
|
194
|
-
media=
|
|
526
|
+
media=da.values, population=population
|
|
195
527
|
)
|
|
196
528
|
else:
|
|
197
529
|
raise ValueError(
|
|
@@ -200,8 +532,8 @@ class EDAEngine:
|
|
|
200
532
|
+ '.\nMust be one of: CenteringAndScalingTransformer or'
|
|
201
533
|
' MediaTransformer.'
|
|
202
534
|
)
|
|
203
|
-
|
|
204
|
-
return
|
|
535
|
+
da.values = xarray_transformer.forward(da.values)
|
|
536
|
+
return da
|
|
205
537
|
|
|
206
538
|
def _aggregate_variables(
|
|
207
539
|
self,
|
|
@@ -282,6 +614,103 @@ class EDAEngine:
|
|
|
282
614
|
|
|
283
615
|
return da_national.sel({constants.GEO: temp_geo_dim}, drop=True)
|
|
284
616
|
|
|
617
|
+
def _get_rf_data(
|
|
618
|
+
self,
|
|
619
|
+
reach_raw_da: xr.DataArray,
|
|
620
|
+
freq_raw_da: xr.DataArray,
|
|
621
|
+
is_organic: bool,
|
|
622
|
+
) -> ReachFrequencyData:
|
|
623
|
+
"""Get impressions and frequencies data arrays for RF channels."""
|
|
624
|
+
if is_organic:
|
|
625
|
+
scaled_reach_values = (
|
|
626
|
+
self._meridian.organic_rf_tensors.organic_reach_scaled
|
|
627
|
+
)
|
|
628
|
+
else:
|
|
629
|
+
scaled_reach_values = self._meridian.rf_tensors.reach_scaled
|
|
630
|
+
reach_scaled_da = _data_array_like(
|
|
631
|
+
da=reach_raw_da, values=scaled_reach_values
|
|
632
|
+
)
|
|
633
|
+
# Truncate the media time for reach and scaled reach.
|
|
634
|
+
reach_raw_da = self._truncate_media_time(reach_raw_da)
|
|
635
|
+
reach_scaled_da = self._truncate_media_time(reach_scaled_da)
|
|
636
|
+
|
|
637
|
+
# The geo level frequency
|
|
638
|
+
frequency_da = self._truncate_media_time(freq_raw_da)
|
|
639
|
+
|
|
640
|
+
# The raw geo level impression
|
|
641
|
+
# It's equal to reach * frequency.
|
|
642
|
+
impressions_raw_da = reach_raw_da * frequency_da
|
|
643
|
+
impressions_raw_da.name = (
|
|
644
|
+
constants.ORGANIC_RF_IMPRESSIONS
|
|
645
|
+
if is_organic
|
|
646
|
+
else constants.RF_IMPRESSIONS
|
|
647
|
+
)
|
|
648
|
+
impressions_raw_da.values = tf.cast(impressions_raw_da.values, tf.float32)
|
|
649
|
+
|
|
650
|
+
if self._meridian.is_national:
|
|
651
|
+
reach_raw_da_national = reach_raw_da.squeeze(constants.GEO)
|
|
652
|
+
reach_scaled_da_national = reach_scaled_da.squeeze(constants.GEO)
|
|
653
|
+
impressions_raw_da_national = impressions_raw_da.squeeze(constants.GEO)
|
|
654
|
+
frequency_da_national = frequency_da.squeeze(constants.GEO)
|
|
655
|
+
|
|
656
|
+
# Scaled impressions
|
|
657
|
+
impressions_scaled_da = self._scale_xarray(
|
|
658
|
+
impressions_raw_da, transformers.MediaTransformer
|
|
659
|
+
)
|
|
660
|
+
impressions_scaled_da_national = impressions_scaled_da.squeeze(
|
|
661
|
+
constants.GEO
|
|
662
|
+
)
|
|
663
|
+
else:
|
|
664
|
+
reach_raw_da_national = self._aggregate_and_scale_geo_da(
|
|
665
|
+
reach_raw_da, None
|
|
666
|
+
)
|
|
667
|
+
reach_scaled_da_national = self._aggregate_and_scale_geo_da(
|
|
668
|
+
reach_raw_da, transformers.MediaTransformer
|
|
669
|
+
)
|
|
670
|
+
impressions_raw_da_national = self._aggregate_and_scale_geo_da(
|
|
671
|
+
impressions_raw_da, None
|
|
672
|
+
)
|
|
673
|
+
|
|
674
|
+
# National frequency is a weighted average of geo frequencies,
|
|
675
|
+
# weighted by reach.
|
|
676
|
+
frequency_da_national = xr.where(
|
|
677
|
+
reach_raw_da_national == 0.0,
|
|
678
|
+
0.0,
|
|
679
|
+
impressions_raw_da_national / reach_raw_da_national,
|
|
680
|
+
)
|
|
681
|
+
frequency_da_national.name = (
|
|
682
|
+
constants.ORGANIC_PREFIX if is_organic else ''
|
|
683
|
+
) + constants.FREQUENCY
|
|
684
|
+
frequency_da_national.values = tf.cast(
|
|
685
|
+
frequency_da_national.values, tf.float32
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
# Scale the impressions by population
|
|
689
|
+
impressions_scaled_da = self._scale_xarray(
|
|
690
|
+
impressions_raw_da,
|
|
691
|
+
transformers.MediaTransformer,
|
|
692
|
+
population=self._meridian.population,
|
|
693
|
+
)
|
|
694
|
+
|
|
695
|
+
# Scale the national impressions
|
|
696
|
+
impressions_scaled_da_national = self._aggregate_and_scale_geo_da(
|
|
697
|
+
impressions_raw_da,
|
|
698
|
+
transformers.MediaTransformer,
|
|
699
|
+
)
|
|
700
|
+
|
|
701
|
+
return ReachFrequencyData(
|
|
702
|
+
reach_raw_da=reach_raw_da,
|
|
703
|
+
reach_scaled_da=reach_scaled_da,
|
|
704
|
+
reach_raw_da_national=reach_raw_da_national,
|
|
705
|
+
reach_scaled_da_national=reach_scaled_da_national,
|
|
706
|
+
frequency_da=frequency_da,
|
|
707
|
+
frequency_da_national=frequency_da_national,
|
|
708
|
+
rf_impressions_scaled_da=impressions_scaled_da,
|
|
709
|
+
rf_impressions_scaled_da_national=impressions_scaled_da_national,
|
|
710
|
+
rf_impressions_raw_da=impressions_raw_da,
|
|
711
|
+
rf_impressions_raw_da_national=impressions_raw_da_national,
|
|
712
|
+
)
|
|
713
|
+
|
|
285
714
|
|
|
286
715
|
def _data_array_like(
|
|
287
716
|
*, da: xr.DataArray, values: np.ndarray | tf.Tensor
|
meridian/model/knots.py
CHANGED
|
@@ -276,7 +276,7 @@ class AKS:
|
|
|
276
276
|
feasible_idx = np.where(
|
|
277
277
|
(n_knots >= min_internal_knots) & (n_knots <= max_internal_knots)
|
|
278
278
|
)[0]
|
|
279
|
-
information_criterion = aspline[constants.
|
|
279
|
+
information_criterion = aspline[constants.AIC][feasible_idx]
|
|
280
280
|
knots_sel = [aspline[constants.KNOTS_SELECTED][i] for i in feasible_idx]
|
|
281
281
|
model = [aspline[constants.MODEL][i] for i in feasible_idx]
|
|
282
282
|
opt_idx = max(
|
|
@@ -52,7 +52,9 @@ def _convert_with_swap(array: xr.DataArray, n_burnin: int) -> backend.Tensor:
|
|
|
52
52
|
else:
|
|
53
53
|
pad_value = 0.0 if array.dtype.kind == "f" else 0
|
|
54
54
|
|
|
55
|
-
burnin = backend.fill(
|
|
55
|
+
burnin = backend.fill(
|
|
56
|
+
[n_burnin] + list(transposed_tensor.shape[1:]), pad_value
|
|
57
|
+
)
|
|
56
58
|
return backend.concatenate(
|
|
57
59
|
[burnin, transposed_tensor],
|
|
58
60
|
axis=0,
|
|
@@ -122,18 +124,13 @@ class WithInputDataSamples:
|
|
|
122
124
|
_N_MEDIA_CHANNELS = 3
|
|
123
125
|
_N_RF_CHANNELS = 2
|
|
124
126
|
_N_CONTROLS = 2
|
|
125
|
-
_ROI_CALIBRATION_PERIOD = backend.cast(
|
|
126
|
-
backend.ones((_N_MEDIA_TIMES_SHORT, _N_MEDIA_CHANNELS)),
|
|
127
|
-
dtype=backend.bool_,
|
|
128
|
-
)
|
|
129
|
-
_RF_ROI_CALIBRATION_PERIOD = backend.cast(
|
|
130
|
-
backend.ones((_N_MEDIA_TIMES_SHORT, _N_RF_CHANNELS)),
|
|
131
|
-
dtype=backend.bool_,
|
|
132
|
-
)
|
|
133
127
|
_N_ORGANIC_MEDIA_CHANNELS = 4
|
|
134
128
|
_N_ORGANIC_RF_CHANNELS = 1
|
|
135
129
|
_N_NON_MEDIA_CHANNELS = 2
|
|
136
130
|
|
|
131
|
+
_ROI_CALIBRATION_PERIOD: backend.Tensor
|
|
132
|
+
_RF_ROI_CALIBRATION_PERIOD: backend.Tensor
|
|
133
|
+
|
|
137
134
|
# Private class variables to hold the base test data.
|
|
138
135
|
_input_data_non_revenue_no_revenue_per_kpi: input_data.InputData
|
|
139
136
|
_input_data_media_and_rf_non_revenue_no_revenue_per_kpi: input_data.InputData
|
|
@@ -170,6 +167,15 @@ class WithInputDataSamples:
|
|
|
170
167
|
@classmethod
|
|
171
168
|
def setup(cls):
|
|
172
169
|
"""Sets up input data samples."""
|
|
170
|
+
cls._ROI_CALIBRATION_PERIOD = backend.cast(
|
|
171
|
+
backend.ones((cls._N_MEDIA_TIMES_SHORT, cls._N_MEDIA_CHANNELS)),
|
|
172
|
+
dtype=backend.bool_,
|
|
173
|
+
)
|
|
174
|
+
cls._RF_ROI_CALIBRATION_PERIOD = backend.cast(
|
|
175
|
+
backend.ones((cls._N_MEDIA_TIMES_SHORT, cls._N_RF_CHANNELS)),
|
|
176
|
+
dtype=backend.bool_,
|
|
177
|
+
)
|
|
178
|
+
|
|
173
179
|
cls._input_data_non_revenue_no_revenue_per_kpi = (
|
|
174
180
|
test_utils.sample_input_data_non_revenue_no_revenue_per_kpi(
|
|
175
181
|
n_geos=cls._N_GEOS,
|