google-meridian 1.1.6__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.1.dist-info}/METADATA +8 -2
- google_meridian-1.2.1.dist-info/RECORD +52 -0
- meridian/__init__.py +1 -0
- meridian/analysis/analyzer.py +621 -393
- meridian/analysis/optimizer.py +403 -351
- meridian/analysis/summarizer.py +31 -16
- meridian/analysis/test_utils.py +96 -94
- meridian/analysis/visualizer.py +53 -54
- meridian/backend/__init__.py +975 -0
- meridian/backend/config.py +118 -0
- meridian/backend/test_utils.py +181 -0
- meridian/constants.py +71 -10
- meridian/data/input_data.py +99 -0
- meridian/data/test_utils.py +146 -12
- meridian/mlflow/autolog.py +2 -2
- meridian/model/adstock_hill.py +280 -33
- meridian/model/eda/__init__.py +17 -0
- meridian/model/eda/eda_engine.py +735 -0
- meridian/model/knots.py +525 -2
- meridian/model/media.py +62 -54
- meridian/model/model.py +224 -97
- meridian/model/model_test_data.py +331 -159
- meridian/model/posterior_sampler.py +388 -383
- meridian/model/prior_distribution.py +612 -177
- meridian/model/prior_sampler.py +65 -65
- meridian/model/spec.py +23 -3
- meridian/model/transformers.py +55 -49
- meridian/version.py +1 -1
- google_meridian-1.1.6.dist-info/RECORD +0 -47
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.1.dist-info}/WHEEL +0 -0
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.1.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,735 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Meridian EDA Engine."""
|
|
16
|
+
|
|
17
|
+
import dataclasses
|
|
18
|
+
import functools
|
|
19
|
+
from typing import Callable, Dict, Optional, TypeAlias
|
|
20
|
+
from meridian import constants
|
|
21
|
+
from meridian.model import model
|
|
22
|
+
from meridian.model import transformers
|
|
23
|
+
import numpy as np
|
|
24
|
+
import tensorflow as tf
|
|
25
|
+
import xarray as xr
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
_DEFAULT_DA_VAR_AGG_FUNCTION = np.sum
|
|
29
|
+
AggregationMap: TypeAlias = Dict[str, Callable[[xr.DataArray], np.ndarray]]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
33
|
+
class ReachFrequencyData:
|
|
34
|
+
"""Holds reach and frequency data arrays.
|
|
35
|
+
|
|
36
|
+
Attributes:
|
|
37
|
+
reach_raw_da: Raw reach data.
|
|
38
|
+
reach_scaled_da: Scaled reach data.
|
|
39
|
+
reach_raw_da_national: National raw reach data.
|
|
40
|
+
reach_scaled_da_national: National scaled reach data.
|
|
41
|
+
frequency_da: Frequency data.
|
|
42
|
+
frequency_da_national: National frequency data.
|
|
43
|
+
rf_impressions_scaled_da: Scaled reach * frequency impressions data.
|
|
44
|
+
rf_impressions_scaled_da_national: National scaled reach * frequency
|
|
45
|
+
impressions data.
|
|
46
|
+
rf_impressions_raw_da: Raw reach * frequency impressions data.
|
|
47
|
+
rf_impressions_raw_da_national: National raw reach * frequency impressions
|
|
48
|
+
data.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
reach_raw_da: xr.DataArray
|
|
52
|
+
reach_scaled_da: xr.DataArray
|
|
53
|
+
reach_raw_da_national: xr.DataArray
|
|
54
|
+
reach_scaled_da_national: xr.DataArray
|
|
55
|
+
frequency_da: xr.DataArray
|
|
56
|
+
frequency_da_national: xr.DataArray
|
|
57
|
+
rf_impressions_scaled_da: xr.DataArray
|
|
58
|
+
rf_impressions_scaled_da_national: xr.DataArray
|
|
59
|
+
rf_impressions_raw_da: xr.DataArray
|
|
60
|
+
rf_impressions_raw_da_national: xr.DataArray
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
64
|
+
class AggregationConfig:
|
|
65
|
+
"""Configuration for custom aggregation functions.
|
|
66
|
+
|
|
67
|
+
Attributes:
|
|
68
|
+
control_variables: A dictionary mapping control variable names to
|
|
69
|
+
aggregation functions. Defaults to `np.sum` if a variable is not
|
|
70
|
+
specified.
|
|
71
|
+
non_media_treatments: A dictionary mapping non-media variable names to
|
|
72
|
+
aggregation functions. Defaults to `np.sum` if a variable is not
|
|
73
|
+
specified.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
control_variables: AggregationMap = dataclasses.field(default_factory=dict)
|
|
77
|
+
non_media_treatments: AggregationMap = dataclasses.field(default_factory=dict)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class EDAEngine:
|
|
81
|
+
"""Meridian EDA Engine."""
|
|
82
|
+
|
|
83
|
+
def __init__(
|
|
84
|
+
self,
|
|
85
|
+
meridian: model.Meridian,
|
|
86
|
+
agg_config: AggregationConfig = AggregationConfig(),
|
|
87
|
+
):
|
|
88
|
+
self._meridian = meridian
|
|
89
|
+
self._agg_config = agg_config
|
|
90
|
+
|
|
91
|
+
@functools.cached_property
|
|
92
|
+
def controls_scaled_da(self) -> xr.DataArray | None:
|
|
93
|
+
if self._meridian.input_data.controls is None:
|
|
94
|
+
return None
|
|
95
|
+
controls_scaled_da = _data_array_like(
|
|
96
|
+
da=self._meridian.input_data.controls,
|
|
97
|
+
values=self._meridian.controls_scaled,
|
|
98
|
+
)
|
|
99
|
+
return controls_scaled_da
|
|
100
|
+
|
|
101
|
+
@functools.cached_property
|
|
102
|
+
def controls_scaled_da_national(self) -> xr.DataArray | None:
|
|
103
|
+
"""Returns the national controls data array."""
|
|
104
|
+
if self._meridian.input_data.controls is None:
|
|
105
|
+
return None
|
|
106
|
+
if self._meridian.is_national:
|
|
107
|
+
if self.controls_scaled_da is None:
|
|
108
|
+
# This case should be impossible given the check above.
|
|
109
|
+
raise RuntimeError(
|
|
110
|
+
'controls_scaled_da is None when controls is not None.'
|
|
111
|
+
)
|
|
112
|
+
return self.controls_scaled_da.squeeze(constants.GEO)
|
|
113
|
+
else:
|
|
114
|
+
return self._aggregate_and_scale_geo_da(
|
|
115
|
+
self._meridian.input_data.controls,
|
|
116
|
+
transformers.CenteringAndScalingTransformer,
|
|
117
|
+
constants.CONTROL_VARIABLE,
|
|
118
|
+
self._agg_config.control_variables,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
@functools.cached_property
|
|
122
|
+
def media_raw_da(self) -> xr.DataArray | None:
|
|
123
|
+
if self._meridian.input_data.media is None:
|
|
124
|
+
return None
|
|
125
|
+
return self._truncate_media_time(self._meridian.input_data.media)
|
|
126
|
+
|
|
127
|
+
@functools.cached_property
|
|
128
|
+
def media_scaled_da(self) -> xr.DataArray | None:
|
|
129
|
+
if self._meridian.input_data.media is None:
|
|
130
|
+
return None
|
|
131
|
+
media_scaled_da = _data_array_like(
|
|
132
|
+
da=self._meridian.input_data.media,
|
|
133
|
+
values=self._meridian.media_tensors.media_scaled,
|
|
134
|
+
)
|
|
135
|
+
return self._truncate_media_time(media_scaled_da)
|
|
136
|
+
|
|
137
|
+
@functools.cached_property
|
|
138
|
+
def media_spend_da(self) -> xr.DataArray | None:
|
|
139
|
+
if self._meridian.input_data.media_spend is None:
|
|
140
|
+
return None
|
|
141
|
+
media_spend_da = _data_array_like(
|
|
142
|
+
da=self._meridian.input_data.media_spend,
|
|
143
|
+
values=self._meridian.media_tensors.media_spend,
|
|
144
|
+
)
|
|
145
|
+
# No need to truncate the media time for media spend.
|
|
146
|
+
return media_spend_da
|
|
147
|
+
|
|
148
|
+
@functools.cached_property
|
|
149
|
+
def media_spend_da_national(self) -> xr.DataArray | None:
|
|
150
|
+
"""Returns the national media spend data array."""
|
|
151
|
+
if self._meridian.input_data.media_spend is None:
|
|
152
|
+
return None
|
|
153
|
+
if self._meridian.is_national:
|
|
154
|
+
if self.media_spend_da is None:
|
|
155
|
+
# This case should be impossible given the check above.
|
|
156
|
+
raise RuntimeError(
|
|
157
|
+
'media_spend_da is None when media_spend is not None.'
|
|
158
|
+
)
|
|
159
|
+
return self.media_spend_da.squeeze(constants.GEO)
|
|
160
|
+
else:
|
|
161
|
+
return self._aggregate_and_scale_geo_da(
|
|
162
|
+
self._meridian.input_data.media_spend,
|
|
163
|
+
None,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
@functools.cached_property
|
|
167
|
+
def media_raw_da_national(self) -> xr.DataArray | None:
|
|
168
|
+
if self.media_raw_da is None:
|
|
169
|
+
return None
|
|
170
|
+
if self._meridian.is_national:
|
|
171
|
+
return self.media_raw_da.squeeze(constants.GEO)
|
|
172
|
+
else:
|
|
173
|
+
# Note that media is summable by assumption.
|
|
174
|
+
return self._aggregate_and_scale_geo_da(
|
|
175
|
+
self.media_raw_da,
|
|
176
|
+
None,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
@functools.cached_property
|
|
180
|
+
def media_scaled_da_national(self) -> xr.DataArray | None:
|
|
181
|
+
if self.media_scaled_da is None:
|
|
182
|
+
return None
|
|
183
|
+
if self._meridian.is_national:
|
|
184
|
+
return self.media_scaled_da.squeeze(constants.GEO)
|
|
185
|
+
else:
|
|
186
|
+
# Note that media is summable by assumption.
|
|
187
|
+
return self._aggregate_and_scale_geo_da(
|
|
188
|
+
self.media_raw_da,
|
|
189
|
+
transformers.MediaTransformer,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
@functools.cached_property
|
|
193
|
+
def organic_media_raw_da(self) -> xr.DataArray | None:
|
|
194
|
+
if self._meridian.input_data.organic_media is None:
|
|
195
|
+
return None
|
|
196
|
+
return self._truncate_media_time(self._meridian.input_data.organic_media)
|
|
197
|
+
|
|
198
|
+
@functools.cached_property
|
|
199
|
+
def organic_media_scaled_da(self) -> xr.DataArray | None:
|
|
200
|
+
if self._meridian.input_data.organic_media is None:
|
|
201
|
+
return None
|
|
202
|
+
organic_media_scaled_da = _data_array_like(
|
|
203
|
+
da=self._meridian.input_data.organic_media,
|
|
204
|
+
values=self._meridian.organic_media_tensors.organic_media_scaled,
|
|
205
|
+
)
|
|
206
|
+
return self._truncate_media_time(organic_media_scaled_da)
|
|
207
|
+
|
|
208
|
+
@functools.cached_property
|
|
209
|
+
def organic_media_raw_da_national(self) -> xr.DataArray | None:
|
|
210
|
+
if self.organic_media_raw_da is None:
|
|
211
|
+
return None
|
|
212
|
+
if self._meridian.is_national:
|
|
213
|
+
return self.organic_media_raw_da.squeeze(constants.GEO)
|
|
214
|
+
else:
|
|
215
|
+
# Note that organic media is summable by assumption.
|
|
216
|
+
return self._aggregate_and_scale_geo_da(self.organic_media_raw_da, None)
|
|
217
|
+
|
|
218
|
+
@functools.cached_property
|
|
219
|
+
def organic_media_scaled_da_national(self) -> xr.DataArray | None:
|
|
220
|
+
if self.organic_media_scaled_da is None:
|
|
221
|
+
return None
|
|
222
|
+
if self._meridian.is_national:
|
|
223
|
+
return self.organic_media_scaled_da.squeeze(constants.GEO)
|
|
224
|
+
else:
|
|
225
|
+
# Note that organic media is summable by assumption.
|
|
226
|
+
return self._aggregate_and_scale_geo_da(
|
|
227
|
+
self.organic_media_raw_da,
|
|
228
|
+
transformers.MediaTransformer,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
@functools.cached_property
|
|
232
|
+
def non_media_scaled_da(self) -> xr.DataArray | None:
|
|
233
|
+
if self._meridian.input_data.non_media_treatments is None:
|
|
234
|
+
return None
|
|
235
|
+
non_media_scaled_da = _data_array_like(
|
|
236
|
+
da=self._meridian.input_data.non_media_treatments,
|
|
237
|
+
values=self._meridian.non_media_treatments_normalized,
|
|
238
|
+
)
|
|
239
|
+
return non_media_scaled_da
|
|
240
|
+
|
|
241
|
+
@functools.cached_property
|
|
242
|
+
def non_media_scaled_da_national(self) -> xr.DataArray | None:
|
|
243
|
+
"""Returns the national non-media treatment data array."""
|
|
244
|
+
if self._meridian.input_data.non_media_treatments is None:
|
|
245
|
+
return None
|
|
246
|
+
if self._meridian.is_national:
|
|
247
|
+
if self.non_media_scaled_da is None:
|
|
248
|
+
# This case should be impossible given the check above.
|
|
249
|
+
raise RuntimeError(
|
|
250
|
+
'non_media_scaled_da is None when non_media_treatments is not None.'
|
|
251
|
+
)
|
|
252
|
+
return self.non_media_scaled_da.squeeze(constants.GEO)
|
|
253
|
+
else:
|
|
254
|
+
return self._aggregate_and_scale_geo_da(
|
|
255
|
+
self._meridian.input_data.non_media_treatments,
|
|
256
|
+
transformers.CenteringAndScalingTransformer,
|
|
257
|
+
constants.NON_MEDIA_CHANNEL,
|
|
258
|
+
self._agg_config.non_media_treatments,
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
@functools.cached_property
|
|
262
|
+
def rf_spend_da(self) -> xr.DataArray | None:
|
|
263
|
+
if self._meridian.input_data.rf_spend is None:
|
|
264
|
+
return None
|
|
265
|
+
rf_spend_da = _data_array_like(
|
|
266
|
+
da=self._meridian.input_data.rf_spend,
|
|
267
|
+
values=self._meridian.rf_tensors.rf_spend,
|
|
268
|
+
)
|
|
269
|
+
return rf_spend_da
|
|
270
|
+
|
|
271
|
+
@functools.cached_property
|
|
272
|
+
def rf_spend_da_national(self) -> xr.DataArray | None:
|
|
273
|
+
if self._meridian.input_data.rf_spend is None:
|
|
274
|
+
return None
|
|
275
|
+
if self._meridian.is_national:
|
|
276
|
+
if self.rf_spend_da is None:
|
|
277
|
+
# This case should be impossible given the check above.
|
|
278
|
+
raise RuntimeError('rf_spend_da is None when rf_spend is not None.')
|
|
279
|
+
return self.rf_spend_da.squeeze(constants.GEO)
|
|
280
|
+
else:
|
|
281
|
+
return self._aggregate_and_scale_geo_da(
|
|
282
|
+
self._meridian.input_data.rf_spend, None
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
@functools.cached_property
|
|
286
|
+
def _rf_data(self) -> ReachFrequencyData | None:
|
|
287
|
+
if self._meridian.input_data.reach is None:
|
|
288
|
+
return None
|
|
289
|
+
return self._get_rf_data(
|
|
290
|
+
self._meridian.input_data.reach,
|
|
291
|
+
self._meridian.input_data.frequency,
|
|
292
|
+
is_organic=False,
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
@property
|
|
296
|
+
def reach_raw_da(self) -> xr.DataArray | None:
|
|
297
|
+
if self._rf_data is None:
|
|
298
|
+
return None
|
|
299
|
+
return self._rf_data.reach_raw_da
|
|
300
|
+
|
|
301
|
+
@property
|
|
302
|
+
def reach_scaled_da(self) -> xr.DataArray | None:
|
|
303
|
+
if self._rf_data is None:
|
|
304
|
+
return None
|
|
305
|
+
return self._rf_data.reach_scaled_da
|
|
306
|
+
|
|
307
|
+
@property
|
|
308
|
+
def reach_raw_da_national(self) -> xr.DataArray | None:
|
|
309
|
+
if self._rf_data is None:
|
|
310
|
+
return None
|
|
311
|
+
return self._rf_data.reach_raw_da_national
|
|
312
|
+
|
|
313
|
+
@property
|
|
314
|
+
def reach_scaled_da_national(self) -> xr.DataArray | None:
|
|
315
|
+
if self._rf_data is None:
|
|
316
|
+
return None
|
|
317
|
+
return self._rf_data.reach_scaled_da_national
|
|
318
|
+
|
|
319
|
+
@property
|
|
320
|
+
def frequency_da(self) -> xr.DataArray | None:
|
|
321
|
+
if self._rf_data is None:
|
|
322
|
+
return None
|
|
323
|
+
return self._rf_data.frequency_da
|
|
324
|
+
|
|
325
|
+
@property
|
|
326
|
+
def frequency_da_national(self) -> xr.DataArray | None:
|
|
327
|
+
if self._rf_data is None:
|
|
328
|
+
return None
|
|
329
|
+
return self._rf_data.frequency_da_national
|
|
330
|
+
|
|
331
|
+
@property
|
|
332
|
+
def rf_impressions_raw_da(self) -> xr.DataArray | None:
|
|
333
|
+
if self._rf_data is None:
|
|
334
|
+
return None
|
|
335
|
+
return self._rf_data.rf_impressions_raw_da
|
|
336
|
+
|
|
337
|
+
@property
|
|
338
|
+
def rf_impressions_raw_da_national(self) -> xr.DataArray | None:
|
|
339
|
+
if self._rf_data is None:
|
|
340
|
+
return None
|
|
341
|
+
return self._rf_data.rf_impressions_raw_da_national
|
|
342
|
+
|
|
343
|
+
@property
|
|
344
|
+
def rf_impressions_scaled_da(self) -> xr.DataArray | None:
|
|
345
|
+
if self._rf_data is None:
|
|
346
|
+
return None
|
|
347
|
+
return self._rf_data.rf_impressions_scaled_da
|
|
348
|
+
|
|
349
|
+
@property
|
|
350
|
+
def rf_impressions_scaled_da_national(self) -> xr.DataArray | None:
|
|
351
|
+
if self._rf_data is None:
|
|
352
|
+
return None
|
|
353
|
+
return self._rf_data.rf_impressions_scaled_da_national
|
|
354
|
+
|
|
355
|
+
@functools.cached_property
|
|
356
|
+
def _organic_rf_data(self) -> ReachFrequencyData | None:
|
|
357
|
+
if self._meridian.input_data.organic_reach is None:
|
|
358
|
+
return None
|
|
359
|
+
return self._get_rf_data(
|
|
360
|
+
self._meridian.input_data.organic_reach,
|
|
361
|
+
self._meridian.input_data.organic_frequency,
|
|
362
|
+
is_organic=True,
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
@property
|
|
366
|
+
def organic_reach_raw_da(self) -> xr.DataArray | None:
|
|
367
|
+
if self._organic_rf_data is None:
|
|
368
|
+
return None
|
|
369
|
+
return self._organic_rf_data.reach_raw_da
|
|
370
|
+
|
|
371
|
+
@property
|
|
372
|
+
def organic_reach_scaled_da(self) -> xr.DataArray | None:
|
|
373
|
+
if self._organic_rf_data is None:
|
|
374
|
+
return None
|
|
375
|
+
return self._organic_rf_data.reach_scaled_da
|
|
376
|
+
|
|
377
|
+
@property
|
|
378
|
+
def organic_reach_raw_da_national(self) -> xr.DataArray | None:
|
|
379
|
+
if self._organic_rf_data is None:
|
|
380
|
+
return None
|
|
381
|
+
return self._organic_rf_data.reach_raw_da_national
|
|
382
|
+
|
|
383
|
+
@property
|
|
384
|
+
def organic_reach_scaled_da_national(self) -> xr.DataArray | None:
|
|
385
|
+
if self._organic_rf_data is None:
|
|
386
|
+
return None
|
|
387
|
+
return self._organic_rf_data.reach_scaled_da_national
|
|
388
|
+
|
|
389
|
+
@property
|
|
390
|
+
def organic_rf_impressions_scaled_da(self) -> xr.DataArray | None:
|
|
391
|
+
if self._organic_rf_data is None:
|
|
392
|
+
return None
|
|
393
|
+
return self._organic_rf_data.rf_impressions_scaled_da
|
|
394
|
+
|
|
395
|
+
@property
|
|
396
|
+
def organic_rf_impressions_scaled_da_national(self) -> xr.DataArray | None:
|
|
397
|
+
if self._organic_rf_data is None:
|
|
398
|
+
return None
|
|
399
|
+
return self._organic_rf_data.rf_impressions_scaled_da_national
|
|
400
|
+
|
|
401
|
+
@property
|
|
402
|
+
def organic_frequency_da(self) -> xr.DataArray | None:
|
|
403
|
+
if self._organic_rf_data is None:
|
|
404
|
+
return None
|
|
405
|
+
return self._organic_rf_data.frequency_da
|
|
406
|
+
|
|
407
|
+
@property
|
|
408
|
+
def organic_frequency_da_national(self) -> xr.DataArray | None:
|
|
409
|
+
if self._organic_rf_data is None:
|
|
410
|
+
return None
|
|
411
|
+
return self._organic_rf_data.frequency_da_national
|
|
412
|
+
|
|
413
|
+
@property
|
|
414
|
+
def organic_rf_impressions_raw_da(self) -> xr.DataArray | None:
|
|
415
|
+
if self._organic_rf_data is None:
|
|
416
|
+
return None
|
|
417
|
+
return self._organic_rf_data.rf_impressions_raw_da
|
|
418
|
+
|
|
419
|
+
@property
|
|
420
|
+
def organic_rf_impressions_raw_da_national(self) -> xr.DataArray | None:
|
|
421
|
+
if self._organic_rf_data is None:
|
|
422
|
+
return None
|
|
423
|
+
return self._organic_rf_data.rf_impressions_raw_da_national
|
|
424
|
+
|
|
425
|
+
@functools.cached_property
|
|
426
|
+
def geo_population_da(self) -> xr.DataArray | None:
|
|
427
|
+
if self._meridian.is_national:
|
|
428
|
+
return None
|
|
429
|
+
return xr.DataArray(
|
|
430
|
+
self._meridian.population,
|
|
431
|
+
coords={constants.GEO: self._meridian.input_data.geo.values},
|
|
432
|
+
dims=[constants.GEO],
|
|
433
|
+
name=constants.POPULATION,
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
@functools.cached_property
|
|
437
|
+
def kpi_scaled_da(self) -> xr.DataArray:
|
|
438
|
+
return _data_array_like(
|
|
439
|
+
da=self._meridian.input_data.kpi,
|
|
440
|
+
values=self._meridian.kpi_scaled,
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
@functools.cached_property
|
|
444
|
+
def kpi_scaled_da_national(self) -> xr.DataArray:
|
|
445
|
+
if self._meridian.is_national:
|
|
446
|
+
return self.kpi_scaled_da.squeeze(constants.GEO)
|
|
447
|
+
else:
|
|
448
|
+
# Note that kpi is summable by assumption.
|
|
449
|
+
return self._aggregate_and_scale_geo_da(
|
|
450
|
+
self._meridian.input_data.kpi,
|
|
451
|
+
transformers.CenteringAndScalingTransformer,
|
|
452
|
+
)
|
|
453
|
+
|
|
454
|
+
@functools.cached_property
|
|
455
|
+
def treatment_control_scaled_ds(self) -> xr.Dataset:
|
|
456
|
+
"""Returns a Dataset containing all scaled treatments and controls.
|
|
457
|
+
|
|
458
|
+
This includes media, RF impressions, organic media, organic RF impressions,
|
|
459
|
+
non-media treatments, and control variables, all at the geo level.
|
|
460
|
+
"""
|
|
461
|
+
to_merge = [
|
|
462
|
+
da
|
|
463
|
+
for da in [
|
|
464
|
+
self.media_scaled_da,
|
|
465
|
+
self.rf_impressions_scaled_da,
|
|
466
|
+
self.organic_media_scaled_da,
|
|
467
|
+
self.organic_rf_impressions_scaled_da,
|
|
468
|
+
self.controls_scaled_da,
|
|
469
|
+
self.non_media_scaled_da,
|
|
470
|
+
]
|
|
471
|
+
if da is not None
|
|
472
|
+
]
|
|
473
|
+
return xr.merge(to_merge, join='inner')
|
|
474
|
+
|
|
475
|
+
@functools.cached_property
|
|
476
|
+
def treatment_control_scaled_ds_national(self) -> xr.Dataset:
|
|
477
|
+
"""Returns a Dataset containing all scaled treatments and controls.
|
|
478
|
+
|
|
479
|
+
This includes media, RF impressions, organic media, organic RF impressions,
|
|
480
|
+
non-media treatments, and control variables, all at the national level.
|
|
481
|
+
"""
|
|
482
|
+
to_merge_national = [
|
|
483
|
+
da
|
|
484
|
+
for da in [
|
|
485
|
+
self.media_scaled_da_national,
|
|
486
|
+
self.rf_impressions_scaled_da_national,
|
|
487
|
+
self.organic_media_scaled_da_national,
|
|
488
|
+
self.organic_rf_impressions_scaled_da_national,
|
|
489
|
+
self.controls_scaled_da_national,
|
|
490
|
+
self.non_media_scaled_da_national,
|
|
491
|
+
]
|
|
492
|
+
if da is not None
|
|
493
|
+
]
|
|
494
|
+
return xr.merge(to_merge_national, join='inner')
|
|
495
|
+
|
|
496
|
+
def _truncate_media_time(self, da: xr.DataArray) -> xr.DataArray:
|
|
497
|
+
"""Truncates the first `start` elements of the media time of a variable."""
|
|
498
|
+
# This should not happen. If it does, it means this function is mis-used.
|
|
499
|
+
if constants.MEDIA_TIME not in da.coords:
|
|
500
|
+
raise ValueError(
|
|
501
|
+
f'Variable does not have a media time coordinate: {da.name}.'
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
start = self._meridian.n_media_times - self._meridian.n_times
|
|
505
|
+
da = da.copy().isel({constants.MEDIA_TIME: slice(start, None)})
|
|
506
|
+
da = da.rename({constants.MEDIA_TIME: constants.TIME})
|
|
507
|
+
return da
|
|
508
|
+
|
|
509
|
+
def _scale_xarray(
|
|
510
|
+
self,
|
|
511
|
+
xarray: xr.DataArray,
|
|
512
|
+
transformer_class: Optional[type[transformers.TensorTransformer]],
|
|
513
|
+
population: tf.Tensor = tf.constant([1.0], dtype=tf.float32),
|
|
514
|
+
):
|
|
515
|
+
"""Scales xarray values with a TensorTransformer."""
|
|
516
|
+
da = xarray.copy()
|
|
517
|
+
|
|
518
|
+
if transformer_class is None:
|
|
519
|
+
return da
|
|
520
|
+
elif transformer_class is transformers.CenteringAndScalingTransformer:
|
|
521
|
+
xarray_transformer = transformers.CenteringAndScalingTransformer(
|
|
522
|
+
tensor=da.values, population=population
|
|
523
|
+
)
|
|
524
|
+
elif transformer_class is transformers.MediaTransformer:
|
|
525
|
+
xarray_transformer = transformers.MediaTransformer(
|
|
526
|
+
media=da.values, population=population
|
|
527
|
+
)
|
|
528
|
+
else:
|
|
529
|
+
raise ValueError(
|
|
530
|
+
'Unknown transformer class: '
|
|
531
|
+
+ str(transformer_class)
|
|
532
|
+
+ '.\nMust be one of: CenteringAndScalingTransformer or'
|
|
533
|
+
' MediaTransformer.'
|
|
534
|
+
)
|
|
535
|
+
da.values = xarray_transformer.forward(da.values)
|
|
536
|
+
return da
|
|
537
|
+
|
|
538
|
+
def _aggregate_variables(
|
|
539
|
+
self,
|
|
540
|
+
da_geo: xr.DataArray,
|
|
541
|
+
channel_dim: str,
|
|
542
|
+
da_var_agg_map: AggregationMap,
|
|
543
|
+
keepdims: bool = True,
|
|
544
|
+
) -> xr.DataArray:
|
|
545
|
+
"""Aggregates variables within a DataArray based on user-defined functions.
|
|
546
|
+
|
|
547
|
+
Args:
|
|
548
|
+
da_geo: The geo-level DataArray containing multiple variables along
|
|
549
|
+
channel_dim.
|
|
550
|
+
channel_dim: The name of the dimension coordinate to aggregate over (e.g.,
|
|
551
|
+
constants.CONTROL_VARIABLE).
|
|
552
|
+
da_var_agg_map: A dictionary mapping dataArray variable names to
|
|
553
|
+
aggregation functions.
|
|
554
|
+
keepdims: Whether to keep the dimensions of the aggregated DataArray.
|
|
555
|
+
|
|
556
|
+
Returns:
|
|
557
|
+
An xr.DataArray aggregated to the national level, with each variable
|
|
558
|
+
aggregated according to the da_var_agg_map.
|
|
559
|
+
"""
|
|
560
|
+
agg_results = []
|
|
561
|
+
for var_name in da_geo[channel_dim].values:
|
|
562
|
+
var_data = da_geo.sel({channel_dim: var_name})
|
|
563
|
+
agg_func = da_var_agg_map.get(var_name, _DEFAULT_DA_VAR_AGG_FUNCTION)
|
|
564
|
+
# Apply the aggregation function over the GEO dimension
|
|
565
|
+
aggregated_data = var_data.reduce(
|
|
566
|
+
agg_func, dim=constants.GEO, keepdims=keepdims
|
|
567
|
+
)
|
|
568
|
+
agg_results.append(aggregated_data)
|
|
569
|
+
|
|
570
|
+
# Combine the aggregated variables back into a single DataArray
|
|
571
|
+
return xr.concat(agg_results, dim=channel_dim).transpose(..., channel_dim)
|
|
572
|
+
|
|
573
|
+
def _aggregate_and_scale_geo_da(
|
|
574
|
+
self,
|
|
575
|
+
da_geo: xr.DataArray,
|
|
576
|
+
transformer_class: Optional[type[transformers.TensorTransformer]],
|
|
577
|
+
channel_dim: Optional[str] = None,
|
|
578
|
+
da_var_agg_map: Optional[AggregationMap] = None,
|
|
579
|
+
) -> xr.DataArray:
|
|
580
|
+
"""Aggregate geo-level xr.DataArray to national level and then scale values.
|
|
581
|
+
|
|
582
|
+
Args:
|
|
583
|
+
da_geo: The geo-level DataArray to convert.
|
|
584
|
+
transformer_class: The TensorTransformer class to apply after summing to
|
|
585
|
+
national level. Must be None, CenteringAndScalingTransformer, or
|
|
586
|
+
MediaTransformer.
|
|
587
|
+
channel_dim: The name of the dimension coordinate to aggregate over (e.g.,
|
|
588
|
+
constants.CONTROL_VARIABLE). If None, standard sum aggregation is used.
|
|
589
|
+
da_var_agg_map: A dictionary mapping dataArray variable names to
|
|
590
|
+
aggregation functions. Used only if channel_dim is not None.
|
|
591
|
+
|
|
592
|
+
Returns:
|
|
593
|
+
An xr.DataArray representing the aggregated and scaled national-level
|
|
594
|
+
data.
|
|
595
|
+
"""
|
|
596
|
+
temp_geo_dim = constants.NATIONAL_MODEL_DEFAULT_GEO_NAME
|
|
597
|
+
|
|
598
|
+
if da_var_agg_map is None:
|
|
599
|
+
da_var_agg_map = {}
|
|
600
|
+
|
|
601
|
+
if channel_dim is not None:
|
|
602
|
+
da_national = self._aggregate_variables(
|
|
603
|
+
da_geo, channel_dim, da_var_agg_map
|
|
604
|
+
)
|
|
605
|
+
else:
|
|
606
|
+
# Default to sum aggregation if no channel dimension is provided
|
|
607
|
+
da_national = da_geo.sum(
|
|
608
|
+
dim=constants.GEO, keepdims=True, skipna=False, keep_attrs=True
|
|
609
|
+
)
|
|
610
|
+
|
|
611
|
+
da_national = da_national.assign_coords({constants.GEO: [temp_geo_dim]})
|
|
612
|
+
da_national.values = tf.cast(da_national.values, tf.float32)
|
|
613
|
+
da_national = self._scale_xarray(da_national, transformer_class)
|
|
614
|
+
|
|
615
|
+
return da_national.sel({constants.GEO: temp_geo_dim}, drop=True)
|
|
616
|
+
|
|
617
|
+
def _get_rf_data(
|
|
618
|
+
self,
|
|
619
|
+
reach_raw_da: xr.DataArray,
|
|
620
|
+
freq_raw_da: xr.DataArray,
|
|
621
|
+
is_organic: bool,
|
|
622
|
+
) -> ReachFrequencyData:
|
|
623
|
+
"""Get impressions and frequencies data arrays for RF channels."""
|
|
624
|
+
if is_organic:
|
|
625
|
+
scaled_reach_values = (
|
|
626
|
+
self._meridian.organic_rf_tensors.organic_reach_scaled
|
|
627
|
+
)
|
|
628
|
+
else:
|
|
629
|
+
scaled_reach_values = self._meridian.rf_tensors.reach_scaled
|
|
630
|
+
reach_scaled_da = _data_array_like(
|
|
631
|
+
da=reach_raw_da, values=scaled_reach_values
|
|
632
|
+
)
|
|
633
|
+
# Truncate the media time for reach and scaled reach.
|
|
634
|
+
reach_raw_da = self._truncate_media_time(reach_raw_da)
|
|
635
|
+
reach_scaled_da = self._truncate_media_time(reach_scaled_da)
|
|
636
|
+
|
|
637
|
+
# The geo level frequency
|
|
638
|
+
frequency_da = self._truncate_media_time(freq_raw_da)
|
|
639
|
+
|
|
640
|
+
# The raw geo level impression
|
|
641
|
+
# It's equal to reach * frequency.
|
|
642
|
+
impressions_raw_da = reach_raw_da * frequency_da
|
|
643
|
+
impressions_raw_da.name = (
|
|
644
|
+
constants.ORGANIC_RF_IMPRESSIONS
|
|
645
|
+
if is_organic
|
|
646
|
+
else constants.RF_IMPRESSIONS
|
|
647
|
+
)
|
|
648
|
+
impressions_raw_da.values = tf.cast(impressions_raw_da.values, tf.float32)
|
|
649
|
+
|
|
650
|
+
if self._meridian.is_national:
|
|
651
|
+
reach_raw_da_national = reach_raw_da.squeeze(constants.GEO)
|
|
652
|
+
reach_scaled_da_national = reach_scaled_da.squeeze(constants.GEO)
|
|
653
|
+
impressions_raw_da_national = impressions_raw_da.squeeze(constants.GEO)
|
|
654
|
+
frequency_da_national = frequency_da.squeeze(constants.GEO)
|
|
655
|
+
|
|
656
|
+
# Scaled impressions
|
|
657
|
+
impressions_scaled_da = self._scale_xarray(
|
|
658
|
+
impressions_raw_da, transformers.MediaTransformer
|
|
659
|
+
)
|
|
660
|
+
impressions_scaled_da_national = impressions_scaled_da.squeeze(
|
|
661
|
+
constants.GEO
|
|
662
|
+
)
|
|
663
|
+
else:
|
|
664
|
+
reach_raw_da_national = self._aggregate_and_scale_geo_da(
|
|
665
|
+
reach_raw_da, None
|
|
666
|
+
)
|
|
667
|
+
reach_scaled_da_national = self._aggregate_and_scale_geo_da(
|
|
668
|
+
reach_raw_da, transformers.MediaTransformer
|
|
669
|
+
)
|
|
670
|
+
impressions_raw_da_national = self._aggregate_and_scale_geo_da(
|
|
671
|
+
impressions_raw_da, None
|
|
672
|
+
)
|
|
673
|
+
|
|
674
|
+
# National frequency is a weighted average of geo frequencies,
|
|
675
|
+
# weighted by reach.
|
|
676
|
+
frequency_da_national = xr.where(
|
|
677
|
+
reach_raw_da_national == 0.0,
|
|
678
|
+
0.0,
|
|
679
|
+
impressions_raw_da_national / reach_raw_da_national,
|
|
680
|
+
)
|
|
681
|
+
frequency_da_national.name = (
|
|
682
|
+
constants.ORGANIC_PREFIX if is_organic else ''
|
|
683
|
+
) + constants.FREQUENCY
|
|
684
|
+
frequency_da_national.values = tf.cast(
|
|
685
|
+
frequency_da_national.values, tf.float32
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
# Scale the impressions by population
|
|
689
|
+
impressions_scaled_da = self._scale_xarray(
|
|
690
|
+
impressions_raw_da,
|
|
691
|
+
transformers.MediaTransformer,
|
|
692
|
+
population=self._meridian.population,
|
|
693
|
+
)
|
|
694
|
+
|
|
695
|
+
# Scale the national impressions
|
|
696
|
+
impressions_scaled_da_national = self._aggregate_and_scale_geo_da(
|
|
697
|
+
impressions_raw_da,
|
|
698
|
+
transformers.MediaTransformer,
|
|
699
|
+
)
|
|
700
|
+
|
|
701
|
+
return ReachFrequencyData(
|
|
702
|
+
reach_raw_da=reach_raw_da,
|
|
703
|
+
reach_scaled_da=reach_scaled_da,
|
|
704
|
+
reach_raw_da_national=reach_raw_da_national,
|
|
705
|
+
reach_scaled_da_national=reach_scaled_da_national,
|
|
706
|
+
frequency_da=frequency_da,
|
|
707
|
+
frequency_da_national=frequency_da_national,
|
|
708
|
+
rf_impressions_scaled_da=impressions_scaled_da,
|
|
709
|
+
rf_impressions_scaled_da_national=impressions_scaled_da_national,
|
|
710
|
+
rf_impressions_raw_da=impressions_raw_da,
|
|
711
|
+
rf_impressions_raw_da_national=impressions_raw_da_national,
|
|
712
|
+
)
|
|
713
|
+
|
|
714
|
+
|
|
715
|
+
def _data_array_like(
|
|
716
|
+
*, da: xr.DataArray, values: np.ndarray | tf.Tensor
|
|
717
|
+
) -> xr.DataArray:
|
|
718
|
+
"""Returns a DataArray from `values` with the same structure as `da`.
|
|
719
|
+
|
|
720
|
+
Args:
|
|
721
|
+
da: The DataArray whose structure (dimensions, coordinates, name, and attrs)
|
|
722
|
+
will be used for the new DataArray.
|
|
723
|
+
values: The numpy array or tensorflow tensor to use as the values for the
|
|
724
|
+
new DataArray.
|
|
725
|
+
|
|
726
|
+
Returns:
|
|
727
|
+
A new DataArray with the provided `values` and the same structure as `da`.
|
|
728
|
+
"""
|
|
729
|
+
return xr.DataArray(
|
|
730
|
+
values,
|
|
731
|
+
coords=da.coords,
|
|
732
|
+
dims=da.dims,
|
|
733
|
+
name=da.name,
|
|
734
|
+
attrs=da.attrs,
|
|
735
|
+
)
|