google-meridian 1.3.2__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/METADATA +18 -11
- google_meridian-1.5.0.dist-info/RECORD +112 -0
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/WHEEL +1 -1
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/top_level.txt +1 -0
- meridian/analysis/analyzer.py +558 -398
- meridian/analysis/optimizer.py +90 -68
- meridian/analysis/review/reviewer.py +4 -1
- meridian/analysis/summarizer.py +13 -3
- meridian/analysis/test_utils.py +2911 -2102
- meridian/analysis/visualizer.py +37 -14
- meridian/backend/__init__.py +106 -0
- meridian/constants.py +2 -0
- meridian/data/input_data.py +30 -52
- meridian/data/input_data_builder.py +2 -9
- meridian/data/test_utils.py +107 -51
- meridian/data/validator.py +48 -0
- meridian/mlflow/autolog.py +19 -9
- meridian/model/__init__.py +2 -0
- meridian/model/adstock_hill.py +3 -5
- meridian/model/context.py +1059 -0
- meridian/model/eda/constants.py +335 -4
- meridian/model/eda/eda_engine.py +723 -312
- meridian/model/eda/eda_outcome.py +177 -33
- meridian/model/equations.py +418 -0
- meridian/model/knots.py +58 -47
- meridian/model/model.py +228 -878
- meridian/model/model_test_data.py +38 -0
- meridian/model/posterior_sampler.py +103 -62
- meridian/model/prior_sampler.py +114 -94
- meridian/model/spec.py +23 -14
- meridian/templates/card.html.jinja +9 -7
- meridian/templates/chart.html.jinja +1 -6
- meridian/templates/finding.html.jinja +19 -0
- meridian/templates/findings.html.jinja +33 -0
- meridian/templates/formatter.py +41 -5
- meridian/templates/formatter_test.py +127 -0
- meridian/templates/style.css +66 -9
- meridian/templates/style.scss +85 -4
- meridian/templates/table.html.jinja +1 -0
- meridian/version.py +1 -1
- scenarioplanner/__init__.py +42 -0
- scenarioplanner/converters/__init__.py +25 -0
- scenarioplanner/converters/dataframe/__init__.py +28 -0
- scenarioplanner/converters/dataframe/budget_opt_converters.py +383 -0
- scenarioplanner/converters/dataframe/common.py +71 -0
- scenarioplanner/converters/dataframe/constants.py +137 -0
- scenarioplanner/converters/dataframe/converter.py +42 -0
- scenarioplanner/converters/dataframe/dataframe_model_converter.py +70 -0
- scenarioplanner/converters/dataframe/marketing_analyses_converters.py +543 -0
- scenarioplanner/converters/dataframe/rf_opt_converters.py +314 -0
- scenarioplanner/converters/mmm.py +743 -0
- scenarioplanner/converters/mmm_converter.py +58 -0
- scenarioplanner/converters/sheets.py +156 -0
- scenarioplanner/converters/test_data.py +714 -0
- scenarioplanner/linkingapi/__init__.py +47 -0
- scenarioplanner/linkingapi/constants.py +27 -0
- scenarioplanner/linkingapi/url_generator.py +131 -0
- scenarioplanner/mmm_ui_proto_generator.py +355 -0
- schema/__init__.py +5 -2
- schema/mmm_proto_generator.py +71 -0
- schema/model_consumer.py +133 -0
- schema/processors/__init__.py +77 -0
- schema/processors/budget_optimization_processor.py +832 -0
- schema/processors/common.py +64 -0
- schema/processors/marketing_processor.py +1137 -0
- schema/processors/model_fit_processor.py +367 -0
- schema/processors/model_kernel_processor.py +117 -0
- schema/processors/model_processor.py +415 -0
- schema/processors/reach_frequency_optimization_processor.py +584 -0
- schema/serde/distribution.py +12 -7
- schema/serde/hyperparameters.py +54 -107
- schema/serde/meridian_serde.py +6 -1
- schema/test_data.py +380 -0
- schema/utils/__init__.py +2 -0
- schema/utils/date_range_bucketing.py +117 -0
- schema/utils/proto_enum_converter.py +127 -0
- google_meridian-1.3.2.dist-info/RECORD +0 -76
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/licenses/LICENSE +0 -0
schema/serde/hyperparameters.py
CHANGED
|
@@ -16,12 +16,14 @@
|
|
|
16
16
|
|
|
17
17
|
import warnings
|
|
18
18
|
|
|
19
|
+
import bidict
|
|
19
20
|
from meridian import backend
|
|
20
21
|
from meridian import constants as c
|
|
21
22
|
from meridian.model import spec
|
|
22
23
|
from mmm.v1.model.meridian import meridian_model_pb2 as meridian_pb
|
|
23
24
|
from schema.serde import constants as sc
|
|
24
25
|
from schema.serde import serde
|
|
26
|
+
from schema.utils import proto_enum_converter
|
|
25
27
|
import numpy as np
|
|
26
28
|
|
|
27
29
|
_MediaEffectsDist = meridian_pb.MediaEffectsDistribution
|
|
@@ -31,101 +33,44 @@ _NonMediaBaselineFunction = (
|
|
|
31
33
|
meridian_pb.NonMediaBaselineValue.NonMediaBaselineFunction
|
|
32
34
|
)
|
|
33
35
|
|
|
36
|
+
media_effects_converter = proto_enum_converter.ProtoEnumConverter(
|
|
37
|
+
enum_display_name="Media effects distribution",
|
|
38
|
+
enum_message=_MediaEffectsDist,
|
|
39
|
+
mapping=bidict.bidict({
|
|
40
|
+
c.MEDIA_EFFECTS_LOG_NORMAL: "LOG_NORMAL",
|
|
41
|
+
c.MEDIA_EFFECTS_NORMAL: "NORMAL",
|
|
42
|
+
}),
|
|
43
|
+
enum_unspecified=_MediaEffectsDist.MEDIA_EFFECTS_DISTRIBUTION_UNSPECIFIED,
|
|
44
|
+
default_when_unspecified=c.MEDIA_EFFECTS_LOG_NORMAL,
|
|
45
|
+
)
|
|
34
46
|
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
def _proto_enum_to_media_effects_dist(
|
|
48
|
-
proto_enum: _MediaEffectsDist,
|
|
49
|
-
) -> str:
|
|
50
|
-
"""Converts a `_MediaEffectsDist` enum to its string representation."""
|
|
51
|
-
match proto_enum:
|
|
52
|
-
case _MediaEffectsDist.LOG_NORMAL:
|
|
53
|
-
return c.MEDIA_EFFECTS_LOG_NORMAL
|
|
54
|
-
case _MediaEffectsDist.NORMAL:
|
|
55
|
-
return c.MEDIA_EFFECTS_NORMAL
|
|
56
|
-
case _MediaEffectsDist.MEDIA_EFFECTS_DISTRIBUTION_UNSPECIFIED:
|
|
57
|
-
warnings.warn(
|
|
58
|
-
"Media effects distribution is unspecified. Resolving to"
|
|
59
|
-
" 'log-normal'."
|
|
60
|
-
)
|
|
61
|
-
return c.MEDIA_EFFECTS_LOG_NORMAL
|
|
62
|
-
case _:
|
|
63
|
-
raise ValueError(
|
|
64
|
-
"Unsupported MediaEffectsDistribution proto enum value:"
|
|
65
|
-
f" {proto_enum}."
|
|
66
|
-
)
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
def _paid_media_prior_type_to_proto_enum(
|
|
70
|
-
paid_media_prior_type: str | None,
|
|
71
|
-
) -> _PaidMediaPriorType:
|
|
72
|
-
"""Converts a paid media prior type string to its proto enum."""
|
|
73
|
-
if paid_media_prior_type is None:
|
|
74
|
-
return _PaidMediaPriorType.PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED
|
|
75
|
-
try:
|
|
76
|
-
return _PaidMediaPriorType.Value(paid_media_prior_type.upper())
|
|
77
|
-
except ValueError:
|
|
78
|
-
warnings.warn(
|
|
79
|
-
f"Invalid paid media prior type: {paid_media_prior_type}. Resolving to"
|
|
80
|
-
" PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED."
|
|
81
|
-
)
|
|
82
|
-
return _PaidMediaPriorType.PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
def _proto_enum_to_paid_media_prior_type(
|
|
86
|
-
proto_enum: _PaidMediaPriorType,
|
|
87
|
-
) -> str | None:
|
|
88
|
-
"""Converts a `_PaidMediaPriorType` enum to its string representation."""
|
|
89
|
-
if proto_enum == _PaidMediaPriorType.PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED:
|
|
90
|
-
return None
|
|
91
|
-
return _PaidMediaPriorType.Name(proto_enum).lower()
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
def _non_paid_prior_type_to_proto_enum(
|
|
95
|
-
non_paid_prior_type: str,
|
|
96
|
-
) -> _NonPaidTreatmentsPriorType:
|
|
97
|
-
"""Converts a non-paid prior type string to its proto enum."""
|
|
98
|
-
try:
|
|
99
|
-
return _NonPaidTreatmentsPriorType.Value(
|
|
100
|
-
f"NON_PAID_TREATMENTS_PRIOR_TYPE_{non_paid_prior_type.upper()}"
|
|
101
|
-
)
|
|
102
|
-
except ValueError:
|
|
103
|
-
warnings.warn(
|
|
104
|
-
f"Invalid non-paid prior type: {non_paid_prior_type}. Resolving to"
|
|
105
|
-
" NON_PAID_TREATMENTS_PRIOR_TYPE_CONTRIBUTION."
|
|
106
|
-
)
|
|
107
|
-
return (
|
|
108
|
-
_NonPaidTreatmentsPriorType.NON_PAID_TREATMENTS_PRIOR_TYPE_CONTRIBUTION
|
|
109
|
-
)
|
|
110
|
-
|
|
47
|
+
paid_media_prior_type_converter = proto_enum_converter.ProtoEnumConverter(
|
|
48
|
+
enum_display_name="Paid media prior type",
|
|
49
|
+
enum_message=_PaidMediaPriorType,
|
|
50
|
+
mapping=bidict.bidict({
|
|
51
|
+
c.TREATMENT_PRIOR_TYPE_ROI: "ROI",
|
|
52
|
+
c.TREATMENT_PRIOR_TYPE_MROI: "MROI",
|
|
53
|
+
c.TREATMENT_PRIOR_TYPE_COEFFICIENT: "COEFFICIENT",
|
|
54
|
+
c.TREATMENT_PRIOR_TYPE_CONTRIBUTION: "CONTRIBUTION",
|
|
55
|
+
}),
|
|
56
|
+
enum_unspecified=_PaidMediaPriorType.PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED,
|
|
57
|
+
default_when_unspecified=None,
|
|
58
|
+
)
|
|
111
59
|
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
)
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
.replace("NON_PAID_TREATMENTS_PRIOR_TYPE_", "")
|
|
127
|
-
.lower()
|
|
128
|
-
)
|
|
60
|
+
non_paid_treatments_prior_type_converter = proto_enum_converter.ProtoEnumConverter(
|
|
61
|
+
enum_display_name="Non-paid treatments prior type",
|
|
62
|
+
enum_message=_NonPaidTreatmentsPriorType,
|
|
63
|
+
mapping=bidict.bidict({
|
|
64
|
+
c.TREATMENT_PRIOR_TYPE_COEFFICIENT: (
|
|
65
|
+
"NON_PAID_TREATMENTS_PRIOR_TYPE_COEFFICIENT"
|
|
66
|
+
),
|
|
67
|
+
c.TREATMENT_PRIOR_TYPE_CONTRIBUTION: (
|
|
68
|
+
"NON_PAID_TREATMENTS_PRIOR_TYPE_CONTRIBUTION"
|
|
69
|
+
),
|
|
70
|
+
}),
|
|
71
|
+
enum_unspecified=_NonPaidTreatmentsPriorType.NON_PAID_TREATMENTS_PRIOR_TYPE_UNSPECIFIED,
|
|
72
|
+
default_when_unspecified=c.TREATMENT_PRIOR_TYPE_CONTRIBUTION,
|
|
73
|
+
)
|
|
129
74
|
|
|
130
75
|
|
|
131
76
|
class HyperparametersSerde(
|
|
@@ -141,25 +86,27 @@ class HyperparametersSerde(
|
|
|
141
86
|
def serialize(self, obj: spec.ModelSpec) -> meridian_pb.Hyperparameters:
|
|
142
87
|
"""Serializes the given ModelSpec into a `Hyperparameters` proto."""
|
|
143
88
|
hyperparameters_proto = meridian_pb.Hyperparameters(
|
|
144
|
-
media_effects_dist=
|
|
89
|
+
media_effects_dist=media_effects_converter.to_proto(
|
|
145
90
|
obj.media_effects_dist
|
|
146
91
|
),
|
|
147
92
|
hill_before_adstock=obj.hill_before_adstock,
|
|
148
93
|
unique_sigma_for_each_geo=obj.unique_sigma_for_each_geo,
|
|
149
|
-
media_prior_type=
|
|
94
|
+
media_prior_type=paid_media_prior_type_converter.to_proto(
|
|
150
95
|
obj.media_prior_type
|
|
151
96
|
),
|
|
152
|
-
rf_prior_type=
|
|
153
|
-
|
|
97
|
+
rf_prior_type=paid_media_prior_type_converter.to_proto(
|
|
98
|
+
obj.rf_prior_type
|
|
99
|
+
),
|
|
100
|
+
paid_media_prior_type=paid_media_prior_type_converter.to_proto(
|
|
154
101
|
obj.paid_media_prior_type
|
|
155
102
|
),
|
|
156
|
-
organic_media_prior_type=
|
|
103
|
+
organic_media_prior_type=non_paid_treatments_prior_type_converter.to_proto(
|
|
157
104
|
obj.organic_media_prior_type
|
|
158
105
|
),
|
|
159
|
-
organic_rf_prior_type=
|
|
106
|
+
organic_rf_prior_type=non_paid_treatments_prior_type_converter.to_proto(
|
|
160
107
|
obj.organic_rf_prior_type
|
|
161
108
|
),
|
|
162
|
-
non_media_treatments_prior_type=
|
|
109
|
+
non_media_treatments_prior_type=non_paid_treatments_prior_type_converter.to_proto(
|
|
163
110
|
obj.non_media_treatments_prior_type
|
|
164
111
|
),
|
|
165
112
|
enable_aks=obj.enable_aks,
|
|
@@ -326,28 +273,28 @@ class HyperparametersSerde(
|
|
|
326
273
|
adstock_decay_spec = sc.DEFAULT_DECAY
|
|
327
274
|
|
|
328
275
|
return spec.ModelSpec(
|
|
329
|
-
media_effects_dist=
|
|
276
|
+
media_effects_dist=media_effects_converter.from_proto(
|
|
330
277
|
serialized.media_effects_dist
|
|
331
278
|
),
|
|
332
279
|
hill_before_adstock=serialized.hill_before_adstock,
|
|
333
280
|
max_lag=max_lag,
|
|
334
281
|
unique_sigma_for_each_geo=serialized.unique_sigma_for_each_geo,
|
|
335
|
-
media_prior_type=
|
|
282
|
+
media_prior_type=paid_media_prior_type_converter.from_proto(
|
|
336
283
|
serialized.media_prior_type
|
|
337
284
|
),
|
|
338
|
-
rf_prior_type=
|
|
285
|
+
rf_prior_type=paid_media_prior_type_converter.from_proto(
|
|
339
286
|
serialized.rf_prior_type
|
|
340
287
|
),
|
|
341
|
-
paid_media_prior_type=
|
|
288
|
+
paid_media_prior_type=paid_media_prior_type_converter.from_proto(
|
|
342
289
|
serialized.paid_media_prior_type
|
|
343
290
|
),
|
|
344
|
-
organic_media_prior_type=
|
|
291
|
+
organic_media_prior_type=non_paid_treatments_prior_type_converter.from_proto(
|
|
345
292
|
serialized.organic_media_prior_type
|
|
346
293
|
),
|
|
347
|
-
organic_rf_prior_type=
|
|
294
|
+
organic_rf_prior_type=non_paid_treatments_prior_type_converter.from_proto(
|
|
348
295
|
serialized.organic_rf_prior_type
|
|
349
296
|
),
|
|
350
|
-
non_media_treatments_prior_type=
|
|
297
|
+
non_media_treatments_prior_type=non_paid_treatments_prior_type_converter.from_proto(
|
|
351
298
|
serialized.non_media_treatments_prior_type
|
|
352
299
|
),
|
|
353
300
|
non_media_baseline_values=non_media_baseline_values,
|
schema/serde/meridian_serde.py
CHANGED
|
@@ -43,6 +43,7 @@ import dataclasses
|
|
|
43
43
|
import os
|
|
44
44
|
import warnings
|
|
45
45
|
|
|
46
|
+
import arviz as az
|
|
46
47
|
from google.protobuf import text_format
|
|
47
48
|
import meridian
|
|
48
49
|
from meridian import backend
|
|
@@ -165,6 +166,7 @@ class MeridianSerde(serde.Serde[kernel_pb.MmmKernel, model.Meridian]):
|
|
|
165
166
|
inference_data=inference_data.InferenceDataSerde().serialize(
|
|
166
167
|
mmm.inference_data
|
|
167
168
|
),
|
|
169
|
+
arviz_version=az.__version__,
|
|
168
170
|
)
|
|
169
171
|
# For backwards compatibility, only serialize EDA spec if it exists.
|
|
170
172
|
if hasattr(mmm, 'eda_spec'):
|
|
@@ -190,7 +192,10 @@ class MeridianSerde(serde.Serde[kernel_pb.MmmKernel, model.Meridian]):
|
|
|
190
192
|
# NotFittedModelError can be raised below. If raised,
|
|
191
193
|
# return None. Otherwise, set convergence status based on
|
|
192
194
|
# MCMCSamplingError (caught in the except block).
|
|
193
|
-
rhats = analyzer.Analyzer(
|
|
195
|
+
rhats = analyzer.Analyzer(
|
|
196
|
+
model_context=mmm.model_context,
|
|
197
|
+
inference_data=mmm.inference_data,
|
|
198
|
+
).get_rhat()
|
|
194
199
|
rhat_proto = meridian_pb.RHatDiagnostic()
|
|
195
200
|
for name, tensor in rhats.items():
|
|
196
201
|
rhat_proto.parameter_r_hats.add(
|
schema/test_data.py
ADDED
|
@@ -0,0 +1,380 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Test data for MMM proto generator."""
|
|
16
|
+
|
|
17
|
+
from collections.abc import Sequence
|
|
18
|
+
import datetime
|
|
19
|
+
|
|
20
|
+
from mmm.v1 import mmm_pb2 as mmm_pb
|
|
21
|
+
from mmm.v1.common import date_interval_pb2 as date_interval_pb
|
|
22
|
+
from mmm.v1.fit import model_fit_pb2 as fit_pb
|
|
23
|
+
from mmm.v1.marketing.analysis import marketing_analysis_pb2
|
|
24
|
+
from mmm.v1.marketing.optimization import budget_optimization_pb2 as budget_pb
|
|
25
|
+
from mmm.v1.marketing.optimization import reach_frequency_optimization_pb2 as rf_pb
|
|
26
|
+
from schema.processors import budget_optimization_processor
|
|
27
|
+
from schema.processors import marketing_processor
|
|
28
|
+
from schema.processors import model_fit_processor
|
|
29
|
+
from schema.processors import model_processor
|
|
30
|
+
from schema.processors import reach_frequency_optimization_processor as rf_opt_processor
|
|
31
|
+
|
|
32
|
+
from google.type import date_pb2
|
|
33
|
+
|
|
34
|
+
# Weekly dates from 2022-11-21 to 2024-01-01.
|
|
35
|
+
ALL_TIMES_IN_MERIDIAN = (
|
|
36
|
+
'2022-11-21',
|
|
37
|
+
'2022-11-28',
|
|
38
|
+
'2022-12-05',
|
|
39
|
+
'2022-12-12',
|
|
40
|
+
'2022-12-19',
|
|
41
|
+
'2022-12-26',
|
|
42
|
+
'2023-01-02',
|
|
43
|
+
'2023-01-09',
|
|
44
|
+
'2023-01-16',
|
|
45
|
+
'2023-01-23',
|
|
46
|
+
'2023-01-30',
|
|
47
|
+
'2023-02-06',
|
|
48
|
+
'2023-02-13',
|
|
49
|
+
'2023-02-20',
|
|
50
|
+
'2023-02-27',
|
|
51
|
+
'2023-03-06',
|
|
52
|
+
'2023-03-13',
|
|
53
|
+
'2023-03-20',
|
|
54
|
+
'2023-03-27',
|
|
55
|
+
'2023-04-03',
|
|
56
|
+
'2023-04-10',
|
|
57
|
+
'2023-04-17',
|
|
58
|
+
'2023-04-24',
|
|
59
|
+
'2023-05-01',
|
|
60
|
+
'2023-05-08',
|
|
61
|
+
'2023-05-15',
|
|
62
|
+
'2023-05-22',
|
|
63
|
+
'2023-05-29',
|
|
64
|
+
'2023-06-05',
|
|
65
|
+
'2023-06-12',
|
|
66
|
+
'2023-06-19',
|
|
67
|
+
'2023-06-26',
|
|
68
|
+
'2023-07-03',
|
|
69
|
+
'2023-07-10',
|
|
70
|
+
'2023-07-17',
|
|
71
|
+
'2023-07-24',
|
|
72
|
+
'2023-07-31',
|
|
73
|
+
'2023-08-07',
|
|
74
|
+
'2023-08-14',
|
|
75
|
+
'2023-08-21',
|
|
76
|
+
'2023-08-28',
|
|
77
|
+
'2023-09-04',
|
|
78
|
+
'2023-09-11',
|
|
79
|
+
'2023-09-18',
|
|
80
|
+
'2023-09-25',
|
|
81
|
+
'2023-10-02',
|
|
82
|
+
'2023-10-09',
|
|
83
|
+
'2023-10-16',
|
|
84
|
+
'2023-10-23',
|
|
85
|
+
'2023-10-30',
|
|
86
|
+
'2023-11-06',
|
|
87
|
+
'2023-11-13',
|
|
88
|
+
'2023-11-20',
|
|
89
|
+
'2023-11-27',
|
|
90
|
+
'2023-12-04',
|
|
91
|
+
'2023-12-11',
|
|
92
|
+
'2023-12-18',
|
|
93
|
+
'2023-12-25',
|
|
94
|
+
'2024-01-01',
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
ALL_TIME_BUCKET_DATED_SPECS = (
|
|
98
|
+
# All
|
|
99
|
+
model_processor.DatedSpec(
|
|
100
|
+
start_date=datetime.date(2022, 11, 21),
|
|
101
|
+
end_date=datetime.date(2024, 1, 8),
|
|
102
|
+
date_interval_tag='ALL',
|
|
103
|
+
),
|
|
104
|
+
# Monthly buckets
|
|
105
|
+
model_processor.DatedSpec(
|
|
106
|
+
start_date=datetime.date(2022, 12, 5),
|
|
107
|
+
end_date=datetime.date(2023, 1, 2),
|
|
108
|
+
date_interval_tag='Y2022 Dec',
|
|
109
|
+
),
|
|
110
|
+
model_processor.DatedSpec(
|
|
111
|
+
start_date=datetime.date(2023, 1, 2),
|
|
112
|
+
end_date=datetime.date(2023, 2, 6),
|
|
113
|
+
date_interval_tag='Y2023 Jan',
|
|
114
|
+
),
|
|
115
|
+
model_processor.DatedSpec(
|
|
116
|
+
start_date=datetime.date(2023, 2, 6),
|
|
117
|
+
end_date=datetime.date(2023, 3, 6),
|
|
118
|
+
date_interval_tag='Y2023 Feb',
|
|
119
|
+
),
|
|
120
|
+
model_processor.DatedSpec(
|
|
121
|
+
start_date=datetime.date(2023, 3, 6),
|
|
122
|
+
end_date=datetime.date(2023, 4, 3),
|
|
123
|
+
date_interval_tag='Y2023 Mar',
|
|
124
|
+
),
|
|
125
|
+
model_processor.DatedSpec(
|
|
126
|
+
start_date=datetime.date(2023, 4, 3),
|
|
127
|
+
end_date=datetime.date(2023, 5, 1),
|
|
128
|
+
date_interval_tag='Y2023 Apr',
|
|
129
|
+
),
|
|
130
|
+
model_processor.DatedSpec(
|
|
131
|
+
start_date=datetime.date(2023, 5, 1),
|
|
132
|
+
end_date=datetime.date(2023, 6, 5),
|
|
133
|
+
date_interval_tag='Y2023 May',
|
|
134
|
+
),
|
|
135
|
+
model_processor.DatedSpec(
|
|
136
|
+
start_date=datetime.date(2023, 6, 5),
|
|
137
|
+
end_date=datetime.date(2023, 7, 3),
|
|
138
|
+
date_interval_tag='Y2023 Jun',
|
|
139
|
+
),
|
|
140
|
+
model_processor.DatedSpec(
|
|
141
|
+
start_date=datetime.date(2023, 7, 3),
|
|
142
|
+
end_date=datetime.date(2023, 8, 7),
|
|
143
|
+
date_interval_tag='Y2023 Jul',
|
|
144
|
+
),
|
|
145
|
+
model_processor.DatedSpec(
|
|
146
|
+
start_date=datetime.date(2023, 8, 7),
|
|
147
|
+
end_date=datetime.date(2023, 9, 4),
|
|
148
|
+
date_interval_tag='Y2023 Aug',
|
|
149
|
+
),
|
|
150
|
+
model_processor.DatedSpec(
|
|
151
|
+
start_date=datetime.date(2023, 9, 4),
|
|
152
|
+
end_date=datetime.date(2023, 10, 2),
|
|
153
|
+
date_interval_tag='Y2023 Sep',
|
|
154
|
+
),
|
|
155
|
+
model_processor.DatedSpec(
|
|
156
|
+
start_date=datetime.date(2023, 10, 2),
|
|
157
|
+
end_date=datetime.date(2023, 11, 6),
|
|
158
|
+
date_interval_tag='Y2023 Oct',
|
|
159
|
+
),
|
|
160
|
+
model_processor.DatedSpec(
|
|
161
|
+
start_date=datetime.date(2023, 11, 6),
|
|
162
|
+
end_date=datetime.date(2023, 12, 4),
|
|
163
|
+
date_interval_tag='Y2023 Nov',
|
|
164
|
+
),
|
|
165
|
+
model_processor.DatedSpec(
|
|
166
|
+
start_date=datetime.date(2023, 12, 4),
|
|
167
|
+
end_date=datetime.date(2024, 1, 1),
|
|
168
|
+
date_interval_tag='Y2023 Dec',
|
|
169
|
+
),
|
|
170
|
+
# Quarterly buckets
|
|
171
|
+
model_processor.DatedSpec(
|
|
172
|
+
start_date=datetime.date(2023, 1, 2),
|
|
173
|
+
end_date=datetime.date(2023, 4, 3),
|
|
174
|
+
date_interval_tag='Y2023 Q1',
|
|
175
|
+
),
|
|
176
|
+
model_processor.DatedSpec(
|
|
177
|
+
start_date=datetime.date(2023, 4, 3),
|
|
178
|
+
end_date=datetime.date(2023, 7, 3),
|
|
179
|
+
date_interval_tag='Y2023 Q2',
|
|
180
|
+
),
|
|
181
|
+
model_processor.DatedSpec(
|
|
182
|
+
start_date=datetime.date(2023, 7, 3),
|
|
183
|
+
end_date=datetime.date(2023, 10, 2),
|
|
184
|
+
date_interval_tag='Y2023 Q3',
|
|
185
|
+
),
|
|
186
|
+
model_processor.DatedSpec(
|
|
187
|
+
start_date=datetime.date(2023, 10, 2),
|
|
188
|
+
end_date=datetime.date(2024, 1, 1),
|
|
189
|
+
date_interval_tag='Y2023 Q4',
|
|
190
|
+
),
|
|
191
|
+
# Yearly buckets
|
|
192
|
+
model_processor.DatedSpec(
|
|
193
|
+
start_date=datetime.date(2023, 1, 2),
|
|
194
|
+
end_date=datetime.date(2024, 1, 1),
|
|
195
|
+
date_interval_tag='Y2023',
|
|
196
|
+
),
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def _dated_spec_to_date_interval(
|
|
201
|
+
spec: model_processor.DatedSpec,
|
|
202
|
+
) -> date_interval_pb.DateInterval:
|
|
203
|
+
if spec.start_date is None or spec.end_date is None:
|
|
204
|
+
raise ValueError('Start date or end date is None.')
|
|
205
|
+
|
|
206
|
+
return date_interval_pb.DateInterval(
|
|
207
|
+
start_date=date_pb2.Date(
|
|
208
|
+
year=spec.start_date.year,
|
|
209
|
+
month=spec.start_date.month,
|
|
210
|
+
day=spec.start_date.day,
|
|
211
|
+
),
|
|
212
|
+
end_date=date_pb2.Date(
|
|
213
|
+
year=spec.end_date.year,
|
|
214
|
+
month=spec.end_date.month,
|
|
215
|
+
day=spec.end_date.day,
|
|
216
|
+
),
|
|
217
|
+
tag=spec.date_interval_tag,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
class FakeModelFitProcessor(
|
|
222
|
+
model_processor.ModelProcessor[
|
|
223
|
+
model_fit_processor.ModelFitSpec, fit_pb.ModelFit
|
|
224
|
+
]
|
|
225
|
+
):
|
|
226
|
+
"""Fake ModelFitProcessor for testing."""
|
|
227
|
+
|
|
228
|
+
def __init__(self, trained_model: model_processor.TrainedModel):
|
|
229
|
+
self._trained_model = trained_model
|
|
230
|
+
|
|
231
|
+
@classmethod
|
|
232
|
+
def spec_type(cls):
|
|
233
|
+
return model_fit_processor.ModelFitSpec
|
|
234
|
+
|
|
235
|
+
@classmethod
|
|
236
|
+
def output_type(cls):
|
|
237
|
+
return fit_pb.ModelFit
|
|
238
|
+
|
|
239
|
+
def execute(
|
|
240
|
+
self, specs: Sequence[model_fit_processor.ModelFitSpec]
|
|
241
|
+
) -> fit_pb.ModelFit:
|
|
242
|
+
return fit_pb.ModelFit()
|
|
243
|
+
|
|
244
|
+
def _set_output(self, output: mmm_pb.Mmm, result: fit_pb.ModelFit):
|
|
245
|
+
output.model_fit.CopyFrom(result)
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
class FakeBudgetOptimizationProcessor(
|
|
249
|
+
model_processor.ModelProcessor[
|
|
250
|
+
budget_optimization_processor.BudgetOptimizationSpec,
|
|
251
|
+
budget_pb.BudgetOptimization,
|
|
252
|
+
]
|
|
253
|
+
):
|
|
254
|
+
"""Fake BudgetOptimizationProcessor for testing."""
|
|
255
|
+
|
|
256
|
+
def __init__(self, trained_model: model_processor.TrainedModel):
|
|
257
|
+
self._trained_model = trained_model
|
|
258
|
+
|
|
259
|
+
@classmethod
|
|
260
|
+
def spec_type(cls):
|
|
261
|
+
return budget_optimization_processor.BudgetOptimizationSpec
|
|
262
|
+
|
|
263
|
+
@classmethod
|
|
264
|
+
def output_type(cls):
|
|
265
|
+
return budget_pb.BudgetOptimization
|
|
266
|
+
|
|
267
|
+
def execute(
|
|
268
|
+
self,
|
|
269
|
+
specs: Sequence[budget_optimization_processor.BudgetOptimizationSpec],
|
|
270
|
+
) -> budget_pb.BudgetOptimization:
|
|
271
|
+
results = []
|
|
272
|
+
for spec in specs:
|
|
273
|
+
result = budget_pb.BudgetOptimizationResult(
|
|
274
|
+
name=spec.optimization_name,
|
|
275
|
+
spec=budget_pb.BudgetOptimizationSpec(
|
|
276
|
+
date_interval=_dated_spec_to_date_interval(spec)
|
|
277
|
+
),
|
|
278
|
+
incremental_outcome_grid=budget_pb.IncrementalOutcomeGrid(
|
|
279
|
+
name=spec.grid_name
|
|
280
|
+
),
|
|
281
|
+
)
|
|
282
|
+
if spec.group_id:
|
|
283
|
+
result.group_id = spec.group_id
|
|
284
|
+
results.append(result)
|
|
285
|
+
|
|
286
|
+
return budget_pb.BudgetOptimization(results=results)
|
|
287
|
+
|
|
288
|
+
def _set_output(
|
|
289
|
+
self, output: mmm_pb.Mmm, result: budget_pb.BudgetOptimization
|
|
290
|
+
):
|
|
291
|
+
output.marketing_optimization.budget_optimization.CopyFrom(result)
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
class FakeReachFrequencyOptimizationProcessor(
|
|
295
|
+
model_processor.ModelProcessor[
|
|
296
|
+
rf_opt_processor.ReachFrequencyOptimizationSpec,
|
|
297
|
+
rf_pb.ReachFrequencyOptimization,
|
|
298
|
+
]
|
|
299
|
+
):
|
|
300
|
+
"""Fake ReachFrequencyOptimizationProcessor for testing."""
|
|
301
|
+
|
|
302
|
+
def __init__(self, trained_model: model_processor.TrainedModel):
|
|
303
|
+
self._trained_model = trained_model
|
|
304
|
+
|
|
305
|
+
@classmethod
|
|
306
|
+
def spec_type(cls):
|
|
307
|
+
return rf_opt_processor.ReachFrequencyOptimizationSpec
|
|
308
|
+
|
|
309
|
+
@classmethod
|
|
310
|
+
def output_type(cls):
|
|
311
|
+
return rf_pb.ReachFrequencyOptimization
|
|
312
|
+
|
|
313
|
+
def execute(
|
|
314
|
+
self,
|
|
315
|
+
specs: Sequence[rf_opt_processor.ReachFrequencyOptimizationSpec],
|
|
316
|
+
) -> rf_pb.ReachFrequencyOptimization:
|
|
317
|
+
results = []
|
|
318
|
+
for spec in specs:
|
|
319
|
+
result = rf_pb.ReachFrequencyOptimizationResult(
|
|
320
|
+
name=spec.optimization_name,
|
|
321
|
+
spec=rf_pb.ReachFrequencyOptimizationSpec(
|
|
322
|
+
date_interval=_dated_spec_to_date_interval(spec)
|
|
323
|
+
),
|
|
324
|
+
frequency_outcome_grid=rf_pb.FrequencyOutcomeGrid(
|
|
325
|
+
name=spec.grid_name
|
|
326
|
+
),
|
|
327
|
+
)
|
|
328
|
+
if spec.group_id:
|
|
329
|
+
result.group_id = spec.group_id
|
|
330
|
+
results.append(result)
|
|
331
|
+
|
|
332
|
+
return rf_pb.ReachFrequencyOptimization(results=results)
|
|
333
|
+
|
|
334
|
+
def _set_output(
|
|
335
|
+
self,
|
|
336
|
+
output: mmm_pb.Mmm,
|
|
337
|
+
result: rf_pb.ReachFrequencyOptimization,
|
|
338
|
+
):
|
|
339
|
+
output.marketing_optimization.reach_frequency_optimization.CopyFrom(result)
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
class FakeMarketingProcessor(
|
|
343
|
+
model_processor.ModelProcessor[
|
|
344
|
+
marketing_processor.MarketingAnalysisSpec,
|
|
345
|
+
marketing_analysis_pb2.MarketingAnalysisList,
|
|
346
|
+
]
|
|
347
|
+
):
|
|
348
|
+
"""Fake MarketingProcessor for testing."""
|
|
349
|
+
|
|
350
|
+
def __init__(self, trained_model: model_processor.TrainedModel):
|
|
351
|
+
self._trained_model = trained_model
|
|
352
|
+
|
|
353
|
+
@classmethod
|
|
354
|
+
def spec_type(cls):
|
|
355
|
+
return marketing_processor.MarketingAnalysisSpec
|
|
356
|
+
|
|
357
|
+
@classmethod
|
|
358
|
+
def output_type(cls):
|
|
359
|
+
return marketing_analysis_pb2.MarketingAnalysisList
|
|
360
|
+
|
|
361
|
+
def execute(
|
|
362
|
+
self, specs: Sequence[marketing_processor.MarketingAnalysisSpec]
|
|
363
|
+
) -> marketing_analysis_pb2.MarketingAnalysisList:
|
|
364
|
+
marketing_analyses = []
|
|
365
|
+
for spec in specs:
|
|
366
|
+
marketing_analysis = marketing_analysis_pb2.MarketingAnalysis(
|
|
367
|
+
date_interval=_dated_spec_to_date_interval(spec)
|
|
368
|
+
)
|
|
369
|
+
marketing_analyses.append(marketing_analysis)
|
|
370
|
+
|
|
371
|
+
return marketing_analysis_pb2.MarketingAnalysisList(
|
|
372
|
+
marketing_analyses=marketing_analyses
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
def _set_output(
|
|
376
|
+
self,
|
|
377
|
+
output: mmm_pb.Mmm,
|
|
378
|
+
result: marketing_analysis_pb2.MarketingAnalysisList,
|
|
379
|
+
):
|
|
380
|
+
output.marketing_analysis_list.CopyFrom(result)
|