google-meridian 1.3.2__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/METADATA +8 -4
  2. {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/RECORD +49 -17
  3. {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/top_level.txt +1 -0
  4. meridian/analysis/summarizer.py +7 -2
  5. meridian/analysis/test_utils.py +934 -485
  6. meridian/analysis/visualizer.py +10 -6
  7. meridian/constants.py +1 -0
  8. meridian/data/test_utils.py +82 -10
  9. meridian/model/__init__.py +2 -0
  10. meridian/model/context.py +925 -0
  11. meridian/model/eda/constants.py +1 -0
  12. meridian/model/equations.py +418 -0
  13. meridian/model/knots.py +58 -47
  14. meridian/model/model.py +93 -792
  15. meridian/version.py +1 -1
  16. scenarioplanner/__init__.py +42 -0
  17. scenarioplanner/converters/__init__.py +25 -0
  18. scenarioplanner/converters/dataframe/__init__.py +28 -0
  19. scenarioplanner/converters/dataframe/budget_opt_converters.py +383 -0
  20. scenarioplanner/converters/dataframe/common.py +71 -0
  21. scenarioplanner/converters/dataframe/constants.py +137 -0
  22. scenarioplanner/converters/dataframe/converter.py +42 -0
  23. scenarioplanner/converters/dataframe/dataframe_model_converter.py +70 -0
  24. scenarioplanner/converters/dataframe/marketing_analyses_converters.py +543 -0
  25. scenarioplanner/converters/dataframe/rf_opt_converters.py +314 -0
  26. scenarioplanner/converters/mmm.py +743 -0
  27. scenarioplanner/converters/mmm_converter.py +58 -0
  28. scenarioplanner/converters/sheets.py +156 -0
  29. scenarioplanner/converters/test_data.py +714 -0
  30. scenarioplanner/linkingapi/__init__.py +47 -0
  31. scenarioplanner/linkingapi/constants.py +27 -0
  32. scenarioplanner/linkingapi/url_generator.py +131 -0
  33. scenarioplanner/mmm_ui_proto_generator.py +354 -0
  34. schema/__init__.py +5 -2
  35. schema/mmm_proto_generator.py +71 -0
  36. schema/model_consumer.py +133 -0
  37. schema/processors/__init__.py +77 -0
  38. schema/processors/budget_optimization_processor.py +832 -0
  39. schema/processors/common.py +64 -0
  40. schema/processors/marketing_processor.py +1136 -0
  41. schema/processors/model_fit_processor.py +367 -0
  42. schema/processors/model_kernel_processor.py +117 -0
  43. schema/processors/model_processor.py +412 -0
  44. schema/processors/reach_frequency_optimization_processor.py +584 -0
  45. schema/test_data.py +380 -0
  46. schema/utils/__init__.py +1 -0
  47. schema/utils/date_range_bucketing.py +117 -0
  48. {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/WHEEL +0 -0
  49. {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,412 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Defines common and base classes for processing trained Meridian model to an MMM schema."""
16
+
17
+ import abc
18
+ from collections.abc import Sequence
19
+ import dataclasses
20
+ import datetime
21
+ import functools
22
+ from typing import Generic, TypeVar
23
+
24
+ from google.protobuf import message
25
+ from meridian import constants as c
26
+ from meridian.analysis import analyzer
27
+ from meridian.analysis import optimizer
28
+ from meridian.analysis import visualizer
29
+ from meridian.data import time_coordinates as tc
30
+ from meridian.model import model
31
+ from mmm.v1 import mmm_pb2 as pb
32
+ from mmm.v1.common import date_interval_pb2
33
+ from schema.utils import time_record
34
+ from typing_extensions import override
35
+
36
+
37
+ __all__ = [
38
+ 'ModelProcessor',
39
+ 'TrainedModel',
40
+ 'DatedSpec',
41
+ 'DatedSpecResolver',
42
+ 'OptimizationSpec',
43
+ 'ensure_trained_model',
44
+ ]
45
+
46
+
47
+ class TrainedModel(abc.ABC):
48
+ """Encapsulates a trained MMM model."""
49
+
50
+ def __init__(self, mmm: model.Meridian):
51
+ """Initializes the TrainedModel with a Meridian model.
52
+
53
+ Args:
54
+ mmm: A Meridian model that has been fitted (posterior samples drawn).
55
+
56
+ Raises:
57
+ ValueError: If the model has not been fitted (posterior samples drawn).
58
+ """
59
+ # Ideally, this could be encoded in the model type itself, and we won't need
60
+ # this extra runtime check.
61
+ if mmm.inference_data.prior is None or mmm.inference_data.posterior is None: # pytype: disable=attribute-error
62
+ raise ValueError('MMM model has not been fitted.')
63
+ self._mmm = mmm
64
+
65
+ @property
66
+ def mmm(self) -> model.Meridian:
67
+ return self._mmm
68
+
69
+ @property
70
+ def time_coordinates(self) -> tc.TimeCoordinates:
71
+ return self._mmm.input_data.time_coordinates
72
+
73
+ @functools.cached_property
74
+ def internal_analyzer(self) -> analyzer.Analyzer:
75
+ """Returns an internal `Analyzer` bound to this trained model."""
76
+ return analyzer.Analyzer(self.mmm)
77
+
78
+ @functools.cached_property
79
+ def internal_optimizer(self) -> optimizer.BudgetOptimizer:
80
+ """Returns an internal `BudgetOptimizer` bound to this trained model."""
81
+ return optimizer.BudgetOptimizer(self.mmm)
82
+
83
+ @functools.cached_property
84
+ def internal_model_diagnostics(self) -> visualizer.ModelDiagnostics:
85
+ """Returns an internal `ModelDiagnostics` bound to this trained model."""
86
+ return visualizer.ModelDiagnostics(self.mmm)
87
+
88
+
89
+ ModelType = model.Meridian | TrainedModel
90
+
91
+
92
+ def ensure_trained_model(model_input: ModelType) -> TrainedModel:
93
+ """Ensure the given model is a trained model, and wrap it in a TrainedModel."""
94
+ if isinstance(model_input, TrainedModel):
95
+ return model_input
96
+ return TrainedModel(model_input)
97
+
98
+
99
+ class Spec(abc.ABC):
100
+ """Contains parameters needed for model-based analysis/optimization."""
101
+
102
+ @abc.abstractmethod
103
+ def validate(self):
104
+ """Checks whether each parameter in the Spec has a valid value.
105
+
106
+ Raises:
107
+ ValueError: If any parameter in the Spec has an invalid value.
108
+ """
109
+
110
+ def __post_init__(self):
111
+ self.validate()
112
+
113
+
114
+ @dataclasses.dataclass(frozen=True)
115
+ class DatedSpec(Spec):
116
+ """A spec with a `[start_date, end_date)` closed-open date range semantic.
117
+
118
+ Attrs:
119
+ start_date: The start date of the analysis/optimization. If left as `None`,
120
+ then this will eventually resolve to a model's first time coordinate.
121
+ end_date: The end date of the analysis/optimization. If left as `None`, then
122
+ this will eventually resolve to a model's last time coordinate. When
123
+ specified, this end date is exclusive.
124
+ date_interval_tag: An optional tag that identifies the date interval.
125
+ """
126
+
127
+ start_date: datetime.date | None = None
128
+ end_date: datetime.date | None = None
129
+ date_interval_tag: str = ''
130
+
131
+ @override
132
+ def validate(self):
133
+ """Overrides the Spec.validate() method to check that dates are valid."""
134
+ if (
135
+ self.start_date is not None
136
+ and self.end_date is not None
137
+ and self.start_date > self.end_date
138
+ ):
139
+ raise ValueError('Start date must be before end date.')
140
+
141
+ def resolver(
142
+ self, time_coordinates: tc.TimeCoordinates
143
+ ) -> 'DatedSpecResolver':
144
+ """Returns a date resolver for this spec, with the given Meridian model."""
145
+ return DatedSpecResolver(self, time_coordinates)
146
+
147
+
148
+ class DatedSpecResolver:
149
+ """Resolves date parameters in specs based on a model's time coordinates."""
150
+
151
+ def __init__(self, spec: DatedSpec, time_coordinates: tc.TimeCoordinates):
152
+ self._spec = spec
153
+ self._time_coordinates = time_coordinates
154
+
155
+ @property
156
+ def _interval_days(self) -> int:
157
+ return self._time_coordinates.interval_days
158
+
159
+ @property
160
+ def time_coordinates(self) -> tc.TimeCoordinates:
161
+ return self._time_coordinates
162
+
163
+ def to_closed_date_interval_tuple(
164
+ self,
165
+ ) -> tuple[str | None, str | None]:
166
+ """Transforms given spec into a closed `[start, end]` date interval tuple.
167
+
168
+ For each of the bookends in the tuple, `None` value indicates a time
169
+ coordinate default (first or last time coordinate, respectively).
170
+
171
+ Returns:
172
+ A **closed** `[start, end]` date interval tuple.
173
+ """
174
+ start, end = (None, None)
175
+
176
+ if self._spec.start_date is not None:
177
+ start = self._spec.start_date.strftime(c.DATE_FORMAT)
178
+ if self._spec.end_date is not None:
179
+ inclusive_end_date = self._spec.end_date - datetime.timedelta(
180
+ days=self._interval_days
181
+ )
182
+ end = inclusive_end_date.strftime(c.DATE_FORMAT)
183
+
184
+ return (start, end)
185
+
186
+ def resolve_to_enumerated_selected_times(self) -> list[str] | None:
187
+ """Resolves the given spec into an enumerated list of time coordinates.
188
+
189
+ Returns:
190
+ An enumerated list of time coordinates, or None (semantic "All") if the
191
+ bound spec is also None.
192
+ """
193
+ start, end = self.to_closed_date_interval_tuple()
194
+ expanded = self._time_coordinates.expand_selected_time_dims(
195
+ start_date=start, end_date=end
196
+ )
197
+ if expanded is None:
198
+ return None
199
+ return [date.strftime(c.DATE_FORMAT) for date in expanded]
200
+
201
+ def resolve_to_bool_selected_times(self) -> list[bool] | None:
202
+ """Resolves the given spec into a list of booleans indicating selected times.
203
+
204
+ Returns:
205
+ A list of booleans indicating selected times, or None (semantic "All") if
206
+ the bound spec is also None.
207
+ """
208
+ selected_times = self.resolve_to_enumerated_selected_times()
209
+ if selected_times is None:
210
+ return None
211
+ return [
212
+ time in selected_times for time in self._time_coordinates.all_dates_str
213
+ ]
214
+
215
+ def collapse_to_date_interval_proto(self) -> date_interval_pb2.DateInterval:
216
+ """Collapses the given spec into a `DateInterval` proto.
217
+
218
+ If the spec's date range is unbounded, then the DateInterval proto will have
219
+ the semantic "All", and we resolve it by consulting the time coordinates of
220
+ the model bound to this resolver.
221
+
222
+ Note that the exclusive end date semantic will be preserved in the returned
223
+ proto.
224
+
225
+ Returns:
226
+ A `DateInterval` proto the represents the date interval specified by the
227
+ spec.
228
+ """
229
+ selected_times = self.resolve_to_enumerated_selected_times()
230
+ if selected_times is None:
231
+ start_date = self._time_coordinates.all_dates[0]
232
+ end_date = self._time_coordinates.all_dates[-1]
233
+ else:
234
+ normalized_selected_times = [
235
+ tc.normalize_date(date) for date in selected_times
236
+ ]
237
+ start_date = normalized_selected_times[0]
238
+ end_date = normalized_selected_times[-1]
239
+
240
+ # Adjust end_date to make it exclusive.
241
+ end_date += datetime.timedelta(days=self._interval_days)
242
+
243
+ return time_record.create_date_interval_pb(
244
+ start_date, end_date, tag=self._spec.date_interval_tag
245
+ )
246
+
247
+ def transform_to_date_interval_protos(
248
+ self,
249
+ ) -> list[date_interval_pb2.DateInterval]:
250
+ """Transforms the given spec into `DateInterval` protos.
251
+
252
+ If the spec's date range is unbounded, then the DateInterval proto will have
253
+ the semantic "All", and we resolve it by consulting the time coordinates of
254
+ the model bound to this resolver.
255
+
256
+ Note that the exclusive end date semantic will be preserved in the returned
257
+ proto.
258
+
259
+ Returns:
260
+ A list of `DateInterval` protos the represents the date intervals
261
+ specified by the spec.
262
+ """
263
+ selected_times = self.resolve_to_enumerated_selected_times()
264
+ if selected_times is None:
265
+ times_list = self._time_coordinates.all_dates
266
+ else:
267
+ times_list = [tc.normalize_date(date) for date in selected_times]
268
+
269
+ date_intervals = []
270
+ for start_date in times_list:
271
+ date_interval = time_record.create_date_interval_pb(
272
+ start_date=start_date,
273
+ end_date=start_date + datetime.timedelta(days=self._interval_days),
274
+ tag=self._spec.date_interval_tag,
275
+ )
276
+ date_intervals.append(date_interval)
277
+
278
+ return date_intervals
279
+
280
+ def resolve_to_date_interval_open_end(
281
+ self,
282
+ ) -> tuple[datetime.date, datetime.date]:
283
+ """Resolves given spec into an open-ended `[start, end)` date interval."""
284
+ start = self._spec.start_date or self._time_coordinates.all_dates[0]
285
+ end = self._spec.end_date
286
+ if end is None:
287
+ end = self._time_coordinates.all_dates[-1]
288
+ # Adjust `end` to make it exclusive, but only if we pulled it from the
289
+ # time coordinates.
290
+ end += datetime.timedelta(days=self._interval_days)
291
+ return (start, end)
292
+
293
+ def resolve_to_date_interval_proto(self) -> date_interval_pb2.DateInterval:
294
+ """Resolves the given spec into a fully specified `DateInterval` proto.
295
+
296
+ If either `start_date` or `end_date` is None in the bound spec, then we
297
+ resolve it by consulting the time coordinates of the model bound to this
298
+ resolver. They are resolved to the first and last time coordinates (plus
299
+ interval length), respectively.
300
+
301
+ Note that the exclusive end date semantic will be preserved in the returned
302
+ proto.
303
+
304
+ Returns:
305
+ A resolved `DateInterval` proto the represents the date interval specified
306
+ by the bound spec.
307
+ """
308
+ start_date, end_date = self.resolve_to_date_interval_open_end()
309
+ return time_record.create_date_interval_pb(
310
+ start_date, end_date, tag=self._spec.date_interval_tag
311
+ )
312
+
313
+
314
+ @dataclasses.dataclass(frozen=True, kw_only=True)
315
+ class OptimizationSpec(DatedSpec):
316
+ """A dated spec for optimization.
317
+
318
+ Attrs:
319
+ optimization_name: The name of the optimization in this spec.
320
+ grid_name: The name of the optimization grid.
321
+ group_id: An optional group ID for linking related optimizations.
322
+ confidence_level: The threshold for computing confidence intervals. Defaults
323
+ to 0.9. Must be a number between 0 and 1.
324
+ """
325
+
326
+ optimization_name: str
327
+ grid_name: str
328
+ group_id: str | None = None
329
+ confidence_level: float = c.DEFAULT_CONFIDENCE_LEVEL
330
+
331
+ @override
332
+ def validate(self):
333
+ """Check optimization parameters are valid."""
334
+ super().validate()
335
+
336
+ if not self.optimization_name or self.optimization_name.isspace():
337
+ raise ValueError('Optimization name must not be empty or blank.')
338
+
339
+ if not self.grid_name or self.grid_name.isspace():
340
+ raise ValueError('Grid name must not be empty or blank.')
341
+
342
+ if self.confidence_level < 0 or self.confidence_level > 1:
343
+ raise ValueError('Confidence level must be between 0 and 1.')
344
+
345
+
346
+ S = TypeVar('S', bound=Spec)
347
+ M = TypeVar('M', bound=message.Message)
348
+
349
+
350
+ class ModelProcessor(abc.ABC, Generic[S, M]):
351
+ """Performs model-based analysis or optimization."""
352
+
353
+ @classmethod
354
+ @abc.abstractmethod
355
+ def spec_type(cls) -> type[S]:
356
+ """Returns the concrete Spec type that this ModelProcessor operates on."""
357
+ raise NotImplementedError()
358
+
359
+ @classmethod
360
+ @abc.abstractmethod
361
+ def output_type(cls) -> type[M]:
362
+ """Returns the concrete output type that this ModelProcessor produces."""
363
+ raise NotImplementedError()
364
+
365
+ @abc.abstractmethod
366
+ def execute(self, specs: Sequence[S]) -> M:
367
+ """Runs an analysis/optimization on the model using the given specs.
368
+
369
+ Args:
370
+ specs: Sequence of Specs containing parameters needed for the
371
+ analysis/optimization. The specs must all be of the same type as
372
+ `self.spec_type()` for this processor
373
+
374
+ Returns:
375
+ A proto containing the results of the analysis/optimization.
376
+ """
377
+ raise NotImplementedError()
378
+
379
+ @abc.abstractmethod
380
+ def _set_output(self, output: pb.Mmm, result: M):
381
+ """Sets the output field in the given `MmmOutput` proto.
382
+
383
+ A model consumer that orchestrated this processor will indirectly call this
384
+ method (via `__call__`) to attach the output of `execute()` (a
385
+ processor-defined message `M`) into a partially built `MmmOutput` proto that
386
+ the model consumer manages.
387
+
388
+ Args:
389
+ output: The container output proto to which the given result message
390
+ should be attached.
391
+ result: An output of `execute()`.
392
+ """
393
+ raise NotImplementedError()
394
+
395
+ def __call__(self, specs: Sequence[S], output: pb.Mmm):
396
+ """Runs an analysis/optimization on the model using the given specs.
397
+
398
+ This also sets the appropriate output field in the given MmmOutput proto.
399
+
400
+ Args:
401
+ specs: Sequence of Specs containing parameters needed for the
402
+ analysis/optimization. The specs must all be of the same type as
403
+ `self.spec_type()` for this processor
404
+ output: The output proto to which the results of the analysis/optimization
405
+ should be attached.
406
+
407
+ Raises:
408
+ ValueError: If any spec is not of the same type as `self.spec_type()`.
409
+ """
410
+ if not all([isinstance(spec, self.spec_type()) for spec in specs]):
411
+ raise ValueError('Not all specs are of type %s' % self.spec_type())
412
+ self._set_output(output, self.execute(specs))