google-meridian 1.0.9__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -65,12 +65,24 @@ _REQUIRED_COORDS = immutabledict.immutabledict({
65
65
  c.MEDIA_TIME: _sample_times(n_times=3),
66
66
  c.CONTROL_VARIABLE: ['control_0', 'control_1'],
67
67
  })
68
+ _NON_MEDIA_COORDS = immutabledict.immutabledict(
69
+ {c.NON_MEDIA_CHANNEL: ['non_media_channel_0', 'non_media_channel_1']}
70
+ )
68
71
  _MEDIA_COORDS = immutabledict.immutabledict(
69
72
  {c.MEDIA_CHANNEL: ['media_channel_0', 'media_channel_1', 'media_channel_2']}
70
73
  )
74
+ _ORGANIC_MEDIA_COORDS = immutabledict.immutabledict({
75
+ c.ORGANIC_MEDIA_CHANNEL: [
76
+ 'organic_media_channel_0',
77
+ 'organic_media_channel_1',
78
+ ]
79
+ })
71
80
  _RF_COORDS = immutabledict.immutabledict(
72
81
  {c.RF_CHANNEL: ['rf_channel_0', 'rf_channel_1']}
73
82
  )
83
+ _ORGANIC_RF_COORDS = immutabledict.immutabledict(
84
+ {c.ORGANIC_RF_CHANNEL: ['organic_rf_channel_0', 'organic_rf_channel_1']}
85
+ )
74
86
 
75
87
  _REQUIRED_DATA_VARS = immutabledict.immutabledict({
76
88
  c.KPI: (['geo', 'time'], [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]),
@@ -376,6 +388,67 @@ DATASET_WITHOUT_TIME_VARIATION_IN_REACH = xr.Dataset(
376
388
  },
377
389
  )
378
390
 
391
+ DATASET_WITHOUT_TIME_VARIATION_IN_ORGANIC_MEDIA = xr.Dataset(
392
+ coords=_REQUIRED_COORDS
393
+ | _MEDIA_COORDS
394
+ | _RF_COORDS
395
+ | _ORGANIC_MEDIA_COORDS,
396
+ data_vars=_REQUIRED_DATA_VARS
397
+ | _MEDIA_DATA_VARS
398
+ | _RF_DATA_VARS
399
+ | _OPTIONAL_DATA_VARS
400
+ | {
401
+ c.ORGANIC_MEDIA: (
402
+ ['geo', 'media_time', 'organic_media_channel'],
403
+ [
404
+ [[2.1, 2.2], [2.1, 2.21], [2.1, 2.2]],
405
+ [[2.7, 2.8], [2.7, 2.8], [2.7, 2.8]],
406
+ ],
407
+ ),
408
+ },
409
+ )
410
+
411
+ DATASET_WITHOUT_TIME_VARIATION_IN_ORGANIC_REACH = xr.Dataset(
412
+ coords=_REQUIRED_COORDS | _MEDIA_COORDS | _RF_COORDS | _ORGANIC_RF_COORDS,
413
+ data_vars=_REQUIRED_DATA_VARS
414
+ | _MEDIA_DATA_VARS
415
+ | _RF_DATA_VARS
416
+ | _OPTIONAL_DATA_VARS
417
+ | {
418
+ c.ORGANIC_REACH: (
419
+ ['geo', 'media_time', 'organic_rf_channel'],
420
+ [
421
+ [[2.1, 2.2], [2.11, 2.2], [2.1, 2.2]],
422
+ [[2.7, 2.8], [2.7, 2.8], [2.7, 2.8]],
423
+ ],
424
+ ),
425
+ c.ORGANIC_FREQUENCY: (
426
+ ['geo', 'media_time', 'organic_rf_channel'],
427
+ [
428
+ [[7.1, 7.2], [7.3, 7.4], [7.5, 7.6]],
429
+ [[7.11, 7.21], [7.31, 7.41], [7.51, 7.61]],
430
+ ],
431
+ ),
432
+ },
433
+ )
434
+
435
+ DATASET_WITHOUT_TIME_VARIATION_IN_NON_MEDIA_TREATMENTS = xr.Dataset(
436
+ coords=_REQUIRED_COORDS | _MEDIA_COORDS | _RF_COORDS | _NON_MEDIA_COORDS,
437
+ data_vars=_REQUIRED_DATA_VARS
438
+ | _MEDIA_DATA_VARS
439
+ | _RF_DATA_VARS
440
+ | _OPTIONAL_DATA_VARS
441
+ | {
442
+ c.NON_MEDIA_TREATMENTS: (
443
+ ['geo', 'time', 'non_media_channel'],
444
+ [
445
+ [[2.1, 2.2], [2.1, 2.2], [2.1, 2.2]],
446
+ [[2.7, 2.8], [2.7, 2.8], [2.7, 2.8]],
447
+ ],
448
+ ),
449
+ },
450
+ )
451
+
379
452
  _NATIONAL_COORDS = immutabledict.immutabledict({
380
453
  c.TIME: [
381
454
  _SAMPLE_START_DATE.strftime(c.DATE_FORMAT),
@@ -742,11 +815,11 @@ def random_controls_da(
742
815
 
743
816
  def random_kpi_da(
744
817
  media: xr.DataArray,
745
- controls: xr.DataArray,
746
818
  n_geos: int,
747
819
  n_times: int,
748
820
  n_media_channels: int,
749
- n_controls: int,
821
+ n_controls: int | None = None,
822
+ controls: xr.DataArray | None = None,
750
823
  seed: int = 0,
751
824
  integer_geos: bool = False,
752
825
  ) -> xr.DataArray:
@@ -762,19 +835,26 @@ def random_kpi_da(
762
835
  n_media_channels,
763
836
  axis=2,
764
837
  )
765
- control_geo_sd = abs(np.random.normal(0, 5, size=n_geos))
766
- control_geo_sd = np.repeat(
767
- np.repeat(control_geo_sd[:, np.newaxis], n_times, axis=1)[
768
- ..., np.newaxis
769
- ],
770
- n_controls,
771
- axis=2,
772
- )
838
+ if n_controls:
839
+ control_geo_sd = abs(np.random.normal(0, 5, size=n_geos))
840
+ control_geo_sd = np.repeat(
841
+ np.repeat(control_geo_sd[:, np.newaxis], n_times, axis=1)[
842
+ ..., np.newaxis
843
+ ],
844
+ n_controls,
845
+ axis=2,
846
+ )
847
+ else:
848
+ control_geo_sd = 0
773
849
 
774
850
  # Simulates outcome which is the dependent variable. Typically this is the
775
851
  # number of units sold, but it can be any metric (e.g. revenue).
776
852
  media_portion = np.random.normal(media_common, media_geo_sd).sum(axis=2)
777
- control_portion = np.random.normal(controls, control_geo_sd).sum(axis=2)
853
+ if controls is not None:
854
+ control_portion = np.random.normal(controls, control_geo_sd).sum(axis=2)
855
+ else:
856
+ control_portion = 0
857
+
778
858
  error = np.random.normal(0, 2, size=(n_geos, n_times))
779
859
  kpi = abs(media_portion + control_portion + error)
780
860
 
@@ -1085,7 +1165,7 @@ def random_dataset(
1085
1165
  n_geos: int,
1086
1166
  n_times: int,
1087
1167
  n_media_times: int,
1088
- n_controls: int,
1168
+ n_controls: int | None = None,
1089
1169
  n_non_media_channels: int | None = None,
1090
1170
  n_organic_media_channels: int | None = None,
1091
1171
  n_organic_rf_channels: int | None = None,
@@ -1095,7 +1175,7 @@ def random_dataset(
1095
1175
  seed: int = 0,
1096
1176
  remove_media_time: bool = False,
1097
1177
  integer_geos: bool = False,
1098
- ):
1178
+ ) -> xr.Dataset:
1099
1179
  """Generates a random dataset."""
1100
1180
  if n_media_channels:
1101
1181
  media = random_media_da(
@@ -1156,14 +1236,18 @@ def random_dataset(
1156
1236
  else:
1157
1237
  revenue_per_kpi = None
1158
1238
 
1159
- controls = random_controls_da(
1160
- media=media if n_media_channels else reach,
1161
- n_geos=n_geos,
1162
- n_times=n_times,
1163
- n_controls=n_controls,
1164
- seed=seed,
1165
- integer_geos=integer_geos,
1166
- )
1239
+ if n_controls:
1240
+ controls = random_controls_da(
1241
+ media=media if n_media_channels else reach,
1242
+ n_geos=n_geos,
1243
+ n_times=n_times,
1244
+ n_controls=n_controls,
1245
+ seed=seed,
1246
+ integer_geos=integer_geos,
1247
+ )
1248
+ else:
1249
+ controls = None
1250
+
1167
1251
  if n_non_media_channels:
1168
1252
  non_media_treatments = random_non_media_treatments_da(
1169
1253
  media=media if n_media_channels else reach,
@@ -1222,7 +1306,9 @@ def random_dataset(
1222
1306
  n_geos=n_geos, seed=seed, integer_geos=integer_geos
1223
1307
  )
1224
1308
 
1225
- dataset = xr.combine_by_coords([kpi, population, controls])
1309
+ dataset = xr.combine_by_coords(
1310
+ [kpi, population] + ([controls] if controls is not None else [])
1311
+ )
1226
1312
  if revenue_per_kpi is not None:
1227
1313
  dataset = xr.combine_by_coords([dataset, revenue_per_kpi])
1228
1314
  if media is not None:
@@ -1271,7 +1357,7 @@ def random_dataset(
1271
1357
 
1272
1358
  def dataset_to_dataframe(
1273
1359
  dataset: xr.Dataset,
1274
- controls_column_names: list[str],
1360
+ controls_column_names: list[str] | None = None,
1275
1361
  media_column_names: list[str] | None = None,
1276
1362
  media_spend_column_names: list[str] | None = None,
1277
1363
  reach_column_names: list[str] | None = None,
@@ -1320,10 +1406,15 @@ def dataset_to_dataframe(
1320
1406
  )
1321
1407
  population = dataset[c.POPULATION].to_dataframe(name=c.POPULATION)
1322
1408
 
1323
- controls = dataset[c.CONTROLS].to_dataframe(name=c.CONTROLS).unstack()
1324
- controls.columns = controls_column_names
1409
+ if controls_column_names is not None:
1410
+ controls = dataset[c.CONTROLS].to_dataframe(name=c.CONTROLS).unstack()
1411
+ controls.columns = controls_column_names
1412
+ else:
1413
+ controls = None
1325
1414
 
1326
- result = kpi.join(revenue_per_kpi).join(population).join(controls)
1415
+ result = kpi.join(revenue_per_kpi).join(population)
1416
+ if controls is not None:
1417
+ result = result.join(controls)
1327
1418
 
1328
1419
  if non_media_column_names is not None:
1329
1420
  non_media_treatments = (
@@ -1395,7 +1486,7 @@ def random_dataframe(
1395
1486
  n_geos,
1396
1487
  n_times,
1397
1488
  n_media_times,
1398
- n_controls,
1489
+ n_controls=None,
1399
1490
  n_media_channels=None,
1400
1491
  n_rf_channels=None,
1401
1492
  seed=0,
@@ -1413,7 +1504,9 @@ def random_dataframe(
1413
1504
 
1414
1505
  return dataset_to_dataframe(
1415
1506
  dataset,
1416
- controls_column_names=_sample_names('control_', n_controls),
1507
+ controls_column_names=(
1508
+ _sample_names('control_', n_controls) if n_controls else None
1509
+ ),
1417
1510
  media_column_names=_sample_names('media_', n_media_channels),
1418
1511
  media_spend_column_names=_sample_names('media_spend_', n_media_channels),
1419
1512
  reach_column_names=_sample_names('reach_', n_rf_channels),
@@ -1423,7 +1516,7 @@ def random_dataframe(
1423
1516
 
1424
1517
 
1425
1518
  def sample_coord_to_columns(
1426
- n_controls: int,
1519
+ n_controls: int | None = None,
1427
1520
  n_media_channels: int | None = None,
1428
1521
  n_rf_channels: int | None = None,
1429
1522
  n_non_media_channels: int | None = None,
@@ -1474,7 +1567,7 @@ def sample_coord_to_columns(
1474
1567
  kpi=c.KPI,
1475
1568
  revenue_per_kpi=c.REVENUE_PER_KPI if include_revenue_per_kpi else None,
1476
1569
  population=c.POPULATION,
1477
- controls=_sample_names('control_', n_controls),
1570
+ controls=(_sample_names('control_', n_controls) if n_controls else None),
1478
1571
  media=media,
1479
1572
  media_spend=media_spend,
1480
1573
  reach=reach,
@@ -1491,17 +1584,52 @@ def sample_input_data_from_dataset(
1491
1584
  dataset: xr.Dataset, kpi_type: str
1492
1585
  ) -> input_data.InputData:
1493
1586
  """Generates a sample `InputData` from a full xarray Dataset."""
1587
+ media = dataset.media if c.MEDIA in dataset.data_vars.keys() else None
1588
+ media_spend = (
1589
+ dataset.media_spend if c.MEDIA_SPEND in dataset.data_vars.keys() else None
1590
+ )
1591
+ reach = dataset.reach if c.REACH in dataset.data_vars.keys() else None
1592
+ frequency = (
1593
+ dataset.frequency if c.FREQUENCY in dataset.data_vars.keys() else None
1594
+ )
1595
+ rf_spend = (
1596
+ dataset.rf_spend if c.RF_SPEND in dataset.data_vars.keys() else None
1597
+ )
1598
+ organic_media = (
1599
+ dataset.organic_media
1600
+ if c.ORGANIC_MEDIA in dataset.data_vars.keys()
1601
+ else None
1602
+ )
1603
+ organic_reach = (
1604
+ dataset.organic_reach
1605
+ if c.ORGANIC_REACH in dataset.data_vars.keys()
1606
+ else None
1607
+ )
1608
+ organic_frequency = (
1609
+ dataset.organic_frequency
1610
+ if c.ORGANIC_FREQUENCY in dataset.data_vars.keys()
1611
+ else None
1612
+ )
1613
+ non_media_treatments = (
1614
+ dataset.non_media_treatments
1615
+ if c.NON_MEDIA_TREATMENTS in dataset.data_vars.keys()
1616
+ else None
1617
+ )
1494
1618
  return input_data.InputData(
1495
1619
  kpi=dataset.kpi,
1496
1620
  kpi_type=kpi_type,
1497
1621
  revenue_per_kpi=dataset.revenue_per_kpi,
1498
1622
  population=dataset.population,
1499
1623
  controls=dataset.controls,
1500
- media=dataset.media,
1501
- media_spend=dataset.media_spend,
1502
- reach=dataset.reach,
1503
- frequency=dataset.frequency,
1504
- rf_spend=dataset.rf_spend,
1624
+ media=media,
1625
+ media_spend=media_spend,
1626
+ reach=reach,
1627
+ frequency=frequency,
1628
+ rf_spend=rf_spend,
1629
+ organic_media=organic_media,
1630
+ organic_reach=organic_reach,
1631
+ organic_frequency=organic_frequency,
1632
+ non_media_treatments=non_media_treatments,
1505
1633
  )
1506
1634
 
1507
1635
 
@@ -1557,7 +1685,7 @@ def sample_input_data_non_revenue_revenue_per_kpi(
1557
1685
  n_geos: int = 10,
1558
1686
  n_times: int = 50,
1559
1687
  n_media_times: int = 53,
1560
- n_controls: int = 2,
1688
+ n_controls: int | None = 2,
1561
1689
  n_non_media_channels: int | None = None,
1562
1690
  n_media_channels: int | None = None,
1563
1691
  n_rf_channels: int | None = None,
@@ -1583,10 +1711,10 @@ def sample_input_data_non_revenue_revenue_per_kpi(
1583
1711
  kpi_type=c.NON_REVENUE,
1584
1712
  revenue_per_kpi=dataset.revenue_per_kpi,
1585
1713
  population=dataset.population,
1586
- controls=dataset.controls,
1587
- non_media_treatments=dataset.non_media_treatments
1588
- if n_non_media_channels
1589
- else None,
1714
+ controls=(dataset.controls if n_controls else None),
1715
+ non_media_treatments=(
1716
+ dataset.non_media_treatments if n_non_media_channels else None
1717
+ ),
1590
1718
  media=dataset.media if n_media_channels else None,
1591
1719
  media_spend=dataset.media_spend if n_media_channels else None,
1592
1720
  reach=dataset.reach if n_rf_channels else None,
@@ -1594,9 +1722,9 @@ def sample_input_data_non_revenue_revenue_per_kpi(
1594
1722
  rf_spend=dataset.rf_spend if n_rf_channels else None,
1595
1723
  organic_media=dataset.organic_media if n_organic_media_channels else None,
1596
1724
  organic_reach=dataset.organic_reach if n_organic_rf_channels else None,
1597
- organic_frequency=dataset.organic_frequency
1598
- if n_organic_rf_channels
1599
- else None,
1725
+ organic_frequency=(
1726
+ dataset.organic_frequency if n_organic_rf_channels else None
1727
+ ),
1600
1728
  )
1601
1729
 
1602
1730
 
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -36,7 +36,7 @@ __all__ = [
36
36
 
37
37
 
38
38
  # A type alias for a polymorphic "date" type.
39
- Date: TypeAlias = str | datetime.datetime | datetime.date | np.datetime64
39
+ Date: TypeAlias = str | datetime.datetime | datetime.date | np.datetime64 | None
40
40
 
41
41
  # A type alias for a polymorphic "date interval" type. In all variants it is
42
42
  # always a tuple of (start_date, end_date).
@@ -236,8 +236,8 @@ class TimeCoordinates:
236
236
 
237
237
  def expand_selected_time_dims(
238
238
  self,
239
- start_date: Date | None = None,
240
- end_date: Date | None = None,
239
+ start_date: Date = None,
240
+ end_date: Date = None,
241
241
  ) -> list[datetime.date] | None:
242
242
  """Validates and returns time dimension values based on the selected times.
243
243
 
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
meridian/model/knots.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.