google-meridian 1.3.2__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/METADATA +18 -11
- google_meridian-1.5.0.dist-info/RECORD +112 -0
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/WHEEL +1 -1
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/top_level.txt +1 -0
- meridian/analysis/analyzer.py +558 -398
- meridian/analysis/optimizer.py +90 -68
- meridian/analysis/review/reviewer.py +4 -1
- meridian/analysis/summarizer.py +13 -3
- meridian/analysis/test_utils.py +2911 -2102
- meridian/analysis/visualizer.py +37 -14
- meridian/backend/__init__.py +106 -0
- meridian/constants.py +2 -0
- meridian/data/input_data.py +30 -52
- meridian/data/input_data_builder.py +2 -9
- meridian/data/test_utils.py +107 -51
- meridian/data/validator.py +48 -0
- meridian/mlflow/autolog.py +19 -9
- meridian/model/__init__.py +2 -0
- meridian/model/adstock_hill.py +3 -5
- meridian/model/context.py +1059 -0
- meridian/model/eda/constants.py +335 -4
- meridian/model/eda/eda_engine.py +723 -312
- meridian/model/eda/eda_outcome.py +177 -33
- meridian/model/equations.py +418 -0
- meridian/model/knots.py +58 -47
- meridian/model/model.py +228 -878
- meridian/model/model_test_data.py +38 -0
- meridian/model/posterior_sampler.py +103 -62
- meridian/model/prior_sampler.py +114 -94
- meridian/model/spec.py +23 -14
- meridian/templates/card.html.jinja +9 -7
- meridian/templates/chart.html.jinja +1 -6
- meridian/templates/finding.html.jinja +19 -0
- meridian/templates/findings.html.jinja +33 -0
- meridian/templates/formatter.py +41 -5
- meridian/templates/formatter_test.py +127 -0
- meridian/templates/style.css +66 -9
- meridian/templates/style.scss +85 -4
- meridian/templates/table.html.jinja +1 -0
- meridian/version.py +1 -1
- scenarioplanner/__init__.py +42 -0
- scenarioplanner/converters/__init__.py +25 -0
- scenarioplanner/converters/dataframe/__init__.py +28 -0
- scenarioplanner/converters/dataframe/budget_opt_converters.py +383 -0
- scenarioplanner/converters/dataframe/common.py +71 -0
- scenarioplanner/converters/dataframe/constants.py +137 -0
- scenarioplanner/converters/dataframe/converter.py +42 -0
- scenarioplanner/converters/dataframe/dataframe_model_converter.py +70 -0
- scenarioplanner/converters/dataframe/marketing_analyses_converters.py +543 -0
- scenarioplanner/converters/dataframe/rf_opt_converters.py +314 -0
- scenarioplanner/converters/mmm.py +743 -0
- scenarioplanner/converters/mmm_converter.py +58 -0
- scenarioplanner/converters/sheets.py +156 -0
- scenarioplanner/converters/test_data.py +714 -0
- scenarioplanner/linkingapi/__init__.py +47 -0
- scenarioplanner/linkingapi/constants.py +27 -0
- scenarioplanner/linkingapi/url_generator.py +131 -0
- scenarioplanner/mmm_ui_proto_generator.py +355 -0
- schema/__init__.py +5 -2
- schema/mmm_proto_generator.py +71 -0
- schema/model_consumer.py +133 -0
- schema/processors/__init__.py +77 -0
- schema/processors/budget_optimization_processor.py +832 -0
- schema/processors/common.py +64 -0
- schema/processors/marketing_processor.py +1137 -0
- schema/processors/model_fit_processor.py +367 -0
- schema/processors/model_kernel_processor.py +117 -0
- schema/processors/model_processor.py +415 -0
- schema/processors/reach_frequency_optimization_processor.py +584 -0
- schema/serde/distribution.py +12 -7
- schema/serde/hyperparameters.py +54 -107
- schema/serde/meridian_serde.py +6 -1
- schema/test_data.py +380 -0
- schema/utils/__init__.py +2 -0
- schema/utils/date_range_bucketing.py +117 -0
- schema/utils/proto_enum_converter.py +127 -0
- google_meridian-1.3.2.dist-info/RECORD +0 -76
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,367 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Meridian module for analyzing model fit in a Meridian model.
|
|
16
|
+
|
|
17
|
+
This module provides a `ModelFitProcessor`, which assesses the goodness of fit
|
|
18
|
+
of a trained Meridian model. It compares the model's predictions against the
|
|
19
|
+
actual observed data, generating key performance metrics.
|
|
20
|
+
|
|
21
|
+
Key metrics generated include R-squared, MAPE, and Weighted MAPE. The output
|
|
22
|
+
also includes timeseries data of actual values versus predicted values (with
|
|
23
|
+
confidence intervals) and the predicted baseline.
|
|
24
|
+
|
|
25
|
+
The results are structured into a `ModelFit` protobuf message.
|
|
26
|
+
|
|
27
|
+
Key Classes:
|
|
28
|
+
|
|
29
|
+
- `ModelFitSpec`: Dataclass to specify parameters for the model fit analysis,
|
|
30
|
+
such as whether to split by train/test sets and the confidence level for
|
|
31
|
+
intervals.
|
|
32
|
+
- `ModelFitProcessor`: The processor class that performs the fit analysis.
|
|
33
|
+
|
|
34
|
+
Example Usage:
|
|
35
|
+
|
|
36
|
+
```python
|
|
37
|
+
from schema.processors import model_fit_processor
|
|
38
|
+
from schema.processors import model_processor
|
|
39
|
+
|
|
40
|
+
# Assuming 'mmm' is a trained Meridian model object
|
|
41
|
+
trained_model = model_processor.TrainedModel(mmm)
|
|
42
|
+
|
|
43
|
+
# Default spec: split results by train/test if holdout ID exists
|
|
44
|
+
spec = model_fit_processor.ModelFitSpec()
|
|
45
|
+
|
|
46
|
+
processor = model_fit_processor.ModelFitProcessor(trained_model)
|
|
47
|
+
# result is a model_fit_pb2.ModelFit proto
|
|
48
|
+
result = processor.execute([spec])
|
|
49
|
+
|
|
50
|
+
print("Model Fit Analysis Results:")
|
|
51
|
+
for res in result.results:
|
|
52
|
+
print(f" Dataset: {res.name}")
|
|
53
|
+
print(f" R-squared: {res.performance.r_squared:.3f}")
|
|
54
|
+
print(f" MAPE: {res.performance.mape:.3f}")
|
|
55
|
+
print(f" Weighted MAPE: {res.performance.weighted_mape:.3f}")
|
|
56
|
+
# Prediction data is available in res.predictions
|
|
57
|
+
# Each element in res.predictions corresponds to a time point.
|
|
58
|
+
# e.g., res.predictions[0].actual_value
|
|
59
|
+
# e.g., res.predictions[0].predicted_outcome.value
|
|
60
|
+
```
|
|
61
|
+
|
|
62
|
+
Note: Only one spec is supported per processor execution.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
from collections.abc import Mapping, Sequence
|
|
66
|
+
import dataclasses
|
|
67
|
+
import warnings
|
|
68
|
+
|
|
69
|
+
from meridian import constants
|
|
70
|
+
from mmm.v1 import mmm_pb2
|
|
71
|
+
from mmm.v1.common import date_interval_pb2
|
|
72
|
+
from mmm.v1.common import estimate_pb2
|
|
73
|
+
from mmm.v1.fit import model_fit_pb2
|
|
74
|
+
from schema.processors import model_processor
|
|
75
|
+
from schema.utils import time_record
|
|
76
|
+
import xarray as xr
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
__all__ = [
|
|
80
|
+
"ModelFitSpec",
|
|
81
|
+
"ModelFitProcessor",
|
|
82
|
+
]
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@dataclasses.dataclass(frozen=True)
|
|
86
|
+
class ModelFitSpec(model_processor.Spec):
|
|
87
|
+
"""Stores parameters needed for generating ModelFit protos.
|
|
88
|
+
|
|
89
|
+
Attributes:
|
|
90
|
+
split: If `True` and Meridian model contains holdout IDs, results are
|
|
91
|
+
generated for `'Train'`, `'Test'`, and `'All Data'` sets.
|
|
92
|
+
confidence_level: Confidence level for prior and posterior credible
|
|
93
|
+
intervals, represented as a value between zero and one.
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
split: bool = True
|
|
97
|
+
confidence_level: float = constants.DEFAULT_CONFIDENCE_LEVEL
|
|
98
|
+
|
|
99
|
+
def validate(self):
|
|
100
|
+
if self.confidence_level <= 0 or self.confidence_level >= 1:
|
|
101
|
+
raise ValueError(
|
|
102
|
+
"Confidence level must be greater than 0 and less than 1."
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class ModelFitProcessor(
|
|
107
|
+
model_processor.ModelProcessor[ModelFitSpec, model_fit_pb2.ModelFit]
|
|
108
|
+
):
|
|
109
|
+
"""Generates a ModelFit proto for a given trained Meridian model.
|
|
110
|
+
|
|
111
|
+
The proto contains performance metrics for each dataset as well as a list of
|
|
112
|
+
predictions.
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
def __init__(
|
|
116
|
+
self,
|
|
117
|
+
trained_model: model_processor.ModelType,
|
|
118
|
+
):
|
|
119
|
+
trained_model = model_processor.ensure_trained_model(trained_model)
|
|
120
|
+
self._analyzer = trained_model.internal_analyzer
|
|
121
|
+
self._time_coordinates = trained_model.time_coordinates
|
|
122
|
+
|
|
123
|
+
@classmethod
|
|
124
|
+
def spec_type(cls) -> type[ModelFitSpec]:
|
|
125
|
+
return ModelFitSpec
|
|
126
|
+
|
|
127
|
+
@classmethod
|
|
128
|
+
def output_type(cls) -> type[model_fit_pb2.ModelFit]:
|
|
129
|
+
return model_fit_pb2.ModelFit
|
|
130
|
+
|
|
131
|
+
def _set_output(self, output: mmm_pb2.Mmm, result: model_fit_pb2.ModelFit):
|
|
132
|
+
output.model_fit.CopyFrom(result)
|
|
133
|
+
|
|
134
|
+
def execute(self, specs: Sequence[ModelFitSpec]) -> model_fit_pb2.ModelFit:
|
|
135
|
+
model_fit_spec = specs[0]
|
|
136
|
+
if len(specs) > 1:
|
|
137
|
+
warnings.warn(
|
|
138
|
+
"Multiple specs were provided. Only the first one will be used."
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
expected_vs_actual = self._analyzer.expected_vs_actual_data(
|
|
142
|
+
confidence_level=model_fit_spec.confidence_level,
|
|
143
|
+
split_by_holdout_id=model_fit_spec.split,
|
|
144
|
+
aggregate_geos=True,
|
|
145
|
+
)
|
|
146
|
+
metrics = self._analyzer.predictive_accuracy()
|
|
147
|
+
time_to_date_interval = time_record.convert_times_to_date_intervals(
|
|
148
|
+
self._time_coordinates.datetime_index
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
results: list[model_fit_pb2.Result] = []
|
|
152
|
+
|
|
153
|
+
if constants.EVALUATION_SET_VAR in expected_vs_actual.coords:
|
|
154
|
+
results.append(
|
|
155
|
+
self._create_result(
|
|
156
|
+
result_type=constants.TRAIN,
|
|
157
|
+
expected_vs_actual=expected_vs_actual.sel(
|
|
158
|
+
evaluation_set=constants.TRAIN
|
|
159
|
+
),
|
|
160
|
+
metrics=metrics.sel(evaluation_set=constants.TRAIN),
|
|
161
|
+
model_fit_spec=model_fit_spec,
|
|
162
|
+
time_to_date_interval=time_to_date_interval,
|
|
163
|
+
)
|
|
164
|
+
)
|
|
165
|
+
results.append(
|
|
166
|
+
self._create_result(
|
|
167
|
+
result_type=constants.TEST,
|
|
168
|
+
expected_vs_actual=expected_vs_actual.sel(
|
|
169
|
+
evaluation_set=constants.TEST
|
|
170
|
+
),
|
|
171
|
+
metrics=metrics.sel(evaluation_set=constants.TEST),
|
|
172
|
+
model_fit_spec=model_fit_spec,
|
|
173
|
+
time_to_date_interval=time_to_date_interval,
|
|
174
|
+
)
|
|
175
|
+
)
|
|
176
|
+
results.append(
|
|
177
|
+
self._create_result(
|
|
178
|
+
result_type=constants.ALL_DATA,
|
|
179
|
+
expected_vs_actual=expected_vs_actual.sel(
|
|
180
|
+
evaluation_set=constants.ALL_DATA
|
|
181
|
+
),
|
|
182
|
+
metrics=metrics.sel(evaluation_set=constants.ALL_DATA),
|
|
183
|
+
model_fit_spec=model_fit_spec,
|
|
184
|
+
time_to_date_interval=time_to_date_interval,
|
|
185
|
+
)
|
|
186
|
+
)
|
|
187
|
+
else:
|
|
188
|
+
results.append(
|
|
189
|
+
self._create_result(
|
|
190
|
+
result_type=constants.ALL_DATA,
|
|
191
|
+
expected_vs_actual=expected_vs_actual,
|
|
192
|
+
metrics=metrics,
|
|
193
|
+
model_fit_spec=model_fit_spec,
|
|
194
|
+
time_to_date_interval=time_to_date_interval,
|
|
195
|
+
)
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
return model_fit_pb2.ModelFit(results=results)
|
|
199
|
+
|
|
200
|
+
def _create_result(
|
|
201
|
+
self,
|
|
202
|
+
result_type: str,
|
|
203
|
+
expected_vs_actual: xr.Dataset,
|
|
204
|
+
metrics: xr.Dataset,
|
|
205
|
+
model_fit_spec: ModelFitSpec,
|
|
206
|
+
time_to_date_interval: Mapping[str, date_interval_pb2.DateInterval],
|
|
207
|
+
) -> model_fit_pb2.Result:
|
|
208
|
+
"""Creates a proto that stores the model fit results for an evaluation set.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
result_type: The evaluation set (`"Train"`, `"Test"`, or `"All Data"`) for
|
|
212
|
+
the result.
|
|
213
|
+
expected_vs_actual: A dataset containing the expected and actual values
|
|
214
|
+
for the model. This dataset is filtered by the evaluation set in the
|
|
215
|
+
calling code.
|
|
216
|
+
metrics: A dataset containing the performance metrics for the model. This
|
|
217
|
+
dataset is filtered by the evaluation set in the calling code.
|
|
218
|
+
model_fit_spec: An instance of ModelFitSpec.
|
|
219
|
+
time_to_date_interval: A mapping of date strings (in YYYY-MM-DD format) to
|
|
220
|
+
date interval protos.
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
A proto containing the results of the model fit analysis for the given
|
|
224
|
+
evaluation set.
|
|
225
|
+
"""
|
|
226
|
+
|
|
227
|
+
predictions: list[model_fit_pb2.Prediction] = []
|
|
228
|
+
|
|
229
|
+
for start_date in self._time_coordinates.all_dates_str:
|
|
230
|
+
date_interval = time_to_date_interval[start_date]
|
|
231
|
+
actual = (
|
|
232
|
+
expected_vs_actual.data_vars[constants.ACTUAL]
|
|
233
|
+
.sel(
|
|
234
|
+
time=start_date,
|
|
235
|
+
)
|
|
236
|
+
.item()
|
|
237
|
+
)
|
|
238
|
+
expected_dataset = expected_vs_actual[constants.EXPECTED].sel(
|
|
239
|
+
time=start_date,
|
|
240
|
+
)
|
|
241
|
+
expected = expected_dataset.sel(metric=constants.MEAN).item()
|
|
242
|
+
expected_lowerbound = expected_dataset.sel(metric=constants.CI_LO).item()
|
|
243
|
+
expected_upperbound = expected_dataset.sel(metric=constants.CI_HI).item()
|
|
244
|
+
baseline_dataset = expected_vs_actual[constants.BASELINE].sel(
|
|
245
|
+
time=start_date,
|
|
246
|
+
)
|
|
247
|
+
baseline = baseline_dataset.sel(metric=constants.MEAN).item()
|
|
248
|
+
baseline_lowerbound = baseline_dataset.sel(metric=constants.CI_LO).item()
|
|
249
|
+
baseline_upperbound = baseline_dataset.sel(metric=constants.CI_HI).item()
|
|
250
|
+
|
|
251
|
+
prediction = self._create_prediction(
|
|
252
|
+
model_fit_spec=model_fit_spec,
|
|
253
|
+
date_interval=date_interval,
|
|
254
|
+
actual_value=actual,
|
|
255
|
+
estimated_value=expected,
|
|
256
|
+
estimated_lower_bound=expected_lowerbound,
|
|
257
|
+
estimated_upper_bound=expected_upperbound,
|
|
258
|
+
baseline_value=baseline,
|
|
259
|
+
baseline_lower_bound=baseline_lowerbound,
|
|
260
|
+
baseline_upper_bound=baseline_upperbound,
|
|
261
|
+
)
|
|
262
|
+
predictions.append(prediction)
|
|
263
|
+
|
|
264
|
+
performance = self._evaluate_model_fit(metrics)
|
|
265
|
+
|
|
266
|
+
return model_fit_pb2.Result(
|
|
267
|
+
name=result_type, predictions=predictions, performance=performance
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
def _create_prediction(
|
|
271
|
+
self,
|
|
272
|
+
model_fit_spec: ModelFitSpec,
|
|
273
|
+
date_interval: date_interval_pb2.DateInterval,
|
|
274
|
+
actual_value: float,
|
|
275
|
+
estimated_value: float,
|
|
276
|
+
estimated_lower_bound: float,
|
|
277
|
+
estimated_upper_bound: float,
|
|
278
|
+
baseline_value: float,
|
|
279
|
+
baseline_lower_bound: float,
|
|
280
|
+
baseline_upper_bound: float,
|
|
281
|
+
) -> model_fit_pb2.Prediction:
|
|
282
|
+
"""Creates a proto that stores the model's prediction for the given date.
|
|
283
|
+
|
|
284
|
+
Args:
|
|
285
|
+
model_fit_spec: An instance of ModelFitSpec.
|
|
286
|
+
date_interval: A DateInterval proto containing the start date and end date
|
|
287
|
+
for this prediction.
|
|
288
|
+
actual_value: The model's actual value for this date.
|
|
289
|
+
estimated_value: The model's estimated value for this date.
|
|
290
|
+
estimated_lower_bound: The lower bound of the estimated value's confidence
|
|
291
|
+
interval.
|
|
292
|
+
estimated_upper_bound: The upper bound of the estimated value's confidence
|
|
293
|
+
interval.
|
|
294
|
+
baseline_value: The baseline value for this date.
|
|
295
|
+
baseline_lower_bound: The lower bound of the baseline value's confidence
|
|
296
|
+
interval.
|
|
297
|
+
baseline_upper_bound: The upper bound of the baseline value's confidence
|
|
298
|
+
interval.
|
|
299
|
+
|
|
300
|
+
Returns:
|
|
301
|
+
A proto containing the model's predicted value and actual value for the
|
|
302
|
+
given date.
|
|
303
|
+
"""
|
|
304
|
+
|
|
305
|
+
estimate = estimate_pb2.Estimate(value=estimated_value)
|
|
306
|
+
estimate.uncertainties.add(
|
|
307
|
+
probability=model_fit_spec.confidence_level,
|
|
308
|
+
lowerbound=estimated_lower_bound,
|
|
309
|
+
upperbound=estimated_upper_bound,
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
baseline_estimate = estimate_pb2.Estimate(value=baseline_value)
|
|
313
|
+
baseline_estimate.uncertainties.add(
|
|
314
|
+
probability=model_fit_spec.confidence_level,
|
|
315
|
+
lowerbound=baseline_lower_bound,
|
|
316
|
+
upperbound=baseline_upper_bound,
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
return model_fit_pb2.Prediction(
|
|
320
|
+
date_interval=date_interval,
|
|
321
|
+
predicted_outcome=estimate,
|
|
322
|
+
predicted_baseline=baseline_estimate,
|
|
323
|
+
actual_value=actual_value,
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
def _evaluate_model_fit(
|
|
327
|
+
self,
|
|
328
|
+
metrics: xr.Dataset,
|
|
329
|
+
) -> model_fit_pb2.Performance:
|
|
330
|
+
"""Creates a proto that stores the model's performance metrics.
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
metrics: A dataset containing the performance metrics for the model. This
|
|
334
|
+
dataset is filtered by evaluation set before this function is called.
|
|
335
|
+
|
|
336
|
+
Returns:
|
|
337
|
+
A proto containing the model's performance metrics for a specific
|
|
338
|
+
evaluation set.
|
|
339
|
+
"""
|
|
340
|
+
|
|
341
|
+
performance = model_fit_pb2.Performance()
|
|
342
|
+
performance.r_squared = (
|
|
343
|
+
metrics[constants.VALUE]
|
|
344
|
+
.sel(
|
|
345
|
+
geo_granularity=constants.NATIONAL,
|
|
346
|
+
metric=constants.R_SQUARED,
|
|
347
|
+
)
|
|
348
|
+
.item()
|
|
349
|
+
)
|
|
350
|
+
performance.mape = (
|
|
351
|
+
metrics[constants.VALUE]
|
|
352
|
+
.sel(
|
|
353
|
+
geo_granularity=constants.NATIONAL,
|
|
354
|
+
metric=constants.MAPE,
|
|
355
|
+
)
|
|
356
|
+
.item()
|
|
357
|
+
)
|
|
358
|
+
performance.weighted_mape = (
|
|
359
|
+
metrics[constants.VALUE]
|
|
360
|
+
.sel(
|
|
361
|
+
geo_granularity=constants.NATIONAL,
|
|
362
|
+
metric=constants.WMAPE,
|
|
363
|
+
)
|
|
364
|
+
.item()
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
return performance
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Module for transforming a Meridian model into a structured MMM schema.
|
|
16
|
+
|
|
17
|
+
This module provides the `ModelKernelProcessor`, which is responsible for
|
|
18
|
+
transforming the internal state of a trained Meridian model object into a
|
|
19
|
+
structured and portable format defined by the `MmmKernel` protobuf message.
|
|
20
|
+
|
|
21
|
+
The "kernel" includes essential information about the model, such as:
|
|
22
|
+
|
|
23
|
+
- Model specifications and hyperparameters.
|
|
24
|
+
- Inferred parameters distributions (as a serialized ArViz inference data).
|
|
25
|
+
- MMM-agnostic marketing data (i.e. input data to the model).
|
|
26
|
+
|
|
27
|
+
This serialized representation allows the model to be saved, loaded, and
|
|
28
|
+
analyzed across different environments or by other tools that understand the
|
|
29
|
+
`MmmKernel` schema.
|
|
30
|
+
|
|
31
|
+
The serialization logic is primarily handled by the `MeridianSerde` class from
|
|
32
|
+
the `schema.serde` package.
|
|
33
|
+
|
|
34
|
+
Key Classes:
|
|
35
|
+
|
|
36
|
+
- `ModelKernelProcessor`: The processor class that takes a Meridian model
|
|
37
|
+
instance and populates an `MmmKernel` message.
|
|
38
|
+
|
|
39
|
+
Example Usage:
|
|
40
|
+
|
|
41
|
+
```python
|
|
42
|
+
import meridian
|
|
43
|
+
from meridian.model import model
|
|
44
|
+
from mmm.v1 import mmm_pb2
|
|
45
|
+
from schema.processors import model_kernel_processor
|
|
46
|
+
import semver
|
|
47
|
+
|
|
48
|
+
# Assuming 'mmm' is a `meridian.model.Meridian` object.
|
|
49
|
+
# Example:
|
|
50
|
+
# mmm = meridian.model.Meridian(...)
|
|
51
|
+
# mmm.sample_prior(...)
|
|
52
|
+
# mmm.sample_posterior(...)
|
|
53
|
+
|
|
54
|
+
processor = model_kernel_processor.ModelKernelProcessor(
|
|
55
|
+
meridian_model=mmm,
|
|
56
|
+
model_id="my_model_v1",
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
# Create an output Mmm proto message
|
|
60
|
+
output_proto = mmm_pb2.Mmm()
|
|
61
|
+
|
|
62
|
+
# Populate the mmm_kernel field
|
|
63
|
+
processor(output_proto)
|
|
64
|
+
|
|
65
|
+
# Now output_proto.mmm_kernel contains the serialized model.
|
|
66
|
+
# This can be saved to a file, sent over a network, etc.
|
|
67
|
+
print(f"Model Kernel ID: {output_proto.mmm_kernel.model_id}")
|
|
68
|
+
print(f"Meridian Version: {output_proto.mmm_kernel.meridian_version}")
|
|
69
|
+
# Access other fields within output_proto.mmm_kernel as needed.
|
|
70
|
+
```
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
import abc
|
|
74
|
+
|
|
75
|
+
import meridian
|
|
76
|
+
from meridian.model import model
|
|
77
|
+
from mmm.v1 import mmm_pb2 as pb
|
|
78
|
+
from schema.serde import meridian_serde
|
|
79
|
+
import semver
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class ModelKernelProcessor(abc.ABC):
|
|
83
|
+
"""Transcribes a model's stats into an `"MmmKernel` message."""
|
|
84
|
+
|
|
85
|
+
def __init__(
|
|
86
|
+
self,
|
|
87
|
+
meridian_model: model.Meridian,
|
|
88
|
+
model_id: str = '',
|
|
89
|
+
meridian_version: semver.VersionInfo = semver.VersionInfo.parse(
|
|
90
|
+
meridian.__version__
|
|
91
|
+
),
|
|
92
|
+
):
|
|
93
|
+
"""Initializes this `ModelKernelProcessor` with a Meridian model.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
meridian_model: A Meridian model.
|
|
97
|
+
model_id: An optional model identifier unique to the given model.
|
|
98
|
+
meridian_version: The version of current Meridian framework.
|
|
99
|
+
"""
|
|
100
|
+
self._meridian = meridian_model
|
|
101
|
+
self._model_id = model_id
|
|
102
|
+
self._meridian_version = meridian_version
|
|
103
|
+
|
|
104
|
+
def __call__(self, output: pb.Mmm):
|
|
105
|
+
"""Sets `mmm_kernel` field in the given `Mmm` proto.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
output: The output proto to modify.
|
|
109
|
+
"""
|
|
110
|
+
output.mmm_kernel.CopyFrom(
|
|
111
|
+
meridian_serde.MeridianSerde().serialize(
|
|
112
|
+
self._meridian,
|
|
113
|
+
self._model_id,
|
|
114
|
+
self._meridian_version,
|
|
115
|
+
include_convergence_info=True,
|
|
116
|
+
)
|
|
117
|
+
)
|