google-meridian 1.2.1__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. google_meridian-1.3.1.dist-info/METADATA +209 -0
  2. google_meridian-1.3.1.dist-info/RECORD +76 -0
  3. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/top_level.txt +1 -0
  4. meridian/analysis/__init__.py +2 -0
  5. meridian/analysis/analyzer.py +179 -105
  6. meridian/analysis/formatter.py +2 -2
  7. meridian/analysis/optimizer.py +227 -87
  8. meridian/analysis/review/__init__.py +20 -0
  9. meridian/analysis/review/checks.py +721 -0
  10. meridian/analysis/review/configs.py +110 -0
  11. meridian/analysis/review/constants.py +40 -0
  12. meridian/analysis/review/results.py +544 -0
  13. meridian/analysis/review/reviewer.py +186 -0
  14. meridian/analysis/summarizer.py +21 -34
  15. meridian/analysis/templates/chips.html.jinja +12 -0
  16. meridian/analysis/test_utils.py +27 -5
  17. meridian/analysis/visualizer.py +41 -57
  18. meridian/backend/__init__.py +457 -118
  19. meridian/backend/test_utils.py +162 -0
  20. meridian/constants.py +39 -3
  21. meridian/model/__init__.py +1 -0
  22. meridian/model/eda/__init__.py +3 -0
  23. meridian/model/eda/constants.py +21 -0
  24. meridian/model/eda/eda_engine.py +1309 -196
  25. meridian/model/eda/eda_outcome.py +200 -0
  26. meridian/model/eda/eda_spec.py +84 -0
  27. meridian/model/eda/meridian_eda.py +220 -0
  28. meridian/model/knots.py +55 -49
  29. meridian/model/media.py +10 -8
  30. meridian/model/model.py +79 -16
  31. meridian/model/model_test_data.py +53 -0
  32. meridian/model/posterior_sampler.py +39 -32
  33. meridian/model/prior_distribution.py +12 -2
  34. meridian/model/prior_sampler.py +146 -90
  35. meridian/model/spec.py +7 -8
  36. meridian/model/transformers.py +11 -3
  37. meridian/version.py +1 -1
  38. schema/__init__.py +18 -0
  39. schema/serde/__init__.py +26 -0
  40. schema/serde/constants.py +48 -0
  41. schema/serde/distribution.py +515 -0
  42. schema/serde/eda_spec.py +192 -0
  43. schema/serde/function_registry.py +143 -0
  44. schema/serde/hyperparameters.py +363 -0
  45. schema/serde/inference_data.py +105 -0
  46. schema/serde/marketing_data.py +1321 -0
  47. schema/serde/meridian_serde.py +413 -0
  48. schema/serde/serde.py +47 -0
  49. schema/serde/test_data.py +4608 -0
  50. schema/utils/__init__.py +17 -0
  51. schema/utils/time_record.py +156 -0
  52. google_meridian-1.2.1.dist-info/METADATA +0 -409
  53. google_meridian-1.2.1.dist-info/RECORD +0 -52
  54. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/WHEEL +0 -0
  55. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,4608 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Test data for serde module."""
16
+
17
+ import inspect
18
+ import types
19
+ from typing import Any, Sequence
20
+ from unittest import mock
21
+
22
+ from meridian import backend
23
+ from meridian import constants as c
24
+ from meridian.model import prior_distribution
25
+ from meridian.model import spec
26
+ from mmm.v1.marketing import marketing_data_pb2 as marketing_pb
27
+ from mmm.v1.model.meridian import meridian_model_pb2 as meridian_pb
28
+ import numpy as np
29
+ import xarray as xr
30
+
31
+ from google.protobuf import text_format
32
+ from tensorflow.core.framework import tensor_pb2 # pylint: disable=g-direct-tensorflow-import
33
+ from tensorflow.core.framework import tensor_shape_pb2 # pylint: disable=g-direct-tensorflow-import
34
+ from tensorflow.core.framework import types_pb2 # pylint: disable=g-direct-tensorflow-import
35
+
36
+ _MediaEffectsDist = meridian_pb.MediaEffectsDistribution
37
+ _PaidMediaPriorType = meridian_pb.PaidMediaPriorType
38
+ _NonPaidTreatmentsPriorType = meridian_pb.NonPaidTreatmentsPriorType
39
+ _NonMediaBaselineFunction = (
40
+ meridian_pb.NonMediaBaselineValue.NonMediaBaselineFunction
41
+ )
42
+
43
+ # Shared constants
44
+ _TIME_STRS = ['2021-02-01', '2021-02-08']
45
+ _MEDIA_TIME_STRS = ['2021-01-25', '2021-02-01', '2021-02-08']
46
+ _GEO_IDS = ['geo_0', 'geo_1']
47
+ _MEDIA_CHANNEL_PAID = ['ch_paid_0', 'ch_paid_1']
48
+ _MEDIA_CHANNEL_ORGANIC = ['ch_organic_0', 'ch_organic_1']
49
+ _RF_CHANNEL_PAID = ['rf_ch_paid_0', 'rf_ch_paid_1']
50
+ _RF_CHANNEL_ORGANIC = ['rf_ch_organic_0', 'rf_ch_organic_1']
51
+ _CONTROL_VARIABLES = ['control_0', 'control_1']
52
+ _NON_MEDIA_TREATMENT_VARIABLES = [
53
+ 'non_media_treatment_0',
54
+ 'non_media_treatment_1',
55
+ ]
56
+
57
+
58
+ def make_tensor_shape_proto(
59
+ dims: Sequence[int],
60
+ ) -> tensor_shape_pb2.TensorShapeProto:
61
+ tensor_shape = tensor_shape_pb2.TensorShapeProto()
62
+ for dim in dims:
63
+ tensor_shape.dim.append(tensor_shape_pb2.TensorShapeProto.Dim(size=dim))
64
+ return tensor_shape
65
+
66
+
67
+ def make_tensor_proto(
68
+ dims: Sequence[int],
69
+ dtype: types_pb2.DataType = types_pb2.DT_FLOAT,
70
+ bool_vals: Sequence[bool] = (),
71
+ string_vals: Sequence[str] = (),
72
+ tensor_content: bytes = b'',
73
+ ) -> tensor_pb2.TensorProto:
74
+ return tensor_pb2.TensorProto(
75
+ dtype=dtype,
76
+ tensor_shape=make_tensor_shape_proto(dims),
77
+ bool_val=bool_vals,
78
+ string_val=[x.encode() for x in string_vals],
79
+ tensor_content=tensor_content,
80
+ )
81
+
82
+
83
+ def make_sample_dataset(
84
+ n_chains: int,
85
+ n_draws: int,
86
+ n_geos: int = 5,
87
+ n_controls: int = 2,
88
+ n_knots: int = 0,
89
+ n_times: int = 0,
90
+ n_media_channels: int = 0,
91
+ n_rf_channels: int = 0,
92
+ n_organic_media_channels: int = 0,
93
+ n_organic_rf_channels: int = 0,
94
+ n_non_media_channels: int = 0,
95
+ ) -> xr.Dataset:
96
+ """Creates a sample dataset with all relevant Meridian dimensions.
97
+
98
+ Args:
99
+ n_chains: The number of chains.
100
+ n_draws: The number of draws per chain.
101
+ n_geos: The number of geos.
102
+ n_controls: The number of control variables.
103
+ n_knots: The number of knots.
104
+ n_times: The number of time periods.
105
+ n_media_channels: The number of media channels.
106
+ n_rf_channels: The number of reach and frequency channels.
107
+ n_organic_media_channels: The number of organic media channels.
108
+ n_organic_rf_channels: The number of organic reach and frequency channels.
109
+ n_non_media_channels: The number of non-media channels.
110
+
111
+ Returns:
112
+ An xarray Dataset with sample data.
113
+ """
114
+ data_vars = {
115
+ c.STEP_SIZE: (
116
+ [c.CHAIN, c.DRAW],
117
+ np.random.normal(size=(n_chains, n_draws)),
118
+ ),
119
+ c.TUNE: (
120
+ [c.CHAIN, c.DRAW],
121
+ np.full((n_chains, n_draws), False),
122
+ ),
123
+ c.TARGET_LOG_PROBABILITY_TF: (
124
+ [c.CHAIN, c.DRAW],
125
+ np.random.normal(size=(n_chains, n_draws)),
126
+ ),
127
+ c.DIVERGING: (
128
+ [c.CHAIN, c.DRAW],
129
+ np.full((n_chains, n_draws), False),
130
+ ),
131
+ c.ACCEPT_RATIO: (
132
+ [c.CHAIN, c.DRAW],
133
+ np.random.normal(size=(n_chains, n_draws)),
134
+ ),
135
+ c.N_STEPS: (
136
+ [c.CHAIN, c.DRAW],
137
+ np.random.normal(size=(n_chains, n_draws)),
138
+ ),
139
+ 'is_accepted': (
140
+ [c.CHAIN, c.DRAW],
141
+ np.full((n_chains, n_draws), True),
142
+ ),
143
+ }
144
+ coords = {
145
+ c.CHAIN: ([c.CHAIN], np.arange(n_chains)),
146
+ c.DRAW: ([c.DRAW], np.arange(n_draws)),
147
+ c.GEO: ([c.GEO], np.arange(n_geos)),
148
+ c.CONTROL_VARIABLE: (
149
+ [c.CONTROL_VARIABLE],
150
+ np.arange(n_controls),
151
+ ),
152
+ }
153
+
154
+ if n_knots > 0:
155
+ coords[c.KNOTS] = ([c.KNOTS], np.arange(n_knots))
156
+
157
+ if n_times > 0:
158
+ coords[c.TIME] = ([c.TIME], np.arange(n_times))
159
+
160
+ if n_media_channels > 0:
161
+ coords[c.MEDIA_CHANNEL] = (
162
+ [c.MEDIA_CHANNEL],
163
+ np.arange(n_media_channels),
164
+ )
165
+
166
+ if n_rf_channels > 0:
167
+ coords[c.RF_CHANNEL] = (
168
+ [c.RF_CHANNEL],
169
+ np.arange(n_rf_channels),
170
+ )
171
+
172
+ if n_organic_media_channels > 0:
173
+ coords[c.ORGANIC_MEDIA_CHANNEL] = (
174
+ [c.ORGANIC_MEDIA_CHANNEL],
175
+ np.arange(n_organic_media_channels),
176
+ )
177
+
178
+ if n_organic_rf_channels > 0:
179
+ coords[c.ORGANIC_RF_CHANNEL] = (
180
+ [c.ORGANIC_RF_CHANNEL],
181
+ np.arange(n_organic_rf_channels),
182
+ )
183
+
184
+ if n_non_media_channels > 0:
185
+ coords[c.NON_MEDIA_CHANNEL] = (
186
+ [c.NON_MEDIA_CHANNEL],
187
+ np.arange(n_non_media_channels),
188
+ )
189
+
190
+ return xr.Dataset(data_vars, coords=coords)
191
+
192
+
193
+ # Marketing data test data
194
+ MOCK_INPUT_DATA_NATIONAL_MEDIA_RF_NON_REVENUE = mock.MagicMock(
195
+ kpi_type=c.NON_REVENUE,
196
+ geo=xr.DataArray(np.array(['national_geo'])),
197
+ time=xr.DataArray(np.array(_TIME_STRS)),
198
+ media_time=xr.DataArray(np.array(_TIME_STRS)),
199
+ population=xr.DataArray(
200
+ coords={c.GEO: ['national_geo']},
201
+ data=np.array([1.0]),
202
+ name=c.POPULATION,
203
+ ),
204
+ media=xr.DataArray(
205
+ coords={
206
+ c.GEO: ['national_geo'],
207
+ c.MEDIA_TIME: _TIME_STRS,
208
+ c.MEDIA_CHANNEL: _MEDIA_CHANNEL_PAID,
209
+ },
210
+ data=np.array([[[41, 42], [43, 44]]]),
211
+ name=c.MEDIA,
212
+ ),
213
+ media_spend=xr.DataArray(
214
+ coords={
215
+ c.GEO: ['national_geo'],
216
+ c.TIME: _TIME_STRS,
217
+ c.MEDIA_CHANNEL: _MEDIA_CHANNEL_PAID,
218
+ },
219
+ data=np.array([[[141, 142], [143, 144]]]),
220
+ name=c.MEDIA_SPEND,
221
+ ),
222
+ media_spend_has_geo_dimension=True,
223
+ media_spend_has_time_dimension=True,
224
+ reach=xr.DataArray(
225
+ coords={
226
+ c.GEO: ['national_geo'],
227
+ c.MEDIA_TIME: _TIME_STRS,
228
+ c.RF_CHANNEL: _RF_CHANNEL_PAID,
229
+ },
230
+ data=np.array([[[51.0, 52.0], [53.0, 54.0]]]),
231
+ name=c.REACH,
232
+ ),
233
+ frequency=xr.DataArray(
234
+ coords={
235
+ c.GEO: ['national_geo'],
236
+ c.MEDIA_TIME: _TIME_STRS,
237
+ c.RF_CHANNEL: _RF_CHANNEL_PAID,
238
+ },
239
+ data=np.array([[[1.1, 1.2], [2, 3]]]),
240
+ name=c.FREQUENCY,
241
+ ),
242
+ rf_spend=xr.DataArray(
243
+ coords={c.RF_CHANNEL: _RF_CHANNEL_PAID},
244
+ data=np.array([502, 504]),
245
+ name=c.RF_SPEND,
246
+ ),
247
+ rf_spend_has_geo_dimension=False,
248
+ rf_spend_has_time_dimension=False,
249
+ kpi=xr.DataArray(
250
+ coords={
251
+ c.GEO: ['national_geo'],
252
+ c.TIME: _TIME_STRS,
253
+ },
254
+ data=np.array([[1, 2]]),
255
+ name=c.KPI,
256
+ ),
257
+ revenue_per_kpi=xr.DataArray(
258
+ coords={
259
+ c.GEO: ['national_geo'],
260
+ c.TIME: _TIME_STRS,
261
+ },
262
+ data=np.array([[11, 12]]),
263
+ name=c.REVENUE_PER_KPI,
264
+ ),
265
+ controls=xr.DataArray(
266
+ coords={
267
+ c.GEO: ['national_geo'],
268
+ c.TIME: _TIME_STRS,
269
+ c.CONTROL_VARIABLE: ['control_0', 'control_1'],
270
+ },
271
+ data=np.array([[[31, 32], [33, 34]]]),
272
+ name=c.CONTROLS,
273
+ ),
274
+ organic_media=None,
275
+ organic_reach=None,
276
+ organic_frequency=None,
277
+ non_media_treatments=None,
278
+ )
279
+
280
+ MOCK_PROTO_NATIONAL_MEDIA_RF_NON_REVENUE = text_format.Parse(
281
+ """
282
+ marketing_data_points {
283
+ geo_info {
284
+ geo_id: "national_geo"
285
+ population: 1
286
+ }
287
+ date_interval {
288
+ start_date {
289
+ year: 2021
290
+ month: 2
291
+ day: 1
292
+ }
293
+ end_date {
294
+ year: 2021
295
+ month: 2
296
+ day: 8
297
+ }
298
+ }
299
+ control_variables {
300
+ name: "control_0"
301
+ value: 31.0
302
+ }
303
+ control_variables {
304
+ name: "control_1"
305
+ value: 32.0
306
+ }
307
+ media_variables {
308
+ channel_name: "ch_paid_0"
309
+ scalar_metric {
310
+ name: "impressions"
311
+ value: 41.0
312
+ }
313
+ media_spend: 141.0
314
+ }
315
+ media_variables {
316
+ channel_name: "ch_paid_1"
317
+ scalar_metric {
318
+ name: "impressions"
319
+ value: 42.0
320
+ }
321
+ media_spend: 142.0
322
+ }
323
+ reach_frequency_variables {
324
+ channel_name: "rf_ch_paid_0"
325
+ reach: 51
326
+ average_frequency: 1.1
327
+ }
328
+ reach_frequency_variables {
329
+ channel_name: "rf_ch_paid_1"
330
+ reach: 52
331
+ average_frequency: 1.2
332
+ }
333
+ kpi {
334
+ name: "non_revenue"
335
+ non_revenue {
336
+ value: 1.0
337
+ revenue_per_kpi: 11.0
338
+ }
339
+ }
340
+ }
341
+ marketing_data_points {
342
+ geo_info {
343
+ geo_id: "national_geo"
344
+ population: 1
345
+ }
346
+ date_interval {
347
+ start_date {
348
+ year: 2021
349
+ month: 2
350
+ day: 8
351
+ }
352
+ end_date {
353
+ year: 2021
354
+ month: 2
355
+ day: 15
356
+ }
357
+ }
358
+ control_variables {
359
+ name: "control_0"
360
+ value: 33.0
361
+ }
362
+ control_variables {
363
+ name: "control_1"
364
+ value: 34.0
365
+ }
366
+ media_variables {
367
+ channel_name: "ch_paid_0"
368
+ scalar_metric {
369
+ name: "impressions"
370
+ value: 43.0
371
+ }
372
+ media_spend: 143.0
373
+ }
374
+ media_variables {
375
+ channel_name: "ch_paid_1"
376
+ scalar_metric {
377
+ name: "impressions"
378
+ value: 44.0
379
+ }
380
+ media_spend: 144.0
381
+ }
382
+ reach_frequency_variables {
383
+ channel_name: "rf_ch_paid_0"
384
+ reach: 53
385
+ average_frequency: 2.0
386
+ }
387
+ reach_frequency_variables {
388
+ channel_name: "rf_ch_paid_1"
389
+ reach: 54
390
+ average_frequency: 3.0
391
+ }
392
+ kpi {
393
+ name: "non_revenue"
394
+ non_revenue {
395
+ value: 2.0
396
+ revenue_per_kpi: 12.0
397
+ }
398
+ }
399
+ }
400
+ marketing_data_points {
401
+ date_interval {
402
+ start_date {
403
+ year: 2021
404
+ month: 2
405
+ day: 1
406
+ }
407
+ end_date {
408
+ year: 2021
409
+ month: 2
410
+ day: 15
411
+ }
412
+ }
413
+ reach_frequency_variables {
414
+ channel_name: "rf_ch_paid_0"
415
+ spend: 502.0
416
+ }
417
+ reach_frequency_variables {
418
+ channel_name: "rf_ch_paid_1"
419
+ spend: 504.0
420
+ }
421
+ }
422
+ metadata {
423
+ time_dimensions {
424
+ name: "time"
425
+ dates {
426
+ year: 2021
427
+ month: 2
428
+ day: 1
429
+ }
430
+ dates {
431
+ year: 2021
432
+ month: 2
433
+ day: 8
434
+ }
435
+ }
436
+ time_dimensions {
437
+ name: "media_time"
438
+ dates {
439
+ year: 2021
440
+ month: 2
441
+ day: 1
442
+ }
443
+ dates {
444
+ year: 2021
445
+ month: 2
446
+ day: 8
447
+ }
448
+ }
449
+ channel_dimensions {
450
+ name: "media"
451
+ channels: "ch_paid_0"
452
+ channels: "ch_paid_1"
453
+ }
454
+ channel_dimensions {
455
+ name: "reach"
456
+ channels: "rf_ch_paid_0"
457
+ channels: "rf_ch_paid_1"
458
+ }
459
+ channel_dimensions {
460
+ name: "frequency"
461
+ channels: "rf_ch_paid_0"
462
+ channels: "rf_ch_paid_1"
463
+ }
464
+ control_names: "control_0"
465
+ control_names: "control_1"
466
+ kpi_type: "non_revenue"
467
+ }
468
+ """,
469
+ marketing_pb.MarketingData(),
470
+ )
471
+
472
+ # Same as above, but with no controls.
473
+ MOCK_INPUT_DATA_NATIONAL_MEDIA_RF_NON_REVENUE_NO_CONTROLS = mock.MagicMock(
474
+ kpi_type=c.NON_REVENUE,
475
+ geo=xr.DataArray(np.array(['national_geo'])),
476
+ time=xr.DataArray(np.array(_TIME_STRS)),
477
+ media_time=xr.DataArray(np.array(_TIME_STRS)),
478
+ population=xr.DataArray(
479
+ coords={c.GEO: ['national_geo']},
480
+ data=np.array([1.0]),
481
+ name=c.POPULATION,
482
+ ),
483
+ media=xr.DataArray(
484
+ coords={
485
+ c.GEO: ['national_geo'],
486
+ c.MEDIA_TIME: _TIME_STRS,
487
+ c.MEDIA_CHANNEL: _MEDIA_CHANNEL_PAID,
488
+ },
489
+ data=np.array([[[41, 42], [43, 44]]]),
490
+ name=c.MEDIA,
491
+ ),
492
+ media_spend=xr.DataArray(
493
+ coords={
494
+ c.GEO: ['national_geo'],
495
+ c.TIME: _TIME_STRS,
496
+ c.MEDIA_CHANNEL: _MEDIA_CHANNEL_PAID,
497
+ },
498
+ data=np.array([[[141, 142], [143, 144]]]),
499
+ name=c.MEDIA_SPEND,
500
+ ),
501
+ media_spend_has_geo_dimension=True,
502
+ media_spend_has_time_dimension=True,
503
+ reach=xr.DataArray(
504
+ coords={
505
+ c.GEO: ['national_geo'],
506
+ c.MEDIA_TIME: _TIME_STRS,
507
+ c.RF_CHANNEL: _RF_CHANNEL_PAID,
508
+ },
509
+ data=np.array([[[51.0, 52.0], [53.0, 54.0]]]),
510
+ name=c.REACH,
511
+ ),
512
+ frequency=xr.DataArray(
513
+ coords={
514
+ c.GEO: ['national_geo'],
515
+ c.MEDIA_TIME: _TIME_STRS,
516
+ c.RF_CHANNEL: _RF_CHANNEL_PAID,
517
+ },
518
+ data=np.array([[[1.1, 1.2], [2, 3]]]),
519
+ name=c.FREQUENCY,
520
+ ),
521
+ rf_spend=xr.DataArray(
522
+ coords={c.RF_CHANNEL: _RF_CHANNEL_PAID},
523
+ data=np.array([502, 504]),
524
+ name=c.RF_SPEND,
525
+ ),
526
+ rf_spend_has_geo_dimension=False,
527
+ rf_spend_has_time_dimension=False,
528
+ kpi=xr.DataArray(
529
+ coords={
530
+ c.GEO: ['national_geo'],
531
+ c.TIME: _TIME_STRS,
532
+ },
533
+ data=np.array([[1, 2]]),
534
+ name=c.KPI,
535
+ ),
536
+ revenue_per_kpi=xr.DataArray(
537
+ coords={
538
+ c.GEO: ['national_geo'],
539
+ c.TIME: _TIME_STRS,
540
+ },
541
+ data=np.array([[11, 12]]),
542
+ name=c.REVENUE_PER_KPI,
543
+ ),
544
+ controls=None,
545
+ organic_media=None,
546
+ organic_reach=None,
547
+ organic_frequency=None,
548
+ non_media_treatments=None,
549
+ )
550
+
551
+ MOCK_PROTO_NATIONAL_MEDIA_RF_NON_REVENUE_NO_CONTROLS = text_format.Parse(
552
+ """
553
+ marketing_data_points {
554
+ geo_info {
555
+ geo_id: "national_geo"
556
+ population: 1
557
+ }
558
+ date_interval {
559
+ start_date {
560
+ year: 2021
561
+ month: 2
562
+ day: 1
563
+ }
564
+ end_date {
565
+ year: 2021
566
+ month: 2
567
+ day: 8
568
+ }
569
+ }
570
+ media_variables {
571
+ channel_name: "ch_paid_0"
572
+ scalar_metric {
573
+ name: "impressions"
574
+ value: 41.0
575
+ }
576
+ media_spend: 141.0
577
+ }
578
+ media_variables {
579
+ channel_name: "ch_paid_1"
580
+ scalar_metric {
581
+ name: "impressions"
582
+ value: 42.0
583
+ }
584
+ media_spend: 142.0
585
+ }
586
+ reach_frequency_variables {
587
+ channel_name: "rf_ch_paid_0"
588
+ reach: 51
589
+ average_frequency: 1.1
590
+ }
591
+ reach_frequency_variables {
592
+ channel_name: "rf_ch_paid_1"
593
+ reach: 52
594
+ average_frequency: 1.2
595
+ }
596
+ kpi {
597
+ name: "non_revenue"
598
+ non_revenue {
599
+ value: 1.0
600
+ revenue_per_kpi: 11.0
601
+ }
602
+ }
603
+ }
604
+ marketing_data_points {
605
+ geo_info {
606
+ geo_id: "national_geo"
607
+ population: 1
608
+ }
609
+ date_interval {
610
+ start_date {
611
+ year: 2021
612
+ month: 2
613
+ day: 8
614
+ }
615
+ end_date {
616
+ year: 2021
617
+ month: 2
618
+ day: 15
619
+ }
620
+ }
621
+ media_variables {
622
+ channel_name: "ch_paid_0"
623
+ scalar_metric {
624
+ name: "impressions"
625
+ value: 43.0
626
+ }
627
+ media_spend: 143.0
628
+ }
629
+ media_variables {
630
+ channel_name: "ch_paid_1"
631
+ scalar_metric {
632
+ name: "impressions"
633
+ value: 44.0
634
+ }
635
+ media_spend: 144.0
636
+ }
637
+ reach_frequency_variables {
638
+ channel_name: "rf_ch_paid_0"
639
+ reach: 53
640
+ average_frequency: 2.0
641
+ }
642
+ reach_frequency_variables {
643
+ channel_name: "rf_ch_paid_1"
644
+ reach: 54
645
+ average_frequency: 3.0
646
+ }
647
+ kpi {
648
+ name: "non_revenue"
649
+ non_revenue {
650
+ value: 2.0
651
+ revenue_per_kpi: 12.0
652
+ }
653
+ }
654
+ }
655
+ marketing_data_points {
656
+ date_interval {
657
+ start_date {
658
+ year: 2021
659
+ month: 2
660
+ day: 1
661
+ }
662
+ end_date {
663
+ year: 2021
664
+ month: 2
665
+ day: 15
666
+ }
667
+ }
668
+ reach_frequency_variables {
669
+ channel_name: "rf_ch_paid_0"
670
+ spend: 502.0
671
+ }
672
+ reach_frequency_variables {
673
+ channel_name: "rf_ch_paid_1"
674
+ spend: 504.0
675
+ }
676
+ }
677
+ metadata {
678
+ time_dimensions {
679
+ name: "time"
680
+ dates {
681
+ year: 2021
682
+ month: 2
683
+ day: 1
684
+ }
685
+ dates {
686
+ year: 2021
687
+ month: 2
688
+ day: 8
689
+ }
690
+ }
691
+ time_dimensions {
692
+ name: "media_time"
693
+ dates {
694
+ year: 2021
695
+ month: 2
696
+ day: 1
697
+ }
698
+ dates {
699
+ year: 2021
700
+ month: 2
701
+ day: 8
702
+ }
703
+ }
704
+ channel_dimensions {
705
+ name: "media"
706
+ channels: "ch_paid_0"
707
+ channels: "ch_paid_1"
708
+ }
709
+ channel_dimensions {
710
+ name: "reach"
711
+ channels: "rf_ch_paid_0"
712
+ channels: "rf_ch_paid_1"
713
+ }
714
+ channel_dimensions {
715
+ name: "frequency"
716
+ channels: "rf_ch_paid_0"
717
+ channels: "rf_ch_paid_1"
718
+ }
719
+ kpi_type: "non_revenue"
720
+ }
721
+ """,
722
+ marketing_pb.MarketingData(),
723
+ )
724
+
725
+ # Media, Paid, Expanded, Lagged
726
+ MOCK_INPUT_DATA_MEDIA_PAID_EXPANDED_LAGGED = mock.MagicMock(
727
+ kpi_type=c.REVENUE,
728
+ geo=xr.DataArray(np.array(_GEO_IDS)),
729
+ time=xr.DataArray(np.array(_TIME_STRS)),
730
+ media_time=xr.DataArray(np.array(_MEDIA_TIME_STRS)),
731
+ population=xr.DataArray(
732
+ coords={c.GEO: _GEO_IDS},
733
+ data=np.array([11.1, 12.2]),
734
+ name=c.POPULATION,
735
+ ),
736
+ media=xr.DataArray(
737
+ coords={
738
+ c.GEO: _GEO_IDS,
739
+ c.MEDIA_TIME: _MEDIA_TIME_STRS,
740
+ c.MEDIA_CHANNEL: _MEDIA_CHANNEL_PAID,
741
+ },
742
+ data=np.array(
743
+ [[[39, 40], [41, 42], [43, 44]], [[45, 46], [47, 48], [49, 50]]]
744
+ ),
745
+ name=c.MEDIA,
746
+ ),
747
+ media_spend=xr.DataArray(
748
+ coords={c.MEDIA_CHANNEL: _MEDIA_CHANNEL_PAID},
749
+ data=np.array([492, 496]),
750
+ name=c.MEDIA_SPEND,
751
+ ),
752
+ media_spend_has_geo_dimension=False,
753
+ media_spend_has_time_dimension=False,
754
+ reach=None,
755
+ frequency=None,
756
+ rf_spend=None,
757
+ kpi=xr.DataArray(
758
+ coords={
759
+ c.GEO: _GEO_IDS,
760
+ c.TIME: _TIME_STRS,
761
+ },
762
+ data=np.array([[2, 3], [4, 5]]),
763
+ name=c.KPI,
764
+ ),
765
+ revenue_per_kpi=xr.DataArray(
766
+ coords={
767
+ c.GEO: _GEO_IDS,
768
+ c.TIME: _TIME_STRS,
769
+ },
770
+ data=np.ones((2, 2)),
771
+ name=c.REVENUE_PER_KPI,
772
+ ),
773
+ controls=xr.DataArray(
774
+ coords={
775
+ c.GEO: _GEO_IDS,
776
+ c.TIME: _TIME_STRS,
777
+ c.CONTROL_VARIABLE: _CONTROL_VARIABLES,
778
+ },
779
+ data=np.array([[[32, 33], [34, 35]], [[36, 37], [38, 39]]]),
780
+ name=c.CONTROLS,
781
+ ),
782
+ organic_media=None,
783
+ organic_reach=None,
784
+ organic_frequency=None,
785
+ non_media_treatments=None,
786
+ )
787
+
788
+ MOCK_PROTO_MEDIA_PAID_EXPANDED_LAGGED = text_format.Parse(
789
+ """
790
+ marketing_data_points {
791
+ geo_info {
792
+ geo_id: "geo_0"
793
+ population: 11
794
+ }
795
+ date_interval {
796
+ start_date {
797
+ year: 2021
798
+ month: 1
799
+ day: 25
800
+ }
801
+ end_date {
802
+ year: 2021
803
+ month: 2
804
+ day: 1
805
+ }
806
+ }
807
+ media_variables {
808
+ channel_name: "ch_paid_0"
809
+ scalar_metric {
810
+ name: "impressions"
811
+ value: 39.0
812
+ }
813
+ }
814
+ media_variables {
815
+ channel_name: "ch_paid_1"
816
+ scalar_metric {
817
+ name: "impressions"
818
+ value: 40.0
819
+ }
820
+ }
821
+ }
822
+ marketing_data_points {
823
+ geo_info {
824
+ geo_id: "geo_0"
825
+ population: 11
826
+ }
827
+ date_interval {
828
+ start_date {
829
+ year: 2021
830
+ month: 2
831
+ day: 1
832
+ }
833
+ end_date {
834
+ year: 2021
835
+ month: 2
836
+ day: 8
837
+ }
838
+ }
839
+ control_variables {
840
+ name: "control_0"
841
+ value: 32.0
842
+ }
843
+ control_variables {
844
+ name: "control_1"
845
+ value: 33.0
846
+ }
847
+ media_variables {
848
+ channel_name: "ch_paid_0"
849
+ scalar_metric {
850
+ name: "impressions"
851
+ value: 41.0
852
+ }
853
+ }
854
+ media_variables {
855
+ channel_name: "ch_paid_1"
856
+ scalar_metric {
857
+ name: "impressions"
858
+ value: 42.0
859
+ }
860
+ }
861
+ kpi {
862
+ name: "revenue"
863
+ revenue {
864
+ value: 2.0
865
+ }
866
+ }
867
+ }
868
+ marketing_data_points {
869
+ geo_info {
870
+ geo_id: "geo_0"
871
+ population: 11
872
+ }
873
+ date_interval {
874
+ start_date {
875
+ year: 2021
876
+ month: 2
877
+ day: 8
878
+ }
879
+ end_date {
880
+ year: 2021
881
+ month: 2
882
+ day: 15
883
+ }
884
+ }
885
+ control_variables {
886
+ name: "control_0"
887
+ value: 34.0
888
+ }
889
+ control_variables {
890
+ name: "control_1"
891
+ value: 35.0
892
+ }
893
+ media_variables {
894
+ channel_name: "ch_paid_0"
895
+ scalar_metric {
896
+ name: "impressions"
897
+ value: 43.0
898
+ }
899
+ }
900
+ media_variables {
901
+ channel_name: "ch_paid_1"
902
+ scalar_metric {
903
+ name: "impressions"
904
+ value: 44.0
905
+ }
906
+ }
907
+ kpi {
908
+ name: "revenue"
909
+ revenue {
910
+ value: 3.0
911
+ }
912
+ }
913
+ }
914
+ marketing_data_points {
915
+ geo_info {
916
+ geo_id: "geo_1"
917
+ population: 12
918
+ }
919
+ date_interval {
920
+ start_date {
921
+ year: 2021
922
+ month: 1
923
+ day: 25
924
+ }
925
+ end_date {
926
+ year: 2021
927
+ month: 2
928
+ day: 1
929
+ }
930
+ }
931
+ media_variables {
932
+ channel_name: "ch_paid_0"
933
+ scalar_metric {
934
+ name: "impressions"
935
+ value: 45.0
936
+ }
937
+ }
938
+ media_variables {
939
+ channel_name: "ch_paid_1"
940
+ scalar_metric {
941
+ name: "impressions"
942
+ value: 46.0
943
+ }
944
+ }
945
+ }
946
+ marketing_data_points {
947
+ geo_info {
948
+ geo_id: "geo_1"
949
+ population: 12
950
+ }
951
+ date_interval {
952
+ start_date {
953
+ year: 2021
954
+ month: 2
955
+ day: 1
956
+ }
957
+ end_date {
958
+ year: 2021
959
+ month: 2
960
+ day: 8
961
+ }
962
+ }
963
+ control_variables {
964
+ name: "control_0"
965
+ value: 36.0
966
+ }
967
+ control_variables {
968
+ name: "control_1"
969
+ value: 37.0
970
+ }
971
+ media_variables {
972
+ channel_name: "ch_paid_0"
973
+ scalar_metric {
974
+ name: "impressions"
975
+ value: 47.0
976
+ }
977
+ }
978
+ media_variables {
979
+ channel_name: "ch_paid_1"
980
+ scalar_metric {
981
+ name: "impressions"
982
+ value: 48.0
983
+ }
984
+ }
985
+ kpi {
986
+ name: "revenue"
987
+ revenue {
988
+ value: 4.0
989
+ }
990
+ }
991
+ }
992
+ marketing_data_points {
993
+ geo_info {
994
+ geo_id: "geo_1"
995
+ population: 12
996
+ }
997
+ date_interval {
998
+ start_date {
999
+ year: 2021
1000
+ month: 2
1001
+ day: 8
1002
+ }
1003
+ end_date {
1004
+ year: 2021
1005
+ month: 2
1006
+ day: 15
1007
+ }
1008
+ }
1009
+ control_variables {
1010
+ name: "control_0"
1011
+ value: 38.0
1012
+ }
1013
+ control_variables {
1014
+ name: "control_1"
1015
+ value: 39.0
1016
+ }
1017
+ media_variables {
1018
+ channel_name: "ch_paid_0"
1019
+ scalar_metric {
1020
+ name: "impressions"
1021
+ value: 49.0
1022
+ }
1023
+ }
1024
+ media_variables {
1025
+ channel_name: "ch_paid_1"
1026
+ scalar_metric {
1027
+ name: "impressions"
1028
+ value: 50.0
1029
+ }
1030
+ }
1031
+ kpi {
1032
+ name: "revenue"
1033
+ revenue {
1034
+ value: 5.0
1035
+ }
1036
+ }
1037
+ }
1038
+ marketing_data_points {
1039
+ date_interval {
1040
+ start_date {
1041
+ year: 2021
1042
+ month: 1
1043
+ day: 25
1044
+ }
1045
+ end_date {
1046
+ year: 2021
1047
+ month: 2
1048
+ day: 15
1049
+ }
1050
+ }
1051
+ media_variables {
1052
+ channel_name: "ch_paid_0"
1053
+ media_spend: 492.0
1054
+ }
1055
+ media_variables {
1056
+ channel_name: "ch_paid_1"
1057
+ media_spend: 496.0
1058
+ }
1059
+ }
1060
+ metadata {
1061
+ time_dimensions {
1062
+ name: "time"
1063
+ dates {
1064
+ year: 2021
1065
+ month: 2
1066
+ day: 1
1067
+ }
1068
+ dates {
1069
+ year: 2021
1070
+ month: 2
1071
+ day: 8
1072
+ }
1073
+ }
1074
+ time_dimensions {
1075
+ name: "media_time"
1076
+ dates {
1077
+ year: 2021
1078
+ month: 1
1079
+ day: 25
1080
+ }
1081
+ dates {
1082
+ year: 2021
1083
+ month: 2
1084
+ day: 1
1085
+ }
1086
+ dates {
1087
+ year: 2021
1088
+ month: 2
1089
+ day: 8
1090
+ }
1091
+ }
1092
+ channel_dimensions {
1093
+ name: "media"
1094
+ channels: "ch_paid_0"
1095
+ channels: "ch_paid_1"
1096
+ }
1097
+ control_names: "control_0"
1098
+ control_names: "control_1"
1099
+ kpi_type: "revenue"
1100
+ }
1101
+ """,
1102
+ marketing_pb.MarketingData(),
1103
+ )
1104
+
1105
+ # Media, Paid, Granular, Not Lagged
1106
+ MOCK_INPUT_DATA_MEDIA_PAID_GRANULAR_NOT_LAGGED = mock.MagicMock(
1107
+ kpi_type=c.REVENUE,
1108
+ geo=xr.DataArray(np.array(_GEO_IDS)),
1109
+ time=xr.DataArray(np.array(_TIME_STRS)),
1110
+ media_time=xr.DataArray(np.array(_TIME_STRS)),
1111
+ population=xr.DataArray(
1112
+ coords={c.GEO: _GEO_IDS}, data=np.array([11.1, 12.2]), name=c.POPULATION
1113
+ ),
1114
+ media=xr.DataArray(
1115
+ coords={
1116
+ c.GEO: _GEO_IDS,
1117
+ c.MEDIA_TIME: _TIME_STRS,
1118
+ c.MEDIA_CHANNEL: _MEDIA_CHANNEL_PAID,
1119
+ },
1120
+ data=np.array([[[39, 40], [41, 42]], [[43, 44], [45, 46]]]),
1121
+ name=c.MEDIA,
1122
+ ),
1123
+ media_spend=xr.DataArray(
1124
+ coords={
1125
+ c.GEO: _GEO_IDS,
1126
+ c.TIME: _TIME_STRS,
1127
+ c.MEDIA_CHANNEL: _MEDIA_CHANNEL_PAID,
1128
+ },
1129
+ data=np.array([[[123, 124], [125, 126]], [[127, 128], [129, 130]]]),
1130
+ name=c.MEDIA_SPEND,
1131
+ ),
1132
+ media_spend_has_geo_dimension=True,
1133
+ media_spend_has_time_dimension=True,
1134
+ reach=None,
1135
+ frequency=None,
1136
+ rf_spend=None,
1137
+ kpi=xr.DataArray(
1138
+ coords={c.GEO: _GEO_IDS, c.TIME: _TIME_STRS},
1139
+ data=np.array([[1, 2], [3, 4]]),
1140
+ name=c.KPI,
1141
+ ),
1142
+ revenue_per_kpi=xr.DataArray(
1143
+ coords={c.GEO: _GEO_IDS, c.TIME: _TIME_STRS},
1144
+ data=np.ones((2, 2)),
1145
+ name=c.REVENUE_PER_KPI,
1146
+ ),
1147
+ controls=xr.DataArray(
1148
+ coords={
1149
+ c.GEO: _GEO_IDS,
1150
+ c.TIME: _TIME_STRS,
1151
+ c.CONTROL_VARIABLE: _CONTROL_VARIABLES,
1152
+ },
1153
+ data=np.array([[[31, 32], [33, 34]], [[35, 36], [37, 38]]]),
1154
+ name=c.CONTROLS,
1155
+ ),
1156
+ organic_media=None,
1157
+ organic_reach=None,
1158
+ organic_frequency=None,
1159
+ non_media_treatments=None,
1160
+ )
1161
+
1162
+ MOCK_PROTO_MEDIA_PAID_GRANULAR_NOT_LAGGED_STRING = """
1163
+ marketing_data_points {
1164
+ geo_info {
1165
+ geo_id: "geo_0"
1166
+ population: 11
1167
+ }
1168
+ date_interval {
1169
+ start_date {
1170
+ year: 2021
1171
+ month: 2
1172
+ day: 1
1173
+ }
1174
+ end_date {
1175
+ year: 2021
1176
+ month: 2
1177
+ day: 8
1178
+ }
1179
+ }
1180
+ control_variables {
1181
+ name: "control_0"
1182
+ value: 31.0
1183
+ }
1184
+ control_variables {
1185
+ name: "control_1"
1186
+ value: 32.0
1187
+ }
1188
+ media_variables {
1189
+ channel_name: "ch_paid_0"
1190
+ scalar_metric {
1191
+ name: "impressions"
1192
+ value: 39.0
1193
+ }
1194
+ media_spend: 123.0
1195
+ }
1196
+ media_variables {
1197
+ channel_name: "ch_paid_1"
1198
+ scalar_metric {
1199
+ name: "impressions"
1200
+ value: 40.0
1201
+ }
1202
+ media_spend: 124.0
1203
+ }
1204
+ kpi {
1205
+ name: "revenue"
1206
+ revenue {
1207
+ value: 1.0
1208
+ }
1209
+ }
1210
+ }
1211
+ marketing_data_points {
1212
+ geo_info {
1213
+ geo_id: "geo_0"
1214
+ population: 11
1215
+ }
1216
+ date_interval {
1217
+ start_date {
1218
+ year: 2021
1219
+ month: 2
1220
+ day: 8
1221
+ }
1222
+ end_date {
1223
+ year: 2021
1224
+ month: 2
1225
+ day: 15
1226
+ }
1227
+ }
1228
+ control_variables {
1229
+ name: "control_0"
1230
+ value: 33.0
1231
+ }
1232
+ control_variables {
1233
+ name: "control_1"
1234
+ value: 34.0
1235
+ }
1236
+ media_variables {
1237
+ channel_name: "ch_paid_0"
1238
+ scalar_metric {
1239
+ name: "impressions"
1240
+ value: 41.0
1241
+ }
1242
+ media_spend: 125.0
1243
+ }
1244
+ media_variables {
1245
+ channel_name: "ch_paid_1"
1246
+ scalar_metric {
1247
+ name: "impressions"
1248
+ value: 42.0
1249
+ }
1250
+ media_spend: 126.0
1251
+ }
1252
+ kpi {
1253
+ name: "revenue"
1254
+ revenue {
1255
+ value: 2.0
1256
+ }
1257
+ }
1258
+ }
1259
+ marketing_data_points {
1260
+ geo_info {
1261
+ geo_id: "geo_1"
1262
+ population: 12
1263
+ }
1264
+ date_interval {
1265
+ start_date {
1266
+ year: 2021
1267
+ month: 2
1268
+ day: 1
1269
+ }
1270
+ end_date {
1271
+ year: 2021
1272
+ month: 2
1273
+ day: 8
1274
+ }
1275
+ }
1276
+ control_variables {
1277
+ name: "control_0"
1278
+ value: 35.0
1279
+ }
1280
+ control_variables {
1281
+ name: "control_1"
1282
+ value: 36.0
1283
+ }
1284
+ media_variables {
1285
+ channel_name: "ch_paid_0"
1286
+ scalar_metric {
1287
+ name: "impressions"
1288
+ value: 43.0
1289
+ }
1290
+ media_spend: 127.0
1291
+ }
1292
+ media_variables {
1293
+ channel_name: "ch_paid_1"
1294
+ scalar_metric {
1295
+ name: "impressions"
1296
+ value: 44.0
1297
+ }
1298
+ media_spend: 128.0
1299
+ }
1300
+ kpi {
1301
+ name: "revenue"
1302
+ revenue {
1303
+ value: 3.0
1304
+ }
1305
+ }
1306
+ }
1307
+ marketing_data_points {
1308
+ geo_info {
1309
+ geo_id: "geo_1"
1310
+ population: 12
1311
+ }
1312
+ date_interval {
1313
+ start_date {
1314
+ year: 2021
1315
+ month: 2
1316
+ day: 8
1317
+ }
1318
+ end_date {
1319
+ year: 2021
1320
+ month: 2
1321
+ day: 15
1322
+ }
1323
+ }
1324
+ control_variables {
1325
+ name: "control_0"
1326
+ value: 37.0
1327
+ }
1328
+ control_variables {
1329
+ name: "control_1"
1330
+ value: 38.0
1331
+ }
1332
+ media_variables {
1333
+ channel_name: "ch_paid_0"
1334
+ scalar_metric {
1335
+ name: "impressions"
1336
+ value: 45.0
1337
+ }
1338
+ media_spend: 129.0
1339
+ }
1340
+ media_variables {
1341
+ channel_name: "ch_paid_1"
1342
+ scalar_metric {
1343
+ name: "impressions"
1344
+ value: 46.0
1345
+ }
1346
+ media_spend: 130.0
1347
+ }
1348
+ kpi {
1349
+ name: "revenue"
1350
+ revenue {
1351
+ value: 4.0
1352
+ }
1353
+ }
1354
+ }
1355
+ metadata {
1356
+ time_dimensions {
1357
+ name: "time"
1358
+ dates {
1359
+ year: 2021
1360
+ month: 2
1361
+ day: 1
1362
+ }
1363
+ dates {
1364
+ year: 2021
1365
+ month: 2
1366
+ day: 8
1367
+ }
1368
+ }
1369
+ time_dimensions {
1370
+ name: "media_time"
1371
+ dates {
1372
+ year: 2021
1373
+ month: 2
1374
+ day: 1
1375
+ }
1376
+ dates {
1377
+ year: 2021
1378
+ month: 2
1379
+ day: 8
1380
+ }
1381
+ }
1382
+ channel_dimensions {
1383
+ name: "media"
1384
+ channels: "ch_paid_0"
1385
+ channels: "ch_paid_1"
1386
+ }
1387
+ control_names: "control_0"
1388
+ control_names: "control_1"
1389
+ kpi_type: "revenue"
1390
+ }
1391
+ """
1392
+
1393
+ MOCK_PROTO_MEDIA_PAID_GRANULAR_NOT_LAGGED = text_format.Parse(
1394
+ MOCK_PROTO_MEDIA_PAID_GRANULAR_NOT_LAGGED_STRING,
1395
+ marketing_pb.MarketingData(),
1396
+ )
1397
+
1398
+ # Media, Organic, Expanded, Lagged
1399
+ MOCK_INPUT_DATA_MEDIA_ORGANIC_EXPANDED_LAGGED = mock.MagicMock(
1400
+ kpi_type=c.REVENUE,
1401
+ geo=xr.DataArray(np.array(_GEO_IDS)),
1402
+ time=xr.DataArray(np.array(_TIME_STRS)),
1403
+ media_time=xr.DataArray(np.array(_MEDIA_TIME_STRS)),
1404
+ population=xr.DataArray(
1405
+ coords={c.GEO: _GEO_IDS}, data=np.array([11.1, 12.2]), name=c.POPULATION
1406
+ ),
1407
+ media=xr.DataArray(
1408
+ coords={
1409
+ c.GEO: _GEO_IDS,
1410
+ c.MEDIA_TIME: _MEDIA_TIME_STRS,
1411
+ c.MEDIA_CHANNEL: _MEDIA_CHANNEL_PAID,
1412
+ },
1413
+ data=np.array([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]]),
1414
+ name=c.MEDIA,
1415
+ ),
1416
+ media_spend=xr.DataArray(
1417
+ coords={c.MEDIA_CHANNEL: _MEDIA_CHANNEL_PAID},
1418
+ data=np.array([492, 496]),
1419
+ name=c.MEDIA_SPEND,
1420
+ ),
1421
+ media_spend_has_geo_dimension=False,
1422
+ media_spend_has_time_dimension=False,
1423
+ reach=None,
1424
+ frequency=None,
1425
+ rf_spend=None,
1426
+ organic_media=xr.DataArray(
1427
+ coords={
1428
+ c.GEO: _GEO_IDS,
1429
+ c.MEDIA_TIME: _MEDIA_TIME_STRS,
1430
+ c.ORGANIC_MEDIA_CHANNEL: _MEDIA_CHANNEL_ORGANIC,
1431
+ },
1432
+ data=np.array(
1433
+ [[[39, 40], [41, 42], [43, 44]], [[45, 46], [47, 48], [49, 50]]]
1434
+ ),
1435
+ name=c.ORGANIC_MEDIA,
1436
+ ),
1437
+ organic_reach=None,
1438
+ organic_frequency=None,
1439
+ non_media_treatments=None,
1440
+ kpi=xr.DataArray(
1441
+ coords={c.GEO: _GEO_IDS, c.TIME: _TIME_STRS},
1442
+ data=np.array([[2, 2], [3, 3]]),
1443
+ name=c.KPI,
1444
+ ),
1445
+ revenue_per_kpi=xr.DataArray(
1446
+ coords={c.GEO: _GEO_IDS, c.TIME: _TIME_STRS},
1447
+ data=np.ones((2, 2)),
1448
+ name=c.REVENUE_PER_KPI,
1449
+ ),
1450
+ controls=xr.DataArray(
1451
+ coords={
1452
+ c.GEO: _GEO_IDS,
1453
+ c.TIME: _TIME_STRS,
1454
+ c.CONTROL_VARIABLE: _CONTROL_VARIABLES,
1455
+ },
1456
+ data=np.array([[[31, 32], [33, 34]], [[35, 36], [37, 38]]]),
1457
+ name=c.CONTROLS,
1458
+ ),
1459
+ )
1460
+
1461
+ MOCK_PROTO_MEDIA_ORGANIC_EXPANDED_LAGGED = text_format.Parse(
1462
+ """
1463
+ marketing_data_points {
1464
+ geo_info {
1465
+ geo_id: "geo_0"
1466
+ population: 11
1467
+ }
1468
+ date_interval {
1469
+ start_date {
1470
+ year: 2021
1471
+ month: 1
1472
+ day: 25
1473
+ }
1474
+ end_date {
1475
+ year: 2021
1476
+ month: 2
1477
+ day: 1
1478
+ }
1479
+ }
1480
+ media_variables {
1481
+ channel_name: "ch_paid_0"
1482
+ scalar_metric {
1483
+ name: "impressions"
1484
+ value: 1.0
1485
+ }
1486
+ }
1487
+ media_variables {
1488
+ channel_name: "ch_paid_1"
1489
+ scalar_metric {
1490
+ name: "impressions"
1491
+ value: 2.0
1492
+ }
1493
+ }
1494
+ media_variables {
1495
+ channel_name: "ch_organic_0"
1496
+ scalar_metric {
1497
+ name: "impressions"
1498
+ value: 39.0
1499
+ }
1500
+ }
1501
+ media_variables {
1502
+ channel_name: "ch_organic_1"
1503
+ scalar_metric {
1504
+ name: "impressions"
1505
+ value: 40.0
1506
+ }
1507
+ }
1508
+ }
1509
+ marketing_data_points {
1510
+ geo_info {
1511
+ geo_id: "geo_0"
1512
+ population: 11
1513
+ }
1514
+ date_interval {
1515
+ start_date {
1516
+ year: 2021
1517
+ month: 2
1518
+ day: 1
1519
+ }
1520
+ end_date {
1521
+ year: 2021
1522
+ month: 2
1523
+ day: 8
1524
+ }
1525
+ }
1526
+ control_variables {
1527
+ name: "control_0"
1528
+ value: 31.0
1529
+ }
1530
+ control_variables {
1531
+ name: "control_1"
1532
+ value: 32.0
1533
+ }
1534
+ media_variables {
1535
+ channel_name: "ch_paid_0"
1536
+ scalar_metric {
1537
+ name: "impressions"
1538
+ value: 3.0
1539
+ }
1540
+ }
1541
+ media_variables {
1542
+ channel_name: "ch_paid_1"
1543
+ scalar_metric {
1544
+ name: "impressions"
1545
+ value: 4.0
1546
+ }
1547
+ }
1548
+ media_variables {
1549
+ channel_name: "ch_organic_0"
1550
+ scalar_metric {
1551
+ name: "impressions"
1552
+ value: 41.0
1553
+ }
1554
+ }
1555
+ media_variables {
1556
+ channel_name: "ch_organic_1"
1557
+ scalar_metric {
1558
+ name: "impressions"
1559
+ value: 42.0
1560
+ }
1561
+ }
1562
+ kpi {
1563
+ name: "revenue"
1564
+ revenue {
1565
+ value: 2.0
1566
+ }
1567
+ }
1568
+ }
1569
+ marketing_data_points {
1570
+ geo_info {
1571
+ geo_id: "geo_0"
1572
+ population: 11
1573
+ }
1574
+ date_interval {
1575
+ start_date {
1576
+ year: 2021
1577
+ month: 2
1578
+ day: 8
1579
+ }
1580
+ end_date {
1581
+ year: 2021
1582
+ month: 2
1583
+ day: 15
1584
+ }
1585
+ }
1586
+ control_variables {
1587
+ name: "control_0"
1588
+ value: 33.0
1589
+ }
1590
+ control_variables {
1591
+ name: "control_1"
1592
+ value: 34.0
1593
+ }
1594
+ media_variables {
1595
+ channel_name: "ch_paid_0"
1596
+ scalar_metric {
1597
+ name: "impressions"
1598
+ value: 5.0
1599
+ }
1600
+ }
1601
+ media_variables {
1602
+ channel_name: "ch_paid_1"
1603
+ scalar_metric {
1604
+ name: "impressions"
1605
+ value: 6.0
1606
+ }
1607
+ }
1608
+ media_variables {
1609
+ channel_name: "ch_organic_0"
1610
+ scalar_metric {
1611
+ name: "impressions"
1612
+ value: 43.0
1613
+ }
1614
+ }
1615
+ media_variables {
1616
+ channel_name: "ch_organic_1"
1617
+ scalar_metric {
1618
+ name: "impressions"
1619
+ value: 44.0
1620
+ }
1621
+ }
1622
+ kpi {
1623
+ name: "revenue"
1624
+ revenue {
1625
+ value: 2.0
1626
+ }
1627
+ }
1628
+ }
1629
+ marketing_data_points {
1630
+ geo_info {
1631
+ geo_id: "geo_1"
1632
+ population: 12
1633
+ }
1634
+ date_interval {
1635
+ start_date {
1636
+ year: 2021
1637
+ month: 1
1638
+ day: 25
1639
+ }
1640
+ end_date {
1641
+ year: 2021
1642
+ month: 2
1643
+ day: 1
1644
+ }
1645
+ }
1646
+ media_variables {
1647
+ channel_name: "ch_paid_0"
1648
+ scalar_metric {
1649
+ name: "impressions"
1650
+ value: 7.0
1651
+ }
1652
+ }
1653
+ media_variables {
1654
+ channel_name: "ch_paid_1"
1655
+ scalar_metric {
1656
+ name: "impressions"
1657
+ value: 8.0
1658
+ }
1659
+ }
1660
+ media_variables {
1661
+ channel_name: "ch_organic_0"
1662
+ scalar_metric {
1663
+ name: "impressions"
1664
+ value: 45.0
1665
+ }
1666
+ }
1667
+ media_variables {
1668
+ channel_name: "ch_organic_1"
1669
+ scalar_metric {
1670
+ name: "impressions"
1671
+ value: 46.0
1672
+ }
1673
+ }
1674
+ }
1675
+ marketing_data_points {
1676
+ geo_info {
1677
+ geo_id: "geo_1"
1678
+ population: 12
1679
+ }
1680
+ date_interval {
1681
+ start_date {
1682
+ year: 2021
1683
+ month: 2
1684
+ day: 1
1685
+ }
1686
+ end_date {
1687
+ year: 2021
1688
+ month: 2
1689
+ day: 8
1690
+ }
1691
+ }
1692
+ control_variables {
1693
+ name: "control_0"
1694
+ value: 35.0
1695
+ }
1696
+ control_variables {
1697
+ name: "control_1"
1698
+ value: 36.0
1699
+ }
1700
+ media_variables {
1701
+ channel_name: "ch_paid_0"
1702
+ scalar_metric {
1703
+ name: "impressions"
1704
+ value: 9.0
1705
+ }
1706
+ }
1707
+ media_variables {
1708
+ channel_name: "ch_paid_1"
1709
+ scalar_metric {
1710
+ name: "impressions"
1711
+ value: 10.0
1712
+ }
1713
+ }
1714
+ media_variables {
1715
+ channel_name: "ch_organic_0"
1716
+ scalar_metric {
1717
+ name: "impressions"
1718
+ value: 47.0
1719
+ }
1720
+ }
1721
+ media_variables {
1722
+ channel_name: "ch_organic_1"
1723
+ scalar_metric {
1724
+ name: "impressions"
1725
+ value: 48.0
1726
+ }
1727
+ }
1728
+ kpi {
1729
+ name: "revenue"
1730
+ revenue {
1731
+ value: 3.0
1732
+ }
1733
+ }
1734
+ }
1735
+ marketing_data_points {
1736
+ geo_info {
1737
+ geo_id: "geo_1"
1738
+ population: 12
1739
+ }
1740
+ date_interval {
1741
+ start_date {
1742
+ year: 2021
1743
+ month: 2
1744
+ day: 8
1745
+ }
1746
+ end_date {
1747
+ year: 2021
1748
+ month: 2
1749
+ day: 15
1750
+ }
1751
+ }
1752
+ control_variables {
1753
+ name: "control_0"
1754
+ value: 37.0
1755
+ }
1756
+ control_variables {
1757
+ name: "control_1"
1758
+ value: 38.0
1759
+ }
1760
+ media_variables {
1761
+ channel_name: "ch_paid_0"
1762
+ scalar_metric {
1763
+ name: "impressions"
1764
+ value: 11.0
1765
+ }
1766
+ }
1767
+ media_variables {
1768
+ channel_name: "ch_paid_1"
1769
+ scalar_metric {
1770
+ name: "impressions"
1771
+ value: 12.0
1772
+ }
1773
+ }
1774
+ media_variables {
1775
+ channel_name: "ch_organic_0"
1776
+ scalar_metric {
1777
+ name: "impressions"
1778
+ value: 49.0
1779
+ }
1780
+ }
1781
+ media_variables {
1782
+ channel_name: "ch_organic_1"
1783
+ scalar_metric {
1784
+ name: "impressions"
1785
+ value: 50.0
1786
+ }
1787
+ }
1788
+ kpi {
1789
+ name: "revenue"
1790
+ revenue {
1791
+ value: 3.0
1792
+ }
1793
+ }
1794
+ }
1795
+ marketing_data_points {
1796
+ date_interval {
1797
+ start_date {
1798
+ year: 2021
1799
+ month: 1
1800
+ day: 25
1801
+ }
1802
+ end_date {
1803
+ year: 2021
1804
+ month: 2
1805
+ day: 15
1806
+ }
1807
+ }
1808
+ media_variables {
1809
+ channel_name: "ch_paid_0"
1810
+ media_spend: 492.0
1811
+ }
1812
+ media_variables {
1813
+ channel_name: "ch_paid_1"
1814
+ media_spend: 496.0
1815
+ }
1816
+ }
1817
+ metadata {
1818
+ time_dimensions {
1819
+ name: "time"
1820
+ dates {
1821
+ year: 2021
1822
+ month: 2
1823
+ day: 1
1824
+ }
1825
+ dates {
1826
+ year: 2021
1827
+ month: 2
1828
+ day: 8
1829
+ }
1830
+ }
1831
+ time_dimensions {
1832
+ name: "media_time"
1833
+ dates {
1834
+ year: 2021
1835
+ month: 1
1836
+ day: 25
1837
+ }
1838
+ dates {
1839
+ year: 2021
1840
+ month: 2
1841
+ day: 1
1842
+ }
1843
+ dates {
1844
+ year: 2021
1845
+ month: 2
1846
+ day: 8
1847
+ }
1848
+ }
1849
+ channel_dimensions {
1850
+ name: "media"
1851
+ channels: "ch_paid_0"
1852
+ channels: "ch_paid_1"
1853
+ }
1854
+ channel_dimensions {
1855
+ name: "organic_media"
1856
+ channels: "ch_organic_0"
1857
+ channels: "ch_organic_1"
1858
+ }
1859
+ control_names: "control_0"
1860
+ control_names: "control_1"
1861
+ kpi_type: "revenue"
1862
+ }
1863
+ """,
1864
+ marketing_pb.MarketingData(),
1865
+ )
1866
+
1867
+ # Media, Organic, Granular, Not Lagged
1868
+ MOCK_INPUT_DATA_MEDIA_ORGANIC_GRANULAR_NOT_LAGGED = mock.MagicMock(
1869
+ kpi_type=c.REVENUE,
1870
+ geo=xr.DataArray(np.array(_GEO_IDS)),
1871
+ time=xr.DataArray(np.array(_TIME_STRS)),
1872
+ media_time=xr.DataArray(np.array(_TIME_STRS)),
1873
+ population=xr.DataArray(
1874
+ coords={c.GEO: _GEO_IDS}, data=np.array([11.1, 12.2]), name=c.POPULATION
1875
+ ),
1876
+ media=xr.DataArray(
1877
+ coords={
1878
+ c.GEO: _GEO_IDS,
1879
+ c.MEDIA_TIME: _TIME_STRS,
1880
+ c.MEDIA_CHANNEL: _MEDIA_CHANNEL_PAID,
1881
+ },
1882
+ data=np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]),
1883
+ name=c.MEDIA,
1884
+ ),
1885
+ media_spend=xr.DataArray(
1886
+ coords={
1887
+ c.GEO: _GEO_IDS,
1888
+ c.TIME: _TIME_STRS,
1889
+ c.MEDIA_CHANNEL: _MEDIA_CHANNEL_PAID,
1890
+ },
1891
+ data=np.array([[[123, 124], [125, 126]], [[127, 128], [129, 130]]]),
1892
+ name=c.MEDIA_SPEND,
1893
+ ),
1894
+ media_spend_has_geo_dimension=True,
1895
+ media_spend_has_time_dimension=True,
1896
+ reach=None,
1897
+ frequency=None,
1898
+ rf_spend=None,
1899
+ organic_media=xr.DataArray(
1900
+ coords={
1901
+ c.GEO: _GEO_IDS,
1902
+ c.MEDIA_TIME: _TIME_STRS,
1903
+ c.ORGANIC_MEDIA_CHANNEL: _MEDIA_CHANNEL_ORGANIC,
1904
+ },
1905
+ data=np.array([[[39, 40], [41, 42]], [[43, 44], [45, 46]]]),
1906
+ name=c.ORGANIC_MEDIA,
1907
+ ),
1908
+ organic_reach=None,
1909
+ organic_frequency=None,
1910
+ non_media_treatments=None,
1911
+ kpi=xr.DataArray(
1912
+ coords={c.GEO: _GEO_IDS, c.TIME: _TIME_STRS},
1913
+ data=np.array([[2, 2], [3, 3]]),
1914
+ name=c.KPI,
1915
+ ),
1916
+ revenue_per_kpi=xr.DataArray(
1917
+ coords={c.GEO: _GEO_IDS, c.TIME: _TIME_STRS},
1918
+ data=np.ones((2, 2)),
1919
+ name=c.REVENUE_PER_KPI,
1920
+ ),
1921
+ controls=xr.DataArray(
1922
+ coords={
1923
+ c.GEO: _GEO_IDS,
1924
+ c.TIME: _TIME_STRS,
1925
+ c.CONTROL_VARIABLE: _CONTROL_VARIABLES,
1926
+ },
1927
+ data=np.array([[[31, 32], [33, 34]], [[35, 36], [37, 38]]]),
1928
+ name=c.CONTROLS,
1929
+ ),
1930
+ )
1931
+
1932
+ MOCK_PROTO_MEDIA_ORGANIC_GRANULAR_NOT_LAGGED = text_format.Parse(
1933
+ """
1934
+ marketing_data_points {
1935
+ geo_info {
1936
+ geo_id: "geo_0"
1937
+ population: 11
1938
+ }
1939
+ date_interval {
1940
+ start_date {
1941
+ year: 2021
1942
+ month: 2
1943
+ day: 1
1944
+ }
1945
+ end_date {
1946
+ year: 2021
1947
+ month: 2
1948
+ day: 8
1949
+ }
1950
+ }
1951
+ control_variables {
1952
+ name: "control_0"
1953
+ value: 31.0
1954
+ }
1955
+ control_variables {
1956
+ name: "control_1"
1957
+ value: 32.0
1958
+ }
1959
+ media_variables {
1960
+ channel_name: "ch_paid_0"
1961
+ scalar_metric {
1962
+ name: "impressions"
1963
+ value: 1.0
1964
+ }
1965
+ media_spend: 123.0
1966
+ }
1967
+ media_variables {
1968
+ channel_name: "ch_paid_1"
1969
+ scalar_metric {
1970
+ name: "impressions"
1971
+ value: 2.0
1972
+ }
1973
+ media_spend: 124.0
1974
+ }
1975
+ media_variables {
1976
+ channel_name: "ch_organic_0"
1977
+ scalar_metric {
1978
+ name: "impressions"
1979
+ value: 39.0
1980
+ }
1981
+ }
1982
+ media_variables {
1983
+ channel_name: "ch_organic_1"
1984
+ scalar_metric {
1985
+ name: "impressions"
1986
+ value: 40.0
1987
+ }
1988
+ }
1989
+ kpi {
1990
+ name: "revenue"
1991
+ revenue {
1992
+ value: 2.0
1993
+ }
1994
+ }
1995
+ }
1996
+ marketing_data_points {
1997
+ geo_info {
1998
+ geo_id: "geo_0"
1999
+ population: 11
2000
+ }
2001
+ date_interval {
2002
+ start_date {
2003
+ year: 2021
2004
+ month: 2
2005
+ day: 8
2006
+ }
2007
+ end_date {
2008
+ year: 2021
2009
+ month: 2
2010
+ day: 15
2011
+ }
2012
+ }
2013
+ control_variables {
2014
+ name: "control_0"
2015
+ value: 33.0
2016
+ }
2017
+ control_variables {
2018
+ name: "control_1"
2019
+ value: 34.0
2020
+ }
2021
+ media_variables {
2022
+ channel_name: "ch_paid_0"
2023
+ scalar_metric {
2024
+ name: "impressions"
2025
+ value: 3.0
2026
+ }
2027
+ media_spend: 125.0
2028
+ }
2029
+ media_variables {
2030
+ channel_name: "ch_paid_1"
2031
+ scalar_metric {
2032
+ name: "impressions"
2033
+ value: 4.0
2034
+ }
2035
+ media_spend: 126.0
2036
+ }
2037
+ media_variables {
2038
+ channel_name: "ch_organic_0"
2039
+ scalar_metric {
2040
+ name: "impressions"
2041
+ value: 41.0
2042
+ }
2043
+ }
2044
+ media_variables {
2045
+ channel_name: "ch_organic_1"
2046
+ scalar_metric {
2047
+ name: "impressions"
2048
+ value: 42.0
2049
+ }
2050
+ }
2051
+ kpi {
2052
+ name: "revenue"
2053
+ revenue {
2054
+ value: 2.0
2055
+ }
2056
+ }
2057
+ }
2058
+ marketing_data_points {
2059
+ geo_info {
2060
+ geo_id: "geo_1"
2061
+ population: 12
2062
+ }
2063
+ date_interval {
2064
+ start_date {
2065
+ year: 2021
2066
+ month: 2
2067
+ day: 1
2068
+ }
2069
+ end_date {
2070
+ year: 2021
2071
+ month: 2
2072
+ day: 8
2073
+ }
2074
+ }
2075
+ control_variables {
2076
+ name: "control_0"
2077
+ value: 35.0
2078
+ }
2079
+ control_variables {
2080
+ name: "control_1"
2081
+ value: 36.0
2082
+ }
2083
+ media_variables {
2084
+ channel_name: "ch_paid_0"
2085
+ scalar_metric {
2086
+ name: "impressions"
2087
+ value: 5.0
2088
+ }
2089
+ media_spend: 127.0
2090
+ }
2091
+ media_variables {
2092
+ channel_name: "ch_paid_1"
2093
+ scalar_metric {
2094
+ name: "impressions"
2095
+ value: 6.0
2096
+ }
2097
+ media_spend: 128.0
2098
+ }
2099
+ media_variables {
2100
+ channel_name: "ch_organic_0"
2101
+ scalar_metric {
2102
+ name: "impressions"
2103
+ value: 43.0
2104
+ }
2105
+ }
2106
+ media_variables {
2107
+ channel_name: "ch_organic_1"
2108
+ scalar_metric {
2109
+ name: "impressions"
2110
+ value: 44.0
2111
+ }
2112
+ }
2113
+ kpi {
2114
+ name: "revenue"
2115
+ revenue {
2116
+ value: 3.0
2117
+ }
2118
+ }
2119
+ }
2120
+ marketing_data_points {
2121
+ geo_info {
2122
+ geo_id: "geo_1"
2123
+ population: 12
2124
+ }
2125
+ date_interval {
2126
+ start_date {
2127
+ year: 2021
2128
+ month: 2
2129
+ day: 8
2130
+ }
2131
+ end_date {
2132
+ year: 2021
2133
+ month: 2
2134
+ day: 15
2135
+ }
2136
+ }
2137
+ control_variables {
2138
+ name: "control_0"
2139
+ value: 37.0
2140
+ }
2141
+ control_variables {
2142
+ name: "control_1"
2143
+ value: 38.0
2144
+ }
2145
+ media_variables {
2146
+ channel_name: "ch_paid_0"
2147
+ scalar_metric {
2148
+ name: "impressions"
2149
+ value: 7.0
2150
+ }
2151
+ media_spend: 129.0
2152
+ }
2153
+ media_variables {
2154
+ channel_name: "ch_paid_1"
2155
+ scalar_metric {
2156
+ name: "impressions"
2157
+ value: 8.0
2158
+ }
2159
+ media_spend: 130.0
2160
+ }
2161
+ media_variables {
2162
+ channel_name: "ch_organic_0"
2163
+ scalar_metric {
2164
+ name: "impressions"
2165
+ value: 45.0
2166
+ }
2167
+ }
2168
+ media_variables {
2169
+ channel_name: "ch_organic_1"
2170
+ scalar_metric {
2171
+ name: "impressions"
2172
+ value: 46.0
2173
+ }
2174
+ }
2175
+ kpi {
2176
+ name: "revenue"
2177
+ revenue {
2178
+ value: 3.0
2179
+ }
2180
+ }
2181
+ }
2182
+ metadata {
2183
+ time_dimensions {
2184
+ name: "time"
2185
+ dates {
2186
+ year: 2021
2187
+ month: 2
2188
+ day: 1
2189
+ }
2190
+ dates {
2191
+ year: 2021
2192
+ month: 2
2193
+ day: 8
2194
+ }
2195
+ }
2196
+ time_dimensions {
2197
+ name: "media_time"
2198
+ dates {
2199
+ year: 2021
2200
+ month: 2
2201
+ day: 1
2202
+ }
2203
+ dates {
2204
+ year: 2021
2205
+ month: 2
2206
+ day: 8
2207
+ }
2208
+ }
2209
+ channel_dimensions {
2210
+ name: "media"
2211
+ channels: "ch_paid_0"
2212
+ channels: "ch_paid_1"
2213
+ }
2214
+ channel_dimensions {
2215
+ name: "organic_media"
2216
+ channels: "ch_organic_0"
2217
+ channels: "ch_organic_1"
2218
+ }
2219
+ control_names: "control_0"
2220
+ control_names: "control_1"
2221
+ kpi_type: "revenue"
2222
+ }
2223
+ """,
2224
+ marketing_pb.MarketingData(),
2225
+ )
2226
+
2227
+ # Reach and Frequency, Paid, Expanded, Lagged
2228
+ MOCK_INPUT_DATA_RF_PAID_EXPANDED_LAGGED = mock.MagicMock(
2229
+ kpi_type=c.REVENUE,
2230
+ geo=xr.DataArray(np.array(_GEO_IDS)),
2231
+ time=xr.DataArray(np.array(_TIME_STRS)),
2232
+ media_time=xr.DataArray(np.array(_MEDIA_TIME_STRS)),
2233
+ population=xr.DataArray(
2234
+ coords={c.GEO: _GEO_IDS},
2235
+ data=np.array([11.1, 12.2]),
2236
+ name=c.POPULATION,
2237
+ ),
2238
+ media=None,
2239
+ media_spend=None,
2240
+ reach=xr.DataArray(
2241
+ coords={
2242
+ c.GEO: _GEO_IDS,
2243
+ c.MEDIA_TIME: _MEDIA_TIME_STRS,
2244
+ c.RF_CHANNEL: _RF_CHANNEL_PAID,
2245
+ },
2246
+ data=np.array(
2247
+ [[[51, 52], [53, 54], [55, 56]], [[57, 58], [59, 60], [61, 62]]]
2248
+ ),
2249
+ name=c.REACH,
2250
+ ),
2251
+ frequency=xr.DataArray(
2252
+ coords={
2253
+ c.GEO: _GEO_IDS,
2254
+ c.MEDIA_TIME: _MEDIA_TIME_STRS,
2255
+ c.RF_CHANNEL: _RF_CHANNEL_PAID,
2256
+ },
2257
+ data=np.array([
2258
+ [[1.1, 1.2], [1.3, 1.4], [1.5, 1.6]],
2259
+ [[1.7, 1.8], [1.9, 2.0], [2.1, 2.2]],
2260
+ ]),
2261
+ name=c.FREQUENCY,
2262
+ ),
2263
+ rf_spend=xr.DataArray(
2264
+ coords={c.RF_CHANNEL: _RF_CHANNEL_PAID},
2265
+ data=np.array([1004, 1008]),
2266
+ name=c.RF_SPEND,
2267
+ ),
2268
+ rf_spend_has_geo_dimension=False,
2269
+ rf_spend_has_time_dimension=False,
2270
+ kpi=xr.DataArray(
2271
+ coords={
2272
+ c.GEO: _GEO_IDS,
2273
+ c.TIME: _TIME_STRS,
2274
+ },
2275
+ data=np.array([[2, 3], [4, 5]]),
2276
+ name=c.KPI,
2277
+ ),
2278
+ revenue_per_kpi=xr.DataArray(
2279
+ coords={
2280
+ c.GEO: _GEO_IDS,
2281
+ c.TIME: _TIME_STRS,
2282
+ },
2283
+ data=np.ones((2, 2)),
2284
+ name=c.REVENUE_PER_KPI,
2285
+ ),
2286
+ controls=xr.DataArray(
2287
+ coords={
2288
+ c.GEO: _GEO_IDS,
2289
+ c.TIME: _TIME_STRS,
2290
+ c.CONTROL_VARIABLE: _CONTROL_VARIABLES,
2291
+ },
2292
+ data=np.array([[[32, 33], [34, 35]], [[36, 37], [38, 39]]]),
2293
+ name=c.CONTROLS,
2294
+ ),
2295
+ organic_media=None,
2296
+ organic_reach=None,
2297
+ organic_frequency=None,
2298
+ non_media_treatments=None,
2299
+ )
2300
+
2301
+ MOCK_PROTO_RF_PAID_EXPANDED_LAGGED = text_format.Parse(
2302
+ """
2303
+ marketing_data_points {
2304
+ geo_info {
2305
+ geo_id: "geo_0"
2306
+ population: 11
2307
+ }
2308
+ date_interval {
2309
+ start_date {
2310
+ year: 2021
2311
+ month: 1
2312
+ day: 25
2313
+ }
2314
+ end_date {
2315
+ year: 2021
2316
+ month: 2
2317
+ day: 1
2318
+ }
2319
+ }
2320
+ reach_frequency_variables {
2321
+ channel_name: "rf_ch_paid_0"
2322
+ reach: 51
2323
+ average_frequency: 1.1
2324
+ }
2325
+ reach_frequency_variables {
2326
+ channel_name: "rf_ch_paid_1"
2327
+ reach: 52
2328
+ average_frequency: 1.2
2329
+ }
2330
+ }
2331
+ marketing_data_points {
2332
+ geo_info {
2333
+ geo_id: "geo_0"
2334
+ population: 11
2335
+ }
2336
+ date_interval {
2337
+ start_date {
2338
+ year: 2021
2339
+ month: 2
2340
+ day: 1
2341
+ }
2342
+ end_date {
2343
+ year: 2021
2344
+ month: 2
2345
+ day: 8
2346
+ }
2347
+ }
2348
+ control_variables {
2349
+ name: "control_0"
2350
+ value: 32.0
2351
+ }
2352
+ control_variables {
2353
+ name: "control_1"
2354
+ value: 33.0
2355
+ }
2356
+ reach_frequency_variables {
2357
+ channel_name: "rf_ch_paid_0"
2358
+ reach: 53
2359
+ average_frequency: 1.3
2360
+ }
2361
+ reach_frequency_variables {
2362
+ channel_name: "rf_ch_paid_1"
2363
+ reach: 54
2364
+ average_frequency: 1.4
2365
+ }
2366
+ kpi {
2367
+ name: "revenue"
2368
+ revenue {
2369
+ value: 2.0
2370
+ }
2371
+ }
2372
+ }
2373
+ marketing_data_points {
2374
+ geo_info {
2375
+ geo_id: "geo_0"
2376
+ population: 11
2377
+ }
2378
+ date_interval {
2379
+ start_date {
2380
+ year: 2021
2381
+ month: 2
2382
+ day: 8
2383
+ }
2384
+ end_date {
2385
+ year: 2021
2386
+ month: 2
2387
+ day: 15
2388
+ }
2389
+ }
2390
+ control_variables {
2391
+ name: "control_0"
2392
+ value: 34.0
2393
+ }
2394
+ control_variables {
2395
+ name: "control_1"
2396
+ value: 35.0
2397
+ }
2398
+ reach_frequency_variables {
2399
+ channel_name: "rf_ch_paid_0"
2400
+ reach: 55
2401
+ average_frequency: 1.5
2402
+ }
2403
+ reach_frequency_variables {
2404
+ channel_name: "rf_ch_paid_1"
2405
+ reach: 56
2406
+ average_frequency: 1.6
2407
+ }
2408
+ kpi {
2409
+ name: "revenue"
2410
+ revenue {
2411
+ value: 3.0
2412
+ }
2413
+ }
2414
+ }
2415
+ marketing_data_points {
2416
+ geo_info {
2417
+ geo_id: "geo_1"
2418
+ population: 12
2419
+ }
2420
+ date_interval {
2421
+ start_date {
2422
+ year: 2021
2423
+ month: 1
2424
+ day: 25
2425
+ }
2426
+ end_date {
2427
+ year: 2021
2428
+ month: 2
2429
+ day: 1
2430
+ }
2431
+ }
2432
+ reach_frequency_variables {
2433
+ channel_name: "rf_ch_paid_0"
2434
+ reach: 57
2435
+ average_frequency: 1.7
2436
+ }
2437
+ reach_frequency_variables {
2438
+ channel_name: "rf_ch_paid_1"
2439
+ reach: 58
2440
+ average_frequency: 1.8
2441
+ }
2442
+ }
2443
+ marketing_data_points {
2444
+ geo_info {
2445
+ geo_id: "geo_1"
2446
+ population: 12
2447
+ }
2448
+ date_interval {
2449
+ start_date {
2450
+ year: 2021
2451
+ month: 2
2452
+ day: 1
2453
+ }
2454
+ end_date {
2455
+ year: 2021
2456
+ month: 2
2457
+ day: 8
2458
+ }
2459
+ }
2460
+ control_variables {
2461
+ name: "control_0"
2462
+ value: 36.0
2463
+ }
2464
+ control_variables {
2465
+ name: "control_1"
2466
+ value: 37.0
2467
+ }
2468
+ reach_frequency_variables {
2469
+ channel_name: "rf_ch_paid_0"
2470
+ reach: 59
2471
+ average_frequency: 1.9
2472
+ }
2473
+ reach_frequency_variables {
2474
+ channel_name: "rf_ch_paid_1"
2475
+ reach: 60
2476
+ average_frequency: 2.0
2477
+ }
2478
+ kpi {
2479
+ name: "revenue"
2480
+ revenue {
2481
+ value: 4.0
2482
+ }
2483
+ }
2484
+ }
2485
+ marketing_data_points {
2486
+ geo_info {
2487
+ geo_id: "geo_1"
2488
+ population: 12
2489
+ }
2490
+ date_interval {
2491
+ start_date {
2492
+ year: 2021
2493
+ month: 2
2494
+ day: 8
2495
+ }
2496
+ end_date {
2497
+ year: 2021
2498
+ month: 2
2499
+ day: 15
2500
+ }
2501
+ }
2502
+ control_variables {
2503
+ name: "control_0"
2504
+ value: 38.0
2505
+ }
2506
+ control_variables {
2507
+ name: "control_1"
2508
+ value: 39.0
2509
+ }
2510
+ reach_frequency_variables {
2511
+ channel_name: "rf_ch_paid_0"
2512
+ reach: 61
2513
+ average_frequency: 2.1
2514
+ }
2515
+ reach_frequency_variables {
2516
+ channel_name: "rf_ch_paid_1"
2517
+ reach: 62
2518
+ average_frequency: 2.2
2519
+ }
2520
+ kpi {
2521
+ name: "revenue"
2522
+ revenue {
2523
+ value: 5.0
2524
+ }
2525
+ }
2526
+ }
2527
+ marketing_data_points {
2528
+ date_interval {
2529
+ start_date {
2530
+ year: 2021
2531
+ month: 1
2532
+ day: 25
2533
+ }
2534
+ end_date {
2535
+ year: 2021
2536
+ month: 2
2537
+ day: 15
2538
+ }
2539
+ }
2540
+ reach_frequency_variables {
2541
+ channel_name: "rf_ch_paid_0"
2542
+ spend: 1004.0
2543
+ }
2544
+ reach_frequency_variables {
2545
+ channel_name: "rf_ch_paid_1"
2546
+ spend: 1008.0
2547
+ }
2548
+ }
2549
+ metadata {
2550
+ time_dimensions {
2551
+ name: "time"
2552
+ dates {
2553
+ year: 2021
2554
+ month: 2
2555
+ day: 1
2556
+ }
2557
+ dates {
2558
+ year: 2021
2559
+ month: 2
2560
+ day: 8
2561
+ }
2562
+ }
2563
+ time_dimensions {
2564
+ name: "media_time"
2565
+ dates {
2566
+ year: 2021
2567
+ month: 1
2568
+ day: 25
2569
+ }
2570
+ dates {
2571
+ year: 2021
2572
+ month: 2
2573
+ day: 1
2574
+ }
2575
+ dates {
2576
+ year: 2021
2577
+ month: 2
2578
+ day: 8
2579
+ }
2580
+ }
2581
+ channel_dimensions {
2582
+ name: "reach"
2583
+ channels: "rf_ch_paid_0"
2584
+ channels: "rf_ch_paid_1"
2585
+ }
2586
+ channel_dimensions {
2587
+ name: "frequency"
2588
+ channels: "rf_ch_paid_0"
2589
+ channels: "rf_ch_paid_1"
2590
+ }
2591
+ control_names: "control_0"
2592
+ control_names: "control_1"
2593
+ kpi_type: "revenue"
2594
+ }
2595
+ """,
2596
+ marketing_pb.MarketingData(),
2597
+ )
2598
+
2599
+ # Reach and Frequency, Paid, Granular, Not Lagged
2600
+ MOCK_INPUT_DATA_RF_PAID_GRANULAR_NOT_LAGGED = mock.MagicMock(
2601
+ kpi_type=c.REVENUE,
2602
+ geo=xr.DataArray(np.array(_GEO_IDS)),
2603
+ time=xr.DataArray(np.array(_TIME_STRS)),
2604
+ media_time=xr.DataArray(np.array(_TIME_STRS)),
2605
+ population=xr.DataArray(
2606
+ coords={c.GEO: _GEO_IDS}, data=np.array([11.1, 12.2]), name=c.POPULATION
2607
+ ),
2608
+ media=None,
2609
+ media_spend=None,
2610
+ reach=xr.DataArray(
2611
+ coords={
2612
+ c.GEO: _GEO_IDS,
2613
+ c.MEDIA_TIME: _TIME_STRS,
2614
+ c.RF_CHANNEL: _RF_CHANNEL_PAID,
2615
+ },
2616
+ data=np.array(
2617
+ [[[51.0, 52.0], [53.0, 54.0]], [[55.0, 56.0], [57.0, 58.0]]]
2618
+ ),
2619
+ name=c.REACH,
2620
+ ),
2621
+ frequency=xr.DataArray(
2622
+ coords={
2623
+ c.GEO: _GEO_IDS,
2624
+ c.MEDIA_TIME: _TIME_STRS,
2625
+ c.RF_CHANNEL: _RF_CHANNEL_PAID,
2626
+ },
2627
+ data=np.array([[[1.1, 1.2], [2, 3]], [[4, 5], [6, 7]]]),
2628
+ name=c.FREQUENCY,
2629
+ ),
2630
+ rf_spend=xr.DataArray(
2631
+ coords={
2632
+ c.GEO: _GEO_IDS,
2633
+ c.TIME: _TIME_STRS,
2634
+ c.RF_CHANNEL: _RF_CHANNEL_PAID,
2635
+ },
2636
+ data=np.array([[[252, 253], [254, 255]], [[256, 257], [258, 259]]]),
2637
+ name=c.RF_SPEND,
2638
+ ),
2639
+ rf_spend_has_geo_dimension=True,
2640
+ rf_spend_has_time_dimension=True,
2641
+ kpi=xr.DataArray(
2642
+ coords={c.GEO: _GEO_IDS, c.TIME: _TIME_STRS},
2643
+ data=np.array([[1, 2], [3, 4]]),
2644
+ name=c.KPI,
2645
+ ),
2646
+ revenue_per_kpi=xr.DataArray(
2647
+ coords={c.GEO: _GEO_IDS, c.TIME: _TIME_STRS},
2648
+ data=np.ones((2, 2)),
2649
+ name=c.REVENUE_PER_KPI,
2650
+ ),
2651
+ controls=xr.DataArray(
2652
+ coords={
2653
+ c.GEO: _GEO_IDS,
2654
+ c.TIME: _TIME_STRS,
2655
+ c.CONTROL_VARIABLE: _CONTROL_VARIABLES,
2656
+ },
2657
+ data=np.array([[[31, 32], [33, 34]], [[35, 36], [37, 38]]]),
2658
+ name=c.CONTROLS,
2659
+ ),
2660
+ organic_media=None,
2661
+ organic_reach=None,
2662
+ organic_frequency=None,
2663
+ non_media_treatments=None,
2664
+ )
2665
+
2666
+ MOCK_PROTO_RF_PAID_GRANULAR_NOT_LAGGED = text_format.Parse(
2667
+ """
2668
+ marketing_data_points {
2669
+ geo_info {
2670
+ geo_id: "geo_0"
2671
+ population: 11
2672
+ }
2673
+ date_interval {
2674
+ start_date {
2675
+ year: 2021
2676
+ month: 2
2677
+ day: 1
2678
+ }
2679
+ end_date {
2680
+ year: 2021
2681
+ month: 2
2682
+ day: 8
2683
+ }
2684
+ }
2685
+ control_variables {
2686
+ name: "control_0"
2687
+ value: 31.0
2688
+ }
2689
+ control_variables {
2690
+ name: "control_1"
2691
+ value: 32.0
2692
+ }
2693
+ reach_frequency_variables {
2694
+ channel_name: "rf_ch_paid_0"
2695
+ reach: 51
2696
+ average_frequency: 1.1
2697
+ spend: 252.0
2698
+ }
2699
+ reach_frequency_variables {
2700
+ channel_name: "rf_ch_paid_1"
2701
+ reach: 52
2702
+ average_frequency: 1.2
2703
+ spend: 253.0
2704
+ }
2705
+ kpi {
2706
+ name: "revenue"
2707
+ revenue {
2708
+ value: 1.0
2709
+ }
2710
+ }
2711
+ }
2712
+ marketing_data_points {
2713
+ geo_info {
2714
+ geo_id: "geo_0"
2715
+ population: 11
2716
+ }
2717
+ date_interval {
2718
+ start_date {
2719
+ year: 2021
2720
+ month: 2
2721
+ day: 8
2722
+ }
2723
+ end_date {
2724
+ year: 2021
2725
+ month: 2
2726
+ day: 15
2727
+ }
2728
+ }
2729
+ control_variables {
2730
+ name: "control_0"
2731
+ value: 33.0
2732
+ }
2733
+ control_variables {
2734
+ name: "control_1"
2735
+ value: 34.0
2736
+ }
2737
+ reach_frequency_variables {
2738
+ channel_name: "rf_ch_paid_0"
2739
+ reach: 53
2740
+ average_frequency: 2.0
2741
+ spend: 254.0
2742
+ }
2743
+ reach_frequency_variables {
2744
+ channel_name: "rf_ch_paid_1"
2745
+ reach: 54
2746
+ average_frequency: 3.0
2747
+ spend: 255.0
2748
+ }
2749
+ kpi {
2750
+ name: "revenue"
2751
+ revenue {
2752
+ value: 2.0
2753
+ }
2754
+ }
2755
+ }
2756
+ marketing_data_points {
2757
+ geo_info {
2758
+ geo_id: "geo_1"
2759
+ population: 12
2760
+ }
2761
+ date_interval {
2762
+ start_date {
2763
+ year: 2021
2764
+ month: 2
2765
+ day: 1
2766
+ }
2767
+ end_date {
2768
+ year: 2021
2769
+ month: 2
2770
+ day: 8
2771
+ }
2772
+ }
2773
+ control_variables {
2774
+ name: "control_0"
2775
+ value: 35.0
2776
+ }
2777
+ control_variables {
2778
+ name: "control_1"
2779
+ value: 36.0
2780
+ }
2781
+ reach_frequency_variables {
2782
+ channel_name: "rf_ch_paid_0"
2783
+ reach: 55
2784
+ average_frequency: 4.0
2785
+ spend: 256.0
2786
+ }
2787
+ reach_frequency_variables {
2788
+ channel_name: "rf_ch_paid_1"
2789
+ reach: 56
2790
+ average_frequency: 5.0
2791
+ spend: 257.0
2792
+ }
2793
+ kpi {
2794
+ name: "revenue"
2795
+ revenue {
2796
+ value: 3.0
2797
+ }
2798
+ }
2799
+ }
2800
+ marketing_data_points {
2801
+ geo_info {
2802
+ geo_id: "geo_1"
2803
+ population: 12
2804
+ }
2805
+ date_interval {
2806
+ start_date {
2807
+ year: 2021
2808
+ month: 2
2809
+ day: 8
2810
+ }
2811
+ end_date {
2812
+ year: 2021
2813
+ month: 2
2814
+ day: 15
2815
+ }
2816
+ }
2817
+ control_variables {
2818
+ name: "control_0"
2819
+ value: 37.0
2820
+ }
2821
+ control_variables {
2822
+ name: "control_1"
2823
+ value: 38.0
2824
+ }
2825
+ reach_frequency_variables {
2826
+ channel_name: "rf_ch_paid_0"
2827
+ reach: 57
2828
+ average_frequency: 6.0
2829
+ spend: 258.0
2830
+ }
2831
+ reach_frequency_variables {
2832
+ channel_name: "rf_ch_paid_1"
2833
+ reach: 58
2834
+ average_frequency: 7.0
2835
+ spend: 259.0
2836
+ }
2837
+ kpi {
2838
+ name: "revenue"
2839
+ revenue {
2840
+ value: 4.0
2841
+ }
2842
+ }
2843
+ }
2844
+ metadata {
2845
+ time_dimensions {
2846
+ name: "time"
2847
+ dates {
2848
+ year: 2021
2849
+ month: 2
2850
+ day: 1
2851
+ }
2852
+ dates {
2853
+ year: 2021
2854
+ month: 2
2855
+ day: 8
2856
+ }
2857
+ }
2858
+ time_dimensions {
2859
+ name: "media_time"
2860
+ dates {
2861
+ year: 2021
2862
+ month: 2
2863
+ day: 1
2864
+ }
2865
+ dates {
2866
+ year: 2021
2867
+ month: 2
2868
+ day: 8
2869
+ }
2870
+ }
2871
+ channel_dimensions {
2872
+ name: "reach"
2873
+ channels: "rf_ch_paid_0"
2874
+ channels: "rf_ch_paid_1"
2875
+ }
2876
+ channel_dimensions {
2877
+ name: "frequency"
2878
+ channels: "rf_ch_paid_0"
2879
+ channels: "rf_ch_paid_1"
2880
+ }
2881
+ control_names: "control_0"
2882
+ control_names: "control_1"
2883
+ kpi_type: "revenue"
2884
+ }
2885
+ """,
2886
+ marketing_pb.MarketingData(),
2887
+ )
2888
+
2889
+ # Reach and Frequency, Organic, Expanded, Lagged
2890
+ MOCK_INPUT_DATA_RF_ORGANIC_EXPANDED_LAGGED = mock.MagicMock(
2891
+ kpi_type=c.REVENUE,
2892
+ geo=xr.DataArray(np.array(_GEO_IDS)),
2893
+ time=xr.DataArray(np.array(_TIME_STRS)),
2894
+ media_time=xr.DataArray(np.array(_MEDIA_TIME_STRS)),
2895
+ population=xr.DataArray(
2896
+ coords={c.GEO: _GEO_IDS}, data=np.array([11.1, 12.2]), name=c.POPULATION
2897
+ ),
2898
+ media=None,
2899
+ media_spend=None,
2900
+ reach=xr.DataArray(
2901
+ coords={
2902
+ c.GEO: _GEO_IDS,
2903
+ c.MEDIA_TIME: _MEDIA_TIME_STRS,
2904
+ c.RF_CHANNEL: _RF_CHANNEL_PAID,
2905
+ },
2906
+ data=np.array([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]]),
2907
+ name=c.REACH,
2908
+ ),
2909
+ frequency=xr.DataArray(
2910
+ coords={
2911
+ c.GEO: _GEO_IDS,
2912
+ c.MEDIA_TIME: _MEDIA_TIME_STRS,
2913
+ c.RF_CHANNEL: _RF_CHANNEL_PAID,
2914
+ },
2915
+ data=np.array([
2916
+ [[2.1, 2.2], [2.3, 2.4], [2.5, 2.6]],
2917
+ [[2.7, 2.8], [2.9, 3.0], [3.1, 3.2]],
2918
+ ]),
2919
+ name=c.FREQUENCY,
2920
+ ),
2921
+ rf_spend=xr.DataArray(
2922
+ coords={c.RF_CHANNEL: _RF_CHANNEL_PAID},
2923
+ data=np.array([1004, 1008]),
2924
+ name=c.RF_SPEND,
2925
+ ),
2926
+ rf_spend_has_geo_dimension=False,
2927
+ rf_spend_has_time_dimension=False,
2928
+ organic_media=None,
2929
+ organic_reach=xr.DataArray(
2930
+ coords={
2931
+ c.GEO: _GEO_IDS,
2932
+ c.MEDIA_TIME: _MEDIA_TIME_STRS,
2933
+ c.ORGANIC_RF_CHANNEL: _RF_CHANNEL_ORGANIC,
2934
+ },
2935
+ data=np.array(
2936
+ [[[51, 52], [53, 54], [55, 56]], [[57, 58], [59, 60], [61, 62]]]
2937
+ ),
2938
+ name=c.ORGANIC_REACH,
2939
+ ),
2940
+ organic_frequency=xr.DataArray(
2941
+ coords={
2942
+ c.GEO: _GEO_IDS,
2943
+ c.MEDIA_TIME: _MEDIA_TIME_STRS,
2944
+ c.ORGANIC_RF_CHANNEL: _RF_CHANNEL_ORGANIC,
2945
+ },
2946
+ data=np.array([
2947
+ [[1.1, 1.2], [1.3, 1.4], [1.5, 1.6]],
2948
+ [[1.7, 1.8], [1.9, 2.0], [2.1, 2.2]],
2949
+ ]),
2950
+ name=c.ORGANIC_FREQUENCY,
2951
+ ),
2952
+ non_media_treatments=None,
2953
+ kpi=xr.DataArray(
2954
+ coords={c.GEO: _GEO_IDS, c.TIME: _TIME_STRS},
2955
+ data=np.array([[2, 2], [3, 3]]),
2956
+ name=c.KPI,
2957
+ ),
2958
+ revenue_per_kpi=xr.DataArray(
2959
+ coords={c.GEO: _GEO_IDS, c.TIME: _TIME_STRS},
2960
+ data=np.ones((2, 2)),
2961
+ name=c.REVENUE_PER_KPI,
2962
+ ),
2963
+ controls=xr.DataArray(
2964
+ coords={
2965
+ c.GEO: _GEO_IDS,
2966
+ c.TIME: _TIME_STRS,
2967
+ c.CONTROL_VARIABLE: _CONTROL_VARIABLES,
2968
+ },
2969
+ data=np.array([[[31, 32], [33, 34]], [[35, 36], [37, 38]]]),
2970
+ name=c.CONTROLS,
2971
+ ),
2972
+ )
2973
+
2974
+ MOCK_PROTO_RF_ORGANIC_EXPANDED_LAGGED = text_format.Parse(
2975
+ """
2976
+ marketing_data_points {
2977
+ geo_info {
2978
+ geo_id: "geo_0"
2979
+ population: 11
2980
+ }
2981
+ date_interval {
2982
+ start_date {
2983
+ year: 2021
2984
+ month: 1
2985
+ day: 25
2986
+ }
2987
+ end_date {
2988
+ year: 2021
2989
+ month: 2
2990
+ day: 1
2991
+ }
2992
+ }
2993
+ reach_frequency_variables {
2994
+ channel_name: "rf_ch_paid_0"
2995
+ reach: 1
2996
+ average_frequency: 2.1
2997
+ }
2998
+ reach_frequency_variables {
2999
+ channel_name: "rf_ch_paid_1"
3000
+ reach: 2
3001
+ average_frequency: 2.2
3002
+ }
3003
+ reach_frequency_variables {
3004
+ channel_name: "rf_ch_organic_0"
3005
+ reach: 51
3006
+ average_frequency: 1.1
3007
+ }
3008
+ reach_frequency_variables {
3009
+ channel_name: "rf_ch_organic_1"
3010
+ reach: 52
3011
+ average_frequency: 1.2
3012
+ }
3013
+ }
3014
+ marketing_data_points {
3015
+ geo_info {
3016
+ geo_id: "geo_0"
3017
+ population: 11
3018
+ }
3019
+ date_interval {
3020
+ start_date {
3021
+ year: 2021
3022
+ month: 2
3023
+ day: 1
3024
+ }
3025
+ end_date {
3026
+ year: 2021
3027
+ month: 2
3028
+ day: 8
3029
+ }
3030
+ }
3031
+ control_variables {
3032
+ name: "control_0"
3033
+ value: 31.0
3034
+ }
3035
+ control_variables {
3036
+ name: "control_1"
3037
+ value: 32.0
3038
+ }
3039
+ reach_frequency_variables {
3040
+ channel_name: "rf_ch_paid_0"
3041
+ reach: 3
3042
+ average_frequency: 2.3
3043
+ }
3044
+ reach_frequency_variables {
3045
+ channel_name: "rf_ch_paid_1"
3046
+ reach: 4
3047
+ average_frequency: 2.4
3048
+ }
3049
+ reach_frequency_variables {
3050
+ channel_name: "rf_ch_organic_0"
3051
+ reach: 53
3052
+ average_frequency: 1.3
3053
+ }
3054
+ reach_frequency_variables {
3055
+ channel_name: "rf_ch_organic_1"
3056
+ reach: 54
3057
+ average_frequency: 1.4
3058
+ }
3059
+ kpi {
3060
+ name: "revenue"
3061
+ revenue {
3062
+ value: 2.0
3063
+ }
3064
+ }
3065
+ }
3066
+ marketing_data_points {
3067
+ geo_info {
3068
+ geo_id: "geo_0"
3069
+ population: 11
3070
+ }
3071
+ date_interval {
3072
+ start_date {
3073
+ year: 2021
3074
+ month: 2
3075
+ day: 8
3076
+ }
3077
+ end_date {
3078
+ year: 2021
3079
+ month: 2
3080
+ day: 15
3081
+ }
3082
+ }
3083
+ control_variables {
3084
+ name: "control_0"
3085
+ value: 33.0
3086
+ }
3087
+ control_variables {
3088
+ name: "control_1"
3089
+ value: 34.0
3090
+ }
3091
+ reach_frequency_variables {
3092
+ channel_name: "rf_ch_paid_0"
3093
+ reach: 5
3094
+ average_frequency: 2.5
3095
+ }
3096
+ reach_frequency_variables {
3097
+ channel_name: "rf_ch_paid_1"
3098
+ reach: 6
3099
+ average_frequency: 2.6
3100
+ }
3101
+ reach_frequency_variables {
3102
+ channel_name: "rf_ch_organic_0"
3103
+ reach: 55
3104
+ average_frequency: 1.5
3105
+ }
3106
+ reach_frequency_variables {
3107
+ channel_name: "rf_ch_organic_1"
3108
+ reach: 56
3109
+ average_frequency: 1.6
3110
+ }
3111
+ kpi {
3112
+ name: "revenue"
3113
+ revenue {
3114
+ value: 2.0
3115
+ }
3116
+ }
3117
+ }
3118
+ marketing_data_points {
3119
+ geo_info {
3120
+ geo_id: "geo_1"
3121
+ population: 12
3122
+ }
3123
+ date_interval {
3124
+ start_date {
3125
+ year: 2021
3126
+ month: 1
3127
+ day: 25
3128
+ }
3129
+ end_date {
3130
+ year: 2021
3131
+ month: 2
3132
+ day: 1
3133
+ }
3134
+ }
3135
+ reach_frequency_variables {
3136
+ channel_name: "rf_ch_paid_0"
3137
+ reach: 7
3138
+ average_frequency: 2.7
3139
+ }
3140
+ reach_frequency_variables {
3141
+ channel_name: "rf_ch_paid_1"
3142
+ reach: 8
3143
+ average_frequency: 2.8
3144
+ }
3145
+ reach_frequency_variables {
3146
+ channel_name: "rf_ch_organic_0"
3147
+ reach: 57
3148
+ average_frequency: 1.7
3149
+ }
3150
+ reach_frequency_variables {
3151
+ channel_name: "rf_ch_organic_1"
3152
+ reach: 58
3153
+ average_frequency: 1.8
3154
+ }
3155
+ }
3156
+ marketing_data_points {
3157
+ geo_info {
3158
+ geo_id: "geo_1"
3159
+ population: 12
3160
+ }
3161
+ date_interval {
3162
+ start_date {
3163
+ year: 2021
3164
+ month: 2
3165
+ day: 1
3166
+ }
3167
+ end_date {
3168
+ year: 2021
3169
+ month: 2
3170
+ day: 8
3171
+ }
3172
+ }
3173
+ control_variables {
3174
+ name: "control_0"
3175
+ value: 35.0
3176
+ }
3177
+ control_variables {
3178
+ name: "control_1"
3179
+ value: 36.0
3180
+ }
3181
+ reach_frequency_variables {
3182
+ channel_name: "rf_ch_paid_0"
3183
+ reach: 9
3184
+ average_frequency: 2.9
3185
+ }
3186
+ reach_frequency_variables {
3187
+ channel_name: "rf_ch_paid_1"
3188
+ reach: 10
3189
+ average_frequency: 3.0
3190
+ }
3191
+ reach_frequency_variables {
3192
+ channel_name: "rf_ch_organic_0"
3193
+ reach: 59
3194
+ average_frequency: 1.9
3195
+ }
3196
+ reach_frequency_variables {
3197
+ channel_name: "rf_ch_organic_1"
3198
+ reach: 60
3199
+ average_frequency: 2.0
3200
+ }
3201
+ kpi {
3202
+ name: "revenue"
3203
+ revenue {
3204
+ value: 3.0
3205
+ }
3206
+ }
3207
+ }
3208
+ marketing_data_points {
3209
+ geo_info {
3210
+ geo_id: "geo_1"
3211
+ population: 12
3212
+ }
3213
+ date_interval {
3214
+ start_date {
3215
+ year: 2021
3216
+ month: 2
3217
+ day: 8
3218
+ }
3219
+ end_date {
3220
+ year: 2021
3221
+ month: 2
3222
+ day: 15
3223
+ }
3224
+ }
3225
+ control_variables {
3226
+ name: "control_0"
3227
+ value: 37.0
3228
+ }
3229
+ control_variables {
3230
+ name: "control_1"
3231
+ value: 38.0
3232
+ }
3233
+ reach_frequency_variables {
3234
+ channel_name: "rf_ch_paid_0"
3235
+ reach: 11
3236
+ average_frequency: 3.1
3237
+ }
3238
+ reach_frequency_variables {
3239
+ channel_name: "rf_ch_paid_1"
3240
+ reach: 12
3241
+ average_frequency: 3.2
3242
+ }
3243
+ reach_frequency_variables {
3244
+ channel_name: "rf_ch_organic_0"
3245
+ reach: 61
3246
+ average_frequency: 2.1
3247
+ }
3248
+ reach_frequency_variables {
3249
+ channel_name: "rf_ch_organic_1"
3250
+ reach: 62
3251
+ average_frequency: 2.2
3252
+ }
3253
+ kpi {
3254
+ name: "revenue"
3255
+ revenue {
3256
+ value: 3.0
3257
+ }
3258
+ }
3259
+ }
3260
+ marketing_data_points {
3261
+ date_interval {
3262
+ start_date {
3263
+ year: 2021
3264
+ month: 1
3265
+ day: 25
3266
+ }
3267
+ end_date {
3268
+ year: 2021
3269
+ month: 2
3270
+ day: 15
3271
+ }
3272
+ }
3273
+ reach_frequency_variables {
3274
+ channel_name: "rf_ch_paid_0"
3275
+ spend: 1004.0
3276
+ }
3277
+ reach_frequency_variables {
3278
+ channel_name: "rf_ch_paid_1"
3279
+ spend: 1008.0
3280
+ }
3281
+ }
3282
+ metadata {
3283
+ time_dimensions {
3284
+ name: "time"
3285
+ dates {
3286
+ year: 2021
3287
+ month: 2
3288
+ day: 1
3289
+ }
3290
+ dates {
3291
+ year: 2021
3292
+ month: 2
3293
+ day: 8
3294
+ }
3295
+ }
3296
+ time_dimensions {
3297
+ name: "media_time"
3298
+ dates {
3299
+ year: 2021
3300
+ month: 1
3301
+ day: 25
3302
+ }
3303
+ dates {
3304
+ year: 2021
3305
+ month: 2
3306
+ day: 1
3307
+ }
3308
+ dates {
3309
+ year: 2021
3310
+ month: 2
3311
+ day: 8
3312
+ }
3313
+ }
3314
+ channel_dimensions {
3315
+ name: "reach"
3316
+ channels: "rf_ch_paid_0"
3317
+ channels: "rf_ch_paid_1"
3318
+ }
3319
+ channel_dimensions {
3320
+ name: "frequency"
3321
+ channels: "rf_ch_paid_0"
3322
+ channels: "rf_ch_paid_1"
3323
+ }
3324
+ channel_dimensions {
3325
+ name: "organic_reach"
3326
+ channels: "rf_ch_organic_0"
3327
+ channels: "rf_ch_organic_1"
3328
+ }
3329
+ channel_dimensions {
3330
+ name: "organic_frequency"
3331
+ channels: "rf_ch_organic_0"
3332
+ channels: "rf_ch_organic_1"
3333
+ }
3334
+ control_names: "control_0"
3335
+ control_names: "control_1"
3336
+ kpi_type: "revenue"
3337
+ }
3338
+ """,
3339
+ marketing_pb.MarketingData(),
3340
+ )
3341
+
3342
+ # Reach and Frequency, Organic, Granular, Not Lagged
3343
+ MOCK_INPUT_DATA_RF_ORGANIC_GRANULAR_NOT_LAGGED = mock.MagicMock(
3344
+ kpi_type=c.REVENUE,
3345
+ geo=xr.DataArray(np.array(_GEO_IDS)),
3346
+ time=xr.DataArray(np.array(_TIME_STRS)),
3347
+ media_time=xr.DataArray(np.array(_TIME_STRS)),
3348
+ population=xr.DataArray(
3349
+ coords={c.GEO: _GEO_IDS}, data=np.array([11.1, 12.2]), name=c.POPULATION
3350
+ ),
3351
+ media=None,
3352
+ media_spend=None,
3353
+ reach=xr.DataArray(
3354
+ coords={
3355
+ c.GEO: _GEO_IDS,
3356
+ c.MEDIA_TIME: _TIME_STRS,
3357
+ c.RF_CHANNEL: _RF_CHANNEL_PAID,
3358
+ },
3359
+ data=np.array([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]),
3360
+ name=c.REACH,
3361
+ ),
3362
+ frequency=xr.DataArray(
3363
+ coords={
3364
+ c.GEO: _GEO_IDS,
3365
+ c.MEDIA_TIME: _TIME_STRS,
3366
+ c.RF_CHANNEL: _RF_CHANNEL_PAID,
3367
+ },
3368
+ data=np.array([[[2.1, 2.2], [3, 4]], [[5, 6], [7, 8]]]),
3369
+ name=c.FREQUENCY,
3370
+ ),
3371
+ rf_spend=xr.DataArray(
3372
+ coords={
3373
+ c.GEO: _GEO_IDS,
3374
+ c.TIME: _TIME_STRS,
3375
+ c.RF_CHANNEL: _RF_CHANNEL_PAID,
3376
+ },
3377
+ data=np.array([[[252, 253], [254, 255]], [[256, 257], [258, 259]]]),
3378
+ name=c.RF_SPEND,
3379
+ ),
3380
+ rf_spend_has_geo_dimension=True,
3381
+ rf_spend_has_time_dimension=True,
3382
+ organic_media=None,
3383
+ organic_reach=xr.DataArray(
3384
+ coords={
3385
+ c.GEO: _GEO_IDS,
3386
+ c.MEDIA_TIME: _TIME_STRS,
3387
+ c.ORGANIC_RF_CHANNEL: _RF_CHANNEL_ORGANIC,
3388
+ },
3389
+ data=np.array(
3390
+ [[[51.0, 52.0], [53.0, 54.0]], [[55.0, 56.0], [57.0, 58.0]]]
3391
+ ),
3392
+ name=c.ORGANIC_REACH,
3393
+ ),
3394
+ organic_frequency=xr.DataArray(
3395
+ coords={
3396
+ c.GEO: _GEO_IDS,
3397
+ c.MEDIA_TIME: _TIME_STRS,
3398
+ c.ORGANIC_RF_CHANNEL: _RF_CHANNEL_ORGANIC,
3399
+ },
3400
+ data=np.array(
3401
+ [[[51.0, 52.0], [53.0, 54.0]], [[55.0, 56.0], [57.0, 58.0]]]
3402
+ ),
3403
+ name=c.ORGANIC_FREQUENCY,
3404
+ ),
3405
+ non_media_treatments=None,
3406
+ kpi=xr.DataArray(
3407
+ coords={c.GEO: _GEO_IDS, c.TIME: _TIME_STRS},
3408
+ data=np.array([[2, 2], [3, 3]]),
3409
+ name=c.KPI,
3410
+ ),
3411
+ revenue_per_kpi=xr.DataArray(
3412
+ coords={c.GEO: _GEO_IDS, c.TIME: _TIME_STRS},
3413
+ data=np.ones((2, 2)),
3414
+ name=c.REVENUE_PER_KPI,
3415
+ ),
3416
+ controls=xr.DataArray(
3417
+ coords={
3418
+ c.GEO: _GEO_IDS,
3419
+ c.TIME: _TIME_STRS,
3420
+ c.CONTROL_VARIABLE: _CONTROL_VARIABLES,
3421
+ },
3422
+ data=np.array([[[31, 32], [33, 34]], [[35, 36], [37, 38]]]),
3423
+ name=c.CONTROLS,
3424
+ ),
3425
+ )
3426
+
3427
+ MOCK_PROTO_RF_ORGANIC_GRANULAR_NOT_LAGGED = text_format.Parse(
3428
+ """
3429
+ marketing_data_points {
3430
+ geo_info {
3431
+ geo_id: "geo_0"
3432
+ population: 11
3433
+ }
3434
+ date_interval {
3435
+ start_date {
3436
+ year: 2021
3437
+ month: 2
3438
+ day: 1
3439
+ }
3440
+ end_date {
3441
+ year: 2021
3442
+ month: 2
3443
+ day: 8
3444
+ }
3445
+ }
3446
+ control_variables {
3447
+ name: "control_0"
3448
+ value: 31.0
3449
+ }
3450
+ control_variables {
3451
+ name: "control_1"
3452
+ value: 32.0
3453
+ }
3454
+ reach_frequency_variables {
3455
+ channel_name: "rf_ch_paid_0"
3456
+ reach: 1
3457
+ average_frequency: 2.1
3458
+ spend: 252.0
3459
+ }
3460
+ reach_frequency_variables {
3461
+ channel_name: "rf_ch_paid_1"
3462
+ reach: 2
3463
+ average_frequency: 2.2
3464
+ spend: 253.0
3465
+ }
3466
+ reach_frequency_variables {
3467
+ channel_name: "rf_ch_organic_0"
3468
+ reach: 51
3469
+ average_frequency: 51.0
3470
+ }
3471
+ reach_frequency_variables {
3472
+ channel_name: "rf_ch_organic_1"
3473
+ reach: 52
3474
+ average_frequency: 52.0
3475
+ }
3476
+ kpi {
3477
+ name: "revenue"
3478
+ revenue {
3479
+ value: 2.0
3480
+ }
3481
+ }
3482
+ }
3483
+ marketing_data_points {
3484
+ geo_info {
3485
+ geo_id: "geo_0"
3486
+ population: 11
3487
+ }
3488
+ date_interval {
3489
+ start_date {
3490
+ year: 2021
3491
+ month: 2
3492
+ day: 8
3493
+ }
3494
+ end_date {
3495
+ year: 2021
3496
+ month: 2
3497
+ day: 15
3498
+ }
3499
+ }
3500
+ control_variables {
3501
+ name: "control_0"
3502
+ value: 33.0
3503
+ }
3504
+ control_variables {
3505
+ name: "control_1"
3506
+ value: 34.0
3507
+ }
3508
+ reach_frequency_variables {
3509
+ channel_name: "rf_ch_paid_0"
3510
+ reach: 3
3511
+ average_frequency: 3.0
3512
+ spend: 254.0
3513
+ }
3514
+ reach_frequency_variables {
3515
+ channel_name: "rf_ch_paid_1"
3516
+ reach: 4
3517
+ average_frequency: 4.0
3518
+ spend: 255.0
3519
+ }
3520
+ reach_frequency_variables {
3521
+ channel_name: "rf_ch_organic_0"
3522
+ reach: 53
3523
+ average_frequency: 53.0
3524
+ }
3525
+ reach_frequency_variables {
3526
+ channel_name: "rf_ch_organic_1"
3527
+ reach: 54
3528
+ average_frequency: 54.0
3529
+ }
3530
+ kpi {
3531
+ name: "revenue"
3532
+ revenue {
3533
+ value: 2.0
3534
+ }
3535
+ }
3536
+ }
3537
+ marketing_data_points {
3538
+ geo_info {
3539
+ geo_id: "geo_1"
3540
+ population: 12
3541
+ }
3542
+ date_interval {
3543
+ start_date {
3544
+ year: 2021
3545
+ month: 2
3546
+ day: 1
3547
+ }
3548
+ end_date {
3549
+ year: 2021
3550
+ month: 2
3551
+ day: 8
3552
+ }
3553
+ }
3554
+ control_variables {
3555
+ name: "control_0"
3556
+ value: 35.0
3557
+ }
3558
+ control_variables {
3559
+ name: "control_1"
3560
+ value: 36.0
3561
+ }
3562
+ reach_frequency_variables {
3563
+ channel_name: "rf_ch_paid_0"
3564
+ reach: 5
3565
+ average_frequency: 5.0
3566
+ spend: 256.0
3567
+ }
3568
+ reach_frequency_variables {
3569
+ channel_name: "rf_ch_paid_1"
3570
+ reach: 6
3571
+ average_frequency: 6.0
3572
+ spend: 257.0
3573
+ }
3574
+ reach_frequency_variables {
3575
+ channel_name: "rf_ch_organic_0"
3576
+ reach: 55
3577
+ average_frequency: 55.0
3578
+ }
3579
+ reach_frequency_variables {
3580
+ channel_name: "rf_ch_organic_1"
3581
+ reach: 56
3582
+ average_frequency: 56.0
3583
+ }
3584
+ kpi {
3585
+ name: "revenue"
3586
+ revenue {
3587
+ value: 3.0
3588
+ }
3589
+ }
3590
+ }
3591
+ marketing_data_points {
3592
+ geo_info {
3593
+ geo_id: "geo_1"
3594
+ population: 12
3595
+ }
3596
+ date_interval {
3597
+ start_date {
3598
+ year: 2021
3599
+ month: 2
3600
+ day: 8
3601
+ }
3602
+ end_date {
3603
+ year: 2021
3604
+ month: 2
3605
+ day: 15
3606
+ }
3607
+ }
3608
+ control_variables {
3609
+ name: "control_0"
3610
+ value: 37.0
3611
+ }
3612
+ control_variables {
3613
+ name: "control_1"
3614
+ value: 38.0
3615
+ }
3616
+ reach_frequency_variables {
3617
+ channel_name: "rf_ch_paid_0"
3618
+ reach: 7
3619
+ average_frequency: 7.0
3620
+ spend: 258.0
3621
+ }
3622
+ reach_frequency_variables {
3623
+ channel_name: "rf_ch_paid_1"
3624
+ reach: 8
3625
+ average_frequency: 8.0
3626
+ spend: 259.0
3627
+ }
3628
+ reach_frequency_variables {
3629
+ channel_name: "rf_ch_organic_0"
3630
+ reach: 57
3631
+ average_frequency: 57.0
3632
+ }
3633
+ reach_frequency_variables {
3634
+ channel_name: "rf_ch_organic_1"
3635
+ reach: 58
3636
+ average_frequency: 58.0
3637
+ }
3638
+ kpi {
3639
+ name: "revenue"
3640
+ revenue {
3641
+ value: 3.0
3642
+ }
3643
+ }
3644
+ }
3645
+ metadata {
3646
+ time_dimensions {
3647
+ name: "time"
3648
+ dates {
3649
+ year: 2021
3650
+ month: 2
3651
+ day: 1
3652
+ }
3653
+ dates {
3654
+ year: 2021
3655
+ month: 2
3656
+ day: 8
3657
+ }
3658
+ }
3659
+ time_dimensions {
3660
+ name: "media_time"
3661
+ dates {
3662
+ year: 2021
3663
+ month: 2
3664
+ day: 1
3665
+ }
3666
+ dates {
3667
+ year: 2021
3668
+ month: 2
3669
+ day: 8
3670
+ }
3671
+ }
3672
+ channel_dimensions {
3673
+ name: "reach"
3674
+ channels: "rf_ch_paid_0"
3675
+ channels: "rf_ch_paid_1"
3676
+ }
3677
+ channel_dimensions {
3678
+ name: "frequency"
3679
+ channels: "rf_ch_paid_0"
3680
+ channels: "rf_ch_paid_1"
3681
+ }
3682
+ channel_dimensions {
3683
+ name: "organic_reach"
3684
+ channels: "rf_ch_organic_0"
3685
+ channels: "rf_ch_organic_1"
3686
+ }
3687
+ channel_dimensions {
3688
+ name: "organic_frequency"
3689
+ channels: "rf_ch_organic_0"
3690
+ channels: "rf_ch_organic_1"
3691
+ }
3692
+ control_names: "control_0"
3693
+ control_names: "control_1"
3694
+ kpi_type: "revenue"
3695
+ }
3696
+ """,
3697
+ marketing_pb.MarketingData(),
3698
+ )
3699
+
3700
+ MOCK_INPUT_DATA_NON_MEDIA_TREATMENTS = mock.MagicMock(
3701
+ kpi_type=c.REVENUE,
3702
+ geo=xr.DataArray(np.array(_GEO_IDS)),
3703
+ time=xr.DataArray(np.array(_TIME_STRS)),
3704
+ media_time=xr.DataArray(np.array(_MEDIA_TIME_STRS)),
3705
+ population=xr.DataArray(
3706
+ coords={c.GEO: _GEO_IDS}, data=np.array([11.1, 12.2]), name=c.POPULATION
3707
+ ),
3708
+ kpi=xr.DataArray(
3709
+ coords={c.GEO: _GEO_IDS, c.TIME: _TIME_STRS},
3710
+ data=np.array([[1, 2], [3, 4]]),
3711
+ name=c.KPI,
3712
+ ),
3713
+ revenue_per_kpi=xr.DataArray(
3714
+ coords={c.GEO: _GEO_IDS, c.TIME: _TIME_STRS},
3715
+ data=np.ones((2, 2)),
3716
+ name=c.REVENUE_PER_KPI,
3717
+ ),
3718
+ controls=xr.DataArray(
3719
+ coords={
3720
+ c.GEO: _GEO_IDS,
3721
+ c.TIME: _TIME_STRS,
3722
+ c.CONTROL_VARIABLE: _CONTROL_VARIABLES,
3723
+ },
3724
+ data=np.array([[[31, 32], [33, 34]], [[35, 36], [37, 38]]]),
3725
+ name=c.CONTROLS,
3726
+ ),
3727
+ non_media_treatments=xr.DataArray(
3728
+ coords={
3729
+ c.GEO: _GEO_IDS,
3730
+ c.TIME: _TIME_STRS,
3731
+ c.NON_MEDIA_CHANNEL: _NON_MEDIA_TREATMENT_VARIABLES,
3732
+ },
3733
+ data=np.array([[[61, 62], [63, 64]], [[65, 66], [67, 68]]]),
3734
+ name=c.NON_MEDIA_TREATMENTS,
3735
+ ),
3736
+ media=xr.DataArray(
3737
+ coords={
3738
+ c.GEO: _GEO_IDS,
3739
+ c.MEDIA_TIME: _MEDIA_TIME_STRS,
3740
+ c.MEDIA_CHANNEL: _MEDIA_CHANNEL_PAID,
3741
+ },
3742
+ data=np.array(
3743
+ [[[39, 40], [41, 42], [43, 44]], [[45, 46], [47, 48], [49, 50]]]
3744
+ ),
3745
+ name=c.MEDIA,
3746
+ ),
3747
+ media_spend=xr.DataArray(
3748
+ coords={c.MEDIA_CHANNEL: _MEDIA_CHANNEL_PAID},
3749
+ data=np.array([492, 496]),
3750
+ name=c.MEDIA_SPEND,
3751
+ ),
3752
+ media_spend_has_geo_dimension=False,
3753
+ media_spend_has_time_dimension=False,
3754
+ reach=None,
3755
+ frequency=None,
3756
+ rf_spend=None,
3757
+ rf_spend_has_geo_dimension=False,
3758
+ rf_spend_has_time_dimension=False,
3759
+ organic_media=None,
3760
+ organic_reach=None,
3761
+ organic_frequency=None,
3762
+ )
3763
+
3764
+ MOCK_PROTO_NON_MEDIA_TREATMENTS = text_format.Parse(
3765
+ """
3766
+ marketing_data_points {
3767
+ geo_info {
3768
+ geo_id: "geo_0"
3769
+ population: 11
3770
+ }
3771
+ date_interval {
3772
+ start_date {
3773
+ year: 2021
3774
+ month: 1
3775
+ day: 25
3776
+ }
3777
+ end_date {
3778
+ year: 2021
3779
+ month: 2
3780
+ day: 1
3781
+ }
3782
+ }
3783
+ media_variables {
3784
+ channel_name: "ch_paid_0"
3785
+ scalar_metric {
3786
+ name: "impressions"
3787
+ value: 39.0
3788
+ }
3789
+ }
3790
+ media_variables {
3791
+ channel_name: "ch_paid_1"
3792
+ scalar_metric {
3793
+ name: "impressions"
3794
+ value: 40.0
3795
+ }
3796
+ }
3797
+ }
3798
+ marketing_data_points {
3799
+ geo_info {
3800
+ geo_id: "geo_0"
3801
+ population: 11
3802
+ }
3803
+ date_interval {
3804
+ start_date {
3805
+ year: 2021
3806
+ month: 2
3807
+ day: 1
3808
+ }
3809
+ end_date {
3810
+ year: 2021
3811
+ month: 2
3812
+ day: 8
3813
+ }
3814
+ }
3815
+ control_variables {
3816
+ name: "control_0"
3817
+ value: 31.0
3818
+ }
3819
+ control_variables {
3820
+ name: "control_1"
3821
+ value: 32.0
3822
+ }
3823
+ media_variables {
3824
+ channel_name: "ch_paid_0"
3825
+ scalar_metric {
3826
+ name: "impressions"
3827
+ value: 41.0
3828
+ }
3829
+ }
3830
+ media_variables {
3831
+ channel_name: "ch_paid_1"
3832
+ scalar_metric {
3833
+ name: "impressions"
3834
+ value: 42.0
3835
+ }
3836
+ }
3837
+ non_media_treatment_variables {
3838
+ name: "non_media_treatment_0"
3839
+ value: 61.0
3840
+ }
3841
+ non_media_treatment_variables {
3842
+ name: "non_media_treatment_1"
3843
+ value: 62.0
3844
+ }
3845
+ kpi {
3846
+ name: "revenue"
3847
+ revenue {
3848
+ value: 1.0
3849
+ }
3850
+ }
3851
+ }
3852
+ marketing_data_points {
3853
+ geo_info {
3854
+ geo_id: "geo_0"
3855
+ population: 11
3856
+ }
3857
+ date_interval {
3858
+ start_date {
3859
+ year: 2021
3860
+ month: 2
3861
+ day: 8
3862
+ }
3863
+ end_date {
3864
+ year: 2021
3865
+ month: 2
3866
+ day: 15
3867
+ }
3868
+ }
3869
+ control_variables {
3870
+ name: "control_0"
3871
+ value: 33.0
3872
+ }
3873
+ control_variables {
3874
+ name: "control_1"
3875
+ value: 34.0
3876
+ }
3877
+ media_variables {
3878
+ channel_name: "ch_paid_0"
3879
+ scalar_metric {
3880
+ name: "impressions"
3881
+ value: 43.0
3882
+ }
3883
+ }
3884
+ media_variables {
3885
+ channel_name: "ch_paid_1"
3886
+ scalar_metric {
3887
+ name: "impressions"
3888
+ value: 44.0
3889
+ }
3890
+ }
3891
+ non_media_treatment_variables {
3892
+ name: "non_media_treatment_0"
3893
+ value: 63.0
3894
+ }
3895
+ non_media_treatment_variables {
3896
+ name: "non_media_treatment_1"
3897
+ value: 64.0
3898
+ }
3899
+ kpi {
3900
+ name: "revenue"
3901
+ revenue {
3902
+ value: 2.0
3903
+ }
3904
+ }
3905
+ }
3906
+ marketing_data_points {
3907
+ geo_info {
3908
+ geo_id: "geo_1"
3909
+ population: 12
3910
+ }
3911
+ date_interval {
3912
+ start_date {
3913
+ year: 2021
3914
+ month: 1
3915
+ day: 25
3916
+ }
3917
+ end_date {
3918
+ year: 2021
3919
+ month: 2
3920
+ day: 1
3921
+ }
3922
+ }
3923
+ media_variables {
3924
+ channel_name: "ch_paid_0"
3925
+ scalar_metric {
3926
+ name: "impressions"
3927
+ value: 45.0
3928
+ }
3929
+ }
3930
+ media_variables {
3931
+ channel_name: "ch_paid_1"
3932
+ scalar_metric {
3933
+ name: "impressions"
3934
+ value: 46.0
3935
+ }
3936
+ }
3937
+ }
3938
+ marketing_data_points {
3939
+ geo_info {
3940
+ geo_id: "geo_1"
3941
+ population: 12
3942
+ }
3943
+ date_interval {
3944
+ start_date {
3945
+ year: 2021
3946
+ month: 2
3947
+ day: 1
3948
+ }
3949
+ end_date {
3950
+ year: 2021
3951
+ month: 2
3952
+ day: 8
3953
+ }
3954
+ }
3955
+ control_variables {
3956
+ name: "control_0"
3957
+ value: 35.0
3958
+ }
3959
+ control_variables {
3960
+ name: "control_1"
3961
+ value: 36.0
3962
+ }
3963
+ media_variables {
3964
+ channel_name: "ch_paid_0"
3965
+ scalar_metric {
3966
+ name: "impressions"
3967
+ value: 47.0
3968
+ }
3969
+ }
3970
+ media_variables {
3971
+ channel_name: "ch_paid_1"
3972
+ scalar_metric {
3973
+ name: "impressions"
3974
+ value: 48.0
3975
+ }
3976
+ }
3977
+ non_media_treatment_variables {
3978
+ name: "non_media_treatment_0"
3979
+ value: 65.0
3980
+ }
3981
+ non_media_treatment_variables {
3982
+ name: "non_media_treatment_1"
3983
+ value: 66.0
3984
+ }
3985
+ kpi {
3986
+ name: "revenue"
3987
+ revenue {
3988
+ value: 3.0
3989
+ }
3990
+ }
3991
+ }
3992
+ marketing_data_points {
3993
+ geo_info {
3994
+ geo_id: "geo_1"
3995
+ population: 12
3996
+ }
3997
+ date_interval {
3998
+ start_date {
3999
+ year: 2021
4000
+ month: 2
4001
+ day: 8
4002
+ }
4003
+ end_date {
4004
+ year: 2021
4005
+ month: 2
4006
+ day: 15
4007
+ }
4008
+ }
4009
+ control_variables {
4010
+ name: "control_0"
4011
+ value: 37.0
4012
+ }
4013
+ control_variables {
4014
+ name: "control_1"
4015
+ value: 38.0
4016
+ }
4017
+ media_variables {
4018
+ channel_name: "ch_paid_0"
4019
+ scalar_metric {
4020
+ name: "impressions"
4021
+ value: 49.0
4022
+ }
4023
+ }
4024
+ media_variables {
4025
+ channel_name: "ch_paid_1"
4026
+ scalar_metric {
4027
+ name: "impressions"
4028
+ value: 50.0
4029
+ }
4030
+ }
4031
+ non_media_treatment_variables {
4032
+ name: "non_media_treatment_0"
4033
+ value: 67.0
4034
+ }
4035
+ non_media_treatment_variables {
4036
+ name: "non_media_treatment_1"
4037
+ value: 68.0
4038
+ }
4039
+ kpi {
4040
+ name: "revenue"
4041
+ revenue {
4042
+ value: 4.0
4043
+ }
4044
+ }
4045
+ }
4046
+ marketing_data_points {
4047
+ date_interval {
4048
+ start_date {
4049
+ year: 2021
4050
+ month: 1
4051
+ day: 25
4052
+ }
4053
+ end_date {
4054
+ year: 2021
4055
+ month: 2
4056
+ day: 15
4057
+ }
4058
+ }
4059
+ media_variables {
4060
+ channel_name: "ch_paid_0"
4061
+ media_spend: 492.0
4062
+ }
4063
+ media_variables {
4064
+ channel_name: "ch_paid_1"
4065
+ media_spend: 496.0
4066
+ }
4067
+ }
4068
+ metadata {
4069
+ time_dimensions {
4070
+ name: "time"
4071
+ dates {
4072
+ year: 2021
4073
+ month: 2
4074
+ day: 1
4075
+ }
4076
+ dates {
4077
+ year: 2021
4078
+ month: 2
4079
+ day: 8
4080
+ }
4081
+ }
4082
+ time_dimensions {
4083
+ name: "media_time"
4084
+ dates {
4085
+ year: 2021
4086
+ month: 1
4087
+ day: 25
4088
+ }
4089
+ dates {
4090
+ year: 2021
4091
+ month: 2
4092
+ day: 1
4093
+ }
4094
+ dates {
4095
+ year: 2021
4096
+ month: 2
4097
+ day: 8
4098
+ }
4099
+ }
4100
+ channel_dimensions {
4101
+ name: "media"
4102
+ channels: "ch_paid_0"
4103
+ channels: "ch_paid_1"
4104
+ }
4105
+ control_names: "control_0"
4106
+ control_names: "control_1"
4107
+ non_media_treatment_names: "non_media_treatment_0"
4108
+ non_media_treatment_names: "non_media_treatment_1"
4109
+ kpi_type: "revenue"
4110
+ }
4111
+ """,
4112
+ marketing_pb.MarketingData(),
4113
+ )
4114
+
4115
+ MOCK_INPUT_DATA_NO_REVENUE_PER_KPI = mock.MagicMock(
4116
+ kpi_type=c.NON_REVENUE,
4117
+ geo=xr.DataArray(np.array(_GEO_IDS)),
4118
+ time=xr.DataArray(np.array(_TIME_STRS)),
4119
+ media_time=xr.DataArray(np.array(_TIME_STRS)),
4120
+ population=xr.DataArray(
4121
+ coords={c.GEO: _GEO_IDS},
4122
+ data=np.array([1000.0, 1200.0]),
4123
+ name=c.POPULATION,
4124
+ ),
4125
+ kpi=xr.DataArray(
4126
+ coords={c.GEO: _GEO_IDS, c.TIME: _TIME_STRS},
4127
+ data=np.array([[50, 60], [70, 80]]),
4128
+ name=c.KPI,
4129
+ ),
4130
+ revenue_per_kpi=None,
4131
+ controls=xr.DataArray(
4132
+ coords={
4133
+ c.GEO: _GEO_IDS,
4134
+ c.TIME: _TIME_STRS,
4135
+ c.CONTROL_VARIABLE: _CONTROL_VARIABLES,
4136
+ },
4137
+ data=np.array([[[31, 32], [33, 34]], [[35, 36], [37, 38]]]),
4138
+ name=c.CONTROLS,
4139
+ ),
4140
+ media=xr.DataArray(
4141
+ coords={
4142
+ c.GEO: _GEO_IDS,
4143
+ c.MEDIA_TIME: _TIME_STRS,
4144
+ c.MEDIA_CHANNEL: _MEDIA_CHANNEL_PAID,
4145
+ },
4146
+ data=np.array([[[39, 40], [41, 42]], [[43, 44], [45, 46]]]),
4147
+ name=c.MEDIA,
4148
+ ),
4149
+ media_spend=xr.DataArray(
4150
+ coords={c.MEDIA_CHANNEL: _MEDIA_CHANNEL_PAID},
4151
+ data=np.array([492, 496]),
4152
+ name=c.MEDIA_SPEND,
4153
+ ),
4154
+ media_spend_has_geo_dimension=False,
4155
+ media_spend_has_time_dimension=False,
4156
+ reach=None,
4157
+ frequency=None,
4158
+ rf_spend=None,
4159
+ organic_media=None,
4160
+ organic_reach=None,
4161
+ organic_frequency=None,
4162
+ non_media_treatments=None,
4163
+ )
4164
+
4165
+ # Expected Protobuf (Textproto format)
4166
+ MOCK_PROTO_NO_REVENUE_PER_KPI = text_format.Parse(
4167
+ """
4168
+ marketing_data_points {
4169
+ geo_info { geo_id: "geo_0" population: 1000 }
4170
+ date_interval {
4171
+ start_date { year: 2021 month: 2 day: 1 }
4172
+ end_date { year: 2021 month: 2 day: 8 }
4173
+ }
4174
+ control_variables { name: "control_0" value: 31.0 }
4175
+ control_variables { name: "control_1" value: 32.0 }
4176
+ media_variables {
4177
+ channel_name: "ch_paid_0"
4178
+ scalar_metric {
4179
+ name: "impressions"
4180
+ value: 39.0
4181
+ }
4182
+ }
4183
+ media_variables {
4184
+ channel_name: "ch_paid_1"
4185
+ scalar_metric {
4186
+ name: "impressions"
4187
+ value: 40.0
4188
+ }
4189
+ }
4190
+ kpi { name: "non_revenue" non_revenue { value: 50.0 } }
4191
+ }
4192
+ marketing_data_points {
4193
+ geo_info { geo_id: "geo_0" population: 1000 }
4194
+ date_interval {
4195
+ start_date { year: 2021 month: 2 day: 8 }
4196
+ end_date { year: 2021 month: 2 day: 15 }
4197
+ }
4198
+ control_variables { name: "control_0" value: 33.0 }
4199
+ control_variables { name: "control_1" value: 34.0 }
4200
+ media_variables {
4201
+ channel_name: "ch_paid_0"
4202
+ scalar_metric {
4203
+ name: "impressions"
4204
+ value: 41.0
4205
+ }
4206
+ }
4207
+ media_variables {
4208
+ channel_name: "ch_paid_1"
4209
+ scalar_metric {
4210
+ name: "impressions"
4211
+ value: 42.0
4212
+ }
4213
+ }
4214
+ kpi { name: "non_revenue" non_revenue { value: 60.0 } }
4215
+ }
4216
+ marketing_data_points {
4217
+ geo_info { geo_id: "geo_1" population: 1200 }
4218
+ date_interval {
4219
+ start_date { year: 2021 month: 2 day: 1 }
4220
+ end_date { year: 2021 month: 2 day: 8 }
4221
+ }
4222
+ control_variables { name: "control_0" value: 35.0 }
4223
+ control_variables { name: "control_1" value: 36.0 }
4224
+ media_variables {
4225
+ channel_name: "ch_paid_0"
4226
+ scalar_metric {
4227
+ name: "impressions"
4228
+ value: 43.0
4229
+ }
4230
+ }
4231
+ media_variables {
4232
+ channel_name: "ch_paid_1"
4233
+ scalar_metric {
4234
+ name: "impressions"
4235
+ value: 44.0
4236
+ }
4237
+ }
4238
+ kpi { name: "non_revenue" non_revenue { value: 70.0 } }
4239
+ }
4240
+ marketing_data_points {
4241
+ geo_info {
4242
+ geo_id: "geo_1"
4243
+ population: 1200
4244
+ }
4245
+ date_interval {
4246
+ start_date {
4247
+ year: 2021
4248
+ month: 2
4249
+ day: 8
4250
+ }
4251
+ end_date {
4252
+ year: 2021
4253
+ month: 2
4254
+ day: 15
4255
+ }
4256
+ }
4257
+ control_variables {
4258
+ name: "control_0"
4259
+ value: 37.0
4260
+ }
4261
+ control_variables {
4262
+ name: "control_1"
4263
+ value: 38.0
4264
+ }
4265
+ media_variables {
4266
+ channel_name: "ch_paid_0"
4267
+ scalar_metric {
4268
+ name: "impressions"
4269
+ value: 45.0
4270
+ }
4271
+ }
4272
+ media_variables {
4273
+ channel_name: "ch_paid_1"
4274
+ scalar_metric {
4275
+ name: "impressions"
4276
+ value: 46.0
4277
+ }
4278
+ }
4279
+ kpi { name: "non_revenue" non_revenue { value: 80.0 } }
4280
+ }
4281
+ marketing_data_points {
4282
+ date_interval {
4283
+ start_date {
4284
+ year: 2021
4285
+ month: 2
4286
+ day: 1
4287
+ }
4288
+ end_date {
4289
+ year: 2021
4290
+ month: 2
4291
+ day: 15
4292
+ }
4293
+ }
4294
+ media_variables {
4295
+ channel_name: "ch_paid_0"
4296
+ media_spend: 492.0
4297
+ }
4298
+ media_variables {
4299
+ channel_name: "ch_paid_1"
4300
+ media_spend: 496.0
4301
+ }
4302
+ }
4303
+ metadata {
4304
+ time_dimensions { name: "time" dates { year: 2021 month: 2 day: 1 } dates { year: 2021 month: 2 day: 8} }
4305
+ time_dimensions { name: "media_time" dates { year: 2021 month: 2 day: 1 } dates { year: 2021 month: 2 day: 8 } }
4306
+ channel_dimensions { name: "media" channels: "ch_paid_0" channels: "ch_paid_1" }
4307
+ control_names: "control_0"
4308
+ control_names: "control_1"
4309
+ kpi_type: "non_revenue"
4310
+ }
4311
+ """,
4312
+ marketing_pb.MarketingData(),
4313
+ )
4314
+
4315
+
4316
+ # Hyperparameters test data
4317
+ def get_default_model_spec() -> spec.ModelSpec:
4318
+ return spec.ModelSpec()
4319
+
4320
+
4321
+ DEFAULT_HYPERPARAMETERS_PROTO = meridian_pb.Hyperparameters(
4322
+ media_effects_dist=_MediaEffectsDist.LOG_NORMAL,
4323
+ hill_before_adstock=False,
4324
+ max_lag=8,
4325
+ unique_sigma_for_each_geo=False,
4326
+ media_prior_type=_PaidMediaPriorType.PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED,
4327
+ rf_prior_type=_PaidMediaPriorType.PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED,
4328
+ paid_media_prior_type=_PaidMediaPriorType.PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED,
4329
+ organic_media_prior_type=_NonPaidTreatmentsPriorType.NON_PAID_TREATMENTS_PRIOR_TYPE_CONTRIBUTION,
4330
+ organic_rf_prior_type=_NonPaidTreatmentsPriorType.NON_PAID_TREATMENTS_PRIOR_TYPE_CONTRIBUTION,
4331
+ non_media_treatments_prior_type=_NonPaidTreatmentsPriorType.NON_PAID_TREATMENTS_PRIOR_TYPE_CONTRIBUTION,
4332
+ enable_aks=False,
4333
+ global_adstock_decay='geometric',
4334
+ )
4335
+
4336
+
4337
+ def get_custom_model_spec_1() -> spec.ModelSpec:
4338
+ return spec.ModelSpec(
4339
+ prior=prior_distribution.PriorDistribution(),
4340
+ media_effects_dist=c.MEDIA_EFFECTS_NORMAL,
4341
+ hill_before_adstock=True,
4342
+ max_lag=777,
4343
+ unique_sigma_for_each_geo=True,
4344
+ media_prior_type=c.TREATMENT_PRIOR_TYPE_MROI,
4345
+ rf_prior_type=c.TREATMENT_PRIOR_TYPE_MROI,
4346
+ knots=2,
4347
+ baseline_geo='baseline_geo',
4348
+ roi_calibration_period=None,
4349
+ rf_roi_calibration_period=None,
4350
+ holdout_id=None,
4351
+ control_population_scaling_id=None,
4352
+ adstock_decay_spec='binomial',
4353
+ )
4354
+
4355
+
4356
+ CUSTOM_HYPERPARAMETERS_PROTO_1 = meridian_pb.Hyperparameters(
4357
+ media_effects_dist=_MediaEffectsDist.NORMAL,
4358
+ hill_before_adstock=True,
4359
+ max_lag=777,
4360
+ unique_sigma_for_each_geo=True,
4361
+ media_prior_type=_PaidMediaPriorType.MROI,
4362
+ rf_prior_type=_PaidMediaPriorType.MROI,
4363
+ paid_media_prior_type=_PaidMediaPriorType.PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED,
4364
+ organic_media_prior_type=_NonPaidTreatmentsPriorType.NON_PAID_TREATMENTS_PRIOR_TYPE_CONTRIBUTION,
4365
+ organic_rf_prior_type=_NonPaidTreatmentsPriorType.NON_PAID_TREATMENTS_PRIOR_TYPE_CONTRIBUTION,
4366
+ non_media_treatments_prior_type=_NonPaidTreatmentsPriorType.NON_PAID_TREATMENTS_PRIOR_TYPE_CONTRIBUTION,
4367
+ knots=[2],
4368
+ baseline_geo_string='baseline_geo',
4369
+ enable_aks=False,
4370
+ global_adstock_decay='binomial',
4371
+ )
4372
+
4373
+
4374
+ def get_custom_model_spec_2() -> spec.ModelSpec:
4375
+ return spec.ModelSpec(
4376
+ prior=prior_distribution.PriorDistribution(),
4377
+ media_effects_dist='log_normal',
4378
+ hill_before_adstock=True,
4379
+ max_lag=777,
4380
+ unique_sigma_for_each_geo=True,
4381
+ media_prior_type=c.TREATMENT_PRIOR_TYPE_ROI,
4382
+ rf_prior_type=c.TREATMENT_PRIOR_TYPE_ROI,
4383
+ organic_media_prior_type=c.TREATMENT_PRIOR_TYPE_CONTRIBUTION,
4384
+ organic_rf_prior_type=c.TREATMENT_PRIOR_TYPE_COEFFICIENT,
4385
+ non_media_treatments_prior_type=c.TREATMENT_PRIOR_TYPE_COEFFICIENT,
4386
+ non_media_baseline_values=['min', 0.5, 'max'],
4387
+ knots=[1, 5, 8],
4388
+ baseline_geo=3,
4389
+ roi_calibration_period=np.full((2, 3), True),
4390
+ rf_roi_calibration_period=np.full((4, 5), False),
4391
+ holdout_id=np.full((6,), True),
4392
+ control_population_scaling_id=np.full((7, 8), False),
4393
+ non_media_population_scaling_id=np.full((9, 10), False),
4394
+ adstock_decay_spec={'ch_paid_0': 'binomial', 'rf_ch_paid_1': 'geometric'},
4395
+ )
4396
+
4397
+ CUSTOM_HYPERPARAMETERS_PROTO_2 = meridian_pb.Hyperparameters(
4398
+ media_effects_dist=_MediaEffectsDist.LOG_NORMAL,
4399
+ hill_before_adstock=True,
4400
+ max_lag=777,
4401
+ unique_sigma_for_each_geo=True,
4402
+ media_prior_type=_PaidMediaPriorType.ROI,
4403
+ rf_prior_type=_PaidMediaPriorType.ROI,
4404
+ paid_media_prior_type=_PaidMediaPriorType.PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED,
4405
+ organic_media_prior_type=_NonPaidTreatmentsPriorType.NON_PAID_TREATMENTS_PRIOR_TYPE_CONTRIBUTION,
4406
+ organic_rf_prior_type=_NonPaidTreatmentsPriorType.NON_PAID_TREATMENTS_PRIOR_TYPE_COEFFICIENT,
4407
+ non_media_treatments_prior_type=_NonPaidTreatmentsPriorType.NON_PAID_TREATMENTS_PRIOR_TYPE_COEFFICIENT,
4408
+ knots=[1, 5, 8],
4409
+ baseline_geo_int=3,
4410
+ roi_calibration_period=make_tensor_proto(
4411
+ dims=[2, 3],
4412
+ dtype=types_pb2.DT_BOOL,
4413
+ bool_vals=[True] * (2 * 3),
4414
+ ),
4415
+ rf_roi_calibration_period=make_tensor_proto(
4416
+ dims=[4, 5],
4417
+ dtype=types_pb2.DT_BOOL,
4418
+ bool_vals=[False] * (4 * 5),
4419
+ ),
4420
+ holdout_id=make_tensor_proto(
4421
+ dims=[6],
4422
+ dtype=types_pb2.DT_BOOL,
4423
+ bool_vals=[True] * 6,
4424
+ ),
4425
+ control_population_scaling_id=make_tensor_proto(
4426
+ dims=[7, 8],
4427
+ dtype=types_pb2.DT_BOOL,
4428
+ bool_vals=[False] * (7 * 8),
4429
+ ),
4430
+ non_media_population_scaling_id=make_tensor_proto(
4431
+ dims=[9, 10],
4432
+ dtype=types_pb2.DT_BOOL,
4433
+ bool_vals=[False] * (9 * 10),
4434
+ ),
4435
+ non_media_baseline_values=[
4436
+ meridian_pb.NonMediaBaselineValue(
4437
+ function_value=_NonMediaBaselineFunction.MIN
4438
+ ),
4439
+ meridian_pb.NonMediaBaselineValue(value=0.5),
4440
+ meridian_pb.NonMediaBaselineValue(
4441
+ function_value=_NonMediaBaselineFunction.MAX
4442
+ ),
4443
+ ],
4444
+ enable_aks=False,
4445
+ adstock_decay_by_channel=meridian_pb.AdstockDecayByChannel(
4446
+ channel_decays={'ch_paid_0': 'binomial', 'rf_ch_paid_1': 'geometric'}
4447
+ ),
4448
+ )
4449
+
4450
+
4451
+ def get_custom_model_spec_3() -> spec.ModelSpec:
4452
+ return spec.ModelSpec(
4453
+ prior=prior_distribution.PriorDistribution(),
4454
+ media_effects_dist=c.MEDIA_EFFECTS_NORMAL,
4455
+ hill_before_adstock=True,
4456
+ max_lag=777,
4457
+ unique_sigma_for_each_geo=True,
4458
+ media_prior_type=c.TREATMENT_PRIOR_TYPE_MROI,
4459
+ rf_prior_type=c.TREATMENT_PRIOR_TYPE_MROI,
4460
+ organic_media_prior_type=c.TREATMENT_PRIOR_TYPE_CONTRIBUTION,
4461
+ organic_rf_prior_type=c.TREATMENT_PRIOR_TYPE_CONTRIBUTION,
4462
+ non_media_treatments_prior_type=c.TREATMENT_PRIOR_TYPE_CONTRIBUTION,
4463
+ baseline_geo='baseline_geo',
4464
+ roi_calibration_period=None,
4465
+ rf_roi_calibration_period=None,
4466
+ holdout_id=None,
4467
+ control_population_scaling_id=None,
4468
+ enable_aks=True,
4469
+ )
4470
+
4471
+
4472
+ CUSTOM_HYPERPARAMETERS_PROTO_3 = meridian_pb.Hyperparameters(
4473
+ media_effects_dist=_MediaEffectsDist.NORMAL,
4474
+ hill_before_adstock=True,
4475
+ max_lag=777,
4476
+ unique_sigma_for_each_geo=True,
4477
+ media_prior_type=_PaidMediaPriorType.MROI,
4478
+ rf_prior_type=_PaidMediaPriorType.MROI,
4479
+ paid_media_prior_type=_PaidMediaPriorType.PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED,
4480
+ organic_media_prior_type=_NonPaidTreatmentsPriorType.NON_PAID_TREATMENTS_PRIOR_TYPE_CONTRIBUTION,
4481
+ organic_rf_prior_type=_NonPaidTreatmentsPriorType.NON_PAID_TREATMENTS_PRIOR_TYPE_CONTRIBUTION,
4482
+ non_media_treatments_prior_type=_NonPaidTreatmentsPriorType.NON_PAID_TREATMENTS_PRIOR_TYPE_CONTRIBUTION,
4483
+ baseline_geo_string='baseline_geo',
4484
+ enable_aks=True,
4485
+ global_adstock_decay='geometric',
4486
+ )
4487
+
4488
+
4489
+ def _create_tfp_params_from_dict(
4490
+ param_dict: dict[str, Any],
4491
+ distribution: backend.tfd.Distribution | backend.bijectors.Bijector,
4492
+ ) -> dict[str, meridian_pb.TfpParameterValue]:
4493
+ param_dict.update({
4494
+ 'validate_args': False,
4495
+ })
4496
+ return {
4497
+ key: _create_tfp_param(key, value, distribution)
4498
+ for key, value in param_dict.items()
4499
+ }
4500
+
4501
+
4502
+ def create_distribution_proto(
4503
+ distribution_type: str, **kwargs
4504
+ ) -> meridian_pb.TfpDistribution:
4505
+ distribution = getattr(backend.tfd, distribution_type)
4506
+ return meridian_pb.TfpDistribution(
4507
+ distribution_type=distribution_type,
4508
+ parameters=_create_tfp_params_from_dict(kwargs, distribution),
4509
+ )
4510
+
4511
+
4512
+ def create_bijector_proto(
4513
+ bijector_type: str, **kwargs
4514
+ ) -> meridian_pb.TfpBijector:
4515
+ bijector = getattr(backend.bijectors, bijector_type)
4516
+ return meridian_pb.TfpBijector(
4517
+ bijector_type=bijector_type,
4518
+ parameters=_create_tfp_params_from_dict(kwargs, bijector),
4519
+ )
4520
+
4521
+
4522
+ def _create_tfp_param(param_name, param_value, distribution):
4523
+ """Creates a TfpParameterValue object based on the input value's type."""
4524
+ match param_value:
4525
+ case float():
4526
+ return meridian_pb.TfpParameterValue(scalar_value=param_value)
4527
+ case int():
4528
+ return meridian_pb.TfpParameterValue(int_value=param_value)
4529
+ case bool():
4530
+ return meridian_pb.TfpParameterValue(bool_value=param_value)
4531
+ case str():
4532
+ return meridian_pb.TfpParameterValue(string_value=param_value)
4533
+ case None:
4534
+ return meridian_pb.TfpParameterValue(none_value=True)
4535
+ case list():
4536
+ value_generator = (
4537
+ _create_tfp_param(param_name, v, distribution) for v in param_value
4538
+ )
4539
+ tfp_list_value = meridian_pb.TfpParameterValue.List(
4540
+ values=value_generator
4541
+ )
4542
+ return meridian_pb.TfpParameterValue(list_value=tfp_list_value)
4543
+ case dict():
4544
+ dict_value = {
4545
+ key: _create_tfp_param(key, v, distribution)
4546
+ for key, v in param_value.items()
4547
+ }
4548
+ return meridian_pb.TfpParameterValue(dict_value=dict_value)
4549
+ case _ if isinstance(param_value, (np.ndarray, backend.Tensor)):
4550
+ return meridian_pb.TfpParameterValue(
4551
+ tensor_value=backend.make_tensor_proto(param_value)
4552
+ )
4553
+ case meridian_pb.TfpDistribution():
4554
+ return meridian_pb.TfpParameterValue(distribution_value=param_value)
4555
+ case meridian_pb.TfpBijector():
4556
+ return meridian_pb.TfpParameterValue(bijector_value=param_value)
4557
+ case backend.tfd.ReparameterizationType():
4558
+ fully_reparameterized = param_value == backend.tfd.FULLY_REPARAMETERIZED
4559
+ return meridian_pb.TfpParameterValue(
4560
+ fully_reparameterized=fully_reparameterized
4561
+ )
4562
+ case types.FunctionType():
4563
+ # Add custom functions used for tests.
4564
+ test_registry = {'distribution_fn': distribution_fn}
4565
+
4566
+ for function_key, func in test_registry.items():
4567
+ if func == param_value: # pylint: disable=comparison-with-callable
4568
+ return meridian_pb.TfpParameterValue(
4569
+ function_param=meridian_pb.TfpParameterValue.FunctionParam(
4570
+ function_key=function_key
4571
+ )
4572
+ )
4573
+ # Function has default value.
4574
+ signature = inspect.signature(distribution.__init__)
4575
+ param = signature.parameters[param_name]
4576
+ if param.default:
4577
+ return meridian_pb.TfpParameterValue(
4578
+ function_param=meridian_pb.TfpParameterValue.FunctionParam(
4579
+ uses_default=True
4580
+ )
4581
+ )
4582
+ raise TypeError(
4583
+ f'No function found in registry for "{param_value.__name__}"'
4584
+ )
4585
+ case _:
4586
+ # Handle unsupported types.
4587
+ raise TypeError(f'Unsupported type: {type(param_value)}')
4588
+
4589
+
4590
+ # Arbitrary function used for testing `tfd.Autoregressive`.
4591
+ # https://github.com/tensorflow/probability/blob/65f265c62bb1e2d15ef3e25104afb245a6d52429/tensorflow_probability/python/distributions/autoregressive_test.py#L89
4592
+ def distribution_fn(sample0):
4593
+ num_frames = sample0.shape[-1]
4594
+ mask = backend.one_hot(0, num_frames)[:, backend.newaxis]
4595
+ probs = backend.roll(backend.one_hot(sample0, 3), shift=1, axis=-2)
4596
+ probs = probs * (1.0 - mask) + backend.to_tensor([0.5, 0.5, 0]) * mask
4597
+ return backend.tfd.Independent(
4598
+ backend.tfd.Categorical(probs=probs), reinterpreted_batch_ndims=1
4599
+ )
4600
+
4601
+
4602
+ def get_default_kwargs_split_fn():
4603
+ """Returns the default `kwargs_split_fn` used for tfd Distributions."""
4604
+ # `dist` can be any Distribution that has kwargs_split_fn in its signature.
4605
+ dist = backend.tfd.TransformedDistribution
4606
+ signature = inspect.signature(dist.__init__)
4607
+ kwargs_split_fn_param = signature.parameters['kwargs_split_fn']
4608
+ return kwargs_split_fn_param.default