google-meridian 1.2.1__py3-none-any.whl → 1.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google_meridian-1.3.1.dist-info/METADATA +209 -0
- google_meridian-1.3.1.dist-info/RECORD +76 -0
- {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/top_level.txt +1 -0
- meridian/analysis/__init__.py +2 -0
- meridian/analysis/analyzer.py +179 -105
- meridian/analysis/formatter.py +2 -2
- meridian/analysis/optimizer.py +227 -87
- meridian/analysis/review/__init__.py +20 -0
- meridian/analysis/review/checks.py +721 -0
- meridian/analysis/review/configs.py +110 -0
- meridian/analysis/review/constants.py +40 -0
- meridian/analysis/review/results.py +544 -0
- meridian/analysis/review/reviewer.py +186 -0
- meridian/analysis/summarizer.py +21 -34
- meridian/analysis/templates/chips.html.jinja +12 -0
- meridian/analysis/test_utils.py +27 -5
- meridian/analysis/visualizer.py +41 -57
- meridian/backend/__init__.py +457 -118
- meridian/backend/test_utils.py +162 -0
- meridian/constants.py +39 -3
- meridian/model/__init__.py +1 -0
- meridian/model/eda/__init__.py +3 -0
- meridian/model/eda/constants.py +21 -0
- meridian/model/eda/eda_engine.py +1309 -196
- meridian/model/eda/eda_outcome.py +200 -0
- meridian/model/eda/eda_spec.py +84 -0
- meridian/model/eda/meridian_eda.py +220 -0
- meridian/model/knots.py +55 -49
- meridian/model/media.py +10 -8
- meridian/model/model.py +79 -16
- meridian/model/model_test_data.py +53 -0
- meridian/model/posterior_sampler.py +39 -32
- meridian/model/prior_distribution.py +12 -2
- meridian/model/prior_sampler.py +146 -90
- meridian/model/spec.py +7 -8
- meridian/model/transformers.py +11 -3
- meridian/version.py +1 -1
- schema/__init__.py +18 -0
- schema/serde/__init__.py +26 -0
- schema/serde/constants.py +48 -0
- schema/serde/distribution.py +515 -0
- schema/serde/eda_spec.py +192 -0
- schema/serde/function_registry.py +143 -0
- schema/serde/hyperparameters.py +363 -0
- schema/serde/inference_data.py +105 -0
- schema/serde/marketing_data.py +1321 -0
- schema/serde/meridian_serde.py +413 -0
- schema/serde/serde.py +47 -0
- schema/serde/test_data.py +4608 -0
- schema/utils/__init__.py +17 -0
- schema/utils/time_record.py +156 -0
- google_meridian-1.2.1.dist-info/METADATA +0 -409
- google_meridian-1.2.1.dist-info/RECORD +0 -52
- {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/WHEEL +0 -0
- {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,1321 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Serialization and deserialization of `InputData` for Meridian models."""
|
|
16
|
+
|
|
17
|
+
from collections.abc import Mapping
|
|
18
|
+
import dataclasses
|
|
19
|
+
import datetime as dt
|
|
20
|
+
import functools
|
|
21
|
+
import itertools
|
|
22
|
+
from typing import Sequence
|
|
23
|
+
|
|
24
|
+
from meridian import constants as c
|
|
25
|
+
from meridian.data import input_data as meridian_input_data
|
|
26
|
+
from mmm.v1.common import date_interval_pb2
|
|
27
|
+
from mmm.v1.marketing import marketing_data_pb2 as marketing_pb
|
|
28
|
+
from schema.serde import constants as sc
|
|
29
|
+
from schema.serde import serde
|
|
30
|
+
from schema.utils import time_record
|
|
31
|
+
import numpy as np
|
|
32
|
+
import xarray as xr
|
|
33
|
+
|
|
34
|
+
from google.type import date_pb2
|
|
35
|
+
|
|
36
|
+
# Mapping from DataArray names to coordinate names
|
|
37
|
+
_COORD_NAME_MAP = {
|
|
38
|
+
c.MEDIA: c.MEDIA_CHANNEL,
|
|
39
|
+
c.REACH: c.RF_CHANNEL,
|
|
40
|
+
c.FREQUENCY: c.RF_CHANNEL,
|
|
41
|
+
c.ORGANIC_MEDIA: c.ORGANIC_MEDIA_CHANNEL,
|
|
42
|
+
c.ORGANIC_REACH: c.ORGANIC_RF_CHANNEL,
|
|
43
|
+
c.ORGANIC_FREQUENCY: c.ORGANIC_RF_CHANNEL,
|
|
44
|
+
c.NON_MEDIA_TREATMENTS: c.NON_MEDIA_CHANNEL,
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclasses.dataclass(frozen=True)
|
|
49
|
+
class _DeserializedTimeDimension:
|
|
50
|
+
"""Wrapper class for `TimeDimension` proto to provide utility methods during deserialization."""
|
|
51
|
+
|
|
52
|
+
_time_dimension: marketing_pb.MarketingDataMetadata.TimeDimension
|
|
53
|
+
|
|
54
|
+
def __post_init__(self):
|
|
55
|
+
if not self._time_dimension.dates:
|
|
56
|
+
raise ValueError("TimeDimension proto must have at least one date.")
|
|
57
|
+
|
|
58
|
+
@functools.cached_property
|
|
59
|
+
def date_coordinates(self) -> list[dt.date]:
|
|
60
|
+
"""Returns a list of date coordinates in this time dimension."""
|
|
61
|
+
return [dt.date(d.year, d.month, d.day) for d in self._time_dimension.dates]
|
|
62
|
+
|
|
63
|
+
@functools.cached_property
|
|
64
|
+
def time_dimension_interval(self) -> date_interval_pb2.DateInterval:
|
|
65
|
+
"""Returns the `[start, end)` interval that spans this time dimension.
|
|
66
|
+
|
|
67
|
+
This date interval spans all of the date coordinates in this time dimension.
|
|
68
|
+
"""
|
|
69
|
+
date_intervals = time_record.convert_times_to_date_intervals(
|
|
70
|
+
self.date_coordinates
|
|
71
|
+
)
|
|
72
|
+
return _get_date_interval_from_date_intervals(list(date_intervals.values()))
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@dataclasses.dataclass(frozen=True)
|
|
76
|
+
class _DeserializedMetadata:
|
|
77
|
+
"""A container for parsed metadata from the `MarketingData` proto.
|
|
78
|
+
|
|
79
|
+
Attributes:
|
|
80
|
+
_metadata: The `MarketingDataMetadata` proto.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
_metadata: marketing_pb.MarketingDataMetadata
|
|
84
|
+
|
|
85
|
+
def __post_init__(self):
|
|
86
|
+
# Evaluate the properties to trigger validation
|
|
87
|
+
_ = self.time_dimension
|
|
88
|
+
_ = self.media_time_dimension
|
|
89
|
+
|
|
90
|
+
def _get_time_dimension(self, name: str) -> _DeserializedTimeDimension:
|
|
91
|
+
"""Helper method to get a specific TimeDimension proto by name."""
|
|
92
|
+
for time_dimension in self._metadata.time_dimensions:
|
|
93
|
+
if time_dimension.name == name:
|
|
94
|
+
return _DeserializedTimeDimension(time_dimension)
|
|
95
|
+
raise ValueError(f"No TimeDimension found with name '{name}' in metadata.")
|
|
96
|
+
|
|
97
|
+
@functools.cached_property
|
|
98
|
+
def time_dimension(self) -> _DeserializedTimeDimension:
|
|
99
|
+
"""Returns the TimeDimension with name 'time'."""
|
|
100
|
+
return self._get_time_dimension(c.TIME)
|
|
101
|
+
|
|
102
|
+
@functools.cached_property
|
|
103
|
+
def media_time_dimension(self) -> _DeserializedTimeDimension:
|
|
104
|
+
"""Returns the TimeDimension with name 'media_time'."""
|
|
105
|
+
return self._get_time_dimension(c.MEDIA_TIME)
|
|
106
|
+
|
|
107
|
+
@functools.cached_property
|
|
108
|
+
def channel_dimensions(self) -> Mapping[str, list[str]]:
|
|
109
|
+
"""Returns a mapping of channel dimension names to their corresponding channel coordinate names."""
|
|
110
|
+
return {
|
|
111
|
+
cd.name: list(cd.channels) for cd in self._metadata.channel_dimensions
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
@functools.cached_property
|
|
115
|
+
def channel_types(self) -> Mapping[str, str | None]:
|
|
116
|
+
"""Returns a mapping of individual channel names to their types."""
|
|
117
|
+
channel_coord_map = {}
|
|
118
|
+
for name, channels in self.channel_dimensions.items():
|
|
119
|
+
for channel in channels:
|
|
120
|
+
channel_coord_map[channel] = _COORD_NAME_MAP.get(
|
|
121
|
+
name,
|
|
122
|
+
)
|
|
123
|
+
return channel_coord_map
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def _extract_data_array(
|
|
127
|
+
serialized_data_points: Sequence[marketing_pb.MarketingDataPoint],
|
|
128
|
+
data_extractor_fn,
|
|
129
|
+
data_name,
|
|
130
|
+
) -> xr.DataArray | None:
|
|
131
|
+
"""Helper function to extract data into an `xr.DataArray`.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
serialized_data_points: A Sequence of MarketingDataPoint protos.
|
|
135
|
+
data_extractor_fn: A function that takes a data point and returns either a
|
|
136
|
+
tuple of `(geo_id, time_str, value)`, or `None` if the data point should
|
|
137
|
+
be skipped.
|
|
138
|
+
data_name: The desired name for the `xr.DataArray`.
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
An `xr.DataArray` containing the extracted data, or `None` if no data is
|
|
142
|
+
found.
|
|
143
|
+
"""
|
|
144
|
+
data_dict = {} # (geo_id, time_str) -> value
|
|
145
|
+
geo_ids = []
|
|
146
|
+
times = []
|
|
147
|
+
|
|
148
|
+
for data_point in serialized_data_points:
|
|
149
|
+
extraction_result = data_extractor_fn(data_point)
|
|
150
|
+
if extraction_result is None:
|
|
151
|
+
continue
|
|
152
|
+
|
|
153
|
+
geo_id, time_str, value = extraction_result
|
|
154
|
+
|
|
155
|
+
# TODO: Enforce dimension uniqueness in Meridian.
|
|
156
|
+
if geo_id not in geo_ids:
|
|
157
|
+
geo_ids.append(geo_id)
|
|
158
|
+
if time_str not in times:
|
|
159
|
+
times.append(time_str)
|
|
160
|
+
|
|
161
|
+
data_dict[(geo_id, time_str)] = value
|
|
162
|
+
|
|
163
|
+
if not data_dict:
|
|
164
|
+
return None
|
|
165
|
+
|
|
166
|
+
data_values = np.array([
|
|
167
|
+
[data_dict.get((geo_id, time), np.nan) for time in times]
|
|
168
|
+
for geo_id in geo_ids
|
|
169
|
+
])
|
|
170
|
+
|
|
171
|
+
return xr.DataArray(
|
|
172
|
+
data=data_values,
|
|
173
|
+
coords={
|
|
174
|
+
c.GEO: geo_ids,
|
|
175
|
+
c.TIME: times,
|
|
176
|
+
},
|
|
177
|
+
dims=(c.GEO, c.TIME),
|
|
178
|
+
name=data_name,
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def _extract_3d_data_array(
|
|
183
|
+
serialized_data_points: Sequence[marketing_pb.MarketingDataPoint],
|
|
184
|
+
data_extractor_fn,
|
|
185
|
+
data_name,
|
|
186
|
+
third_dim_name,
|
|
187
|
+
time_dim_name=c.TIME,
|
|
188
|
+
) -> xr.DataArray | None:
|
|
189
|
+
"""Helper function to extract data with 3 dimensions into an `xr.DataArray`.
|
|
190
|
+
|
|
191
|
+
The first dimension is always `GEO`, and the second is the time dimension
|
|
192
|
+
(default: `TIME`).
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
serialized_data_points: A sequence of MarketingDataPoint protos.
|
|
196
|
+
data_extractor_fn: A function that takes a data point and returns either a
|
|
197
|
+
tuple of `(geo_id, time_str, third_dim_key, value)`, or `None` if the
|
|
198
|
+
data point should be skipped.
|
|
199
|
+
data_name: The desired name for the `xr.DataArray`.
|
|
200
|
+
third_dim_name: The name of the third dimension.
|
|
201
|
+
time_dim_name: The name of the time dimension. Default is `TIME`.
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
An `xr.DataArray` containing the extracted data, or `None` if no data is
|
|
205
|
+
found.
|
|
206
|
+
"""
|
|
207
|
+
data_dict = {} # (geo_id, time_str, third_dim_key) -> value
|
|
208
|
+
geo_ids = []
|
|
209
|
+
times = []
|
|
210
|
+
third_dim_keys = []
|
|
211
|
+
|
|
212
|
+
for data_point in serialized_data_points:
|
|
213
|
+
for extraction_result in data_extractor_fn(data_point):
|
|
214
|
+
geo_id, time_str, third_dim_key, value = extraction_result
|
|
215
|
+
|
|
216
|
+
if geo_id not in geo_ids:
|
|
217
|
+
geo_ids.append(geo_id)
|
|
218
|
+
if time_str not in times:
|
|
219
|
+
times.append(time_str)
|
|
220
|
+
if third_dim_key not in third_dim_keys:
|
|
221
|
+
third_dim_keys.append(third_dim_key)
|
|
222
|
+
|
|
223
|
+
# TODO: Enforce dimension uniqueness in Meridian.
|
|
224
|
+
data_dict[(geo_id, time_str, third_dim_key)] = value
|
|
225
|
+
|
|
226
|
+
if not data_dict:
|
|
227
|
+
return None
|
|
228
|
+
|
|
229
|
+
data_values = np.array([
|
|
230
|
+
[
|
|
231
|
+
[
|
|
232
|
+
data_dict.get((geo_id, time, third_dim_key), np.nan)
|
|
233
|
+
for third_dim_key in third_dim_keys
|
|
234
|
+
]
|
|
235
|
+
for time in times
|
|
236
|
+
]
|
|
237
|
+
for geo_id in geo_ids
|
|
238
|
+
])
|
|
239
|
+
|
|
240
|
+
return xr.DataArray(
|
|
241
|
+
data=data_values,
|
|
242
|
+
coords={
|
|
243
|
+
c.GEO: geo_ids,
|
|
244
|
+
time_dim_name: times,
|
|
245
|
+
third_dim_name: third_dim_keys,
|
|
246
|
+
},
|
|
247
|
+
dims=(c.GEO, time_dim_name, third_dim_name),
|
|
248
|
+
name=data_name,
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def _get_date_interval_from_date_intervals(
|
|
253
|
+
date_intervals: Sequence[date_interval_pb2.DateInterval],
|
|
254
|
+
) -> date_interval_pb2.DateInterval:
|
|
255
|
+
"""Gets the date interval based on the earliest start date and latest end date.
|
|
256
|
+
|
|
257
|
+
Args:
|
|
258
|
+
date_intervals: A list of DateInterval protos.
|
|
259
|
+
|
|
260
|
+
Returns:
|
|
261
|
+
A DateInterval representing the earliest start date and latest end date.
|
|
262
|
+
"""
|
|
263
|
+
get_start_date = lambda interval: dt.date(
|
|
264
|
+
interval.start_date.year,
|
|
265
|
+
interval.start_date.month,
|
|
266
|
+
interval.start_date.day,
|
|
267
|
+
)
|
|
268
|
+
get_end_date = lambda interval: dt.date(
|
|
269
|
+
interval.end_date.year, interval.end_date.month, interval.end_date.day
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
min_start_date_interval = min(date_intervals, key=get_start_date)
|
|
273
|
+
max_end_date_interval = max(date_intervals, key=get_end_date)
|
|
274
|
+
|
|
275
|
+
return date_interval_pb2.DateInterval(
|
|
276
|
+
start_date=date_pb2.Date(
|
|
277
|
+
year=min_start_date_interval.start_date.year,
|
|
278
|
+
month=min_start_date_interval.start_date.month,
|
|
279
|
+
day=min_start_date_interval.start_date.day,
|
|
280
|
+
),
|
|
281
|
+
end_date=date_pb2.Date(
|
|
282
|
+
year=max_end_date_interval.end_date.year,
|
|
283
|
+
month=max_end_date_interval.end_date.month,
|
|
284
|
+
day=max_end_date_interval.end_date.day,
|
|
285
|
+
),
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
class _InputDataSerializer:
|
|
290
|
+
"""Serializes an `InputData` container in Meridian model."""
|
|
291
|
+
|
|
292
|
+
def __init__(self, input_data: meridian_input_data.InputData):
|
|
293
|
+
self._input_data = input_data
|
|
294
|
+
|
|
295
|
+
@property
|
|
296
|
+
def _n_geos(self) -> int:
|
|
297
|
+
return len(self._input_data.geo)
|
|
298
|
+
|
|
299
|
+
@property
|
|
300
|
+
def _n_times(self) -> int:
|
|
301
|
+
return len(self._input_data.time)
|
|
302
|
+
|
|
303
|
+
def __call__(self) -> marketing_pb.MarketingData:
|
|
304
|
+
"""Serializes the input data into a MarketingData proto."""
|
|
305
|
+
marketing_proto = marketing_pb.MarketingData()
|
|
306
|
+
# Use media_time since it covers larger range.
|
|
307
|
+
times_to_date_intervals = time_record.convert_times_to_date_intervals(
|
|
308
|
+
self._input_data.media_time.data
|
|
309
|
+
)
|
|
310
|
+
geos_and_times = itertools.product(
|
|
311
|
+
self._input_data.geo.data, self._input_data.media_time.data
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
for geo, time in geos_and_times:
|
|
315
|
+
data_point = self._serialize_data_point(
|
|
316
|
+
geo,
|
|
317
|
+
time,
|
|
318
|
+
times_to_date_intervals,
|
|
319
|
+
)
|
|
320
|
+
marketing_proto.marketing_data_points.append(data_point)
|
|
321
|
+
|
|
322
|
+
if self._input_data.media_spend is not None:
|
|
323
|
+
if (
|
|
324
|
+
not self._input_data.media_spend_has_geo_dimension
|
|
325
|
+
and not self._input_data.media_spend_has_time_dimension
|
|
326
|
+
):
|
|
327
|
+
marketing_proto.marketing_data_points.append(
|
|
328
|
+
self._serialize_aggregated_media_spend_data_point(
|
|
329
|
+
self._input_data.media_spend,
|
|
330
|
+
times_to_date_intervals,
|
|
331
|
+
)
|
|
332
|
+
)
|
|
333
|
+
elif (
|
|
334
|
+
self._input_data.media_spend_has_geo_dimension
|
|
335
|
+
!= self._input_data.media_spend_has_time_dimension
|
|
336
|
+
):
|
|
337
|
+
raise AssertionError(
|
|
338
|
+
"Invalid input data: media_spend must either be fully granular"
|
|
339
|
+
" (both geo and time dimensions) or fully aggregated (neither geo"
|
|
340
|
+
" nor time dimensions)."
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
if self._input_data.rf_spend is not None:
|
|
344
|
+
if (
|
|
345
|
+
not self._input_data.rf_spend_has_geo_dimension
|
|
346
|
+
and not self._input_data.rf_spend_has_time_dimension
|
|
347
|
+
):
|
|
348
|
+
marketing_proto.marketing_data_points.append(
|
|
349
|
+
self._serialize_aggregated_rf_spend_data_point(
|
|
350
|
+
self._input_data.rf_spend, times_to_date_intervals
|
|
351
|
+
)
|
|
352
|
+
)
|
|
353
|
+
elif (
|
|
354
|
+
self._input_data.rf_spend_has_geo_dimension
|
|
355
|
+
!= self._input_data.rf_spend_has_time_dimension
|
|
356
|
+
):
|
|
357
|
+
raise AssertionError(
|
|
358
|
+
"Invalid input data: rf_spend must either be fully granular (both"
|
|
359
|
+
" geo and time dimensions) or fully aggregated (neither geo nor"
|
|
360
|
+
" time dimensions)."
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
marketing_proto.metadata.CopyFrom(self._serialize_metadata())
|
|
364
|
+
|
|
365
|
+
return marketing_proto
|
|
366
|
+
|
|
367
|
+
def _serialize_media_variables(
|
|
368
|
+
self,
|
|
369
|
+
geo: str,
|
|
370
|
+
time: str,
|
|
371
|
+
channel_dim_name: str,
|
|
372
|
+
impressions_data_array: xr.DataArray,
|
|
373
|
+
spend_data_array: xr.DataArray | None = None,
|
|
374
|
+
) -> list[marketing_pb.MediaVariable]:
|
|
375
|
+
"""Serializes media variables for a given geo and time.
|
|
376
|
+
|
|
377
|
+
Args:
|
|
378
|
+
geo: The geo ID.
|
|
379
|
+
time: The time string.
|
|
380
|
+
channel_dim_name: The name of the channel dimension.
|
|
381
|
+
impressions_data_array: The DataArray containing impressions data.
|
|
382
|
+
Expected dimensions: `(n_geos, n_media_times, n_channels)`.
|
|
383
|
+
spend_data_array: The optional DataArray containing spend data. Expected
|
|
384
|
+
dimensions are `(n_geos, n_times, n_media_channels)`.
|
|
385
|
+
|
|
386
|
+
Returns:
|
|
387
|
+
A list of MediaVariable protos.
|
|
388
|
+
"""
|
|
389
|
+
media_variables = []
|
|
390
|
+
for media_data in impressions_data_array.sel(geo=geo, media_time=time):
|
|
391
|
+
channel = media_data[channel_dim_name].item()
|
|
392
|
+
media_variable = marketing_pb.MediaVariable(
|
|
393
|
+
channel_name=channel,
|
|
394
|
+
scalar_metric=marketing_pb.ScalarMetric(
|
|
395
|
+
name=c.IMPRESSIONS, value=media_data.item()
|
|
396
|
+
),
|
|
397
|
+
)
|
|
398
|
+
if spend_data_array is not None and time in spend_data_array.time:
|
|
399
|
+
media_variable.media_spend = spend_data_array.sel(
|
|
400
|
+
geo=geo, time=time, **{channel_dim_name: channel}
|
|
401
|
+
).item()
|
|
402
|
+
media_variables.append(media_variable)
|
|
403
|
+
return media_variables
|
|
404
|
+
|
|
405
|
+
def _serialize_reach_frequency_variables(
|
|
406
|
+
self,
|
|
407
|
+
geo: str,
|
|
408
|
+
time: str,
|
|
409
|
+
channel_dim_name: str,
|
|
410
|
+
reach_data_array: xr.DataArray,
|
|
411
|
+
frequency_data_array: xr.DataArray,
|
|
412
|
+
spend_data_array: xr.DataArray | None = None,
|
|
413
|
+
) -> list[marketing_pb.ReachFrequencyVariable]:
|
|
414
|
+
"""Serializes reach and frequency variables for a given geo and time.
|
|
415
|
+
|
|
416
|
+
Iterates through the R&F channels separately, creating a MediaVariable
|
|
417
|
+
for each. It's safe to assume that Meridian media channel names are
|
|
418
|
+
unique across `media_data` and `reach_data`. This assumption is
|
|
419
|
+
checked when an `InputData` is created in model training.
|
|
420
|
+
|
|
421
|
+
Dimensions of `reach_data_array` and `frequency_data_array` are expected
|
|
422
|
+
to be `(n_geos, n_media_times, n_rf_channels)`.
|
|
423
|
+
|
|
424
|
+
Args:
|
|
425
|
+
geo: The geo ID.
|
|
426
|
+
time: The time string.
|
|
427
|
+
channel_dim_name: The name of the channel dimension (e.g., 'rf_channel').
|
|
428
|
+
reach_data_array: The DataArray containing reach data.
|
|
429
|
+
frequency_data_array: The DataArray containing frequency data.
|
|
430
|
+
spend_data_array: The optional DataArray containing spend data.
|
|
431
|
+
|
|
432
|
+
Returns:
|
|
433
|
+
A list of MediaVariable protos.
|
|
434
|
+
"""
|
|
435
|
+
rf_variables = []
|
|
436
|
+
for reach_data in reach_data_array.sel(geo=geo, media_time=time):
|
|
437
|
+
reach_value = reach_data.item()
|
|
438
|
+
channel = reach_data[channel_dim_name].item()
|
|
439
|
+
frequency_value = frequency_data_array.sel(
|
|
440
|
+
geo=geo,
|
|
441
|
+
media_time=time,
|
|
442
|
+
**{channel_dim_name: channel},
|
|
443
|
+
).item()
|
|
444
|
+
rf_variable = marketing_pb.ReachFrequencyVariable(
|
|
445
|
+
channel_name=channel,
|
|
446
|
+
reach=int(reach_value),
|
|
447
|
+
average_frequency=frequency_value,
|
|
448
|
+
)
|
|
449
|
+
if spend_data_array is not None and time in spend_data_array.time:
|
|
450
|
+
rf_variable.spend = spend_data_array.sel(
|
|
451
|
+
geo=geo, time=time, **{channel_dim_name: channel}
|
|
452
|
+
).item()
|
|
453
|
+
rf_variables.append(rf_variable)
|
|
454
|
+
return rf_variables
|
|
455
|
+
|
|
456
|
+
def _serialize_non_media_treatment_variables(
|
|
457
|
+
self, geo: str, time: str
|
|
458
|
+
) -> list[marketing_pb.NonMediaTreatmentVariable]:
|
|
459
|
+
"""Serializes non-media treatment variables for a given geo and time.
|
|
460
|
+
|
|
461
|
+
Args:
|
|
462
|
+
geo: The geo ID.
|
|
463
|
+
time: The time string.
|
|
464
|
+
|
|
465
|
+
Returns:
|
|
466
|
+
A list of NonMediaTreatmentVariable protos.
|
|
467
|
+
"""
|
|
468
|
+
non_media_treatment_variables = []
|
|
469
|
+
if (
|
|
470
|
+
self._input_data.non_media_treatments is not None
|
|
471
|
+
and geo in self._input_data.non_media_treatments.geo
|
|
472
|
+
and time in self._input_data.non_media_treatments.time
|
|
473
|
+
):
|
|
474
|
+
for non_media_treatment_data in self._input_data.non_media_treatments.sel(
|
|
475
|
+
geo=geo, time=time
|
|
476
|
+
):
|
|
477
|
+
non_media_treatment_variables.append(
|
|
478
|
+
marketing_pb.NonMediaTreatmentVariable(
|
|
479
|
+
name=non_media_treatment_data[c.NON_MEDIA_CHANNEL].item(),
|
|
480
|
+
value=non_media_treatment_data.item(),
|
|
481
|
+
)
|
|
482
|
+
)
|
|
483
|
+
return non_media_treatment_variables
|
|
484
|
+
|
|
485
|
+
def _serialize_data_point(
|
|
486
|
+
self,
|
|
487
|
+
geo: str,
|
|
488
|
+
time: str,
|
|
489
|
+
times_to_date_intervals: Mapping[str, date_interval_pb2.DateInterval],
|
|
490
|
+
) -> marketing_pb.MarketingDataPoint:
|
|
491
|
+
"""Serializes a MarketingDataPoint proto for a given geo and time."""
|
|
492
|
+
data_point = marketing_pb.MarketingDataPoint(
|
|
493
|
+
geo_info=marketing_pb.GeoInfo(
|
|
494
|
+
geo_id=geo,
|
|
495
|
+
population=round(self._input_data.population.sel(geo=geo).item()),
|
|
496
|
+
),
|
|
497
|
+
date_interval=times_to_date_intervals.get(time),
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
if self._input_data.controls is not None:
|
|
501
|
+
if time in self._input_data.controls.time:
|
|
502
|
+
for control_data in self._input_data.controls.sel(geo=geo, time=time):
|
|
503
|
+
data_point.control_variables.add(
|
|
504
|
+
name=control_data.control_variable.item(),
|
|
505
|
+
value=control_data.item(),
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
if self._input_data.media is not None:
|
|
509
|
+
if (
|
|
510
|
+
self._input_data.media_spend_has_geo_dimension
|
|
511
|
+
and self._input_data.media_spend_has_time_dimension
|
|
512
|
+
):
|
|
513
|
+
spend_data_array = self._input_data.media_spend
|
|
514
|
+
else:
|
|
515
|
+
# Aggregated spend data is serialized in a separate data point.
|
|
516
|
+
spend_data_array = None
|
|
517
|
+
media_variables = self._serialize_media_variables(
|
|
518
|
+
geo,
|
|
519
|
+
time,
|
|
520
|
+
c.MEDIA_CHANNEL,
|
|
521
|
+
self._input_data.media,
|
|
522
|
+
spend_data_array,
|
|
523
|
+
)
|
|
524
|
+
data_point.media_variables.extend(media_variables)
|
|
525
|
+
|
|
526
|
+
if (
|
|
527
|
+
self._input_data.reach is not None
|
|
528
|
+
and self._input_data.frequency is not None
|
|
529
|
+
):
|
|
530
|
+
if (
|
|
531
|
+
self._input_data.rf_spend_has_geo_dimension
|
|
532
|
+
and self._input_data.rf_spend_has_time_dimension
|
|
533
|
+
):
|
|
534
|
+
rf_spend_data_array = self._input_data.rf_spend
|
|
535
|
+
else:
|
|
536
|
+
# Aggregated spend data is serialized in a separate data point.
|
|
537
|
+
rf_spend_data_array = None
|
|
538
|
+
rf_variables = self._serialize_reach_frequency_variables(
|
|
539
|
+
geo,
|
|
540
|
+
time,
|
|
541
|
+
c.RF_CHANNEL,
|
|
542
|
+
self._input_data.reach,
|
|
543
|
+
self._input_data.frequency,
|
|
544
|
+
rf_spend_data_array,
|
|
545
|
+
)
|
|
546
|
+
data_point.reach_frequency_variables.extend(rf_variables)
|
|
547
|
+
|
|
548
|
+
if self._input_data.organic_media is not None:
|
|
549
|
+
organic_media_variables = self._serialize_media_variables(
|
|
550
|
+
geo, time, c.ORGANIC_MEDIA_CHANNEL, self._input_data.organic_media
|
|
551
|
+
)
|
|
552
|
+
data_point.media_variables.extend(organic_media_variables)
|
|
553
|
+
|
|
554
|
+
if (
|
|
555
|
+
self._input_data.organic_reach is not None
|
|
556
|
+
and self._input_data.organic_frequency is not None
|
|
557
|
+
):
|
|
558
|
+
organic_rf_variables = self._serialize_reach_frequency_variables(
|
|
559
|
+
geo,
|
|
560
|
+
time,
|
|
561
|
+
c.ORGANIC_RF_CHANNEL,
|
|
562
|
+
self._input_data.organic_reach,
|
|
563
|
+
self._input_data.organic_frequency,
|
|
564
|
+
)
|
|
565
|
+
data_point.reach_frequency_variables.extend(organic_rf_variables)
|
|
566
|
+
|
|
567
|
+
non_media_treatment_variables = (
|
|
568
|
+
self._serialize_non_media_treatment_variables(geo, time)
|
|
569
|
+
)
|
|
570
|
+
data_point.non_media_treatment_variables.extend(
|
|
571
|
+
non_media_treatment_variables
|
|
572
|
+
)
|
|
573
|
+
|
|
574
|
+
if time in self._input_data.kpi.time:
|
|
575
|
+
kpi_proto = self._make_kpi_proto(geo, time)
|
|
576
|
+
data_point.kpi.CopyFrom(kpi_proto)
|
|
577
|
+
|
|
578
|
+
return data_point
|
|
579
|
+
|
|
580
|
+
def _serialize_aggregated_media_spend_data_point(
|
|
581
|
+
self,
|
|
582
|
+
spend_data_array: xr.DataArray,
|
|
583
|
+
times_to_date_intervals: Mapping[str, date_interval_pb2.DateInterval],
|
|
584
|
+
) -> marketing_pb.MarketingDataPoint:
|
|
585
|
+
"""Serializes and appends a data point for aggregated spend."""
|
|
586
|
+
spend_data_point = marketing_pb.MarketingDataPoint()
|
|
587
|
+
date_interval = _get_date_interval_from_date_intervals(
|
|
588
|
+
list(times_to_date_intervals.values())
|
|
589
|
+
)
|
|
590
|
+
spend_data_point.date_interval.CopyFrom(date_interval)
|
|
591
|
+
|
|
592
|
+
for channel_name in spend_data_array.coords[c.MEDIA_CHANNEL].values:
|
|
593
|
+
spend_value = spend_data_array.sel(
|
|
594
|
+
**{c.MEDIA_CHANNEL: channel_name}
|
|
595
|
+
).item()
|
|
596
|
+
spend_data_point.media_variables.add(
|
|
597
|
+
channel_name=channel_name, media_spend=spend_value
|
|
598
|
+
)
|
|
599
|
+
|
|
600
|
+
return spend_data_point
|
|
601
|
+
|
|
602
|
+
def _serialize_aggregated_rf_spend_data_point(
|
|
603
|
+
self,
|
|
604
|
+
spend_data_array: xr.DataArray,
|
|
605
|
+
times_to_date_intervals: Mapping[str, date_interval_pb2.DateInterval],
|
|
606
|
+
) -> marketing_pb.MarketingDataPoint:
|
|
607
|
+
"""Serializes and appends a data point for aggregated spend."""
|
|
608
|
+
spend_data_point = marketing_pb.MarketingDataPoint()
|
|
609
|
+
date_interval = _get_date_interval_from_date_intervals(
|
|
610
|
+
list(times_to_date_intervals.values())
|
|
611
|
+
)
|
|
612
|
+
spend_data_point.date_interval.CopyFrom(date_interval)
|
|
613
|
+
|
|
614
|
+
for channel_name in spend_data_array.coords[c.RF_CHANNEL].values:
|
|
615
|
+
spend_value = spend_data_array.sel(**{c.RF_CHANNEL: channel_name}).item()
|
|
616
|
+
spend_data_point.reach_frequency_variables.add(
|
|
617
|
+
channel_name=channel_name, spend=spend_value
|
|
618
|
+
)
|
|
619
|
+
|
|
620
|
+
return spend_data_point
|
|
621
|
+
|
|
622
|
+
def _serialize_time_dimensions(
|
|
623
|
+
self, name: str, time_data: xr.DataArray
|
|
624
|
+
) -> marketing_pb.MarketingDataMetadata.TimeDimension:
|
|
625
|
+
"""Creates a TimeDimension message."""
|
|
626
|
+
time_dimensions = marketing_pb.MarketingDataMetadata.TimeDimension(
|
|
627
|
+
name=name
|
|
628
|
+
)
|
|
629
|
+
for date in time_data.values:
|
|
630
|
+
date_obj = dt.datetime.strptime(date, c.DATE_FORMAT).date()
|
|
631
|
+
time_dimensions.dates.add(
|
|
632
|
+
year=date_obj.year, month=date_obj.month, day=date_obj.day
|
|
633
|
+
)
|
|
634
|
+
return time_dimensions
|
|
635
|
+
|
|
636
|
+
def _serialize_channel_dimensions(
|
|
637
|
+
self, channel_data: xr.DataArray | None
|
|
638
|
+
) -> marketing_pb.MarketingDataMetadata.ChannelDimension | None:
|
|
639
|
+
"""Creates a ChannelDimension message if the corresponding attribute exists."""
|
|
640
|
+
if channel_data is None:
|
|
641
|
+
return None
|
|
642
|
+
|
|
643
|
+
coord_name = _COORD_NAME_MAP.get(channel_data.name)
|
|
644
|
+
if coord_name:
|
|
645
|
+
return marketing_pb.MarketingDataMetadata.ChannelDimension(
|
|
646
|
+
name=channel_data.name,
|
|
647
|
+
channels=channel_data.coords[coord_name].values.tolist(),
|
|
648
|
+
)
|
|
649
|
+
else:
|
|
650
|
+
# Make sure that all channel dimensions are handled.
|
|
651
|
+
raise ValueError(f"Unknown channel data name: {channel_data.name}. ")
|
|
652
|
+
|
|
653
|
+
def _serialize_metadata(self) -> marketing_pb.MarketingDataMetadata:
|
|
654
|
+
"""Serializes metadata from InputData to MarketingDataMetadata."""
|
|
655
|
+
metadata = marketing_pb.MarketingDataMetadata()
|
|
656
|
+
|
|
657
|
+
metadata.time_dimensions.append(
|
|
658
|
+
self._serialize_time_dimensions(c.TIME, self._input_data.time)
|
|
659
|
+
)
|
|
660
|
+
metadata.time_dimensions.append(
|
|
661
|
+
self._serialize_time_dimensions(
|
|
662
|
+
c.MEDIA_TIME, self._input_data.media_time
|
|
663
|
+
)
|
|
664
|
+
)
|
|
665
|
+
|
|
666
|
+
channel_data_arrays = [
|
|
667
|
+
self._input_data.media,
|
|
668
|
+
self._input_data.reach,
|
|
669
|
+
self._input_data.frequency,
|
|
670
|
+
self._input_data.organic_media,
|
|
671
|
+
self._input_data.organic_reach,
|
|
672
|
+
self._input_data.organic_frequency,
|
|
673
|
+
]
|
|
674
|
+
|
|
675
|
+
for channel_data_array in channel_data_arrays:
|
|
676
|
+
channel_names_message = self._serialize_channel_dimensions(
|
|
677
|
+
channel_data_array
|
|
678
|
+
)
|
|
679
|
+
if channel_names_message:
|
|
680
|
+
metadata.channel_dimensions.append(channel_names_message)
|
|
681
|
+
|
|
682
|
+
if self._input_data.controls is not None:
|
|
683
|
+
metadata.control_names.extend(
|
|
684
|
+
self._input_data.controls.control_variable.values
|
|
685
|
+
)
|
|
686
|
+
|
|
687
|
+
if self._input_data.non_media_treatments is not None:
|
|
688
|
+
metadata.non_media_treatment_names.extend(
|
|
689
|
+
self._input_data.non_media_treatments.non_media_channel.values
|
|
690
|
+
)
|
|
691
|
+
|
|
692
|
+
metadata.kpi_type = self._input_data.kpi_type
|
|
693
|
+
|
|
694
|
+
return metadata
|
|
695
|
+
|
|
696
|
+
def _make_kpi_proto(self, geo: str, time: str) -> marketing_pb.Kpi:
|
|
697
|
+
"""Constructs a Kpi proto from the TrainedModel."""
|
|
698
|
+
kpi_proto = marketing_pb.Kpi(name=self._input_data.kpi_type)
|
|
699
|
+
# `kpi` and `revenue_per_kpi` dimensions: `(n_geos, n_times)`.
|
|
700
|
+
if self._input_data.kpi_type == c.REVENUE:
|
|
701
|
+
kpi_proto.revenue.CopyFrom(
|
|
702
|
+
marketing_pb.Kpi.Revenue(
|
|
703
|
+
value=self._input_data.kpi.sel(geo=geo, time=time).item()
|
|
704
|
+
)
|
|
705
|
+
)
|
|
706
|
+
else:
|
|
707
|
+
kpi_proto.non_revenue.CopyFrom(
|
|
708
|
+
marketing_pb.Kpi.NonRevenue(
|
|
709
|
+
value=self._input_data.kpi.sel(geo=geo, time=time).item()
|
|
710
|
+
)
|
|
711
|
+
)
|
|
712
|
+
if self._input_data.revenue_per_kpi is not None:
|
|
713
|
+
kpi_proto.non_revenue.revenue_per_kpi = (
|
|
714
|
+
self._input_data.revenue_per_kpi.sel(geo=geo, time=time).item()
|
|
715
|
+
)
|
|
716
|
+
return kpi_proto
|
|
717
|
+
|
|
718
|
+
|
|
719
|
+
class _InputDataDeserializer:
|
|
720
|
+
"""Deserializes a `MarketingData` proto into a Meridian `InputData`."""
|
|
721
|
+
|
|
722
|
+
def __init__(self, serialized: marketing_pb.MarketingData):
|
|
723
|
+
self._serialized = serialized
|
|
724
|
+
|
|
725
|
+
def __post_init__(self):
|
|
726
|
+
if not self._serialized.HasField(sc.METADATA):
|
|
727
|
+
raise ValueError(
|
|
728
|
+
f"MarketingData proto is missing the '{sc.METADATA}' field."
|
|
729
|
+
)
|
|
730
|
+
|
|
731
|
+
@functools.cached_property
|
|
732
|
+
def _metadata(self) -> _DeserializedMetadata:
|
|
733
|
+
"""Parses metadata and extracts time dimensions, channel dimensions, and channel type map."""
|
|
734
|
+
return _DeserializedMetadata(self._serialized.metadata)
|
|
735
|
+
|
|
736
|
+
def _extract_population(self) -> xr.DataArray:
|
|
737
|
+
"""Extracts population data from the serialized proto."""
|
|
738
|
+
geo_populations = {}
|
|
739
|
+
|
|
740
|
+
for data_point in self._serialized.marketing_data_points:
|
|
741
|
+
geo_id = data_point.geo_info.geo_id
|
|
742
|
+
if not geo_id:
|
|
743
|
+
continue
|
|
744
|
+
|
|
745
|
+
geo_populations[geo_id] = data_point.geo_info.population
|
|
746
|
+
|
|
747
|
+
return xr.DataArray(
|
|
748
|
+
coords={c.GEO: list(geo_populations.keys())},
|
|
749
|
+
data=np.array(list(geo_populations.values())),
|
|
750
|
+
name=c.POPULATION,
|
|
751
|
+
)
|
|
752
|
+
|
|
753
|
+
def _extract_kpi_type(self) -> str:
|
|
754
|
+
"""Extracts the kpi_type from the serialized proto."""
|
|
755
|
+
kpi_type = None
|
|
756
|
+
for data_point in self._serialized.marketing_data_points:
|
|
757
|
+
if data_point.HasField(c.KPI):
|
|
758
|
+
current_kpi_type = data_point.kpi.WhichOneof(c.TYPE)
|
|
759
|
+
|
|
760
|
+
if kpi_type is None:
|
|
761
|
+
kpi_type = current_kpi_type
|
|
762
|
+
elif kpi_type != current_kpi_type:
|
|
763
|
+
raise ValueError(
|
|
764
|
+
"Inconsistent kpi_type found in the data. "
|
|
765
|
+
f"Expected {kpi_type}, found {current_kpi_type}"
|
|
766
|
+
)
|
|
767
|
+
|
|
768
|
+
if kpi_type is None:
|
|
769
|
+
raise ValueError("kpi_type not found in the data.")
|
|
770
|
+
return kpi_type
|
|
771
|
+
|
|
772
|
+
def _extract_geo_and_time(self, data_point) -> tuple[str | None, str]:
|
|
773
|
+
"""Extracts geo_id and time_str from a data_point."""
|
|
774
|
+
geo_id = data_point.geo_info.geo_id
|
|
775
|
+
start_date = data_point.date_interval.start_date
|
|
776
|
+
time_str = dt.datetime(
|
|
777
|
+
start_date.year, start_date.month, start_date.day
|
|
778
|
+
).strftime(c.DATE_FORMAT)
|
|
779
|
+
return geo_id, time_str
|
|
780
|
+
|
|
781
|
+
def _extract_kpi(self, kpi_type: str) -> xr.DataArray:
|
|
782
|
+
"""Extracts KPI data from the serialized proto."""
|
|
783
|
+
|
|
784
|
+
def _kpi_extractor(data_point):
|
|
785
|
+
if not data_point.HasField(c.KPI):
|
|
786
|
+
return None
|
|
787
|
+
|
|
788
|
+
geo_id, time_str = self._extract_geo_and_time(data_point)
|
|
789
|
+
|
|
790
|
+
if data_point.kpi.WhichOneof(c.TYPE) != kpi_type:
|
|
791
|
+
raise ValueError(
|
|
792
|
+
"Inconsistent kpi_type found in the data. "
|
|
793
|
+
f"Expected {kpi_type}, found"
|
|
794
|
+
f" {data_point.kpi.WhichOneof(c.TYPE)}"
|
|
795
|
+
)
|
|
796
|
+
|
|
797
|
+
kpi_value = (
|
|
798
|
+
data_point.kpi.revenue.value
|
|
799
|
+
if kpi_type == c.REVENUE
|
|
800
|
+
else data_point.kpi.non_revenue.value
|
|
801
|
+
)
|
|
802
|
+
return geo_id, time_str, kpi_value
|
|
803
|
+
|
|
804
|
+
kpi = _extract_data_array(
|
|
805
|
+
serialized_data_points=self._serialized.marketing_data_points,
|
|
806
|
+
data_extractor_fn=_kpi_extractor,
|
|
807
|
+
data_name=c.KPI,
|
|
808
|
+
)
|
|
809
|
+
|
|
810
|
+
if kpi is None:
|
|
811
|
+
raise ValueError(f"{c.KPI} is not found in the data.")
|
|
812
|
+
|
|
813
|
+
return kpi
|
|
814
|
+
|
|
815
|
+
def _extract_revenue_per_kpi(self, kpi_type: str) -> xr.DataArray | None:
|
|
816
|
+
"""Extracts revenue per KPI data from the serialized proto."""
|
|
817
|
+
|
|
818
|
+
if kpi_type == c.REVENUE:
|
|
819
|
+
raise ValueError(
|
|
820
|
+
f"{c.REVENUE_PER_KPI} is not applicable when kpi_type is {c.REVENUE}."
|
|
821
|
+
)
|
|
822
|
+
|
|
823
|
+
def _revenue_per_kpi_extractor(data_point):
|
|
824
|
+
if not data_point.HasField(c.KPI):
|
|
825
|
+
return None
|
|
826
|
+
|
|
827
|
+
if not data_point.kpi.non_revenue.HasField(c.REVENUE_PER_KPI):
|
|
828
|
+
return None
|
|
829
|
+
|
|
830
|
+
geo_id, time_str = self._extract_geo_and_time(data_point)
|
|
831
|
+
|
|
832
|
+
if data_point.kpi.WhichOneof(c.TYPE) != kpi_type:
|
|
833
|
+
raise ValueError(
|
|
834
|
+
"Inconsistent kpi_type found in the data. "
|
|
835
|
+
f"Expected {kpi_type}, found"
|
|
836
|
+
f" {data_point.kpi.WhichOneof(c.TYPE)}"
|
|
837
|
+
)
|
|
838
|
+
|
|
839
|
+
return geo_id, time_str, data_point.kpi.non_revenue.revenue_per_kpi
|
|
840
|
+
|
|
841
|
+
return _extract_data_array(
|
|
842
|
+
serialized_data_points=self._serialized.marketing_data_points,
|
|
843
|
+
data_extractor_fn=_revenue_per_kpi_extractor,
|
|
844
|
+
data_name=c.REVENUE_PER_KPI,
|
|
845
|
+
)
|
|
846
|
+
|
|
847
|
+
def _extract_controls(self) -> xr.DataArray | None:
|
|
848
|
+
"""Extracts control variables data from the serialized proto, if any."""
|
|
849
|
+
|
|
850
|
+
def _controls_extractor(data_point):
|
|
851
|
+
if not data_point.control_variables:
|
|
852
|
+
return None
|
|
853
|
+
|
|
854
|
+
geo_id, time_str = self._extract_geo_and_time(data_point)
|
|
855
|
+
|
|
856
|
+
for control_variable in data_point.control_variables:
|
|
857
|
+
control_name = control_variable.name
|
|
858
|
+
control_value = control_variable.value
|
|
859
|
+
yield geo_id, time_str, control_name, control_value
|
|
860
|
+
|
|
861
|
+
return _extract_3d_data_array(
|
|
862
|
+
serialized_data_points=self._serialized.marketing_data_points,
|
|
863
|
+
data_extractor_fn=_controls_extractor,
|
|
864
|
+
data_name=c.CONTROLS,
|
|
865
|
+
third_dim_name=c.CONTROL_VARIABLE,
|
|
866
|
+
)
|
|
867
|
+
|
|
868
|
+
def _extract_media(self) -> xr.DataArray | None:
|
|
869
|
+
"""Extracts media variables data from the serialized proto."""
|
|
870
|
+
|
|
871
|
+
def _media_extractor(data_point):
|
|
872
|
+
geo_id, time_str = self._extract_geo_and_time(data_point)
|
|
873
|
+
|
|
874
|
+
if not geo_id:
|
|
875
|
+
return None
|
|
876
|
+
|
|
877
|
+
for media_variable in data_point.media_variables:
|
|
878
|
+
channel_name = media_variable.channel_name
|
|
879
|
+
if self._metadata.channel_types.get(channel_name) != c.MEDIA_CHANNEL:
|
|
880
|
+
continue
|
|
881
|
+
|
|
882
|
+
media_value = media_variable.scalar_metric.value
|
|
883
|
+
yield geo_id, time_str, channel_name, media_value
|
|
884
|
+
|
|
885
|
+
return _extract_3d_data_array(
|
|
886
|
+
serialized_data_points=self._serialized.marketing_data_points,
|
|
887
|
+
data_extractor_fn=_media_extractor,
|
|
888
|
+
data_name=c.MEDIA,
|
|
889
|
+
third_dim_name=c.MEDIA_CHANNEL,
|
|
890
|
+
time_dim_name=c.MEDIA_TIME,
|
|
891
|
+
)
|
|
892
|
+
|
|
893
|
+
def _extract_reach(self) -> xr.DataArray | None:
|
|
894
|
+
"""Extracts reach data from the serialized proto."""
|
|
895
|
+
|
|
896
|
+
def _reach_extractor(data_point):
|
|
897
|
+
geo_id, time_str = self._extract_geo_and_time(data_point)
|
|
898
|
+
|
|
899
|
+
if not geo_id:
|
|
900
|
+
return None
|
|
901
|
+
|
|
902
|
+
for rf_variable in data_point.reach_frequency_variables:
|
|
903
|
+
channel_name = rf_variable.channel_name
|
|
904
|
+
if self._metadata.channel_types.get(channel_name) != c.RF_CHANNEL:
|
|
905
|
+
continue
|
|
906
|
+
|
|
907
|
+
reach_value = rf_variable.reach
|
|
908
|
+
yield geo_id, time_str, channel_name, reach_value
|
|
909
|
+
|
|
910
|
+
return _extract_3d_data_array(
|
|
911
|
+
serialized_data_points=self._serialized.marketing_data_points,
|
|
912
|
+
data_extractor_fn=_reach_extractor,
|
|
913
|
+
data_name=c.REACH,
|
|
914
|
+
third_dim_name=c.RF_CHANNEL,
|
|
915
|
+
time_dim_name=c.MEDIA_TIME,
|
|
916
|
+
)
|
|
917
|
+
|
|
918
|
+
def _extract_frequency(self) -> xr.DataArray | None:
|
|
919
|
+
"""Extracts frequency data from the serialized proto."""
|
|
920
|
+
|
|
921
|
+
def _frequency_extractor(data_point):
|
|
922
|
+
geo_id, time_str = self._extract_geo_and_time(data_point)
|
|
923
|
+
|
|
924
|
+
if not geo_id:
|
|
925
|
+
return None
|
|
926
|
+
|
|
927
|
+
for rf_variable in data_point.reach_frequency_variables:
|
|
928
|
+
channel_name = rf_variable.channel_name
|
|
929
|
+
if self._metadata.channel_types.get(channel_name) != c.RF_CHANNEL:
|
|
930
|
+
continue
|
|
931
|
+
|
|
932
|
+
frequency_value = rf_variable.average_frequency
|
|
933
|
+
yield geo_id, time_str, channel_name, frequency_value
|
|
934
|
+
|
|
935
|
+
return _extract_3d_data_array(
|
|
936
|
+
serialized_data_points=self._serialized.marketing_data_points,
|
|
937
|
+
data_extractor_fn=_frequency_extractor,
|
|
938
|
+
data_name=c.FREQUENCY,
|
|
939
|
+
third_dim_name=c.RF_CHANNEL,
|
|
940
|
+
time_dim_name=c.MEDIA_TIME,
|
|
941
|
+
)
|
|
942
|
+
|
|
943
|
+
def _extract_organic_media(self) -> xr.DataArray | None:
|
|
944
|
+
"""Extracts organic media variables data from the serialized proto."""
|
|
945
|
+
|
|
946
|
+
def _organic_media_extractor(data_point):
|
|
947
|
+
geo_id, time_str = self._extract_geo_and_time(data_point)
|
|
948
|
+
|
|
949
|
+
if not geo_id:
|
|
950
|
+
return None
|
|
951
|
+
|
|
952
|
+
for media_variable in data_point.media_variables:
|
|
953
|
+
channel_name = media_variable.channel_name
|
|
954
|
+
if (
|
|
955
|
+
self._metadata.channel_types.get(channel_name)
|
|
956
|
+
!= c.ORGANIC_MEDIA_CHANNEL
|
|
957
|
+
):
|
|
958
|
+
continue
|
|
959
|
+
|
|
960
|
+
media_value = media_variable.scalar_metric.value
|
|
961
|
+
yield geo_id, time_str, channel_name, media_value
|
|
962
|
+
|
|
963
|
+
return _extract_3d_data_array(
|
|
964
|
+
serialized_data_points=self._serialized.marketing_data_points,
|
|
965
|
+
data_extractor_fn=_organic_media_extractor,
|
|
966
|
+
data_name=c.ORGANIC_MEDIA,
|
|
967
|
+
third_dim_name=c.ORGANIC_MEDIA_CHANNEL,
|
|
968
|
+
time_dim_name=c.MEDIA_TIME,
|
|
969
|
+
)
|
|
970
|
+
|
|
971
|
+
def _extract_organic_reach(self) -> xr.DataArray | None:
|
|
972
|
+
"""Extracts organic reach data from the serialized proto."""
|
|
973
|
+
|
|
974
|
+
def _organic_reach_extractor(data_point):
|
|
975
|
+
geo_id, time_str = self._extract_geo_and_time(data_point)
|
|
976
|
+
|
|
977
|
+
if not geo_id:
|
|
978
|
+
return None
|
|
979
|
+
|
|
980
|
+
for rf_variable in data_point.reach_frequency_variables:
|
|
981
|
+
channel_name = rf_variable.channel_name
|
|
982
|
+
if (
|
|
983
|
+
self._metadata.channel_types.get(channel_name)
|
|
984
|
+
!= c.ORGANIC_RF_CHANNEL
|
|
985
|
+
):
|
|
986
|
+
continue
|
|
987
|
+
|
|
988
|
+
reach_value = rf_variable.reach
|
|
989
|
+
yield geo_id, time_str, channel_name, reach_value
|
|
990
|
+
|
|
991
|
+
return _extract_3d_data_array(
|
|
992
|
+
serialized_data_points=self._serialized.marketing_data_points,
|
|
993
|
+
data_extractor_fn=_organic_reach_extractor,
|
|
994
|
+
data_name=c.ORGANIC_REACH,
|
|
995
|
+
third_dim_name=c.ORGANIC_RF_CHANNEL,
|
|
996
|
+
time_dim_name=c.MEDIA_TIME,
|
|
997
|
+
)
|
|
998
|
+
|
|
999
|
+
def _extract_organic_frequency(self) -> xr.DataArray | None:
|
|
1000
|
+
"""Extracts organic frequency data from the serialized proto."""
|
|
1001
|
+
|
|
1002
|
+
def _organic_frequency_extractor(data_point):
|
|
1003
|
+
geo_id, time_str = self._extract_geo_and_time(data_point)
|
|
1004
|
+
|
|
1005
|
+
if not geo_id:
|
|
1006
|
+
return None
|
|
1007
|
+
|
|
1008
|
+
for rf_variable in data_point.reach_frequency_variables:
|
|
1009
|
+
channel_name = rf_variable.channel_name
|
|
1010
|
+
if (
|
|
1011
|
+
self._metadata.channel_types.get(channel_name)
|
|
1012
|
+
!= c.ORGANIC_RF_CHANNEL
|
|
1013
|
+
):
|
|
1014
|
+
continue
|
|
1015
|
+
|
|
1016
|
+
frequency_value = rf_variable.average_frequency
|
|
1017
|
+
yield geo_id, time_str, channel_name, frequency_value
|
|
1018
|
+
|
|
1019
|
+
return _extract_3d_data_array(
|
|
1020
|
+
serialized_data_points=self._serialized.marketing_data_points,
|
|
1021
|
+
data_extractor_fn=_organic_frequency_extractor,
|
|
1022
|
+
data_name=c.ORGANIC_FREQUENCY,
|
|
1023
|
+
third_dim_name=c.ORGANIC_RF_CHANNEL,
|
|
1024
|
+
time_dim_name=c.MEDIA_TIME,
|
|
1025
|
+
)
|
|
1026
|
+
|
|
1027
|
+
def _extract_granular_media_spend(
|
|
1028
|
+
self,
|
|
1029
|
+
data_points_with_spend: list[marketing_pb.MarketingDataPoint],
|
|
1030
|
+
) -> xr.DataArray | None:
|
|
1031
|
+
"""Extracts granular spend data.
|
|
1032
|
+
|
|
1033
|
+
Args:
|
|
1034
|
+
data_points_with_spend: List of MarketingDataPoint protos with spend data.
|
|
1035
|
+
|
|
1036
|
+
Returns:
|
|
1037
|
+
An xr.DataArray with granular spend data or None if no data found.
|
|
1038
|
+
"""
|
|
1039
|
+
|
|
1040
|
+
def _granular_spend_extractor(data_point):
|
|
1041
|
+
geo_id, time_str = self._extract_geo_and_time(data_point)
|
|
1042
|
+
for media_variable in data_point.media_variables:
|
|
1043
|
+
if (
|
|
1044
|
+
media_variable.HasField(c.MEDIA_SPEND)
|
|
1045
|
+
and self._metadata.channel_types.get(media_variable.channel_name)
|
|
1046
|
+
== c.MEDIA_CHANNEL
|
|
1047
|
+
):
|
|
1048
|
+
yield geo_id, time_str, media_variable.channel_name, media_variable.media_spend
|
|
1049
|
+
|
|
1050
|
+
return _extract_3d_data_array(
|
|
1051
|
+
serialized_data_points=data_points_with_spend,
|
|
1052
|
+
data_extractor_fn=_granular_spend_extractor,
|
|
1053
|
+
data_name=c.MEDIA_SPEND,
|
|
1054
|
+
third_dim_name=c.MEDIA_CHANNEL,
|
|
1055
|
+
time_dim_name=c.TIME,
|
|
1056
|
+
)
|
|
1057
|
+
|
|
1058
|
+
def _extract_granular_rf_spend(
|
|
1059
|
+
self,
|
|
1060
|
+
data_points_with_spend: list[marketing_pb.MarketingDataPoint],
|
|
1061
|
+
) -> xr.DataArray | None:
|
|
1062
|
+
"""Extracts granular spend data.
|
|
1063
|
+
|
|
1064
|
+
Args:
|
|
1065
|
+
data_points_with_spend: List of MarketingDataPoint protos with spend data.
|
|
1066
|
+
|
|
1067
|
+
Returns:
|
|
1068
|
+
An xr.DataArray with granular spend data or None if no data found.
|
|
1069
|
+
"""
|
|
1070
|
+
|
|
1071
|
+
def _granular_spend_extractor(data_point):
|
|
1072
|
+
geo_id, time_str = self._extract_geo_and_time(data_point)
|
|
1073
|
+
for rf_variable in data_point.reach_frequency_variables:
|
|
1074
|
+
if (
|
|
1075
|
+
rf_variable.HasField(c.SPEND)
|
|
1076
|
+
and self._metadata.channel_types.get(rf_variable.channel_name)
|
|
1077
|
+
== c.RF_CHANNEL
|
|
1078
|
+
):
|
|
1079
|
+
yield geo_id, time_str, rf_variable.channel_name, rf_variable.spend
|
|
1080
|
+
|
|
1081
|
+
return _extract_3d_data_array(
|
|
1082
|
+
serialized_data_points=data_points_with_spend,
|
|
1083
|
+
data_extractor_fn=_granular_spend_extractor,
|
|
1084
|
+
data_name=c.RF_SPEND,
|
|
1085
|
+
third_dim_name=c.RF_CHANNEL,
|
|
1086
|
+
time_dim_name=c.TIME,
|
|
1087
|
+
)
|
|
1088
|
+
|
|
1089
|
+
def _extract_aggregated_media_spend(
|
|
1090
|
+
self,
|
|
1091
|
+
data_points_with_spend: list[marketing_pb.MarketingDataPoint],
|
|
1092
|
+
) -> xr.DataArray | None:
|
|
1093
|
+
"""Extracts aggregated spend data.
|
|
1094
|
+
|
|
1095
|
+
Args:
|
|
1096
|
+
data_points_with_spend: List of MarketingDataPoint protos with spend data.
|
|
1097
|
+
|
|
1098
|
+
Returns:
|
|
1099
|
+
An xr.DataArray with aggregated spend data or None if no data found.
|
|
1100
|
+
"""
|
|
1101
|
+
channel_names = self._metadata.channel_dimensions.get(c.MEDIA, [])
|
|
1102
|
+
channel_spend_map = {}
|
|
1103
|
+
|
|
1104
|
+
for spend_data_point in data_points_with_spend:
|
|
1105
|
+
for media_variable in spend_data_point.media_variables:
|
|
1106
|
+
if (
|
|
1107
|
+
media_variable.channel_name in channel_names
|
|
1108
|
+
and media_variable.HasField(c.MEDIA_SPEND)
|
|
1109
|
+
):
|
|
1110
|
+
channel_spend_map[media_variable.channel_name] = (
|
|
1111
|
+
media_variable.media_spend
|
|
1112
|
+
)
|
|
1113
|
+
|
|
1114
|
+
if not channel_spend_map:
|
|
1115
|
+
return None
|
|
1116
|
+
|
|
1117
|
+
return xr.DataArray(
|
|
1118
|
+
data=list(channel_spend_map.values()),
|
|
1119
|
+
coords={c.MEDIA_CHANNEL: list(channel_spend_map.keys())},
|
|
1120
|
+
dims=[c.MEDIA_CHANNEL],
|
|
1121
|
+
name=c.MEDIA_SPEND,
|
|
1122
|
+
)
|
|
1123
|
+
|
|
1124
|
+
def _extract_aggregated_rf_spend(
|
|
1125
|
+
self,
|
|
1126
|
+
data_points_with_spend: list[marketing_pb.MarketingDataPoint],
|
|
1127
|
+
) -> xr.DataArray | None:
|
|
1128
|
+
"""Extracts aggregated spend data.
|
|
1129
|
+
|
|
1130
|
+
Args:
|
|
1131
|
+
data_points_with_spend: List of MarketingDataPoint protos with spend data.
|
|
1132
|
+
|
|
1133
|
+
Returns:
|
|
1134
|
+
An xr.DataArray with aggregated spend data or None if no data found.
|
|
1135
|
+
"""
|
|
1136
|
+
channel_names = self._metadata.channel_dimensions.get(c.REACH, [])
|
|
1137
|
+
channel_spend_map = {}
|
|
1138
|
+
|
|
1139
|
+
for spend_data_point in data_points_with_spend:
|
|
1140
|
+
for rf_variable in spend_data_point.reach_frequency_variables:
|
|
1141
|
+
if rf_variable.channel_name in channel_names and rf_variable.HasField(
|
|
1142
|
+
c.SPEND
|
|
1143
|
+
):
|
|
1144
|
+
channel_spend_map[rf_variable.channel_name] = rf_variable.spend
|
|
1145
|
+
|
|
1146
|
+
if not channel_spend_map:
|
|
1147
|
+
return None
|
|
1148
|
+
|
|
1149
|
+
return xr.DataArray(
|
|
1150
|
+
data=list(channel_spend_map.values()),
|
|
1151
|
+
coords={c.RF_CHANNEL: list(channel_spend_map.keys())},
|
|
1152
|
+
dims=[c.RF_CHANNEL],
|
|
1153
|
+
name=c.RF_SPEND,
|
|
1154
|
+
)
|
|
1155
|
+
|
|
1156
|
+
def _is_aggregated_spend_data_point(
|
|
1157
|
+
self, dp: marketing_pb.MarketingDataPoint
|
|
1158
|
+
) -> bool:
|
|
1159
|
+
"""Checks if a MarketingDataPoint with spend represents aggregated spend data.
|
|
1160
|
+
|
|
1161
|
+
Args:
|
|
1162
|
+
dp: A marketing_pb.MarketingDataPoint representing a spend data point.
|
|
1163
|
+
|
|
1164
|
+
Returns:
|
|
1165
|
+
True if the data point represents aggregated spend, False otherwise.
|
|
1166
|
+
"""
|
|
1167
|
+
if not dp.HasField(sc.GEO_INFO) and self._metadata.media_time_dimension:
|
|
1168
|
+
media_time_interval = (
|
|
1169
|
+
self._metadata.media_time_dimension.time_dimension_interval
|
|
1170
|
+
)
|
|
1171
|
+
return (
|
|
1172
|
+
media_time_interval.start_date == dp.date_interval.start_date
|
|
1173
|
+
and media_time_interval.end_date == dp.date_interval.end_date
|
|
1174
|
+
)
|
|
1175
|
+
return False
|
|
1176
|
+
|
|
1177
|
+
def _extract_media_spend(self) -> xr.DataArray | None:
|
|
1178
|
+
"""Extracts media spend data from the serialized proto.
|
|
1179
|
+
|
|
1180
|
+
Returns:
|
|
1181
|
+
An xr.DataArray with spend data or None if no data found.
|
|
1182
|
+
"""
|
|
1183
|
+
# Filter data points relevant to spend based on channel type map
|
|
1184
|
+
media_channels = {
|
|
1185
|
+
channel
|
|
1186
|
+
for channel, metadata_channel_type in self._metadata.channel_types.items()
|
|
1187
|
+
if metadata_channel_type == c.MEDIA_CHANNEL
|
|
1188
|
+
}
|
|
1189
|
+
spend_data_points = [
|
|
1190
|
+
dp
|
|
1191
|
+
for dp in self._serialized.marketing_data_points
|
|
1192
|
+
if any(
|
|
1193
|
+
mv.HasField(c.MEDIA_SPEND) and mv.channel_name in media_channels
|
|
1194
|
+
for mv in dp.media_variables
|
|
1195
|
+
)
|
|
1196
|
+
]
|
|
1197
|
+
|
|
1198
|
+
if not spend_data_points:
|
|
1199
|
+
return None
|
|
1200
|
+
|
|
1201
|
+
aggregated_spend_data_points = [
|
|
1202
|
+
dp
|
|
1203
|
+
for dp in spend_data_points
|
|
1204
|
+
if self._is_aggregated_spend_data_point(dp)
|
|
1205
|
+
]
|
|
1206
|
+
|
|
1207
|
+
if aggregated_spend_data_points:
|
|
1208
|
+
return self._extract_aggregated_media_spend(aggregated_spend_data_points)
|
|
1209
|
+
|
|
1210
|
+
return self._extract_granular_media_spend(spend_data_points)
|
|
1211
|
+
|
|
1212
|
+
def _extract_rf_spend(self) -> xr.DataArray | None:
|
|
1213
|
+
"""Extracts reach and frequency spend data from the serialized proto.
|
|
1214
|
+
|
|
1215
|
+
Returns:
|
|
1216
|
+
An xr.DataArray with spend data or None if no data found.
|
|
1217
|
+
"""
|
|
1218
|
+
# Filter data points relevant to spend based on channel type map
|
|
1219
|
+
rf_channels = {
|
|
1220
|
+
channel
|
|
1221
|
+
for channel, metadata_channel_type in self._metadata.channel_types.items()
|
|
1222
|
+
if metadata_channel_type == c.RF_CHANNEL
|
|
1223
|
+
}
|
|
1224
|
+
spend_data_points = [
|
|
1225
|
+
dp
|
|
1226
|
+
for dp in self._serialized.marketing_data_points
|
|
1227
|
+
if any(
|
|
1228
|
+
mv.HasField(c.SPEND) and mv.channel_name in rf_channels
|
|
1229
|
+
for mv in dp.reach_frequency_variables
|
|
1230
|
+
)
|
|
1231
|
+
]
|
|
1232
|
+
|
|
1233
|
+
if not spend_data_points:
|
|
1234
|
+
return None
|
|
1235
|
+
|
|
1236
|
+
aggregated_spend_data_points = [
|
|
1237
|
+
dp
|
|
1238
|
+
for dp in spend_data_points
|
|
1239
|
+
if self._is_aggregated_spend_data_point(dp)
|
|
1240
|
+
]
|
|
1241
|
+
|
|
1242
|
+
if aggregated_spend_data_points:
|
|
1243
|
+
return self._extract_aggregated_rf_spend(aggregated_spend_data_points)
|
|
1244
|
+
|
|
1245
|
+
return self._extract_granular_rf_spend(spend_data_points)
|
|
1246
|
+
|
|
1247
|
+
def _extract_non_media_treatments(self) -> xr.DataArray | None:
|
|
1248
|
+
"""Extracts non-media treatment variables data from the serialized proto."""
|
|
1249
|
+
|
|
1250
|
+
def _non_media_treatments_extractor(data_point):
|
|
1251
|
+
if not data_point.non_media_treatment_variables:
|
|
1252
|
+
return None
|
|
1253
|
+
|
|
1254
|
+
geo_id, time_str = self._extract_geo_and_time(data_point)
|
|
1255
|
+
|
|
1256
|
+
for (
|
|
1257
|
+
non_media_treatment_variable
|
|
1258
|
+
) in data_point.non_media_treatment_variables:
|
|
1259
|
+
treatment_name = non_media_treatment_variable.name
|
|
1260
|
+
treatment_value = non_media_treatment_variable.value
|
|
1261
|
+
yield geo_id, time_str, treatment_name, treatment_value
|
|
1262
|
+
|
|
1263
|
+
non_media_treatments_data_array = _extract_3d_data_array(
|
|
1264
|
+
serialized_data_points=self._serialized.marketing_data_points,
|
|
1265
|
+
data_extractor_fn=_non_media_treatments_extractor,
|
|
1266
|
+
data_name=c.NON_MEDIA_TREATMENTS,
|
|
1267
|
+
third_dim_name=c.NON_MEDIA_CHANNEL,
|
|
1268
|
+
)
|
|
1269
|
+
|
|
1270
|
+
return non_media_treatments_data_array
|
|
1271
|
+
|
|
1272
|
+
def __call__(self) -> meridian_input_data.InputData:
|
|
1273
|
+
"""Converts the `MarketingData` proto to a Meridian `InputData`."""
|
|
1274
|
+
kpi_type = self._extract_kpi_type()
|
|
1275
|
+
return meridian_input_data.InputData(
|
|
1276
|
+
kpi=self._extract_kpi(kpi_type),
|
|
1277
|
+
kpi_type=kpi_type,
|
|
1278
|
+
controls=self._extract_controls(),
|
|
1279
|
+
population=self._extract_population(),
|
|
1280
|
+
revenue_per_kpi=(
|
|
1281
|
+
self._extract_revenue_per_kpi(kpi_type)
|
|
1282
|
+
if kpi_type == c.NON_REVENUE
|
|
1283
|
+
else None
|
|
1284
|
+
),
|
|
1285
|
+
media=self._extract_media(),
|
|
1286
|
+
media_spend=self._extract_media_spend(),
|
|
1287
|
+
reach=self._extract_reach(),
|
|
1288
|
+
frequency=self._extract_frequency(),
|
|
1289
|
+
rf_spend=self._extract_rf_spend(),
|
|
1290
|
+
organic_media=self._extract_organic_media(),
|
|
1291
|
+
organic_reach=self._extract_organic_reach(),
|
|
1292
|
+
organic_frequency=self._extract_organic_frequency(),
|
|
1293
|
+
non_media_treatments=self._extract_non_media_treatments(),
|
|
1294
|
+
)
|
|
1295
|
+
|
|
1296
|
+
|
|
1297
|
+
class MarketingDataSerde(
|
|
1298
|
+
serde.Serde[marketing_pb.MarketingData, meridian_input_data.InputData]
|
|
1299
|
+
):
|
|
1300
|
+
"""Serializes and deserializes an `InputData` container in Meridian."""
|
|
1301
|
+
|
|
1302
|
+
def serialize(
|
|
1303
|
+
self, obj: meridian_input_data.InputData
|
|
1304
|
+
) -> marketing_pb.MarketingData:
|
|
1305
|
+
"""Serializes the given Meridian input data into a `MarketingData` proto."""
|
|
1306
|
+
return _InputDataSerializer(obj)()
|
|
1307
|
+
|
|
1308
|
+
def deserialize(
|
|
1309
|
+
self, serialized: marketing_pb.MarketingData, serialized_version: str = ""
|
|
1310
|
+
) -> meridian_input_data.InputData:
|
|
1311
|
+
"""Deserializes the given `MarketingData` proto.
|
|
1312
|
+
|
|
1313
|
+
Args:
|
|
1314
|
+
serialized: The serialized `MarketingData` proto.
|
|
1315
|
+
serialized_version: The version of the serialized model. This is used to
|
|
1316
|
+
handle changes in deserialization logic across different versions.
|
|
1317
|
+
|
|
1318
|
+
Returns:
|
|
1319
|
+
A Meridian input data container.
|
|
1320
|
+
"""
|
|
1321
|
+
return _InputDataDeserializer(serialized)()
|