google-meridian 1.3.2__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/METADATA +18 -11
  2. google_meridian-1.5.0.dist-info/RECORD +112 -0
  3. {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/WHEEL +1 -1
  4. {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/top_level.txt +1 -0
  5. meridian/analysis/analyzer.py +558 -398
  6. meridian/analysis/optimizer.py +90 -68
  7. meridian/analysis/review/reviewer.py +4 -1
  8. meridian/analysis/summarizer.py +13 -3
  9. meridian/analysis/test_utils.py +2911 -2102
  10. meridian/analysis/visualizer.py +37 -14
  11. meridian/backend/__init__.py +106 -0
  12. meridian/constants.py +2 -0
  13. meridian/data/input_data.py +30 -52
  14. meridian/data/input_data_builder.py +2 -9
  15. meridian/data/test_utils.py +107 -51
  16. meridian/data/validator.py +48 -0
  17. meridian/mlflow/autolog.py +19 -9
  18. meridian/model/__init__.py +2 -0
  19. meridian/model/adstock_hill.py +3 -5
  20. meridian/model/context.py +1059 -0
  21. meridian/model/eda/constants.py +335 -4
  22. meridian/model/eda/eda_engine.py +723 -312
  23. meridian/model/eda/eda_outcome.py +177 -33
  24. meridian/model/equations.py +418 -0
  25. meridian/model/knots.py +58 -47
  26. meridian/model/model.py +228 -878
  27. meridian/model/model_test_data.py +38 -0
  28. meridian/model/posterior_sampler.py +103 -62
  29. meridian/model/prior_sampler.py +114 -94
  30. meridian/model/spec.py +23 -14
  31. meridian/templates/card.html.jinja +9 -7
  32. meridian/templates/chart.html.jinja +1 -6
  33. meridian/templates/finding.html.jinja +19 -0
  34. meridian/templates/findings.html.jinja +33 -0
  35. meridian/templates/formatter.py +41 -5
  36. meridian/templates/formatter_test.py +127 -0
  37. meridian/templates/style.css +66 -9
  38. meridian/templates/style.scss +85 -4
  39. meridian/templates/table.html.jinja +1 -0
  40. meridian/version.py +1 -1
  41. scenarioplanner/__init__.py +42 -0
  42. scenarioplanner/converters/__init__.py +25 -0
  43. scenarioplanner/converters/dataframe/__init__.py +28 -0
  44. scenarioplanner/converters/dataframe/budget_opt_converters.py +383 -0
  45. scenarioplanner/converters/dataframe/common.py +71 -0
  46. scenarioplanner/converters/dataframe/constants.py +137 -0
  47. scenarioplanner/converters/dataframe/converter.py +42 -0
  48. scenarioplanner/converters/dataframe/dataframe_model_converter.py +70 -0
  49. scenarioplanner/converters/dataframe/marketing_analyses_converters.py +543 -0
  50. scenarioplanner/converters/dataframe/rf_opt_converters.py +314 -0
  51. scenarioplanner/converters/mmm.py +743 -0
  52. scenarioplanner/converters/mmm_converter.py +58 -0
  53. scenarioplanner/converters/sheets.py +156 -0
  54. scenarioplanner/converters/test_data.py +714 -0
  55. scenarioplanner/linkingapi/__init__.py +47 -0
  56. scenarioplanner/linkingapi/constants.py +27 -0
  57. scenarioplanner/linkingapi/url_generator.py +131 -0
  58. scenarioplanner/mmm_ui_proto_generator.py +355 -0
  59. schema/__init__.py +5 -2
  60. schema/mmm_proto_generator.py +71 -0
  61. schema/model_consumer.py +133 -0
  62. schema/processors/__init__.py +77 -0
  63. schema/processors/budget_optimization_processor.py +832 -0
  64. schema/processors/common.py +64 -0
  65. schema/processors/marketing_processor.py +1137 -0
  66. schema/processors/model_fit_processor.py +367 -0
  67. schema/processors/model_kernel_processor.py +117 -0
  68. schema/processors/model_processor.py +415 -0
  69. schema/processors/reach_frequency_optimization_processor.py +584 -0
  70. schema/serde/distribution.py +12 -7
  71. schema/serde/hyperparameters.py +54 -107
  72. schema/serde/meridian_serde.py +6 -1
  73. schema/test_data.py +380 -0
  74. schema/utils/__init__.py +2 -0
  75. schema/utils/date_range_bucketing.py +117 -0
  76. schema/utils/proto_enum_converter.py +127 -0
  77. google_meridian-1.3.2.dist-info/RECORD +0 -76
  78. {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1059 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Defines ModelContext class for Meridian."""
16
+
17
+ from collections.abc import Mapping, Sequence
18
+ import functools
19
+ import warnings
20
+
21
+ from meridian import backend
22
+ from meridian import constants
23
+ from meridian.data import input_data as data
24
+ from meridian.data import time_coordinates as tc
25
+ from meridian.model import adstock_hill
26
+ from meridian.model import knots
27
+ from meridian.model import media
28
+ from meridian.model import prior_distribution
29
+ from meridian.model import spec
30
+ from meridian.model import transformers
31
+ import numpy as np
32
+
33
+ __all__ = [
34
+ "ModelContext",
35
+ ]
36
+
37
+
38
+ class ModelContext:
39
+ """Model context for Meridian.
40
+
41
+ This class contains all model parameters that do not change between the runs
42
+ of Meridian.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ input_data: data.InputData,
48
+ model_spec: spec.ModelSpec,
49
+ ):
50
+ self._input_data = input_data
51
+ self._model_spec = model_spec
52
+
53
+ self._validate_data_dependent_model_spec()
54
+ self._validate_model_spec_shapes()
55
+
56
+ self._set_total_media_contribution_prior = False
57
+ self._warn_setting_ignored_priors()
58
+ self._validate_mroi_priors_non_revenue()
59
+ self._validate_roi_priors_non_revenue()
60
+ self._check_media_prior_support()
61
+ self._validate_geo_invariants()
62
+ self._validate_time_invariants()
63
+ self._validate_media_spend_for_paid_channels()
64
+ self._validate_rf_spend_for_paid_channels()
65
+
66
+ def _validate_data_dependent_model_spec(self):
67
+ """Validates that the data dependent model specs have correct shapes."""
68
+
69
+ if self._model_spec.roi_calibration_period is not None and (
70
+ self._model_spec.roi_calibration_period.shape
71
+ != (
72
+ self.n_media_times,
73
+ self.n_media_channels,
74
+ )
75
+ ):
76
+ raise ValueError(
77
+ "The shape of `roi_calibration_period`"
78
+ f" {self._model_spec.roi_calibration_period.shape} is different from"
79
+ f" `(n_media_times, n_media_channels) = ({self.n_media_times},"
80
+ f" {self.n_media_channels})`."
81
+ )
82
+
83
+ if self._model_spec.rf_roi_calibration_period is not None and (
84
+ self._model_spec.rf_roi_calibration_period.shape
85
+ != (
86
+ self.n_media_times,
87
+ self.n_rf_channels,
88
+ )
89
+ ):
90
+ raise ValueError(
91
+ "The shape of `rf_roi_calibration_period`"
92
+ f" {self._model_spec.rf_roi_calibration_period.shape} is different"
93
+ f" from `(n_media_times, n_rf_channels) = ({self.n_media_times},"
94
+ f" {self.n_rf_channels})`."
95
+ )
96
+
97
+ if self._model_spec.holdout_id is not None:
98
+ if self.is_national and (
99
+ self._model_spec.holdout_id.shape != (self.n_times,)
100
+ ):
101
+ raise ValueError(
102
+ f"The shape of `holdout_id` {self._model_spec.holdout_id.shape} is"
103
+ f" different from `(n_times,) = ({self.n_times},)`."
104
+ )
105
+ elif not self.is_national and (
106
+ self._model_spec.holdout_id.shape
107
+ != (
108
+ self.n_geos,
109
+ self.n_times,
110
+ )
111
+ ):
112
+ raise ValueError(
113
+ f"The shape of `holdout_id` {self._model_spec.holdout_id.shape} is"
114
+ f" different from `(n_geos, n_times) = ({self.n_geos},"
115
+ f" {self.n_times})`."
116
+ )
117
+
118
+ if self._model_spec.control_population_scaling_id is not None and (
119
+ self._model_spec.control_population_scaling_id.shape
120
+ != (self.n_controls,)
121
+ ):
122
+ raise ValueError(
123
+ "The shape of `control_population_scaling_id`"
124
+ f" {self._model_spec.control_population_scaling_id.shape} is"
125
+ f" different from `(n_controls,) = ({self.n_controls},)`."
126
+ )
127
+
128
+ if self._model_spec.non_media_population_scaling_id is not None and (
129
+ self._model_spec.non_media_population_scaling_id.shape
130
+ != (self.n_non_media_channels,)
131
+ ):
132
+ raise ValueError(
133
+ "The shape of `non_media_population_scaling_id`"
134
+ f" {self._model_spec.non_media_population_scaling_id.shape} is"
135
+ " different from `(n_non_media_channels,) ="
136
+ f" ({self.n_non_media_channels},)`."
137
+ )
138
+
139
+ def _validate_model_spec_shapes(self):
140
+ """Validate shapes of model_spec attributes."""
141
+ if self._model_spec.roi_calibration_period is not None:
142
+ if self._model_spec.roi_calibration_period.shape != (
143
+ self.n_media_times,
144
+ self.n_media_channels,
145
+ ):
146
+ raise ValueError(
147
+ "The shape of `roi_calibration_period`"
148
+ f" {self._model_spec.roi_calibration_period.shape} is different"
149
+ f" from `(n_media_times, n_media_channels) = ({self.n_media_times},"
150
+ f" {self.n_media_channels})`."
151
+ )
152
+
153
+ if self._model_spec.rf_roi_calibration_period is not None:
154
+ if self._model_spec.rf_roi_calibration_period.shape != (
155
+ self.n_media_times,
156
+ self.n_rf_channels,
157
+ ):
158
+ raise ValueError(
159
+ "The shape of `rf_roi_calibration_period`"
160
+ f" {self._model_spec.rf_roi_calibration_period.shape} is different"
161
+ f" from `(n_media_times, n_rf_channels) = ({self.n_media_times},"
162
+ f" {self.n_rf_channels})`."
163
+ )
164
+
165
+ if self._model_spec.holdout_id is not None:
166
+ expected_shape = (
167
+ (self.n_times,) if self.is_national else (self.n_geos, self.n_times)
168
+ )
169
+ if self._model_spec.holdout_id.shape != expected_shape:
170
+ raise ValueError(
171
+ f"The shape of `holdout_id` {self._model_spec.holdout_id.shape} is"
172
+ " different from"
173
+ f" {'`(n_times,)`' if self.is_national else '`(n_geos, n_times)`'} ="
174
+ f" {expected_shape}."
175
+ )
176
+
177
+ if self._model_spec.control_population_scaling_id is not None:
178
+ if self._model_spec.control_population_scaling_id.shape != (
179
+ self.n_controls,
180
+ ):
181
+ raise ValueError(
182
+ "The shape of `control_population_scaling_id`"
183
+ f" {self._model_spec.control_population_scaling_id.shape} is"
184
+ f" different from `(n_controls,) = ({self.n_controls},)`."
185
+ )
186
+
187
+ def _validate_geo_invariants(self):
188
+ """Validates non-national model invariants."""
189
+ if self.is_national:
190
+ return
191
+
192
+ if self._input_data.controls is not None:
193
+ self._check_if_no_geo_variation(
194
+ self.controls_scaled,
195
+ constants.CONTROLS,
196
+ self._input_data.controls.coords[constants.CONTROL_VARIABLE].values,
197
+ )
198
+ if self._input_data.non_media_treatments is not None:
199
+ self._check_if_no_geo_variation(
200
+ self.non_media_treatments_normalized,
201
+ constants.NON_MEDIA_TREATMENTS,
202
+ self._input_data.non_media_treatments.coords[
203
+ constants.NON_MEDIA_CHANNEL
204
+ ].values,
205
+ )
206
+ if self._input_data.media is not None:
207
+ self._check_if_no_geo_variation(
208
+ self.media_tensors.media_scaled,
209
+ constants.MEDIA,
210
+ self._input_data.media.coords[constants.MEDIA_CHANNEL].values,
211
+ )
212
+ if self._input_data.reach is not None:
213
+ self._check_if_no_geo_variation(
214
+ self.rf_tensors.reach_scaled,
215
+ constants.REACH,
216
+ self._input_data.reach.coords[constants.RF_CHANNEL].values,
217
+ )
218
+ if self._input_data.organic_media is not None:
219
+ self._check_if_no_geo_variation(
220
+ self.organic_media_tensors.organic_media_scaled,
221
+ "organic_media",
222
+ self._input_data.organic_media.coords[
223
+ constants.ORGANIC_MEDIA_CHANNEL
224
+ ].values,
225
+ )
226
+ if self._input_data.organic_reach is not None:
227
+ self._check_if_no_geo_variation(
228
+ self.organic_rf_tensors.organic_reach_scaled,
229
+ "organic_reach",
230
+ self._input_data.organic_reach.coords[
231
+ constants.ORGANIC_RF_CHANNEL
232
+ ].values,
233
+ )
234
+
235
+ def _check_if_no_geo_variation(
236
+ self,
237
+ scaled_data: backend.Tensor,
238
+ data_name: str,
239
+ data_dims: Sequence[str],
240
+ epsilon=1e-4,
241
+ ):
242
+ """Raise an error if `n_knots == n_time` and data lacks geo variation."""
243
+
244
+ # Result shape: [n, d], where d is the number of axes of condition.
245
+ col_idx_full = backend.get_indices_where(
246
+ backend.reduce_std(scaled_data, axis=0) < epsilon
247
+ )[:, 1]
248
+ col_idx_unique, _, counts = backend.unique_with_counts(col_idx_full)
249
+ # We use the shape of scaled_data (instead of `n_time`) because the data may
250
+ # be padded to account for lagged effects.
251
+ data_n_time = scaled_data.shape[1]
252
+ mask = backend.equal(counts, data_n_time)
253
+ col_idx_bad = backend.boolean_mask(col_idx_unique, mask)
254
+ dims_bad = backend.gather(data_dims, col_idx_bad)
255
+
256
+ if col_idx_bad.shape[0] and self.knot_info.n_knots == self.n_times:
257
+ raise ValueError(
258
+ f"The following {data_name} variables do not vary across geos, making"
259
+ f" a model with n_knots=n_time unidentifiable: {dims_bad}. This can"
260
+ " lead to poor model convergence. Since these variables only vary"
261
+ " across time and not across geo, they are collinear with time and"
262
+ " redundant in a model with a parameter for each time period. To"
263
+ " address this, you can either: (1) decrease the number of knots"
264
+ " (n_knots < n_time), or (2) drop the listed variables that do not"
265
+ " vary across geos."
266
+ )
267
+
268
+ def _validate_time_invariants(self):
269
+ """Validates model time invariants."""
270
+ if self._input_data.controls is not None:
271
+ self._check_if_no_time_variation(
272
+ self.controls_scaled,
273
+ constants.CONTROLS,
274
+ self._input_data.controls.coords[constants.CONTROL_VARIABLE].values,
275
+ )
276
+ if self._input_data.non_media_treatments is not None:
277
+ self._check_if_no_time_variation(
278
+ self.non_media_treatments_normalized,
279
+ constants.NON_MEDIA_TREATMENTS,
280
+ self._input_data.non_media_treatments.coords[
281
+ constants.NON_MEDIA_CHANNEL
282
+ ].values,
283
+ )
284
+ if self._input_data.media is not None:
285
+ self._check_if_no_time_variation(
286
+ self.media_tensors.media_scaled,
287
+ constants.MEDIA,
288
+ self._input_data.media.coords[constants.MEDIA_CHANNEL].values,
289
+ )
290
+ if self._input_data.reach is not None:
291
+ self._check_if_no_time_variation(
292
+ self.rf_tensors.reach_scaled,
293
+ constants.REACH,
294
+ self._input_data.reach.coords[constants.RF_CHANNEL].values,
295
+ )
296
+ if self._input_data.organic_media is not None:
297
+ self._check_if_no_time_variation(
298
+ self.organic_media_tensors.organic_media_scaled,
299
+ constants.ORGANIC_MEDIA,
300
+ self._input_data.organic_media.coords[
301
+ constants.ORGANIC_MEDIA_CHANNEL
302
+ ].values,
303
+ )
304
+ if self._input_data.organic_reach is not None:
305
+ self._check_if_no_time_variation(
306
+ self.organic_rf_tensors.organic_reach_scaled,
307
+ constants.ORGANIC_REACH,
308
+ self._input_data.organic_reach.coords[
309
+ constants.ORGANIC_RF_CHANNEL
310
+ ].values,
311
+ )
312
+
313
+ def _validate_media_spend_for_paid_channels(self) -> None:
314
+ self._validate_spend_for_paid_channels(
315
+ self.input_data.aggregate_media_spend(), constants.MEDIA_CHANNEL
316
+ )
317
+
318
+ def _validate_rf_spend_for_paid_channels(self) -> None:
319
+ self._validate_spend_for_paid_channels(
320
+ self.input_data.aggregate_rf_spend(), constants.RF_CHANNEL
321
+ )
322
+
323
+ def _validate_spend_for_paid_channels(
324
+ self,
325
+ spend: np.ndarray | None,
326
+ dim: str,
327
+ ) -> None:
328
+ """Validates non-zero media spend for paid media channels.
329
+
330
+ Args:
331
+ spend: The media spend data to validate.
332
+ dim: The dimension name of the spend data.
333
+
334
+ Raises:
335
+ ValueError if any paid media channel has zero total spend.
336
+ """
337
+ if spend is None:
338
+ return
339
+ zero_spend_channels = spend.coords[dim].where(spend == 0, drop=True).values
340
+
341
+ if zero_spend_channels.size > 0:
342
+ raise ValueError(
343
+ "Zero total spend detected for paid channels:"
344
+ f" {', '.join(zero_spend_channels)}. If data is correct and this is"
345
+ " expected, please consider modeling the data as organic media."
346
+ )
347
+
348
+ def _check_if_no_time_variation(
349
+ self,
350
+ scaled_data: backend.Tensor,
351
+ data_name: str,
352
+ data_dims: Sequence[str],
353
+ epsilon=1e-4,
354
+ ):
355
+ """Raise an error if data lacks time variation."""
356
+
357
+ # Result shape: [n, d], where d is the number of axes of condition.
358
+ col_idx_full = backend.get_indices_where(
359
+ backend.reduce_std(scaled_data, axis=1) < epsilon
360
+ )[:, 1]
361
+ col_idx_unique, _, counts = backend.unique_with_counts(col_idx_full)
362
+ mask = backend.equal(counts, self.n_geos)
363
+ col_idx_bad = backend.boolean_mask(col_idx_unique, mask)
364
+ dims_bad = backend.gather(data_dims, col_idx_bad)
365
+ if col_idx_bad.shape[0]:
366
+ if self.is_national:
367
+ raise ValueError(
368
+ f"The following {data_name} variables do not vary across time,"
369
+ " which is equivalent to no signal at all in a national model:"
370
+ f" {dims_bad}. This can lead to poor model convergence. To address"
371
+ " this, drop the listed variables that do not vary across time."
372
+ )
373
+ else:
374
+ raise ValueError(
375
+ f"The following {data_name} variables do not vary across time,"
376
+ f" making a model with geo main effects unidentifiable: {dims_bad}."
377
+ " This can lead to poor model convergence. Since these variables"
378
+ " only vary across geo and not across time, they are collinear"
379
+ " with geo and redundant in a model with geo main effects. To"
380
+ " address this, drop the listed variables that do not vary across"
381
+ " time."
382
+ )
383
+
384
+ @property
385
+ def input_data(self) -> data.InputData:
386
+ return self._input_data
387
+
388
+ @property
389
+ def model_spec(self) -> spec.ModelSpec:
390
+ return self._model_spec
391
+
392
+ @functools.cached_property
393
+ def media_tensors(self) -> media.MediaTensors:
394
+ return media.build_media_tensors(self._input_data, self._model_spec)
395
+
396
+ @functools.cached_property
397
+ def rf_tensors(self) -> media.RfTensors:
398
+ return media.build_rf_tensors(self._input_data, self._model_spec)
399
+
400
+ @functools.cached_property
401
+ def organic_media_tensors(self) -> media.OrganicMediaTensors:
402
+ return media.build_organic_media_tensors(self._input_data)
403
+
404
+ @functools.cached_property
405
+ def organic_rf_tensors(self) -> media.OrganicRfTensors:
406
+ return media.build_organic_rf_tensors(self._input_data)
407
+
408
+ @functools.cached_property
409
+ def kpi(self) -> backend.Tensor:
410
+ return backend.to_tensor(self._input_data.kpi, dtype=backend.float32)
411
+
412
+ @functools.cached_property
413
+ def revenue_per_kpi(self) -> backend.Tensor | None:
414
+ if self._input_data.revenue_per_kpi is None:
415
+ return None
416
+ return backend.to_tensor(
417
+ self._input_data.revenue_per_kpi, dtype=backend.float32
418
+ )
419
+
420
+ @functools.cached_property
421
+ def controls(self) -> backend.Tensor | None:
422
+ if self._input_data.controls is None:
423
+ return None
424
+ return backend.to_tensor(self._input_data.controls, dtype=backend.float32)
425
+
426
+ @functools.cached_property
427
+ def non_media_treatments(self) -> backend.Tensor | None:
428
+ if self._input_data.non_media_treatments is None:
429
+ return None
430
+ return backend.to_tensor(
431
+ self._input_data.non_media_treatments, dtype=backend.float32
432
+ )
433
+
434
+ @functools.cached_property
435
+ def population(self) -> backend.Tensor:
436
+ return backend.to_tensor(self._input_data.population, dtype=backend.float32)
437
+
438
+ @functools.cached_property
439
+ def total_spend(self) -> backend.Tensor:
440
+ return backend.to_tensor(
441
+ self._input_data.get_total_spend(), dtype=backend.float32
442
+ )
443
+
444
+ @functools.cached_property
445
+ def total_outcome(self) -> backend.Tensor:
446
+ return backend.to_tensor(
447
+ self._input_data.get_total_outcome(), dtype=backend.float32
448
+ )
449
+
450
+ @property
451
+ def n_geos(self) -> int:
452
+ return len(self._input_data.geo)
453
+
454
+ @property
455
+ def n_media_channels(self) -> int:
456
+ if self._input_data.media_channel is None:
457
+ return 0
458
+ return len(self._input_data.media_channel)
459
+
460
+ @property
461
+ def n_rf_channels(self) -> int:
462
+ if self._input_data.rf_channel is None:
463
+ return 0
464
+ return len(self._input_data.rf_channel)
465
+
466
+ @property
467
+ def n_organic_media_channels(self) -> int:
468
+ if self._input_data.organic_media_channel is None:
469
+ return 0
470
+ return len(self._input_data.organic_media_channel)
471
+
472
+ @property
473
+ def n_organic_rf_channels(self) -> int:
474
+ if self._input_data.organic_rf_channel is None:
475
+ return 0
476
+ return len(self._input_data.organic_rf_channel)
477
+
478
+ @property
479
+ def n_controls(self) -> int:
480
+ if self._input_data.control_variable is None:
481
+ return 0
482
+ return len(self._input_data.control_variable)
483
+
484
+ @property
485
+ def n_non_media_channels(self) -> int:
486
+ if self._input_data.non_media_channel is None:
487
+ return 0
488
+ return len(self._input_data.non_media_channel)
489
+
490
+ @property
491
+ def n_times(self) -> int:
492
+ return len(self._input_data.time)
493
+
494
+ @property
495
+ def n_media_times(self) -> int:
496
+ return len(self._input_data.media_time)
497
+
498
+ @property
499
+ def is_national(self) -> bool:
500
+ return self.n_geos == 1
501
+
502
+ @functools.cached_property
503
+ def knot_info(self) -> knots.KnotInfo:
504
+ return knots.get_knot_info(
505
+ n_times=self.n_times,
506
+ knots=self._model_spec.knots,
507
+ enable_aks=self._model_spec.enable_aks,
508
+ data=self._input_data,
509
+ is_national=self.is_national,
510
+ )
511
+
512
+ @functools.cached_property
513
+ def controls_transformer(
514
+ self,
515
+ ) -> transformers.CenteringAndScalingTransformer | None:
516
+ """Returns a `CenteringAndScalingTransformer` for controls, if it exists."""
517
+ if self.controls is None:
518
+ return None
519
+
520
+ if self._model_spec.control_population_scaling_id is not None:
521
+ controls_population_scaling_id = backend.to_tensor(
522
+ self._model_spec.control_population_scaling_id, dtype=backend.bool_
523
+ )
524
+ else:
525
+ controls_population_scaling_id = None
526
+
527
+ return transformers.CenteringAndScalingTransformer(
528
+ tensor=self.controls,
529
+ population=self.population,
530
+ population_scaling_id=controls_population_scaling_id,
531
+ )
532
+
533
+ @functools.cached_property
534
+ def non_media_transformer(
535
+ self,
536
+ ) -> transformers.CenteringAndScalingTransformer | None:
537
+ """Returns a `CenteringAndScalingTransformer` for non-media treatments."""
538
+ if self.non_media_treatments is None:
539
+ return None
540
+ if self._model_spec.non_media_population_scaling_id is not None:
541
+ non_media_population_scaling_id = backend.to_tensor(
542
+ self._model_spec.non_media_population_scaling_id, dtype=backend.bool_
543
+ )
544
+ else:
545
+ non_media_population_scaling_id = None
546
+
547
+ return transformers.CenteringAndScalingTransformer(
548
+ tensor=self.non_media_treatments,
549
+ population=self.population,
550
+ population_scaling_id=non_media_population_scaling_id,
551
+ )
552
+
553
+ @functools.cached_property
554
+ def kpi_transformer(self) -> transformers.KpiTransformer:
555
+ return transformers.KpiTransformer(self.kpi, self.population)
556
+
557
+ @functools.cached_property
558
+ def controls_scaled(self) -> backend.Tensor | None:
559
+ if self.controls is not None:
560
+ # If `controls` is defined, then `controls_transformer` is also defined.
561
+ return self.controls_transformer.forward(self.controls) # pytype: disable=attribute-error
562
+ else:
563
+ return None
564
+
565
+ @functools.cached_property
566
+ def non_media_treatments_normalized(self) -> backend.Tensor | None:
567
+ """Normalized non-media treatments.
568
+
569
+ The non-media treatments values are scaled by population (for channels where
570
+ `non_media_population_scaling_id` is `True`) and normalized by centering and
571
+ scaling with means and standard deviations.
572
+ """
573
+ if self.non_media_transformer is not None:
574
+ return self.non_media_transformer.forward(
575
+ self.non_media_treatments
576
+ ) # pytype: disable=attribute-error
577
+ else:
578
+ return None
579
+
580
+ @functools.cached_property
581
+ def kpi_scaled(self) -> backend.Tensor:
582
+ return self.kpi_transformer.forward(self.kpi)
583
+
584
+ @functools.cached_property
585
+ def media_effects_dist(self) -> str:
586
+ if self.is_national:
587
+ return constants.NATIONAL_MODEL_SPEC_ARGS[constants.MEDIA_EFFECTS_DIST]
588
+ else:
589
+ return self._model_spec.media_effects_dist
590
+
591
+ @functools.cached_property
592
+ def unique_sigma_for_each_geo(self) -> bool:
593
+ if self.is_national:
594
+ # Should evaluate to False.
595
+ return constants.NATIONAL_MODEL_SPEC_ARGS[
596
+ constants.UNIQUE_SIGMA_FOR_EACH_GEO
597
+ ]
598
+ else:
599
+ return self._model_spec.unique_sigma_for_each_geo
600
+
601
+ @functools.cached_property
602
+ def baseline_geo_idx(self) -> int:
603
+ """Returns the index of the baseline geo."""
604
+ if isinstance(self._model_spec.baseline_geo, int):
605
+ if (
606
+ self._model_spec.baseline_geo < 0
607
+ or self._model_spec.baseline_geo >= self.n_geos
608
+ ):
609
+ raise ValueError(
610
+ f"Baseline geo index {self._model_spec.baseline_geo} out of range"
611
+ f" [0, {self.n_geos - 1}]."
612
+ )
613
+ return self._model_spec.baseline_geo
614
+ elif isinstance(self._model_spec.baseline_geo, str):
615
+ # np.where returns a 1-D tuple, its first element is an array of found
616
+ # elements.
617
+ index = np.where(self._input_data.geo == self._model_spec.baseline_geo)[0]
618
+ if index.size == 0:
619
+ raise ValueError(
620
+ f"Baseline geo '{self._model_spec.baseline_geo}' not found."
621
+ )
622
+ # Geos are unique, so index is a 1-element array.
623
+ return index[0]
624
+ else:
625
+ return backend.argmax(self.population)
626
+
627
+ @functools.cached_property
628
+ def holdout_id(self) -> backend.Tensor | None:
629
+ if self._model_spec.holdout_id is None:
630
+ return None
631
+ tensor = backend.to_tensor(self._model_spec.holdout_id, dtype=backend.bool_)
632
+ return tensor[backend.newaxis, ...] if self.is_national else tensor
633
+
634
+ def _warn_setting_ignored_priors(self):
635
+ """Raises a warning if ignored priors are set."""
636
+ default_distribution = prior_distribution.PriorDistribution()
637
+ for ignored_priors_dict, prior_type, prior_type_name in (
638
+ (
639
+ constants.IGNORED_PRIORS_MEDIA,
640
+ self._model_spec.effective_media_prior_type,
641
+ "media_prior_type",
642
+ ),
643
+ (
644
+ constants.IGNORED_PRIORS_RF,
645
+ self._model_spec.effective_rf_prior_type,
646
+ "rf_prior_type",
647
+ ),
648
+ ):
649
+ ignored_custom_priors = []
650
+ for prior in ignored_priors_dict.get(prior_type, []):
651
+ self_prior = getattr(self._model_spec.prior, prior)
652
+ default_prior = getattr(default_distribution, prior)
653
+ if not prior_distribution.distributions_are_equal(
654
+ self_prior, default_prior
655
+ ):
656
+ ignored_custom_priors.append(prior)
657
+ if ignored_custom_priors:
658
+ ignored_priors_str = ", ".join(ignored_custom_priors)
659
+ warnings.warn(
660
+ f"Custom prior(s) `{ignored_priors_str}` are ignored when"
661
+ f' `{prior_type_name}` is set to "{prior_type}".'
662
+ )
663
+
664
+ def _validate_mroi_priors_non_revenue(self):
665
+ """Validates mroi priors in the non-revenue outcome case."""
666
+ if (
667
+ self._input_data.kpi_type == constants.NON_REVENUE
668
+ and self._input_data.revenue_per_kpi is None
669
+ ):
670
+ default_distribution = prior_distribution.PriorDistribution()
671
+ if (
672
+ self.n_media_channels > 0
673
+ and (
674
+ self._model_spec.effective_media_prior_type
675
+ == constants.TREATMENT_PRIOR_TYPE_MROI
676
+ )
677
+ and prior_distribution.distributions_are_equal(
678
+ self._model_spec.prior.mroi_m, default_distribution.mroi_m
679
+ )
680
+ ):
681
+ raise ValueError(
682
+ f"Custom priors should be set on `{constants.MROI_M}` when"
683
+ ' `media_prior_type` is "mroi", KPI is non-revenue and revenue per'
684
+ " kpi data is missing."
685
+ )
686
+ if (
687
+ self.n_rf_channels > 0
688
+ and (
689
+ self._model_spec.effective_rf_prior_type
690
+ == constants.TREATMENT_PRIOR_TYPE_MROI
691
+ )
692
+ and prior_distribution.distributions_are_equal(
693
+ self._model_spec.prior.mroi_rf, default_distribution.mroi_rf
694
+ )
695
+ ):
696
+ raise ValueError(
697
+ f"Custom priors should be set on `{constants.MROI_RF}` when"
698
+ ' `rf_prior_type` is "mroi", KPI is non-revenue and revenue per kpi'
699
+ " data is missing."
700
+ )
701
+
702
+ def _validate_roi_priors_non_revenue(self):
703
+ """Validates roi priors in the non-revenue outcome case."""
704
+ if (
705
+ self._input_data.kpi_type == constants.NON_REVENUE
706
+ and self._input_data.revenue_per_kpi is None
707
+ ):
708
+ default_distribution = prior_distribution.PriorDistribution()
709
+ default_roi_m_used = (
710
+ self._model_spec.effective_media_prior_type
711
+ == constants.TREATMENT_PRIOR_TYPE_ROI
712
+ and prior_distribution.distributions_are_equal(
713
+ self._model_spec.prior.roi_m, default_distribution.roi_m
714
+ )
715
+ )
716
+ default_roi_rf_used = (
717
+ self._model_spec.effective_rf_prior_type
718
+ == constants.TREATMENT_PRIOR_TYPE_ROI
719
+ and prior_distribution.distributions_are_equal(
720
+ self._model_spec.prior.roi_rf, default_distribution.roi_rf
721
+ )
722
+ )
723
+ # If ROI priors are used with the default prior distribution for all paid
724
+ # channels (media and RF), then use the "total paid media contribution
725
+ # prior" procedure.
726
+ if (
727
+ (default_roi_m_used and default_roi_rf_used)
728
+ or (self.n_media_channels == 0 and default_roi_rf_used)
729
+ or (self.n_rf_channels == 0 and default_roi_m_used)
730
+ ):
731
+ self._set_total_media_contribution_prior = True
732
+ warnings.warn(
733
+ "Consider setting custom ROI priors, as kpi_type was specified as"
734
+ " `non_revenue` with no `revenue_per_kpi` being set. Otherwise, the"
735
+ " total media contribution prior will be used with"
736
+ f" `p_mean={constants.P_MEAN}` and `p_sd={constants.P_SD}`. Further"
737
+ " documentation available at "
738
+ " https://developers.google.com/meridian/docs/advanced-modeling/unknown-revenue-kpi-custom#set-total-paid-media-contribution-prior",
739
+ )
740
+ elif self.n_media_channels > 0 and default_roi_m_used:
741
+ raise ValueError(
742
+ f"Custom priors should be set on `{constants.ROI_M}` when"
743
+ ' `media_prior_type` is "roi", custom priors are assigned on'
744
+ ' `{constants.ROI_RF}` or `rf_prior_type` is not "roi", KPI is'
745
+ " non-revenue and revenue per kpi data is missing."
746
+ )
747
+ elif self.n_rf_channels > 0 and default_roi_rf_used:
748
+ raise ValueError(
749
+ f"Custom priors should be set on `{constants.ROI_RF}` when"
750
+ ' `rf_prior_type` is "roi", custom priors are assigned on'
751
+ ' `{constants.ROI_M}` or `media_prior_type` is not "roi", KPI is'
752
+ " non-revenue and revenue per kpi data is missing."
753
+ )
754
+
755
+ def _check_media_prior_support(self) -> None:
756
+ """Checks ROI, mROI, and Contribution prior support when random effects are log-normal.
757
+
758
+ Priors for ROI, mROI, and Contribution can only have negative support if the
759
+ random effects follow a normal distribution. This check enforces that priors
760
+ have non-negative support when random effects follow a log-normal
761
+ distribution. This check only applies to geo-level models with log-normal
762
+ random effects since national models do not have random effects.
763
+ """
764
+ prior = self._model_spec.prior
765
+ if self.n_media_channels > 0:
766
+ self._check_for_negative_support(
767
+ prior.roi_m,
768
+ self.media_effects_dist,
769
+ constants.TREATMENT_PRIOR_TYPE_ROI,
770
+ )
771
+ self._check_for_negative_support(
772
+ prior.mroi_m,
773
+ self.media_effects_dist,
774
+ constants.TREATMENT_PRIOR_TYPE_MROI,
775
+ )
776
+ self._check_for_negative_support(
777
+ prior.contribution_m,
778
+ self.media_effects_dist,
779
+ constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION,
780
+ )
781
+ if self.n_rf_channels > 0:
782
+ self._check_for_negative_support(
783
+ prior.roi_rf,
784
+ self.media_effects_dist,
785
+ constants.TREATMENT_PRIOR_TYPE_ROI,
786
+ )
787
+ self._check_for_negative_support(
788
+ prior.mroi_rf,
789
+ self.media_effects_dist,
790
+ constants.TREATMENT_PRIOR_TYPE_MROI,
791
+ )
792
+ self._check_for_negative_support(
793
+ prior.contribution_rf,
794
+ self.media_effects_dist,
795
+ constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION,
796
+ )
797
+ if self.n_organic_media_channels > 0:
798
+ self._check_for_negative_support(
799
+ prior.contribution_om,
800
+ self.media_effects_dist,
801
+ constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION,
802
+ )
803
+ if self.n_organic_rf_channels > 0:
804
+ self._check_for_negative_support(
805
+ prior.contribution_orf,
806
+ self.media_effects_dist,
807
+ constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION,
808
+ )
809
+
810
+ def _check_for_negative_support(
811
+ self,
812
+ dist: backend.tfd.Distribution,
813
+ media_effects_dist: str,
814
+ prior_type: str,
815
+ ) -> None:
816
+ """Checks for negative support in prior distributions.
817
+
818
+ When `media_effects_dist` is `MEDIA_EFFECTS_LOG_NORMAL`, prior distributions
819
+ for media effects must be non-negative. This function raises a ValueError if
820
+ any part of the distribution's CDF is greater than 0 at 0, indicating some
821
+ probability mass below zero.
822
+
823
+ Args:
824
+ dist: The distribution to check.
825
+ media_effects_dist: The type of media effects distribution.
826
+ prior_type: The prior type that corresponds with current prior under test.
827
+
828
+ Raises:
829
+ ValueError: If the prior distribution has negative support when
830
+ `media_effects_dist` is `MEDIA_EFFECTS_LOG_NORMAL`.
831
+ """
832
+ if (
833
+ prior_type == self._model_spec.media_prior_type
834
+ and media_effects_dist == constants.MEDIA_EFFECTS_LOG_NORMAL
835
+ and np.any(dist.cdf(0) > 0)
836
+ ):
837
+ raise ValueError(
838
+ "Media priors must have non-negative support when"
839
+ f' `media_effects_dist`="{media_effects_dist}". Found negative prior'
840
+ f" distribution support for {dist.name}."
841
+ )
842
+
843
+ @functools.cached_property
844
+ def prior_broadcast(self) -> prior_distribution.PriorDistribution:
845
+ """Returns broadcasted `PriorDistribution` object."""
846
+ total_spend = self._input_data.get_total_spend()
847
+ # Total spend can have 1, 2 or 3 dimensions. Aggregate by channel.
848
+ if len(total_spend.shape) == 1:
849
+ # Already aggregated by channel.
850
+ agg_total_spend = total_spend
851
+ elif len(total_spend.shape) == 2:
852
+ agg_total_spend = np.sum(total_spend, axis=(0,))
853
+ else:
854
+ agg_total_spend = np.sum(total_spend, axis=(0, 1))
855
+
856
+ return self._model_spec.prior.broadcast(
857
+ n_geos=self.n_geos,
858
+ n_media_channels=self.n_media_channels,
859
+ n_rf_channels=self.n_rf_channels,
860
+ n_organic_media_channels=self.n_organic_media_channels,
861
+ n_organic_rf_channels=self.n_organic_rf_channels,
862
+ n_controls=self.n_controls,
863
+ n_non_media_channels=self.n_non_media_channels,
864
+ unique_sigma_for_each_geo=self.unique_sigma_for_each_geo,
865
+ n_knots=self.knot_info.n_knots,
866
+ is_national=self.is_national,
867
+ set_total_media_contribution_prior=self._set_total_media_contribution_prior,
868
+ kpi=np.sum(self._input_data.kpi.values),
869
+ total_spend=agg_total_spend,
870
+ )
871
+
872
+ @functools.cached_property
873
+ def adstock_decay_spec(self) -> adstock_hill.AdstockDecaySpec:
874
+ """Returns `AdstockDecaySpec` object with correctly mapped channels."""
875
+ if isinstance(self._model_spec.adstock_decay_spec, str):
876
+ return adstock_hill.AdstockDecaySpec.from_consistent_type(
877
+ self._model_spec.adstock_decay_spec
878
+ )
879
+
880
+ try:
881
+ return self._create_adstock_decay_functions_from_channel_map(
882
+ self._model_spec.adstock_decay_spec
883
+ )
884
+ except KeyError as e:
885
+ raise ValueError(
886
+ "Unrecognized channel names found in `adstock_decay_spec` keys"
887
+ f" {tuple(self._model_spec.adstock_decay_spec.keys())}. Keys should"
888
+ " either contain only channel_names"
889
+ f" {tuple(self._input_data.get_all_adstock_hill_channels().tolist())} or"
890
+ " be one or more of {'media', 'rf', 'organic_media',"
891
+ " 'organic_rf'}."
892
+ ) from e
893
+
894
+ def _create_adstock_decay_functions_from_channel_map(
895
+ self, channel_function_map: Mapping[str, str]
896
+ ) -> adstock_hill.AdstockDecaySpec:
897
+ """Create `AdstockDecaySpec` from mapping from channels to decay functions."""
898
+
899
+ for channel in channel_function_map:
900
+ if channel not in self._input_data.get_all_adstock_hill_channels():
901
+ raise KeyError(f"Channel {channel} not found in data.")
902
+
903
+ if self._input_data.media_channel is not None:
904
+ media_channel_builder = self._input_data.get_paid_media_channels_argument_builder().with_default_value(
905
+ constants.GEOMETRIC_DECAY
906
+ )
907
+ media_adstock_function = media_channel_builder(**channel_function_map)
908
+ else:
909
+ media_adstock_function = constants.GEOMETRIC_DECAY
910
+
911
+ if self._input_data.rf_channel is not None:
912
+ rf_channel_builder = self._input_data.get_paid_rf_channels_argument_builder().with_default_value(
913
+ constants.GEOMETRIC_DECAY
914
+ )
915
+ rf_adstock_function = rf_channel_builder(**channel_function_map)
916
+ else:
917
+ rf_adstock_function = constants.GEOMETRIC_DECAY
918
+
919
+ if self._input_data.organic_media_channel is not None:
920
+ organic_media_channel_builder = self._input_data.get_organic_media_channels_argument_builder().with_default_value(
921
+ constants.GEOMETRIC_DECAY
922
+ )
923
+ organic_media_adstock_function = organic_media_channel_builder(
924
+ **channel_function_map
925
+ )
926
+ else:
927
+ organic_media_adstock_function = constants.GEOMETRIC_DECAY
928
+
929
+ if self._input_data.organic_rf_channel is not None:
930
+ organic_rf_channel_builder = self._input_data.get_organic_rf_channels_argument_builder().with_default_value(
931
+ constants.GEOMETRIC_DECAY
932
+ )
933
+ organic_rf_adstock_function = organic_rf_channel_builder(
934
+ **channel_function_map
935
+ )
936
+ else:
937
+ organic_rf_adstock_function = constants.GEOMETRIC_DECAY
938
+
939
+ return adstock_hill.AdstockDecaySpec(
940
+ media=media_adstock_function,
941
+ rf=rf_adstock_function,
942
+ organic_media=organic_media_adstock_function,
943
+ organic_rf=organic_rf_adstock_function,
944
+ )
945
+
946
+ def create_inference_data_coords(
947
+ self, n_chains: int, n_draws: int
948
+ ) -> Mapping[str, np.ndarray | Sequence[str]]:
949
+ """Creates data coordinates for inference data."""
950
+ media_channel_names = (
951
+ self.input_data.media_channel
952
+ if self.input_data.media_channel is not None
953
+ else np.array([])
954
+ )
955
+ rf_channel_names = (
956
+ self.input_data.rf_channel
957
+ if self.input_data.rf_channel is not None
958
+ else np.array([])
959
+ )
960
+ organic_media_channel_names = (
961
+ self.input_data.organic_media_channel
962
+ if self.input_data.organic_media_channel is not None
963
+ else np.array([])
964
+ )
965
+ organic_rf_channel_names = (
966
+ self.input_data.organic_rf_channel
967
+ if self.input_data.organic_rf_channel is not None
968
+ else np.array([])
969
+ )
970
+ non_media_channel_names = (
971
+ self.input_data.non_media_channel
972
+ if self.input_data.non_media_channel is not None
973
+ else np.array([])
974
+ )
975
+ control_variable_names = (
976
+ self.input_data.control_variable
977
+ if self.input_data.control_variable is not None
978
+ else np.array([])
979
+ )
980
+ return {
981
+ constants.CHAIN: np.arange(n_chains),
982
+ constants.DRAW: np.arange(n_draws),
983
+ constants.GEO: self.input_data.geo,
984
+ constants.TIME: self.input_data.time,
985
+ constants.MEDIA_TIME: self.input_data.media_time,
986
+ constants.KNOTS: np.arange(self.knot_info.n_knots),
987
+ constants.CONTROL_VARIABLE: control_variable_names,
988
+ constants.NON_MEDIA_CHANNEL: non_media_channel_names,
989
+ constants.MEDIA_CHANNEL: media_channel_names,
990
+ constants.RF_CHANNEL: rf_channel_names,
991
+ constants.ORGANIC_MEDIA_CHANNEL: organic_media_channel_names,
992
+ constants.ORGANIC_RF_CHANNEL: organic_rf_channel_names,
993
+ }
994
+
995
+ def create_inference_data_dims(self) -> Mapping[str, Sequence[str]]:
996
+ """Creates data dimensions for inference data."""
997
+ inference_dims = dict(constants.INFERENCE_DIMS)
998
+ if self.unique_sigma_for_each_geo:
999
+ inference_dims[constants.SIGMA] = [constants.GEO]
1000
+ else:
1001
+ inference_dims[constants.SIGMA] = []
1002
+
1003
+ return {
1004
+ param: [constants.CHAIN, constants.DRAW] + list(dims)
1005
+ for param, dims in inference_dims.items()
1006
+ }
1007
+
1008
+ def populate_cached_properties(self):
1009
+ """Eagerly activates all cached properties.
1010
+
1011
+ This is useful for creating a `tf.function` computation graph with this
1012
+ Meridian object as part of a captured closure. Within the computation graph,
1013
+ internal state mutations are problematic, and so this method freezes the
1014
+ object's states before the computation graph is created.
1015
+ """
1016
+ cls = self.__class__
1017
+ # "Freeze" all @cached_property attributes by simply accessing them (with
1018
+ # `getattr()`).
1019
+ cached_properties = [
1020
+ attr
1021
+ for attr in dir(self)
1022
+ if isinstance(getattr(cls, attr, cls), functools.cached_property)
1023
+ ]
1024
+ for attr in cached_properties:
1025
+ _ = getattr(self, attr)
1026
+
1027
+ def expand_selected_time_dims(
1028
+ self,
1029
+ start_date: tc.Date = None,
1030
+ end_date: tc.Date = None,
1031
+ ) -> list[str] | None:
1032
+ """Validates and returns time dimension values based on the selected times.
1033
+
1034
+ If both `start_date` and `end_date` are None, returns None. If specified,
1035
+ both `start_date` and `end_date` are inclusive, and must be present in the
1036
+ time coordinates of the input data.
1037
+
1038
+ Args:
1039
+ start_date: Start date of the selected time period. If None, implies the
1040
+ earliest time dimension value in the input data.
1041
+ end_date: End date of the selected time period. If None, implies the
1042
+ latest time dimension value in the input data.
1043
+
1044
+ Returns:
1045
+ A list of time dimension values (as Meridian-formatted strings) in the
1046
+ input data within the selected time period, or do nothing and pass through
1047
+ None if both arguments are Nones, or if `start_date` and `end_date`
1048
+ correspond to the entire time range in the input data.
1049
+
1050
+ Raises:
1051
+ ValueError if `start_date` or `end_date` is not in the input data time
1052
+ dimensions.
1053
+ """
1054
+ expanded = self.input_data.time_coordinates.expand_selected_time_dims(
1055
+ start_date=start_date, end_date=end_date
1056
+ )
1057
+ if expanded is None:
1058
+ return None
1059
+ return [date.strftime(constants.DATE_FORMAT) for date in expanded]