google-meridian 1.1.1__py3-none-any.whl → 1.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.1.1.dist-info → google_meridian-1.1.3.dist-info}/METADATA +6 -2
- {google_meridian-1.1.1.dist-info → google_meridian-1.1.3.dist-info}/RECORD +23 -17
- meridian/__init__.py +6 -4
- meridian/analysis/analyzer.py +61 -19
- meridian/analysis/optimizer.py +75 -44
- meridian/analysis/visualizer.py +15 -5
- meridian/constants.py +1 -0
- meridian/data/__init__.py +3 -0
- meridian/data/data_frame_input_data_builder.py +614 -0
- meridian/data/input_data_builder.py +823 -0
- meridian/data/load.py +138 -402
- meridian/data/nd_array_input_data_builder.py +509 -0
- meridian/mlflow/__init__.py +17 -0
- meridian/mlflow/autolog.py +206 -0
- meridian/model/media.py +7 -0
- meridian/model/model.py +32 -26
- meridian/model/posterior_sampler.py +13 -9
- meridian/model/prior_sampler.py +4 -6
- meridian/model/spec.py +17 -7
- meridian/version.py +17 -0
- {google_meridian-1.1.1.dist-info → google_meridian-1.1.3.dist-info}/WHEEL +0 -0
- {google_meridian-1.1.1.dist-info → google_meridian-1.1.3.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.1.1.dist-info → google_meridian-1.1.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,823 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""This module defines a Builder API for Meridian `InputData`.
|
|
16
|
+
|
|
17
|
+
The Builder API for `InputData` exposes piecewise data ingestion with its own
|
|
18
|
+
validation logic and an overall final validation logic before a valid
|
|
19
|
+
`InputData` is constructed.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
import abc
|
|
23
|
+
from collections.abc import Sequence
|
|
24
|
+
import datetime
|
|
25
|
+
import warnings
|
|
26
|
+
from meridian import constants
|
|
27
|
+
from meridian.data import input_data
|
|
28
|
+
from meridian.data import time_coordinates as tc
|
|
29
|
+
import natsort
|
|
30
|
+
import numpy as np
|
|
31
|
+
import xarray as xr
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
__all__ = [
|
|
35
|
+
'InputDataBuilder',
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class InputDataBuilder(abc.ABC):
|
|
40
|
+
"""Abstract base class for `InputData` builders."""
|
|
41
|
+
|
|
42
|
+
def __init__(self, kpi_type: str):
|
|
43
|
+
self._kpi_type = kpi_type
|
|
44
|
+
|
|
45
|
+
# These working attributes are going to be set along the way as the builder
|
|
46
|
+
# is provided piecemeal with the user's input data.
|
|
47
|
+
# In the course of processing each DataFrame piece, dimension coordinates
|
|
48
|
+
# will be discovered and set with, e.g., `self.time_coords = ...`.
|
|
49
|
+
# The setter code will perform basic validation
|
|
50
|
+
# checks, e.g.:
|
|
51
|
+
# * If previous dataframe input already set it, then it should be consistent
|
|
52
|
+
# * If not, set it for the first time.
|
|
53
|
+
# * When setting, make consistency checks against other dimensions
|
|
54
|
+
# * etc...
|
|
55
|
+
|
|
56
|
+
# Working dimensions and their coordinates.
|
|
57
|
+
self._time_coords: Sequence[str] = None
|
|
58
|
+
self._media_time_coords: Sequence[str] = None
|
|
59
|
+
self._geos: Sequence[str] = None
|
|
60
|
+
|
|
61
|
+
# Working data arrays (components of the final `InputData` object)
|
|
62
|
+
self._kpi: xr.DataArray = None
|
|
63
|
+
self._controls: xr.DataArray = None
|
|
64
|
+
self._population: xr.DataArray = None
|
|
65
|
+
self._revenue_per_kpi: xr.DataArray = None
|
|
66
|
+
self._media: xr.DataArray = None
|
|
67
|
+
self._media_spend: xr.DataArray = None
|
|
68
|
+
self._reach: xr.DataArray = None
|
|
69
|
+
self._frequency: xr.DataArray = None
|
|
70
|
+
self._rf_spend: xr.DataArray = None
|
|
71
|
+
self._organic_media: xr.DataArray = None
|
|
72
|
+
self._organic_reach: xr.DataArray = None
|
|
73
|
+
self._organic_frequency: xr.DataArray = None
|
|
74
|
+
self._non_media_treatments: xr.DataArray = None
|
|
75
|
+
|
|
76
|
+
@property
|
|
77
|
+
def time_coords(self) -> Sequence[str]:
|
|
78
|
+
return self._time_coords
|
|
79
|
+
|
|
80
|
+
@time_coords.setter
|
|
81
|
+
def time_coords(self, value: Sequence[str]):
|
|
82
|
+
if len(value) != len(set(value)):
|
|
83
|
+
raise ValueError('`times` coords must be unique.')
|
|
84
|
+
if self.time_coords is not None and set(self.time_coords) != set(value):
|
|
85
|
+
raise ValueError(f'`times` coords already set to {self.time_coords}.')
|
|
86
|
+
if self.media_time_coords is not None and not set(value).issubset(
|
|
87
|
+
self.media_time_coords
|
|
88
|
+
):
|
|
89
|
+
raise ValueError(
|
|
90
|
+
'`times` coords must be subset of previously set `media_times`'
|
|
91
|
+
' coords.'
|
|
92
|
+
)
|
|
93
|
+
if self.media_time_coords is not None:
|
|
94
|
+
self._validate_lagged_media(
|
|
95
|
+
media_time_coords=self.media_time_coords, time_coords=value
|
|
96
|
+
)
|
|
97
|
+
_ = tc.TimeCoordinates.from_dates(sorted(value)).interval_days
|
|
98
|
+
self._time_coords = value
|
|
99
|
+
|
|
100
|
+
@property
|
|
101
|
+
def media_time_coords(self) -> Sequence[str]:
|
|
102
|
+
return self._media_time_coords
|
|
103
|
+
|
|
104
|
+
@media_time_coords.setter
|
|
105
|
+
def media_time_coords(self, value: Sequence[str]):
|
|
106
|
+
if len(value) != len(set(value)):
|
|
107
|
+
raise ValueError('`media_times` coords must be unique.')
|
|
108
|
+
if self.media_time_coords is not None and set(
|
|
109
|
+
self.media_time_coords
|
|
110
|
+
) != set(value):
|
|
111
|
+
raise ValueError(
|
|
112
|
+
f'`media_times` coords already set to {self.media_time_coords}.'
|
|
113
|
+
)
|
|
114
|
+
if self.time_coords is not None and not set(value).issuperset(
|
|
115
|
+
self.time_coords
|
|
116
|
+
):
|
|
117
|
+
raise ValueError(
|
|
118
|
+
'`media_times` coords must be superset of previously set `times`'
|
|
119
|
+
' coords.'
|
|
120
|
+
)
|
|
121
|
+
if self.time_coords is not None:
|
|
122
|
+
self._validate_lagged_media(
|
|
123
|
+
media_time_coords=value, time_coords=self.time_coords
|
|
124
|
+
)
|
|
125
|
+
_ = tc.TimeCoordinates.from_dates(sorted(value)).interval_days
|
|
126
|
+
self._media_time_coords = value
|
|
127
|
+
|
|
128
|
+
@property
|
|
129
|
+
def geos(self) -> Sequence[str]:
|
|
130
|
+
return self._geos
|
|
131
|
+
|
|
132
|
+
@geos.setter
|
|
133
|
+
def geos(self, value: Sequence[str]):
|
|
134
|
+
if len(value) != len(set(value)):
|
|
135
|
+
raise ValueError('Geos must be unique.')
|
|
136
|
+
if self.geos is not None and set(self.geos) != set(value):
|
|
137
|
+
raise ValueError(f'geos already set to {self.geos}.')
|
|
138
|
+
self._geos = value
|
|
139
|
+
|
|
140
|
+
@property
|
|
141
|
+
def kpi(self) -> xr.DataArray:
|
|
142
|
+
return self._kpi
|
|
143
|
+
|
|
144
|
+
@kpi.setter
|
|
145
|
+
def kpi(self, kpi: xr.DataArray):
|
|
146
|
+
"""Sets the `kpi` data array.
|
|
147
|
+
|
|
148
|
+
`kpi` must have the following `DataArray` signature:
|
|
149
|
+
|
|
150
|
+
```
|
|
151
|
+
xarray.DataArray(
|
|
152
|
+
data=...,
|
|
153
|
+
name='kpi',
|
|
154
|
+
dims=['geo', 'time'],
|
|
155
|
+
coords={
|
|
156
|
+
'geo': ...,
|
|
157
|
+
'time': ...,
|
|
158
|
+
},
|
|
159
|
+
)
|
|
160
|
+
```
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
kpi: Kpi DataArray.
|
|
164
|
+
"""
|
|
165
|
+
self._validate_set('KPI', self.kpi)
|
|
166
|
+
|
|
167
|
+
self._kpi = self._normalize_coords(kpi, constants.TIME)
|
|
168
|
+
self.geos = self.kpi.coords[constants.GEO].values.tolist()
|
|
169
|
+
self.time_coords = self.kpi.coords[constants.TIME].values.tolist()
|
|
170
|
+
|
|
171
|
+
@property
|
|
172
|
+
def controls(self) -> xr.DataArray:
|
|
173
|
+
return self._controls
|
|
174
|
+
|
|
175
|
+
@controls.setter
|
|
176
|
+
def controls(self, controls: xr.DataArray):
|
|
177
|
+
"""Sets the `controls` data array.
|
|
178
|
+
|
|
179
|
+
`controls` must have the following `DataArray` signature:
|
|
180
|
+
|
|
181
|
+
```
|
|
182
|
+
xarray.DataArray(
|
|
183
|
+
data=...,
|
|
184
|
+
name='controls',
|
|
185
|
+
dims=['geo', 'time', 'control_variable'],
|
|
186
|
+
coords={
|
|
187
|
+
'geo': ...,
|
|
188
|
+
'time': ...,
|
|
189
|
+
'control_variable': ...,
|
|
190
|
+
},
|
|
191
|
+
)
|
|
192
|
+
```
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
controls: Controls DataArray.
|
|
196
|
+
"""
|
|
197
|
+
self._validate_set('Controls', self.controls)
|
|
198
|
+
|
|
199
|
+
self._controls = self._normalize_coords(controls, constants.TIME)
|
|
200
|
+
self.geos = self.controls.coords[constants.GEO].values.tolist()
|
|
201
|
+
self.time_coords = self.controls.coords[constants.TIME].values.tolist()
|
|
202
|
+
|
|
203
|
+
@property
|
|
204
|
+
def population(self) -> xr.DataArray:
|
|
205
|
+
return self._population
|
|
206
|
+
|
|
207
|
+
@population.setter
|
|
208
|
+
def population(self, population: xr.DataArray):
|
|
209
|
+
"""Sets the `media` data array.
|
|
210
|
+
|
|
211
|
+
`population` must have the following `DataArray` signature:
|
|
212
|
+
|
|
213
|
+
```
|
|
214
|
+
xarray.DataArray(
|
|
215
|
+
data=...,
|
|
216
|
+
name='population',
|
|
217
|
+
dims=['geo'],
|
|
218
|
+
coords={
|
|
219
|
+
'geo': ...,
|
|
220
|
+
},
|
|
221
|
+
)
|
|
222
|
+
```
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
population: Population DataArray.
|
|
226
|
+
"""
|
|
227
|
+
self._validate_set('Population', self.population)
|
|
228
|
+
|
|
229
|
+
self._population = self._normalize_coords(population)
|
|
230
|
+
self.geos = self.population.coords[constants.GEO].values.tolist()
|
|
231
|
+
|
|
232
|
+
@property
|
|
233
|
+
def revenue_per_kpi(self) -> xr.DataArray:
|
|
234
|
+
return self._revenue_per_kpi
|
|
235
|
+
|
|
236
|
+
@revenue_per_kpi.setter
|
|
237
|
+
def revenue_per_kpi(self, revenue_per_kpi: xr.DataArray):
|
|
238
|
+
"""Sets the `revenue_per_kpi` data array.
|
|
239
|
+
|
|
240
|
+
`revenue_per_kpi` must have the following `DataArray` signature:
|
|
241
|
+
|
|
242
|
+
```
|
|
243
|
+
xarray.DataArray(
|
|
244
|
+
data=...,
|
|
245
|
+
name='revenue_per_kpi',
|
|
246
|
+
dims=['geo', 'time'],
|
|
247
|
+
coords={
|
|
248
|
+
'geo': ...,
|
|
249
|
+
'time': ...,
|
|
250
|
+
},
|
|
251
|
+
)
|
|
252
|
+
```
|
|
253
|
+
Args:
|
|
254
|
+
revenue_per_kpi: Revenue per kpi DataArray.
|
|
255
|
+
"""
|
|
256
|
+
self._validate_set('Revenue per KPI', self.revenue_per_kpi)
|
|
257
|
+
|
|
258
|
+
self._revenue_per_kpi = self._normalize_coords(
|
|
259
|
+
revenue_per_kpi, constants.TIME
|
|
260
|
+
)
|
|
261
|
+
self.geos = self.revenue_per_kpi.coords[constants.GEO].values.tolist()
|
|
262
|
+
self.time_coords = self.revenue_per_kpi.coords[
|
|
263
|
+
constants.TIME
|
|
264
|
+
].values.tolist()
|
|
265
|
+
|
|
266
|
+
@property
|
|
267
|
+
def media(self) -> xr.DataArray:
|
|
268
|
+
return self._media
|
|
269
|
+
|
|
270
|
+
@media.setter
|
|
271
|
+
def media(self, media: xr.DataArray):
|
|
272
|
+
"""Sets the `media` data array.
|
|
273
|
+
|
|
274
|
+
`media` must have the following `DataArray` signature:
|
|
275
|
+
|
|
276
|
+
```
|
|
277
|
+
xarray.DataArray(
|
|
278
|
+
data=...,
|
|
279
|
+
name='media',
|
|
280
|
+
dims=['geo', 'media_time', 'media_channel'],
|
|
281
|
+
coords={
|
|
282
|
+
'geo': ...,
|
|
283
|
+
'media_time': ...,
|
|
284
|
+
'media_channel': ...,
|
|
285
|
+
},
|
|
286
|
+
)
|
|
287
|
+
```
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
media: Media DataArray.
|
|
291
|
+
"""
|
|
292
|
+
self._validate_set('Media', self.media)
|
|
293
|
+
self._validate_channels_consistency(
|
|
294
|
+
constants.MEDIA_CHANNEL, [media, self.media_spend]
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
self._media = self._normalize_coords(media, constants.MEDIA_TIME)
|
|
298
|
+
self.geos = self.media.coords[constants.GEO].values.tolist()
|
|
299
|
+
self.media_time_coords = self.media.coords[
|
|
300
|
+
constants.MEDIA_TIME
|
|
301
|
+
].values.tolist()
|
|
302
|
+
|
|
303
|
+
@property
|
|
304
|
+
def media_spend(self) -> xr.DataArray:
|
|
305
|
+
return self._media_spend
|
|
306
|
+
|
|
307
|
+
@media_spend.setter
|
|
308
|
+
def media_spend(self, media_spend: xr.DataArray):
|
|
309
|
+
"""Sets the `media_spend` data array.
|
|
310
|
+
|
|
311
|
+
`media_spend` must have the following `DataArray` signature:
|
|
312
|
+
|
|
313
|
+
```
|
|
314
|
+
xarray.DataArray(
|
|
315
|
+
data=...,
|
|
316
|
+
name='media_spend',
|
|
317
|
+
dims=['geo', 'time', 'media_channel'],
|
|
318
|
+
coords={
|
|
319
|
+
'geo': ...,
|
|
320
|
+
'time': ...,
|
|
321
|
+
'media_channel': ...,
|
|
322
|
+
},
|
|
323
|
+
)
|
|
324
|
+
```
|
|
325
|
+
|
|
326
|
+
Args:
|
|
327
|
+
media_spend: Media spend DataArray.
|
|
328
|
+
"""
|
|
329
|
+
self._validate_set('Media spend', self.media_spend)
|
|
330
|
+
self._validate_channels_consistency(
|
|
331
|
+
constants.MEDIA_CHANNEL, [media_spend, self.media]
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
self._media_spend = self._normalize_coords(media_spend, constants.TIME)
|
|
335
|
+
self.geos = self.media_spend.coords[constants.GEO].values.tolist()
|
|
336
|
+
self.time_coords = self.media_spend.coords[constants.TIME].values.tolist()
|
|
337
|
+
|
|
338
|
+
@property
|
|
339
|
+
def reach(self) -> xr.DataArray:
|
|
340
|
+
return self._reach
|
|
341
|
+
|
|
342
|
+
@reach.setter
|
|
343
|
+
def reach(self, reach: xr.DataArray):
|
|
344
|
+
"""Sets the `reach` data array.
|
|
345
|
+
|
|
346
|
+
`reach` must have the following `DataArray` signature:
|
|
347
|
+
|
|
348
|
+
```
|
|
349
|
+
xarray.DataArray(
|
|
350
|
+
data=...,
|
|
351
|
+
name='reach',
|
|
352
|
+
dims=['geo', 'media_time', 'rf_channel'],
|
|
353
|
+
coords={
|
|
354
|
+
'geo': ...,
|
|
355
|
+
'media_time': ...,
|
|
356
|
+
'rf_channel': ...,
|
|
357
|
+
},
|
|
358
|
+
)
|
|
359
|
+
```
|
|
360
|
+
|
|
361
|
+
Args:
|
|
362
|
+
reach: Reach DataArray.
|
|
363
|
+
"""
|
|
364
|
+
self._validate_set('Reach', self.reach)
|
|
365
|
+
self._validate_channels_consistency(
|
|
366
|
+
constants.RF_CHANNEL, [reach, self.frequency, self.rf_spend]
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
self._reach = self._normalize_coords(reach, constants.MEDIA_TIME)
|
|
370
|
+
self.geos = self.reach.coords[constants.GEO].values.tolist()
|
|
371
|
+
self.media_time_coords = self.reach.coords[
|
|
372
|
+
constants.MEDIA_TIME
|
|
373
|
+
].values.tolist()
|
|
374
|
+
|
|
375
|
+
@property
|
|
376
|
+
def frequency(self) -> xr.DataArray:
|
|
377
|
+
return self._frequency
|
|
378
|
+
|
|
379
|
+
@frequency.setter
|
|
380
|
+
def frequency(self, frequency: xr.DataArray):
|
|
381
|
+
"""Sets the `frequency` data array.
|
|
382
|
+
|
|
383
|
+
`frequency` must have the following `DataArray` signature:
|
|
384
|
+
|
|
385
|
+
```
|
|
386
|
+
xarray.DataArray(
|
|
387
|
+
data=...,
|
|
388
|
+
name='frequency',
|
|
389
|
+
dims=['geo', 'media_time', 'rf_channel'],
|
|
390
|
+
coords={
|
|
391
|
+
'geo': ...,
|
|
392
|
+
'media_time': ...,
|
|
393
|
+
'rf_channel': ...,
|
|
394
|
+
},
|
|
395
|
+
)
|
|
396
|
+
```
|
|
397
|
+
|
|
398
|
+
Args:
|
|
399
|
+
frequency: Frequency DataArray.
|
|
400
|
+
"""
|
|
401
|
+
self._validate_set('Frequency', self.frequency)
|
|
402
|
+
self._validate_channels_consistency(
|
|
403
|
+
constants.RF_CHANNEL, [frequency, self.reach, self.rf_spend]
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
self._frequency = self._normalize_coords(frequency, constants.MEDIA_TIME)
|
|
407
|
+
self.geos = self.frequency.coords[constants.GEO].values.tolist()
|
|
408
|
+
self.media_time_coords = self.frequency.coords[
|
|
409
|
+
constants.MEDIA_TIME
|
|
410
|
+
].values.tolist()
|
|
411
|
+
|
|
412
|
+
@property
|
|
413
|
+
def rf_spend(self) -> xr.DataArray:
|
|
414
|
+
return self._rf_spend
|
|
415
|
+
|
|
416
|
+
@rf_spend.setter
|
|
417
|
+
def rf_spend(self, rf_spend: xr.DataArray):
|
|
418
|
+
"""Sets the `rf_spend` data array.
|
|
419
|
+
|
|
420
|
+
`rf_spend` must have the following `DataArray` signature:
|
|
421
|
+
|
|
422
|
+
```
|
|
423
|
+
xarray.DataArray(
|
|
424
|
+
data=...,
|
|
425
|
+
name='rf_spend',
|
|
426
|
+
dims=['geo', 'time', 'rf_channel'],
|
|
427
|
+
coords={
|
|
428
|
+
'geo': ...,
|
|
429
|
+
'time': ...,
|
|
430
|
+
'rf_channel': ...,
|
|
431
|
+
},
|
|
432
|
+
)
|
|
433
|
+
```
|
|
434
|
+
|
|
435
|
+
Args:
|
|
436
|
+
rf_spend: RF spend DataArray.
|
|
437
|
+
"""
|
|
438
|
+
self._validate_set('RF spend', self.rf_spend)
|
|
439
|
+
self._validate_channels_consistency(
|
|
440
|
+
constants.RF_CHANNEL, [rf_spend, self.reach, self.frequency]
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
self._rf_spend = self._normalize_coords(rf_spend, constants.TIME)
|
|
444
|
+
self.geos = self.rf_spend.coords[constants.GEO].values.tolist()
|
|
445
|
+
self.time_coords = self.rf_spend.coords[constants.TIME].values.tolist()
|
|
446
|
+
|
|
447
|
+
@property
|
|
448
|
+
def organic_media(self) -> xr.DataArray:
|
|
449
|
+
return self._organic_media
|
|
450
|
+
|
|
451
|
+
@organic_media.setter
|
|
452
|
+
def organic_media(self, organic_media: xr.DataArray):
|
|
453
|
+
"""Sets the `organic_media` data array.
|
|
454
|
+
|
|
455
|
+
`organic_media` must have the following `DataArray` signature:
|
|
456
|
+
|
|
457
|
+
```
|
|
458
|
+
xarray.DataArray(
|
|
459
|
+
data=...,
|
|
460
|
+
name='organic_media',
|
|
461
|
+
dims=['geo', 'media_time'],
|
|
462
|
+
coords={
|
|
463
|
+
'geo': ...,
|
|
464
|
+
'media_time': ...,
|
|
465
|
+
},
|
|
466
|
+
)
|
|
467
|
+
```
|
|
468
|
+
|
|
469
|
+
Args:
|
|
470
|
+
organic_media: Organic media DataArray.
|
|
471
|
+
"""
|
|
472
|
+
self._validate_set('Organic media', self.organic_media)
|
|
473
|
+
|
|
474
|
+
self._organic_media = self._normalize_coords(
|
|
475
|
+
organic_media, constants.MEDIA_TIME
|
|
476
|
+
)
|
|
477
|
+
self.geos = self.organic_media.coords[constants.GEO].values.tolist()
|
|
478
|
+
self.media_time_coords = self.organic_media.coords[
|
|
479
|
+
constants.MEDIA_TIME
|
|
480
|
+
].values.tolist()
|
|
481
|
+
|
|
482
|
+
@property
|
|
483
|
+
def organic_reach(self) -> xr.DataArray:
|
|
484
|
+
return self._organic_reach
|
|
485
|
+
|
|
486
|
+
@organic_reach.setter
|
|
487
|
+
def organic_reach(self, organic_reach: xr.DataArray):
|
|
488
|
+
"""Sets the `organic_reach` data array.
|
|
489
|
+
|
|
490
|
+
`organic_reach` must have the following `DataArray` signature:
|
|
491
|
+
|
|
492
|
+
```
|
|
493
|
+
xarray.DataArray(
|
|
494
|
+
data=...,
|
|
495
|
+
name='organic_reach',
|
|
496
|
+
dims=['geo', 'media_time', 'organic_rf_channel'],
|
|
497
|
+
coords={
|
|
498
|
+
'geo': ...,
|
|
499
|
+
'media_time': ...,
|
|
500
|
+
'organic_rf_channel': ...,
|
|
501
|
+
},
|
|
502
|
+
)
|
|
503
|
+
```
|
|
504
|
+
|
|
505
|
+
Args:
|
|
506
|
+
organic_reach: Organic reach DataArray.
|
|
507
|
+
"""
|
|
508
|
+
self._validate_set('Organic reach', self.organic_reach)
|
|
509
|
+
self._validate_channels_consistency(
|
|
510
|
+
constants.ORGANIC_RF_CHANNEL, [organic_reach, self.organic_frequency]
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
self._organic_reach = self._normalize_coords(
|
|
514
|
+
organic_reach, constants.MEDIA_TIME
|
|
515
|
+
)
|
|
516
|
+
self.geos = self.organic_reach.coords[constants.GEO].values.tolist()
|
|
517
|
+
self.media_time_coords = self.organic_reach.coords[
|
|
518
|
+
constants.MEDIA_TIME
|
|
519
|
+
].values.tolist()
|
|
520
|
+
|
|
521
|
+
@property
|
|
522
|
+
def organic_frequency(self) -> xr.DataArray:
|
|
523
|
+
return self._organic_frequency
|
|
524
|
+
|
|
525
|
+
@organic_frequency.setter
|
|
526
|
+
def organic_frequency(self, organic_frequency: xr.DataArray):
|
|
527
|
+
"""Sets the `organic_frequency` data array.
|
|
528
|
+
|
|
529
|
+
`organic_frequency` must have the following `DataArray` signature:
|
|
530
|
+
|
|
531
|
+
```
|
|
532
|
+
xarray.DataArray(
|
|
533
|
+
data=...,
|
|
534
|
+
name='organic_frequency',
|
|
535
|
+
dims=['geo', 'media_time', 'organic_rf_channel'],
|
|
536
|
+
coords={
|
|
537
|
+
'geo': ...,
|
|
538
|
+
'media_time': ...,
|
|
539
|
+
'organic_rf_channel': ...,
|
|
540
|
+
},
|
|
541
|
+
)
|
|
542
|
+
```
|
|
543
|
+
|
|
544
|
+
Args:
|
|
545
|
+
organic_frequency: Organic frequency DataArray.
|
|
546
|
+
"""
|
|
547
|
+
self._validate_set('Organic frequency', self.organic_frequency)
|
|
548
|
+
self._validate_channels_consistency(
|
|
549
|
+
constants.ORGANIC_RF_CHANNEL, [organic_frequency, self.organic_reach]
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
self._organic_frequency = self._normalize_coords(
|
|
553
|
+
organic_frequency, constants.MEDIA_TIME
|
|
554
|
+
)
|
|
555
|
+
self.geos = self.organic_frequency.coords[constants.GEO].values.tolist()
|
|
556
|
+
self.media_time_coords = self.organic_frequency.coords[
|
|
557
|
+
constants.MEDIA_TIME
|
|
558
|
+
].values.tolist()
|
|
559
|
+
|
|
560
|
+
@property
|
|
561
|
+
def non_media_treatments(self) -> xr.DataArray:
|
|
562
|
+
return self._non_media_treatments
|
|
563
|
+
|
|
564
|
+
@non_media_treatments.setter
|
|
565
|
+
def non_media_treatments(self, non_media_treatments: xr.DataArray):
|
|
566
|
+
"""Sets the `non media treatments` data array.
|
|
567
|
+
|
|
568
|
+
`non_media_treatments` must have the following `DataArray` signature:
|
|
569
|
+
|
|
570
|
+
```
|
|
571
|
+
xarray.DataArray(
|
|
572
|
+
data=...,
|
|
573
|
+
name='non_media_treatments',
|
|
574
|
+
dims=['geo', 'time', 'non_media_channel'],
|
|
575
|
+
coords={
|
|
576
|
+
'geo': ...,
|
|
577
|
+
'time': ...,
|
|
578
|
+
'non_media_channel': ...,
|
|
579
|
+
},
|
|
580
|
+
)
|
|
581
|
+
```
|
|
582
|
+
|
|
583
|
+
Args:
|
|
584
|
+
non_media_treatments: Non-media treatments DataArray.
|
|
585
|
+
"""
|
|
586
|
+
self._validate_set('Non-media treatments', self.non_media_treatments)
|
|
587
|
+
|
|
588
|
+
self._non_media_treatments = self._normalize_coords(
|
|
589
|
+
non_media_treatments, constants.TIME
|
|
590
|
+
)
|
|
591
|
+
self.geos = self.non_media_treatments.coords[constants.GEO].values.tolist()
|
|
592
|
+
self.time_coords = self.non_media_treatments.coords[
|
|
593
|
+
constants.TIME
|
|
594
|
+
].values.tolist()
|
|
595
|
+
|
|
596
|
+
def build(self) -> input_data.InputData:
|
|
597
|
+
"""Builds an `InputData`.
|
|
598
|
+
|
|
599
|
+
Constructs an `InputData` from constituent `DataArray`s given to this
|
|
600
|
+
builder thus far after performing one final validation pass over all data
|
|
601
|
+
arrays for consistency checks.
|
|
602
|
+
|
|
603
|
+
Returns:
|
|
604
|
+
A validated `InputData`.
|
|
605
|
+
"""
|
|
606
|
+
self._validate_required_components()
|
|
607
|
+
self._validate_nas()
|
|
608
|
+
|
|
609
|
+
# TODO: move logic from input_data to here: all channel names
|
|
610
|
+
# should be unique across media channels, rf channels, organic media
|
|
611
|
+
# channels, and organic rf channels.
|
|
612
|
+
sorted_geos = natsort.natsorted(self.geos)
|
|
613
|
+
sorted_times = natsort.natsorted(self.time_coords)
|
|
614
|
+
sorted_media_times = natsort.natsorted(self.media_time_coords)
|
|
615
|
+
|
|
616
|
+
def _get_sorted(da: xr.DataArray | None, is_media_time: bool = False):
|
|
617
|
+
"""Naturally sorts the DataArray by geo and time/media time."""
|
|
618
|
+
|
|
619
|
+
if da is None:
|
|
620
|
+
return None
|
|
621
|
+
if is_media_time:
|
|
622
|
+
return da.reindex(geo=sorted_geos, media_time=sorted_media_times)
|
|
623
|
+
else:
|
|
624
|
+
return da.reindex(geo=sorted_geos, time=sorted_times)
|
|
625
|
+
|
|
626
|
+
return input_data.InputData(
|
|
627
|
+
kpi_type=self._kpi_type,
|
|
628
|
+
kpi=_get_sorted(self.kpi),
|
|
629
|
+
revenue_per_kpi=_get_sorted(self.revenue_per_kpi),
|
|
630
|
+
controls=_get_sorted(self.controls),
|
|
631
|
+
population=self.population.reindex(geo=sorted_geos),
|
|
632
|
+
media=_get_sorted(self.media, True),
|
|
633
|
+
media_spend=_get_sorted(self.media_spend),
|
|
634
|
+
reach=_get_sorted(self.reach, True),
|
|
635
|
+
frequency=_get_sorted(self.frequency, True),
|
|
636
|
+
rf_spend=_get_sorted(self.rf_spend),
|
|
637
|
+
non_media_treatments=_get_sorted(self.non_media_treatments),
|
|
638
|
+
organic_media=_get_sorted(self.organic_media, True),
|
|
639
|
+
organic_reach=_get_sorted(self.organic_reach, True),
|
|
640
|
+
organic_frequency=_get_sorted(self.organic_frequency, True),
|
|
641
|
+
)
|
|
642
|
+
|
|
643
|
+
def _normalize_coords(
|
|
644
|
+
self, da: xr.DataArray, time_dimension_name: str | None = None
|
|
645
|
+
) -> xr.DataArray:
|
|
646
|
+
"""Normalizes the given `DataArray`'s coordinates in Meridian convention.
|
|
647
|
+
|
|
648
|
+
Validates that time values are in the conventional Meridian format and
|
|
649
|
+
that geos have national name if national. If geo coordinates are not string-
|
|
650
|
+
typed, they will be converted to strings.
|
|
651
|
+
|
|
652
|
+
Args:
|
|
653
|
+
da: The DataArray to normalize.
|
|
654
|
+
time_dimension_name: The name of the time dimension. If None, the will
|
|
655
|
+
skip time normalization.
|
|
656
|
+
|
|
657
|
+
Returns:
|
|
658
|
+
The normalized DataArray.
|
|
659
|
+
"""
|
|
660
|
+
if time_dimension_name is not None:
|
|
661
|
+
# Time values are expected to be
|
|
662
|
+
# (a) strings formatted in `"yyyy-mm-dd"`
|
|
663
|
+
# or
|
|
664
|
+
# (b) `datetime` values as numpy's `datetime64` types.
|
|
665
|
+
# All other types are not currently supported.
|
|
666
|
+
|
|
667
|
+
# If (b), `datetime` coord values will be normalized as formatted strings.
|
|
668
|
+
|
|
669
|
+
if da.coords.dtypes[time_dimension_name] == np.dtype('datetime64[ns]'):
|
|
670
|
+
date_strvalues = np.datetime_as_string(
|
|
671
|
+
da.coords[time_dimension_name], unit='D'
|
|
672
|
+
)
|
|
673
|
+
da = da.assign_coords({time_dimension_name: date_strvalues})
|
|
674
|
+
|
|
675
|
+
# Assume that the time coordinate labels are date-formatted strings.
|
|
676
|
+
# We don't currently support other, arbitrary object types in the builder.
|
|
677
|
+
for time in da.coords[time_dimension_name].values:
|
|
678
|
+
try:
|
|
679
|
+
_ = datetime.datetime.strptime(time, constants.DATE_FORMAT)
|
|
680
|
+
except ValueError as exc:
|
|
681
|
+
raise ValueError(
|
|
682
|
+
f"Invalid time label: '{time}'. Expected format:"
|
|
683
|
+
f" '{constants.DATE_FORMAT}'"
|
|
684
|
+
) from exc
|
|
685
|
+
|
|
686
|
+
if len(da.coords[constants.GEO].values.tolist()) == 1:
|
|
687
|
+
da = da.assign_coords(
|
|
688
|
+
{constants.GEO: [constants.NATIONAL_MODEL_DEFAULT_GEO_NAME]},
|
|
689
|
+
)
|
|
690
|
+
else:
|
|
691
|
+
da = da.assign_coords(
|
|
692
|
+
{constants.GEO: da.coords[constants.GEO].astype(str)}
|
|
693
|
+
)
|
|
694
|
+
|
|
695
|
+
return da
|
|
696
|
+
|
|
697
|
+
def _validate_set(self, component: str, da: xr.DataArray):
|
|
698
|
+
if da is not None:
|
|
699
|
+
raise ValueError(f'{component} was already set to {da}.')
|
|
700
|
+
|
|
701
|
+
def _validate_channels_consistency(
|
|
702
|
+
self, channel_dimension_name: str, da_list: list[xr.DataArray | None]
|
|
703
|
+
):
|
|
704
|
+
for da in da_list:
|
|
705
|
+
if da is not None and set(
|
|
706
|
+
da.coords[channel_dimension_name].values.tolist()
|
|
707
|
+
) != set(da_list[0].coords[channel_dimension_name].values.tolist()):
|
|
708
|
+
raise ValueError(
|
|
709
|
+
f'{channel_dimension_name} coordinates must be the same between'
|
|
710
|
+
f' {[da.name for da in da_list if da is not None]}.'
|
|
711
|
+
)
|
|
712
|
+
|
|
713
|
+
def _validate_required_components(self):
|
|
714
|
+
"""Validates that all required data arrays are defined."""
|
|
715
|
+
if self.kpi is None:
|
|
716
|
+
raise ValueError('KPI is required.')
|
|
717
|
+
|
|
718
|
+
if len(self.geos) == 1:
|
|
719
|
+
if self.population is not None:
|
|
720
|
+
warnings.warn(
|
|
721
|
+
'The `population` argument is ignored in a nationally aggregated'
|
|
722
|
+
' model. It will be reset to [1, 1, ..., 1]'
|
|
723
|
+
)
|
|
724
|
+
self._population = xr.DataArray(
|
|
725
|
+
[constants.NATIONAL_MODEL_DEFAULT_POPULATION_VALUE],
|
|
726
|
+
dims=[constants.GEO],
|
|
727
|
+
coords={
|
|
728
|
+
constants.GEO: self.geos,
|
|
729
|
+
},
|
|
730
|
+
name=constants.POPULATION,
|
|
731
|
+
)
|
|
732
|
+
if self.population is None:
|
|
733
|
+
raise ValueError('Population is required for non national models.')
|
|
734
|
+
|
|
735
|
+
if (self.media is None) ^ (self.media_spend is None):
|
|
736
|
+
raise ValueError('Media and media spend must be provided together.')
|
|
737
|
+
if (
|
|
738
|
+
self.reach is not None
|
|
739
|
+
or self.frequency is not None
|
|
740
|
+
or self.rf_spend is not None
|
|
741
|
+
) and (
|
|
742
|
+
self.reach is None or self.frequency is None or self.rf_spend is None
|
|
743
|
+
):
|
|
744
|
+
raise ValueError(
|
|
745
|
+
'Reach, frequency, and rf_spend must be provided together.'
|
|
746
|
+
)
|
|
747
|
+
if (self.organic_reach is None) ^ (self.organic_frequency is None):
|
|
748
|
+
raise ValueError(
|
|
749
|
+
'Organic reach and organic frequency must be provided together.'
|
|
750
|
+
)
|
|
751
|
+
if (
|
|
752
|
+
self.reach is None
|
|
753
|
+
and self.frequency is None
|
|
754
|
+
and self.rf_spend is None
|
|
755
|
+
and self.media_spend is None
|
|
756
|
+
and self.media is None
|
|
757
|
+
):
|
|
758
|
+
raise ValueError(
|
|
759
|
+
'It is required to have at least one of media or reach + frequency.'
|
|
760
|
+
)
|
|
761
|
+
|
|
762
|
+
def _validate_nas(self):
|
|
763
|
+
"""Check for NAs in all of the DataArrays.
|
|
764
|
+
|
|
765
|
+
Since the DataArray components should already distinguish between media time
|
|
766
|
+
and time coords, there are no media times to infer so there should be no
|
|
767
|
+
NAs.
|
|
768
|
+
"""
|
|
769
|
+
if self.kpi.isnull().any(axis=None):
|
|
770
|
+
raise ValueError('NA values found in the kpi data.')
|
|
771
|
+
if self.population.isnull().any(axis=None):
|
|
772
|
+
raise ValueError('NA values found in the population data.')
|
|
773
|
+
if self.controls is not None and self.controls.isnull().any(axis=None):
|
|
774
|
+
raise ValueError('NA values found in the controls data.')
|
|
775
|
+
if self.revenue_per_kpi is not None and self.revenue_per_kpi.isnull().any(
|
|
776
|
+
axis=None
|
|
777
|
+
):
|
|
778
|
+
raise ValueError('NA values found in the revenue per kpi data.')
|
|
779
|
+
if self.media_spend is not None and self.media_spend.isnull().any(
|
|
780
|
+
axis=None
|
|
781
|
+
):
|
|
782
|
+
raise ValueError('NA values found in the media spend data.')
|
|
783
|
+
if self.rf_spend is not None and self.rf_spend.isnull().any(axis=None):
|
|
784
|
+
raise ValueError('NA values found in the rf spend data.')
|
|
785
|
+
if (
|
|
786
|
+
self.non_media_treatments is not None
|
|
787
|
+
and self.non_media_treatments.isnull().any(axis=None)
|
|
788
|
+
):
|
|
789
|
+
raise ValueError('NA values found in the non media treatments data.')
|
|
790
|
+
|
|
791
|
+
if self.media is not None and self.media.isnull().any(axis=None):
|
|
792
|
+
raise ValueError('NA values found in the media data.')
|
|
793
|
+
|
|
794
|
+
if self.reach is not None and self.reach.isnull().any(axis=None):
|
|
795
|
+
raise ValueError('NA values found in the reach data.')
|
|
796
|
+
if self.frequency is not None and self.frequency.isnull().any(axis=None):
|
|
797
|
+
raise ValueError('NA values found in the frequency data.')
|
|
798
|
+
|
|
799
|
+
if self.organic_media is not None and self.organic_media.isnull().any(
|
|
800
|
+
axis=None
|
|
801
|
+
):
|
|
802
|
+
raise ValueError('NA values found in the organic media data.')
|
|
803
|
+
|
|
804
|
+
if self.organic_reach is not None and self.organic_reach.isnull().any(
|
|
805
|
+
axis=None
|
|
806
|
+
):
|
|
807
|
+
raise ValueError('NA values found in the organic reach data.')
|
|
808
|
+
if (
|
|
809
|
+
self.organic_frequency is not None
|
|
810
|
+
and self.organic_frequency.isnull().any(axis=None)
|
|
811
|
+
):
|
|
812
|
+
raise ValueError('NA values found in the organic frequency data.')
|
|
813
|
+
|
|
814
|
+
def _validate_lagged_media(
|
|
815
|
+
self, media_time_coords: Sequence[str], time_coords: Sequence[str]
|
|
816
|
+
):
|
|
817
|
+
na_period = np.sort(list(set(media_time_coords) - set(time_coords)))
|
|
818
|
+
if not np.all(na_period == np.sort(media_time_coords)[: len(na_period)]):
|
|
819
|
+
raise ValueError(
|
|
820
|
+
"The 'lagged media' period (period with 100% NA values in all"
|
|
821
|
+
f' non-media columns) {na_period} is not a continuous window'
|
|
822
|
+
' starting from the earliest time period.'
|
|
823
|
+
)
|