google-meridian 1.3.2__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/METADATA +18 -11
- google_meridian-1.5.0.dist-info/RECORD +112 -0
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/WHEEL +1 -1
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/top_level.txt +1 -0
- meridian/analysis/analyzer.py +558 -398
- meridian/analysis/optimizer.py +90 -68
- meridian/analysis/review/reviewer.py +4 -1
- meridian/analysis/summarizer.py +13 -3
- meridian/analysis/test_utils.py +2911 -2102
- meridian/analysis/visualizer.py +37 -14
- meridian/backend/__init__.py +106 -0
- meridian/constants.py +2 -0
- meridian/data/input_data.py +30 -52
- meridian/data/input_data_builder.py +2 -9
- meridian/data/test_utils.py +107 -51
- meridian/data/validator.py +48 -0
- meridian/mlflow/autolog.py +19 -9
- meridian/model/__init__.py +2 -0
- meridian/model/adstock_hill.py +3 -5
- meridian/model/context.py +1059 -0
- meridian/model/eda/constants.py +335 -4
- meridian/model/eda/eda_engine.py +723 -312
- meridian/model/eda/eda_outcome.py +177 -33
- meridian/model/equations.py +418 -0
- meridian/model/knots.py +58 -47
- meridian/model/model.py +228 -878
- meridian/model/model_test_data.py +38 -0
- meridian/model/posterior_sampler.py +103 -62
- meridian/model/prior_sampler.py +114 -94
- meridian/model/spec.py +23 -14
- meridian/templates/card.html.jinja +9 -7
- meridian/templates/chart.html.jinja +1 -6
- meridian/templates/finding.html.jinja +19 -0
- meridian/templates/findings.html.jinja +33 -0
- meridian/templates/formatter.py +41 -5
- meridian/templates/formatter_test.py +127 -0
- meridian/templates/style.css +66 -9
- meridian/templates/style.scss +85 -4
- meridian/templates/table.html.jinja +1 -0
- meridian/version.py +1 -1
- scenarioplanner/__init__.py +42 -0
- scenarioplanner/converters/__init__.py +25 -0
- scenarioplanner/converters/dataframe/__init__.py +28 -0
- scenarioplanner/converters/dataframe/budget_opt_converters.py +383 -0
- scenarioplanner/converters/dataframe/common.py +71 -0
- scenarioplanner/converters/dataframe/constants.py +137 -0
- scenarioplanner/converters/dataframe/converter.py +42 -0
- scenarioplanner/converters/dataframe/dataframe_model_converter.py +70 -0
- scenarioplanner/converters/dataframe/marketing_analyses_converters.py +543 -0
- scenarioplanner/converters/dataframe/rf_opt_converters.py +314 -0
- scenarioplanner/converters/mmm.py +743 -0
- scenarioplanner/converters/mmm_converter.py +58 -0
- scenarioplanner/converters/sheets.py +156 -0
- scenarioplanner/converters/test_data.py +714 -0
- scenarioplanner/linkingapi/__init__.py +47 -0
- scenarioplanner/linkingapi/constants.py +27 -0
- scenarioplanner/linkingapi/url_generator.py +131 -0
- scenarioplanner/mmm_ui_proto_generator.py +355 -0
- schema/__init__.py +5 -2
- schema/mmm_proto_generator.py +71 -0
- schema/model_consumer.py +133 -0
- schema/processors/__init__.py +77 -0
- schema/processors/budget_optimization_processor.py +832 -0
- schema/processors/common.py +64 -0
- schema/processors/marketing_processor.py +1137 -0
- schema/processors/model_fit_processor.py +367 -0
- schema/processors/model_kernel_processor.py +117 -0
- schema/processors/model_processor.py +415 -0
- schema/processors/reach_frequency_optimization_processor.py +584 -0
- schema/serde/distribution.py +12 -7
- schema/serde/hyperparameters.py +54 -107
- schema/serde/meridian_serde.py +6 -1
- schema/test_data.py +380 -0
- schema/utils/__init__.py +2 -0
- schema/utils/date_range_bucketing.py +117 -0
- schema/utils/proto_enum_converter.py +127 -0
- google_meridian-1.3.2.dist-info/RECORD +0 -76
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,743 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Provides wrappers for the `Mmm` proto.
|
|
16
|
+
|
|
17
|
+
This module defines a set of dataclasses that act as high-level wrappers around
|
|
18
|
+
the `Mmm` protocol buffer and its nested messages. The primary goal is to offer
|
|
19
|
+
a more intuitive API for accessing and manipulating MMM data, abstracting away
|
|
20
|
+
the verbosity of the raw protobuf structures.
|
|
21
|
+
|
|
22
|
+
The main entry point is the `Mmm` class, which wraps the top-level `mmm_pb2.Mmm`
|
|
23
|
+
proto. From an instance of this class, you can navigate through the model's
|
|
24
|
+
different components, such as marketing data, model fit results, and various
|
|
25
|
+
analyses, using simple properties and methods.
|
|
26
|
+
|
|
27
|
+
Typical Usage:
|
|
28
|
+
|
|
29
|
+
```python
|
|
30
|
+
from mmm.v1 import mmm_pb2
|
|
31
|
+
from lookerstudio.converters import mmm
|
|
32
|
+
|
|
33
|
+
# Assume `mmm_proto` is a populated instance of the Mmm proto
|
|
34
|
+
mmm_proto = mmm_pb2.Mmm()
|
|
35
|
+
# ...
|
|
36
|
+
|
|
37
|
+
# Create the main wrapper instance
|
|
38
|
+
mmm_wrapper = mmm.Mmm(mmm_proto)
|
|
39
|
+
|
|
40
|
+
# Access marketing data and calculate total spends for a given period
|
|
41
|
+
marketing_data = mmm_wrapper.marketing_data
|
|
42
|
+
total_spends = marketing_data.all_channel_spends(
|
|
43
|
+
date_interval=('2025-01-01', '2025-03-31')
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
# Access budget optimization results
|
|
47
|
+
for budget_result in mmm_wrapper.budget_optimization_results:
|
|
48
|
+
print(f"Name: {budget_result.name}, Max: {budget_result.spec.max_budget}")
|
|
49
|
+
```
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
import abc
|
|
53
|
+
import dataclasses
|
|
54
|
+
import datetime
|
|
55
|
+
import functools
|
|
56
|
+
from typing import TypeAlias
|
|
57
|
+
|
|
58
|
+
from meridian import constants as c
|
|
59
|
+
from meridian.data import time_coordinates as tc
|
|
60
|
+
from mmm.v1 import mmm_pb2 as mmm_pb
|
|
61
|
+
from mmm.v1.common import date_interval_pb2 as date_interval_pb
|
|
62
|
+
from mmm.v1.common import estimate_pb2 as estimate_pb
|
|
63
|
+
from mmm.v1.common import kpi_type_pb2 as kpi_type_pb
|
|
64
|
+
from mmm.v1.common import target_metric_pb2 as target_metric_pb
|
|
65
|
+
from mmm.v1.fit import model_fit_pb2 as fit_pb
|
|
66
|
+
from mmm.v1.marketing import marketing_data_pb2 as marketing_data_pb
|
|
67
|
+
from mmm.v1.marketing.analysis import marketing_analysis_pb2 as marketing_pb
|
|
68
|
+
from mmm.v1.marketing.analysis import media_analysis_pb2 as media_pb
|
|
69
|
+
from mmm.v1.marketing.analysis import non_media_analysis_pb2 as non_media_pb
|
|
70
|
+
from mmm.v1.marketing.analysis import outcome_pb2 as outcome_pb
|
|
71
|
+
from mmm.v1.marketing.analysis import response_curve_pb2 as response_curve_pb
|
|
72
|
+
from mmm.v1.marketing.optimization import budget_optimization_pb2 as budget_pb
|
|
73
|
+
from mmm.v1.marketing.optimization import constraints_pb2 as constraints_pb
|
|
74
|
+
from mmm.v1.marketing.optimization import reach_frequency_optimization_pb2 as rf_pb
|
|
75
|
+
|
|
76
|
+
from google.type import date_pb2 as date_pb
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
_DateIntervalTuple: TypeAlias = tuple[datetime.date, datetime.date]
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@dataclasses.dataclass(frozen=True)
|
|
83
|
+
class DateInterval:
|
|
84
|
+
"""A dataclass wrapper around a tuple of `(start, end)` dates."""
|
|
85
|
+
|
|
86
|
+
date_interval: _DateIntervalTuple
|
|
87
|
+
|
|
88
|
+
@property
|
|
89
|
+
def start(self) -> datetime.date:
|
|
90
|
+
return self.date_interval[0]
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def end(self) -> datetime.date:
|
|
94
|
+
return self.date_interval[1]
|
|
95
|
+
|
|
96
|
+
def __contains__(self, date: datetime.date) -> bool:
|
|
97
|
+
"""Returns whether this date interval contains the given date."""
|
|
98
|
+
return self.start <= date < self.end
|
|
99
|
+
|
|
100
|
+
def __lt__(self, other: "DateInterval") -> bool:
|
|
101
|
+
return self.start < other.start
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _to_datetime_date(
|
|
105
|
+
date_proto: date_pb.Date,
|
|
106
|
+
) -> datetime.date:
|
|
107
|
+
"""Converts a `Date` proto into a `datetime.date`."""
|
|
108
|
+
return datetime.date(
|
|
109
|
+
year=date_proto.year, month=date_proto.month, day=date_proto.day
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _to_date_interval_dc(
|
|
114
|
+
date_interval: date_interval_pb.DateInterval,
|
|
115
|
+
) -> DateInterval:
|
|
116
|
+
"""Converts a `DateInterval` proto into `DateInterval` dataclass."""
|
|
117
|
+
return DateInterval((
|
|
118
|
+
_to_datetime_date(date_interval.start_date),
|
|
119
|
+
_to_datetime_date(date_interval.end_date),
|
|
120
|
+
))
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
@dataclasses.dataclass(frozen=True)
|
|
124
|
+
class Outcome:
|
|
125
|
+
"""A wrapper for `Outcome` proto with derived properties."""
|
|
126
|
+
|
|
127
|
+
outcome_proto: outcome_pb.Outcome
|
|
128
|
+
|
|
129
|
+
@property
|
|
130
|
+
def is_revenue_kpi(self) -> bool:
|
|
131
|
+
return self.outcome_proto.kpi_type == kpi_type_pb.REVENUE
|
|
132
|
+
|
|
133
|
+
@property
|
|
134
|
+
def is_nonrevenue_kpi(self) -> bool:
|
|
135
|
+
return self.outcome_proto.kpi_type == kpi_type_pb.NON_REVENUE
|
|
136
|
+
|
|
137
|
+
@property
|
|
138
|
+
def contribution_pb(self) -> outcome_pb.Contribution:
|
|
139
|
+
return self.outcome_proto.contribution
|
|
140
|
+
|
|
141
|
+
@property
|
|
142
|
+
def effectiveness_pb(self) -> outcome_pb.Effectiveness:
|
|
143
|
+
return self.outcome_proto.effectiveness
|
|
144
|
+
|
|
145
|
+
@property
|
|
146
|
+
def roi_pb(self) -> estimate_pb.Estimate:
|
|
147
|
+
return self.outcome_proto.roi
|
|
148
|
+
|
|
149
|
+
@property
|
|
150
|
+
def marginal_roi_pb(self) -> estimate_pb.Estimate:
|
|
151
|
+
return self.outcome_proto.marginal_roi
|
|
152
|
+
|
|
153
|
+
@property
|
|
154
|
+
def cost_per_contribution_pb(self) -> estimate_pb.Estimate:
|
|
155
|
+
return self.outcome_proto.cost_per_contribution
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class _OutcomeMixin(abc.ABC):
|
|
159
|
+
"""Mixin for (non-)media analysis with typed KPI outcome property getters.
|
|
160
|
+
|
|
161
|
+
A `MediaAnalysis` or `NonMediaAnalysis` proto is configured with multiple
|
|
162
|
+
polymorphic `Outcome`s. In Meridian processors, both types (revenue and
|
|
163
|
+
non-revenue) may be present in the analysis container. However, for each type
|
|
164
|
+
there should be at most one `Outcome` value.
|
|
165
|
+
|
|
166
|
+
This mixin provides both `MediaAnalysis` and `NonMediaAnalysis` dataclasses
|
|
167
|
+
with property getters to retrieve typed `Outcome` values.
|
|
168
|
+
"""
|
|
169
|
+
|
|
170
|
+
@property
|
|
171
|
+
@abc.abstractmethod
|
|
172
|
+
def _outcome_pbs(self) -> list[outcome_pb.Outcome]:
|
|
173
|
+
"""Returns a list of `Outcome` protos."""
|
|
174
|
+
raise NotImplementedError()
|
|
175
|
+
|
|
176
|
+
@functools.cached_property
|
|
177
|
+
def maybe_revenue_outcome(self) -> Outcome | None:
|
|
178
|
+
"""Returns the revenue-type `Outcome`, or None if it does not exist."""
|
|
179
|
+
for outcome_proto in self._outcome_pbs:
|
|
180
|
+
outcome = Outcome(outcome_proto)
|
|
181
|
+
if outcome.is_revenue_kpi:
|
|
182
|
+
return outcome
|
|
183
|
+
return None
|
|
184
|
+
|
|
185
|
+
@property
|
|
186
|
+
def revenue_outcome(self) -> Outcome:
|
|
187
|
+
"""Returns the revenue-type `Outcome`, or raises an error if it does not exist."""
|
|
188
|
+
outcome = self.maybe_revenue_outcome
|
|
189
|
+
if outcome is None:
|
|
190
|
+
raise ValueError(
|
|
191
|
+
"No revenue-type `Outcome` found in an expected analysis proto."
|
|
192
|
+
)
|
|
193
|
+
return outcome
|
|
194
|
+
|
|
195
|
+
@functools.cached_property
|
|
196
|
+
def maybe_non_revenue_outcome(self) -> Outcome | None:
|
|
197
|
+
"""Returns the nonrevenue-type `Outcome`, or None if it does not exist."""
|
|
198
|
+
for outcome_proto in self._outcome_pbs:
|
|
199
|
+
outcome = Outcome(outcome_proto)
|
|
200
|
+
if outcome.is_nonrevenue_kpi:
|
|
201
|
+
return outcome
|
|
202
|
+
return None
|
|
203
|
+
|
|
204
|
+
@property
|
|
205
|
+
def non_revenue_outcome(self) -> Outcome:
|
|
206
|
+
"""Returns the nonrevenue-type `Outcome`, or raises an error if it does not exist."""
|
|
207
|
+
outcome = self.maybe_non_revenue_outcome
|
|
208
|
+
if outcome is None:
|
|
209
|
+
raise ValueError(
|
|
210
|
+
"No nonrevenue-type `Outcome` found in an expected analysis proto."
|
|
211
|
+
)
|
|
212
|
+
return outcome
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
@dataclasses.dataclass(frozen=True)
|
|
216
|
+
class MediaAnalysis(_OutcomeMixin):
|
|
217
|
+
"""A wrapper for `MediaAnalysis` proto with derived properties."""
|
|
218
|
+
|
|
219
|
+
analysis_proto: media_pb.MediaAnalysis
|
|
220
|
+
|
|
221
|
+
@property
|
|
222
|
+
def channel_name(self) -> str:
|
|
223
|
+
return self.analysis_proto.channel_name
|
|
224
|
+
|
|
225
|
+
@property
|
|
226
|
+
def spend_info_pb(self) -> media_pb.SpendInfo:
|
|
227
|
+
return self.analysis_proto.spend_info
|
|
228
|
+
|
|
229
|
+
@property
|
|
230
|
+
def _outcome_pbs(self) -> list[outcome_pb.Outcome]:
|
|
231
|
+
return list(self.analysis_proto.media_outcomes)
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
@dataclasses.dataclass(frozen=True)
|
|
235
|
+
class NonMediaAnalysis(_OutcomeMixin):
|
|
236
|
+
"""A wrapper for `NonMediaAnalysis` proto with derived properties."""
|
|
237
|
+
|
|
238
|
+
analysis_proto: non_media_pb.NonMediaAnalysis
|
|
239
|
+
|
|
240
|
+
@property
|
|
241
|
+
def non_media_name(self) -> str:
|
|
242
|
+
return self.analysis_proto.non_media_name
|
|
243
|
+
|
|
244
|
+
@property
|
|
245
|
+
def _outcome_pbs(self) -> list[outcome_pb.Outcome]:
|
|
246
|
+
return list(self.analysis_proto.non_media_outcomes)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
@dataclasses.dataclass(frozen=True)
|
|
250
|
+
class ResponseCurve:
|
|
251
|
+
"""A wrapper for `ResponseCurve` proto with derived properties."""
|
|
252
|
+
|
|
253
|
+
channel_name: str
|
|
254
|
+
response_curve_proto: response_curve_pb.ResponseCurve
|
|
255
|
+
|
|
256
|
+
@property
|
|
257
|
+
def input_name(self) -> str:
|
|
258
|
+
return self.response_curve_proto.input_name
|
|
259
|
+
|
|
260
|
+
@property
|
|
261
|
+
def response_points(self) -> list[tuple[float, float]]:
|
|
262
|
+
"""Returns `(spend, incremental outcome)` tuples for this channel's curve."""
|
|
263
|
+
return [
|
|
264
|
+
(point.input_value, point.incremental_kpi)
|
|
265
|
+
for point in self.response_curve_proto.response_points
|
|
266
|
+
]
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
@dataclasses.dataclass(frozen=True)
|
|
270
|
+
class MarketingAnalysis:
|
|
271
|
+
"""A wrapper for `MarketingAnalysis` proto with derived properties."""
|
|
272
|
+
|
|
273
|
+
marketing_analysis_proto: marketing_pb.MarketingAnalysis
|
|
274
|
+
|
|
275
|
+
@property
|
|
276
|
+
def tag(self) -> str:
|
|
277
|
+
return self.marketing_analysis_proto.date_interval.tag
|
|
278
|
+
|
|
279
|
+
@functools.cached_property
|
|
280
|
+
def analysis_date_interval(
|
|
281
|
+
self,
|
|
282
|
+
) -> DateInterval:
|
|
283
|
+
return _to_date_interval_dc(self.marketing_analysis_proto.date_interval)
|
|
284
|
+
|
|
285
|
+
@property
|
|
286
|
+
def analysis_date_interval_str(self) -> tuple[str, str]:
|
|
287
|
+
"""Returns a tuple of `(date_start, date_end)` as strings."""
|
|
288
|
+
return (
|
|
289
|
+
self.analysis_date_interval.start.strftime(c.DATE_FORMAT),
|
|
290
|
+
self.analysis_date_interval.end.strftime(c.DATE_FORMAT),
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
@functools.cached_property
|
|
294
|
+
def channel_mapped_media_analyses(self) -> dict[str, MediaAnalysis]:
|
|
295
|
+
"""Returns media analyses mapped to their channel names."""
|
|
296
|
+
return {
|
|
297
|
+
analysis.channel_name: MediaAnalysis(analysis)
|
|
298
|
+
for analysis in self.marketing_analysis_proto.media_analyses
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
@functools.cached_property
|
|
302
|
+
def channel_mapped_non_media_analyses(self) -> dict[str, NonMediaAnalysis]:
|
|
303
|
+
"""Returns non-media analyses mapped to their non-media names."""
|
|
304
|
+
return {
|
|
305
|
+
analysis.non_media_name: NonMediaAnalysis(analysis)
|
|
306
|
+
for analysis in self.marketing_analysis_proto.non_media_analyses
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
@functools.cached_property
|
|
310
|
+
def baseline_analysis(self) -> NonMediaAnalysis:
|
|
311
|
+
"""Returns a "baseline" non media analysis among the given values.
|
|
312
|
+
|
|
313
|
+
Raises:
|
|
314
|
+
ValueError: if there is no "baseline" analysis
|
|
315
|
+
"""
|
|
316
|
+
for non_media_analysis in self.marketing_analysis_proto.non_media_analyses:
|
|
317
|
+
if non_media_analysis.non_media_name == c.BASELINE:
|
|
318
|
+
return NonMediaAnalysis(non_media_analysis)
|
|
319
|
+
else:
|
|
320
|
+
raise ValueError(
|
|
321
|
+
f"No '{c.BASELINE}' found in the set of `NonMediaAnalysis` for this"
|
|
322
|
+
" `MarketingAnalysis`."
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
@functools.cached_property
|
|
326
|
+
def response_curves(self) -> list[ResponseCurve]:
|
|
327
|
+
"""Returns a list of `ResponseCurve`s."""
|
|
328
|
+
return [
|
|
329
|
+
ResponseCurve(m_analysis.channel_name, m_analysis.response_curve)
|
|
330
|
+
for m_analysis in self.marketing_analysis_proto.media_analyses
|
|
331
|
+
]
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
@dataclasses.dataclass(frozen=True)
|
|
335
|
+
class IncrementalOutcomeGrid:
|
|
336
|
+
"""A wrapper for `IncrementalOutcomeGrid` proto with derived properties."""
|
|
337
|
+
|
|
338
|
+
incremental_outcome_grid_proto: budget_pb.IncrementalOutcomeGrid
|
|
339
|
+
|
|
340
|
+
@property
|
|
341
|
+
def name(self) -> str:
|
|
342
|
+
return self.incremental_outcome_grid_proto.name
|
|
343
|
+
|
|
344
|
+
@property
|
|
345
|
+
def channel_spend_grids(self) -> dict[str, list[tuple[float, float]]]:
|
|
346
|
+
"""Returns channels mapped to (spend, incremental outcome) tuples."""
|
|
347
|
+
grid = {}
|
|
348
|
+
for channel_cells in self.incremental_outcome_grid_proto.channel_cells:
|
|
349
|
+
grid[channel_cells.channel_name] = [
|
|
350
|
+
(cell.spend, cell.incremental_outcome.value)
|
|
351
|
+
for cell in channel_cells.cells
|
|
352
|
+
]
|
|
353
|
+
return grid
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
class _SpecMixin(abc.ABC):
|
|
357
|
+
"""Mixin for both budget and R&F optimization specs."""
|
|
358
|
+
|
|
359
|
+
@property
|
|
360
|
+
@abc.abstractmethod
|
|
361
|
+
def _date_interval_proto(self) -> date_interval_pb.DateInterval:
|
|
362
|
+
"""Returns the date interval proto."""
|
|
363
|
+
raise NotImplementedError()
|
|
364
|
+
|
|
365
|
+
@functools.cached_property
|
|
366
|
+
def date_interval(self) -> DateInterval:
|
|
367
|
+
"""Returns the spec's date interval."""
|
|
368
|
+
date_interval_proto = self._date_interval_proto
|
|
369
|
+
return DateInterval((
|
|
370
|
+
datetime.date(
|
|
371
|
+
year=date_interval_proto.start_date.year,
|
|
372
|
+
month=date_interval_proto.start_date.month,
|
|
373
|
+
day=date_interval_proto.start_date.day,
|
|
374
|
+
),
|
|
375
|
+
datetime.date(
|
|
376
|
+
year=date_interval_proto.end_date.year,
|
|
377
|
+
month=date_interval_proto.end_date.month,
|
|
378
|
+
day=date_interval_proto.end_date.day,
|
|
379
|
+
),
|
|
380
|
+
))
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
@dataclasses.dataclass(frozen=True)
|
|
384
|
+
class BudgetOptimizationSpec(_SpecMixin):
|
|
385
|
+
"""A wrapper for `BudgetOptimizationSpec` proto with derived properties."""
|
|
386
|
+
|
|
387
|
+
budget_optimization_spec_proto: budget_pb.BudgetOptimizationSpec
|
|
388
|
+
|
|
389
|
+
@property
|
|
390
|
+
def _date_interval_proto(self) -> date_interval_pb.DateInterval:
|
|
391
|
+
return self.budget_optimization_spec_proto.date_interval
|
|
392
|
+
|
|
393
|
+
@property
|
|
394
|
+
def date_interval_tag(self) -> str:
|
|
395
|
+
return self._date_interval_proto.tag
|
|
396
|
+
|
|
397
|
+
@property
|
|
398
|
+
def objective(self) -> target_metric_pb.TargetMetric:
|
|
399
|
+
return self.budget_optimization_spec_proto.objective
|
|
400
|
+
|
|
401
|
+
@property
|
|
402
|
+
def is_fixed_scenario(self) -> bool:
|
|
403
|
+
return (
|
|
404
|
+
self.budget_optimization_spec_proto.WhichOneof("scenario")
|
|
405
|
+
== "fixed_budget_scenario"
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
@property
|
|
409
|
+
def max_budget(self) -> float:
|
|
410
|
+
"""Returns the maximum budget for this spec.
|
|
411
|
+
|
|
412
|
+
Max budget is the total budget for a fixed scenario spec, or the max budget
|
|
413
|
+
upper bound for a flexible scenario spec.
|
|
414
|
+
"""
|
|
415
|
+
if self.is_fixed_scenario:
|
|
416
|
+
return (
|
|
417
|
+
self.budget_optimization_spec_proto.fixed_budget_scenario.total_budget
|
|
418
|
+
)
|
|
419
|
+
else:
|
|
420
|
+
return (
|
|
421
|
+
self.budget_optimization_spec_proto.flexible_budget_scenario.total_budget_constraint.max_budget
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
@functools.cached_property
|
|
425
|
+
def channel_constraints(self) -> list[budget_pb.ChannelConstraint]:
|
|
426
|
+
"""Returns a list of `ChannelConstraint`s.
|
|
427
|
+
|
|
428
|
+
If the underlying spec proto has no channel constraints, then it is implied
|
|
429
|
+
that this spec's maximum budget is applied to them. Returns an empty list in
|
|
430
|
+
this case, and it is up to the caller to handle.
|
|
431
|
+
"""
|
|
432
|
+
return list(self.budget_optimization_spec_proto.channel_constraints)
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
@dataclasses.dataclass(frozen=True)
|
|
436
|
+
class RfOptimizationSpec(_SpecMixin):
|
|
437
|
+
"""A wrapper for `ReachFrequencyOptimizationSpec` proto with derived properties."""
|
|
438
|
+
|
|
439
|
+
rf_optimization_spec_proto: rf_pb.ReachFrequencyOptimizationSpec
|
|
440
|
+
|
|
441
|
+
@property
|
|
442
|
+
def _date_interval_proto(self) -> date_interval_pb.DateInterval:
|
|
443
|
+
return self.rf_optimization_spec_proto.date_interval
|
|
444
|
+
|
|
445
|
+
@property
|
|
446
|
+
def objective(self) -> target_metric_pb.TargetMetric:
|
|
447
|
+
return self.rf_optimization_spec_proto.objective
|
|
448
|
+
|
|
449
|
+
@property
|
|
450
|
+
def total_budget_constraint(self) -> constraints_pb.BudgetConstraint:
|
|
451
|
+
return self.rf_optimization_spec_proto.total_budget_constraint
|
|
452
|
+
|
|
453
|
+
@functools.cached_property
|
|
454
|
+
def channel_constraints(self) -> list[rf_pb.RfChannelConstraint]:
|
|
455
|
+
"""Returns a list of `RfChannelConstraint`s."""
|
|
456
|
+
return list(self.rf_optimization_spec_proto.rf_channel_constraints)
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
class _NamedResultMixin(abc.ABC):
|
|
460
|
+
"""Mixin for named optimization results with assigned group ID."""
|
|
461
|
+
|
|
462
|
+
@property
|
|
463
|
+
@abc.abstractmethod
|
|
464
|
+
def group_id(self) -> str:
|
|
465
|
+
raise NotImplementedError()
|
|
466
|
+
|
|
467
|
+
@property
|
|
468
|
+
@abc.abstractmethod
|
|
469
|
+
def name(self) -> str:
|
|
470
|
+
raise NotImplementedError()
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
@dataclasses.dataclass(frozen=True)
|
|
474
|
+
class BudgetOptimizationResult(_NamedResultMixin):
|
|
475
|
+
"""A wrapper for `BudgetOptimizationResult` proto with derived properties."""
|
|
476
|
+
|
|
477
|
+
budget_optimization_result_proto: budget_pb.BudgetOptimizationResult
|
|
478
|
+
|
|
479
|
+
@property
|
|
480
|
+
def name(self) -> str:
|
|
481
|
+
return self.budget_optimization_result_proto.name
|
|
482
|
+
|
|
483
|
+
@property
|
|
484
|
+
def group_id(self) -> str:
|
|
485
|
+
return self.budget_optimization_result_proto.group_id
|
|
486
|
+
|
|
487
|
+
@functools.cached_property
|
|
488
|
+
def spec(self) -> BudgetOptimizationSpec:
|
|
489
|
+
return BudgetOptimizationSpec(self.budget_optimization_result_proto.spec)
|
|
490
|
+
|
|
491
|
+
@functools.cached_property
|
|
492
|
+
def optimized_marketing_analysis(self) -> MarketingAnalysis:
|
|
493
|
+
return MarketingAnalysis(
|
|
494
|
+
self.budget_optimization_result_proto.optimized_marketing_analysis
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
@functools.cached_property
|
|
498
|
+
def incremental_outcome_grid(self) -> IncrementalOutcomeGrid:
|
|
499
|
+
return IncrementalOutcomeGrid(
|
|
500
|
+
self.budget_optimization_result_proto.incremental_outcome_grid
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
@functools.cached_property
|
|
504
|
+
def response_curves(self) -> list[ResponseCurve]:
|
|
505
|
+
return MarketingAnalysis(
|
|
506
|
+
self.budget_optimization_result_proto.optimized_marketing_analysis
|
|
507
|
+
).response_curves
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
@dataclasses.dataclass(frozen=True)
|
|
511
|
+
class FrequencyOutcomeGrid:
|
|
512
|
+
"""A wrapper for `FrequencyOutcomeGrid` proto with derived properties."""
|
|
513
|
+
|
|
514
|
+
frequency_outcome_grid_proto: rf_pb.FrequencyOutcomeGrid
|
|
515
|
+
|
|
516
|
+
@property
|
|
517
|
+
def name(self) -> str:
|
|
518
|
+
return self.frequency_outcome_grid_proto.name
|
|
519
|
+
|
|
520
|
+
@property
|
|
521
|
+
def channel_frequency_grids(self) -> dict[str, list[tuple[float, float]]]:
|
|
522
|
+
"""Returns channels mapped to (frequency, outcome) tuples."""
|
|
523
|
+
grid = {}
|
|
524
|
+
for channel_cells in self.frequency_outcome_grid_proto.channel_cells:
|
|
525
|
+
grid[channel_cells.channel_name] = [
|
|
526
|
+
(cell.reach_frequency.average_frequency, cell.outcome.value)
|
|
527
|
+
for cell in channel_cells.cells
|
|
528
|
+
]
|
|
529
|
+
return grid
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
@dataclasses.dataclass(frozen=True)
|
|
533
|
+
class ReachFrequencyOptimizationResult(_NamedResultMixin):
|
|
534
|
+
"""A wrapper for `ReachFrequencyOptimizationResult` proto with derived properties."""
|
|
535
|
+
|
|
536
|
+
rf_optimization_result_proto: rf_pb.ReachFrequencyOptimizationResult
|
|
537
|
+
|
|
538
|
+
@property
|
|
539
|
+
def name(self) -> str:
|
|
540
|
+
return self.rf_optimization_result_proto.name
|
|
541
|
+
|
|
542
|
+
@property
|
|
543
|
+
def group_id(self) -> str:
|
|
544
|
+
return self.rf_optimization_result_proto.group_id
|
|
545
|
+
|
|
546
|
+
@functools.cached_property
|
|
547
|
+
def spec(self) -> RfOptimizationSpec:
|
|
548
|
+
return RfOptimizationSpec(self.rf_optimization_result_proto.spec)
|
|
549
|
+
|
|
550
|
+
@functools.cached_property
|
|
551
|
+
def channel_mapped_optimized_frequencies(self) -> dict[str, float]:
|
|
552
|
+
"""Returns optimized frequencies mapped to their channel names."""
|
|
553
|
+
return {
|
|
554
|
+
optimized_channel_frequency.channel_name: (
|
|
555
|
+
optimized_channel_frequency.optimal_average_frequency
|
|
556
|
+
)
|
|
557
|
+
for optimized_channel_frequency in self.rf_optimization_result_proto.optimized_channel_frequencies
|
|
558
|
+
}
|
|
559
|
+
|
|
560
|
+
@functools.cached_property
|
|
561
|
+
def optimized_marketing_analysis(self) -> MarketingAnalysis:
|
|
562
|
+
return MarketingAnalysis(
|
|
563
|
+
self.rf_optimization_result_proto.optimized_marketing_analysis
|
|
564
|
+
)
|
|
565
|
+
|
|
566
|
+
@functools.cached_property
|
|
567
|
+
def frequency_outcome_grid(self) -> FrequencyOutcomeGrid:
|
|
568
|
+
return FrequencyOutcomeGrid(
|
|
569
|
+
self.rf_optimization_result_proto.frequency_outcome_grid
|
|
570
|
+
)
|
|
571
|
+
|
|
572
|
+
|
|
573
|
+
@dataclasses.dataclass(frozen=True)
|
|
574
|
+
class MarketingData:
|
|
575
|
+
"""A wrapper for `MarketingData` proto with derived properties."""
|
|
576
|
+
|
|
577
|
+
marketing_data_proto: marketing_data_pb.MarketingData
|
|
578
|
+
|
|
579
|
+
@property
|
|
580
|
+
def _marketing_data_points(
|
|
581
|
+
self,
|
|
582
|
+
) -> list[marketing_data_pb.MarketingDataPoint]:
|
|
583
|
+
"""Returns a list of `MarketingDataPoint`s."""
|
|
584
|
+
return list(self.marketing_data_proto.marketing_data_points)
|
|
585
|
+
|
|
586
|
+
@functools.cached_property
|
|
587
|
+
def media_channels(self) -> list[str]:
|
|
588
|
+
"""Returns unique (non-R&F) media channel names in the marketing data."""
|
|
589
|
+
channels = set()
|
|
590
|
+
for data_point in self._marketing_data_points:
|
|
591
|
+
for var in data_point.media_variables:
|
|
592
|
+
channels.add(var.channel_name)
|
|
593
|
+
return sorted(channels) # For deterministic order in iterating.
|
|
594
|
+
|
|
595
|
+
@functools.cached_property
|
|
596
|
+
def rf_channels(self) -> list[str]:
|
|
597
|
+
"""Returns unique R&F channel names in the marketing data."""
|
|
598
|
+
channels = set()
|
|
599
|
+
for data_point in self._marketing_data_points:
|
|
600
|
+
for var in data_point.reach_frequency_variables:
|
|
601
|
+
channels.add(var.channel_name)
|
|
602
|
+
return sorted(channels) # For deterministic order in iterating.
|
|
603
|
+
|
|
604
|
+
@functools.cached_property
|
|
605
|
+
def date_intervals(self) -> list[DateInterval]:
|
|
606
|
+
"""Returns all date intervals in the marketing data."""
|
|
607
|
+
date_intervals = set()
|
|
608
|
+
for data_point in self._marketing_data_points:
|
|
609
|
+
date_intervals.add(_to_date_interval_dc(data_point.date_interval))
|
|
610
|
+
return sorted(date_intervals)
|
|
611
|
+
|
|
612
|
+
def media_channel_spends(
|
|
613
|
+
self, date_interval: tc.DateInterval
|
|
614
|
+
) -> dict[str, float]:
|
|
615
|
+
"""Returns non-RF media channel names mapped to their total spend values, for the given date interval.
|
|
616
|
+
|
|
617
|
+
All channel spends in time coordinates between `[start, end)` of the given
|
|
618
|
+
date interval are summed up.
|
|
619
|
+
|
|
620
|
+
Args:
|
|
621
|
+
date_interval: the date interval to query for
|
|
622
|
+
|
|
623
|
+
Returns:
|
|
624
|
+
A dict of channel names mapped to their total spend values, for the given
|
|
625
|
+
date interval.
|
|
626
|
+
"""
|
|
627
|
+
date_interval = DateInterval(tc.normalize_date_interval(date_interval))
|
|
628
|
+
channel_spends = {channel: 0.0 for channel in self.media_channels}
|
|
629
|
+
for data_point in self._marketing_data_points:
|
|
630
|
+
# The time coordinate for a marketing data point is the start date of its
|
|
631
|
+
# date interval field: test that it is contained within the given interval
|
|
632
|
+
data_point_date = _to_date_interval_dc(data_point.date_interval).start
|
|
633
|
+
if data_point_date not in date_interval:
|
|
634
|
+
continue
|
|
635
|
+
for var in data_point.media_variables:
|
|
636
|
+
channel_spends[var.channel_name] = (
|
|
637
|
+
channel_spends[var.channel_name] + var.media_spend
|
|
638
|
+
)
|
|
639
|
+
return channel_spends
|
|
640
|
+
|
|
641
|
+
def rf_channel_spends(
|
|
642
|
+
self, date_interval: tc.DateInterval
|
|
643
|
+
) -> dict[str, float]:
|
|
644
|
+
"""Returns *Reach and Frequency* channel names mapped to their total spend values, for the given date interval.
|
|
645
|
+
|
|
646
|
+
All channel spends in time coordinates between `[start, end)` of the given
|
|
647
|
+
date interval are summed up.
|
|
648
|
+
|
|
649
|
+
Args:
|
|
650
|
+
date_interval: the date interval to query for
|
|
651
|
+
|
|
652
|
+
Returns:
|
|
653
|
+
A dict of channel names mapped to their total spend values, for the given
|
|
654
|
+
date interval.
|
|
655
|
+
"""
|
|
656
|
+
date_interval = DateInterval(tc.normalize_date_interval(date_interval))
|
|
657
|
+
channel_spends = {channel: 0.0 for channel in self.rf_channels}
|
|
658
|
+
for data_point in self._marketing_data_points:
|
|
659
|
+
# The time coordinate for a marketing data point is the start date of its
|
|
660
|
+
# date interval field: test that it is contained within the given interval
|
|
661
|
+
data_point_date = _to_date_interval_dc(data_point.date_interval).start
|
|
662
|
+
if data_point_date not in date_interval:
|
|
663
|
+
continue
|
|
664
|
+
for var in data_point.reach_frequency_variables:
|
|
665
|
+
channel_spends[var.channel_name] = (
|
|
666
|
+
channel_spends[var.channel_name] + var.spend
|
|
667
|
+
)
|
|
668
|
+
return channel_spends
|
|
669
|
+
|
|
670
|
+
def all_channel_spends(
|
|
671
|
+
self, date_interval: tc.DateInterval
|
|
672
|
+
) -> dict[str, float]:
|
|
673
|
+
"""Returns *all* channel names mapped to their total spend values, for the given date interval.
|
|
674
|
+
|
|
675
|
+
All channel spends in time coordinates between `[start, end)` of the given
|
|
676
|
+
date interval are summed up.
|
|
677
|
+
|
|
678
|
+
Args:
|
|
679
|
+
date_interval: the date interval to query for
|
|
680
|
+
|
|
681
|
+
Returns:
|
|
682
|
+
A dict of channel names mapped to their total spend values, for the given
|
|
683
|
+
date interval.
|
|
684
|
+
"""
|
|
685
|
+
spends = self.rf_channel_spends(date_interval)
|
|
686
|
+
spends.update(self.media_channel_spends(date_interval))
|
|
687
|
+
return spends
|
|
688
|
+
|
|
689
|
+
|
|
690
|
+
@dataclasses.dataclass(frozen=True)
|
|
691
|
+
class Mmm:
|
|
692
|
+
"""A wrapper for `Mmm` proto with derived properties."""
|
|
693
|
+
|
|
694
|
+
mmm_proto: mmm_pb.Mmm
|
|
695
|
+
|
|
696
|
+
@functools.cached_property
|
|
697
|
+
def marketing_data(self) -> MarketingData:
|
|
698
|
+
"""Returns marketing data inside the MMM model kernel."""
|
|
699
|
+
return MarketingData(self.mmm_proto.mmm_kernel.marketing_data)
|
|
700
|
+
|
|
701
|
+
@property
|
|
702
|
+
def model_fit(self) -> fit_pb.ModelFit:
|
|
703
|
+
return self.mmm_proto.model_fit
|
|
704
|
+
|
|
705
|
+
@functools.cached_property
|
|
706
|
+
def model_fit_results(self) -> dict[str, fit_pb.Result]:
|
|
707
|
+
"""Returns each model fit `Result`, mapped to its dataset name."""
|
|
708
|
+
return {result.name: result for result in self.model_fit.results}
|
|
709
|
+
|
|
710
|
+
@functools.cached_property
|
|
711
|
+
def marketing_analyses(self) -> list[MarketingAnalysis]:
|
|
712
|
+
"""Returns a list of `MarketingAnalysis` wrappers."""
|
|
713
|
+
return [
|
|
714
|
+
MarketingAnalysis(analysis)
|
|
715
|
+
for analysis in self.mmm_proto.marketing_analysis_list.marketing_analyses
|
|
716
|
+
]
|
|
717
|
+
|
|
718
|
+
@functools.cached_property
|
|
719
|
+
def tagged_marketing_analyses(
|
|
720
|
+
self,
|
|
721
|
+
) -> dict[str, MarketingAnalysis]:
|
|
722
|
+
"""Returns each marketing analysis, mapped to its tag name."""
|
|
723
|
+
return {analysis.tag: analysis for analysis in self.marketing_analyses}
|
|
724
|
+
|
|
725
|
+
@functools.cached_property
|
|
726
|
+
def budget_optimization_results(
|
|
727
|
+
self,
|
|
728
|
+
) -> list[BudgetOptimizationResult]:
|
|
729
|
+
"""Returns a list of `BudgetOptimizationResult` wrappers."""
|
|
730
|
+
return [
|
|
731
|
+
BudgetOptimizationResult(result)
|
|
732
|
+
for result in self.mmm_proto.marketing_optimization.budget_optimization.results
|
|
733
|
+
]
|
|
734
|
+
|
|
735
|
+
@functools.cached_property
|
|
736
|
+
def reach_frequency_optimization_results(
|
|
737
|
+
self,
|
|
738
|
+
) -> list[ReachFrequencyOptimizationResult]:
|
|
739
|
+
"""Returns a list of `ReachFrequencyOptimizationResult` wrappers."""
|
|
740
|
+
return [
|
|
741
|
+
ReachFrequencyOptimizationResult(result)
|
|
742
|
+
for result in self.mmm_proto.marketing_optimization.reach_frequency_optimization.results
|
|
743
|
+
]
|