google-meridian 1.1.0__py3-none-any.whl → 1.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.1.0.dist-info → google_meridian-1.1.2.dist-info}/METADATA +6 -2
- google_meridian-1.1.2.dist-info/RECORD +46 -0
- {google_meridian-1.1.0.dist-info → google_meridian-1.1.2.dist-info}/WHEEL +1 -1
- meridian/__init__.py +2 -2
- meridian/analysis/__init__.py +1 -1
- meridian/analysis/analyzer.py +29 -22
- meridian/analysis/formatter.py +1 -1
- meridian/analysis/optimizer.py +70 -44
- meridian/analysis/summarizer.py +1 -1
- meridian/analysis/summary_text.py +1 -1
- meridian/analysis/test_utils.py +1 -1
- meridian/analysis/visualizer.py +17 -8
- meridian/constants.py +3 -3
- meridian/data/__init__.py +4 -1
- meridian/data/arg_builder.py +1 -1
- meridian/data/data_frame_input_data_builder.py +614 -0
- meridian/data/input_data.py +12 -8
- meridian/data/input_data_builder.py +817 -0
- meridian/data/load.py +121 -428
- meridian/data/nd_array_input_data_builder.py +509 -0
- meridian/data/test_utils.py +60 -43
- meridian/data/time_coordinates.py +1 -1
- meridian/mlflow/__init__.py +17 -0
- meridian/mlflow/autolog.py +54 -0
- meridian/model/__init__.py +1 -1
- meridian/model/adstock_hill.py +1 -1
- meridian/model/knots.py +1 -1
- meridian/model/media.py +1 -1
- meridian/model/model.py +65 -37
- meridian/model/model_test_data.py +75 -1
- meridian/model/posterior_sampler.py +19 -15
- meridian/model/prior_distribution.py +1 -1
- meridian/model/prior_sampler.py +32 -26
- meridian/model/spec.py +18 -8
- meridian/model/transformers.py +1 -1
- google_meridian-1.1.0.dist-info/RECORD +0 -41
- {google_meridian-1.1.0.dist-info → google_meridian-1.1.2.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.1.0.dist-info → google_meridian-1.1.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,509 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""An implementation of `InputDataBuilder` with n-dimensional array primitives."""
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
import warnings
|
|
19
|
+
from meridian import constants
|
|
20
|
+
from meridian.data import input_data_builder
|
|
21
|
+
import numpy as np
|
|
22
|
+
import xarray as xr
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
'NDArrayInputDataBuilder',
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class NDArrayInputDataBuilder(input_data_builder.InputDataBuilder):
|
|
31
|
+
"""Builds `InputData` from n-dimensional arrays."""
|
|
32
|
+
|
|
33
|
+
# Unlike `DataFrameInputDataBuilder`, each piecemeal data has no coordinate
|
|
34
|
+
# information; they're purely data values. It's up to the user to provide
|
|
35
|
+
# coordinates with setter methods from the abstract base class above.
|
|
36
|
+
# Validation is done on each piece w.r.t. dimensional consistency by
|
|
37
|
+
# shape alone.
|
|
38
|
+
|
|
39
|
+
def with_kpi(self, nd: np.ndarray) -> 'NDArrayInputDataBuilder':
|
|
40
|
+
"""Reads KPI data from a ndarray.
|
|
41
|
+
|
|
42
|
+
`nd` must be given with the shape:
|
|
43
|
+
- `(n_geos, n_time)`
|
|
44
|
+
- `(n_time,)` or `(1, n_time)` for national model.
|
|
45
|
+
|
|
46
|
+
If called without a call to .geos() first, the data will be
|
|
47
|
+
assumed to be national-level.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
nd: The ndarray to read the KPI data from.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
The `NDArrayInputDataBuilder` with the added KPI data.
|
|
54
|
+
"""
|
|
55
|
+
### Validate ###
|
|
56
|
+
self._validate_coords()
|
|
57
|
+
self._validate_shape(nd)
|
|
58
|
+
|
|
59
|
+
### Transform ###
|
|
60
|
+
self.kpi = xr.DataArray(
|
|
61
|
+
nd,
|
|
62
|
+
dims=[constants.GEO, constants.TIME],
|
|
63
|
+
coords={
|
|
64
|
+
constants.GEO: self.geos,
|
|
65
|
+
constants.TIME: self.time_coords,
|
|
66
|
+
},
|
|
67
|
+
name=constants.KPI,
|
|
68
|
+
)
|
|
69
|
+
return self
|
|
70
|
+
|
|
71
|
+
def with_controls(
|
|
72
|
+
self, nd: np.ndarray, control_names: list[str]
|
|
73
|
+
) -> 'NDArrayInputDataBuilder':
|
|
74
|
+
"""Reads controls data from a ndarray.
|
|
75
|
+
|
|
76
|
+
`nd` must be given with the shape:
|
|
77
|
+
- `(n_geos, n_time, n_controls)`
|
|
78
|
+
- `(n_time, n_controls)` or `(1, n_time, n_controls)` for national model.
|
|
79
|
+
|
|
80
|
+
If called without a call to .geos() first, the data will be
|
|
81
|
+
assumed to be national-level.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
nd: The ndarray to read the controls data from.
|
|
85
|
+
control_names: The names of the control variables.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
The `NDArrayInputDataBuilder` with the added controls data.
|
|
89
|
+
"""
|
|
90
|
+
### Validate ###
|
|
91
|
+
self._validate_coords()
|
|
92
|
+
self._validate_shape(nd, control_names)
|
|
93
|
+
|
|
94
|
+
### Transform ###
|
|
95
|
+
self.controls = xr.DataArray(
|
|
96
|
+
nd,
|
|
97
|
+
dims=[constants.GEO, constants.TIME, constants.CONTROL_VARIABLE],
|
|
98
|
+
coords={
|
|
99
|
+
constants.GEO: self.geos,
|
|
100
|
+
constants.TIME: self.time_coords,
|
|
101
|
+
constants.CONTROL_VARIABLE: control_names,
|
|
102
|
+
},
|
|
103
|
+
name=constants.CONTROLS,
|
|
104
|
+
)
|
|
105
|
+
return self
|
|
106
|
+
|
|
107
|
+
def with_population(self, nd: np.ndarray) -> 'NDArrayInputDataBuilder':
|
|
108
|
+
"""Reads population data from a ndarray.
|
|
109
|
+
|
|
110
|
+
`nd` must be given with the shape:
|
|
111
|
+
- `(n_geos,)`
|
|
112
|
+
|
|
113
|
+
If called without a call to .geos() first, the data will be
|
|
114
|
+
assumed to be national-level.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
nd: The ndarray to read the population data from.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
The `NDArrayInputDataBuilder` with the added population data.
|
|
121
|
+
"""
|
|
122
|
+
### Validate ###
|
|
123
|
+
self._validate_coords(is_population=True)
|
|
124
|
+
self._validate_shape(nd, is_population=True)
|
|
125
|
+
### Transform ###
|
|
126
|
+
self.population = xr.DataArray(
|
|
127
|
+
nd,
|
|
128
|
+
dims=[constants.GEO],
|
|
129
|
+
coords={constants.GEO: self.geos},
|
|
130
|
+
name=constants.POPULATION,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
return self
|
|
134
|
+
|
|
135
|
+
def with_revenue_per_kpi(self, nd: np.ndarray) -> 'NDArrayInputDataBuilder':
|
|
136
|
+
"""Reads Revenue per KPI data from a ndarray.
|
|
137
|
+
|
|
138
|
+
`nd` must be given with the shape:
|
|
139
|
+
- `(n_geos, n_time)`
|
|
140
|
+
- `(n_time,)` or `(1, n_time)` for national model.
|
|
141
|
+
|
|
142
|
+
If called without a call to .geos() first, the data will be
|
|
143
|
+
assumed to be national-level.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
nd: The ndarray to read the Reevenue per KPI data from.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
The `NDArrayInputDataBuilder` with the added Revenue per KPI data.
|
|
150
|
+
"""
|
|
151
|
+
### Validate ###
|
|
152
|
+
self._validate_coords()
|
|
153
|
+
self._validate_shape(nd)
|
|
154
|
+
revenue_per_kpi_nd = self._check_revenue_per_kpi_defaults(nd)
|
|
155
|
+
|
|
156
|
+
### Transform ###
|
|
157
|
+
self.revenue_per_kpi = xr.DataArray(
|
|
158
|
+
revenue_per_kpi_nd,
|
|
159
|
+
dims=[constants.GEO, constants.TIME],
|
|
160
|
+
coords={
|
|
161
|
+
constants.GEO: self.geos,
|
|
162
|
+
constants.TIME: self.time_coords,
|
|
163
|
+
},
|
|
164
|
+
name=constants.REVENUE_PER_KPI,
|
|
165
|
+
)
|
|
166
|
+
return self
|
|
167
|
+
|
|
168
|
+
def with_media(
|
|
169
|
+
self, m_nd: np.ndarray, ms_nd: np.ndarray, media_channels: list[str]
|
|
170
|
+
) -> 'NDArrayInputDataBuilder':
|
|
171
|
+
"""Reads media and media spend data from the ndarrays.
|
|
172
|
+
|
|
173
|
+
`m_nd` must be given with the shape:
|
|
174
|
+
- `(n_geos, n_media_times, n_media_channels)`
|
|
175
|
+
- `(n_media_times, n_media_channels)` or `(1, n_media_times,
|
|
176
|
+
n_media_channels)` for national model.
|
|
177
|
+
|
|
178
|
+
`ms_nd` must be given with the shape:
|
|
179
|
+
- `(n_geos, n_times, n_media_channels)`
|
|
180
|
+
- `(n_times, n_media_channels)` or `(1, n_times,
|
|
181
|
+
n_media_channels)` for national model.
|
|
182
|
+
|
|
183
|
+
If called without a call to .geos() first, the data will be
|
|
184
|
+
assumed to be national-level.
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
m_nd: The ndarray that contains dimensional media data.
|
|
188
|
+
ms_nd: The ndarray that contains dimensional media spend data.
|
|
189
|
+
media_channels: The names of the media channels.
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
The `NDArrayInputDataBuilder` with the added media and media spend data.
|
|
193
|
+
"""
|
|
194
|
+
### Validate ###
|
|
195
|
+
self._validate_coords(is_media_time=True)
|
|
196
|
+
self._validate_coords(is_media_time=False)
|
|
197
|
+
self._validate_shape(nd=m_nd, dims=media_channels, is_media_time=True)
|
|
198
|
+
self._validate_shape(nd=ms_nd, dims=media_channels, is_media_time=False)
|
|
199
|
+
|
|
200
|
+
### Transform ###
|
|
201
|
+
self.media = xr.DataArray(
|
|
202
|
+
m_nd,
|
|
203
|
+
dims=[
|
|
204
|
+
constants.GEO,
|
|
205
|
+
constants.MEDIA_TIME,
|
|
206
|
+
constants.MEDIA_CHANNEL,
|
|
207
|
+
],
|
|
208
|
+
coords={
|
|
209
|
+
constants.GEO: self.geos,
|
|
210
|
+
constants.MEDIA_TIME: self.media_time_coords,
|
|
211
|
+
constants.MEDIA_CHANNEL: media_channels,
|
|
212
|
+
},
|
|
213
|
+
name=constants.MEDIA,
|
|
214
|
+
)
|
|
215
|
+
self.media_spend = xr.DataArray(
|
|
216
|
+
ms_nd,
|
|
217
|
+
dims=[
|
|
218
|
+
constants.GEO,
|
|
219
|
+
constants.TIME,
|
|
220
|
+
constants.MEDIA_CHANNEL,
|
|
221
|
+
],
|
|
222
|
+
coords={
|
|
223
|
+
constants.GEO: self.geos,
|
|
224
|
+
constants.TIME: self.time_coords,
|
|
225
|
+
constants.MEDIA_CHANNEL: media_channels,
|
|
226
|
+
},
|
|
227
|
+
name=constants.MEDIA_SPEND,
|
|
228
|
+
)
|
|
229
|
+
return self
|
|
230
|
+
|
|
231
|
+
def with_reach(
|
|
232
|
+
self,
|
|
233
|
+
r_nd: np.ndarray,
|
|
234
|
+
f_nd: np.ndarray,
|
|
235
|
+
rfs_nd: np.ndarray,
|
|
236
|
+
rf_channels: list[str],
|
|
237
|
+
) -> 'NDArrayInputDataBuilder':
|
|
238
|
+
"""Reads reach, frequency, and rf_spend data from the ndarrays.
|
|
239
|
+
|
|
240
|
+
`r_nd` and `f_nd` must be given with the shape:
|
|
241
|
+
- `(n_geos, n_media_times, n_rf_channels)`
|
|
242
|
+
- `(n_media_times, n_rf_channels)` or `(1, n_media_times,
|
|
243
|
+
n_rf_channels)` for national model.
|
|
244
|
+
|
|
245
|
+
`rfs_nd` must be given with the shape:
|
|
246
|
+
- `(n_geos, n_times, n_rf_channels)`
|
|
247
|
+
- `(n_times, n_rf_channels)` or `(1, n_times,
|
|
248
|
+
n_rf_channels)` for national model.
|
|
249
|
+
|
|
250
|
+
If called without a call to .geos() first, the data will be
|
|
251
|
+
assumed to be national-level.
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
r_nd: The ndarray that contains dimensional reach data.
|
|
255
|
+
f_nd: The ndarray that contains dimensional frequency data.
|
|
256
|
+
rfs_nd: The ndarray that contains dimensional rf_spend data.
|
|
257
|
+
rf_channels: The names of the rf channels.
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
The `NDArrayInputDataBuilder` with the added reach, frequency, and
|
|
261
|
+
rf_spend data.
|
|
262
|
+
"""
|
|
263
|
+
### Validate ###
|
|
264
|
+
self._validate_coords(is_media_time=True)
|
|
265
|
+
self._validate_coords(is_media_time=False)
|
|
266
|
+
self._validate_shape(nd=r_nd, dims=rf_channels, is_media_time=True)
|
|
267
|
+
self._validate_shape(nd=f_nd, dims=rf_channels, is_media_time=True)
|
|
268
|
+
self._validate_shape(nd=rfs_nd, dims=rf_channels, is_media_time=False)
|
|
269
|
+
|
|
270
|
+
### Transform ###
|
|
271
|
+
self.reach = xr.DataArray(
|
|
272
|
+
r_nd,
|
|
273
|
+
dims=[
|
|
274
|
+
constants.GEO,
|
|
275
|
+
constants.MEDIA_TIME,
|
|
276
|
+
constants.RF_CHANNEL,
|
|
277
|
+
],
|
|
278
|
+
coords={
|
|
279
|
+
constants.GEO: self.geos,
|
|
280
|
+
constants.MEDIA_TIME: self.media_time_coords,
|
|
281
|
+
constants.RF_CHANNEL: rf_channels,
|
|
282
|
+
},
|
|
283
|
+
name=constants.REACH,
|
|
284
|
+
)
|
|
285
|
+
self.frequency = xr.DataArray(
|
|
286
|
+
f_nd,
|
|
287
|
+
dims=[
|
|
288
|
+
constants.GEO,
|
|
289
|
+
constants.MEDIA_TIME,
|
|
290
|
+
constants.RF_CHANNEL,
|
|
291
|
+
],
|
|
292
|
+
coords={
|
|
293
|
+
constants.GEO: self.geos,
|
|
294
|
+
constants.MEDIA_TIME: self.media_time_coords,
|
|
295
|
+
constants.RF_CHANNEL: rf_channels,
|
|
296
|
+
},
|
|
297
|
+
name=constants.FREQUENCY,
|
|
298
|
+
)
|
|
299
|
+
self.rf_spend = xr.DataArray(
|
|
300
|
+
rfs_nd,
|
|
301
|
+
dims=[
|
|
302
|
+
constants.GEO,
|
|
303
|
+
constants.TIME,
|
|
304
|
+
constants.RF_CHANNEL,
|
|
305
|
+
],
|
|
306
|
+
coords={
|
|
307
|
+
constants.GEO: self.geos,
|
|
308
|
+
constants.TIME: self.time_coords,
|
|
309
|
+
constants.RF_CHANNEL: rf_channels,
|
|
310
|
+
},
|
|
311
|
+
name=constants.RF_SPEND,
|
|
312
|
+
)
|
|
313
|
+
return self
|
|
314
|
+
|
|
315
|
+
def with_organic_media(
|
|
316
|
+
self, nd: np.ndarray, organic_media_channels: list[str]
|
|
317
|
+
) -> 'NDArrayInputDataBuilder':
|
|
318
|
+
"""Reads organic media data from a ndarray.
|
|
319
|
+
|
|
320
|
+
`nd` must be given with the shape:
|
|
321
|
+
- `(n_geos, n_media_times, n_organic_media_channels)`
|
|
322
|
+
- `(n_media_times, n_organic_media_channels)` or `(1, n_media_times,
|
|
323
|
+
n_organic_media_channels)` for national model.
|
|
324
|
+
|
|
325
|
+
If called without a call to .geos() first, the data will be
|
|
326
|
+
assumed to be national-level.
|
|
327
|
+
|
|
328
|
+
Args:
|
|
329
|
+
nd: The ndarray to read the organic media data from.
|
|
330
|
+
organic_media_channels: The names of the organic media channels.
|
|
331
|
+
|
|
332
|
+
Returns:
|
|
333
|
+
The `NDArrayInputDataBuilder` with the added organic media data.
|
|
334
|
+
"""
|
|
335
|
+
### Validate ###
|
|
336
|
+
self._validate_coords(is_media_time=True)
|
|
337
|
+
self._validate_shape(nd=nd, dims=organic_media_channels, is_media_time=True)
|
|
338
|
+
|
|
339
|
+
### Transform ###
|
|
340
|
+
self.organic_media = xr.DataArray(
|
|
341
|
+
nd,
|
|
342
|
+
dims=[
|
|
343
|
+
constants.GEO,
|
|
344
|
+
constants.MEDIA_TIME,
|
|
345
|
+
constants.ORGANIC_MEDIA_CHANNEL,
|
|
346
|
+
],
|
|
347
|
+
coords={
|
|
348
|
+
constants.GEO: self.geos,
|
|
349
|
+
constants.MEDIA_TIME: self.media_time_coords,
|
|
350
|
+
constants.ORGANIC_MEDIA_CHANNEL: organic_media_channels,
|
|
351
|
+
},
|
|
352
|
+
name=constants.ORGANIC_MEDIA,
|
|
353
|
+
)
|
|
354
|
+
return self
|
|
355
|
+
|
|
356
|
+
def with_organic_reach(
|
|
357
|
+
self, or_nd: np.ndarray, of_nd: np.ndarray, organic_rf_channels: list[str]
|
|
358
|
+
) -> 'NDArrayInputDataBuilder':
|
|
359
|
+
"""Reads organic reach and organic frequency data from the ndarrays.
|
|
360
|
+
|
|
361
|
+
`or_nd` and `of_nd` must be given with the shape:
|
|
362
|
+
- `(n_geos, n_media_times, n_organic_rf_channels)`
|
|
363
|
+
- `(n_media_times, n_organic_rf_channels)` or `(1, n_media_times,
|
|
364
|
+
n_organic_rf_channels)` for national model.
|
|
365
|
+
|
|
366
|
+
If called without a call to .geos() first, the data will be
|
|
367
|
+
assumed to be national-level.
|
|
368
|
+
|
|
369
|
+
Args:
|
|
370
|
+
or_nd: The ndarray that contains dimensional reach data.
|
|
371
|
+
of_nd: The ndarray that contains dimensional frequency data.
|
|
372
|
+
organic_rf_channels: The names of the organic rf channels.
|
|
373
|
+
|
|
374
|
+
Returns:
|
|
375
|
+
The `NDArrayInputDataBuilder` with the added organic reach and organic
|
|
376
|
+
frequency data.
|
|
377
|
+
"""
|
|
378
|
+
### Validate ###
|
|
379
|
+
self._validate_coords(is_media_time=True)
|
|
380
|
+
self._validate_shape(nd=or_nd, dims=organic_rf_channels, is_media_time=True)
|
|
381
|
+
self._validate_shape(nd=of_nd, dims=organic_rf_channels, is_media_time=True)
|
|
382
|
+
|
|
383
|
+
### Transform ###
|
|
384
|
+
self.organic_reach = xr.DataArray(
|
|
385
|
+
or_nd,
|
|
386
|
+
dims=[
|
|
387
|
+
constants.GEO,
|
|
388
|
+
constants.MEDIA_TIME,
|
|
389
|
+
constants.ORGANIC_RF_CHANNEL,
|
|
390
|
+
],
|
|
391
|
+
coords={
|
|
392
|
+
constants.GEO: self.geos,
|
|
393
|
+
constants.MEDIA_TIME: self.media_time_coords,
|
|
394
|
+
constants.ORGANIC_RF_CHANNEL: organic_rf_channels,
|
|
395
|
+
},
|
|
396
|
+
name=constants.ORGANIC_REACH,
|
|
397
|
+
)
|
|
398
|
+
self.organic_frequency = xr.DataArray(
|
|
399
|
+
of_nd,
|
|
400
|
+
dims=[
|
|
401
|
+
constants.GEO,
|
|
402
|
+
constants.MEDIA_TIME,
|
|
403
|
+
constants.ORGANIC_RF_CHANNEL,
|
|
404
|
+
],
|
|
405
|
+
coords={
|
|
406
|
+
constants.GEO: self.geos,
|
|
407
|
+
constants.MEDIA_TIME: self.media_time_coords,
|
|
408
|
+
constants.ORGANIC_RF_CHANNEL: organic_rf_channels,
|
|
409
|
+
},
|
|
410
|
+
name=constants.ORGANIC_REACH,
|
|
411
|
+
)
|
|
412
|
+
return self
|
|
413
|
+
|
|
414
|
+
def with_non_media_treatments(
|
|
415
|
+
self, nd: np.ndarray, non_media_channel_names: list[str]
|
|
416
|
+
) -> 'NDArrayInputDataBuilder':
|
|
417
|
+
"""Reads non-media treatments data from a ndarray.
|
|
418
|
+
|
|
419
|
+
`nd` must be given with the shape:
|
|
420
|
+
- `(n_geos, n_time, n_media_channels)`
|
|
421
|
+
- `(n_time, n_media_channels)` or `(1, n_time, n_media_channels)` for
|
|
422
|
+
national model.
|
|
423
|
+
|
|
424
|
+
If called without a call to .geos() first, the data will be
|
|
425
|
+
assumed to be national-level.
|
|
426
|
+
|
|
427
|
+
Args:
|
|
428
|
+
nd: The ndarray to read the non-media treatments data from.
|
|
429
|
+
non_media_channel_names: The names of the non-media channels.
|
|
430
|
+
|
|
431
|
+
Returns:
|
|
432
|
+
The `NDArrayInputDataBuilder` with the added non-media treatments data.
|
|
433
|
+
"""
|
|
434
|
+
### Validate ###
|
|
435
|
+
self._validate_coords()
|
|
436
|
+
self._validate_shape(nd, non_media_channel_names)
|
|
437
|
+
|
|
438
|
+
### Transform ###
|
|
439
|
+
self.non_media_treatments = xr.DataArray(
|
|
440
|
+
nd,
|
|
441
|
+
dims=[constants.GEO, constants.TIME, constants.NON_MEDIA_CHANNEL],
|
|
442
|
+
coords={
|
|
443
|
+
constants.GEO: self.geos,
|
|
444
|
+
constants.TIME: self.time_coords,
|
|
445
|
+
constants.NON_MEDIA_CHANNEL: non_media_channel_names,
|
|
446
|
+
},
|
|
447
|
+
name=constants.NON_MEDIA_TREATMENTS,
|
|
448
|
+
)
|
|
449
|
+
return self
|
|
450
|
+
|
|
451
|
+
def _validate_coords(
|
|
452
|
+
self, is_population: bool = False, is_media_time: bool = False
|
|
453
|
+
):
|
|
454
|
+
"""Validates that the data has the expected coordinates."""
|
|
455
|
+
if not is_population:
|
|
456
|
+
if is_media_time and self._media_time_coords is None:
|
|
457
|
+
raise ValueError(
|
|
458
|
+
'Media times are required first. Set using .media_time_coords()'
|
|
459
|
+
)
|
|
460
|
+
if not is_media_time and self.time_coords is None:
|
|
461
|
+
raise ValueError(
|
|
462
|
+
'Time coordinates are required first. Set using .time_coords()'
|
|
463
|
+
)
|
|
464
|
+
if self.geos is None:
|
|
465
|
+
logging.warning(
|
|
466
|
+
'No geo coordinates set. Assuming NATIONAL model and geos will be set'
|
|
467
|
+
' to the default value.'
|
|
468
|
+
)
|
|
469
|
+
self.geos = [constants.NATIONAL_MODEL_DEFAULT_GEO_NAME]
|
|
470
|
+
|
|
471
|
+
def _validate_shape(
|
|
472
|
+
self,
|
|
473
|
+
nd: np.ndarray,
|
|
474
|
+
dims: list[str] | None = None,
|
|
475
|
+
is_population: bool = False,
|
|
476
|
+
is_media_time: bool = False,
|
|
477
|
+
):
|
|
478
|
+
"""Validates that the data has the expected shape."""
|
|
479
|
+
# Since all data has a geo dimension (even for national data),
|
|
480
|
+
# Expect the first axis to have the shape of the geo dimension.
|
|
481
|
+
expected_shape = (len(self.geos),)
|
|
482
|
+
detailed_info = f'Expected: {len(self.geos)} geos'
|
|
483
|
+
if not is_population:
|
|
484
|
+
if is_media_time:
|
|
485
|
+
expected_shape += (len(self.media_time_coords),)
|
|
486
|
+
detailed_info += f' x {len(self.media_time_coords)} media times'
|
|
487
|
+
else:
|
|
488
|
+
expected_shape += (len(self.time_coords),)
|
|
489
|
+
detailed_info += f' x {len(self.time_coords)} times'
|
|
490
|
+
|
|
491
|
+
if dims is not None:
|
|
492
|
+
if len(dims) != len(set(dims)):
|
|
493
|
+
raise ValueError('given dimensions must be unique.')
|
|
494
|
+
expected_shape += (len(dims),)
|
|
495
|
+
detailed_info += f' x {len(dims)} dims'
|
|
496
|
+
|
|
497
|
+
if expected_shape != nd.shape:
|
|
498
|
+
raise ValueError(f'{detailed_info}. Got: {nd.shape}.')
|
|
499
|
+
|
|
500
|
+
def _check_revenue_per_kpi_defaults(self, nd: np.ndarray):
|
|
501
|
+
"""Sets revenue_per_kpi to default if kpi type is revenue and with_revenue_per_kpi is called."""
|
|
502
|
+
if self._kpi_type == constants.REVENUE:
|
|
503
|
+
warnings.warn(
|
|
504
|
+
'with_revenue_per_kpi was called but kpi_type was set to revenue.'
|
|
505
|
+
' Assuming revenue per kpi with values [1].'
|
|
506
|
+
)
|
|
507
|
+
return np.ones(nd.shape)
|
|
508
|
+
else:
|
|
509
|
+
return nd
|