google-meridian 1.3.1__tar.gz → 1.3.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. google_meridian-1.3.2/MANIFEST.in +2 -0
  2. {google_meridian-1.3.1/google_meridian.egg-info → google_meridian-1.3.2}/PKG-INFO +7 -7
  3. {google_meridian-1.3.1 → google_meridian-1.3.2}/README.md +1 -1
  4. {google_meridian-1.3.1 → google_meridian-1.3.2/google_meridian.egg-info}/PKG-INFO +7 -7
  5. {google_meridian-1.3.1 → google_meridian-1.3.2}/google_meridian.egg-info/SOURCES.txt +10 -10
  6. {google_meridian-1.3.1 → google_meridian-1.3.2}/google_meridian.egg-info/requires.txt +5 -5
  7. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/analysis/__init__.py +1 -2
  8. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/analysis/analyzer.py +0 -1
  9. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/analysis/optimizer.py +5 -3
  10. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/analysis/review/checks.py +81 -30
  11. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/analysis/review/constants.py +4 -0
  12. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/analysis/review/results.py +40 -9
  13. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/analysis/summarizer.py +1 -1
  14. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/analysis/visualizer.py +1 -1
  15. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/backend/__init__.py +53 -5
  16. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/backend/test_utils.py +72 -0
  17. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/constants.py +1 -0
  18. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/data/load.py +2 -0
  19. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/model/eda/__init__.py +0 -1
  20. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/model/eda/constants.py +12 -2
  21. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/model/eda/eda_engine.py +299 -37
  22. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/model/eda/eda_outcome.py +21 -1
  23. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/model/knots.py +17 -0
  24. {google_meridian-1.3.1/meridian/analysis → google_meridian-1.3.2/meridian}/templates/card.html.jinja +1 -1
  25. {google_meridian-1.3.1/meridian/analysis → google_meridian-1.3.2/meridian}/templates/chart.html.jinja +1 -1
  26. {google_meridian-1.3.1/meridian/analysis → google_meridian-1.3.2/meridian}/templates/chips.html.jinja +1 -1
  27. {google_meridian-1.3.1/meridian/analysis → google_meridian-1.3.2/meridian/templates}/formatter.py +12 -1
  28. google_meridian-1.3.2/meridian/templates/formatter_test.py +216 -0
  29. {google_meridian-1.3.1/meridian/analysis → google_meridian-1.3.2/meridian}/templates/insights.html.jinja +1 -1
  30. {google_meridian-1.3.1/meridian/analysis → google_meridian-1.3.2/meridian}/templates/stats.html.jinja +1 -1
  31. {google_meridian-1.3.1/meridian/analysis → google_meridian-1.3.2/meridian}/templates/style.scss +1 -1
  32. {google_meridian-1.3.1/meridian/analysis → google_meridian-1.3.2/meridian}/templates/summary.html.jinja +4 -2
  33. {google_meridian-1.3.1/meridian/analysis → google_meridian-1.3.2/meridian}/templates/table.html.jinja +1 -1
  34. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/version.py +1 -1
  35. {google_meridian-1.3.1 → google_meridian-1.3.2}/pyproject.toml +5 -5
  36. {google_meridian-1.3.1 → google_meridian-1.3.2}/schema/__init__.py +12 -0
  37. {google_meridian-1.3.1 → google_meridian-1.3.2}/setup.py +1 -1
  38. google_meridian-1.3.1/MANIFEST.in +0 -2
  39. google_meridian-1.3.1/meridian/model/eda/meridian_eda.py +0 -220
  40. {google_meridian-1.3.1 → google_meridian-1.3.2}/LICENSE +0 -0
  41. {google_meridian-1.3.1 → google_meridian-1.3.2}/google_meridian.egg-info/dependency_links.txt +0 -0
  42. {google_meridian-1.3.1 → google_meridian-1.3.2}/google_meridian.egg-info/top_level.txt +0 -0
  43. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/__init__.py +0 -0
  44. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/analysis/review/__init__.py +0 -0
  45. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/analysis/review/configs.py +0 -0
  46. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/analysis/review/reviewer.py +0 -0
  47. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/analysis/summary_text.py +0 -0
  48. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/analysis/test_utils.py +0 -0
  49. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/backend/config.py +0 -0
  50. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/data/__init__.py +0 -0
  51. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/data/arg_builder.py +0 -0
  52. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/data/data_frame_input_data_builder.py +0 -0
  53. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/data/input_data.py +0 -0
  54. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/data/input_data_builder.py +0 -0
  55. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/data/nd_array_input_data_builder.py +0 -0
  56. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/data/test_utils.py +0 -0
  57. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/data/time_coordinates.py +0 -0
  58. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/mlflow/__init__.py +0 -0
  59. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/mlflow/autolog.py +0 -0
  60. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/model/__init__.py +0 -0
  61. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/model/adstock_hill.py +0 -0
  62. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/model/eda/eda_spec.py +0 -0
  63. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/model/media.py +0 -0
  64. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/model/model.py +0 -0
  65. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/model/model_test_data.py +0 -0
  66. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/model/posterior_sampler.py +0 -0
  67. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/model/prior_distribution.py +0 -0
  68. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/model/prior_sampler.py +0 -0
  69. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/model/spec.py +0 -0
  70. {google_meridian-1.3.1 → google_meridian-1.3.2}/meridian/model/transformers.py +0 -0
  71. {google_meridian-1.3.1 → google_meridian-1.3.2}/schema/serde/__init__.py +0 -0
  72. {google_meridian-1.3.1 → google_meridian-1.3.2}/schema/serde/constants.py +0 -0
  73. {google_meridian-1.3.1 → google_meridian-1.3.2}/schema/serde/distribution.py +0 -0
  74. {google_meridian-1.3.1 → google_meridian-1.3.2}/schema/serde/eda_spec.py +0 -0
  75. {google_meridian-1.3.1 → google_meridian-1.3.2}/schema/serde/function_registry.py +0 -0
  76. {google_meridian-1.3.1 → google_meridian-1.3.2}/schema/serde/hyperparameters.py +0 -0
  77. {google_meridian-1.3.1 → google_meridian-1.3.2}/schema/serde/inference_data.py +0 -0
  78. {google_meridian-1.3.1 → google_meridian-1.3.2}/schema/serde/marketing_data.py +0 -0
  79. {google_meridian-1.3.1 → google_meridian-1.3.2}/schema/serde/meridian_serde.py +0 -0
  80. {google_meridian-1.3.1 → google_meridian-1.3.2}/schema/serde/serde.py +0 -0
  81. {google_meridian-1.3.1 → google_meridian-1.3.2}/schema/serde/test_data.py +0 -0
  82. {google_meridian-1.3.1 → google_meridian-1.3.2}/schema/utils/__init__.py +0 -0
  83. {google_meridian-1.3.1 → google_meridian-1.3.2}/schema/utils/time_record.py +0 -0
  84. {google_meridian-1.3.1 → google_meridian-1.3.2}/setup.cfg +0 -0
@@ -0,0 +1,2 @@
1
+ global-exclude *_test.py
2
+ include meridian/templates/*
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: google-meridian
3
- Version: 1.3.1
3
+ Version: 1.3.2
4
4
  Summary: Google's open source mixed marketing model library, helps you understand your return on investment and direct your ad spend with confidence.
5
5
  Author-email: The Meridian Authors <no-reply@google.com>
6
6
  Project-URL: homepage, https://github.com/google/meridian
@@ -25,9 +25,9 @@ Requires-Dist: pandas<3,>=2.2.2
25
25
  Requires-Dist: patsy<1,>=0.5.3
26
26
  Requires-Dist: scipy<2,>=1.13.1
27
27
  Requires-Dist: statsmodels>=0.14.5
28
- Requires-Dist: tensorflow<2.19,>=2.18
28
+ Requires-Dist: tensorflow<2.21,>=2.18
29
29
  Requires-Dist: tensorflow-probability<0.26,>=0.25
30
- Requires-Dist: tf-keras<2.19,>=2.18
30
+ Requires-Dist: tf-keras<2.21,>=2.18
31
31
  Requires-Dist: xarray
32
32
  Provides-Extra: dev
33
33
  Requires-Dist: pytest>=8.0.0; extra == "dev"
@@ -38,12 +38,12 @@ Provides-Extra: colab
38
38
  Requires-Dist: psutil; extra == "colab"
39
39
  Requires-Dist: python-calamine; extra == "colab"
40
40
  Provides-Extra: and-cuda
41
- Requires-Dist: tensorflow[and-cuda]<2.19,>=2.18; extra == "and-cuda"
41
+ Requires-Dist: tensorflow[and-cuda]<2.21,>=2.18; extra == "and-cuda"
42
42
  Provides-Extra: mlflow
43
43
  Requires-Dist: mlflow; extra == "mlflow"
44
44
  Provides-Extra: jax
45
- Requires-Dist: jax==0.4.26; extra == "jax"
46
- Requires-Dist: jaxlib==0.4.26; extra == "jax"
45
+ Requires-Dist: jax==0.5.3; extra == "jax"
46
+ Requires-Dist: jaxlib==0.5.3; extra == "jax"
47
47
  Requires-Dist: tensorflow-probability[substrates-jax]==0.25.0; extra == "jax"
48
48
  Provides-Extra: schema
49
49
  Requires-Dist: mmm-proto-schema; extra == "schema"
@@ -203,7 +203,7 @@ To cite this repository:
203
203
  author = {Google Meridian Marketing Mix Modeling Team},
204
204
  title = {Meridian: Marketing Mix Modeling},
205
205
  url = {https://github.com/google/meridian},
206
- version = {1.3.1},
206
+ version = {1.3.2},
207
207
  year = {2025},
208
208
  }
209
209
  ```
@@ -151,7 +151,7 @@ To cite this repository:
151
151
  author = {Google Meridian Marketing Mix Modeling Team},
152
152
  title = {Meridian: Marketing Mix Modeling},
153
153
  url = {https://github.com/google/meridian},
154
- version = {1.3.1},
154
+ version = {1.3.2},
155
155
  year = {2025},
156
156
  }
157
157
  ```
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: google-meridian
3
- Version: 1.3.1
3
+ Version: 1.3.2
4
4
  Summary: Google's open source mixed marketing model library, helps you understand your return on investment and direct your ad spend with confidence.
5
5
  Author-email: The Meridian Authors <no-reply@google.com>
6
6
  Project-URL: homepage, https://github.com/google/meridian
@@ -25,9 +25,9 @@ Requires-Dist: pandas<3,>=2.2.2
25
25
  Requires-Dist: patsy<1,>=0.5.3
26
26
  Requires-Dist: scipy<2,>=1.13.1
27
27
  Requires-Dist: statsmodels>=0.14.5
28
- Requires-Dist: tensorflow<2.19,>=2.18
28
+ Requires-Dist: tensorflow<2.21,>=2.18
29
29
  Requires-Dist: tensorflow-probability<0.26,>=0.25
30
- Requires-Dist: tf-keras<2.19,>=2.18
30
+ Requires-Dist: tf-keras<2.21,>=2.18
31
31
  Requires-Dist: xarray
32
32
  Provides-Extra: dev
33
33
  Requires-Dist: pytest>=8.0.0; extra == "dev"
@@ -38,12 +38,12 @@ Provides-Extra: colab
38
38
  Requires-Dist: psutil; extra == "colab"
39
39
  Requires-Dist: python-calamine; extra == "colab"
40
40
  Provides-Extra: and-cuda
41
- Requires-Dist: tensorflow[and-cuda]<2.19,>=2.18; extra == "and-cuda"
41
+ Requires-Dist: tensorflow[and-cuda]<2.21,>=2.18; extra == "and-cuda"
42
42
  Provides-Extra: mlflow
43
43
  Requires-Dist: mlflow; extra == "mlflow"
44
44
  Provides-Extra: jax
45
- Requires-Dist: jax==0.4.26; extra == "jax"
46
- Requires-Dist: jaxlib==0.4.26; extra == "jax"
45
+ Requires-Dist: jax==0.5.3; extra == "jax"
46
+ Requires-Dist: jaxlib==0.5.3; extra == "jax"
47
47
  Requires-Dist: tensorflow-probability[substrates-jax]==0.25.0; extra == "jax"
48
48
  Provides-Extra: schema
49
49
  Requires-Dist: mmm-proto-schema; extra == "schema"
@@ -203,7 +203,7 @@ To cite this repository:
203
203
  author = {Google Meridian Marketing Mix Modeling Team},
204
204
  title = {Meridian: Marketing Mix Modeling},
205
205
  url = {https://github.com/google/meridian},
206
- version = {1.3.1},
206
+ version = {1.3.2},
207
207
  year = {2025},
208
208
  }
209
209
  ```
@@ -13,7 +13,6 @@ meridian/constants.py
13
13
  meridian/version.py
14
14
  meridian/analysis/__init__.py
15
15
  meridian/analysis/analyzer.py
16
- meridian/analysis/formatter.py
17
16
  meridian/analysis/optimizer.py
18
17
  meridian/analysis/summarizer.py
19
18
  meridian/analysis/summary_text.py
@@ -25,14 +24,6 @@ meridian/analysis/review/configs.py
25
24
  meridian/analysis/review/constants.py
26
25
  meridian/analysis/review/results.py
27
26
  meridian/analysis/review/reviewer.py
28
- meridian/analysis/templates/card.html.jinja
29
- meridian/analysis/templates/chart.html.jinja
30
- meridian/analysis/templates/chips.html.jinja
31
- meridian/analysis/templates/insights.html.jinja
32
- meridian/analysis/templates/stats.html.jinja
33
- meridian/analysis/templates/style.scss
34
- meridian/analysis/templates/summary.html.jinja
35
- meridian/analysis/templates/table.html.jinja
36
27
  meridian/backend/__init__.py
37
28
  meridian/backend/config.py
38
29
  meridian/backend/test_utils.py
@@ -63,7 +54,16 @@ meridian/model/eda/constants.py
63
54
  meridian/model/eda/eda_engine.py
64
55
  meridian/model/eda/eda_outcome.py
65
56
  meridian/model/eda/eda_spec.py
66
- meridian/model/eda/meridian_eda.py
57
+ meridian/templates/card.html.jinja
58
+ meridian/templates/chart.html.jinja
59
+ meridian/templates/chips.html.jinja
60
+ meridian/templates/formatter.py
61
+ meridian/templates/formatter_test.py
62
+ meridian/templates/insights.html.jinja
63
+ meridian/templates/stats.html.jinja
64
+ meridian/templates/style.scss
65
+ meridian/templates/summary.html.jinja
66
+ meridian/templates/table.html.jinja
67
67
  schema/__init__.py
68
68
  schema/serde/__init__.py
69
69
  schema/serde/constants.py
@@ -8,13 +8,13 @@ pandas<3,>=2.2.2
8
8
  patsy<1,>=0.5.3
9
9
  scipy<2,>=1.13.1
10
10
  statsmodels>=0.14.5
11
- tensorflow<2.19,>=2.18
11
+ tensorflow<2.21,>=2.18
12
12
  tensorflow-probability<0.26,>=0.25
13
- tf-keras<2.19,>=2.18
13
+ tf-keras<2.21,>=2.18
14
14
  xarray
15
15
 
16
16
  [and-cuda]
17
- tensorflow[and-cuda]<2.19,>=2.18
17
+ tensorflow[and-cuda]<2.21,>=2.18
18
18
 
19
19
  [colab]
20
20
  psutil
@@ -27,8 +27,8 @@ pylint>=2.6.0
27
27
  pyink
28
28
 
29
29
  [jax]
30
- jax==0.4.26
31
- jaxlib==0.4.26
30
+ jax==0.5.3
31
+ jaxlib==0.5.3
32
32
  tensorflow-probability[substrates-jax]==0.25.0
33
33
 
34
34
  [mlflow]
@@ -15,9 +15,8 @@
15
15
  """Meridian analysis API for trained models."""
16
16
 
17
17
  from meridian.analysis import analyzer
18
- from meridian.analysis import formatter
19
18
  from meridian.analysis import optimizer
20
19
  from meridian.analysis import review
21
20
  from meridian.analysis import summarizer
22
21
  from meridian.analysis import visualizer
23
-
22
+ from meridian.templates import formatter
@@ -53,7 +53,6 @@ def _validate_non_media_baseline_values_numbers(
53
53
  )
54
54
 
55
55
 
56
- # TODO: Refactor the related unit tests to be under DataTensors.
57
56
  @dataclasses.dataclass
58
57
  class DataTensors(backend.ExtensionType):
59
58
  """Container for data variable arguments of Analyzer methods.
@@ -27,10 +27,10 @@ import jinja2
27
27
  from meridian import backend
28
28
  from meridian import constants as c
29
29
  from meridian.analysis import analyzer as analyzer_module
30
- from meridian.analysis import formatter
31
30
  from meridian.analysis import summary_text
32
31
  from meridian.data import time_coordinates as tc
33
32
  from meridian.model import model
33
+ from meridian.templates import formatter
34
34
  import numpy as np
35
35
  import pandas as pd
36
36
  import xarray as xr
@@ -1174,10 +1174,12 @@ class OptimizationResults:
1174
1174
  diff = self.optimized_data.total_cpik - self.nonoptimized_data.total_cpik
1175
1175
  non_optimized_performance_title = summary_text.NON_OPTIMIZED_CPIK_LABEL
1176
1176
  non_optimized_performance_stat = (
1177
- f'${self.nonoptimized_data.total_cpik:.2f}'
1177
+ f'{currency}{self.nonoptimized_data.total_cpik:.2f}'
1178
1178
  )
1179
1179
  optimized_performance_title = summary_text.OPTIMIZED_CPIK_LABEL
1180
- optimized_performance_stat = f'${self.optimized_data.total_cpik:.2f}'
1180
+ optimized_performance_stat = (
1181
+ f'{currency}{self.optimized_data.total_cpik:.2f}'
1182
+ )
1181
1183
  optimized_performance_diff = formatter.compact_number(diff, 2, currency)
1182
1184
  non_optimized_performance = formatter.StatsSpec(
1183
1185
  title=non_optimized_performance_title,
@@ -28,6 +28,7 @@ from meridian.analysis.review import constants as review_constants
28
28
  from meridian.analysis.review import results
29
29
  from meridian.model import model
30
30
  import numpy as np
31
+ import pandas as pd
31
32
 
32
33
  ConfigType = TypeVar("ConfigType", bound=configs.BaseConfig)
33
34
  ResultType = TypeVar("ResultType", bound=results.CheckResult)
@@ -207,6 +208,51 @@ class BayesianPPPCheck(
207
208
  # ==============================================================================
208
209
  # Check: Goodness of Fit
209
210
  # ==============================================================================
211
+ def _set_details_from_gof_dataframe(
212
+ details: dict[str, float],
213
+ gof_df: pd.DataFrame,
214
+ geo_granularity: str,
215
+ suffix: str | None = None,
216
+ ) -> None:
217
+ """Sets the `details` variable of the GoodnessOfFitCheckResult.
218
+
219
+ This method takes a DataFrame containing goodness of fit metrics and pivots it
220
+ to a Series, which is then added to the `details` variable of the
221
+ `GoodnessOfFitCheckResult`.
222
+
223
+ Args:
224
+ details: A dictionary to store the goodness of fit metrics in.
225
+ gof_df: A DataFrame containing predictive accuracy of the whole data (if
226
+ holdout set is not used) of filtered to a single evaluation set ("all",
227
+ "train", or "test").
228
+ geo_granularity: The geo granularity of the data ("geo" or "national").
229
+ suffix: A suffix to add to the metric names (e.g., "all", "train", "test").
230
+ If None, the metrics are added without a suffix.
231
+ """
232
+ gof_metrics_pivoted = gof_df.pivot(
233
+ index=constants.GEO_GRANULARITY,
234
+ columns=constants.METRIC,
235
+ values=constants.VALUE,
236
+ )
237
+ gof_metrics_series = gof_metrics_pivoted.loc[geo_granularity]
238
+ if suffix is not None:
239
+ details[f"{review_constants.R_SQUARED}_{suffix}"] = gof_metrics_series[
240
+ constants.R_SQUARED
241
+ ]
242
+ details[f"{review_constants.MAPE}_{suffix}"] = gof_metrics_series[
243
+ constants.MAPE
244
+ ]
245
+ details[f"{review_constants.WMAPE}_{suffix}"] = gof_metrics_series[
246
+ constants.WMAPE
247
+ ]
248
+ else:
249
+ details[review_constants.R_SQUARED] = gof_metrics_series[
250
+ constants.R_SQUARED
251
+ ]
252
+ details[review_constants.MAPE] = gof_metrics_series[constants.MAPE]
253
+ details[review_constants.WMAPE] = gof_metrics_series[constants.WMAPE]
254
+
255
+
210
256
  class GoodnessOfFitCheck(
211
257
  BaseCheck[configs.GoodnessOfFitConfig, results.GoodnessOfFitCheckResult]
212
258
  ):
@@ -221,38 +267,43 @@ class GoodnessOfFitCheck(
221
267
  )
222
268
 
223
269
  gof_metrics = gof_df[gof_df[constants.GEO_GRANULARITY] == geo_granularity]
224
- if constants.EVALUATION_SET_VAR in gof_df.columns:
225
- gof_metrics = gof_metrics[
226
- gof_metrics[constants.EVALUATION_SET_VAR] == constants.ALL_DATA
227
- ]
228
-
229
- gof_metrics_pivoted = gof_metrics.pivot(
230
- index=constants.GEO_GRANULARITY,
231
- columns=constants.METRIC,
232
- values=constants.VALUE,
233
- )
234
- gof_metrics_series = gof_metrics_pivoted.loc[geo_granularity]
235
-
236
- r_squared = gof_metrics_series[constants.R_SQUARED]
237
- mape = gof_metrics_series[constants.MAPE]
238
- wmape = gof_metrics_series[constants.WMAPE]
239
-
240
- details = {
241
- review_constants.R_SQUARED: r_squared,
242
- review_constants.MAPE: mape,
243
- review_constants.WMAPE: wmape,
244
- }
245
-
246
- if r_squared > 0:
247
- return results.GoodnessOfFitCheckResult(
248
- case=results.GoodnessOfFitCases.PASS,
249
- details=details,
250
- )
251
- else: # r_squared <= 0
252
- return results.GoodnessOfFitCheckResult(
253
- case=results.GoodnessOfFitCases.REVIEW,
270
+ is_holdout = constants.EVALUATION_SET_VAR in gof_df.columns
271
+
272
+ details = {}
273
+ case = results.GoodnessOfFitCases.PASS
274
+
275
+ if is_holdout:
276
+ for evaluation_set, suffix in [
277
+ (constants.ALL_DATA, review_constants.ALL_SUFFIX),
278
+ (constants.TRAIN, review_constants.TRAIN_SUFFIX),
279
+ (constants.TEST, review_constants.TEST_SUFFIX),
280
+ ]:
281
+ set_metrics = gof_metrics[
282
+ gof_metrics[constants.EVALUATION_SET_VAR] == evaluation_set
283
+ ]
284
+ _set_details_from_gof_dataframe(
285
+ details=details,
286
+ gof_df=set_metrics,
287
+ geo_granularity=geo_granularity,
288
+ suffix=suffix,
289
+ )
290
+ if details[f"{review_constants.R_SQUARED}_{suffix}"] <= 0:
291
+ case = results.GoodnessOfFitCases.REVIEW
292
+ else:
293
+ _set_details_from_gof_dataframe(
254
294
  details=details,
295
+ gof_df=gof_metrics,
296
+ geo_granularity=geo_granularity,
297
+ suffix=None,
255
298
  )
299
+ if details[review_constants.R_SQUARED] <= 0:
300
+ case = results.GoodnessOfFitCases.REVIEW
301
+
302
+ return results.GoodnessOfFitCheckResult(
303
+ case=case,
304
+ details=details,
305
+ is_holdout=is_holdout,
306
+ )
256
307
 
257
308
 
258
309
  # ==============================================================================
@@ -32,6 +32,10 @@ NEGATIVE_BASELINE_PROB_REVIEW_THRESHOLD = (
32
32
  R_SQUARED = "r_squared"
33
33
  MAPE = "mape"
34
34
  WMAPE = "wmape"
35
+ ALL_SUFFIX = "all"
36
+ TRAIN_SUFFIX = "train"
37
+ TEST_SUFFIX = "test"
38
+ EVALUATION_SET_SUFFIXES = (ALL_SUFFIX, TRAIN_SUFFIX, TEST_SUFFIX)
35
39
  MEAN = "mean"
36
40
  VARIANCE = "variance"
37
41
  MEDIAN = "median"
@@ -319,18 +319,12 @@ class GoodnessOfFitCases(ModelCheckCase, enum.Enum):
319
319
 
320
320
  PASS = (
321
321
  Status.PASS,
322
- (
323
- "R-squared = {r_squared:.4f}, MAPE = {mape:.4f}, and wMAPE ="
324
- " {wmape:.4f}."
325
- ),
322
+ "R-squared = {r_squared:.4f}, MAPE = {mape:.4f}, and wMAPE = {wmape:.4f}",
326
323
  _GOODNESS_OF_FIT_PASS_RECOMMENDATION,
327
324
  )
328
325
  REVIEW = (
329
326
  Status.REVIEW,
330
- (
331
- "R-squared = {r_squared:.4f}, MAPE = {mape:.4f}, and wMAPE ="
332
- " {wmape:.4f}."
333
- ),
327
+ "R-squared = {r_squared:.4f}, MAPE = {mape:.4f}, and wMAPE = {wmape:.4f}",
334
328
  _GOODNESS_OF_FIT_REVIEW_RECOMMENDATION,
335
329
  )
336
330
 
@@ -348,9 +342,28 @@ class GoodnessOfFitCheckResult(CheckResult):
348
342
  """The immutable result of the Goodness of Fit Check."""
349
343
 
350
344
  case: GoodnessOfFitCases
345
+ is_holdout: bool = False
351
346
 
352
347
  def __post_init__(self):
353
- if any(
348
+ if self.is_holdout:
349
+ required_keys = []
350
+ for suffix in [
351
+ constants.ALL_SUFFIX,
352
+ constants.TRAIN_SUFFIX,
353
+ constants.TEST_SUFFIX,
354
+ ]:
355
+ required_keys.extend([
356
+ f"{constants.R_SQUARED}_{suffix}",
357
+ f"{constants.MAPE}_{suffix}",
358
+ f"{constants.WMAPE}_{suffix}",
359
+ ])
360
+ if any(key not in self.details for key in required_keys):
361
+ raise ValueError(
362
+ "The message template is missing required formatting arguments for"
363
+ f" holdout case. Required keys: {required_keys}. Details:"
364
+ f" {self.details}."
365
+ )
366
+ elif any(
354
367
  key not in self.details
355
368
  for key in (
356
369
  constants.R_SQUARED,
@@ -364,6 +377,24 @@ class GoodnessOfFitCheckResult(CheckResult):
364
377
  f" {self.details}."
365
378
  )
366
379
 
380
+ @property
381
+ def recommendation(self) -> str:
382
+ """Returns the check result message."""
383
+ if self.is_holdout:
384
+ report_str = (
385
+ "R-squared = {r_squared_all:.4f} (All),"
386
+ " {r_squared_train:.4f} (Train), {r_squared_test:.4f} (Test); MAPE"
387
+ " = {mape_all:.4f} (All), {mape_train:.4f} (Train),"
388
+ " {mape_test:.4f} (Test); wMAPE = {wmape_all:.4f} (All),"
389
+ " {wmape_train:.4f} (Train), {wmape_test:.4f} (Test)".format(
390
+ **self.details
391
+ )
392
+ )
393
+ else:
394
+ report_str = self.case.message_template.format(**self.details)
395
+
396
+ return f"{report_str}. {self.case.recommendation}"
397
+
367
398
 
368
399
  # ==============================================================================
369
400
  # Check: ROI Consistency
@@ -21,11 +21,11 @@ import os
21
21
  import jinja2
22
22
  from meridian import constants as c
23
23
  from meridian.analysis import analyzer
24
- from meridian.analysis import formatter
25
24
  from meridian.analysis import summary_text
26
25
  from meridian.analysis import visualizer
27
26
  from meridian.data import time_coordinates as tc
28
27
  from meridian.model import model
28
+ from meridian.templates import formatter
29
29
  import pandas as pd
30
30
  import xarray as xr
31
31
 
@@ -22,9 +22,9 @@ import altair as alt
22
22
  from meridian import backend
23
23
  from meridian import constants as c
24
24
  from meridian.analysis import analyzer
25
- from meridian.analysis import formatter
26
25
  from meridian.analysis import summary_text
27
26
  from meridian.model import model
27
+ from meridian.templates import formatter
28
28
  import numpy as np
29
29
  import pandas as pd
30
30
  import xarray as xr
@@ -609,7 +609,37 @@ def _tf_get_seed_data(seed: Any) -> Optional[np.ndarray]:
609
609
 
610
610
 
611
611
  def _jax_convert_to_tensor(data, dtype=None):
612
- """Converts data to a JAX array, handling strings as NumPy arrays."""
612
+ """Converts data to a JAX array, handling strings as NumPy arrays.
613
+
614
+ This function explicitly unwraps objects with a `.values` attribute (e.g.,
615
+ pandas.DataFrame, xarray.DataArray) to access the underlying NumPy array,
616
+ provided that `.values` is not a method. This takes precedence over the
617
+ `__array__` protocol.
618
+
619
+ It also handles precision mismatches: if `data` is float64 and `dtype` is
620
+ not specified, and JAX x64 mode is disabled (default), it issues a warning
621
+ and explicitly casts to float32 to match the backend default and prevent
622
+ silent precision loss or type errors in downstream operations.
623
+
624
+ Args:
625
+ data: The data to convert.
626
+ dtype: The desired data type.
627
+
628
+ Returns:
629
+ A JAX array, or a NumPy array if the dtype is a string type.
630
+ """
631
+ # Unwrap xarray.DataArray, pandas.Series, and pandas.DataFrame objects.
632
+ # These objects wrap the underlying NumPy array in a .values attribute.
633
+ if hasattr(data, "values") and not callable(data.values):
634
+ data = data.values
635
+
636
+ # Convert to numpy array upfront to simplify dtype inspection below.
637
+ # A standard Python float is 64-bit, and this conversion allows the
638
+ # subsequent logic to correctly detect and handle potential float64
639
+ # downcasting for scalar inputs.
640
+ if isinstance(data, (list, tuple, float)):
641
+ data = np.array(data)
642
+
613
643
  # JAX does not natively support string tensors in the same way TF does.
614
644
  # If a string dtype is requested, or if the data is inherently strings,
615
645
  # we fall back to a standard NumPy array.
@@ -623,13 +653,31 @@ def _jax_convert_to_tensor(data, dtype=None):
623
653
  # let jax.asarray handle it.
624
654
  pass
625
655
 
626
- is_string_data = isinstance(data, (list, np.ndarray)) and np.array(
627
- data
628
- ).dtype.kind in ("S", "U")
656
+ is_string_data = isinstance(data, np.ndarray) and data.dtype.kind in (
657
+ "S",
658
+ "U",
659
+ )
629
660
 
630
- if is_string_target or (dtype is None and is_string_data):
661
+ if is_string_target:
631
662
  return np.array(data, dtype=dtype)
632
663
 
664
+ if dtype is None and is_string_data:
665
+ return data
666
+
667
+ # If the user provides float64 data but does not request a specific dtype,
668
+ # and JAX 64-bit mode is disabled (default), JAX would implicitly truncate.
669
+ # We cast to float32 and warn the user to prevent silent mismatches.
670
+ if dtype is None:
671
+ is_float64_input = hasattr(data, "dtype") and data.dtype == np.float64
672
+ if is_float64_input:
673
+ if not jax.config.jax_enable_x64:
674
+ warnings.warn(
675
+ "Input data is float64. Casting to float32 to match backend "
676
+ "default precision.",
677
+ UserWarning,
678
+ )
679
+ dtype = jax_ops.float32
680
+
633
681
  return jax_ops.asarray(data, dtype=dtype)
634
682
 
635
683
 
@@ -14,6 +14,7 @@
14
14
 
15
15
  """Common testing utilities for Meridian, designed to be backend-agnostic."""
16
16
 
17
+ import dataclasses
17
18
  from typing import Any, Optional
18
19
 
19
20
  from absl.testing import parameterized
@@ -26,6 +27,8 @@ import numpy as np
26
27
  from tensorflow.python.util.protobuf import compare
27
28
  # pylint: disable=g-direct-tensorflow-import
28
29
  from tensorflow.core.framework import tensor_pb2
30
+
31
+
29
32
  # pylint: enable=g-direct-tensorflow-import
30
33
 
31
34
  FieldDescriptor = descriptor.FieldDescriptor
@@ -80,6 +83,75 @@ def assert_allequal(a: ArrayLike, b: ArrayLike, err_msg: str = ""):
80
83
  np.testing.assert_array_equal(np.array(a), np.array(b), err_msg=err_msg)
81
84
 
82
85
 
86
+ def assert_deep_equals(
87
+ test_case,
88
+ obj1: Any,
89
+ obj2: Any,
90
+ msg: str = "",
91
+ rtol: float = 1e-5,
92
+ atol: float = 1e-5,
93
+ ):
94
+ """Recursive equality check handling Dataclasses, Lists, and Backend Tensors.
95
+
96
+ Args:
97
+ test_case: The unittest.TestCase instance (self) to use for assertions.
98
+ obj1: The first object to compare.
99
+ obj2: The second object to compare.
100
+ msg: Optional error message prefix.
101
+ rtol: Relative tolerance for float comparison.
102
+ atol: Absolute tolerance for float comparison.
103
+ """
104
+ if obj1 is None or obj2 is None:
105
+ test_case.assertEqual(obj1, obj2, msg=msg)
106
+ return
107
+
108
+ if (
109
+ hasattr(obj1, "__array__")
110
+ or hasattr(obj1, "numpy")
111
+ or isinstance(obj1, (np.ndarray, backend.Tensor))
112
+ ):
113
+ arr1 = np.array(obj1)
114
+ arr2 = np.array(obj2)
115
+
116
+ # Check for non-numeric types where atol/rtol don't apply
117
+ if arr1.dtype.kind in ("U", "S", "O", "b"):
118
+ np.testing.assert_array_equal(arr1, arr2, err_msg=msg)
119
+ else:
120
+ np.testing.assert_allclose(arr1, arr2, err_msg=msg, rtol=rtol, atol=atol)
121
+ return
122
+
123
+ if dataclasses.is_dataclass(obj1):
124
+ test_case.assertIs(
125
+ type(obj1),
126
+ type(obj2),
127
+ msg=f"{msg} Type mismatch: {type(obj1)} vs {type(obj2)}",
128
+ )
129
+ for field in dataclasses.fields(obj1):
130
+ val1 = getattr(obj1, field.name)
131
+ val2 = getattr(obj2, field.name)
132
+ assert_deep_equals(
133
+ test_case,
134
+ val1,
135
+ val2,
136
+ msg=f"{msg}.{field.name}",
137
+ rtol=rtol,
138
+ atol=atol,
139
+ )
140
+ return
141
+
142
+ if isinstance(obj1, (list, tuple)):
143
+ test_case.assertIsInstance(obj2, (list, tuple), msg=f"{msg} Type mismatch")
144
+ test_case.assertEqual(len(obj1), len(obj2), msg=f"{msg} Length mismatch")
145
+ for i, (item1, item2) in enumerate(zip(obj1, obj2)):
146
+ assert_deep_equals(
147
+ test_case, item1, item2, msg=f"{msg}[{i}]", rtol=rtol, atol=atol
148
+ )
149
+ return
150
+
151
+ # Fallback to standard equality for primitives (int, str, float, etc.)
152
+ test_case.assertEqual(obj1, obj2, msg=msg)
153
+
154
+
83
155
  def assert_seed_allequal(a: Any, b: Any, err_msg: str = ""):
84
156
  """Backend-agnostic assertion to check if two seed objects are equal."""
85
157
  data_a = backend.get_seed_data(a)
@@ -662,6 +662,7 @@ CURRENT_SPEND = 'current_spend'
662
662
 
663
663
  # Media summary metrics.
664
664
  SPEND = 'spend'
665
+ COST = 'cost'
665
666
  IMPRESSIONS = 'impressions'
666
667
  ROI = 'roi'
667
668
  OPTIMIZED_ROI = 'optimized_roi'
@@ -35,6 +35,8 @@ __all__ = [
35
35
  'InputDataLoader',
36
36
  'XrDatasetDataLoader',
37
37
  'DataFrameDataLoader',
38
+ 'CoordToColumns',
39
+ 'CsvDataLoader',
38
40
  ]
39
41
 
40
42
 
@@ -17,4 +17,3 @@
17
17
  from meridian.model.eda import eda_engine
18
18
  from meridian.model.eda import eda_outcome
19
19
  from meridian.model.eda import eda_spec
20
- from meridian.model.eda import meridian_eda