google-meridian 1.1.0__tar.gz → 1.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. {google_meridian-1.1.0/google_meridian.egg-info → google_meridian-1.1.1}/PKG-INFO +2 -2
  2. {google_meridian-1.1.0 → google_meridian-1.1.1}/README.md +1 -1
  3. {google_meridian-1.1.0 → google_meridian-1.1.1/google_meridian.egg-info}/PKG-INFO +2 -2
  4. {google_meridian-1.1.0 → google_meridian-1.1.1}/meridian/__init__.py +2 -2
  5. {google_meridian-1.1.0 → google_meridian-1.1.1}/meridian/analysis/__init__.py +1 -1
  6. {google_meridian-1.1.0 → google_meridian-1.1.1}/meridian/analysis/analyzer.py +18 -17
  7. {google_meridian-1.1.0 → google_meridian-1.1.1}/meridian/analysis/formatter.py +1 -1
  8. {google_meridian-1.1.0 → google_meridian-1.1.1}/meridian/analysis/optimizer.py +1 -1
  9. {google_meridian-1.1.0 → google_meridian-1.1.1}/meridian/analysis/summarizer.py +1 -1
  10. {google_meridian-1.1.0 → google_meridian-1.1.1}/meridian/analysis/summary_text.py +1 -1
  11. {google_meridian-1.1.0 → google_meridian-1.1.1}/meridian/analysis/test_utils.py +1 -1
  12. {google_meridian-1.1.0 → google_meridian-1.1.1}/meridian/analysis/visualizer.py +2 -3
  13. {google_meridian-1.1.0 → google_meridian-1.1.1}/meridian/constants.py +3 -3
  14. {google_meridian-1.1.0 → google_meridian-1.1.1}/meridian/data/__init__.py +1 -1
  15. {google_meridian-1.1.0 → google_meridian-1.1.1}/meridian/data/arg_builder.py +1 -1
  16. {google_meridian-1.1.0 → google_meridian-1.1.1}/meridian/data/input_data.py +12 -8
  17. {google_meridian-1.1.0 → google_meridian-1.1.1}/meridian/data/load.py +53 -40
  18. {google_meridian-1.1.0 → google_meridian-1.1.1}/meridian/data/test_utils.py +60 -43
  19. {google_meridian-1.1.0 → google_meridian-1.1.1}/meridian/data/time_coordinates.py +1 -1
  20. {google_meridian-1.1.0 → google_meridian-1.1.1}/meridian/model/__init__.py +1 -1
  21. {google_meridian-1.1.0 → google_meridian-1.1.1}/meridian/model/adstock_hill.py +1 -1
  22. {google_meridian-1.1.0 → google_meridian-1.1.1}/meridian/model/knots.py +1 -1
  23. {google_meridian-1.1.0 → google_meridian-1.1.1}/meridian/model/media.py +1 -1
  24. {google_meridian-1.1.0 → google_meridian-1.1.1}/meridian/model/model.py +47 -27
  25. {google_meridian-1.1.0 → google_meridian-1.1.1}/meridian/model/model_test_data.py +75 -1
  26. {google_meridian-1.1.0 → google_meridian-1.1.1}/meridian/model/posterior_sampler.py +19 -15
  27. {google_meridian-1.1.0 → google_meridian-1.1.1}/meridian/model/prior_distribution.py +1 -1
  28. {google_meridian-1.1.0 → google_meridian-1.1.1}/meridian/model/prior_sampler.py +32 -26
  29. {google_meridian-1.1.0 → google_meridian-1.1.1}/meridian/model/spec.py +1 -1
  30. {google_meridian-1.1.0 → google_meridian-1.1.1}/meridian/model/transformers.py +1 -1
  31. {google_meridian-1.1.0 → google_meridian-1.1.1}/setup.py +1 -1
  32. {google_meridian-1.1.0 → google_meridian-1.1.1}/LICENSE +0 -0
  33. {google_meridian-1.1.0 → google_meridian-1.1.1}/MANIFEST.in +0 -0
  34. {google_meridian-1.1.0 → google_meridian-1.1.1}/google_meridian.egg-info/SOURCES.txt +0 -0
  35. {google_meridian-1.1.0 → google_meridian-1.1.1}/google_meridian.egg-info/dependency_links.txt +0 -0
  36. {google_meridian-1.1.0 → google_meridian-1.1.1}/google_meridian.egg-info/requires.txt +0 -0
  37. {google_meridian-1.1.0 → google_meridian-1.1.1}/google_meridian.egg-info/top_level.txt +0 -0
  38. {google_meridian-1.1.0 → google_meridian-1.1.1}/meridian/analysis/templates/card.html.jinja +0 -0
  39. {google_meridian-1.1.0 → google_meridian-1.1.1}/meridian/analysis/templates/chart.html.jinja +0 -0
  40. {google_meridian-1.1.0 → google_meridian-1.1.1}/meridian/analysis/templates/chips.html.jinja +0 -0
  41. {google_meridian-1.1.0 → google_meridian-1.1.1}/meridian/analysis/templates/insights.html.jinja +0 -0
  42. {google_meridian-1.1.0 → google_meridian-1.1.1}/meridian/analysis/templates/stats.html.jinja +0 -0
  43. {google_meridian-1.1.0 → google_meridian-1.1.1}/meridian/analysis/templates/style.scss +0 -0
  44. {google_meridian-1.1.0 → google_meridian-1.1.1}/meridian/analysis/templates/summary.html.jinja +0 -0
  45. {google_meridian-1.1.0 → google_meridian-1.1.1}/meridian/analysis/templates/table.html.jinja +0 -0
  46. {google_meridian-1.1.0 → google_meridian-1.1.1}/pyproject.toml +0 -0
  47. {google_meridian-1.1.0 → google_meridian-1.1.1}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: google-meridian
3
- Version: 1.1.0
3
+ Version: 1.1.1
4
4
  Summary: Google's open source mixed marketing model library, helps you understand your return on investment and direct your ad spend with confidence.
5
5
  Author-email: The Meridian Authors <no-reply@google.com>
6
6
  License:
@@ -393,7 +393,7 @@ To cite this repository:
393
393
  author = {Google Meridian Marketing Mix Modeling Team},
394
394
  title = {Meridian: Marketing Mix Modeling},
395
395
  url = {https://github.com/google/meridian},
396
- version = {1.1.0},
396
+ version = {1.1.1},
397
397
  year = {2025},
398
398
  }
399
399
  ```
@@ -151,7 +151,7 @@ To cite this repository:
151
151
  author = {Google Meridian Marketing Mix Modeling Team},
152
152
  title = {Meridian: Marketing Mix Modeling},
153
153
  url = {https://github.com/google/meridian},
154
- version = {1.1.0},
154
+ version = {1.1.1},
155
155
  year = {2025},
156
156
  }
157
157
  ```
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: google-meridian
3
- Version: 1.1.0
3
+ Version: 1.1.1
4
4
  Summary: Google's open source mixed marketing model library, helps you understand your return on investment and direct your ad spend with confidence.
5
5
  Author-email: The Meridian Authors <no-reply@google.com>
6
6
  License:
@@ -393,7 +393,7 @@ To cite this repository:
393
393
  author = {Google Meridian Marketing Mix Modeling Team},
394
394
  title = {Meridian: Marketing Mix Modeling},
395
395
  url = {https://github.com/google/meridian},
396
- version = {1.1.0},
396
+ version = {1.1.1},
397
397
  year = {2025},
398
398
  }
399
399
  ```
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
14
14
 
15
15
  """Meridian API."""
16
16
 
17
- __version__ = "1.1.0"
17
+ __version__ = "1.1.1"
18
18
 
19
19
 
20
20
  from meridian import analysis
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -788,7 +788,7 @@ class Analyzer:
788
788
  tensors are expected to be scaled by their corresponding transformers.
789
789
  dist_tensors: A `DistributionTensors` container with the distribution
790
790
  tensors for media, RF, organic media, organic RF, non-media treatments,
791
- and controls.
791
+ and controls (if available).
792
792
 
793
793
  Returns:
794
794
  Tensor representing computed kpi means.
@@ -803,17 +803,15 @@ class Analyzer:
803
803
  )
804
804
  )
805
805
 
806
- result = (
807
- tau_gt
808
- + tf.einsum(
809
- "...gtm,...gm->...gt", combined_media_transformed, combined_beta
810
- )
811
- + tf.einsum(
812
- "...gtc,...gc->...gt",
813
- data_tensors.controls,
814
- dist_tensors.gamma_gc,
815
- )
806
+ result = tau_gt + tf.einsum(
807
+ "...gtm,...gm->...gt", combined_media_transformed, combined_beta
816
808
  )
809
+ if self._meridian.controls is not None:
810
+ result += tf.einsum(
811
+ "...gtc,...gc->...gt",
812
+ data_tensors.controls,
813
+ dist_tensors.gamma_gc,
814
+ )
817
815
  if data_tensors.non_media_treatments is not None:
818
816
  result += tf.einsum(
819
817
  "...gtm,...gm->...gt",
@@ -1464,11 +1462,14 @@ class Analyzer:
1464
1462
  (n_chains, 0, self._meridian.n_geos, self._meridian.n_times)
1465
1463
  )
1466
1464
  batch_starting_indices = np.arange(n_draws, step=batch_size)
1467
- param_list = [
1468
- constants.MU_T,
1469
- constants.TAU_G,
1470
- constants.GAMMA_GC,
1471
- ] + self._get_causal_param_names(include_non_paid_channels=True)
1465
+ param_list = (
1466
+ [
1467
+ constants.MU_T,
1468
+ constants.TAU_G,
1469
+ ]
1470
+ + ([constants.GAMMA_GC] if self._meridian.n_controls else [])
1471
+ + self._get_causal_param_names(include_non_paid_channels=True)
1472
+ )
1472
1473
  outcome_means_temps = []
1473
1474
  for start_index in batch_starting_indices:
1474
1475
  stop_index = np.min([n_draws, start_index + batch_size])
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1493,8 +1493,7 @@ class MediaSummary:
1493
1493
  Returns:
1494
1494
  An `xarray.Dataset` containing the following:
1495
1495
  - **Coordinates:** `channel`, `metric` (`mean`, `median`, `ci_lo`,
1496
- `ci_hi`),
1497
- `distribution` (`prior`, `posterior`)
1496
+ `ci_hi`), `distribution` (`prior`, `posterior`)
1498
1497
  - **Data variables:** `incremental_outcome`, `pct_of_contribution`,
1499
1498
  `effectiveness`.
1500
1499
  """
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -72,10 +72,10 @@ REVENUE = 'revenue'
72
72
  NON_REVENUE = 'non_revenue'
73
73
  REQUIRED_INPUT_DATA_ARRAY_NAMES = (
74
74
  KPI,
75
- CONTROLS,
76
75
  POPULATION,
77
76
  )
78
77
  OPTIONAL_INPUT_DATA_ARRAY_NAMES = (
78
+ CONTROLS,
79
79
  REVENUE_PER_KPI,
80
80
  ORGANIC_MEDIA,
81
81
  ORGANIC_REACH,
@@ -148,7 +148,6 @@ REQUIRED_INPUT_DATA_COORD_NAMES = (
148
148
  GEO,
149
149
  TIME,
150
150
  MEDIA_TIME,
151
- CONTROL_VARIABLE,
152
151
  )
153
152
  NON_PAID_MEDIA_INPUT_DATA_COORD_NAMES = (
154
153
  ORGANIC_MEDIA_CHANNEL,
@@ -159,6 +158,7 @@ MEDIA_INPUT_DATA_COORD_NAMES = (MEDIA_CHANNEL,)
159
158
  RF_INPUT_DATA_COORD_NAMES = (RF_CHANNEL,)
160
159
  POSSIBLE_INPUT_DATA_COORD_NAMES = (
161
160
  REQUIRED_INPUT_DATA_COORD_NAMES
161
+ + (CONTROL_VARIABLE,)
162
162
  + NON_PAID_MEDIA_INPUT_DATA_COORD_NAMES
163
163
  + MEDIA_INPUT_DATA_COORD_NAMES
164
164
  + RF_INPUT_DATA_COORD_NAMES
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -121,11 +121,11 @@ class InputData:
121
121
  `revenue_per_kpi` exists, ROI calibration is used and the analysis is run
122
122
  on revenue. When the `revenue_per_kpi` doesn't exist for the same
123
123
  `kpi_type`, custom ROI calibration is used and the analysis is run on KPI.
124
- controls: A DataArray of dimensions `(n_geos, n_times, n_controls)`
125
- containing control variable values.
126
124
  population: A DataArray of dimensions `(n_geos,)` containing the population
127
125
  of each group. This variable is used to scale the KPI and media for
128
126
  modeling.
127
+ controls: An optional DataArray of dimensions `(n_geos, n_times,
128
+ n_controls)` containing control variable values.
129
129
  revenue_per_kpi: An optional DataArray of dimensions `(n_geos, n_times)`
130
130
  containing the average revenue amount per KPI unit. Although modeling is
131
131
  done on `kpi`, model analysis and optimization are done on `KPI *
@@ -275,8 +275,8 @@ class InputData:
275
275
 
276
276
  kpi: xr.DataArray
277
277
  kpi_type: str
278
- controls: xr.DataArray
279
278
  population: xr.DataArray
279
+ controls: xr.DataArray | None = None
280
280
  revenue_per_kpi: xr.DataArray | None = None
281
281
  media: xr.DataArray | None = None
282
282
  media_spend: xr.DataArray | None = None
@@ -409,9 +409,12 @@ class InputData:
409
409
  return None
410
410
 
411
411
  @property
412
- def control_variable(self) -> xr.DataArray:
412
+ def control_variable(self) -> xr.DataArray | None:
413
413
  """Returns the control variable dimension."""
414
- return self.controls[constants.CONTROL_VARIABLE]
414
+ if self.controls is not None:
415
+ return self.controls[constants.CONTROL_VARIABLE]
416
+ else:
417
+ return None
415
418
 
416
419
  @property
417
420
  def media_spend_has_geo_dimension(self) -> bool:
@@ -502,8 +505,8 @@ class InputData:
502
505
  # Must match the order of constants.POSSIBLE_INPUT_DATA_ARRAY_NAMES!
503
506
  arrays = (
504
507
  self.kpi,
505
- self.controls,
506
508
  self.population,
509
+ self.controls,
507
510
  self.revenue_per_kpi,
508
511
  self.organic_media,
509
512
  self.organic_reach,
@@ -786,9 +789,10 @@ class InputData:
786
789
  """Returns data as a single `xarray.Dataset` object."""
787
790
  data = [
788
791
  self.kpi,
789
- self.controls,
790
792
  self.population,
791
793
  ]
794
+ if self.controls is not None:
795
+ data.append(self.controls)
792
796
  if self.revenue_per_kpi is not None:
793
797
  data.append(self.revenue_per_kpi)
794
798
  if self.media is not None:
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -79,7 +79,7 @@ class XrDatasetDataLoader(InputDataLoader):
79
79
  """Constructor.
80
80
 
81
81
  The coordinates of the input dataset should be: `time`, `media_time`,
82
- `control_variable`, `geo` (optional for a national model),
82
+ `control_variable` (optional), `geo` (optional for a national model),
83
83
  `non_media_channel` (optional), `organic_media_channel` (optional),
84
84
  `organic_rf_channel` (optional), and
85
85
  either `media_channel`, `rf_channel`, or both.
@@ -93,7 +93,7 @@ class XrDatasetDataLoader(InputDataLoader):
93
93
 
94
94
  * `kpi`: `(geo, time)`
95
95
  * `revenue_per_kpi`: `(geo, time)`
96
- * `controls`: `(geo, time, control_variable)`
96
+ * `controls`: `(geo, time, control_variable)` - optional
97
97
  * `population`: `(geo)`
98
98
  * `media`: `(geo, media_time, media_channel)` - optional
99
99
  * `media_spend`: `(geo, time, media_channel)`, `(1, time, media_channel)`,
@@ -113,7 +113,7 @@ class XrDatasetDataLoader(InputDataLoader):
113
113
 
114
114
  * `kpi`: `([1,] time)`
115
115
  * `revenue_per_kpi`: `([1,] time)`
116
- * `controls`: `([1,] time, control_variable)`
116
+ * `controls`: `([1,] time, control_variable)` - optional
117
117
  * `population`: `([1],)` - this array is optional for national data
118
118
  * `media`: `([1,] media_time, media_channel)` - optional
119
119
  * `media_spend`: `([1,] time, media_channel)` or
@@ -198,7 +198,7 @@ class XrDatasetDataLoader(InputDataLoader):
198
198
  self.dataset = dataset.rename(name_mapping)
199
199
 
200
200
  # Add a `geo` dimension if it is not already present.
201
- if (constants.GEO) not in self.dataset.dims.keys():
201
+ if (constants.GEO) not in self.dataset.sizes.keys():
202
202
  self.dataset = self.dataset.expand_dims(dim=[constants.GEO], axis=0)
203
203
 
204
204
  if len(self.dataset.coords[constants.GEO]) == 1:
@@ -228,7 +228,7 @@ class XrDatasetDataLoader(InputDataLoader):
228
228
  compat='override',
229
229
  )
230
230
 
231
- if constants.MEDIA_TIME not in self.dataset.dims.keys():
231
+ if constants.MEDIA_TIME not in self.dataset.sizes.keys():
232
232
  self._add_media_time()
233
233
  self._normalize_time_coordinates(constants.TIME)
234
234
  self._normalize_time_coordinates(constants.MEDIA_TIME)
@@ -349,14 +349,17 @@ class XrDatasetDataLoader(InputDataLoader):
349
349
  # Arrays in which NAs are expected in the lagged-media period.
350
350
  na_arrays = [
351
351
  constants.KPI,
352
- constants.CONTROLS,
353
352
  ]
354
353
 
355
- na_mask = self.dataset[constants.KPI].isnull().any(
356
- dim=constants.GEO
357
- ) | self.dataset[constants.CONTROLS].isnull().any(
358
- dim=[constants.GEO, constants.CONTROL_VARIABLE]
359
- )
354
+ na_mask = self.dataset[constants.KPI].isnull().any(dim=constants.GEO)
355
+
356
+ if constants.CONTROLS in self.dataset.data_vars.keys():
357
+ na_arrays.append(constants.CONTROLS)
358
+ na_mask |= (
359
+ self.dataset[constants.CONTROLS]
360
+ .isnull()
361
+ .any(dim=[constants.GEO, constants.CONTROL_VARIABLE])
362
+ )
360
363
 
361
364
  if constants.NON_MEDIA_TREATMENTS in self.dataset.data_vars.keys():
362
365
  na_arrays.append(constants.NON_MEDIA_TREATMENTS)
@@ -427,11 +430,12 @@ class XrDatasetDataLoader(InputDataLoader):
427
430
  .dropna(dim=constants.TIME)
428
431
  .rename({constants.TIME: new_time})
429
432
  )
430
- new_dataset[constants.CONTROLS] = (
431
- new_dataset[constants.CONTROLS]
432
- .dropna(dim=constants.TIME)
433
- .rename({constants.TIME: new_time})
434
- )
433
+ if constants.CONTROLS in new_dataset.data_vars.keys():
434
+ new_dataset[constants.CONTROLS] = (
435
+ new_dataset[constants.CONTROLS]
436
+ .dropna(dim=constants.TIME)
437
+ .rename({constants.TIME: new_time})
438
+ )
435
439
  if constants.NON_MEDIA_TREATMENTS in new_dataset.data_vars.keys():
436
440
  new_dataset[constants.NON_MEDIA_TREATMENTS] = (
437
441
  new_dataset[constants.NON_MEDIA_TREATMENTS]
@@ -466,6 +470,11 @@ class XrDatasetDataLoader(InputDataLoader):
466
470
 
467
471
  def load(self) -> input_data.InputData:
468
472
  """Returns an `InputData` object containing the data from the dataset."""
473
+ controls = (
474
+ self.dataset.controls
475
+ if constants.CONTROLS in self.dataset.data_vars.keys()
476
+ else None
477
+ )
469
478
  revenue_per_kpi = (
470
479
  self.dataset.revenue_per_kpi
471
480
  if constants.REVENUE_PER_KPI in self.dataset.data_vars.keys()
@@ -519,9 +528,9 @@ class XrDatasetDataLoader(InputDataLoader):
519
528
  return input_data.InputData(
520
529
  kpi=self.dataset.kpi,
521
530
  kpi_type=self.kpi_type,
522
- revenue_per_kpi=revenue_per_kpi,
523
- controls=self.dataset.controls,
524
531
  population=self.dataset.population,
532
+ controls=controls,
533
+ revenue_per_kpi=revenue_per_kpi,
525
534
  media=media,
526
535
  media_spend=media_spend,
527
536
  reach=reach,
@@ -539,14 +548,14 @@ class CoordToColumns:
539
548
  """A mapping between the desired and actual column names in the input data.
540
549
 
541
550
  Attributes:
542
- controls: List of column names containing `controls` values in the input
543
- data.
544
551
  time: Name of column containing `time` values in the input data.
545
- kpi: Name of column containing `kpi` values in the input data.
546
- revenue_per_kpi: Name of column containing `revenue_per_kpi` values in the
547
- input data.
548
552
  geo: Name of column containing `geo` values in the input data. This field
549
553
  is optional for a national model.
554
+ kpi: Name of column containing `kpi` values in the input data.
555
+ controls: List of column names containing `controls` values in the input
556
+ data. Optional.
557
+ revenue_per_kpi: Name of column containing `revenue_per_kpi` values in the
558
+ input data. Optional. Will be overridden if model KPI type is "revenue".
550
559
  population: Name of column containing `population` values in the input data.
551
560
  This field is optional for a national model.
552
561
  media: List of column names containing `media` values in the input data.
@@ -567,11 +576,11 @@ class CoordToColumns:
567
576
  values in the input data.
568
577
  """
569
578
 
570
- controls: Sequence[str]
571
579
  time: str = constants.TIME
580
+ geo: str = constants.GEO
572
581
  kpi: str = constants.KPI
582
+ controls: Sequence[str] | None = None
573
583
  revenue_per_kpi: str | None = None
574
- geo: str = constants.GEO
575
584
  population: str = constants.POPULATION
576
585
  # Media data
577
586
  media: Sequence[str] | None = None
@@ -607,7 +616,7 @@ class DataFrameDataLoader(InputDataLoader):
607
616
  to the DataFrame column names if they are different. The fields are:
608
617
 
609
618
  * `geo`, `time`, `kpi`, `revenue_per_kpi`, `population` (single column)
610
- * `controls` (multiple columns)
619
+ * `controls` (multiple columns, optional)
611
620
  * (1) `media`, `media_spend` (multiple columns)
612
621
  * (2) `reach`, `frequency`, `rf_spend` (multiple columns)
613
622
  * `non_media_treatments` (multiple columns, optional)
@@ -953,9 +962,10 @@ class DataFrameDataLoader(InputDataLoader):
953
962
  not_lagged_columns = []
954
963
  coords = [
955
964
  constants.KPI,
956
- constants.CONTROLS,
957
965
  constants.POPULATION,
958
966
  ]
967
+ if self.coord_to_columns.controls is not None:
968
+ coords.append(constants.CONTROLS)
959
969
  if self.coord_to_columns.revenue_per_kpi is not None:
960
970
  coords.append(constants.REVENUE_PER_KPI)
961
971
  if self.coord_to_columns.media_spend is not None:
@@ -1042,17 +1052,20 @@ class DataFrameDataLoader(InputDataLoader):
1042
1052
  .to_frame()
1043
1053
  .to_xarray()
1044
1054
  )
1045
- controls_xr = (
1046
- df_indexed[self.coord_to_columns.controls]
1047
- .stack()
1048
- .rename(constants.CONTROLS)
1049
- .rename_axis(
1050
- [constants.GEO, constants.TIME, constants.CONTROL_VARIABLE]
1051
- )
1052
- .to_frame()
1053
- .to_xarray()
1054
- )
1055
- dataset = xr.combine_by_coords([kpi_xr, population_xr, controls_xr])
1055
+ dataset = xr.combine_by_coords([kpi_xr, population_xr])
1056
+
1057
+ if self.coord_to_columns.controls is not None:
1058
+ controls_xr = (
1059
+ df_indexed[self.coord_to_columns.controls]
1060
+ .stack()
1061
+ .rename(constants.CONTROLS)
1062
+ .rename_axis(
1063
+ [constants.GEO, constants.TIME, constants.CONTROL_VARIABLE]
1064
+ )
1065
+ .to_frame()
1066
+ .to_xarray()
1067
+ )
1068
+ dataset = xr.combine_by_coords([dataset, controls_xr])
1056
1069
 
1057
1070
  if self.coord_to_columns.non_media_treatments is not None:
1058
1071
  non_media_xr = (
@@ -1224,7 +1237,7 @@ class CsvDataLoader(InputDataLoader):
1224
1237
  CSV column names, if they are different. The fields are:
1225
1238
 
1226
1239
  * `geo`, `time`, `kpi`, `revenue_per_kpi`, `population` (single column)
1227
- * `controls` (multiple columns)
1240
+ * `controls` (multiple columns, optional)
1228
1241
  * (1) `media`, `media_spend` (multiple columns)
1229
1242
  * (2) `reach`, `frequency`, `rf_spend` (multiple columns)
1230
1243
  * `non_media_treatments` (multiple columns, optional)