google-meridian 1.3.2__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/METADATA +18 -11
  2. google_meridian-1.5.0.dist-info/RECORD +112 -0
  3. {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/WHEEL +1 -1
  4. {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/top_level.txt +1 -0
  5. meridian/analysis/analyzer.py +558 -398
  6. meridian/analysis/optimizer.py +90 -68
  7. meridian/analysis/review/reviewer.py +4 -1
  8. meridian/analysis/summarizer.py +13 -3
  9. meridian/analysis/test_utils.py +2911 -2102
  10. meridian/analysis/visualizer.py +37 -14
  11. meridian/backend/__init__.py +106 -0
  12. meridian/constants.py +2 -0
  13. meridian/data/input_data.py +30 -52
  14. meridian/data/input_data_builder.py +2 -9
  15. meridian/data/test_utils.py +107 -51
  16. meridian/data/validator.py +48 -0
  17. meridian/mlflow/autolog.py +19 -9
  18. meridian/model/__init__.py +2 -0
  19. meridian/model/adstock_hill.py +3 -5
  20. meridian/model/context.py +1059 -0
  21. meridian/model/eda/constants.py +335 -4
  22. meridian/model/eda/eda_engine.py +723 -312
  23. meridian/model/eda/eda_outcome.py +177 -33
  24. meridian/model/equations.py +418 -0
  25. meridian/model/knots.py +58 -47
  26. meridian/model/model.py +228 -878
  27. meridian/model/model_test_data.py +38 -0
  28. meridian/model/posterior_sampler.py +103 -62
  29. meridian/model/prior_sampler.py +114 -94
  30. meridian/model/spec.py +23 -14
  31. meridian/templates/card.html.jinja +9 -7
  32. meridian/templates/chart.html.jinja +1 -6
  33. meridian/templates/finding.html.jinja +19 -0
  34. meridian/templates/findings.html.jinja +33 -0
  35. meridian/templates/formatter.py +41 -5
  36. meridian/templates/formatter_test.py +127 -0
  37. meridian/templates/style.css +66 -9
  38. meridian/templates/style.scss +85 -4
  39. meridian/templates/table.html.jinja +1 -0
  40. meridian/version.py +1 -1
  41. scenarioplanner/__init__.py +42 -0
  42. scenarioplanner/converters/__init__.py +25 -0
  43. scenarioplanner/converters/dataframe/__init__.py +28 -0
  44. scenarioplanner/converters/dataframe/budget_opt_converters.py +383 -0
  45. scenarioplanner/converters/dataframe/common.py +71 -0
  46. scenarioplanner/converters/dataframe/constants.py +137 -0
  47. scenarioplanner/converters/dataframe/converter.py +42 -0
  48. scenarioplanner/converters/dataframe/dataframe_model_converter.py +70 -0
  49. scenarioplanner/converters/dataframe/marketing_analyses_converters.py +543 -0
  50. scenarioplanner/converters/dataframe/rf_opt_converters.py +314 -0
  51. scenarioplanner/converters/mmm.py +743 -0
  52. scenarioplanner/converters/mmm_converter.py +58 -0
  53. scenarioplanner/converters/sheets.py +156 -0
  54. scenarioplanner/converters/test_data.py +714 -0
  55. scenarioplanner/linkingapi/__init__.py +47 -0
  56. scenarioplanner/linkingapi/constants.py +27 -0
  57. scenarioplanner/linkingapi/url_generator.py +131 -0
  58. scenarioplanner/mmm_ui_proto_generator.py +355 -0
  59. schema/__init__.py +5 -2
  60. schema/mmm_proto_generator.py +71 -0
  61. schema/model_consumer.py +133 -0
  62. schema/processors/__init__.py +77 -0
  63. schema/processors/budget_optimization_processor.py +832 -0
  64. schema/processors/common.py +64 -0
  65. schema/processors/marketing_processor.py +1137 -0
  66. schema/processors/model_fit_processor.py +367 -0
  67. schema/processors/model_kernel_processor.py +117 -0
  68. schema/processors/model_processor.py +415 -0
  69. schema/processors/reach_frequency_optimization_processor.py +584 -0
  70. schema/serde/distribution.py +12 -7
  71. schema/serde/hyperparameters.py +54 -107
  72. schema/serde/meridian_serde.py +6 -1
  73. schema/test_data.py +380 -0
  74. schema/utils/__init__.py +2 -0
  75. schema/utils/date_range_bucketing.py +117 -0
  76. schema/utils/proto_enum_converter.py +127 -0
  77. google_meridian-1.3.2.dist-info/RECORD +0 -76
  78. {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -16,12 +16,14 @@
16
16
 
17
17
  import warnings
18
18
 
19
+ import bidict
19
20
  from meridian import backend
20
21
  from meridian import constants as c
21
22
  from meridian.model import spec
22
23
  from mmm.v1.model.meridian import meridian_model_pb2 as meridian_pb
23
24
  from schema.serde import constants as sc
24
25
  from schema.serde import serde
26
+ from schema.utils import proto_enum_converter
25
27
  import numpy as np
26
28
 
27
29
  _MediaEffectsDist = meridian_pb.MediaEffectsDistribution
@@ -31,101 +33,44 @@ _NonMediaBaselineFunction = (
31
33
  meridian_pb.NonMediaBaselineValue.NonMediaBaselineFunction
32
34
  )
33
35
 
36
+ media_effects_converter = proto_enum_converter.ProtoEnumConverter(
37
+ enum_display_name="Media effects distribution",
38
+ enum_message=_MediaEffectsDist,
39
+ mapping=bidict.bidict({
40
+ c.MEDIA_EFFECTS_LOG_NORMAL: "LOG_NORMAL",
41
+ c.MEDIA_EFFECTS_NORMAL: "NORMAL",
42
+ }),
43
+ enum_unspecified=_MediaEffectsDist.MEDIA_EFFECTS_DISTRIBUTION_UNSPECIFIED,
44
+ default_when_unspecified=c.MEDIA_EFFECTS_LOG_NORMAL,
45
+ )
34
46
 
35
- def _media_effects_dist_to_proto_enum(
36
- media_effect_dict: str,
37
- ) -> _MediaEffectsDist:
38
- match media_effect_dict:
39
- case c.MEDIA_EFFECTS_LOG_NORMAL:
40
- return _MediaEffectsDist.LOG_NORMAL
41
- case c.MEDIA_EFFECTS_NORMAL:
42
- return _MediaEffectsDist.NORMAL
43
- case _:
44
- return _MediaEffectsDist.MEDIA_EFFECTS_DISTRIBUTION_UNSPECIFIED
45
-
46
-
47
- def _proto_enum_to_media_effects_dist(
48
- proto_enum: _MediaEffectsDist,
49
- ) -> str:
50
- """Converts a `_MediaEffectsDist` enum to its string representation."""
51
- match proto_enum:
52
- case _MediaEffectsDist.LOG_NORMAL:
53
- return c.MEDIA_EFFECTS_LOG_NORMAL
54
- case _MediaEffectsDist.NORMAL:
55
- return c.MEDIA_EFFECTS_NORMAL
56
- case _MediaEffectsDist.MEDIA_EFFECTS_DISTRIBUTION_UNSPECIFIED:
57
- warnings.warn(
58
- "Media effects distribution is unspecified. Resolving to"
59
- " 'log-normal'."
60
- )
61
- return c.MEDIA_EFFECTS_LOG_NORMAL
62
- case _:
63
- raise ValueError(
64
- "Unsupported MediaEffectsDistribution proto enum value:"
65
- f" {proto_enum}."
66
- )
67
-
68
-
69
- def _paid_media_prior_type_to_proto_enum(
70
- paid_media_prior_type: str | None,
71
- ) -> _PaidMediaPriorType:
72
- """Converts a paid media prior type string to its proto enum."""
73
- if paid_media_prior_type is None:
74
- return _PaidMediaPriorType.PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED
75
- try:
76
- return _PaidMediaPriorType.Value(paid_media_prior_type.upper())
77
- except ValueError:
78
- warnings.warn(
79
- f"Invalid paid media prior type: {paid_media_prior_type}. Resolving to"
80
- " PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED."
81
- )
82
- return _PaidMediaPriorType.PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED
83
-
84
-
85
- def _proto_enum_to_paid_media_prior_type(
86
- proto_enum: _PaidMediaPriorType,
87
- ) -> str | None:
88
- """Converts a `_PaidMediaPriorType` enum to its string representation."""
89
- if proto_enum == _PaidMediaPriorType.PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED:
90
- return None
91
- return _PaidMediaPriorType.Name(proto_enum).lower()
92
-
93
-
94
- def _non_paid_prior_type_to_proto_enum(
95
- non_paid_prior_type: str,
96
- ) -> _NonPaidTreatmentsPriorType:
97
- """Converts a non-paid prior type string to its proto enum."""
98
- try:
99
- return _NonPaidTreatmentsPriorType.Value(
100
- f"NON_PAID_TREATMENTS_PRIOR_TYPE_{non_paid_prior_type.upper()}"
101
- )
102
- except ValueError:
103
- warnings.warn(
104
- f"Invalid non-paid prior type: {non_paid_prior_type}. Resolving to"
105
- " NON_PAID_TREATMENTS_PRIOR_TYPE_CONTRIBUTION."
106
- )
107
- return (
108
- _NonPaidTreatmentsPriorType.NON_PAID_TREATMENTS_PRIOR_TYPE_CONTRIBUTION
109
- )
110
-
47
+ paid_media_prior_type_converter = proto_enum_converter.ProtoEnumConverter(
48
+ enum_display_name="Paid media prior type",
49
+ enum_message=_PaidMediaPriorType,
50
+ mapping=bidict.bidict({
51
+ c.TREATMENT_PRIOR_TYPE_ROI: "ROI",
52
+ c.TREATMENT_PRIOR_TYPE_MROI: "MROI",
53
+ c.TREATMENT_PRIOR_TYPE_COEFFICIENT: "COEFFICIENT",
54
+ c.TREATMENT_PRIOR_TYPE_CONTRIBUTION: "CONTRIBUTION",
55
+ }),
56
+ enum_unspecified=_PaidMediaPriorType.PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED,
57
+ default_when_unspecified=None,
58
+ )
111
59
 
112
- def _proto_enum_to_non_paid_prior_type(
113
- proto_enum: _NonPaidTreatmentsPriorType,
114
- ) -> str:
115
- """Converts a `_NonPaidTreatmentsPriorType` enum to its string representation."""
116
- if (
117
- proto_enum
118
- == _NonPaidTreatmentsPriorType.NON_PAID_TREATMENTS_PRIOR_TYPE_UNSPECIFIED
119
- ):
120
- warnings.warn(
121
- "Non-paid prior type is unspecified. Resolving to 'contribution'."
122
- )
123
- return c.TREATMENT_PRIOR_TYPE_CONTRIBUTION
124
- return (
125
- _NonPaidTreatmentsPriorType.Name(proto_enum)
126
- .replace("NON_PAID_TREATMENTS_PRIOR_TYPE_", "")
127
- .lower()
128
- )
60
+ non_paid_treatments_prior_type_converter = proto_enum_converter.ProtoEnumConverter(
61
+ enum_display_name="Non-paid treatments prior type",
62
+ enum_message=_NonPaidTreatmentsPriorType,
63
+ mapping=bidict.bidict({
64
+ c.TREATMENT_PRIOR_TYPE_COEFFICIENT: (
65
+ "NON_PAID_TREATMENTS_PRIOR_TYPE_COEFFICIENT"
66
+ ),
67
+ c.TREATMENT_PRIOR_TYPE_CONTRIBUTION: (
68
+ "NON_PAID_TREATMENTS_PRIOR_TYPE_CONTRIBUTION"
69
+ ),
70
+ }),
71
+ enum_unspecified=_NonPaidTreatmentsPriorType.NON_PAID_TREATMENTS_PRIOR_TYPE_UNSPECIFIED,
72
+ default_when_unspecified=c.TREATMENT_PRIOR_TYPE_CONTRIBUTION,
73
+ )
129
74
 
130
75
 
131
76
  class HyperparametersSerde(
@@ -141,25 +86,27 @@ class HyperparametersSerde(
141
86
  def serialize(self, obj: spec.ModelSpec) -> meridian_pb.Hyperparameters:
142
87
  """Serializes the given ModelSpec into a `Hyperparameters` proto."""
143
88
  hyperparameters_proto = meridian_pb.Hyperparameters(
144
- media_effects_dist=_media_effects_dist_to_proto_enum(
89
+ media_effects_dist=media_effects_converter.to_proto(
145
90
  obj.media_effects_dist
146
91
  ),
147
92
  hill_before_adstock=obj.hill_before_adstock,
148
93
  unique_sigma_for_each_geo=obj.unique_sigma_for_each_geo,
149
- media_prior_type=_paid_media_prior_type_to_proto_enum(
94
+ media_prior_type=paid_media_prior_type_converter.to_proto(
150
95
  obj.media_prior_type
151
96
  ),
152
- rf_prior_type=_paid_media_prior_type_to_proto_enum(obj.rf_prior_type),
153
- paid_media_prior_type=_paid_media_prior_type_to_proto_enum(
97
+ rf_prior_type=paid_media_prior_type_converter.to_proto(
98
+ obj.rf_prior_type
99
+ ),
100
+ paid_media_prior_type=paid_media_prior_type_converter.to_proto(
154
101
  obj.paid_media_prior_type
155
102
  ),
156
- organic_media_prior_type=_non_paid_prior_type_to_proto_enum(
103
+ organic_media_prior_type=non_paid_treatments_prior_type_converter.to_proto(
157
104
  obj.organic_media_prior_type
158
105
  ),
159
- organic_rf_prior_type=_non_paid_prior_type_to_proto_enum(
106
+ organic_rf_prior_type=non_paid_treatments_prior_type_converter.to_proto(
160
107
  obj.organic_rf_prior_type
161
108
  ),
162
- non_media_treatments_prior_type=_non_paid_prior_type_to_proto_enum(
109
+ non_media_treatments_prior_type=non_paid_treatments_prior_type_converter.to_proto(
163
110
  obj.non_media_treatments_prior_type
164
111
  ),
165
112
  enable_aks=obj.enable_aks,
@@ -326,28 +273,28 @@ class HyperparametersSerde(
326
273
  adstock_decay_spec = sc.DEFAULT_DECAY
327
274
 
328
275
  return spec.ModelSpec(
329
- media_effects_dist=_proto_enum_to_media_effects_dist(
276
+ media_effects_dist=media_effects_converter.from_proto(
330
277
  serialized.media_effects_dist
331
278
  ),
332
279
  hill_before_adstock=serialized.hill_before_adstock,
333
280
  max_lag=max_lag,
334
281
  unique_sigma_for_each_geo=serialized.unique_sigma_for_each_geo,
335
- media_prior_type=_proto_enum_to_paid_media_prior_type(
282
+ media_prior_type=paid_media_prior_type_converter.from_proto(
336
283
  serialized.media_prior_type
337
284
  ),
338
- rf_prior_type=_proto_enum_to_paid_media_prior_type(
285
+ rf_prior_type=paid_media_prior_type_converter.from_proto(
339
286
  serialized.rf_prior_type
340
287
  ),
341
- paid_media_prior_type=_proto_enum_to_paid_media_prior_type(
288
+ paid_media_prior_type=paid_media_prior_type_converter.from_proto(
342
289
  serialized.paid_media_prior_type
343
290
  ),
344
- organic_media_prior_type=_proto_enum_to_non_paid_prior_type(
291
+ organic_media_prior_type=non_paid_treatments_prior_type_converter.from_proto(
345
292
  serialized.organic_media_prior_type
346
293
  ),
347
- organic_rf_prior_type=_proto_enum_to_non_paid_prior_type(
294
+ organic_rf_prior_type=non_paid_treatments_prior_type_converter.from_proto(
348
295
  serialized.organic_rf_prior_type
349
296
  ),
350
- non_media_treatments_prior_type=_proto_enum_to_non_paid_prior_type(
297
+ non_media_treatments_prior_type=non_paid_treatments_prior_type_converter.from_proto(
351
298
  serialized.non_media_treatments_prior_type
352
299
  ),
353
300
  non_media_baseline_values=non_media_baseline_values,
@@ -43,6 +43,7 @@ import dataclasses
43
43
  import os
44
44
  import warnings
45
45
 
46
+ import arviz as az
46
47
  from google.protobuf import text_format
47
48
  import meridian
48
49
  from meridian import backend
@@ -165,6 +166,7 @@ class MeridianSerde(serde.Serde[kernel_pb.MmmKernel, model.Meridian]):
165
166
  inference_data=inference_data.InferenceDataSerde().serialize(
166
167
  mmm.inference_data
167
168
  ),
169
+ arviz_version=az.__version__,
168
170
  )
169
171
  # For backwards compatibility, only serialize EDA spec if it exists.
170
172
  if hasattr(mmm, 'eda_spec'):
@@ -190,7 +192,10 @@ class MeridianSerde(serde.Serde[kernel_pb.MmmKernel, model.Meridian]):
190
192
  # NotFittedModelError can be raised below. If raised,
191
193
  # return None. Otherwise, set convergence status based on
192
194
  # MCMCSamplingError (caught in the except block).
193
- rhats = analyzer.Analyzer(mmm).get_rhat()
195
+ rhats = analyzer.Analyzer(
196
+ model_context=mmm.model_context,
197
+ inference_data=mmm.inference_data,
198
+ ).get_rhat()
194
199
  rhat_proto = meridian_pb.RHatDiagnostic()
195
200
  for name, tensor in rhats.items():
196
201
  rhat_proto.parameter_r_hats.add(
schema/test_data.py ADDED
@@ -0,0 +1,380 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Test data for MMM proto generator."""
16
+
17
+ from collections.abc import Sequence
18
+ import datetime
19
+
20
+ from mmm.v1 import mmm_pb2 as mmm_pb
21
+ from mmm.v1.common import date_interval_pb2 as date_interval_pb
22
+ from mmm.v1.fit import model_fit_pb2 as fit_pb
23
+ from mmm.v1.marketing.analysis import marketing_analysis_pb2
24
+ from mmm.v1.marketing.optimization import budget_optimization_pb2 as budget_pb
25
+ from mmm.v1.marketing.optimization import reach_frequency_optimization_pb2 as rf_pb
26
+ from schema.processors import budget_optimization_processor
27
+ from schema.processors import marketing_processor
28
+ from schema.processors import model_fit_processor
29
+ from schema.processors import model_processor
30
+ from schema.processors import reach_frequency_optimization_processor as rf_opt_processor
31
+
32
+ from google.type import date_pb2
33
+
34
+ # Weekly dates from 2022-11-21 to 2024-01-01.
35
+ ALL_TIMES_IN_MERIDIAN = (
36
+ '2022-11-21',
37
+ '2022-11-28',
38
+ '2022-12-05',
39
+ '2022-12-12',
40
+ '2022-12-19',
41
+ '2022-12-26',
42
+ '2023-01-02',
43
+ '2023-01-09',
44
+ '2023-01-16',
45
+ '2023-01-23',
46
+ '2023-01-30',
47
+ '2023-02-06',
48
+ '2023-02-13',
49
+ '2023-02-20',
50
+ '2023-02-27',
51
+ '2023-03-06',
52
+ '2023-03-13',
53
+ '2023-03-20',
54
+ '2023-03-27',
55
+ '2023-04-03',
56
+ '2023-04-10',
57
+ '2023-04-17',
58
+ '2023-04-24',
59
+ '2023-05-01',
60
+ '2023-05-08',
61
+ '2023-05-15',
62
+ '2023-05-22',
63
+ '2023-05-29',
64
+ '2023-06-05',
65
+ '2023-06-12',
66
+ '2023-06-19',
67
+ '2023-06-26',
68
+ '2023-07-03',
69
+ '2023-07-10',
70
+ '2023-07-17',
71
+ '2023-07-24',
72
+ '2023-07-31',
73
+ '2023-08-07',
74
+ '2023-08-14',
75
+ '2023-08-21',
76
+ '2023-08-28',
77
+ '2023-09-04',
78
+ '2023-09-11',
79
+ '2023-09-18',
80
+ '2023-09-25',
81
+ '2023-10-02',
82
+ '2023-10-09',
83
+ '2023-10-16',
84
+ '2023-10-23',
85
+ '2023-10-30',
86
+ '2023-11-06',
87
+ '2023-11-13',
88
+ '2023-11-20',
89
+ '2023-11-27',
90
+ '2023-12-04',
91
+ '2023-12-11',
92
+ '2023-12-18',
93
+ '2023-12-25',
94
+ '2024-01-01',
95
+ )
96
+
97
+ ALL_TIME_BUCKET_DATED_SPECS = (
98
+ # All
99
+ model_processor.DatedSpec(
100
+ start_date=datetime.date(2022, 11, 21),
101
+ end_date=datetime.date(2024, 1, 8),
102
+ date_interval_tag='ALL',
103
+ ),
104
+ # Monthly buckets
105
+ model_processor.DatedSpec(
106
+ start_date=datetime.date(2022, 12, 5),
107
+ end_date=datetime.date(2023, 1, 2),
108
+ date_interval_tag='Y2022 Dec',
109
+ ),
110
+ model_processor.DatedSpec(
111
+ start_date=datetime.date(2023, 1, 2),
112
+ end_date=datetime.date(2023, 2, 6),
113
+ date_interval_tag='Y2023 Jan',
114
+ ),
115
+ model_processor.DatedSpec(
116
+ start_date=datetime.date(2023, 2, 6),
117
+ end_date=datetime.date(2023, 3, 6),
118
+ date_interval_tag='Y2023 Feb',
119
+ ),
120
+ model_processor.DatedSpec(
121
+ start_date=datetime.date(2023, 3, 6),
122
+ end_date=datetime.date(2023, 4, 3),
123
+ date_interval_tag='Y2023 Mar',
124
+ ),
125
+ model_processor.DatedSpec(
126
+ start_date=datetime.date(2023, 4, 3),
127
+ end_date=datetime.date(2023, 5, 1),
128
+ date_interval_tag='Y2023 Apr',
129
+ ),
130
+ model_processor.DatedSpec(
131
+ start_date=datetime.date(2023, 5, 1),
132
+ end_date=datetime.date(2023, 6, 5),
133
+ date_interval_tag='Y2023 May',
134
+ ),
135
+ model_processor.DatedSpec(
136
+ start_date=datetime.date(2023, 6, 5),
137
+ end_date=datetime.date(2023, 7, 3),
138
+ date_interval_tag='Y2023 Jun',
139
+ ),
140
+ model_processor.DatedSpec(
141
+ start_date=datetime.date(2023, 7, 3),
142
+ end_date=datetime.date(2023, 8, 7),
143
+ date_interval_tag='Y2023 Jul',
144
+ ),
145
+ model_processor.DatedSpec(
146
+ start_date=datetime.date(2023, 8, 7),
147
+ end_date=datetime.date(2023, 9, 4),
148
+ date_interval_tag='Y2023 Aug',
149
+ ),
150
+ model_processor.DatedSpec(
151
+ start_date=datetime.date(2023, 9, 4),
152
+ end_date=datetime.date(2023, 10, 2),
153
+ date_interval_tag='Y2023 Sep',
154
+ ),
155
+ model_processor.DatedSpec(
156
+ start_date=datetime.date(2023, 10, 2),
157
+ end_date=datetime.date(2023, 11, 6),
158
+ date_interval_tag='Y2023 Oct',
159
+ ),
160
+ model_processor.DatedSpec(
161
+ start_date=datetime.date(2023, 11, 6),
162
+ end_date=datetime.date(2023, 12, 4),
163
+ date_interval_tag='Y2023 Nov',
164
+ ),
165
+ model_processor.DatedSpec(
166
+ start_date=datetime.date(2023, 12, 4),
167
+ end_date=datetime.date(2024, 1, 1),
168
+ date_interval_tag='Y2023 Dec',
169
+ ),
170
+ # Quarterly buckets
171
+ model_processor.DatedSpec(
172
+ start_date=datetime.date(2023, 1, 2),
173
+ end_date=datetime.date(2023, 4, 3),
174
+ date_interval_tag='Y2023 Q1',
175
+ ),
176
+ model_processor.DatedSpec(
177
+ start_date=datetime.date(2023, 4, 3),
178
+ end_date=datetime.date(2023, 7, 3),
179
+ date_interval_tag='Y2023 Q2',
180
+ ),
181
+ model_processor.DatedSpec(
182
+ start_date=datetime.date(2023, 7, 3),
183
+ end_date=datetime.date(2023, 10, 2),
184
+ date_interval_tag='Y2023 Q3',
185
+ ),
186
+ model_processor.DatedSpec(
187
+ start_date=datetime.date(2023, 10, 2),
188
+ end_date=datetime.date(2024, 1, 1),
189
+ date_interval_tag='Y2023 Q4',
190
+ ),
191
+ # Yearly buckets
192
+ model_processor.DatedSpec(
193
+ start_date=datetime.date(2023, 1, 2),
194
+ end_date=datetime.date(2024, 1, 1),
195
+ date_interval_tag='Y2023',
196
+ ),
197
+ )
198
+
199
+
200
+ def _dated_spec_to_date_interval(
201
+ spec: model_processor.DatedSpec,
202
+ ) -> date_interval_pb.DateInterval:
203
+ if spec.start_date is None or spec.end_date is None:
204
+ raise ValueError('Start date or end date is None.')
205
+
206
+ return date_interval_pb.DateInterval(
207
+ start_date=date_pb2.Date(
208
+ year=spec.start_date.year,
209
+ month=spec.start_date.month,
210
+ day=spec.start_date.day,
211
+ ),
212
+ end_date=date_pb2.Date(
213
+ year=spec.end_date.year,
214
+ month=spec.end_date.month,
215
+ day=spec.end_date.day,
216
+ ),
217
+ tag=spec.date_interval_tag,
218
+ )
219
+
220
+
221
+ class FakeModelFitProcessor(
222
+ model_processor.ModelProcessor[
223
+ model_fit_processor.ModelFitSpec, fit_pb.ModelFit
224
+ ]
225
+ ):
226
+ """Fake ModelFitProcessor for testing."""
227
+
228
+ def __init__(self, trained_model: model_processor.TrainedModel):
229
+ self._trained_model = trained_model
230
+
231
+ @classmethod
232
+ def spec_type(cls):
233
+ return model_fit_processor.ModelFitSpec
234
+
235
+ @classmethod
236
+ def output_type(cls):
237
+ return fit_pb.ModelFit
238
+
239
+ def execute(
240
+ self, specs: Sequence[model_fit_processor.ModelFitSpec]
241
+ ) -> fit_pb.ModelFit:
242
+ return fit_pb.ModelFit()
243
+
244
+ def _set_output(self, output: mmm_pb.Mmm, result: fit_pb.ModelFit):
245
+ output.model_fit.CopyFrom(result)
246
+
247
+
248
+ class FakeBudgetOptimizationProcessor(
249
+ model_processor.ModelProcessor[
250
+ budget_optimization_processor.BudgetOptimizationSpec,
251
+ budget_pb.BudgetOptimization,
252
+ ]
253
+ ):
254
+ """Fake BudgetOptimizationProcessor for testing."""
255
+
256
+ def __init__(self, trained_model: model_processor.TrainedModel):
257
+ self._trained_model = trained_model
258
+
259
+ @classmethod
260
+ def spec_type(cls):
261
+ return budget_optimization_processor.BudgetOptimizationSpec
262
+
263
+ @classmethod
264
+ def output_type(cls):
265
+ return budget_pb.BudgetOptimization
266
+
267
+ def execute(
268
+ self,
269
+ specs: Sequence[budget_optimization_processor.BudgetOptimizationSpec],
270
+ ) -> budget_pb.BudgetOptimization:
271
+ results = []
272
+ for spec in specs:
273
+ result = budget_pb.BudgetOptimizationResult(
274
+ name=spec.optimization_name,
275
+ spec=budget_pb.BudgetOptimizationSpec(
276
+ date_interval=_dated_spec_to_date_interval(spec)
277
+ ),
278
+ incremental_outcome_grid=budget_pb.IncrementalOutcomeGrid(
279
+ name=spec.grid_name
280
+ ),
281
+ )
282
+ if spec.group_id:
283
+ result.group_id = spec.group_id
284
+ results.append(result)
285
+
286
+ return budget_pb.BudgetOptimization(results=results)
287
+
288
+ def _set_output(
289
+ self, output: mmm_pb.Mmm, result: budget_pb.BudgetOptimization
290
+ ):
291
+ output.marketing_optimization.budget_optimization.CopyFrom(result)
292
+
293
+
294
+ class FakeReachFrequencyOptimizationProcessor(
295
+ model_processor.ModelProcessor[
296
+ rf_opt_processor.ReachFrequencyOptimizationSpec,
297
+ rf_pb.ReachFrequencyOptimization,
298
+ ]
299
+ ):
300
+ """Fake ReachFrequencyOptimizationProcessor for testing."""
301
+
302
+ def __init__(self, trained_model: model_processor.TrainedModel):
303
+ self._trained_model = trained_model
304
+
305
+ @classmethod
306
+ def spec_type(cls):
307
+ return rf_opt_processor.ReachFrequencyOptimizationSpec
308
+
309
+ @classmethod
310
+ def output_type(cls):
311
+ return rf_pb.ReachFrequencyOptimization
312
+
313
+ def execute(
314
+ self,
315
+ specs: Sequence[rf_opt_processor.ReachFrequencyOptimizationSpec],
316
+ ) -> rf_pb.ReachFrequencyOptimization:
317
+ results = []
318
+ for spec in specs:
319
+ result = rf_pb.ReachFrequencyOptimizationResult(
320
+ name=spec.optimization_name,
321
+ spec=rf_pb.ReachFrequencyOptimizationSpec(
322
+ date_interval=_dated_spec_to_date_interval(spec)
323
+ ),
324
+ frequency_outcome_grid=rf_pb.FrequencyOutcomeGrid(
325
+ name=spec.grid_name
326
+ ),
327
+ )
328
+ if spec.group_id:
329
+ result.group_id = spec.group_id
330
+ results.append(result)
331
+
332
+ return rf_pb.ReachFrequencyOptimization(results=results)
333
+
334
+ def _set_output(
335
+ self,
336
+ output: mmm_pb.Mmm,
337
+ result: rf_pb.ReachFrequencyOptimization,
338
+ ):
339
+ output.marketing_optimization.reach_frequency_optimization.CopyFrom(result)
340
+
341
+
342
+ class FakeMarketingProcessor(
343
+ model_processor.ModelProcessor[
344
+ marketing_processor.MarketingAnalysisSpec,
345
+ marketing_analysis_pb2.MarketingAnalysisList,
346
+ ]
347
+ ):
348
+ """Fake MarketingProcessor for testing."""
349
+
350
+ def __init__(self, trained_model: model_processor.TrainedModel):
351
+ self._trained_model = trained_model
352
+
353
+ @classmethod
354
+ def spec_type(cls):
355
+ return marketing_processor.MarketingAnalysisSpec
356
+
357
+ @classmethod
358
+ def output_type(cls):
359
+ return marketing_analysis_pb2.MarketingAnalysisList
360
+
361
+ def execute(
362
+ self, specs: Sequence[marketing_processor.MarketingAnalysisSpec]
363
+ ) -> marketing_analysis_pb2.MarketingAnalysisList:
364
+ marketing_analyses = []
365
+ for spec in specs:
366
+ marketing_analysis = marketing_analysis_pb2.MarketingAnalysis(
367
+ date_interval=_dated_spec_to_date_interval(spec)
368
+ )
369
+ marketing_analyses.append(marketing_analysis)
370
+
371
+ return marketing_analysis_pb2.MarketingAnalysisList(
372
+ marketing_analyses=marketing_analyses
373
+ )
374
+
375
+ def _set_output(
376
+ self,
377
+ output: mmm_pb.Mmm,
378
+ result: marketing_analysis_pb2.MarketingAnalysisList,
379
+ ):
380
+ output.marketing_analysis_list.CopyFrom(result)
schema/utils/__init__.py CHANGED
@@ -14,4 +14,6 @@
14
14
 
15
15
  """Module containing MMM schema util functions."""
16
16
 
17
+ from schema.utils import date_range_bucketing
18
+ from schema.utils import proto_enum_converter
17
19
  from schema.utils import time_record