google-meridian 1.3.0__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1321 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Serialization and deserialization of `InputData` for Meridian models."""
16
+
17
+ from collections.abc import Mapping
18
+ import dataclasses
19
+ import datetime as dt
20
+ import functools
21
+ import itertools
22
+ from typing import Sequence
23
+
24
+ from meridian import constants as c
25
+ from meridian.data import input_data as meridian_input_data
26
+ from mmm.v1.common import date_interval_pb2
27
+ from mmm.v1.marketing import marketing_data_pb2 as marketing_pb
28
+ from schema.serde import constants as sc
29
+ from schema.serde import serde
30
+ from schema.utils import time_record
31
+ import numpy as np
32
+ import xarray as xr
33
+
34
+ from google.type import date_pb2
35
+
36
+ # Mapping from DataArray names to coordinate names
37
+ _COORD_NAME_MAP = {
38
+ c.MEDIA: c.MEDIA_CHANNEL,
39
+ c.REACH: c.RF_CHANNEL,
40
+ c.FREQUENCY: c.RF_CHANNEL,
41
+ c.ORGANIC_MEDIA: c.ORGANIC_MEDIA_CHANNEL,
42
+ c.ORGANIC_REACH: c.ORGANIC_RF_CHANNEL,
43
+ c.ORGANIC_FREQUENCY: c.ORGANIC_RF_CHANNEL,
44
+ c.NON_MEDIA_TREATMENTS: c.NON_MEDIA_CHANNEL,
45
+ }
46
+
47
+
48
+ @dataclasses.dataclass(frozen=True)
49
+ class _DeserializedTimeDimension:
50
+ """Wrapper class for `TimeDimension` proto to provide utility methods during deserialization."""
51
+
52
+ _time_dimension: marketing_pb.MarketingDataMetadata.TimeDimension
53
+
54
+ def __post_init__(self):
55
+ if not self._time_dimension.dates:
56
+ raise ValueError("TimeDimension proto must have at least one date.")
57
+
58
+ @functools.cached_property
59
+ def date_coordinates(self) -> list[dt.date]:
60
+ """Returns a list of date coordinates in this time dimension."""
61
+ return [dt.date(d.year, d.month, d.day) for d in self._time_dimension.dates]
62
+
63
+ @functools.cached_property
64
+ def time_dimension_interval(self) -> date_interval_pb2.DateInterval:
65
+ """Returns the `[start, end)` interval that spans this time dimension.
66
+
67
+ This date interval spans all of the date coordinates in this time dimension.
68
+ """
69
+ date_intervals = time_record.convert_times_to_date_intervals(
70
+ self.date_coordinates
71
+ )
72
+ return _get_date_interval_from_date_intervals(list(date_intervals.values()))
73
+
74
+
75
+ @dataclasses.dataclass(frozen=True)
76
+ class _DeserializedMetadata:
77
+ """A container for parsed metadata from the `MarketingData` proto.
78
+
79
+ Attributes:
80
+ _metadata: The `MarketingDataMetadata` proto.
81
+ """
82
+
83
+ _metadata: marketing_pb.MarketingDataMetadata
84
+
85
+ def __post_init__(self):
86
+ # Evaluate the properties to trigger validation
87
+ _ = self.time_dimension
88
+ _ = self.media_time_dimension
89
+
90
+ def _get_time_dimension(self, name: str) -> _DeserializedTimeDimension:
91
+ """Helper method to get a specific TimeDimension proto by name."""
92
+ for time_dimension in self._metadata.time_dimensions:
93
+ if time_dimension.name == name:
94
+ return _DeserializedTimeDimension(time_dimension)
95
+ raise ValueError(f"No TimeDimension found with name '{name}' in metadata.")
96
+
97
+ @functools.cached_property
98
+ def time_dimension(self) -> _DeserializedTimeDimension:
99
+ """Returns the TimeDimension with name 'time'."""
100
+ return self._get_time_dimension(c.TIME)
101
+
102
+ @functools.cached_property
103
+ def media_time_dimension(self) -> _DeserializedTimeDimension:
104
+ """Returns the TimeDimension with name 'media_time'."""
105
+ return self._get_time_dimension(c.MEDIA_TIME)
106
+
107
+ @functools.cached_property
108
+ def channel_dimensions(self) -> Mapping[str, list[str]]:
109
+ """Returns a mapping of channel dimension names to their corresponding channel coordinate names."""
110
+ return {
111
+ cd.name: list(cd.channels) for cd in self._metadata.channel_dimensions
112
+ }
113
+
114
+ @functools.cached_property
115
+ def channel_types(self) -> Mapping[str, str | None]:
116
+ """Returns a mapping of individual channel names to their types."""
117
+ channel_coord_map = {}
118
+ for name, channels in self.channel_dimensions.items():
119
+ for channel in channels:
120
+ channel_coord_map[channel] = _COORD_NAME_MAP.get(
121
+ name,
122
+ )
123
+ return channel_coord_map
124
+
125
+
126
+ def _extract_data_array(
127
+ serialized_data_points: Sequence[marketing_pb.MarketingDataPoint],
128
+ data_extractor_fn,
129
+ data_name,
130
+ ) -> xr.DataArray | None:
131
+ """Helper function to extract data into an `xr.DataArray`.
132
+
133
+ Args:
134
+ serialized_data_points: A Sequence of MarketingDataPoint protos.
135
+ data_extractor_fn: A function that takes a data point and returns either a
136
+ tuple of `(geo_id, time_str, value)`, or `None` if the data point should
137
+ be skipped.
138
+ data_name: The desired name for the `xr.DataArray`.
139
+
140
+ Returns:
141
+ An `xr.DataArray` containing the extracted data, or `None` if no data is
142
+ found.
143
+ """
144
+ data_dict = {} # (geo_id, time_str) -> value
145
+ geo_ids = []
146
+ times = []
147
+
148
+ for data_point in serialized_data_points:
149
+ extraction_result = data_extractor_fn(data_point)
150
+ if extraction_result is None:
151
+ continue
152
+
153
+ geo_id, time_str, value = extraction_result
154
+
155
+ # TODO: Enforce dimension uniqueness in Meridian.
156
+ if geo_id not in geo_ids:
157
+ geo_ids.append(geo_id)
158
+ if time_str not in times:
159
+ times.append(time_str)
160
+
161
+ data_dict[(geo_id, time_str)] = value
162
+
163
+ if not data_dict:
164
+ return None
165
+
166
+ data_values = np.array([
167
+ [data_dict.get((geo_id, time), np.nan) for time in times]
168
+ for geo_id in geo_ids
169
+ ])
170
+
171
+ return xr.DataArray(
172
+ data=data_values,
173
+ coords={
174
+ c.GEO: geo_ids,
175
+ c.TIME: times,
176
+ },
177
+ dims=(c.GEO, c.TIME),
178
+ name=data_name,
179
+ )
180
+
181
+
182
+ def _extract_3d_data_array(
183
+ serialized_data_points: Sequence[marketing_pb.MarketingDataPoint],
184
+ data_extractor_fn,
185
+ data_name,
186
+ third_dim_name,
187
+ time_dim_name=c.TIME,
188
+ ) -> xr.DataArray | None:
189
+ """Helper function to extract data with 3 dimensions into an `xr.DataArray`.
190
+
191
+ The first dimension is always `GEO`, and the second is the time dimension
192
+ (default: `TIME`).
193
+
194
+ Args:
195
+ serialized_data_points: A sequence of MarketingDataPoint protos.
196
+ data_extractor_fn: A function that takes a data point and returns either a
197
+ tuple of `(geo_id, time_str, third_dim_key, value)`, or `None` if the
198
+ data point should be skipped.
199
+ data_name: The desired name for the `xr.DataArray`.
200
+ third_dim_name: The name of the third dimension.
201
+ time_dim_name: The name of the time dimension. Default is `TIME`.
202
+
203
+ Returns:
204
+ An `xr.DataArray` containing the extracted data, or `None` if no data is
205
+ found.
206
+ """
207
+ data_dict = {} # (geo_id, time_str, third_dim_key) -> value
208
+ geo_ids = []
209
+ times = []
210
+ third_dim_keys = []
211
+
212
+ for data_point in serialized_data_points:
213
+ for extraction_result in data_extractor_fn(data_point):
214
+ geo_id, time_str, third_dim_key, value = extraction_result
215
+
216
+ if geo_id not in geo_ids:
217
+ geo_ids.append(geo_id)
218
+ if time_str not in times:
219
+ times.append(time_str)
220
+ if third_dim_key not in third_dim_keys:
221
+ third_dim_keys.append(third_dim_key)
222
+
223
+ # TODO: Enforce dimension uniqueness in Meridian.
224
+ data_dict[(geo_id, time_str, third_dim_key)] = value
225
+
226
+ if not data_dict:
227
+ return None
228
+
229
+ data_values = np.array([
230
+ [
231
+ [
232
+ data_dict.get((geo_id, time, third_dim_key), np.nan)
233
+ for third_dim_key in third_dim_keys
234
+ ]
235
+ for time in times
236
+ ]
237
+ for geo_id in geo_ids
238
+ ])
239
+
240
+ return xr.DataArray(
241
+ data=data_values,
242
+ coords={
243
+ c.GEO: geo_ids,
244
+ time_dim_name: times,
245
+ third_dim_name: third_dim_keys,
246
+ },
247
+ dims=(c.GEO, time_dim_name, third_dim_name),
248
+ name=data_name,
249
+ )
250
+
251
+
252
+ def _get_date_interval_from_date_intervals(
253
+ date_intervals: Sequence[date_interval_pb2.DateInterval],
254
+ ) -> date_interval_pb2.DateInterval:
255
+ """Gets the date interval based on the earliest start date and latest end date.
256
+
257
+ Args:
258
+ date_intervals: A list of DateInterval protos.
259
+
260
+ Returns:
261
+ A DateInterval representing the earliest start date and latest end date.
262
+ """
263
+ get_start_date = lambda interval: dt.date(
264
+ interval.start_date.year,
265
+ interval.start_date.month,
266
+ interval.start_date.day,
267
+ )
268
+ get_end_date = lambda interval: dt.date(
269
+ interval.end_date.year, interval.end_date.month, interval.end_date.day
270
+ )
271
+
272
+ min_start_date_interval = min(date_intervals, key=get_start_date)
273
+ max_end_date_interval = max(date_intervals, key=get_end_date)
274
+
275
+ return date_interval_pb2.DateInterval(
276
+ start_date=date_pb2.Date(
277
+ year=min_start_date_interval.start_date.year,
278
+ month=min_start_date_interval.start_date.month,
279
+ day=min_start_date_interval.start_date.day,
280
+ ),
281
+ end_date=date_pb2.Date(
282
+ year=max_end_date_interval.end_date.year,
283
+ month=max_end_date_interval.end_date.month,
284
+ day=max_end_date_interval.end_date.day,
285
+ ),
286
+ )
287
+
288
+
289
+ class _InputDataSerializer:
290
+ """Serializes an `InputData` container in Meridian model."""
291
+
292
+ def __init__(self, input_data: meridian_input_data.InputData):
293
+ self._input_data = input_data
294
+
295
+ @property
296
+ def _n_geos(self) -> int:
297
+ return len(self._input_data.geo)
298
+
299
+ @property
300
+ def _n_times(self) -> int:
301
+ return len(self._input_data.time)
302
+
303
+ def __call__(self) -> marketing_pb.MarketingData:
304
+ """Serializes the input data into a MarketingData proto."""
305
+ marketing_proto = marketing_pb.MarketingData()
306
+ # Use media_time since it covers larger range.
307
+ times_to_date_intervals = time_record.convert_times_to_date_intervals(
308
+ self._input_data.media_time.data
309
+ )
310
+ geos_and_times = itertools.product(
311
+ self._input_data.geo.data, self._input_data.media_time.data
312
+ )
313
+
314
+ for geo, time in geos_and_times:
315
+ data_point = self._serialize_data_point(
316
+ geo,
317
+ time,
318
+ times_to_date_intervals,
319
+ )
320
+ marketing_proto.marketing_data_points.append(data_point)
321
+
322
+ if self._input_data.media_spend is not None:
323
+ if (
324
+ not self._input_data.media_spend_has_geo_dimension
325
+ and not self._input_data.media_spend_has_time_dimension
326
+ ):
327
+ marketing_proto.marketing_data_points.append(
328
+ self._serialize_aggregated_media_spend_data_point(
329
+ self._input_data.media_spend,
330
+ times_to_date_intervals,
331
+ )
332
+ )
333
+ elif (
334
+ self._input_data.media_spend_has_geo_dimension
335
+ != self._input_data.media_spend_has_time_dimension
336
+ ):
337
+ raise AssertionError(
338
+ "Invalid input data: media_spend must either be fully granular"
339
+ " (both geo and time dimensions) or fully aggregated (neither geo"
340
+ " nor time dimensions)."
341
+ )
342
+
343
+ if self._input_data.rf_spend is not None:
344
+ if (
345
+ not self._input_data.rf_spend_has_geo_dimension
346
+ and not self._input_data.rf_spend_has_time_dimension
347
+ ):
348
+ marketing_proto.marketing_data_points.append(
349
+ self._serialize_aggregated_rf_spend_data_point(
350
+ self._input_data.rf_spend, times_to_date_intervals
351
+ )
352
+ )
353
+ elif (
354
+ self._input_data.rf_spend_has_geo_dimension
355
+ != self._input_data.rf_spend_has_time_dimension
356
+ ):
357
+ raise AssertionError(
358
+ "Invalid input data: rf_spend must either be fully granular (both"
359
+ " geo and time dimensions) or fully aggregated (neither geo nor"
360
+ " time dimensions)."
361
+ )
362
+
363
+ marketing_proto.metadata.CopyFrom(self._serialize_metadata())
364
+
365
+ return marketing_proto
366
+
367
+ def _serialize_media_variables(
368
+ self,
369
+ geo: str,
370
+ time: str,
371
+ channel_dim_name: str,
372
+ impressions_data_array: xr.DataArray,
373
+ spend_data_array: xr.DataArray | None = None,
374
+ ) -> list[marketing_pb.MediaVariable]:
375
+ """Serializes media variables for a given geo and time.
376
+
377
+ Args:
378
+ geo: The geo ID.
379
+ time: The time string.
380
+ channel_dim_name: The name of the channel dimension.
381
+ impressions_data_array: The DataArray containing impressions data.
382
+ Expected dimensions: `(n_geos, n_media_times, n_channels)`.
383
+ spend_data_array: The optional DataArray containing spend data. Expected
384
+ dimensions are `(n_geos, n_times, n_media_channels)`.
385
+
386
+ Returns:
387
+ A list of MediaVariable protos.
388
+ """
389
+ media_variables = []
390
+ for media_data in impressions_data_array.sel(geo=geo, media_time=time):
391
+ channel = media_data[channel_dim_name].item()
392
+ media_variable = marketing_pb.MediaVariable(
393
+ channel_name=channel,
394
+ scalar_metric=marketing_pb.ScalarMetric(
395
+ name=c.IMPRESSIONS, value=media_data.item()
396
+ ),
397
+ )
398
+ if spend_data_array is not None and time in spend_data_array.time:
399
+ media_variable.media_spend = spend_data_array.sel(
400
+ geo=geo, time=time, **{channel_dim_name: channel}
401
+ ).item()
402
+ media_variables.append(media_variable)
403
+ return media_variables
404
+
405
+ def _serialize_reach_frequency_variables(
406
+ self,
407
+ geo: str,
408
+ time: str,
409
+ channel_dim_name: str,
410
+ reach_data_array: xr.DataArray,
411
+ frequency_data_array: xr.DataArray,
412
+ spend_data_array: xr.DataArray | None = None,
413
+ ) -> list[marketing_pb.ReachFrequencyVariable]:
414
+ """Serializes reach and frequency variables for a given geo and time.
415
+
416
+ Iterates through the R&F channels separately, creating a MediaVariable
417
+ for each. It's safe to assume that Meridian media channel names are
418
+ unique across `media_data` and `reach_data`. This assumption is
419
+ checked when an `InputData` is created in model training.
420
+
421
+ Dimensions of `reach_data_array` and `frequency_data_array` are expected
422
+ to be `(n_geos, n_media_times, n_rf_channels)`.
423
+
424
+ Args:
425
+ geo: The geo ID.
426
+ time: The time string.
427
+ channel_dim_name: The name of the channel dimension (e.g., 'rf_channel').
428
+ reach_data_array: The DataArray containing reach data.
429
+ frequency_data_array: The DataArray containing frequency data.
430
+ spend_data_array: The optional DataArray containing spend data.
431
+
432
+ Returns:
433
+ A list of MediaVariable protos.
434
+ """
435
+ rf_variables = []
436
+ for reach_data in reach_data_array.sel(geo=geo, media_time=time):
437
+ reach_value = reach_data.item()
438
+ channel = reach_data[channel_dim_name].item()
439
+ frequency_value = frequency_data_array.sel(
440
+ geo=geo,
441
+ media_time=time,
442
+ **{channel_dim_name: channel},
443
+ ).item()
444
+ rf_variable = marketing_pb.ReachFrequencyVariable(
445
+ channel_name=channel,
446
+ reach=int(reach_value),
447
+ average_frequency=frequency_value,
448
+ )
449
+ if spend_data_array is not None and time in spend_data_array.time:
450
+ rf_variable.spend = spend_data_array.sel(
451
+ geo=geo, time=time, **{channel_dim_name: channel}
452
+ ).item()
453
+ rf_variables.append(rf_variable)
454
+ return rf_variables
455
+
456
+ def _serialize_non_media_treatment_variables(
457
+ self, geo: str, time: str
458
+ ) -> list[marketing_pb.NonMediaTreatmentVariable]:
459
+ """Serializes non-media treatment variables for a given geo and time.
460
+
461
+ Args:
462
+ geo: The geo ID.
463
+ time: The time string.
464
+
465
+ Returns:
466
+ A list of NonMediaTreatmentVariable protos.
467
+ """
468
+ non_media_treatment_variables = []
469
+ if (
470
+ self._input_data.non_media_treatments is not None
471
+ and geo in self._input_data.non_media_treatments.geo
472
+ and time in self._input_data.non_media_treatments.time
473
+ ):
474
+ for non_media_treatment_data in self._input_data.non_media_treatments.sel(
475
+ geo=geo, time=time
476
+ ):
477
+ non_media_treatment_variables.append(
478
+ marketing_pb.NonMediaTreatmentVariable(
479
+ name=non_media_treatment_data[c.NON_MEDIA_CHANNEL].item(),
480
+ value=non_media_treatment_data.item(),
481
+ )
482
+ )
483
+ return non_media_treatment_variables
484
+
485
+ def _serialize_data_point(
486
+ self,
487
+ geo: str,
488
+ time: str,
489
+ times_to_date_intervals: Mapping[str, date_interval_pb2.DateInterval],
490
+ ) -> marketing_pb.MarketingDataPoint:
491
+ """Serializes a MarketingDataPoint proto for a given geo and time."""
492
+ data_point = marketing_pb.MarketingDataPoint(
493
+ geo_info=marketing_pb.GeoInfo(
494
+ geo_id=geo,
495
+ population=round(self._input_data.population.sel(geo=geo).item()),
496
+ ),
497
+ date_interval=times_to_date_intervals.get(time),
498
+ )
499
+
500
+ if self._input_data.controls is not None:
501
+ if time in self._input_data.controls.time:
502
+ for control_data in self._input_data.controls.sel(geo=geo, time=time):
503
+ data_point.control_variables.add(
504
+ name=control_data.control_variable.item(),
505
+ value=control_data.item(),
506
+ )
507
+
508
+ if self._input_data.media is not None:
509
+ if (
510
+ self._input_data.media_spend_has_geo_dimension
511
+ and self._input_data.media_spend_has_time_dimension
512
+ ):
513
+ spend_data_array = self._input_data.media_spend
514
+ else:
515
+ # Aggregated spend data is serialized in a separate data point.
516
+ spend_data_array = None
517
+ media_variables = self._serialize_media_variables(
518
+ geo,
519
+ time,
520
+ c.MEDIA_CHANNEL,
521
+ self._input_data.media,
522
+ spend_data_array,
523
+ )
524
+ data_point.media_variables.extend(media_variables)
525
+
526
+ if (
527
+ self._input_data.reach is not None
528
+ and self._input_data.frequency is not None
529
+ ):
530
+ if (
531
+ self._input_data.rf_spend_has_geo_dimension
532
+ and self._input_data.rf_spend_has_time_dimension
533
+ ):
534
+ rf_spend_data_array = self._input_data.rf_spend
535
+ else:
536
+ # Aggregated spend data is serialized in a separate data point.
537
+ rf_spend_data_array = None
538
+ rf_variables = self._serialize_reach_frequency_variables(
539
+ geo,
540
+ time,
541
+ c.RF_CHANNEL,
542
+ self._input_data.reach,
543
+ self._input_data.frequency,
544
+ rf_spend_data_array,
545
+ )
546
+ data_point.reach_frequency_variables.extend(rf_variables)
547
+
548
+ if self._input_data.organic_media is not None:
549
+ organic_media_variables = self._serialize_media_variables(
550
+ geo, time, c.ORGANIC_MEDIA_CHANNEL, self._input_data.organic_media
551
+ )
552
+ data_point.media_variables.extend(organic_media_variables)
553
+
554
+ if (
555
+ self._input_data.organic_reach is not None
556
+ and self._input_data.organic_frequency is not None
557
+ ):
558
+ organic_rf_variables = self._serialize_reach_frequency_variables(
559
+ geo,
560
+ time,
561
+ c.ORGANIC_RF_CHANNEL,
562
+ self._input_data.organic_reach,
563
+ self._input_data.organic_frequency,
564
+ )
565
+ data_point.reach_frequency_variables.extend(organic_rf_variables)
566
+
567
+ non_media_treatment_variables = (
568
+ self._serialize_non_media_treatment_variables(geo, time)
569
+ )
570
+ data_point.non_media_treatment_variables.extend(
571
+ non_media_treatment_variables
572
+ )
573
+
574
+ if time in self._input_data.kpi.time:
575
+ kpi_proto = self._make_kpi_proto(geo, time)
576
+ data_point.kpi.CopyFrom(kpi_proto)
577
+
578
+ return data_point
579
+
580
+ def _serialize_aggregated_media_spend_data_point(
581
+ self,
582
+ spend_data_array: xr.DataArray,
583
+ times_to_date_intervals: Mapping[str, date_interval_pb2.DateInterval],
584
+ ) -> marketing_pb.MarketingDataPoint:
585
+ """Serializes and appends a data point for aggregated spend."""
586
+ spend_data_point = marketing_pb.MarketingDataPoint()
587
+ date_interval = _get_date_interval_from_date_intervals(
588
+ list(times_to_date_intervals.values())
589
+ )
590
+ spend_data_point.date_interval.CopyFrom(date_interval)
591
+
592
+ for channel_name in spend_data_array.coords[c.MEDIA_CHANNEL].values:
593
+ spend_value = spend_data_array.sel(
594
+ **{c.MEDIA_CHANNEL: channel_name}
595
+ ).item()
596
+ spend_data_point.media_variables.add(
597
+ channel_name=channel_name, media_spend=spend_value
598
+ )
599
+
600
+ return spend_data_point
601
+
602
+ def _serialize_aggregated_rf_spend_data_point(
603
+ self,
604
+ spend_data_array: xr.DataArray,
605
+ times_to_date_intervals: Mapping[str, date_interval_pb2.DateInterval],
606
+ ) -> marketing_pb.MarketingDataPoint:
607
+ """Serializes and appends a data point for aggregated spend."""
608
+ spend_data_point = marketing_pb.MarketingDataPoint()
609
+ date_interval = _get_date_interval_from_date_intervals(
610
+ list(times_to_date_intervals.values())
611
+ )
612
+ spend_data_point.date_interval.CopyFrom(date_interval)
613
+
614
+ for channel_name in spend_data_array.coords[c.RF_CHANNEL].values:
615
+ spend_value = spend_data_array.sel(**{c.RF_CHANNEL: channel_name}).item()
616
+ spend_data_point.reach_frequency_variables.add(
617
+ channel_name=channel_name, spend=spend_value
618
+ )
619
+
620
+ return spend_data_point
621
+
622
+ def _serialize_time_dimensions(
623
+ self, name: str, time_data: xr.DataArray
624
+ ) -> marketing_pb.MarketingDataMetadata.TimeDimension:
625
+ """Creates a TimeDimension message."""
626
+ time_dimensions = marketing_pb.MarketingDataMetadata.TimeDimension(
627
+ name=name
628
+ )
629
+ for date in time_data.values:
630
+ date_obj = dt.datetime.strptime(date, c.DATE_FORMAT).date()
631
+ time_dimensions.dates.add(
632
+ year=date_obj.year, month=date_obj.month, day=date_obj.day
633
+ )
634
+ return time_dimensions
635
+
636
+ def _serialize_channel_dimensions(
637
+ self, channel_data: xr.DataArray | None
638
+ ) -> marketing_pb.MarketingDataMetadata.ChannelDimension | None:
639
+ """Creates a ChannelDimension message if the corresponding attribute exists."""
640
+ if channel_data is None:
641
+ return None
642
+
643
+ coord_name = _COORD_NAME_MAP.get(channel_data.name)
644
+ if coord_name:
645
+ return marketing_pb.MarketingDataMetadata.ChannelDimension(
646
+ name=channel_data.name,
647
+ channels=channel_data.coords[coord_name].values.tolist(),
648
+ )
649
+ else:
650
+ # Make sure that all channel dimensions are handled.
651
+ raise ValueError(f"Unknown channel data name: {channel_data.name}. ")
652
+
653
+ def _serialize_metadata(self) -> marketing_pb.MarketingDataMetadata:
654
+ """Serializes metadata from InputData to MarketingDataMetadata."""
655
+ metadata = marketing_pb.MarketingDataMetadata()
656
+
657
+ metadata.time_dimensions.append(
658
+ self._serialize_time_dimensions(c.TIME, self._input_data.time)
659
+ )
660
+ metadata.time_dimensions.append(
661
+ self._serialize_time_dimensions(
662
+ c.MEDIA_TIME, self._input_data.media_time
663
+ )
664
+ )
665
+
666
+ channel_data_arrays = [
667
+ self._input_data.media,
668
+ self._input_data.reach,
669
+ self._input_data.frequency,
670
+ self._input_data.organic_media,
671
+ self._input_data.organic_reach,
672
+ self._input_data.organic_frequency,
673
+ ]
674
+
675
+ for channel_data_array in channel_data_arrays:
676
+ channel_names_message = self._serialize_channel_dimensions(
677
+ channel_data_array
678
+ )
679
+ if channel_names_message:
680
+ metadata.channel_dimensions.append(channel_names_message)
681
+
682
+ if self._input_data.controls is not None:
683
+ metadata.control_names.extend(
684
+ self._input_data.controls.control_variable.values
685
+ )
686
+
687
+ if self._input_data.non_media_treatments is not None:
688
+ metadata.non_media_treatment_names.extend(
689
+ self._input_data.non_media_treatments.non_media_channel.values
690
+ )
691
+
692
+ metadata.kpi_type = self._input_data.kpi_type
693
+
694
+ return metadata
695
+
696
+ def _make_kpi_proto(self, geo: str, time: str) -> marketing_pb.Kpi:
697
+ """Constructs a Kpi proto from the TrainedModel."""
698
+ kpi_proto = marketing_pb.Kpi(name=self._input_data.kpi_type)
699
+ # `kpi` and `revenue_per_kpi` dimensions: `(n_geos, n_times)`.
700
+ if self._input_data.kpi_type == c.REVENUE:
701
+ kpi_proto.revenue.CopyFrom(
702
+ marketing_pb.Kpi.Revenue(
703
+ value=self._input_data.kpi.sel(geo=geo, time=time).item()
704
+ )
705
+ )
706
+ else:
707
+ kpi_proto.non_revenue.CopyFrom(
708
+ marketing_pb.Kpi.NonRevenue(
709
+ value=self._input_data.kpi.sel(geo=geo, time=time).item()
710
+ )
711
+ )
712
+ if self._input_data.revenue_per_kpi is not None:
713
+ kpi_proto.non_revenue.revenue_per_kpi = (
714
+ self._input_data.revenue_per_kpi.sel(geo=geo, time=time).item()
715
+ )
716
+ return kpi_proto
717
+
718
+
719
+ class _InputDataDeserializer:
720
+ """Deserializes a `MarketingData` proto into a Meridian `InputData`."""
721
+
722
+ def __init__(self, serialized: marketing_pb.MarketingData):
723
+ self._serialized = serialized
724
+
725
+ def __post_init__(self):
726
+ if not self._serialized.HasField(sc.METADATA):
727
+ raise ValueError(
728
+ f"MarketingData proto is missing the '{sc.METADATA}' field."
729
+ )
730
+
731
+ @functools.cached_property
732
+ def _metadata(self) -> _DeserializedMetadata:
733
+ """Parses metadata and extracts time dimensions, channel dimensions, and channel type map."""
734
+ return _DeserializedMetadata(self._serialized.metadata)
735
+
736
+ def _extract_population(self) -> xr.DataArray:
737
+ """Extracts population data from the serialized proto."""
738
+ geo_populations = {}
739
+
740
+ for data_point in self._serialized.marketing_data_points:
741
+ geo_id = data_point.geo_info.geo_id
742
+ if not geo_id:
743
+ continue
744
+
745
+ geo_populations[geo_id] = data_point.geo_info.population
746
+
747
+ return xr.DataArray(
748
+ coords={c.GEO: list(geo_populations.keys())},
749
+ data=np.array(list(geo_populations.values())),
750
+ name=c.POPULATION,
751
+ )
752
+
753
+ def _extract_kpi_type(self) -> str:
754
+ """Extracts the kpi_type from the serialized proto."""
755
+ kpi_type = None
756
+ for data_point in self._serialized.marketing_data_points:
757
+ if data_point.HasField(c.KPI):
758
+ current_kpi_type = data_point.kpi.WhichOneof(c.TYPE)
759
+
760
+ if kpi_type is None:
761
+ kpi_type = current_kpi_type
762
+ elif kpi_type != current_kpi_type:
763
+ raise ValueError(
764
+ "Inconsistent kpi_type found in the data. "
765
+ f"Expected {kpi_type}, found {current_kpi_type}"
766
+ )
767
+
768
+ if kpi_type is None:
769
+ raise ValueError("kpi_type not found in the data.")
770
+ return kpi_type
771
+
772
+ def _extract_geo_and_time(self, data_point) -> tuple[str | None, str]:
773
+ """Extracts geo_id and time_str from a data_point."""
774
+ geo_id = data_point.geo_info.geo_id
775
+ start_date = data_point.date_interval.start_date
776
+ time_str = dt.datetime(
777
+ start_date.year, start_date.month, start_date.day
778
+ ).strftime(c.DATE_FORMAT)
779
+ return geo_id, time_str
780
+
781
+ def _extract_kpi(self, kpi_type: str) -> xr.DataArray:
782
+ """Extracts KPI data from the serialized proto."""
783
+
784
+ def _kpi_extractor(data_point):
785
+ if not data_point.HasField(c.KPI):
786
+ return None
787
+
788
+ geo_id, time_str = self._extract_geo_and_time(data_point)
789
+
790
+ if data_point.kpi.WhichOneof(c.TYPE) != kpi_type:
791
+ raise ValueError(
792
+ "Inconsistent kpi_type found in the data. "
793
+ f"Expected {kpi_type}, found"
794
+ f" {data_point.kpi.WhichOneof(c.TYPE)}"
795
+ )
796
+
797
+ kpi_value = (
798
+ data_point.kpi.revenue.value
799
+ if kpi_type == c.REVENUE
800
+ else data_point.kpi.non_revenue.value
801
+ )
802
+ return geo_id, time_str, kpi_value
803
+
804
+ kpi = _extract_data_array(
805
+ serialized_data_points=self._serialized.marketing_data_points,
806
+ data_extractor_fn=_kpi_extractor,
807
+ data_name=c.KPI,
808
+ )
809
+
810
+ if kpi is None:
811
+ raise ValueError(f"{c.KPI} is not found in the data.")
812
+
813
+ return kpi
814
+
815
+ def _extract_revenue_per_kpi(self, kpi_type: str) -> xr.DataArray | None:
816
+ """Extracts revenue per KPI data from the serialized proto."""
817
+
818
+ if kpi_type == c.REVENUE:
819
+ raise ValueError(
820
+ f"{c.REVENUE_PER_KPI} is not applicable when kpi_type is {c.REVENUE}."
821
+ )
822
+
823
+ def _revenue_per_kpi_extractor(data_point):
824
+ if not data_point.HasField(c.KPI):
825
+ return None
826
+
827
+ if not data_point.kpi.non_revenue.HasField(c.REVENUE_PER_KPI):
828
+ return None
829
+
830
+ geo_id, time_str = self._extract_geo_and_time(data_point)
831
+
832
+ if data_point.kpi.WhichOneof(c.TYPE) != kpi_type:
833
+ raise ValueError(
834
+ "Inconsistent kpi_type found in the data. "
835
+ f"Expected {kpi_type}, found"
836
+ f" {data_point.kpi.WhichOneof(c.TYPE)}"
837
+ )
838
+
839
+ return geo_id, time_str, data_point.kpi.non_revenue.revenue_per_kpi
840
+
841
+ return _extract_data_array(
842
+ serialized_data_points=self._serialized.marketing_data_points,
843
+ data_extractor_fn=_revenue_per_kpi_extractor,
844
+ data_name=c.REVENUE_PER_KPI,
845
+ )
846
+
847
+ def _extract_controls(self) -> xr.DataArray | None:
848
+ """Extracts control variables data from the serialized proto, if any."""
849
+
850
+ def _controls_extractor(data_point):
851
+ if not data_point.control_variables:
852
+ return None
853
+
854
+ geo_id, time_str = self._extract_geo_and_time(data_point)
855
+
856
+ for control_variable in data_point.control_variables:
857
+ control_name = control_variable.name
858
+ control_value = control_variable.value
859
+ yield geo_id, time_str, control_name, control_value
860
+
861
+ return _extract_3d_data_array(
862
+ serialized_data_points=self._serialized.marketing_data_points,
863
+ data_extractor_fn=_controls_extractor,
864
+ data_name=c.CONTROLS,
865
+ third_dim_name=c.CONTROL_VARIABLE,
866
+ )
867
+
868
+ def _extract_media(self) -> xr.DataArray | None:
869
+ """Extracts media variables data from the serialized proto."""
870
+
871
+ def _media_extractor(data_point):
872
+ geo_id, time_str = self._extract_geo_and_time(data_point)
873
+
874
+ if not geo_id:
875
+ return None
876
+
877
+ for media_variable in data_point.media_variables:
878
+ channel_name = media_variable.channel_name
879
+ if self._metadata.channel_types.get(channel_name) != c.MEDIA_CHANNEL:
880
+ continue
881
+
882
+ media_value = media_variable.scalar_metric.value
883
+ yield geo_id, time_str, channel_name, media_value
884
+
885
+ return _extract_3d_data_array(
886
+ serialized_data_points=self._serialized.marketing_data_points,
887
+ data_extractor_fn=_media_extractor,
888
+ data_name=c.MEDIA,
889
+ third_dim_name=c.MEDIA_CHANNEL,
890
+ time_dim_name=c.MEDIA_TIME,
891
+ )
892
+
893
+ def _extract_reach(self) -> xr.DataArray | None:
894
+ """Extracts reach data from the serialized proto."""
895
+
896
+ def _reach_extractor(data_point):
897
+ geo_id, time_str = self._extract_geo_and_time(data_point)
898
+
899
+ if not geo_id:
900
+ return None
901
+
902
+ for rf_variable in data_point.reach_frequency_variables:
903
+ channel_name = rf_variable.channel_name
904
+ if self._metadata.channel_types.get(channel_name) != c.RF_CHANNEL:
905
+ continue
906
+
907
+ reach_value = rf_variable.reach
908
+ yield geo_id, time_str, channel_name, reach_value
909
+
910
+ return _extract_3d_data_array(
911
+ serialized_data_points=self._serialized.marketing_data_points,
912
+ data_extractor_fn=_reach_extractor,
913
+ data_name=c.REACH,
914
+ third_dim_name=c.RF_CHANNEL,
915
+ time_dim_name=c.MEDIA_TIME,
916
+ )
917
+
918
+ def _extract_frequency(self) -> xr.DataArray | None:
919
+ """Extracts frequency data from the serialized proto."""
920
+
921
+ def _frequency_extractor(data_point):
922
+ geo_id, time_str = self._extract_geo_and_time(data_point)
923
+
924
+ if not geo_id:
925
+ return None
926
+
927
+ for rf_variable in data_point.reach_frequency_variables:
928
+ channel_name = rf_variable.channel_name
929
+ if self._metadata.channel_types.get(channel_name) != c.RF_CHANNEL:
930
+ continue
931
+
932
+ frequency_value = rf_variable.average_frequency
933
+ yield geo_id, time_str, channel_name, frequency_value
934
+
935
+ return _extract_3d_data_array(
936
+ serialized_data_points=self._serialized.marketing_data_points,
937
+ data_extractor_fn=_frequency_extractor,
938
+ data_name=c.FREQUENCY,
939
+ third_dim_name=c.RF_CHANNEL,
940
+ time_dim_name=c.MEDIA_TIME,
941
+ )
942
+
943
+ def _extract_organic_media(self) -> xr.DataArray | None:
944
+ """Extracts organic media variables data from the serialized proto."""
945
+
946
+ def _organic_media_extractor(data_point):
947
+ geo_id, time_str = self._extract_geo_and_time(data_point)
948
+
949
+ if not geo_id:
950
+ return None
951
+
952
+ for media_variable in data_point.media_variables:
953
+ channel_name = media_variable.channel_name
954
+ if (
955
+ self._metadata.channel_types.get(channel_name)
956
+ != c.ORGANIC_MEDIA_CHANNEL
957
+ ):
958
+ continue
959
+
960
+ media_value = media_variable.scalar_metric.value
961
+ yield geo_id, time_str, channel_name, media_value
962
+
963
+ return _extract_3d_data_array(
964
+ serialized_data_points=self._serialized.marketing_data_points,
965
+ data_extractor_fn=_organic_media_extractor,
966
+ data_name=c.ORGANIC_MEDIA,
967
+ third_dim_name=c.ORGANIC_MEDIA_CHANNEL,
968
+ time_dim_name=c.MEDIA_TIME,
969
+ )
970
+
971
+ def _extract_organic_reach(self) -> xr.DataArray | None:
972
+ """Extracts organic reach data from the serialized proto."""
973
+
974
+ def _organic_reach_extractor(data_point):
975
+ geo_id, time_str = self._extract_geo_and_time(data_point)
976
+
977
+ if not geo_id:
978
+ return None
979
+
980
+ for rf_variable in data_point.reach_frequency_variables:
981
+ channel_name = rf_variable.channel_name
982
+ if (
983
+ self._metadata.channel_types.get(channel_name)
984
+ != c.ORGANIC_RF_CHANNEL
985
+ ):
986
+ continue
987
+
988
+ reach_value = rf_variable.reach
989
+ yield geo_id, time_str, channel_name, reach_value
990
+
991
+ return _extract_3d_data_array(
992
+ serialized_data_points=self._serialized.marketing_data_points,
993
+ data_extractor_fn=_organic_reach_extractor,
994
+ data_name=c.ORGANIC_REACH,
995
+ third_dim_name=c.ORGANIC_RF_CHANNEL,
996
+ time_dim_name=c.MEDIA_TIME,
997
+ )
998
+
999
+ def _extract_organic_frequency(self) -> xr.DataArray | None:
1000
+ """Extracts organic frequency data from the serialized proto."""
1001
+
1002
+ def _organic_frequency_extractor(data_point):
1003
+ geo_id, time_str = self._extract_geo_and_time(data_point)
1004
+
1005
+ if not geo_id:
1006
+ return None
1007
+
1008
+ for rf_variable in data_point.reach_frequency_variables:
1009
+ channel_name = rf_variable.channel_name
1010
+ if (
1011
+ self._metadata.channel_types.get(channel_name)
1012
+ != c.ORGANIC_RF_CHANNEL
1013
+ ):
1014
+ continue
1015
+
1016
+ frequency_value = rf_variable.average_frequency
1017
+ yield geo_id, time_str, channel_name, frequency_value
1018
+
1019
+ return _extract_3d_data_array(
1020
+ serialized_data_points=self._serialized.marketing_data_points,
1021
+ data_extractor_fn=_organic_frequency_extractor,
1022
+ data_name=c.ORGANIC_FREQUENCY,
1023
+ third_dim_name=c.ORGANIC_RF_CHANNEL,
1024
+ time_dim_name=c.MEDIA_TIME,
1025
+ )
1026
+
1027
+ def _extract_granular_media_spend(
1028
+ self,
1029
+ data_points_with_spend: list[marketing_pb.MarketingDataPoint],
1030
+ ) -> xr.DataArray | None:
1031
+ """Extracts granular spend data.
1032
+
1033
+ Args:
1034
+ data_points_with_spend: List of MarketingDataPoint protos with spend data.
1035
+
1036
+ Returns:
1037
+ An xr.DataArray with granular spend data or None if no data found.
1038
+ """
1039
+
1040
+ def _granular_spend_extractor(data_point):
1041
+ geo_id, time_str = self._extract_geo_and_time(data_point)
1042
+ for media_variable in data_point.media_variables:
1043
+ if (
1044
+ media_variable.HasField(c.MEDIA_SPEND)
1045
+ and self._metadata.channel_types.get(media_variable.channel_name)
1046
+ == c.MEDIA_CHANNEL
1047
+ ):
1048
+ yield geo_id, time_str, media_variable.channel_name, media_variable.media_spend
1049
+
1050
+ return _extract_3d_data_array(
1051
+ serialized_data_points=data_points_with_spend,
1052
+ data_extractor_fn=_granular_spend_extractor,
1053
+ data_name=c.MEDIA_SPEND,
1054
+ third_dim_name=c.MEDIA_CHANNEL,
1055
+ time_dim_name=c.TIME,
1056
+ )
1057
+
1058
+ def _extract_granular_rf_spend(
1059
+ self,
1060
+ data_points_with_spend: list[marketing_pb.MarketingDataPoint],
1061
+ ) -> xr.DataArray | None:
1062
+ """Extracts granular spend data.
1063
+
1064
+ Args:
1065
+ data_points_with_spend: List of MarketingDataPoint protos with spend data.
1066
+
1067
+ Returns:
1068
+ An xr.DataArray with granular spend data or None if no data found.
1069
+ """
1070
+
1071
+ def _granular_spend_extractor(data_point):
1072
+ geo_id, time_str = self._extract_geo_and_time(data_point)
1073
+ for rf_variable in data_point.reach_frequency_variables:
1074
+ if (
1075
+ rf_variable.HasField(c.SPEND)
1076
+ and self._metadata.channel_types.get(rf_variable.channel_name)
1077
+ == c.RF_CHANNEL
1078
+ ):
1079
+ yield geo_id, time_str, rf_variable.channel_name, rf_variable.spend
1080
+
1081
+ return _extract_3d_data_array(
1082
+ serialized_data_points=data_points_with_spend,
1083
+ data_extractor_fn=_granular_spend_extractor,
1084
+ data_name=c.RF_SPEND,
1085
+ third_dim_name=c.RF_CHANNEL,
1086
+ time_dim_name=c.TIME,
1087
+ )
1088
+
1089
+ def _extract_aggregated_media_spend(
1090
+ self,
1091
+ data_points_with_spend: list[marketing_pb.MarketingDataPoint],
1092
+ ) -> xr.DataArray | None:
1093
+ """Extracts aggregated spend data.
1094
+
1095
+ Args:
1096
+ data_points_with_spend: List of MarketingDataPoint protos with spend data.
1097
+
1098
+ Returns:
1099
+ An xr.DataArray with aggregated spend data or None if no data found.
1100
+ """
1101
+ channel_names = self._metadata.channel_dimensions.get(c.MEDIA, [])
1102
+ channel_spend_map = {}
1103
+
1104
+ for spend_data_point in data_points_with_spend:
1105
+ for media_variable in spend_data_point.media_variables:
1106
+ if (
1107
+ media_variable.channel_name in channel_names
1108
+ and media_variable.HasField(c.MEDIA_SPEND)
1109
+ ):
1110
+ channel_spend_map[media_variable.channel_name] = (
1111
+ media_variable.media_spend
1112
+ )
1113
+
1114
+ if not channel_spend_map:
1115
+ return None
1116
+
1117
+ return xr.DataArray(
1118
+ data=list(channel_spend_map.values()),
1119
+ coords={c.MEDIA_CHANNEL: list(channel_spend_map.keys())},
1120
+ dims=[c.MEDIA_CHANNEL],
1121
+ name=c.MEDIA_SPEND,
1122
+ )
1123
+
1124
+ def _extract_aggregated_rf_spend(
1125
+ self,
1126
+ data_points_with_spend: list[marketing_pb.MarketingDataPoint],
1127
+ ) -> xr.DataArray | None:
1128
+ """Extracts aggregated spend data.
1129
+
1130
+ Args:
1131
+ data_points_with_spend: List of MarketingDataPoint protos with spend data.
1132
+
1133
+ Returns:
1134
+ An xr.DataArray with aggregated spend data or None if no data found.
1135
+ """
1136
+ channel_names = self._metadata.channel_dimensions.get(c.REACH, [])
1137
+ channel_spend_map = {}
1138
+
1139
+ for spend_data_point in data_points_with_spend:
1140
+ for rf_variable in spend_data_point.reach_frequency_variables:
1141
+ if rf_variable.channel_name in channel_names and rf_variable.HasField(
1142
+ c.SPEND
1143
+ ):
1144
+ channel_spend_map[rf_variable.channel_name] = rf_variable.spend
1145
+
1146
+ if not channel_spend_map:
1147
+ return None
1148
+
1149
+ return xr.DataArray(
1150
+ data=list(channel_spend_map.values()),
1151
+ coords={c.RF_CHANNEL: list(channel_spend_map.keys())},
1152
+ dims=[c.RF_CHANNEL],
1153
+ name=c.RF_SPEND,
1154
+ )
1155
+
1156
+ def _is_aggregated_spend_data_point(
1157
+ self, dp: marketing_pb.MarketingDataPoint
1158
+ ) -> bool:
1159
+ """Checks if a MarketingDataPoint with spend represents aggregated spend data.
1160
+
1161
+ Args:
1162
+ dp: A marketing_pb.MarketingDataPoint representing a spend data point.
1163
+
1164
+ Returns:
1165
+ True if the data point represents aggregated spend, False otherwise.
1166
+ """
1167
+ if not dp.HasField(sc.GEO_INFO) and self._metadata.media_time_dimension:
1168
+ media_time_interval = (
1169
+ self._metadata.media_time_dimension.time_dimension_interval
1170
+ )
1171
+ return (
1172
+ media_time_interval.start_date == dp.date_interval.start_date
1173
+ and media_time_interval.end_date == dp.date_interval.end_date
1174
+ )
1175
+ return False
1176
+
1177
+ def _extract_media_spend(self) -> xr.DataArray | None:
1178
+ """Extracts media spend data from the serialized proto.
1179
+
1180
+ Returns:
1181
+ An xr.DataArray with spend data or None if no data found.
1182
+ """
1183
+ # Filter data points relevant to spend based on channel type map
1184
+ media_channels = {
1185
+ channel
1186
+ for channel, metadata_channel_type in self._metadata.channel_types.items()
1187
+ if metadata_channel_type == c.MEDIA_CHANNEL
1188
+ }
1189
+ spend_data_points = [
1190
+ dp
1191
+ for dp in self._serialized.marketing_data_points
1192
+ if any(
1193
+ mv.HasField(c.MEDIA_SPEND) and mv.channel_name in media_channels
1194
+ for mv in dp.media_variables
1195
+ )
1196
+ ]
1197
+
1198
+ if not spend_data_points:
1199
+ return None
1200
+
1201
+ aggregated_spend_data_points = [
1202
+ dp
1203
+ for dp in spend_data_points
1204
+ if self._is_aggregated_spend_data_point(dp)
1205
+ ]
1206
+
1207
+ if aggregated_spend_data_points:
1208
+ return self._extract_aggregated_media_spend(aggregated_spend_data_points)
1209
+
1210
+ return self._extract_granular_media_spend(spend_data_points)
1211
+
1212
+ def _extract_rf_spend(self) -> xr.DataArray | None:
1213
+ """Extracts reach and frequency spend data from the serialized proto.
1214
+
1215
+ Returns:
1216
+ An xr.DataArray with spend data or None if no data found.
1217
+ """
1218
+ # Filter data points relevant to spend based on channel type map
1219
+ rf_channels = {
1220
+ channel
1221
+ for channel, metadata_channel_type in self._metadata.channel_types.items()
1222
+ if metadata_channel_type == c.RF_CHANNEL
1223
+ }
1224
+ spend_data_points = [
1225
+ dp
1226
+ for dp in self._serialized.marketing_data_points
1227
+ if any(
1228
+ mv.HasField(c.SPEND) and mv.channel_name in rf_channels
1229
+ for mv in dp.reach_frequency_variables
1230
+ )
1231
+ ]
1232
+
1233
+ if not spend_data_points:
1234
+ return None
1235
+
1236
+ aggregated_spend_data_points = [
1237
+ dp
1238
+ for dp in spend_data_points
1239
+ if self._is_aggregated_spend_data_point(dp)
1240
+ ]
1241
+
1242
+ if aggregated_spend_data_points:
1243
+ return self._extract_aggregated_rf_spend(aggregated_spend_data_points)
1244
+
1245
+ return self._extract_granular_rf_spend(spend_data_points)
1246
+
1247
+ def _extract_non_media_treatments(self) -> xr.DataArray | None:
1248
+ """Extracts non-media treatment variables data from the serialized proto."""
1249
+
1250
+ def _non_media_treatments_extractor(data_point):
1251
+ if not data_point.non_media_treatment_variables:
1252
+ return None
1253
+
1254
+ geo_id, time_str = self._extract_geo_and_time(data_point)
1255
+
1256
+ for (
1257
+ non_media_treatment_variable
1258
+ ) in data_point.non_media_treatment_variables:
1259
+ treatment_name = non_media_treatment_variable.name
1260
+ treatment_value = non_media_treatment_variable.value
1261
+ yield geo_id, time_str, treatment_name, treatment_value
1262
+
1263
+ non_media_treatments_data_array = _extract_3d_data_array(
1264
+ serialized_data_points=self._serialized.marketing_data_points,
1265
+ data_extractor_fn=_non_media_treatments_extractor,
1266
+ data_name=c.NON_MEDIA_TREATMENTS,
1267
+ third_dim_name=c.NON_MEDIA_CHANNEL,
1268
+ )
1269
+
1270
+ return non_media_treatments_data_array
1271
+
1272
+ def __call__(self) -> meridian_input_data.InputData:
1273
+ """Converts the `MarketingData` proto to a Meridian `InputData`."""
1274
+ kpi_type = self._extract_kpi_type()
1275
+ return meridian_input_data.InputData(
1276
+ kpi=self._extract_kpi(kpi_type),
1277
+ kpi_type=kpi_type,
1278
+ controls=self._extract_controls(),
1279
+ population=self._extract_population(),
1280
+ revenue_per_kpi=(
1281
+ self._extract_revenue_per_kpi(kpi_type)
1282
+ if kpi_type == c.NON_REVENUE
1283
+ else None
1284
+ ),
1285
+ media=self._extract_media(),
1286
+ media_spend=self._extract_media_spend(),
1287
+ reach=self._extract_reach(),
1288
+ frequency=self._extract_frequency(),
1289
+ rf_spend=self._extract_rf_spend(),
1290
+ organic_media=self._extract_organic_media(),
1291
+ organic_reach=self._extract_organic_reach(),
1292
+ organic_frequency=self._extract_organic_frequency(),
1293
+ non_media_treatments=self._extract_non_media_treatments(),
1294
+ )
1295
+
1296
+
1297
+ class MarketingDataSerde(
1298
+ serde.Serde[marketing_pb.MarketingData, meridian_input_data.InputData]
1299
+ ):
1300
+ """Serializes and deserializes an `InputData` container in Meridian."""
1301
+
1302
+ def serialize(
1303
+ self, obj: meridian_input_data.InputData
1304
+ ) -> marketing_pb.MarketingData:
1305
+ """Serializes the given Meridian input data into a `MarketingData` proto."""
1306
+ return _InputDataSerializer(obj)()
1307
+
1308
+ def deserialize(
1309
+ self, serialized: marketing_pb.MarketingData, serialized_version: str = ""
1310
+ ) -> meridian_input_data.InputData:
1311
+ """Deserializes the given `MarketingData` proto.
1312
+
1313
+ Args:
1314
+ serialized: The serialized `MarketingData` proto.
1315
+ serialized_version: The version of the serialized model. This is used to
1316
+ handle changes in deserialization logic across different versions.
1317
+
1318
+ Returns:
1319
+ A Meridian input data container.
1320
+ """
1321
+ return _InputDataDeserializer(serialized)()