google-meridian 1.3.2__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/METADATA +18 -11
  2. google_meridian-1.5.0.dist-info/RECORD +112 -0
  3. {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/WHEEL +1 -1
  4. {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/top_level.txt +1 -0
  5. meridian/analysis/analyzer.py +558 -398
  6. meridian/analysis/optimizer.py +90 -68
  7. meridian/analysis/review/reviewer.py +4 -1
  8. meridian/analysis/summarizer.py +13 -3
  9. meridian/analysis/test_utils.py +2911 -2102
  10. meridian/analysis/visualizer.py +37 -14
  11. meridian/backend/__init__.py +106 -0
  12. meridian/constants.py +2 -0
  13. meridian/data/input_data.py +30 -52
  14. meridian/data/input_data_builder.py +2 -9
  15. meridian/data/test_utils.py +107 -51
  16. meridian/data/validator.py +48 -0
  17. meridian/mlflow/autolog.py +19 -9
  18. meridian/model/__init__.py +2 -0
  19. meridian/model/adstock_hill.py +3 -5
  20. meridian/model/context.py +1059 -0
  21. meridian/model/eda/constants.py +335 -4
  22. meridian/model/eda/eda_engine.py +723 -312
  23. meridian/model/eda/eda_outcome.py +177 -33
  24. meridian/model/equations.py +418 -0
  25. meridian/model/knots.py +58 -47
  26. meridian/model/model.py +228 -878
  27. meridian/model/model_test_data.py +38 -0
  28. meridian/model/posterior_sampler.py +103 -62
  29. meridian/model/prior_sampler.py +114 -94
  30. meridian/model/spec.py +23 -14
  31. meridian/templates/card.html.jinja +9 -7
  32. meridian/templates/chart.html.jinja +1 -6
  33. meridian/templates/finding.html.jinja +19 -0
  34. meridian/templates/findings.html.jinja +33 -0
  35. meridian/templates/formatter.py +41 -5
  36. meridian/templates/formatter_test.py +127 -0
  37. meridian/templates/style.css +66 -9
  38. meridian/templates/style.scss +85 -4
  39. meridian/templates/table.html.jinja +1 -0
  40. meridian/version.py +1 -1
  41. scenarioplanner/__init__.py +42 -0
  42. scenarioplanner/converters/__init__.py +25 -0
  43. scenarioplanner/converters/dataframe/__init__.py +28 -0
  44. scenarioplanner/converters/dataframe/budget_opt_converters.py +383 -0
  45. scenarioplanner/converters/dataframe/common.py +71 -0
  46. scenarioplanner/converters/dataframe/constants.py +137 -0
  47. scenarioplanner/converters/dataframe/converter.py +42 -0
  48. scenarioplanner/converters/dataframe/dataframe_model_converter.py +70 -0
  49. scenarioplanner/converters/dataframe/marketing_analyses_converters.py +543 -0
  50. scenarioplanner/converters/dataframe/rf_opt_converters.py +314 -0
  51. scenarioplanner/converters/mmm.py +743 -0
  52. scenarioplanner/converters/mmm_converter.py +58 -0
  53. scenarioplanner/converters/sheets.py +156 -0
  54. scenarioplanner/converters/test_data.py +714 -0
  55. scenarioplanner/linkingapi/__init__.py +47 -0
  56. scenarioplanner/linkingapi/constants.py +27 -0
  57. scenarioplanner/linkingapi/url_generator.py +131 -0
  58. scenarioplanner/mmm_ui_proto_generator.py +355 -0
  59. schema/__init__.py +5 -2
  60. schema/mmm_proto_generator.py +71 -0
  61. schema/model_consumer.py +133 -0
  62. schema/processors/__init__.py +77 -0
  63. schema/processors/budget_optimization_processor.py +832 -0
  64. schema/processors/common.py +64 -0
  65. schema/processors/marketing_processor.py +1137 -0
  66. schema/processors/model_fit_processor.py +367 -0
  67. schema/processors/model_kernel_processor.py +117 -0
  68. schema/processors/model_processor.py +415 -0
  69. schema/processors/reach_frequency_optimization_processor.py +584 -0
  70. schema/serde/distribution.py +12 -7
  71. schema/serde/hyperparameters.py +54 -107
  72. schema/serde/meridian_serde.py +6 -1
  73. schema/test_data.py +380 -0
  74. schema/utils/__init__.py +2 -0
  75. schema/utils/date_range_bucketing.py +117 -0
  76. schema/utils/proto_enum_converter.py +127 -0
  77. google_meridian-1.3.2.dist-info/RECORD +0 -76
  78. {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -642,6 +642,7 @@ def random_media_da(
642
642
  channel_variable_name: str = 'media_channel',
643
643
  channel_prefix: str = 'ch_',
644
644
  integer_geos: bool = False,
645
+ nonzero_shift: float = 0.0,
645
646
  ) -> xr.DataArray:
646
647
  """Generates a sample `media` DataArray.
647
648
 
@@ -662,6 +663,7 @@ def random_media_da(
662
663
  channel_variable_name: The name of the channel variable
663
664
  channel_prefix: The prefix of the channel names
664
665
  integer_geos: If True, the geos will be integers.
666
+ nonzero_shift: A scalar value to add to the generated data.
665
667
 
666
668
  Returns:
667
669
  A DataArray containing random data.
@@ -695,6 +697,8 @@ def random_media_da(
695
697
  )
696
698
  )
697
699
 
700
+ media = media + nonzero_shift
701
+
698
702
  if explicit_geo_names is None:
699
703
  geos = sample_geos(n_geos, integer_geos)
700
704
  else:
@@ -736,6 +740,7 @@ def random_organic_media_da(
736
740
  explicit_time_index: Sequence[str] | None = None,
737
741
  explicit_media_channel_names: Sequence[str] | None = None,
738
742
  integer_geos: bool = False,
743
+ nonzero_shift: float = 0.0,
739
744
  ) -> xr.DataArray:
740
745
  """Generates a sample `organic_media` DataArray."""
741
746
  return random_media_da(
@@ -751,6 +756,7 @@ def random_organic_media_da(
751
756
  channel_variable_name='organic_media_channel',
752
757
  channel_prefix='organic_media_',
753
758
  integer_geos=integer_geos,
759
+ nonzero_shift=nonzero_shift,
754
760
  )
755
761
 
756
762
 
@@ -761,6 +767,7 @@ def random_media_spend_nd_da(
761
767
  seed=0,
762
768
  integer_geos: bool = False,
763
769
  explicit_media_channel_names: Sequence[str] | None = None,
770
+ nonzero_shift: float = 0.0,
764
771
  ) -> xr.DataArray:
765
772
  """Generates a sample N-dimensional `media_spend` DataArray.
766
773
 
@@ -781,6 +788,7 @@ def random_media_spend_nd_da(
781
788
  integer_geos: If True, the geos will be integers.
782
789
  explicit_media_channel_names: If given, ignore `n_media_channels` and use
783
790
  this as is.
791
+ nonzero_shift: A scalar value to add to the generated data.
784
792
 
785
793
  Returns:
786
794
  A DataArray containing the generated `media_spend` data with the given
@@ -818,7 +826,7 @@ def random_media_spend_nd_da(
818
826
  f'Shape {dims} not supported by the random_media_spend_nd_da function.'
819
827
  )
820
828
 
821
- media_spend = abs(np.random.normal(1, 1, size=shape))
829
+ media_spend = abs(np.random.normal(1, 1, size=shape)) + nonzero_shift
822
830
 
823
831
  return xr.DataArray(
824
832
  media_spend,
@@ -1007,8 +1015,27 @@ def random_reach_da(
1007
1015
  channel_variable_name: str = 'rf_channel',
1008
1016
  channel_prefix: str = 'rf_ch_',
1009
1017
  integer_geos: bool = False,
1018
+ nonzero_shift: float = 0.0,
1010
1019
  ) -> xr.DataArray:
1011
- """Generates a sample `reach` DataArray."""
1020
+ """Generates a sample `reach` DataArray.
1021
+
1022
+ Args:
1023
+ n_geos: Number of geos
1024
+ n_times: Number of time periods
1025
+ n_media_times: Number of media time periods
1026
+ n_rf_channels: Number of reach and frequency channels
1027
+ seed: Random seed used by `np.random.seed()`
1028
+ explicit_rf_channel_names: If given, ignore `n_rf_channels` and use this as
1029
+ is
1030
+ array_name: The name of the array to be created
1031
+ channel_variable_name: The name of the channel variable
1032
+ channel_prefix: The prefix of the channel names
1033
+ integer_geos: If True, the geos will be integers.
1034
+ nonzero_shift: A scalar value to add to the generated data.
1035
+
1036
+ Returns:
1037
+ A DataArray containing random data.
1038
+ """
1012
1039
 
1013
1040
  np.random.seed(seed)
1014
1041
 
@@ -1016,12 +1043,15 @@ def random_reach_da(
1016
1043
  if n_times < n_media_times:
1017
1044
  start_date -= datetime.timedelta(weeks=(n_media_times - n_times))
1018
1045
 
1019
- reach = np.round(
1020
- abs(
1021
- np.random.normal(
1022
- 3000, 100, size=(n_geos, n_media_times, n_rf_channels)
1046
+ reach = (
1047
+ np.round(
1048
+ abs(
1049
+ np.random.normal(
1050
+ 3000, 100, size=(n_geos, n_media_times, n_rf_channels)
1051
+ )
1023
1052
  )
1024
1053
  )
1054
+ + nonzero_shift
1025
1055
  )
1026
1056
 
1027
1057
  channels = (
@@ -1051,6 +1081,7 @@ def random_organic_reach_da(
1051
1081
  seed: int = 0,
1052
1082
  explicit_organic_rf_channel_names: Sequence[str] | None = None,
1053
1083
  integer_geos: bool = False,
1084
+ nonzero_shift: float = 0.0,
1054
1085
  ) -> xr.DataArray:
1055
1086
  """Generates a sample `organic_reach` DataArray."""
1056
1087
  return random_reach_da(
@@ -1064,6 +1095,7 @@ def random_organic_reach_da(
1064
1095
  channel_variable_name='organic_rf_channel',
1065
1096
  channel_prefix='organic_rf_ch_',
1066
1097
  integer_geos=integer_geos,
1098
+ nonzero_shift=nonzero_shift,
1067
1099
  )
1068
1100
 
1069
1101
 
@@ -1078,8 +1110,27 @@ def random_frequency_da(
1078
1110
  channel_variable_name: str = 'rf_channel',
1079
1111
  channel_prefix: str = 'rf_ch_',
1080
1112
  integer_geos: bool = False,
1113
+ nonzero_shift: float = 0.0,
1081
1114
  ) -> xr.DataArray:
1082
- """Generates a sample `frequency` DataArray."""
1115
+ """Generates a sample `frequency` DataArray.
1116
+
1117
+ Args:
1118
+ n_geos: Number of geos
1119
+ n_times: Number of time periods
1120
+ n_media_times: Number of media time periods
1121
+ n_rf_channels: Number of reach and frequency channels
1122
+ seed: Random seed used by `np.random.seed()`
1123
+ explicit_rf_channel_names: If given, ignore `n_rf_channels` and use this as
1124
+ is
1125
+ array_name: The name of the array to be created
1126
+ channel_variable_name: The name of the channel variable
1127
+ channel_prefix: The prefix of the channel names
1128
+ integer_geos: If True, the geos will be integers.
1129
+ nonzero_shift: A scalar value to add to the generated data.
1130
+
1131
+ Returns:
1132
+ A DataArray containing random data.
1133
+ """
1083
1134
 
1084
1135
  np.random.seed(seed)
1085
1136
 
@@ -1087,8 +1138,9 @@ def random_frequency_da(
1087
1138
  if n_times < n_media_times:
1088
1139
  start_date -= datetime.timedelta(weeks=(n_media_times - n_times))
1089
1140
 
1090
- frequency = abs(
1091
- np.random.normal(3, 5, size=(n_geos, n_media_times, n_rf_channels))
1141
+ frequency = (
1142
+ abs(np.random.normal(3, 5, size=(n_geos, n_media_times, n_rf_channels)))
1143
+ + nonzero_shift
1092
1144
  )
1093
1145
 
1094
1146
  channels = (
@@ -1119,6 +1171,7 @@ def random_organic_frequency_da(
1119
1171
  seed: int = 0,
1120
1172
  explicit_organic_rf_channel_names: Sequence[str] | None = None,
1121
1173
  integer_geos: bool = False,
1174
+ nonzero_shift: float = 0.0,
1122
1175
  ) -> xr.DataArray:
1123
1176
  """Generates a sample `organic_frequency` DataArray."""
1124
1177
  return random_frequency_da(
@@ -1132,6 +1185,7 @@ def random_organic_frequency_da(
1132
1185
  channel_variable_name='organic_rf_channel',
1133
1186
  channel_prefix='organic_rf_ch_',
1134
1187
  integer_geos=integer_geos,
1188
+ nonzero_shift=nonzero_shift,
1135
1189
  )
1136
1190
 
1137
1191
 
@@ -1141,6 +1195,7 @@ def random_rf_spend_nd_da(
1141
1195
  n_rf_channels: int | None = None,
1142
1196
  seed=0,
1143
1197
  integer_geos: bool = False,
1198
+ nonzero_shift: float = 0.0,
1144
1199
  ) -> xr.DataArray:
1145
1200
  """Generates a sample N-dimensional `rf_spend` DataArray.
1146
1201
 
@@ -1157,6 +1212,7 @@ def random_rf_spend_nd_da(
1157
1212
  n_rf_channels: Number of channels in the created `rf_spend` array.
1158
1213
  seed: Random seed used by `np.random.seed()`.
1159
1214
  integer_geos: If True, the geos will be integers.
1215
+ nonzero_shift: A scalar value to add to the generated data.
1160
1216
 
1161
1217
  Returns:
1162
1218
  A DataArray containing the generated `rf_spend` data with the given
@@ -1187,7 +1243,7 @@ def random_rf_spend_nd_da(
1187
1243
  f'Shape {dims} not supported by the random_rf_spend_nd_da function.'
1188
1244
  )
1189
1245
 
1190
- rf_spend = abs(np.random.normal(1, 1, size=shape))
1246
+ rf_spend = abs(np.random.normal(1, 1, size=shape)) + nonzero_shift
1191
1247
 
1192
1248
  return xr.DataArray(
1193
1249
  rf_spend,
@@ -1206,6 +1262,7 @@ def random_non_media_treatments_da(
1206
1262
  date_format: str = c.DATE_FORMAT,
1207
1263
  explicit_time_index: Sequence[str] | None = None,
1208
1264
  integer_geos: bool = False,
1265
+ nonzero_shift: float = 0.0,
1209
1266
  ) -> xr.DataArray:
1210
1267
  """Generates a sample `non_media_treatments` DataArray.
1211
1268
 
@@ -1218,6 +1275,7 @@ def random_non_media_treatments_da(
1218
1275
  date_format: The date format to use for time coordinate labels
1219
1276
  explicit_time_index: If given, ignore `date_format` and use this as is
1220
1277
  integer_geos: If True, the geos will be integers.
1278
+ nonzero_shift: A scalar value to add to the generated data.
1221
1279
 
1222
1280
  Returns:
1223
1281
  A DataArray containing random non-media variable.
@@ -1232,6 +1290,8 @@ def random_non_media_treatments_da(
1232
1290
  non_media_channel,
1233
1291
  size=(n_geos, n_times, n_non_media_channels),
1234
1292
  )
1293
+ non_media_treatments = non_media_treatments + nonzero_shift
1294
+
1235
1295
  return xr.DataArray(
1236
1296
  non_media_treatments,
1237
1297
  dims=['geo', 'time', 'non_media_channel'],
@@ -1268,6 +1328,7 @@ def random_dataset(
1268
1328
  remove_media_time: bool = False,
1269
1329
  integer_geos: bool = False,
1270
1330
  kpi_data_pattern: str = '',
1331
+ nonzero_shift: float = 0.0,
1271
1332
  ) -> xr.Dataset:
1272
1333
  """Generates a random dataset."""
1273
1334
  if n_media_channels:
@@ -1280,6 +1341,7 @@ def random_dataset(
1280
1341
  integer_geos=integer_geos,
1281
1342
  explicit_media_channel_names=explicit_media_channel_names,
1282
1343
  media_value_scales=media_value_scales,
1344
+ nonzero_shift=nonzero_shift,
1283
1345
  )
1284
1346
  media_spend = random_media_spend_nd_da(
1285
1347
  n_geos=n_geos,
@@ -1288,6 +1350,7 @@ def random_dataset(
1288
1350
  explicit_media_channel_names=explicit_media_channel_names,
1289
1351
  seed=seed,
1290
1352
  integer_geos=integer_geos,
1353
+ nonzero_shift=nonzero_shift,
1291
1354
  )
1292
1355
  else:
1293
1356
  media = None
@@ -1301,6 +1364,7 @@ def random_dataset(
1301
1364
  n_rf_channels=n_rf_channels,
1302
1365
  seed=seed,
1303
1366
  integer_geos=integer_geos,
1367
+ nonzero_shift=nonzero_shift,
1304
1368
  )
1305
1369
  frequency = random_frequency_da(
1306
1370
  n_geos=n_geos,
@@ -1309,6 +1373,7 @@ def random_dataset(
1309
1373
  n_rf_channels=n_rf_channels,
1310
1374
  seed=seed,
1311
1375
  integer_geos=integer_geos,
1376
+ nonzero_shift=nonzero_shift,
1312
1377
  )
1313
1378
  rf_spend = random_rf_spend_nd_da(
1314
1379
  n_geos=n_geos,
@@ -1316,6 +1381,7 @@ def random_dataset(
1316
1381
  n_rf_channels=n_rf_channels,
1317
1382
  seed=seed,
1318
1383
  integer_geos=integer_geos,
1384
+ nonzero_shift=nonzero_shift,
1319
1385
  )
1320
1386
  else:
1321
1387
  reach = None
@@ -1352,6 +1418,7 @@ def random_dataset(
1352
1418
  n_non_media_channels=n_non_media_channels,
1353
1419
  seed=seed,
1354
1420
  integer_geos=integer_geos,
1421
+ nonzero_shift=nonzero_shift,
1355
1422
  )
1356
1423
  else:
1357
1424
  non_media_treatments = None
@@ -1364,6 +1431,7 @@ def random_dataset(
1364
1431
  n_organic_media_channels=n_organic_media_channels,
1365
1432
  seed=seed,
1366
1433
  integer_geos=integer_geos,
1434
+ nonzero_shift=nonzero_shift,
1367
1435
  )
1368
1436
  else:
1369
1437
  organic_media = None
@@ -1376,6 +1444,7 @@ def random_dataset(
1376
1444
  n_organic_rf_channels=n_organic_rf_channels,
1377
1445
  seed=seed,
1378
1446
  integer_geos=integer_geos,
1447
+ nonzero_shift=nonzero_shift,
1379
1448
  )
1380
1449
  organic_frequency = random_organic_frequency_da(
1381
1450
  n_geos=n_geos,
@@ -1384,6 +1453,7 @@ def random_dataset(
1384
1453
  n_organic_rf_channels=n_organic_rf_channels,
1385
1454
  seed=seed,
1386
1455
  integer_geos=integer_geos,
1456
+ nonzero_shift=nonzero_shift,
1387
1457
  )
1388
1458
  else:
1389
1459
  organic_reach = None
@@ -1406,53 +1476,37 @@ def random_dataset(
1406
1476
  constant_value=constant_population_value,
1407
1477
  )
1408
1478
 
1409
- dataset = xr.combine_by_coords(
1410
- [kpi, population] + ([controls] if controls is not None else [])
1411
- )
1479
+ to_merge = [kpi, population]
1480
+ if controls is not None:
1481
+ to_merge.append(controls)
1412
1482
  if revenue_per_kpi is not None:
1413
- dataset = xr.combine_by_coords([dataset, revenue_per_kpi])
1483
+ to_merge.append(revenue_per_kpi)
1414
1484
  if media is not None:
1415
- media_renamed = (
1416
- media.rename({'media_time': 'time'}) if remove_media_time else media
1417
- )
1418
-
1419
- dataset = xr.combine_by_coords([dataset, media_renamed, media_spend])
1485
+ if remove_media_time:
1486
+ media = media.rename({'media_time': 'time'})
1487
+ to_merge.append(media)
1488
+ to_merge.append(media_spend)
1420
1489
  if reach is not None:
1421
- reach_renamed = (
1422
- reach.rename({'media_time': 'time'}) if remove_media_time else reach
1423
- )
1424
- frequency_renamed = (
1425
- frequency.rename({'media_time': 'time'})
1426
- if remove_media_time
1427
- else frequency
1428
- )
1429
- dataset = xr.combine_by_coords(
1430
- [dataset, reach_renamed, frequency_renamed, rf_spend]
1431
- )
1490
+ if remove_media_time:
1491
+ reach = reach.rename({'media_time': 'time'})
1492
+ frequency = frequency.rename({'media_time': 'time'})
1493
+ to_merge.append(reach)
1494
+ to_merge.append(frequency)
1495
+ to_merge.append(rf_spend)
1432
1496
  if non_media_treatments is not None:
1433
- dataset = xr.combine_by_coords([dataset, non_media_treatments])
1497
+ to_merge.append(non_media_treatments)
1434
1498
  if organic_media is not None:
1435
- organic_media_renamed = (
1436
- organic_media.rename({'media_time': 'time'})
1437
- if remove_media_time
1438
- else organic_media
1439
- )
1440
- dataset = xr.combine_by_coords([dataset, organic_media_renamed])
1499
+ if remove_media_time:
1500
+ organic_media = organic_media.rename({'media_time': 'time'})
1501
+ to_merge.append(organic_media)
1441
1502
  if organic_reach is not None:
1442
- organic_reach_renamed = (
1443
- organic_reach.rename({'media_time': 'time'})
1444
- if remove_media_time
1445
- else organic_reach
1446
- )
1447
- organic_frequency_renamed = (
1448
- organic_frequency.rename({'media_time': 'time'})
1449
- if remove_media_time
1450
- else organic_frequency
1451
- )
1452
- dataset = xr.combine_by_coords(
1453
- [dataset, organic_reach_renamed, organic_frequency_renamed]
1454
- )
1455
- return dataset
1503
+ if remove_media_time:
1504
+ organic_reach = organic_reach.rename({'media_time': 'time'})
1505
+ organic_frequency = organic_frequency.rename({'media_time': 'time'})
1506
+ to_merge.append(organic_reach)
1507
+ to_merge.append(organic_frequency)
1508
+
1509
+ return xr.merge(to_merge, join='outer', compat='no_conflicts')
1456
1510
 
1457
1511
 
1458
1512
  def dataset_to_dataframe(
@@ -1794,6 +1848,7 @@ def sample_input_data_non_revenue_revenue_per_kpi(
1794
1848
  n_organic_media_channels: int | None = None,
1795
1849
  n_organic_rf_channels: int | None = None,
1796
1850
  seed: int = 0,
1851
+ nonzero_shift: float = 0.0,
1797
1852
  ) -> input_data.InputData:
1798
1853
  """Generates sample InputData for `non_revenue` KPI w/ revenue_per_kpi."""
1799
1854
  dataset = random_dataset(
@@ -1807,6 +1862,7 @@ def sample_input_data_non_revenue_revenue_per_kpi(
1807
1862
  n_organic_media_channels=n_organic_media_channels,
1808
1863
  n_organic_rf_channels=n_organic_rf_channels,
1809
1864
  seed=seed,
1865
+ nonzero_shift=nonzero_shift,
1810
1866
  )
1811
1867
  return input_data.InputData(
1812
1868
  kpi=dataset.kpi,
@@ -0,0 +1,48 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """This module contains common validation functions for Meridian data."""
16
+
17
+ import datetime as dt
18
+ from meridian import constants
19
+ import xarray as xr
20
+
21
+
22
+ def validate_time_coord_format(array: xr.DataArray | None):
23
+ """Validates the `time` dimensions format of the selected DataArray.
24
+
25
+ The `time` dimension of the selected array must have labels that are
26
+ formatted in the Meridian conventional `"yyyy-mm-dd"` format.
27
+
28
+ Args:
29
+ array: An optional DataArray to validate.
30
+ """
31
+ if array is None:
32
+ return
33
+
34
+ # The component data arrays from the input data builders that call this helper
35
+ # method should only have one of either `media_time` or `time` as its time
36
+ # dimension.
37
+ target_coords = [constants.TIME, constants.MEDIA_TIME]
38
+
39
+ for coord_name in target_coords:
40
+ if (values := array.coords.get(coord_name)) is not None:
41
+ for time in values:
42
+ try:
43
+ dt.datetime.strptime(time.item(), constants.DATE_FORMAT)
44
+ except (TypeError, ValueError) as exc:
45
+ raise ValueError(
46
+ f"Invalid {coord_name} label: {time.item()!r}. "
47
+ f"Expected format: '{constants.DATE_FORMAT}'"
48
+ ) from exc
@@ -70,6 +70,7 @@ import dataclasses
70
70
  import inspect
71
71
  import json
72
72
  from typing import Any, Callable
73
+ import warnings
73
74
 
74
75
  import arviz as az
75
76
  from meridian import backend
@@ -180,16 +181,25 @@ def autolog(
180
181
  f"sample_posterior.{param}", kwargs.get(param, "default")
181
182
  )
182
183
 
183
- original(self, *args, **kwargs)
184
+ result = original(self, *args, **kwargs)
184
185
  if log_metrics:
185
- model_diagnostics = visualizer.ModelDiagnostics(self.model)
186
- df_diag = model_diagnostics.predictive_accuracy_table()
187
-
188
- get_metric = lambda n: df_diag[df_diag.metric == n].value.to_list()[0]
189
-
190
- mlflow.log_metric("R_Squared", get_metric("R_Squared"))
191
- mlflow.log_metric("MAPE", get_metric("MAPE"))
192
- mlflow.log_metric("wMAPE", get_metric("wMAPE"))
186
+ # TODO: Direct injection of `model.Meridian` object into
187
+ # `PosteriorMCMCSampler` is deprecated. Revisit patching method here.
188
+ if self.model is not None:
189
+ model_diagnostics = visualizer.ModelDiagnostics(self.model)
190
+ df_diag = model_diagnostics.predictive_accuracy_table()
191
+
192
+ get_metric = lambda n: df_diag[df_diag.metric == n].value.to_list()[0]
193
+
194
+ mlflow.log_metric("R_Squared", get_metric("R_Squared"))
195
+ mlflow.log_metric("MAPE", get_metric("MAPE"))
196
+ mlflow.log_metric("wMAPE", get_metric("wMAPE"))
197
+ else:
198
+ warnings.warn(
199
+ "log_metrics=True is not supported when PosteriorMCMCSampler is"
200
+ " initialized with model_context."
201
+ )
202
+ return result
193
203
 
194
204
  safe_patch(FLAVOR_NAME, model.Meridian, "__init__", patch_meridian_init)
195
205
  safe_patch(
@@ -15,7 +15,9 @@
15
15
  """The Meridian API module that models the data."""
16
16
 
17
17
  from meridian.model import adstock_hill
18
+ from meridian.model import context
18
19
  from meridian.model import eda
20
+ from meridian.model import equations
19
21
  from meridian.model import knots
20
22
  from meridian.model import media
21
23
  from meridian.model import model
@@ -279,10 +279,6 @@ def _adstock(
279
279
  media = backend.concatenate([backend.zeros(pad_shape), media], axis=-2)
280
280
 
281
281
  # Adstock calculation.
282
- window_list = [None] * window_size
283
- for i in range(window_size):
284
- window_list[i] = media[..., i : i + n_times_output, :]
285
- windowed = backend.stack(window_list)
286
282
  l_range = backend.arange(window_size - 1, -1, -1, dtype=backend.float32)
287
283
  weights = compute_decay_weights(
288
284
  alpha=alpha,
@@ -291,7 +287,9 @@ def _adstock(
291
287
  decay_functions=decay_functions,
292
288
  normalize=True,
293
289
  )
294
- return backend.einsum('...mw,w...gtm->...gtm', weights, windowed)
290
+ return backend.adstock_process(
291
+ media=media, weights=weights, n_times_output=n_times_output
292
+ )
295
293
 
296
294
 
297
295
  def _map_alpha_for_binomial_decay(x: backend.Tensor):