google-meridian 1.1.0__py3-none-any.whl → 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {google_meridian-1.1.0.dist-info → google_meridian-1.1.2.dist-info}/METADATA +6 -2
  2. google_meridian-1.1.2.dist-info/RECORD +46 -0
  3. {google_meridian-1.1.0.dist-info → google_meridian-1.1.2.dist-info}/WHEEL +1 -1
  4. meridian/__init__.py +2 -2
  5. meridian/analysis/__init__.py +1 -1
  6. meridian/analysis/analyzer.py +29 -22
  7. meridian/analysis/formatter.py +1 -1
  8. meridian/analysis/optimizer.py +70 -44
  9. meridian/analysis/summarizer.py +1 -1
  10. meridian/analysis/summary_text.py +1 -1
  11. meridian/analysis/test_utils.py +1 -1
  12. meridian/analysis/visualizer.py +17 -8
  13. meridian/constants.py +3 -3
  14. meridian/data/__init__.py +4 -1
  15. meridian/data/arg_builder.py +1 -1
  16. meridian/data/data_frame_input_data_builder.py +614 -0
  17. meridian/data/input_data.py +12 -8
  18. meridian/data/input_data_builder.py +817 -0
  19. meridian/data/load.py +121 -428
  20. meridian/data/nd_array_input_data_builder.py +509 -0
  21. meridian/data/test_utils.py +60 -43
  22. meridian/data/time_coordinates.py +1 -1
  23. meridian/mlflow/__init__.py +17 -0
  24. meridian/mlflow/autolog.py +54 -0
  25. meridian/model/__init__.py +1 -1
  26. meridian/model/adstock_hill.py +1 -1
  27. meridian/model/knots.py +1 -1
  28. meridian/model/media.py +1 -1
  29. meridian/model/model.py +65 -37
  30. meridian/model/model_test_data.py +75 -1
  31. meridian/model/posterior_sampler.py +19 -15
  32. meridian/model/prior_distribution.py +1 -1
  33. meridian/model/prior_sampler.py +32 -26
  34. meridian/model/spec.py +18 -8
  35. meridian/model/transformers.py +1 -1
  36. google_meridian-1.1.0.dist-info/RECORD +0 -41
  37. {google_meridian-1.1.0.dist-info → google_meridian-1.1.2.dist-info}/licenses/LICENSE +0 -0
  38. {google_meridian-1.1.0.dist-info → google_meridian-1.1.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,817 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """This module defines a Builder API for Meridian `InputData`.
16
+
17
+ The Builder API for `InputData` exposes piecewise data ingestion with its own
18
+ validation logic and an overall final validation logic before a valid
19
+ `InputData` is constructed.
20
+ """
21
+
22
+ import abc
23
+ from collections.abc import Sequence
24
+ import datetime
25
+ import warnings
26
+ from meridian import constants
27
+ from meridian.data import input_data
28
+ from meridian.data import time_coordinates as tc
29
+ import natsort
30
+ import numpy as np
31
+ import xarray as xr
32
+
33
+
34
+ __all__ = [
35
+ 'InputDataBuilder',
36
+ ]
37
+
38
+
39
+ class InputDataBuilder(abc.ABC):
40
+ """Abstract base class for `InputData` builders."""
41
+
42
+ def __init__(self, kpi_type: str):
43
+ self._kpi_type = kpi_type
44
+
45
+ # These working attributes are going to be set along the way as the builder
46
+ # is provided piecemeal with the user's input data.
47
+ # In the course of processing each DataFrame piece, dimension coordinates
48
+ # will be discovered and set with, e.g., `self.time_coords = ...`.
49
+ # The setter code will perform basic validation
50
+ # checks, e.g.:
51
+ # * If previous dataframe input already set it, then it should be consistent
52
+ # * If not, set it for the first time.
53
+ # * When setting, make consistency checks against other dimensions
54
+ # * etc...
55
+
56
+ # Working dimensions and their coordinates.
57
+ self._time_coords: Sequence[str] = None
58
+ self._media_time_coords: Sequence[str] = None
59
+ self._geos: Sequence[str] = None
60
+
61
+ # Working data arrays (components of the final `InputData` object)
62
+ self._kpi: xr.DataArray = None
63
+ self._controls: xr.DataArray = None
64
+ self._population: xr.DataArray = None
65
+ self._revenue_per_kpi: xr.DataArray = None
66
+ self._media: xr.DataArray = None
67
+ self._media_spend: xr.DataArray = None
68
+ self._reach: xr.DataArray = None
69
+ self._frequency: xr.DataArray = None
70
+ self._rf_spend: xr.DataArray = None
71
+ self._organic_media: xr.DataArray = None
72
+ self._organic_reach: xr.DataArray = None
73
+ self._organic_frequency: xr.DataArray = None
74
+ self._non_media_treatments: xr.DataArray = None
75
+
76
+ @property
77
+ def time_coords(self) -> Sequence[str]:
78
+ return self._time_coords
79
+
80
+ @time_coords.setter
81
+ def time_coords(self, value: Sequence[str]):
82
+ if len(value) != len(set(value)):
83
+ raise ValueError('`times` coords must be unique.')
84
+ if self.time_coords is not None and set(self.time_coords) != set(value):
85
+ raise ValueError(f'`times` coords already set to {self.time_coords}.')
86
+ if self.media_time_coords is not None and not set(value).issubset(
87
+ self.media_time_coords
88
+ ):
89
+ raise ValueError(
90
+ '`times` coords must be subset of previously set `media_times`'
91
+ ' coords.'
92
+ )
93
+ if self.media_time_coords is not None:
94
+ self._validate_lagged_media(
95
+ media_time_coords=self.media_time_coords, time_coords=value
96
+ )
97
+ _ = tc.TimeCoordinates.from_dates(sorted(value)).interval_days
98
+ self._time_coords = value
99
+
100
+ @property
101
+ def media_time_coords(self) -> Sequence[str]:
102
+ return self._media_time_coords
103
+
104
+ @media_time_coords.setter
105
+ def media_time_coords(self, value: Sequence[str]):
106
+ if len(value) != len(set(value)):
107
+ raise ValueError('`media_times` coords must be unique.')
108
+ if self.media_time_coords is not None and set(
109
+ self.media_time_coords
110
+ ) != set(value):
111
+ raise ValueError(
112
+ f'`media_times` coords already set to {self.media_time_coords}.'
113
+ )
114
+ if self.time_coords is not None and not set(value).issuperset(
115
+ self.time_coords
116
+ ):
117
+ raise ValueError(
118
+ '`media_times` coords must be superset of previously set `times`'
119
+ ' coords.'
120
+ )
121
+ if self.time_coords is not None:
122
+ self._validate_lagged_media(
123
+ media_time_coords=value, time_coords=self.time_coords
124
+ )
125
+ _ = tc.TimeCoordinates.from_dates(sorted(value)).interval_days
126
+ self._media_time_coords = value
127
+
128
+ @property
129
+ def geos(self) -> Sequence[str]:
130
+ return self._geos
131
+
132
+ @geos.setter
133
+ def geos(self, value: Sequence[str]):
134
+ if len(value) != len(set(value)):
135
+ raise ValueError('Geos must be unique.')
136
+ if self.geos is not None and set(self.geos) != set(value):
137
+ raise ValueError(f'geos already set to {self.geos}.')
138
+ self._geos = value
139
+
140
+ @property
141
+ def kpi(self) -> xr.DataArray:
142
+ return self._kpi
143
+
144
+ @kpi.setter
145
+ def kpi(self, kpi: xr.DataArray):
146
+ """Sets the `kpi` data array.
147
+
148
+ `kpi` must have the following `DataArray` signature:
149
+
150
+ ```
151
+ xarray.DataArray(
152
+ data=...,
153
+ name='kpi',
154
+ dims=['geo', 'time'],
155
+ coords={
156
+ 'geo': ...,
157
+ 'time': ...,
158
+ },
159
+ )
160
+ ```
161
+
162
+ Args:
163
+ kpi: Kpi DataArray.
164
+ """
165
+ self._validate_set('KPI', self.kpi)
166
+
167
+ self._kpi = self._normalize_coords(kpi, constants.TIME)
168
+ self.geos = self.kpi.coords[constants.GEO].values.tolist()
169
+ self.time_coords = self.kpi.coords[constants.TIME].values.tolist()
170
+
171
+ @property
172
+ def controls(self) -> xr.DataArray:
173
+ return self._controls
174
+
175
+ @controls.setter
176
+ def controls(self, controls: xr.DataArray):
177
+ """Sets the `controls` data array.
178
+
179
+ `controls` must have the following `DataArray` signature:
180
+
181
+ ```
182
+ xarray.DataArray(
183
+ data=...,
184
+ name='controls',
185
+ dims=['geo', 'time', 'control_variable'],
186
+ coords={
187
+ 'geo': ...,
188
+ 'time': ...,
189
+ 'control_variable': ...,
190
+ },
191
+ )
192
+ ```
193
+
194
+ Args:
195
+ controls: Controls DataArray.
196
+ """
197
+ self._validate_set('Controls', self.controls)
198
+
199
+ self._controls = self._normalize_coords(controls, constants.TIME)
200
+ self.geos = self.controls.coords[constants.GEO].values.tolist()
201
+ self.time_coords = self.controls.coords[constants.TIME].values.tolist()
202
+
203
+ @property
204
+ def population(self) -> xr.DataArray:
205
+ return self._population
206
+
207
+ @population.setter
208
+ def population(self, population: xr.DataArray):
209
+ """Sets the `media` data array.
210
+
211
+ `population` must have the following `DataArray` signature:
212
+
213
+ ```
214
+ xarray.DataArray(
215
+ data=...,
216
+ name='population',
217
+ dims=['geo'],
218
+ coords={
219
+ 'geo': ...,
220
+ },
221
+ )
222
+ ```
223
+
224
+ Args:
225
+ population: Population DataArray.
226
+ """
227
+ self._validate_set('Population', self.population)
228
+
229
+ self._population = self._normalize_coords(population)
230
+ self.geos = self.population.coords[constants.GEO].values.tolist()
231
+
232
+ @property
233
+ def revenue_per_kpi(self) -> xr.DataArray:
234
+ return self._revenue_per_kpi
235
+
236
+ @revenue_per_kpi.setter
237
+ def revenue_per_kpi(self, revenue_per_kpi: xr.DataArray):
238
+ """Sets the `revenue_per_kpi` data array.
239
+
240
+ `revenue_per_kpi` must have the following `DataArray` signature:
241
+
242
+ ```
243
+ xarray.DataArray(
244
+ data=...,
245
+ name='revenue_per_kpi',
246
+ dims=['geo', 'time'],
247
+ coords={
248
+ 'geo': ...,
249
+ 'time': ...,
250
+ },
251
+ )
252
+ ```
253
+ Args:
254
+ revenue_per_kpi: Revenue per kpi DataArray.
255
+ """
256
+ self._validate_set('Revenue per KPI', self.revenue_per_kpi)
257
+
258
+ self._revenue_per_kpi = self._normalize_coords(
259
+ revenue_per_kpi, constants.TIME
260
+ )
261
+ self.geos = self.revenue_per_kpi.coords[constants.GEO].values.tolist()
262
+ self.time_coords = self.revenue_per_kpi.coords[
263
+ constants.TIME
264
+ ].values.tolist()
265
+
266
+ @property
267
+ def media(self) -> xr.DataArray:
268
+ return self._media
269
+
270
+ @media.setter
271
+ def media(self, media: xr.DataArray):
272
+ """Sets the `media` data array.
273
+
274
+ `media` must have the following `DataArray` signature:
275
+
276
+ ```
277
+ xarray.DataArray(
278
+ data=...,
279
+ name='media',
280
+ dims=['geo', 'media_time', 'media_channel'],
281
+ coords={
282
+ 'geo': ...,
283
+ 'media_time': ...,
284
+ 'media_channel': ...,
285
+ },
286
+ )
287
+ ```
288
+
289
+ Args:
290
+ media: Media DataArray.
291
+ """
292
+ self._validate_set('Media', self.media)
293
+ self._validate_channels_consistency(
294
+ constants.MEDIA_CHANNEL, [media, self.media_spend]
295
+ )
296
+
297
+ self._media = self._normalize_coords(media, constants.MEDIA_TIME)
298
+ self.geos = self.media.coords[constants.GEO].values.tolist()
299
+ self.media_time_coords = self.media.coords[
300
+ constants.MEDIA_TIME
301
+ ].values.tolist()
302
+
303
+ @property
304
+ def media_spend(self) -> xr.DataArray:
305
+ return self._media_spend
306
+
307
+ @media_spend.setter
308
+ def media_spend(self, media_spend: xr.DataArray):
309
+ """Sets the `media_spend` data array.
310
+
311
+ `media_spend` must have the following `DataArray` signature:
312
+
313
+ ```
314
+ xarray.DataArray(
315
+ data=...,
316
+ name='media_spend',
317
+ dims=['geo', 'time', 'media_channel'],
318
+ coords={
319
+ 'geo': ...,
320
+ 'time': ...,
321
+ 'media_channel': ...,
322
+ },
323
+ )
324
+ ```
325
+
326
+ Args:
327
+ media_spend: Media spend DataArray.
328
+ """
329
+ self._validate_set('Media spend', self.media_spend)
330
+ self._validate_channels_consistency(
331
+ constants.MEDIA_CHANNEL, [media_spend, self.media]
332
+ )
333
+
334
+ self._media_spend = self._normalize_coords(media_spend, constants.TIME)
335
+ self.geos = self.media_spend.coords[constants.GEO].values.tolist()
336
+ self.time_coords = self.media_spend.coords[constants.TIME].values.tolist()
337
+
338
+ @property
339
+ def reach(self) -> xr.DataArray:
340
+ return self._reach
341
+
342
+ @reach.setter
343
+ def reach(self, reach: xr.DataArray):
344
+ """Sets the `reach` data array.
345
+
346
+ `reach` must have the following `DataArray` signature:
347
+
348
+ ```
349
+ xarray.DataArray(
350
+ data=...,
351
+ name='reach',
352
+ dims=['geo', 'media_time', 'rf_channel'],
353
+ coords={
354
+ 'geo': ...,
355
+ 'media_time': ...,
356
+ 'rf_channel': ...,
357
+ },
358
+ )
359
+ ```
360
+
361
+ Args:
362
+ reach: Reach DataArray.
363
+ """
364
+ self._validate_set('Reach', self.reach)
365
+ self._validate_channels_consistency(
366
+ constants.RF_CHANNEL, [reach, self.frequency, self.rf_spend]
367
+ )
368
+
369
+ self._reach = self._normalize_coords(reach, constants.MEDIA_TIME)
370
+ self.geos = self.reach.coords[constants.GEO].values.tolist()
371
+ self.media_time_coords = self.reach.coords[
372
+ constants.MEDIA_TIME
373
+ ].values.tolist()
374
+
375
+ @property
376
+ def frequency(self) -> xr.DataArray:
377
+ return self._frequency
378
+
379
+ @frequency.setter
380
+ def frequency(self, frequency: xr.DataArray):
381
+ """Sets the `frequency` data array.
382
+
383
+ `frequency` must have the following `DataArray` signature:
384
+
385
+ ```
386
+ xarray.DataArray(
387
+ data=...,
388
+ name='frequency',
389
+ dims=['geo', 'media_time', 'rf_channel'],
390
+ coords={
391
+ 'geo': ...,
392
+ 'media_time': ...,
393
+ 'rf_channel': ...,
394
+ },
395
+ )
396
+ ```
397
+
398
+ Args:
399
+ frequency: Frequency DataArray.
400
+ """
401
+ self._validate_set('Frequency', self.frequency)
402
+ self._validate_channels_consistency(
403
+ constants.RF_CHANNEL, [frequency, self.reach, self.rf_spend]
404
+ )
405
+
406
+ self._frequency = self._normalize_coords(frequency, constants.MEDIA_TIME)
407
+ self.geos = self.frequency.coords[constants.GEO].values.tolist()
408
+ self.media_time_coords = self.frequency.coords[
409
+ constants.MEDIA_TIME
410
+ ].values.tolist()
411
+
412
+ @property
413
+ def rf_spend(self) -> xr.DataArray:
414
+ return self._rf_spend
415
+
416
+ @rf_spend.setter
417
+ def rf_spend(self, rf_spend: xr.DataArray):
418
+ """Sets the `rf_spend` data array.
419
+
420
+ `rf_spend` must have the following `DataArray` signature:
421
+
422
+ ```
423
+ xarray.DataArray(
424
+ data=...,
425
+ name='rf_spend',
426
+ dims=['geo', 'time', 'rf_channel'],
427
+ coords={
428
+ 'geo': ...,
429
+ 'time': ...,
430
+ 'rf_channel': ...,
431
+ },
432
+ )
433
+ ```
434
+
435
+ Args:
436
+ rf_spend: RF spend DataArray.
437
+ """
438
+ self._validate_set('RF spend', self.rf_spend)
439
+ self._validate_channels_consistency(
440
+ constants.RF_CHANNEL, [rf_spend, self.reach, self.frequency]
441
+ )
442
+
443
+ self._rf_spend = self._normalize_coords(rf_spend, constants.TIME)
444
+ self.geos = self.rf_spend.coords[constants.GEO].values.tolist()
445
+ self.time_coords = self.rf_spend.coords[constants.TIME].values.tolist()
446
+
447
+ @property
448
+ def organic_media(self) -> xr.DataArray:
449
+ return self._organic_media
450
+
451
+ @organic_media.setter
452
+ def organic_media(self, organic_media: xr.DataArray):
453
+ """Sets the `organic_media` data array.
454
+
455
+ `organic_media` must have the following `DataArray` signature:
456
+
457
+ ```
458
+ xarray.DataArray(
459
+ data=...,
460
+ name='organic_media',
461
+ dims=['geo', 'media_time'],
462
+ coords={
463
+ 'geo': ...,
464
+ 'media_time': ...,
465
+ },
466
+ )
467
+ ```
468
+
469
+ Args:
470
+ organic_media: Organic media DataArray.
471
+ """
472
+ self._validate_set('Organic media', self.organic_media)
473
+
474
+ self._organic_media = self._normalize_coords(
475
+ organic_media, constants.MEDIA_TIME
476
+ )
477
+ self.geos = self.organic_media.coords[constants.GEO].values.tolist()
478
+ self.media_time_coords = self.organic_media.coords[
479
+ constants.MEDIA_TIME
480
+ ].values.tolist()
481
+
482
+ @property
483
+ def organic_reach(self) -> xr.DataArray:
484
+ return self._organic_reach
485
+
486
+ @organic_reach.setter
487
+ def organic_reach(self, organic_reach: xr.DataArray):
488
+ """Sets the `organic_reach` data array.
489
+
490
+ `organic_reach` must have the following `DataArray` signature:
491
+
492
+ ```
493
+ xarray.DataArray(
494
+ data=...,
495
+ name='organic_reach',
496
+ dims=['geo', 'media_time', 'organic_rf_channel'],
497
+ coords={
498
+ 'geo': ...,
499
+ 'media_time': ...,
500
+ 'organic_rf_channel': ...,
501
+ },
502
+ )
503
+ ```
504
+
505
+ Args:
506
+ organic_reach: Organic reach DataArray.
507
+ """
508
+ self._validate_set('Organic reach', self.organic_reach)
509
+ self._validate_channels_consistency(
510
+ constants.ORGANIC_RF_CHANNEL, [organic_reach, self.organic_frequency]
511
+ )
512
+
513
+ self._organic_reach = self._normalize_coords(
514
+ organic_reach, constants.MEDIA_TIME
515
+ )
516
+ self.geos = self.organic_reach.coords[constants.GEO].values.tolist()
517
+ self.media_time_coords = self.organic_reach.coords[
518
+ constants.MEDIA_TIME
519
+ ].values.tolist()
520
+
521
+ @property
522
+ def organic_frequency(self) -> xr.DataArray:
523
+ return self._organic_frequency
524
+
525
+ @organic_frequency.setter
526
+ def organic_frequency(self, organic_frequency: xr.DataArray):
527
+ """Sets the `organic_frequency` data array.
528
+
529
+ `organic_frequency` must have the following `DataArray` signature:
530
+
531
+ ```
532
+ xarray.DataArray(
533
+ data=...,
534
+ name='organic_frequency',
535
+ dims=['geo', 'media_time', 'organic_rf_channel'],
536
+ coords={
537
+ 'geo': ...,
538
+ 'media_time': ...,
539
+ 'organic_rf_channel': ...,
540
+ },
541
+ )
542
+ ```
543
+
544
+ Args:
545
+ organic_frequency: Organic frequency DataArray.
546
+ """
547
+ self._validate_set('Organic frequency', self.organic_frequency)
548
+ self._validate_channels_consistency(
549
+ constants.ORGANIC_RF_CHANNEL, [organic_frequency, self.organic_reach]
550
+ )
551
+
552
+ self._organic_frequency = self._normalize_coords(
553
+ organic_frequency, constants.MEDIA_TIME
554
+ )
555
+ self.geos = self.organic_frequency.coords[constants.GEO].values.tolist()
556
+ self.media_time_coords = self.organic_frequency.coords[
557
+ constants.MEDIA_TIME
558
+ ].values.tolist()
559
+
560
+ @property
561
+ def non_media_treatments(self) -> xr.DataArray:
562
+ return self._non_media_treatments
563
+
564
+ @non_media_treatments.setter
565
+ def non_media_treatments(self, non_media_treatments: xr.DataArray):
566
+ """Sets the `non media treatments` data array.
567
+
568
+ `non_media_treatments` must have the following `DataArray` signature:
569
+
570
+ ```
571
+ xarray.DataArray(
572
+ data=...,
573
+ name='non_media_treatments',
574
+ dims=['geo', 'time', 'non_media_channel'],
575
+ coords={
576
+ 'geo': ...,
577
+ 'time': ...,
578
+ 'non_media_channel': ...,
579
+ },
580
+ )
581
+ ```
582
+
583
+ Args:
584
+ non_media_treatments: Non-media treatments DataArray.
585
+ """
586
+ self._validate_set('Non-media treatments', self.non_media_treatments)
587
+
588
+ self._non_media_treatments = self._normalize_coords(
589
+ non_media_treatments, constants.TIME
590
+ )
591
+ self.geos = self.non_media_treatments.coords[constants.GEO].values.tolist()
592
+ self.time_coords = self.non_media_treatments.coords[
593
+ constants.TIME
594
+ ].values.tolist()
595
+
596
+ def build(self) -> input_data.InputData:
597
+ """Builds an `InputData`.
598
+
599
+ Constructs an `InputData` from constituent `DataArray`s given to this
600
+ builder thus far after performing one final validation pass over all data
601
+ arrays for consistency checks.
602
+
603
+ Returns:
604
+ A validated `InputData`.
605
+ """
606
+ self._validate_required_components()
607
+ self._validate_nas()
608
+
609
+ # TODO: move logic from input_data to here: all channel names
610
+ # should be unique across media channels, rf channels, organic media
611
+ # channels, and organic rf channels.
612
+ sorted_geos = natsort.natsorted(self.geos)
613
+ sorted_times = natsort.natsorted(self.time_coords)
614
+ sorted_media_times = natsort.natsorted(self.media_time_coords)
615
+
616
+ def _get_sorted(da: xr.DataArray | None, is_media_time: bool = False):
617
+ """Naturally sorts the DataArray by geo and time/media time."""
618
+
619
+ if da is None:
620
+ return None
621
+ if is_media_time:
622
+ return da.reindex(geo=sorted_geos, media_time=sorted_media_times)
623
+ else:
624
+ return da.reindex(geo=sorted_geos, time=sorted_times)
625
+
626
+ return input_data.InputData(
627
+ kpi_type=self._kpi_type,
628
+ kpi=_get_sorted(self.kpi),
629
+ revenue_per_kpi=_get_sorted(self.revenue_per_kpi),
630
+ controls=_get_sorted(self.controls),
631
+ population=self.population.reindex(geo=sorted_geos),
632
+ media=_get_sorted(self.media, True),
633
+ media_spend=_get_sorted(self.media_spend),
634
+ reach=_get_sorted(self.reach, True),
635
+ frequency=_get_sorted(self.frequency, True),
636
+ rf_spend=_get_sorted(self.rf_spend),
637
+ non_media_treatments=_get_sorted(self.non_media_treatments),
638
+ organic_media=_get_sorted(self.organic_media, True),
639
+ organic_reach=_get_sorted(self.organic_reach, True),
640
+ organic_frequency=_get_sorted(self.organic_frequency, True),
641
+ )
642
+
643
+ def _normalize_coords(
644
+ self, da: xr.DataArray, time_dimension_name: str | None = None
645
+ ) -> xr.DataArray:
646
+ """Normalizes the given `DataArray`'s coordinates in Meridian convention.
647
+
648
+ Validates that time values are in the conventional Meridian format and
649
+ that geos have national name if national.
650
+
651
+ Args:
652
+ da: The DataArray to normalize.
653
+ time_dimension_name: The name of the time dimension. If None, the
654
+ will skip time normalization.
655
+
656
+ Returns:
657
+ The normalized DataArray.
658
+ """
659
+ if time_dimension_name is not None:
660
+ # Time values are expected to be
661
+ # (a) strings formatted in `"yyyy-mm-dd"`
662
+ # or
663
+ # (b) `datetime` values as numpy's `datetime64` types.
664
+ # All other types are not currently supported.
665
+
666
+ # If (b), `datetime` coord values will be normalized as formatted strings.
667
+
668
+ if da.coords.dtypes[time_dimension_name] == np.dtype('datetime64[ns]'):
669
+ date_strvalues = np.datetime_as_string(
670
+ da.coords[time_dimension_name], unit='D'
671
+ )
672
+ da = da.assign_coords({time_dimension_name: date_strvalues})
673
+
674
+ # Assume that the time coordinate labels are date-formatted strings.
675
+ # We don't currently support other, arbitrary object types in the builder.
676
+ for time in da.coords[time_dimension_name].values:
677
+ try:
678
+ _ = datetime.datetime.strptime(time, constants.DATE_FORMAT)
679
+ except ValueError as exc:
680
+ raise ValueError(
681
+ f"Invalid time label: '{time}'. Expected format:"
682
+ f" '{constants.DATE_FORMAT}'"
683
+ ) from exc
684
+
685
+ if len(da.coords[constants.GEO].values.tolist()) == 1:
686
+ da = da.assign_coords(
687
+ {constants.GEO: [constants.NATIONAL_MODEL_DEFAULT_GEO_NAME]},
688
+ )
689
+ return da
690
+
691
+ def _validate_set(self, component: str, da: xr.DataArray):
692
+ if da is not None:
693
+ raise ValueError(f'{component} was already set to {da}.')
694
+
695
+ def _validate_channels_consistency(
696
+ self, channel_dimension_name: str, da_list: list[xr.DataArray | None]
697
+ ):
698
+ for da in da_list:
699
+ if da is not None and set(
700
+ da.coords[channel_dimension_name].values.tolist()
701
+ ) != set(da_list[0].coords[channel_dimension_name].values.tolist()):
702
+ raise ValueError(
703
+ f'{channel_dimension_name} coordinates must be the same between'
704
+ f' {[da.name for da in da_list if da is not None]}.'
705
+ )
706
+
707
+ def _validate_required_components(self):
708
+ """Validates that all required data arrays are defined."""
709
+ if self.kpi is None:
710
+ raise ValueError('KPI is required.')
711
+
712
+ if len(self.geos) == 1:
713
+ if self.population is not None:
714
+ warnings.warn(
715
+ 'The `population` argument is ignored in a nationally aggregated'
716
+ ' model. It will be reset to [1, 1, ..., 1]'
717
+ )
718
+ self._population = xr.DataArray(
719
+ [constants.NATIONAL_MODEL_DEFAULT_POPULATION_VALUE],
720
+ dims=[constants.GEO],
721
+ coords={
722
+ constants.GEO: self.geos,
723
+ },
724
+ name=constants.POPULATION,
725
+ )
726
+ if self.population is None:
727
+ raise ValueError('Population is required for non national models.')
728
+
729
+ if (self.media is None) ^ (self.media_spend is None):
730
+ raise ValueError('Media and media spend must be provided together.')
731
+ if (
732
+ self.reach is not None
733
+ or self.frequency is not None
734
+ or self.rf_spend is not None
735
+ ) and (
736
+ self.reach is None or self.frequency is None or self.rf_spend is None
737
+ ):
738
+ raise ValueError(
739
+ 'Reach, frequency, and rf_spend must be provided together.'
740
+ )
741
+ if (self.organic_reach is None) ^ (self.organic_frequency is None):
742
+ raise ValueError(
743
+ 'Organic reach and organic frequency must be provided together.'
744
+ )
745
+ if (
746
+ self.reach is None
747
+ and self.frequency is None
748
+ and self.rf_spend is None
749
+ and self.media_spend is None
750
+ and self.media is None
751
+ ):
752
+ raise ValueError(
753
+ 'It is required to have at least one of media or reach + frequency.'
754
+ )
755
+
756
+ def _validate_nas(self):
757
+ """Check for NAs in all of the DataArrays.
758
+
759
+ Since the DataArray components should already distinguish between media time
760
+ and time coords, there are no media times to infer so there should be no
761
+ NAs.
762
+ """
763
+ if self.kpi.isnull().any(axis=None):
764
+ raise ValueError('NA values found in the kpi data.')
765
+ if self.population.isnull().any(axis=None):
766
+ raise ValueError('NA values found in the population data.')
767
+ if self.controls is not None and self.controls.isnull().any(axis=None):
768
+ raise ValueError('NA values found in the controls data.')
769
+ if self.revenue_per_kpi is not None and self.revenue_per_kpi.isnull().any(
770
+ axis=None
771
+ ):
772
+ raise ValueError('NA values found in the revenue per kpi data.')
773
+ if self.media_spend is not None and self.media_spend.isnull().any(
774
+ axis=None
775
+ ):
776
+ raise ValueError('NA values found in the media spend data.')
777
+ if self.rf_spend is not None and self.rf_spend.isnull().any(axis=None):
778
+ raise ValueError('NA values found in the rf spend data.')
779
+ if (
780
+ self.non_media_treatments is not None
781
+ and self.non_media_treatments.isnull().any(axis=None)
782
+ ):
783
+ raise ValueError('NA values found in the non media treatments data.')
784
+
785
+ if self.media is not None and self.media.isnull().any(axis=None):
786
+ raise ValueError('NA values found in the media data.')
787
+
788
+ if self.reach is not None and self.reach.isnull().any(axis=None):
789
+ raise ValueError('NA values found in the reach data.')
790
+ if self.frequency is not None and self.frequency.isnull().any(axis=None):
791
+ raise ValueError('NA values found in the frequency data.')
792
+
793
+ if self.organic_media is not None and self.organic_media.isnull().any(
794
+ axis=None
795
+ ):
796
+ raise ValueError('NA values found in the organic media data.')
797
+
798
+ if self.organic_reach is not None and self.organic_reach.isnull().any(
799
+ axis=None
800
+ ):
801
+ raise ValueError('NA values found in the organic reach data.')
802
+ if (
803
+ self.organic_frequency is not None
804
+ and self.organic_frequency.isnull().any(axis=None)
805
+ ):
806
+ raise ValueError('NA values found in the organic frequency data.')
807
+
808
+ def _validate_lagged_media(
809
+ self, media_time_coords: Sequence[str], time_coords: Sequence[str]
810
+ ):
811
+ na_period = np.sort(list(set(media_time_coords) - set(time_coords)))
812
+ if not np.all(na_period == np.sort(media_time_coords)[: len(na_period)]):
813
+ raise ValueError(
814
+ "The 'lagged media' period (period with 100% NA values in all"
815
+ f' non-media columns) {na_period} is not a continuous window'
816
+ ' starting from the earliest time period.'
817
+ )