google-meridian 1.3.2__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/METADATA +18 -11
  2. google_meridian-1.5.0.dist-info/RECORD +112 -0
  3. {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/WHEEL +1 -1
  4. {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/top_level.txt +1 -0
  5. meridian/analysis/analyzer.py +558 -398
  6. meridian/analysis/optimizer.py +90 -68
  7. meridian/analysis/review/reviewer.py +4 -1
  8. meridian/analysis/summarizer.py +13 -3
  9. meridian/analysis/test_utils.py +2911 -2102
  10. meridian/analysis/visualizer.py +37 -14
  11. meridian/backend/__init__.py +106 -0
  12. meridian/constants.py +2 -0
  13. meridian/data/input_data.py +30 -52
  14. meridian/data/input_data_builder.py +2 -9
  15. meridian/data/test_utils.py +107 -51
  16. meridian/data/validator.py +48 -0
  17. meridian/mlflow/autolog.py +19 -9
  18. meridian/model/__init__.py +2 -0
  19. meridian/model/adstock_hill.py +3 -5
  20. meridian/model/context.py +1059 -0
  21. meridian/model/eda/constants.py +335 -4
  22. meridian/model/eda/eda_engine.py +723 -312
  23. meridian/model/eda/eda_outcome.py +177 -33
  24. meridian/model/equations.py +418 -0
  25. meridian/model/knots.py +58 -47
  26. meridian/model/model.py +228 -878
  27. meridian/model/model_test_data.py +38 -0
  28. meridian/model/posterior_sampler.py +103 -62
  29. meridian/model/prior_sampler.py +114 -94
  30. meridian/model/spec.py +23 -14
  31. meridian/templates/card.html.jinja +9 -7
  32. meridian/templates/chart.html.jinja +1 -6
  33. meridian/templates/finding.html.jinja +19 -0
  34. meridian/templates/findings.html.jinja +33 -0
  35. meridian/templates/formatter.py +41 -5
  36. meridian/templates/formatter_test.py +127 -0
  37. meridian/templates/style.css +66 -9
  38. meridian/templates/style.scss +85 -4
  39. meridian/templates/table.html.jinja +1 -0
  40. meridian/version.py +1 -1
  41. scenarioplanner/__init__.py +42 -0
  42. scenarioplanner/converters/__init__.py +25 -0
  43. scenarioplanner/converters/dataframe/__init__.py +28 -0
  44. scenarioplanner/converters/dataframe/budget_opt_converters.py +383 -0
  45. scenarioplanner/converters/dataframe/common.py +71 -0
  46. scenarioplanner/converters/dataframe/constants.py +137 -0
  47. scenarioplanner/converters/dataframe/converter.py +42 -0
  48. scenarioplanner/converters/dataframe/dataframe_model_converter.py +70 -0
  49. scenarioplanner/converters/dataframe/marketing_analyses_converters.py +543 -0
  50. scenarioplanner/converters/dataframe/rf_opt_converters.py +314 -0
  51. scenarioplanner/converters/mmm.py +743 -0
  52. scenarioplanner/converters/mmm_converter.py +58 -0
  53. scenarioplanner/converters/sheets.py +156 -0
  54. scenarioplanner/converters/test_data.py +714 -0
  55. scenarioplanner/linkingapi/__init__.py +47 -0
  56. scenarioplanner/linkingapi/constants.py +27 -0
  57. scenarioplanner/linkingapi/url_generator.py +131 -0
  58. scenarioplanner/mmm_ui_proto_generator.py +355 -0
  59. schema/__init__.py +5 -2
  60. schema/mmm_proto_generator.py +71 -0
  61. schema/model_consumer.py +133 -0
  62. schema/processors/__init__.py +77 -0
  63. schema/processors/budget_optimization_processor.py +832 -0
  64. schema/processors/common.py +64 -0
  65. schema/processors/marketing_processor.py +1137 -0
  66. schema/processors/model_fit_processor.py +367 -0
  67. schema/processors/model_kernel_processor.py +117 -0
  68. schema/processors/model_processor.py +415 -0
  69. schema/processors/reach_frequency_optimization_processor.py +584 -0
  70. schema/serde/distribution.py +12 -7
  71. schema/serde/hyperparameters.py +54 -107
  72. schema/serde/meridian_serde.py +6 -1
  73. schema/test_data.py +380 -0
  74. schema/utils/__init__.py +2 -0
  75. schema/utils/date_range_bucketing.py +117 -0
  76. schema/utils/proto_enum_converter.py +127 -0
  77. google_meridian-1.3.2.dist-info/RECORD +0 -76
  78. {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,415 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Defines common and base classes for processing trained Meridian model to an MMM schema."""
16
+
17
+ import abc
18
+ from collections.abc import Sequence
19
+ import dataclasses
20
+ import datetime
21
+ import functools
22
+ from typing import Generic, TypeVar
23
+
24
+ from google.protobuf import message
25
+ from meridian import constants as c
26
+ from meridian.analysis import analyzer
27
+ from meridian.analysis import optimizer
28
+ from meridian.analysis import visualizer
29
+ from meridian.data import time_coordinates as tc
30
+ from meridian.model import model
31
+ from mmm.v1 import mmm_pb2 as pb
32
+ from mmm.v1.common import date_interval_pb2
33
+ from schema.utils import time_record
34
+ from typing_extensions import override
35
+
36
+
37
+ __all__ = [
38
+ 'ModelProcessor',
39
+ 'TrainedModel',
40
+ 'DatedSpec',
41
+ 'DatedSpecResolver',
42
+ 'OptimizationSpec',
43
+ 'ensure_trained_model',
44
+ ]
45
+
46
+
47
+ class TrainedModel(abc.ABC):
48
+ """Encapsulates a trained MMM model."""
49
+
50
+ def __init__(self, mmm: model.Meridian):
51
+ """Initializes the TrainedModel with a Meridian model.
52
+
53
+ Args:
54
+ mmm: A Meridian model that has been fitted (posterior samples drawn).
55
+
56
+ Raises:
57
+ ValueError: If the model has not been fitted (posterior samples drawn).
58
+ """
59
+ # Ideally, this could be encoded in the model type itself, and we won't need
60
+ # this extra runtime check.
61
+ if mmm.inference_data.prior is None or mmm.inference_data.posterior is None: # pytype: disable=attribute-error
62
+ raise ValueError('MMM model has not been fitted.')
63
+ self._mmm = mmm
64
+
65
+ @property
66
+ def mmm(self) -> model.Meridian:
67
+ return self._mmm
68
+
69
+ @property
70
+ def time_coordinates(self) -> tc.TimeCoordinates:
71
+ return self._mmm.input_data.time_coordinates
72
+
73
+ @functools.cached_property
74
+ def internal_analyzer(self) -> analyzer.Analyzer:
75
+ """Returns an internal `Analyzer` bound to this trained model."""
76
+ return analyzer.Analyzer(
77
+ model_context=self.mmm.model_context,
78
+ inference_data=self.mmm.inference_data,
79
+ )
80
+
81
+ @functools.cached_property
82
+ def internal_optimizer(self) -> optimizer.BudgetOptimizer:
83
+ """Returns an internal `BudgetOptimizer` bound to this trained model."""
84
+ return optimizer.BudgetOptimizer(self.mmm)
85
+
86
+ @functools.cached_property
87
+ def internal_model_diagnostics(self) -> visualizer.ModelDiagnostics:
88
+ """Returns an internal `ModelDiagnostics` bound to this trained model."""
89
+ return visualizer.ModelDiagnostics(self.mmm)
90
+
91
+
92
+ ModelType = model.Meridian | TrainedModel
93
+
94
+
95
+ def ensure_trained_model(model_input: ModelType) -> TrainedModel:
96
+ """Ensure the given model is a trained model, and wrap it in a TrainedModel."""
97
+ if isinstance(model_input, TrainedModel):
98
+ return model_input
99
+ return TrainedModel(model_input)
100
+
101
+
102
+ class Spec(abc.ABC):
103
+ """Contains parameters needed for model-based analysis/optimization."""
104
+
105
+ @abc.abstractmethod
106
+ def validate(self):
107
+ """Checks whether each parameter in the Spec has a valid value.
108
+
109
+ Raises:
110
+ ValueError: If any parameter in the Spec has an invalid value.
111
+ """
112
+
113
+ def __post_init__(self):
114
+ self.validate()
115
+
116
+
117
+ @dataclasses.dataclass(frozen=True)
118
+ class DatedSpec(Spec):
119
+ """A spec with a `[start_date, end_date)` closed-open date range semantic.
120
+
121
+ Attrs:
122
+ start_date: The start date of the analysis/optimization. If left as `None`,
123
+ then this will eventually resolve to a model's first time coordinate.
124
+ end_date: The end date of the analysis/optimization. If left as `None`, then
125
+ this will eventually resolve to a model's last time coordinate. When
126
+ specified, this end date is exclusive.
127
+ date_interval_tag: An optional tag that identifies the date interval.
128
+ """
129
+
130
+ start_date: datetime.date | None = None
131
+ end_date: datetime.date | None = None
132
+ date_interval_tag: str = ''
133
+
134
+ @override
135
+ def validate(self):
136
+ """Overrides the Spec.validate() method to check that dates are valid."""
137
+ if (
138
+ self.start_date is not None
139
+ and self.end_date is not None
140
+ and self.start_date > self.end_date
141
+ ):
142
+ raise ValueError('Start date must be before end date.')
143
+
144
+ def resolver(
145
+ self, time_coordinates: tc.TimeCoordinates
146
+ ) -> 'DatedSpecResolver':
147
+ """Returns a date resolver for this spec, with the given Meridian model."""
148
+ return DatedSpecResolver(self, time_coordinates)
149
+
150
+
151
+ class DatedSpecResolver:
152
+ """Resolves date parameters in specs based on a model's time coordinates."""
153
+
154
+ def __init__(self, spec: DatedSpec, time_coordinates: tc.TimeCoordinates):
155
+ self._spec = spec
156
+ self._time_coordinates = time_coordinates
157
+
158
+ @property
159
+ def _interval_days(self) -> int:
160
+ return self._time_coordinates.interval_days
161
+
162
+ @property
163
+ def time_coordinates(self) -> tc.TimeCoordinates:
164
+ return self._time_coordinates
165
+
166
+ def to_closed_date_interval_tuple(
167
+ self,
168
+ ) -> tuple[str | None, str | None]:
169
+ """Transforms given spec into a closed `[start, end]` date interval tuple.
170
+
171
+ For each of the bookends in the tuple, `None` value indicates a time
172
+ coordinate default (first or last time coordinate, respectively).
173
+
174
+ Returns:
175
+ A **closed** `[start, end]` date interval tuple.
176
+ """
177
+ start, end = (None, None)
178
+
179
+ if self._spec.start_date is not None:
180
+ start = self._spec.start_date.strftime(c.DATE_FORMAT)
181
+ if self._spec.end_date is not None:
182
+ inclusive_end_date = self._spec.end_date - datetime.timedelta(
183
+ days=self._interval_days
184
+ )
185
+ end = inclusive_end_date.strftime(c.DATE_FORMAT)
186
+
187
+ return (start, end)
188
+
189
+ def resolve_to_enumerated_selected_times(self) -> list[str] | None:
190
+ """Resolves the given spec into an enumerated list of time coordinates.
191
+
192
+ Returns:
193
+ An enumerated list of time coordinates, or None (semantic "All") if the
194
+ bound spec is also None.
195
+ """
196
+ start, end = self.to_closed_date_interval_tuple()
197
+ expanded = self._time_coordinates.expand_selected_time_dims(
198
+ start_date=start, end_date=end
199
+ )
200
+ if expanded is None:
201
+ return None
202
+ return [date.strftime(c.DATE_FORMAT) for date in expanded]
203
+
204
+ def resolve_to_bool_selected_times(self) -> list[bool] | None:
205
+ """Resolves the given spec into a list of booleans indicating selected times.
206
+
207
+ Returns:
208
+ A list of booleans indicating selected times, or None (semantic "All") if
209
+ the bound spec is also None.
210
+ """
211
+ selected_times = self.resolve_to_enumerated_selected_times()
212
+ if selected_times is None:
213
+ return None
214
+ return [
215
+ time in selected_times for time in self._time_coordinates.all_dates_str
216
+ ]
217
+
218
+ def collapse_to_date_interval_proto(self) -> date_interval_pb2.DateInterval:
219
+ """Collapses the given spec into a `DateInterval` proto.
220
+
221
+ If the spec's date range is unbounded, then the DateInterval proto will have
222
+ the semantic "All", and we resolve it by consulting the time coordinates of
223
+ the model bound to this resolver.
224
+
225
+ Note that the exclusive end date semantic will be preserved in the returned
226
+ proto.
227
+
228
+ Returns:
229
+ A `DateInterval` proto the represents the date interval specified by the
230
+ spec.
231
+ """
232
+ selected_times = self.resolve_to_enumerated_selected_times()
233
+ if selected_times is None:
234
+ start_date = self._time_coordinates.all_dates[0]
235
+ end_date = self._time_coordinates.all_dates[-1]
236
+ else:
237
+ normalized_selected_times = [
238
+ tc.normalize_date(date) for date in selected_times
239
+ ]
240
+ start_date = normalized_selected_times[0]
241
+ end_date = normalized_selected_times[-1]
242
+
243
+ # Adjust end_date to make it exclusive.
244
+ end_date += datetime.timedelta(days=self._interval_days)
245
+
246
+ return time_record.create_date_interval_pb(
247
+ start_date, end_date, tag=self._spec.date_interval_tag
248
+ )
249
+
250
+ def transform_to_date_interval_protos(
251
+ self,
252
+ ) -> list[date_interval_pb2.DateInterval]:
253
+ """Transforms the given spec into `DateInterval` protos.
254
+
255
+ If the spec's date range is unbounded, then the DateInterval proto will have
256
+ the semantic "All", and we resolve it by consulting the time coordinates of
257
+ the model bound to this resolver.
258
+
259
+ Note that the exclusive end date semantic will be preserved in the returned
260
+ proto.
261
+
262
+ Returns:
263
+ A list of `DateInterval` protos the represents the date intervals
264
+ specified by the spec.
265
+ """
266
+ selected_times = self.resolve_to_enumerated_selected_times()
267
+ if selected_times is None:
268
+ times_list = self._time_coordinates.all_dates
269
+ else:
270
+ times_list = [tc.normalize_date(date) for date in selected_times]
271
+
272
+ date_intervals = []
273
+ for start_date in times_list:
274
+ date_interval = time_record.create_date_interval_pb(
275
+ start_date=start_date,
276
+ end_date=start_date + datetime.timedelta(days=self._interval_days),
277
+ tag=self._spec.date_interval_tag,
278
+ )
279
+ date_intervals.append(date_interval)
280
+
281
+ return date_intervals
282
+
283
+ def resolve_to_date_interval_open_end(
284
+ self,
285
+ ) -> tuple[datetime.date, datetime.date]:
286
+ """Resolves given spec into an open-ended `[start, end)` date interval."""
287
+ start = self._spec.start_date or self._time_coordinates.all_dates[0]
288
+ end = self._spec.end_date
289
+ if end is None:
290
+ end = self._time_coordinates.all_dates[-1]
291
+ # Adjust `end` to make it exclusive, but only if we pulled it from the
292
+ # time coordinates.
293
+ end += datetime.timedelta(days=self._interval_days)
294
+ return (start, end)
295
+
296
+ def resolve_to_date_interval_proto(self) -> date_interval_pb2.DateInterval:
297
+ """Resolves the given spec into a fully specified `DateInterval` proto.
298
+
299
+ If either `start_date` or `end_date` is None in the bound spec, then we
300
+ resolve it by consulting the time coordinates of the model bound to this
301
+ resolver. They are resolved to the first and last time coordinates (plus
302
+ interval length), respectively.
303
+
304
+ Note that the exclusive end date semantic will be preserved in the returned
305
+ proto.
306
+
307
+ Returns:
308
+ A resolved `DateInterval` proto the represents the date interval specified
309
+ by the bound spec.
310
+ """
311
+ start_date, end_date = self.resolve_to_date_interval_open_end()
312
+ return time_record.create_date_interval_pb(
313
+ start_date, end_date, tag=self._spec.date_interval_tag
314
+ )
315
+
316
+
317
+ @dataclasses.dataclass(frozen=True, kw_only=True)
318
+ class OptimizationSpec(DatedSpec):
319
+ """A dated spec for optimization.
320
+
321
+ Attrs:
322
+ optimization_name: The name of the optimization in this spec.
323
+ grid_name: The name of the optimization grid.
324
+ group_id: An optional group ID for linking related optimizations.
325
+ confidence_level: The threshold for computing confidence intervals. Defaults
326
+ to 0.9. Must be a number between 0 and 1.
327
+ """
328
+
329
+ optimization_name: str
330
+ grid_name: str
331
+ group_id: str | None = None
332
+ confidence_level: float = c.DEFAULT_CONFIDENCE_LEVEL
333
+
334
+ @override
335
+ def validate(self):
336
+ """Check optimization parameters are valid."""
337
+ super().validate()
338
+
339
+ if not self.optimization_name or self.optimization_name.isspace():
340
+ raise ValueError('Optimization name must not be empty or blank.')
341
+
342
+ if not self.grid_name or self.grid_name.isspace():
343
+ raise ValueError('Grid name must not be empty or blank.')
344
+
345
+ if self.confidence_level < 0 or self.confidence_level > 1:
346
+ raise ValueError('Confidence level must be between 0 and 1.')
347
+
348
+
349
+ S = TypeVar('S', bound=Spec)
350
+ M = TypeVar('M', bound=message.Message)
351
+
352
+
353
+ class ModelProcessor(abc.ABC, Generic[S, M]):
354
+ """Performs model-based analysis or optimization."""
355
+
356
+ @classmethod
357
+ @abc.abstractmethod
358
+ def spec_type(cls) -> type[S]:
359
+ """Returns the concrete Spec type that this ModelProcessor operates on."""
360
+ raise NotImplementedError()
361
+
362
+ @classmethod
363
+ @abc.abstractmethod
364
+ def output_type(cls) -> type[M]:
365
+ """Returns the concrete output type that this ModelProcessor produces."""
366
+ raise NotImplementedError()
367
+
368
+ @abc.abstractmethod
369
+ def execute(self, specs: Sequence[S]) -> M:
370
+ """Runs an analysis/optimization on the model using the given specs.
371
+
372
+ Args:
373
+ specs: Sequence of Specs containing parameters needed for the
374
+ analysis/optimization. The specs must all be of the same type as
375
+ `self.spec_type()` for this processor
376
+
377
+ Returns:
378
+ A proto containing the results of the analysis/optimization.
379
+ """
380
+ raise NotImplementedError()
381
+
382
+ @abc.abstractmethod
383
+ def _set_output(self, output: pb.Mmm, result: M):
384
+ """Sets the output field in the given `MmmOutput` proto.
385
+
386
+ A model consumer that orchestrated this processor will indirectly call this
387
+ method (via `__call__`) to attach the output of `execute()` (a
388
+ processor-defined message `M`) into a partially built `MmmOutput` proto that
389
+ the model consumer manages.
390
+
391
+ Args:
392
+ output: The container output proto to which the given result message
393
+ should be attached.
394
+ result: An output of `execute()`.
395
+ """
396
+ raise NotImplementedError()
397
+
398
+ def __call__(self, specs: Sequence[S], output: pb.Mmm):
399
+ """Runs an analysis/optimization on the model using the given specs.
400
+
401
+ This also sets the appropriate output field in the given MmmOutput proto.
402
+
403
+ Args:
404
+ specs: Sequence of Specs containing parameters needed for the
405
+ analysis/optimization. The specs must all be of the same type as
406
+ `self.spec_type()` for this processor
407
+ output: The output proto to which the results of the analysis/optimization
408
+ should be attached.
409
+
410
+ Raises:
411
+ ValueError: If any spec is not of the same type as `self.spec_type()`.
412
+ """
413
+ if not all([isinstance(spec, self.spec_type()) for spec in specs]):
414
+ raise ValueError('Not all specs are of type %s' % self.spec_type())
415
+ self._set_output(output, self.execute(specs))