google-meridian 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. {google_meridian-1.4.0.dist-info → google_meridian-1.5.0.dist-info}/METADATA +14 -11
  2. {google_meridian-1.4.0.dist-info → google_meridian-1.5.0.dist-info}/RECORD +47 -43
  3. {google_meridian-1.4.0.dist-info → google_meridian-1.5.0.dist-info}/WHEEL +1 -1
  4. meridian/analysis/analyzer.py +558 -398
  5. meridian/analysis/optimizer.py +90 -68
  6. meridian/analysis/review/reviewer.py +4 -1
  7. meridian/analysis/summarizer.py +6 -1
  8. meridian/analysis/test_utils.py +2898 -2538
  9. meridian/analysis/visualizer.py +28 -9
  10. meridian/backend/__init__.py +106 -0
  11. meridian/constants.py +1 -0
  12. meridian/data/input_data.py +30 -52
  13. meridian/data/input_data_builder.py +2 -9
  14. meridian/data/test_utils.py +25 -41
  15. meridian/data/validator.py +48 -0
  16. meridian/mlflow/autolog.py +19 -9
  17. meridian/model/adstock_hill.py +3 -5
  18. meridian/model/context.py +134 -0
  19. meridian/model/eda/constants.py +334 -4
  20. meridian/model/eda/eda_engine.py +723 -312
  21. meridian/model/eda/eda_outcome.py +177 -33
  22. meridian/model/model.py +159 -110
  23. meridian/model/model_test_data.py +38 -0
  24. meridian/model/posterior_sampler.py +103 -62
  25. meridian/model/prior_sampler.py +114 -94
  26. meridian/model/spec.py +23 -14
  27. meridian/templates/card.html.jinja +9 -7
  28. meridian/templates/chart.html.jinja +1 -6
  29. meridian/templates/finding.html.jinja +19 -0
  30. meridian/templates/findings.html.jinja +33 -0
  31. meridian/templates/formatter.py +41 -5
  32. meridian/templates/formatter_test.py +127 -0
  33. meridian/templates/style.css +66 -9
  34. meridian/templates/style.scss +85 -4
  35. meridian/templates/table.html.jinja +1 -0
  36. meridian/version.py +1 -1
  37. scenarioplanner/linkingapi/constants.py +1 -1
  38. scenarioplanner/mmm_ui_proto_generator.py +1 -0
  39. schema/processors/marketing_processor.py +11 -10
  40. schema/processors/model_processor.py +4 -1
  41. schema/serde/distribution.py +12 -7
  42. schema/serde/hyperparameters.py +54 -107
  43. schema/serde/meridian_serde.py +6 -1
  44. schema/utils/__init__.py +1 -0
  45. schema/utils/proto_enum_converter.py +127 -0
  46. {google_meridian-1.4.0.dist-info → google_meridian-1.5.0.dist-info}/licenses/LICENSE +0 -0
  47. {google_meridian-1.4.0.dist-info → google_meridian-1.5.0.dist-info}/top_level.txt +0 -0
@@ -16,12 +16,14 @@
16
16
 
17
17
  import warnings
18
18
 
19
+ import bidict
19
20
  from meridian import backend
20
21
  from meridian import constants as c
21
22
  from meridian.model import spec
22
23
  from mmm.v1.model.meridian import meridian_model_pb2 as meridian_pb
23
24
  from schema.serde import constants as sc
24
25
  from schema.serde import serde
26
+ from schema.utils import proto_enum_converter
25
27
  import numpy as np
26
28
 
27
29
  _MediaEffectsDist = meridian_pb.MediaEffectsDistribution
@@ -31,101 +33,44 @@ _NonMediaBaselineFunction = (
31
33
  meridian_pb.NonMediaBaselineValue.NonMediaBaselineFunction
32
34
  )
33
35
 
36
+ media_effects_converter = proto_enum_converter.ProtoEnumConverter(
37
+ enum_display_name="Media effects distribution",
38
+ enum_message=_MediaEffectsDist,
39
+ mapping=bidict.bidict({
40
+ c.MEDIA_EFFECTS_LOG_NORMAL: "LOG_NORMAL",
41
+ c.MEDIA_EFFECTS_NORMAL: "NORMAL",
42
+ }),
43
+ enum_unspecified=_MediaEffectsDist.MEDIA_EFFECTS_DISTRIBUTION_UNSPECIFIED,
44
+ default_when_unspecified=c.MEDIA_EFFECTS_LOG_NORMAL,
45
+ )
34
46
 
35
- def _media_effects_dist_to_proto_enum(
36
- media_effect_dict: str,
37
- ) -> _MediaEffectsDist:
38
- match media_effect_dict:
39
- case c.MEDIA_EFFECTS_LOG_NORMAL:
40
- return _MediaEffectsDist.LOG_NORMAL
41
- case c.MEDIA_EFFECTS_NORMAL:
42
- return _MediaEffectsDist.NORMAL
43
- case _:
44
- return _MediaEffectsDist.MEDIA_EFFECTS_DISTRIBUTION_UNSPECIFIED
45
-
46
-
47
- def _proto_enum_to_media_effects_dist(
48
- proto_enum: _MediaEffectsDist,
49
- ) -> str:
50
- """Converts a `_MediaEffectsDist` enum to its string representation."""
51
- match proto_enum:
52
- case _MediaEffectsDist.LOG_NORMAL:
53
- return c.MEDIA_EFFECTS_LOG_NORMAL
54
- case _MediaEffectsDist.NORMAL:
55
- return c.MEDIA_EFFECTS_NORMAL
56
- case _MediaEffectsDist.MEDIA_EFFECTS_DISTRIBUTION_UNSPECIFIED:
57
- warnings.warn(
58
- "Media effects distribution is unspecified. Resolving to"
59
- " 'log-normal'."
60
- )
61
- return c.MEDIA_EFFECTS_LOG_NORMAL
62
- case _:
63
- raise ValueError(
64
- "Unsupported MediaEffectsDistribution proto enum value:"
65
- f" {proto_enum}."
66
- )
67
-
68
-
69
- def _paid_media_prior_type_to_proto_enum(
70
- paid_media_prior_type: str | None,
71
- ) -> _PaidMediaPriorType:
72
- """Converts a paid media prior type string to its proto enum."""
73
- if paid_media_prior_type is None:
74
- return _PaidMediaPriorType.PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED
75
- try:
76
- return _PaidMediaPriorType.Value(paid_media_prior_type.upper())
77
- except ValueError:
78
- warnings.warn(
79
- f"Invalid paid media prior type: {paid_media_prior_type}. Resolving to"
80
- " PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED."
81
- )
82
- return _PaidMediaPriorType.PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED
83
-
84
-
85
- def _proto_enum_to_paid_media_prior_type(
86
- proto_enum: _PaidMediaPriorType,
87
- ) -> str | None:
88
- """Converts a `_PaidMediaPriorType` enum to its string representation."""
89
- if proto_enum == _PaidMediaPriorType.PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED:
90
- return None
91
- return _PaidMediaPriorType.Name(proto_enum).lower()
92
-
93
-
94
- def _non_paid_prior_type_to_proto_enum(
95
- non_paid_prior_type: str,
96
- ) -> _NonPaidTreatmentsPriorType:
97
- """Converts a non-paid prior type string to its proto enum."""
98
- try:
99
- return _NonPaidTreatmentsPriorType.Value(
100
- f"NON_PAID_TREATMENTS_PRIOR_TYPE_{non_paid_prior_type.upper()}"
101
- )
102
- except ValueError:
103
- warnings.warn(
104
- f"Invalid non-paid prior type: {non_paid_prior_type}. Resolving to"
105
- " NON_PAID_TREATMENTS_PRIOR_TYPE_CONTRIBUTION."
106
- )
107
- return (
108
- _NonPaidTreatmentsPriorType.NON_PAID_TREATMENTS_PRIOR_TYPE_CONTRIBUTION
109
- )
110
-
47
+ paid_media_prior_type_converter = proto_enum_converter.ProtoEnumConverter(
48
+ enum_display_name="Paid media prior type",
49
+ enum_message=_PaidMediaPriorType,
50
+ mapping=bidict.bidict({
51
+ c.TREATMENT_PRIOR_TYPE_ROI: "ROI",
52
+ c.TREATMENT_PRIOR_TYPE_MROI: "MROI",
53
+ c.TREATMENT_PRIOR_TYPE_COEFFICIENT: "COEFFICIENT",
54
+ c.TREATMENT_PRIOR_TYPE_CONTRIBUTION: "CONTRIBUTION",
55
+ }),
56
+ enum_unspecified=_PaidMediaPriorType.PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED,
57
+ default_when_unspecified=None,
58
+ )
111
59
 
112
- def _proto_enum_to_non_paid_prior_type(
113
- proto_enum: _NonPaidTreatmentsPriorType,
114
- ) -> str:
115
- """Converts a `_NonPaidTreatmentsPriorType` enum to its string representation."""
116
- if (
117
- proto_enum
118
- == _NonPaidTreatmentsPriorType.NON_PAID_TREATMENTS_PRIOR_TYPE_UNSPECIFIED
119
- ):
120
- warnings.warn(
121
- "Non-paid prior type is unspecified. Resolving to 'contribution'."
122
- )
123
- return c.TREATMENT_PRIOR_TYPE_CONTRIBUTION
124
- return (
125
- _NonPaidTreatmentsPriorType.Name(proto_enum)
126
- .replace("NON_PAID_TREATMENTS_PRIOR_TYPE_", "")
127
- .lower()
128
- )
60
+ non_paid_treatments_prior_type_converter = proto_enum_converter.ProtoEnumConverter(
61
+ enum_display_name="Non-paid treatments prior type",
62
+ enum_message=_NonPaidTreatmentsPriorType,
63
+ mapping=bidict.bidict({
64
+ c.TREATMENT_PRIOR_TYPE_COEFFICIENT: (
65
+ "NON_PAID_TREATMENTS_PRIOR_TYPE_COEFFICIENT"
66
+ ),
67
+ c.TREATMENT_PRIOR_TYPE_CONTRIBUTION: (
68
+ "NON_PAID_TREATMENTS_PRIOR_TYPE_CONTRIBUTION"
69
+ ),
70
+ }),
71
+ enum_unspecified=_NonPaidTreatmentsPriorType.NON_PAID_TREATMENTS_PRIOR_TYPE_UNSPECIFIED,
72
+ default_when_unspecified=c.TREATMENT_PRIOR_TYPE_CONTRIBUTION,
73
+ )
129
74
 
130
75
 
131
76
  class HyperparametersSerde(
@@ -141,25 +86,27 @@ class HyperparametersSerde(
141
86
  def serialize(self, obj: spec.ModelSpec) -> meridian_pb.Hyperparameters:
142
87
  """Serializes the given ModelSpec into a `Hyperparameters` proto."""
143
88
  hyperparameters_proto = meridian_pb.Hyperparameters(
144
- media_effects_dist=_media_effects_dist_to_proto_enum(
89
+ media_effects_dist=media_effects_converter.to_proto(
145
90
  obj.media_effects_dist
146
91
  ),
147
92
  hill_before_adstock=obj.hill_before_adstock,
148
93
  unique_sigma_for_each_geo=obj.unique_sigma_for_each_geo,
149
- media_prior_type=_paid_media_prior_type_to_proto_enum(
94
+ media_prior_type=paid_media_prior_type_converter.to_proto(
150
95
  obj.media_prior_type
151
96
  ),
152
- rf_prior_type=_paid_media_prior_type_to_proto_enum(obj.rf_prior_type),
153
- paid_media_prior_type=_paid_media_prior_type_to_proto_enum(
97
+ rf_prior_type=paid_media_prior_type_converter.to_proto(
98
+ obj.rf_prior_type
99
+ ),
100
+ paid_media_prior_type=paid_media_prior_type_converter.to_proto(
154
101
  obj.paid_media_prior_type
155
102
  ),
156
- organic_media_prior_type=_non_paid_prior_type_to_proto_enum(
103
+ organic_media_prior_type=non_paid_treatments_prior_type_converter.to_proto(
157
104
  obj.organic_media_prior_type
158
105
  ),
159
- organic_rf_prior_type=_non_paid_prior_type_to_proto_enum(
106
+ organic_rf_prior_type=non_paid_treatments_prior_type_converter.to_proto(
160
107
  obj.organic_rf_prior_type
161
108
  ),
162
- non_media_treatments_prior_type=_non_paid_prior_type_to_proto_enum(
109
+ non_media_treatments_prior_type=non_paid_treatments_prior_type_converter.to_proto(
163
110
  obj.non_media_treatments_prior_type
164
111
  ),
165
112
  enable_aks=obj.enable_aks,
@@ -326,28 +273,28 @@ class HyperparametersSerde(
326
273
  adstock_decay_spec = sc.DEFAULT_DECAY
327
274
 
328
275
  return spec.ModelSpec(
329
- media_effects_dist=_proto_enum_to_media_effects_dist(
276
+ media_effects_dist=media_effects_converter.from_proto(
330
277
  serialized.media_effects_dist
331
278
  ),
332
279
  hill_before_adstock=serialized.hill_before_adstock,
333
280
  max_lag=max_lag,
334
281
  unique_sigma_for_each_geo=serialized.unique_sigma_for_each_geo,
335
- media_prior_type=_proto_enum_to_paid_media_prior_type(
282
+ media_prior_type=paid_media_prior_type_converter.from_proto(
336
283
  serialized.media_prior_type
337
284
  ),
338
- rf_prior_type=_proto_enum_to_paid_media_prior_type(
285
+ rf_prior_type=paid_media_prior_type_converter.from_proto(
339
286
  serialized.rf_prior_type
340
287
  ),
341
- paid_media_prior_type=_proto_enum_to_paid_media_prior_type(
288
+ paid_media_prior_type=paid_media_prior_type_converter.from_proto(
342
289
  serialized.paid_media_prior_type
343
290
  ),
344
- organic_media_prior_type=_proto_enum_to_non_paid_prior_type(
291
+ organic_media_prior_type=non_paid_treatments_prior_type_converter.from_proto(
345
292
  serialized.organic_media_prior_type
346
293
  ),
347
- organic_rf_prior_type=_proto_enum_to_non_paid_prior_type(
294
+ organic_rf_prior_type=non_paid_treatments_prior_type_converter.from_proto(
348
295
  serialized.organic_rf_prior_type
349
296
  ),
350
- non_media_treatments_prior_type=_proto_enum_to_non_paid_prior_type(
297
+ non_media_treatments_prior_type=non_paid_treatments_prior_type_converter.from_proto(
351
298
  serialized.non_media_treatments_prior_type
352
299
  ),
353
300
  non_media_baseline_values=non_media_baseline_values,
@@ -43,6 +43,7 @@ import dataclasses
43
43
  import os
44
44
  import warnings
45
45
 
46
+ import arviz as az
46
47
  from google.protobuf import text_format
47
48
  import meridian
48
49
  from meridian import backend
@@ -165,6 +166,7 @@ class MeridianSerde(serde.Serde[kernel_pb.MmmKernel, model.Meridian]):
165
166
  inference_data=inference_data.InferenceDataSerde().serialize(
166
167
  mmm.inference_data
167
168
  ),
169
+ arviz_version=az.__version__,
168
170
  )
169
171
  # For backwards compatibility, only serialize EDA spec if it exists.
170
172
  if hasattr(mmm, 'eda_spec'):
@@ -190,7 +192,10 @@ class MeridianSerde(serde.Serde[kernel_pb.MmmKernel, model.Meridian]):
190
192
  # NotFittedModelError can be raised below. If raised,
191
193
  # return None. Otherwise, set convergence status based on
192
194
  # MCMCSamplingError (caught in the except block).
193
- rhats = analyzer.Analyzer(mmm).get_rhat()
195
+ rhats = analyzer.Analyzer(
196
+ model_context=mmm.model_context,
197
+ inference_data=mmm.inference_data,
198
+ ).get_rhat()
194
199
  rhat_proto = meridian_pb.RHatDiagnostic()
195
200
  for name, tensor in rhats.items():
196
201
  rhat_proto.parameter_r_hats.add(
schema/utils/__init__.py CHANGED
@@ -15,4 +15,5 @@
15
15
  """Module containing MMM schema util functions."""
16
16
 
17
17
  from schema.utils import date_range_bucketing
18
+ from schema.utils import proto_enum_converter
18
19
  from schema.utils import time_record
@@ -0,0 +1,127 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """A generic class for converting between Protobuf enums and strings."""
16
+
17
+ from typing import Generic, Type, TypeVar
18
+ import warnings
19
+
20
+ import bidict
21
+
22
+
23
+ EnumType = TypeVar("EnumType")
24
+ DefaultType = TypeVar("DefaultType")
25
+
26
+
27
+ class ProtoEnumConverter(Generic[EnumType, DefaultType]):
28
+ """Class for converting between proto enums and strings."""
29
+
30
+ def __init__(
31
+ self,
32
+ enum_message: Type[EnumType],
33
+ enum_display_name: str,
34
+ mapping: bidict.bidict,
35
+ enum_unspecified: EnumType,
36
+ default_when_unspecified: DefaultType,
37
+ ):
38
+ """Initializes the ProtoEnumConverter.
39
+
40
+ Arguments:
41
+ enum_message: The proto enum message definition.
42
+ enum_display_name: The loggable proto enum message name.
43
+ mapping: The mapping between the proto enum name and the string
44
+ representation.
45
+ enum_unspecified: The enum value that corresponds to unspecified.
46
+ default_when_unspecified: The default value that should be returned when
47
+ the proto enum is unspecified.
48
+ """
49
+ self.enum_message = enum_message
50
+ self.enum_display_name = enum_display_name
51
+ self.mapping = mapping
52
+ self.enum_unspecified = enum_unspecified
53
+ self.default_when_unspecified = default_when_unspecified
54
+
55
+ def to_proto(self, string_value: str | None) -> EnumType:
56
+ """Converts a string to its corresponding proto enum.
57
+
58
+ Args:
59
+ string_value: The string to convert to a proto enum.
60
+
61
+ Returns:
62
+ The corresponding proto enum or enum_unspecified when the enum message
63
+ doesn't exist.
64
+
65
+ Raises:
66
+ ValueError when given string is not found in the mapping.
67
+ """
68
+ if string_value is None:
69
+ return self.enum_unspecified
70
+
71
+ proto_name = self.mapping.get(string_value)
72
+ if proto_name:
73
+ try:
74
+ return self.enum_message.Value(proto_name)
75
+ except ValueError:
76
+ warnings.warn(
77
+ "Invalid %s value: %s. Resolving to %s."
78
+ % (
79
+ self.enum_message.DESCRIPTOR.name,
80
+ string_value,
81
+ self.enum_message.Name(self.enum_unspecified),
82
+ )
83
+ )
84
+ return self.enum_unspecified
85
+ else:
86
+ raise ValueError(
87
+ f"Unmatched {self.enum_message.DESCRIPTOR.name} value:"
88
+ f" {string_value}."
89
+ )
90
+
91
+ def from_proto(self, proto_enum: EnumType) -> str | DefaultType:
92
+ """Converts a proto enum to its string representation.
93
+
94
+ Args:
95
+ proto_enum: The enum value to convert to its string representation
96
+
97
+ Returns:
98
+ The string representation of the given proto_enum or the default value
99
+ when the proto enum is unspecified.
100
+
101
+ Raises:
102
+ ValueError when given proto enum is not found in the mapping.
103
+ """
104
+ if proto_enum == self.enum_unspecified:
105
+ warnings.warn(
106
+ "%s is unspecified. Resolving to default: %s."
107
+ % (
108
+ self.enum_display_name,
109
+ self.enum_message.Name(self.enum_unspecified),
110
+ )
111
+ )
112
+ return self.default_when_unspecified
113
+
114
+ try:
115
+ proto_name = self.enum_message.Name(proto_enum)
116
+ except ValueError as e:
117
+ raise ValueError(
118
+ f"Invalid {self.enum_message.DESCRIPTOR.name} proto enum value:"
119
+ f" {proto_enum}."
120
+ ) from e
121
+
122
+ try:
123
+ return self.mapping.inv[proto_name]
124
+ except KeyError as e:
125
+ raise KeyError(
126
+ f"Protobuf enum name '{proto_name}' is not configured in the mapping."
127
+ ) from e