google-meridian 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. {google_meridian-1.4.0.dist-info → google_meridian-1.5.0.dist-info}/METADATA +14 -11
  2. {google_meridian-1.4.0.dist-info → google_meridian-1.5.0.dist-info}/RECORD +47 -43
  3. {google_meridian-1.4.0.dist-info → google_meridian-1.5.0.dist-info}/WHEEL +1 -1
  4. meridian/analysis/analyzer.py +558 -398
  5. meridian/analysis/optimizer.py +90 -68
  6. meridian/analysis/review/reviewer.py +4 -1
  7. meridian/analysis/summarizer.py +6 -1
  8. meridian/analysis/test_utils.py +2898 -2538
  9. meridian/analysis/visualizer.py +28 -9
  10. meridian/backend/__init__.py +106 -0
  11. meridian/constants.py +1 -0
  12. meridian/data/input_data.py +30 -52
  13. meridian/data/input_data_builder.py +2 -9
  14. meridian/data/test_utils.py +25 -41
  15. meridian/data/validator.py +48 -0
  16. meridian/mlflow/autolog.py +19 -9
  17. meridian/model/adstock_hill.py +3 -5
  18. meridian/model/context.py +134 -0
  19. meridian/model/eda/constants.py +334 -4
  20. meridian/model/eda/eda_engine.py +723 -312
  21. meridian/model/eda/eda_outcome.py +177 -33
  22. meridian/model/model.py +159 -110
  23. meridian/model/model_test_data.py +38 -0
  24. meridian/model/posterior_sampler.py +103 -62
  25. meridian/model/prior_sampler.py +114 -94
  26. meridian/model/spec.py +23 -14
  27. meridian/templates/card.html.jinja +9 -7
  28. meridian/templates/chart.html.jinja +1 -6
  29. meridian/templates/finding.html.jinja +19 -0
  30. meridian/templates/findings.html.jinja +33 -0
  31. meridian/templates/formatter.py +41 -5
  32. meridian/templates/formatter_test.py +127 -0
  33. meridian/templates/style.css +66 -9
  34. meridian/templates/style.scss +85 -4
  35. meridian/templates/table.html.jinja +1 -0
  36. meridian/version.py +1 -1
  37. scenarioplanner/linkingapi/constants.py +1 -1
  38. scenarioplanner/mmm_ui_proto_generator.py +1 -0
  39. schema/processors/marketing_processor.py +11 -10
  40. schema/processors/model_processor.py +4 -1
  41. schema/serde/distribution.py +12 -7
  42. schema/serde/hyperparameters.py +54 -107
  43. schema/serde/meridian_serde.py +6 -1
  44. schema/utils/__init__.py +1 -0
  45. schema/utils/proto_enum_converter.py +127 -0
  46. {google_meridian-1.4.0.dist-info → google_meridian-1.5.0.dist-info}/licenses/LICENSE +0 -0
  47. {google_meridian-1.4.0.dist-info → google_meridian-1.5.0.dist-info}/top_level.txt +0 -0
@@ -14,26 +14,32 @@
14
14
 
15
15
  """Meridian EDA Outcome."""
16
16
 
17
+ from collections.abc import Sequence
17
18
  import dataclasses
18
19
  import enum
19
20
  import typing
20
21
  import pandas as pd
21
22
  import xarray as xr
22
23
 
23
- __all__ = [
24
+ __all__ = (
24
25
  "EDASeverity",
25
26
  "EDAFinding",
26
27
  "AnalysisLevel",
27
28
  "AnalysisArtifact",
29
+ "FindingCause",
28
30
  "PairwiseCorrArtifact",
29
31
  "StandardDeviationArtifact",
30
32
  "VIFArtifact",
31
33
  "KpiInvariabilityArtifact",
32
34
  "CostPerMediaUnitArtifact",
35
+ "VariableGeoTimeCollinearityArtifact",
36
+ "PopulationCorrelationArtifact",
37
+ "PriorProbabilityArtifact",
33
38
  "EDACheckType",
34
39
  "ArtifactType",
35
40
  "EDAOutcome",
36
- ]
41
+ "CriticalCheckEDAOutcomes",
42
+ )
37
43
 
38
44
 
39
45
  @enum.unique
@@ -48,19 +54,31 @@ class EDASeverity(enum.Enum):
48
54
  ERROR = enum.auto()
49
55
 
50
56
 
51
- @dataclasses.dataclass(frozen=True)
52
- class EDAFinding:
53
- """Encapsulates a single, specific finding from an EDA check.
57
+ @enum.unique
58
+ class FindingCause(enum.Enum):
59
+ """Enumeration for the type of finding, mapping to specific data tables.
54
60
 
55
61
  Attributes:
56
- severity: The severity level of the finding.
57
- explanation: A human-readable description about the EDA check and a
58
- potential actionable guidance on how to address or interpret this
59
- specific finding.
62
+ NONE: For informational findings that do not indicate a data issue.
63
+ MULTICOLLINEARITY: For findings related to multicollinearity between
64
+ variables (e.g. from VIF or pairwise correlation checks).
65
+ VARIABILITY: For findings related to variables with extreme variability
66
+ issues, such as no variation (e.g. KPI invariability check or standard
67
+ deviation checks).
68
+ INCONSISTENT_DATA: For findings related to inconsistent data points (e.g.
69
+ zero cost with positive media units, from cost per media unit check).
70
+ RUNTIME_ERROR: For findings that indicate a runtime error during an EDA
71
+ check.
72
+ OUTLIER: For findings related to outliers in data (e.g. cost per media unit
73
+ outlier check).
60
74
  """
61
75
 
62
- severity: EDASeverity
63
- explanation: str
76
+ NONE = enum.auto()
77
+ MULTICOLLINEARITY = enum.auto()
78
+ VARIABILITY = enum.auto()
79
+ INCONSISTENT_DATA = enum.auto()
80
+ RUNTIME_ERROR = enum.auto()
81
+ OUTLIER = enum.auto()
64
82
 
65
83
 
66
84
  @enum.unique
@@ -95,9 +113,30 @@ class AnalysisArtifact:
95
113
  level: AnalysisLevel
96
114
 
97
115
 
116
+ @dataclasses.dataclass(frozen=True, kw_only=True)
117
+ class EDAFinding:
118
+ """A single, specific finding from an EDA check.
119
+
120
+ Attributes:
121
+ severity: The severity level of the finding.
122
+ explanation: A human-readable description about the EDA check and a
123
+ potential actionable guidance on how to address or interpret this
124
+ specific finding.
125
+ finding_cause: The type of finding, mapping to specific data tables.
126
+ associated_artifact: The artifact associated with the finding, if any.
127
+ """
128
+
129
+ __hash__ = None
130
+
131
+ severity: EDASeverity
132
+ explanation: str
133
+ finding_cause: FindingCause
134
+ associated_artifact: AnalysisArtifact | None = None
135
+
136
+
98
137
  @dataclasses.dataclass(frozen=True)
99
138
  class PairwiseCorrArtifact(AnalysisArtifact):
100
- """Encapsulates artifacts from a single pairwise correlation analysis.
139
+ """Artifacts from a single pairwise correlation analysis.
101
140
 
102
141
  Attributes:
103
142
  corr_matrix: Pairwise correlation matrix.
@@ -115,7 +154,7 @@ class PairwiseCorrArtifact(AnalysisArtifact):
115
154
 
116
155
  @dataclasses.dataclass(frozen=True)
117
156
  class StandardDeviationArtifact(AnalysisArtifact):
118
- """Encapsulates artifacts from a standard deviation analysis.
157
+ """Artifacts from a standard deviation analysis.
119
158
 
120
159
  Attributes:
121
160
  variable: The variable for which standard deviation is calculated.
@@ -130,7 +169,7 @@ class StandardDeviationArtifact(AnalysisArtifact):
130
169
 
131
170
  @dataclasses.dataclass(frozen=True)
132
171
  class VIFArtifact(AnalysisArtifact):
133
- """Encapsulates artifacts from a single VIF analysis.
172
+ """Artifacts from a single VIF analysis.
134
173
 
135
174
  Attributes:
136
175
  vif_da: DataArray with VIF values.
@@ -138,12 +177,13 @@ class VIFArtifact(AnalysisArtifact):
138
177
  """
139
178
 
140
179
  vif_da: xr.DataArray
180
+ # TODO: change this naming
141
181
  outlier_df: pd.DataFrame
142
182
 
143
183
 
144
184
  @dataclasses.dataclass(frozen=True)
145
185
  class KpiInvariabilityArtifact(AnalysisArtifact):
146
- """Encapsulates artifacts from a KPI invariability analysis.
186
+ """Artifacts from a KPI invariability analysis.
147
187
 
148
188
  Attributes:
149
189
  kpi_da: DataArray of the KPI that is examined for variability.
@@ -157,14 +197,15 @@ class KpiInvariabilityArtifact(AnalysisArtifact):
157
197
 
158
198
  @dataclasses.dataclass(frozen=True)
159
199
  class CostPerMediaUnitArtifact(AnalysisArtifact):
160
- """Encapsulates artifacts from a Cost per Media Unit analysis.
200
+ """Artifacts from a Cost per Media Unit analysis.
161
201
 
162
202
  Attributes:
163
203
  cost_per_media_unit_da: DataArray of cost per media unit.
164
204
  cost_media_unit_inconsistency_df: DataFrame of time periods where cost and
165
205
  media units are inconsistent (e.g., zero cost with positive media units,
166
206
  or positive cost with zero media units).
167
- outlier_df: DataFrame with outliers of cost per media unit.
207
+ outlier_df: DataFrame with outliers of cost per media unit, along with the
208
+ spend, and media units.
168
209
  """
169
210
 
170
211
  cost_per_media_unit_da: xr.DataArray
@@ -172,6 +213,47 @@ class CostPerMediaUnitArtifact(AnalysisArtifact):
172
213
  outlier_df: pd.DataFrame
173
214
 
174
215
 
216
+ @dataclasses.dataclass(frozen=True)
217
+ class VariableGeoTimeCollinearityArtifact(AnalysisArtifact):
218
+ """Artifacts from a Geo/Time Collinearity analysis for Treatment/Control variables.
219
+
220
+ Attributes:
221
+ rsquared_ds: Dataset containing adjusted R-squared values for treatments and
222
+ controls regressed against 'geo' and 'time'.
223
+ """
224
+
225
+ rsquared_ds: xr.Dataset
226
+
227
+
228
+ @dataclasses.dataclass(frozen=True)
229
+ class PopulationCorrelationArtifact(AnalysisArtifact):
230
+ """Artifacts from population correlation analysis.
231
+
232
+ Attributes:
233
+ correlation_ds: Dataset with Spearman correlation coefficients between
234
+ population and time-averaged treatments/controls. Each data variable in
235
+ the dataset corresponds to a variable in treatment_control_scaled_ds, and
236
+ its dimensions reflect the non-geo, non-time dimensions (e.g., 'channel').
237
+ """
238
+
239
+ correlation_ds: xr.Dataset
240
+
241
+
242
+ @dataclasses.dataclass(frozen=True)
243
+ class PriorProbabilityArtifact(AnalysisArtifact):
244
+ """Artifact for prior probability check.
245
+
246
+ Attributes:
247
+ prior_negative_baseline_prob: A float value for prior probability of
248
+ negative baseline.
249
+ mean_prior_contribution_da: The array containing the prior mean of each
250
+ treatment's contribution.
251
+ """
252
+
253
+ prior_negative_baseline_prob: float
254
+ mean_prior_contribution_da: xr.DataArray
255
+
256
+
175
257
  @enum.unique
176
258
  class EDACheckType(enum.Enum):
177
259
  """Enumeration for the type of an EDA check."""
@@ -181,9 +263,12 @@ class EDACheckType(enum.Enum):
181
263
  MULTICOLLINEARITY = enum.auto()
182
264
  KPI_INVARIABILITY = enum.auto()
183
265
  COST_PER_MEDIA_UNIT = enum.auto()
266
+ VARIABLE_GEO_TIME_COLLINEARITY = enum.auto()
267
+ POPULATION_CORRELATION = enum.auto()
268
+ PRIOR_PROBABILITY = enum.auto()
184
269
 
185
270
 
186
- ArtifactType = typing.TypeVar("ArtifactType", bound="AnalysisArtifact")
271
+ ArtifactType = typing.TypeVar("ArtifactType", bound=AnalysisArtifact)
187
272
 
188
273
 
189
274
  @dataclasses.dataclass(frozen=True)
@@ -203,18 +288,77 @@ class EDAOutcome(typing.Generic[ArtifactType]):
203
288
  findings: list[EDAFinding]
204
289
  analysis_artifacts: list[ArtifactType]
205
290
 
206
- @property
207
- def get_geo_artifact(self) -> ArtifactType | None:
208
- """Returns the geo-level analysis artifact."""
209
- for artifact in self.analysis_artifacts:
210
- if artifact.level == AnalysisLevel.GEO:
211
- return artifact
212
- return None
213
-
214
- @property
215
- def get_national_artifact(self) -> ArtifactType | None:
216
- """Returns the national-level analysis artifact."""
217
- for artifact in self.analysis_artifacts:
218
- if artifact.level == AnalysisLevel.NATIONAL:
219
- return artifact
220
- return None
291
+ def _get_artifacts_by_level(self, level: AnalysisLevel) -> list[ArtifactType]:
292
+ """Helper method to retrieve artifacts by level.
293
+
294
+ Args:
295
+ level: The AnalysisLevel to filter artifacts by.
296
+
297
+ Returns:
298
+ A list of AnalysisArtifacts at the specified level.
299
+
300
+ Raises:
301
+ ValueError: If no artifacts of the specified level are found.
302
+ """
303
+ artifacts = [
304
+ artifact
305
+ for artifact in self.analysis_artifacts
306
+ if artifact.level == level
307
+ ]
308
+
309
+ if not artifacts:
310
+ raise ValueError(
311
+ f"The EDAOutcome for {self.check_type.name} check does not have "
312
+ f"{level.name.lower()} artifacts."
313
+ )
314
+ return artifacts
315
+
316
+ def get_geo_artifacts(self) -> list[ArtifactType]:
317
+ """Returns the geo-level analysis artifacts.
318
+
319
+ Returns a list to account for checks that produce multiple artifacts
320
+ at the same level (e.g. Standard Deviation check).
321
+ """
322
+ return self._get_artifacts_by_level(AnalysisLevel.GEO)
323
+
324
+ def get_national_artifacts(self) -> list[ArtifactType]:
325
+ """Returns the national-level analysis artifacts.
326
+
327
+ Returns a list to account for checks that produce multiple artifacts
328
+ at the same level.
329
+ """
330
+ return self._get_artifacts_by_level(AnalysisLevel.NATIONAL)
331
+
332
+ def get_overall_artifacts(self) -> list[ArtifactType]:
333
+ """Returns the overall-level analysis artifacts.
334
+
335
+ Returns a list to account for checks that produce multiple artifacts
336
+ at the same level.
337
+ """
338
+ return self._get_artifacts_by_level(AnalysisLevel.OVERALL)
339
+
340
+ def get_findings_by_cause_and_severity(
341
+ self, finding_cause: FindingCause, severity: EDASeverity
342
+ ) -> Sequence[EDAFinding]:
343
+ """Helper method to retrieve findings by cause and severity."""
344
+ return [
345
+ finding
346
+ for finding in self.findings
347
+ if finding.finding_cause == finding_cause
348
+ and finding.severity == severity
349
+ ]
350
+
351
+
352
+ @dataclasses.dataclass(frozen=True, kw_only=True)
353
+ class CriticalCheckEDAOutcomes:
354
+ """Outcomes of all critical EDA checks.
355
+
356
+ Attributes:
357
+ kpi_invariability: Outcome of the KPI invariability check.
358
+ multicollinearity: Outcome of the multicollinearity (VIF) check.
359
+ pairwise_correlation: Outcome of the pairwise correlation check.
360
+ """
361
+
362
+ kpi_invariability: EDAOutcome[KpiInvariabilityArtifact]
363
+ multicollinearity: EDAOutcome[VIFArtifact]
364
+ pairwise_correlation: EDAOutcome[PairwiseCorrArtifact]