google-meridian 1.1.0__py3-none-any.whl → 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {google_meridian-1.1.0.dist-info → google_meridian-1.1.2.dist-info}/METADATA +6 -2
  2. google_meridian-1.1.2.dist-info/RECORD +46 -0
  3. {google_meridian-1.1.0.dist-info → google_meridian-1.1.2.dist-info}/WHEEL +1 -1
  4. meridian/__init__.py +2 -2
  5. meridian/analysis/__init__.py +1 -1
  6. meridian/analysis/analyzer.py +29 -22
  7. meridian/analysis/formatter.py +1 -1
  8. meridian/analysis/optimizer.py +70 -44
  9. meridian/analysis/summarizer.py +1 -1
  10. meridian/analysis/summary_text.py +1 -1
  11. meridian/analysis/test_utils.py +1 -1
  12. meridian/analysis/visualizer.py +17 -8
  13. meridian/constants.py +3 -3
  14. meridian/data/__init__.py +4 -1
  15. meridian/data/arg_builder.py +1 -1
  16. meridian/data/data_frame_input_data_builder.py +614 -0
  17. meridian/data/input_data.py +12 -8
  18. meridian/data/input_data_builder.py +817 -0
  19. meridian/data/load.py +121 -428
  20. meridian/data/nd_array_input_data_builder.py +509 -0
  21. meridian/data/test_utils.py +60 -43
  22. meridian/data/time_coordinates.py +1 -1
  23. meridian/mlflow/__init__.py +17 -0
  24. meridian/mlflow/autolog.py +54 -0
  25. meridian/model/__init__.py +1 -1
  26. meridian/model/adstock_hill.py +1 -1
  27. meridian/model/knots.py +1 -1
  28. meridian/model/media.py +1 -1
  29. meridian/model/model.py +65 -37
  30. meridian/model/model_test_data.py +75 -1
  31. meridian/model/posterior_sampler.py +19 -15
  32. meridian/model/prior_distribution.py +1 -1
  33. meridian/model/prior_sampler.py +32 -26
  34. meridian/model/spec.py +18 -8
  35. meridian/model/transformers.py +1 -1
  36. google_meridian-1.1.0.dist-info/RECORD +0 -41
  37. {google_meridian-1.1.0.dist-info → google_meridian-1.1.2.dist-info}/licenses/LICENSE +0 -0
  38. {google_meridian-1.1.0.dist-info → google_meridian-1.1.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,614 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """An implementation of `InputDataBuilder` with DataFrame primitives."""
16
+
17
+ import logging
18
+ import warnings
19
+
20
+ from meridian import constants
21
+ from meridian.data import input_data_builder
22
+ import pandas as pd
23
+
24
+
25
+ __all__ = [
26
+ 'DataFrameInputDataBuilder',
27
+ ]
28
+
29
+
30
+ class DataFrameInputDataBuilder(input_data_builder.InputDataBuilder):
31
+ """Builds `InputData` from DataFrames."""
32
+
33
+ def with_kpi(
34
+ self,
35
+ df: pd.DataFrame,
36
+ kpi_col: str = constants.KPI,
37
+ time_col: str = constants.TIME,
38
+ geo_col: str = constants.GEO,
39
+ ) -> 'DataFrameInputDataBuilder':
40
+ """Reads KPI data from a DataFrame.
41
+
42
+ Args:
43
+ df: The DataFrame to read the KPI data from.
44
+ kpi_col: The name of the column containing the KPI values. If not
45
+ provided, the default name is `kpi`.
46
+ time_col: The name of the column containing the time coordinates. If not
47
+ provided, the default name is `time`.
48
+ geo_col: (Optional) The name of the column containing the geo coordinates.
49
+ If not provided, the default name is `geo`. If the DataFrame provided
50
+ has no geo column, a national model data is assumed and a geo dimension
51
+ will be created internally with a single coordinate value
52
+ `national_geo`.
53
+
54
+ Returns:
55
+ The `DataFrameInputDataBuilder` with the added KPI data.
56
+ """
57
+ kpi_df = df.copy()
58
+
59
+ ### Validate ###
60
+ self._validate_cols(kpi_df, [kpi_col, time_col], [geo_col])
61
+ self._validate_coords(kpi_df, geo_col, time_col)
62
+
63
+ ### Transform ###
64
+ data = kpi_df.set_index([geo_col, time_col])[kpi_col].dropna()
65
+ self.kpi = (
66
+ data.rename(constants.KPI)
67
+ .rename_axis([constants.GEO, constants.TIME])
68
+ .to_xarray()
69
+ )
70
+ return self
71
+
72
+ def with_controls(
73
+ self,
74
+ df: pd.DataFrame,
75
+ control_cols: list[str],
76
+ time_col: str = constants.TIME,
77
+ geo_col: str = constants.GEO,
78
+ ) -> 'DataFrameInputDataBuilder':
79
+ """Reads controls data from a DataFrame.
80
+
81
+ Args:
82
+ df: The DataFrame to read the controls data from.
83
+ control_cols: The names of the columns containing the controls values.
84
+ time_col: The name of the column containing the time coordinates. If not
85
+ provided, the default name is `time`.
86
+ geo_col: (Optional) The name of the column containing the geo coordinates.
87
+ If not provided, the default name is `geo`. If the DataFrame provided
88
+ has no geo column, a national model data is assumed and a geo dimension
89
+ will be created internally with a single coordinate value
90
+ `national_geo`.
91
+
92
+ Returns:
93
+ The `DataFrameInputDataBuilder` with the added controls data.
94
+ """
95
+ controls_df = df.copy()
96
+
97
+ ### Validate ###
98
+ self._validate_cols(
99
+ controls_df,
100
+ control_cols + [time_col],
101
+ [geo_col],
102
+ )
103
+ self._validate_coords(controls_df, geo_col, time_col)
104
+
105
+ ### Transform ###
106
+ data = controls_df.set_index([geo_col, time_col])[control_cols].stack()
107
+ self.controls = (
108
+ data.rename(constants.CONTROLS)
109
+ .rename_axis(
110
+ [constants.GEO, constants.TIME, constants.CONTROL_VARIABLE]
111
+ )
112
+ .to_xarray()
113
+ )
114
+ return self
115
+
116
+ def with_population(
117
+ self,
118
+ df: pd.DataFrame,
119
+ population_col: str = constants.POPULATION,
120
+ geo_col: str = constants.GEO,
121
+ ) -> 'DataFrameInputDataBuilder':
122
+ """Reads population data from a DataFrame.
123
+
124
+ Args:
125
+ df: The DataFrame to read the population data from.
126
+ population_col: The name of the column containing the population values.
127
+ If not provided, the default name is `population`.
128
+ geo_col: (Optional) The name of the column containing the geo coordinates.
129
+ If not provided, the default name is `geo`. If the DataFrame provided
130
+ has no geo column, a national model data is assumed and a geo dimension
131
+ will be created internally with a single coordinate value
132
+ `national_geo`.
133
+
134
+ Returns:
135
+ The `DataFrameInputDataBuilder` with the added population data.
136
+ """
137
+ population_df = df.copy()
138
+
139
+ ### Validate ###
140
+ self._validate_cols(population_df, [population_col], [geo_col])
141
+ self._validate_coords(population_df, geo_col)
142
+
143
+ ### Transform ###
144
+ data = (
145
+ population_df.set_index([geo_col])[population_col]
146
+ .groupby(geo_col)
147
+ .mean()
148
+ )
149
+ self.population = (
150
+ data.rename(constants.POPULATION)
151
+ .rename_axis([constants.GEO])
152
+ .to_xarray()
153
+ )
154
+
155
+ return self
156
+
157
+ def with_revenue_per_kpi(
158
+ self,
159
+ df: pd.DataFrame,
160
+ revenue_per_kpi_col: str = constants.REVENUE_PER_KPI,
161
+ time_col: str = constants.TIME,
162
+ geo_col: str = constants.GEO,
163
+ ) -> 'DataFrameInputDataBuilder':
164
+ """Reads revenue per KPI data from a DataFrame.
165
+
166
+ Args:
167
+ df: The DataFrame to read the revenue per KPI data from.
168
+ revenue_per_kpi_col: The name of the column containing the revenue per KPI
169
+ values. If not provided, the default name is `revenue_per_kpi`.
170
+ time_col: The name of the column containing the time coordinates. If not
171
+ provided, the default name is `time`.
172
+ geo_col: (Optional) The name of the column containing the geo coordinates.
173
+ If not provided, the default name is `geo`. If the DataFrame provided
174
+ has no geo column, a national model data is assumed and a geo dimension
175
+ will be created internally with a single coordinate value
176
+ `national_geo`.
177
+
178
+ Returns:
179
+ The `DataFrameInputDataBuilder` with the added revenue per KPI data.
180
+ """
181
+ revenue_per_kpi_df = df.copy()
182
+
183
+ ### Validate ###
184
+ self._validate_cols(
185
+ revenue_per_kpi_df,
186
+ [revenue_per_kpi_col, time_col],
187
+ [geo_col],
188
+ )
189
+ self._check_revenue_per_kpi_defaults(
190
+ revenue_per_kpi_df, revenue_per_kpi_col
191
+ )
192
+ self._validate_coords(revenue_per_kpi_df, geo_col, time_col)
193
+
194
+ ### Transform ###
195
+ data = revenue_per_kpi_df.set_index([geo_col, time_col])[
196
+ revenue_per_kpi_col
197
+ ].dropna()
198
+ self.revenue_per_kpi = (
199
+ data.rename(constants.REVENUE_PER_KPI)
200
+ .rename_axis([constants.GEO, constants.TIME])
201
+ .to_xarray()
202
+ )
203
+
204
+ return self
205
+
206
+ def with_media(
207
+ self,
208
+ df: pd.DataFrame,
209
+ media_cols: list[str],
210
+ media_spend_cols: list[str],
211
+ media_channels: list[str],
212
+ time_col: str = constants.TIME,
213
+ geo_col: str = constants.GEO,
214
+ ) -> 'DataFrameInputDataBuilder':
215
+ """Reads media and media spend data from a DataFrame.
216
+
217
+ Args:
218
+ df: The DataFrame to read the media and media spend data from.
219
+ media_cols: The name of the columns containing the media values.
220
+ media_spend_cols: The name of the columns containing the media spend
221
+ values.
222
+ media_channels: The desired media channel coordinate names. Must match
223
+ `media_cols` and `media_spend_cols` in length. These are also index
224
+ mapped.
225
+ time_col: The name of the column containing the time coordinates for media
226
+ spend and media time coordinates for media. If not provided, the default
227
+ name is `time`. Media time coordinates will be shorter than time
228
+ coordinates if media spend values are missing (NaN) for some t in
229
+ `time`. Media time must be equal or a subset of time.
230
+ geo_col: (Optional) The name of the column containing the geo coordinates.
231
+ If not provided, the default name is `geo`. If the DataFrame provided
232
+ has no geo column, a national model data is assumed and a geo dimension
233
+ will be created internally with a single coordinate value
234
+ `national_geo`.
235
+
236
+ Returns:
237
+ The `DataFrameInputDataBuilder` with the added media and media spend data.
238
+ """
239
+ media_df = df.copy()
240
+
241
+ ### Validate ###
242
+ # For a media dataframe, media and media_spend columns may be the same
243
+ # (e.g. if using media spend as media execution value), so here we validate
244
+ # execution and spend columns separately when checking for duplicates.
245
+ self._validate_cols(media_df, media_cols + [time_col], [geo_col])
246
+ self._validate_cols(media_df, media_spend_cols + [time_col], [geo_col])
247
+ self._validate_coords(media_df, geo_col, time_col)
248
+ self._validate_channel_cols(media_channels, [media_cols, media_spend_cols])
249
+ ### Transform ###
250
+ media_data = media_df.set_index([geo_col, time_col])[media_cols]
251
+ media_data.columns = media_channels
252
+ self.media = (
253
+ media_data.stack()
254
+ .rename(constants.MEDIA)
255
+ .rename_axis([
256
+ constants.GEO,
257
+ constants.MEDIA_TIME,
258
+ constants.MEDIA_CHANNEL,
259
+ ])
260
+ .to_xarray()
261
+ )
262
+ media_spend_data = media_df.set_index([geo_col, time_col])[media_spend_cols]
263
+ media_spend_data.columns = media_channels
264
+ self.media_spend = (
265
+ media_spend_data.stack()
266
+ .rename(constants.MEDIA_SPEND)
267
+ .rename_axis([
268
+ constants.GEO,
269
+ constants.TIME,
270
+ constants.MEDIA_CHANNEL,
271
+ ])
272
+ .to_xarray()
273
+ )
274
+ return self
275
+
276
+ def with_reach(
277
+ self,
278
+ df: pd.DataFrame,
279
+ reach_cols: list[str],
280
+ frequency_cols: list[str],
281
+ rf_spend_cols: list[str],
282
+ rf_channels: list[str],
283
+ time_col: str = constants.TIME,
284
+ geo_col: str = constants.GEO,
285
+ ) -> 'DataFrameInputDataBuilder':
286
+ """Reads reach, frequency, and rf spend data from a DataFrame.
287
+
288
+ Args:
289
+ df: The DataFrame to read the reach, frequency, and rf spend data from.
290
+ reach_cols: The name of the columns containing the reach values.
291
+ frequency_cols: The name of the columns containing the frequency values.
292
+ rf_spend_cols: The name of the columns containing the rf spend values.
293
+ rf_channels: The desired rf channel coordinate names. Must match
294
+ `reach_cols`, `frequency_cols`, and `rf_spend_cols` in length. These are
295
+ also index mapped.
296
+ time_col: The name of the column containing the time coordinates for rf
297
+ spend and media time coordinates for reach and frequency. If not
298
+ provided, the default name is `time`. Media time coordinates will be
299
+ shorter than time coordinates if media spend values are missing (NaN)
300
+ for some t in `time`. Media time must be equal or a subset of time.
301
+ geo_col: (Optional) The name of the column containing the geo coordinates.
302
+ If not provided, the default name is `geo`. If the DataFrame provided
303
+ has no geo column, a national model data is assumed and a geo dimension
304
+ will be created internally with a single coordinate value
305
+ `national_geo`.
306
+
307
+ Returns:
308
+ The `DataFrameInputDataBuilder` with the added reach, frequency, and rf
309
+ spend data.
310
+ """
311
+ reach_df = df.copy()
312
+
313
+ ### Validate ###
314
+ self._validate_cols(
315
+ reach_df,
316
+ reach_cols + frequency_cols + rf_spend_cols + [time_col],
317
+ [geo_col],
318
+ )
319
+ self._validate_coords(reach_df, geo_col, time_col)
320
+ self._validate_channel_cols(
321
+ rf_channels,
322
+ [reach_cols, frequency_cols, rf_spend_cols],
323
+ )
324
+
325
+ ### Transform ###
326
+ reach_data = reach_df.set_index([geo_col, time_col])[reach_cols]
327
+ reach_data.columns = rf_channels
328
+ self.reach = (
329
+ reach_data.stack()
330
+ .rename(constants.REACH)
331
+ .rename_axis([
332
+ constants.GEO,
333
+ constants.MEDIA_TIME,
334
+ constants.RF_CHANNEL,
335
+ ])
336
+ .to_xarray()
337
+ )
338
+
339
+ frequency_data = reach_df.set_index([geo_col, time_col])[frequency_cols]
340
+ frequency_data.columns = rf_channels
341
+ self.frequency = (
342
+ frequency_data.stack()
343
+ .rename(constants.FREQUENCY)
344
+ .rename_axis([
345
+ constants.GEO,
346
+ constants.MEDIA_TIME,
347
+ constants.RF_CHANNEL,
348
+ ])
349
+ .to_xarray()
350
+ )
351
+
352
+ rf_spend_data = reach_df.set_index([geo_col, time_col])[rf_spend_cols]
353
+ rf_spend_data.columns = rf_channels
354
+ self.rf_spend = (
355
+ rf_spend_data.stack()
356
+ .rename(constants.RF_SPEND)
357
+ .rename_axis([
358
+ constants.GEO,
359
+ constants.TIME,
360
+ constants.RF_CHANNEL,
361
+ ])
362
+ .to_xarray()
363
+ )
364
+ return self
365
+
366
+ def with_organic_media(
367
+ self,
368
+ df: pd.DataFrame,
369
+ organic_media_cols: list[str],
370
+ organic_media_channels: list[str] | None = None,
371
+ media_time_col: str = constants.MEDIA_TIME,
372
+ geo_col: str = constants.GEO,
373
+ ) -> 'DataFrameInputDataBuilder':
374
+ """Reads organic media data from a DataFrame.
375
+
376
+ Args:
377
+ df: The DataFrame to read the organic media data from.
378
+ organic_media_cols: The name of the columns containing the organic media
379
+ values.
380
+ organic_media_channels: The desired organic media channel coordinate
381
+ names. Will default to the organic media columns if not given. If
382
+ provided, must match `organic_media_cols` in length. This is index
383
+ mapped.
384
+ media_time_col: The name of the column containing the media time
385
+ coordinates. If not provided, the default name is `media_time`.
386
+ geo_col: (Optional) The name of the column containing the geo coordinates.
387
+ If not provided, the default name is `geo`. If the DataFrame provided
388
+ has no geo column, a national model data is assumed and a geo dimension
389
+ will be created internally with a single coordinate value
390
+ `national_geo`.
391
+
392
+ Returns:
393
+ The `DataFrameInputDataBuilder` with the added organic media data.
394
+ """
395
+ organic_media_df = df.copy()
396
+
397
+ ### Validate ###
398
+ if not organic_media_channels:
399
+ organic_media_channels = organic_media_cols
400
+ self._validate_cols(
401
+ organic_media_df,
402
+ organic_media_cols + [media_time_col],
403
+ [geo_col],
404
+ )
405
+ self._validate_coords(organic_media_df, geo_col, media_time_col)
406
+ self._validate_channel_cols(
407
+ organic_media_channels,
408
+ [organic_media_cols],
409
+ )
410
+
411
+ ### Transform ###
412
+ data = organic_media_df.set_index([geo_col, media_time_col])[
413
+ organic_media_cols
414
+ ]
415
+ data.columns = organic_media_channels
416
+ self.organic_media = (
417
+ data.stack()
418
+ .rename(constants.ORGANIC_MEDIA)
419
+ .rename_axis([
420
+ constants.GEO,
421
+ constants.MEDIA_TIME,
422
+ constants.ORGANIC_MEDIA_CHANNEL,
423
+ ])
424
+ .to_xarray()
425
+ )
426
+
427
+ return self
428
+
429
+ def with_organic_reach(
430
+ self,
431
+ df: pd.DataFrame,
432
+ organic_reach_cols: list[str],
433
+ organic_frequency_cols: list[str],
434
+ organic_rf_channels: list[str],
435
+ media_time_col: str = constants.MEDIA_TIME,
436
+ geo_col: str = constants.GEO,
437
+ ) -> 'DataFrameInputDataBuilder':
438
+ """Reads organic reach and organic frequency data from a DataFrame.
439
+
440
+ Args:
441
+ df: The DataFrame to read the organic reach and frequency data from.
442
+ organic_reach_cols: The name of the columns containing the organic reach
443
+ values.
444
+ organic_frequency_cols: The name of the columns containing the organic
445
+ frequency values.
446
+ organic_rf_channels: The desired organic rf channel coordinate names. Must
447
+ match `organic_reach_cols` and `organic_frequency_cols` in length. These
448
+ are also index mapped.
449
+ media_time_col: The name of the column containing the media time
450
+ coordinates. If not provided, the default name is `media_time`.
451
+ geo_col: (Optional) The name of the column containing the geo coordinates.
452
+ If not provided, the default name is `geo`. If the DataFrame provided
453
+ has no geo column, a national model data is assumed and a geo dimension
454
+ will be created internally with a single coordinate value
455
+ `national_geo`.
456
+
457
+ Returns:
458
+ The `DataFrameInputDataBuilder` with the added organic reach and organic
459
+ frequency data.
460
+ """
461
+ organic_reach_frequency_df = df.copy()
462
+
463
+ ### Validate ###
464
+ self._validate_cols(
465
+ organic_reach_frequency_df,
466
+ organic_reach_cols + organic_frequency_cols + [media_time_col],
467
+ [geo_col],
468
+ )
469
+ self._validate_coords(organic_reach_frequency_df, geo_col, media_time_col)
470
+ self._validate_channel_cols(
471
+ organic_rf_channels,
472
+ [organic_reach_cols, organic_frequency_cols],
473
+ )
474
+ ### Transform ###
475
+ organic_reach_data = organic_reach_frequency_df.set_index(
476
+ [geo_col, media_time_col]
477
+ )[organic_reach_cols]
478
+ organic_reach_data.columns = organic_rf_channels
479
+ self.organic_reach = (
480
+ organic_reach_data.stack()
481
+ .rename(constants.ORGANIC_REACH)
482
+ .rename_axis([
483
+ constants.GEO,
484
+ constants.MEDIA_TIME,
485
+ constants.ORGANIC_RF_CHANNEL,
486
+ ])
487
+ .to_xarray()
488
+ )
489
+ organic_frequency_data = organic_reach_frequency_df.set_index(
490
+ [geo_col, media_time_col]
491
+ )[organic_frequency_cols]
492
+ organic_frequency_data.columns = organic_rf_channels
493
+ self.organic_frequency = (
494
+ organic_frequency_data.stack()
495
+ .rename(constants.ORGANIC_FREQUENCY)
496
+ .rename_axis([
497
+ constants.GEO,
498
+ constants.MEDIA_TIME,
499
+ constants.ORGANIC_RF_CHANNEL,
500
+ ])
501
+ .to_xarray()
502
+ )
503
+ return self
504
+
505
+ def with_non_media_treatments(
506
+ self,
507
+ df: pd.DataFrame,
508
+ non_media_treatment_cols: list[str],
509
+ time_col: str = constants.TIME,
510
+ geo_col: str = constants.GEO,
511
+ ) -> 'DataFrameInputDataBuilder':
512
+ """Reads non-media treatments data from a DataFrame.
513
+
514
+ Args:
515
+ df: The DataFrame to read the non-media treatments data from.
516
+ non_media_treatment_cols: The names of the columns containing the
517
+ non-media treatments values.
518
+ time_col: The name of the column containing the time coordinates. If not
519
+ provided, the default name is `time`.
520
+ geo_col: (Optional) The name of the column containing the geo coordinates.
521
+ If not provided, the default name is `geo`. If the DataFrame provided
522
+ has no geo column, a national model data is assumed and a geo dimension
523
+ will be created internally with a single coordinate value
524
+ `national_geo`.
525
+
526
+ Returns:
527
+ The `DataFrameInputDataBuilder` with the added non-media treatments data.
528
+ """
529
+ non_media_treatments_df = df.copy()
530
+
531
+ ### Validate ###
532
+ self._validate_cols(
533
+ non_media_treatments_df,
534
+ non_media_treatment_cols + [time_col],
535
+ [geo_col],
536
+ )
537
+ self._validate_coords(non_media_treatments_df, geo_col, time_col)
538
+
539
+ ### Transform ###
540
+ data = non_media_treatments_df.set_index([geo_col, time_col])[
541
+ non_media_treatment_cols
542
+ ].stack()
543
+ self.non_media_treatments = (
544
+ data.rename(constants.NON_MEDIA_TREATMENTS)
545
+ .rename_axis(
546
+ [constants.GEO, constants.TIME, constants.NON_MEDIA_CHANNEL]
547
+ )
548
+ .to_xarray()
549
+ )
550
+ return self
551
+
552
+ def _validate_cols(
553
+ self, df: pd.DataFrame, required_cols: list[str], optional_cols: list[str]
554
+ ):
555
+ """Validates that the DataFrame has all the expected columns and there are no duplicates."""
556
+ if len(required_cols + optional_cols) != len(
557
+ set(required_cols + optional_cols)
558
+ ):
559
+ raise ValueError(
560
+ 'DataFrame has duplicate columns from'
561
+ f' {required_cols + optional_cols}'
562
+ )
563
+
564
+ if not all(column in df.columns for column in required_cols):
565
+ raise ValueError(
566
+ f'DataFrame is missing one or more columns from {required_cols}'
567
+ )
568
+
569
+ def _validate_coords(
570
+ self,
571
+ df: pd.DataFrame,
572
+ geo_col: str,
573
+ time_col: str | None = None,
574
+ ):
575
+ """Adds geo columns in a national model if necessary and validates that for every geo the list of `time`s is the same for non-population dfs."""
576
+ if geo_col not in df.columns:
577
+ df[geo_col] = constants.NATIONAL_MODEL_DEFAULT_GEO_NAME
578
+ logging.info('DataFrame has no geo column. Assuming "National".')
579
+
580
+ if time_col is not None:
581
+ df_grouped = df.sort_values(time_col).groupby(geo_col)[time_col]
582
+ if any(df_grouped.count() != df_grouped.nunique()):
583
+ # Currently we raise errors for all duplicate geo time entries. Might
584
+ # want to consider silently dropping dupes for column values that are
585
+ # the same (e.g. {geo: ['a', 'a'], 'time': ['1', '1'], kpi: [120, 120]})
586
+ raise ValueError("Duplicate entries found in the 'time' column.")
587
+
588
+ times_by_geo = df_grouped.apply(list).reset_index(drop=True)
589
+ if any(t != times_by_geo[0] for t in times_by_geo[1:]):
590
+ raise ValueError(
591
+ "Values in the 'time' column not consistent across different geos."
592
+ )
593
+
594
+ def _check_revenue_per_kpi_defaults(
595
+ self, df: pd.DataFrame, revenue_per_kpi_col: str
596
+ ):
597
+ """Sets revenue_per_kpi to default if kpi type is revenue and with_revenue_per_kpi is called."""
598
+ if self._kpi_type == constants.REVENUE:
599
+ df[revenue_per_kpi_col] = 1.0
600
+ warnings.warn(
601
+ 'with_revenue_per_kpi was called but kpi_type was set to revenue.'
602
+ ' Assuming revenue per kpi with values [1].'
603
+ )
604
+
605
+ def _validate_channel_cols(
606
+ self, channel_names: list[str], all_channel_cols: list[list[str]]
607
+ ):
608
+ if len(channel_names) != len(set(channel_names)):
609
+ raise ValueError('Channel names must be unique.')
610
+ for channel_cols in all_channel_cols:
611
+ if len(channel_cols) != len(channel_names):
612
+ raise ValueError(
613
+ 'Given channel columns must have same length as channel names.'
614
+ )
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -121,11 +121,11 @@ class InputData:
121
121
  `revenue_per_kpi` exists, ROI calibration is used and the analysis is run
122
122
  on revenue. When the `revenue_per_kpi` doesn't exist for the same
123
123
  `kpi_type`, custom ROI calibration is used and the analysis is run on KPI.
124
- controls: A DataArray of dimensions `(n_geos, n_times, n_controls)`
125
- containing control variable values.
126
124
  population: A DataArray of dimensions `(n_geos,)` containing the population
127
125
  of each group. This variable is used to scale the KPI and media for
128
126
  modeling.
127
+ controls: An optional DataArray of dimensions `(n_geos, n_times,
128
+ n_controls)` containing control variable values.
129
129
  revenue_per_kpi: An optional DataArray of dimensions `(n_geos, n_times)`
130
130
  containing the average revenue amount per KPI unit. Although modeling is
131
131
  done on `kpi`, model analysis and optimization are done on `KPI *
@@ -275,8 +275,8 @@ class InputData:
275
275
 
276
276
  kpi: xr.DataArray
277
277
  kpi_type: str
278
- controls: xr.DataArray
279
278
  population: xr.DataArray
279
+ controls: xr.DataArray | None = None
280
280
  revenue_per_kpi: xr.DataArray | None = None
281
281
  media: xr.DataArray | None = None
282
282
  media_spend: xr.DataArray | None = None
@@ -409,9 +409,12 @@ class InputData:
409
409
  return None
410
410
 
411
411
  @property
412
- def control_variable(self) -> xr.DataArray:
412
+ def control_variable(self) -> xr.DataArray | None:
413
413
  """Returns the control variable dimension."""
414
- return self.controls[constants.CONTROL_VARIABLE]
414
+ if self.controls is not None:
415
+ return self.controls[constants.CONTROL_VARIABLE]
416
+ else:
417
+ return None
415
418
 
416
419
  @property
417
420
  def media_spend_has_geo_dimension(self) -> bool:
@@ -502,8 +505,8 @@ class InputData:
502
505
  # Must match the order of constants.POSSIBLE_INPUT_DATA_ARRAY_NAMES!
503
506
  arrays = (
504
507
  self.kpi,
505
- self.controls,
506
508
  self.population,
509
+ self.controls,
507
510
  self.revenue_per_kpi,
508
511
  self.organic_media,
509
512
  self.organic_reach,
@@ -786,9 +789,10 @@ class InputData:
786
789
  """Returns data as a single `xarray.Dataset` object."""
787
790
  data = [
788
791
  self.kpi,
789
- self.controls,
790
792
  self.population,
791
793
  ]
794
+ if self.controls is not None:
795
+ data.append(self.controls)
792
796
  if self.revenue_per_kpi is not None:
793
797
  data.append(self.revenue_per_kpi)
794
798
  if self.media is not None: