google-meridian 1.3.0__py3-none-any.whl → 1.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google_meridian-1.3.1.dist-info/METADATA +209 -0
- {google_meridian-1.3.0.dist-info → google_meridian-1.3.1.dist-info}/RECORD +24 -10
- {google_meridian-1.3.0.dist-info → google_meridian-1.3.1.dist-info}/top_level.txt +1 -0
- meridian/backend/__init__.py +180 -23
- meridian/backend/test_utils.py +122 -0
- meridian/model/eda/eda_engine.py +54 -8
- meridian/model/model_test_data.py +15 -0
- meridian/version.py +1 -1
- schema/__init__.py +18 -0
- schema/serde/__init__.py +26 -0
- schema/serde/constants.py +48 -0
- schema/serde/distribution.py +515 -0
- schema/serde/eda_spec.py +192 -0
- schema/serde/function_registry.py +143 -0
- schema/serde/hyperparameters.py +363 -0
- schema/serde/inference_data.py +105 -0
- schema/serde/marketing_data.py +1321 -0
- schema/serde/meridian_serde.py +413 -0
- schema/serde/serde.py +47 -0
- schema/serde/test_data.py +4608 -0
- schema/utils/__init__.py +17 -0
- schema/utils/time_record.py +156 -0
- google_meridian-1.3.0.dist-info/METADATA +0 -409
- {google_meridian-1.3.0.dist-info → google_meridian-1.3.1.dist-info}/WHEEL +0 -0
- {google_meridian-1.3.0.dist-info → google_meridian-1.3.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: google-meridian
|
|
3
|
+
Version: 1.3.1
|
|
4
|
+
Summary: Google's open source mixed marketing model library, helps you understand your return on investment and direct your ad spend with confidence.
|
|
5
|
+
Author-email: The Meridian Authors <no-reply@google.com>
|
|
6
|
+
Project-URL: homepage, https://github.com/google/meridian
|
|
7
|
+
Project-URL: repository, https://github.com/google/meridian
|
|
8
|
+
Project-URL: changelog, https://github.com/google/meridian/blob/main/CHANGELOG.md
|
|
9
|
+
Project-URL: documentation, https://developers.google.com/meridian
|
|
10
|
+
Keywords: mmm
|
|
11
|
+
Classifier: Programming Language :: Python :: 3
|
|
12
|
+
Classifier: Programming Language :: Python :: 3 :: Only
|
|
13
|
+
Classifier: Topic :: Other/Nonlisted Topic
|
|
14
|
+
Classifier: Topic :: Scientific/Engineering :: Information Analysis
|
|
15
|
+
Requires-Python: >=3.10
|
|
16
|
+
Description-Content-Type: text/markdown
|
|
17
|
+
License-File: LICENSE
|
|
18
|
+
Requires-Dist: arviz
|
|
19
|
+
Requires-Dist: altair>=5
|
|
20
|
+
Requires-Dist: immutabledict
|
|
21
|
+
Requires-Dist: joblib
|
|
22
|
+
Requires-Dist: natsort<8,>=7.1.1
|
|
23
|
+
Requires-Dist: numpy<3,>=2.0.2
|
|
24
|
+
Requires-Dist: pandas<3,>=2.2.2
|
|
25
|
+
Requires-Dist: patsy<1,>=0.5.3
|
|
26
|
+
Requires-Dist: scipy<2,>=1.13.1
|
|
27
|
+
Requires-Dist: statsmodels>=0.14.5
|
|
28
|
+
Requires-Dist: tensorflow<2.19,>=2.18
|
|
29
|
+
Requires-Dist: tensorflow-probability<0.26,>=0.25
|
|
30
|
+
Requires-Dist: tf-keras<2.19,>=2.18
|
|
31
|
+
Requires-Dist: xarray
|
|
32
|
+
Provides-Extra: dev
|
|
33
|
+
Requires-Dist: pytest>=8.0.0; extra == "dev"
|
|
34
|
+
Requires-Dist: pytest-xdist; extra == "dev"
|
|
35
|
+
Requires-Dist: pylint>=2.6.0; extra == "dev"
|
|
36
|
+
Requires-Dist: pyink; extra == "dev"
|
|
37
|
+
Provides-Extra: colab
|
|
38
|
+
Requires-Dist: psutil; extra == "colab"
|
|
39
|
+
Requires-Dist: python-calamine; extra == "colab"
|
|
40
|
+
Provides-Extra: and-cuda
|
|
41
|
+
Requires-Dist: tensorflow[and-cuda]<2.19,>=2.18; extra == "and-cuda"
|
|
42
|
+
Provides-Extra: mlflow
|
|
43
|
+
Requires-Dist: mlflow; extra == "mlflow"
|
|
44
|
+
Provides-Extra: jax
|
|
45
|
+
Requires-Dist: jax==0.4.26; extra == "jax"
|
|
46
|
+
Requires-Dist: jaxlib==0.4.26; extra == "jax"
|
|
47
|
+
Requires-Dist: tensorflow-probability[substrates-jax]==0.25.0; extra == "jax"
|
|
48
|
+
Provides-Extra: schema
|
|
49
|
+
Requires-Dist: mmm-proto-schema; extra == "schema"
|
|
50
|
+
Requires-Dist: semver; extra == "schema"
|
|
51
|
+
Dynamic: license-file
|
|
52
|
+
|
|
53
|
+
# About Meridian
|
|
54
|
+
|
|
55
|
+
Marketing mix modeling (MMM) is a statistical analysis technique that measures
|
|
56
|
+
the impact of marketing campaigns and activities to guide budget planning
|
|
57
|
+
decisions and improve overall media effectiveness. MMM uses aggregated data to
|
|
58
|
+
measure impact across marketing channels and account for non-marketing factors
|
|
59
|
+
that impact sales and other key performance indicators (KPIs). MMM is
|
|
60
|
+
privacy-safe and does not use any cookie or user-level information.
|
|
61
|
+
|
|
62
|
+
Meridian is an MMM framework that enables advertisers to set up and run their
|
|
63
|
+
own in-house models. Meridian helps you answer key questions such as:
|
|
64
|
+
|
|
65
|
+
* How did the marketing channels drive my revenue or other KPI?
|
|
66
|
+
* What was my marketing return on investment (ROI)?
|
|
67
|
+
* How do I optimize my marketing budget allocation for the future?
|
|
68
|
+
|
|
69
|
+
Meridian is a highly customizable modeling framework that is based on
|
|
70
|
+
[Bayesian causal inference](https://developers.google.com/meridian/docs/causal-inference/bayesian-inference).
|
|
71
|
+
It is capable of handling large scale geo-level data, which is encouraged if
|
|
72
|
+
available, but it can also be used for national-level modeling. Meridian
|
|
73
|
+
provides clear insights and visualizations to inform business decisions around
|
|
74
|
+
marketing budget and planning. Additionally, Meridian provides methodologies to
|
|
75
|
+
support calibration of MMM with experiments and other prior information, and to
|
|
76
|
+
optimize target ad frequency by utilizing reach and frequency data.
|
|
77
|
+
|
|
78
|
+
If you are using LightweightMMM, see the
|
|
79
|
+
[migration guide](https://developers.google.com/meridian/docs/migrate) to help
|
|
80
|
+
you understand the differences between these MMM projects.
|
|
81
|
+
|
|
82
|
+
## Install Meridian
|
|
83
|
+
|
|
84
|
+
Python 3.11 or 3.12 is required to use Meridian. We also recommend using a
|
|
85
|
+
minimum of 1 GPU.
|
|
86
|
+
|
|
87
|
+
Note: This project has been tested on T4 GPU using 16 GB of RAM.
|
|
88
|
+
|
|
89
|
+
To install Meridian, run the following command to automatically install the
|
|
90
|
+
latest release from PyPI.
|
|
91
|
+
|
|
92
|
+
* For Linux-GPU users:
|
|
93
|
+
|
|
94
|
+
Note: CUDA toolchain and a compatible GPU device is necessary for
|
|
95
|
+
`[and-cuda]` extra to activate.
|
|
96
|
+
|
|
97
|
+
```sh
|
|
98
|
+
$ pip install --upgrade google-meridian[and-cuda]
|
|
99
|
+
```
|
|
100
|
+
|
|
101
|
+
* For macOS and general CPU users:
|
|
102
|
+
|
|
103
|
+
Note: There is no official GPU support for macOS.
|
|
104
|
+
|
|
105
|
+
```sh
|
|
106
|
+
$ pip install --upgrade google-meridian
|
|
107
|
+
```
|
|
108
|
+
|
|
109
|
+
Alternatively, run the following command to install the most recent, unreleased
|
|
110
|
+
version from GitHub.
|
|
111
|
+
|
|
112
|
+
* For GPU users:
|
|
113
|
+
|
|
114
|
+
```sh
|
|
115
|
+
$ pip install --upgrade "google-meridian[and-cuda] @ git+https://github.com/google/meridian.git"
|
|
116
|
+
```
|
|
117
|
+
|
|
118
|
+
* For CPU users:
|
|
119
|
+
|
|
120
|
+
```sh
|
|
121
|
+
$ pip install --upgrade git+https://github.com/google/meridian.git
|
|
122
|
+
```
|
|
123
|
+
|
|
124
|
+
We recommend to install Meridian in a fresh
|
|
125
|
+
[virtual environment](https://virtualenv.pypa.io/en/latest/user_guide.html#quick-start)
|
|
126
|
+
to make sure that correct versions of all the dependencies are installed, as
|
|
127
|
+
defined in [pyproject.toml](https://github.com/google/meridian/blob/main/pyproject.toml).
|
|
128
|
+
|
|
129
|
+
## How to use the Meridian library
|
|
130
|
+
|
|
131
|
+
To get started with Meridian, you can run the code programmatically using sample
|
|
132
|
+
data with the [Getting Started Colab][3].
|
|
133
|
+
|
|
134
|
+
The Meridian model uses a holistic MCMC sampling approach called
|
|
135
|
+
[No U Turn Sampler (NUTS)](https://www.tensorflow.org/probability/api_docs/python/tfp/experimental/mcmc/NoUTurnSampler)
|
|
136
|
+
which can be compute intensive. To help with this, GPU support has been
|
|
137
|
+
developed across the library (out-of-the-box) using tensors. We recommend
|
|
138
|
+
running your Meridian model on GPUs to get real time optimization results and
|
|
139
|
+
significantly reduce training time.
|
|
140
|
+
|
|
141
|
+
## Meridian Documentation & Tutorials
|
|
142
|
+
|
|
143
|
+
The following documentation, colab, and video resources will help you get
|
|
144
|
+
started quickly with using Meridian:
|
|
145
|
+
|
|
146
|
+
| Resource | Description |
|
|
147
|
+
| --------------------------- | ---------------------------------------------- |
|
|
148
|
+
| [Meridian documentation][1] | Main landing page for Meridian documentation. |
|
|
149
|
+
| [Meridian basics][2] | Learn about Meridian features, methodologies, and the model math. |
|
|
150
|
+
| [Getting started colab][3] | Install and quickly learn how to use Meridian with this colab tutorial using sample data. |
|
|
151
|
+
| [User guide][4] | A detailed walk-through of how to use Meridian and generating visualizations using your own data. |
|
|
152
|
+
| [Pre-modeling][5] | Prepare and analyze your data before modeling. |
|
|
153
|
+
| [Modeling][6] | Modeling guidance for model refinement and edge cases. |
|
|
154
|
+
| [Post-modeling][7] | Post-modeling guidance for model fit, visualizations, optimizations, refreshing the model, and debugging. |
|
|
155
|
+
| [Migrate from LMMM][8] | Learn about the differences between Meridian and LightweightMMM as you consider migrating. |
|
|
156
|
+
| [API Reference][9] | API reference documentation for the Meridian package. |
|
|
157
|
+
| [Reference list][10] | White papers and other referenced material. |
|
|
158
|
+
|
|
159
|
+
[1]: https://developers.google.com/meridian
|
|
160
|
+
[2]: https://developers.google.com/meridian/docs/basics/meridian-introduction
|
|
161
|
+
[3]: https://developers.google.com/meridian/notebook/meridian-getting-started
|
|
162
|
+
[4]: https://developers.google.com/meridian/docs/user-guide/installing
|
|
163
|
+
[5]: https://developers.google.com/meridian/docs/pre-modeling/collect-data
|
|
164
|
+
[6]: https://developers.google.com/meridian/docs/advanced-modeling/control-variables
|
|
165
|
+
[7]: https://developers.google.com/meridian/docs/post-modeling/model-fit
|
|
166
|
+
[8]: https://developers.google.com/meridian/docs/migrate
|
|
167
|
+
[9]: https://developers.google.com/meridian/reference/api/meridian
|
|
168
|
+
[10]: https://developers.google.com/meridian/docs/reference-list
|
|
169
|
+
|
|
170
|
+
## Support
|
|
171
|
+
|
|
172
|
+
**Questions about methodology**: Please see the [Modeling](https://developers.google.com/meridian/docs/basics/meridian-introduction) tab in the technical documentation.
|
|
173
|
+
|
|
174
|
+
**Issues installing or using Meridian**: Feel free to post questions in the
|
|
175
|
+
[Discussions](https://github.com/google/meridian/discussions) or [Issues](https://github.com/google/meridian/issues) tabs of the Meridian GitHub repository. The Meridian team responds to
|
|
176
|
+
these questions weekly in batches, so please be patient and don't reach out
|
|
177
|
+
directly to your Google Account teams.
|
|
178
|
+
|
|
179
|
+
**Bug reports**: Please post bug reports to the [Issues](https://github.com/google/meridian/issues)
|
|
180
|
+
tab of the Meridian GitHub repository. We also encourage the community to share
|
|
181
|
+
tips and advice with each other on the [Issues](https://github.com/google/meridian/issues)
|
|
182
|
+
tab. When our team addresses or resolves a new bug, we will notify you through
|
|
183
|
+
the comments on the issue.
|
|
184
|
+
|
|
185
|
+
**Feature requests**: Please post these to the [Discussions](https://github.com/google/meridian/discussions)
|
|
186
|
+
tab of the Meridian GitHub repository. We have an internal roadmap for Meridian
|
|
187
|
+
development, but would love your inputs for new feature requests so that we can
|
|
188
|
+
prioritize them based on the roadmap.
|
|
189
|
+
|
|
190
|
+
**Pull requests**: These are appreciated but are very difficult for us to merge
|
|
191
|
+
because the code in this repository is linked to Google internal systems and has
|
|
192
|
+
to pass internal review. If you submit a pull request and we believe that we can
|
|
193
|
+
incorporate a change in the base code, we will reach out to you directly about
|
|
194
|
+
this.
|
|
195
|
+
|
|
196
|
+
## Citing Meridian
|
|
197
|
+
|
|
198
|
+
To cite this repository:
|
|
199
|
+
|
|
200
|
+
<!-- mdlint off(SNIPPET_INVALID_LANGUAGE) -->
|
|
201
|
+
```BibTeX
|
|
202
|
+
@software{meridian_github,
|
|
203
|
+
author = {Google Meridian Marketing Mix Modeling Team},
|
|
204
|
+
title = {Meridian: Marketing Mix Modeling},
|
|
205
|
+
url = {https://github.com/google/meridian},
|
|
206
|
+
version = {1.3.1},
|
|
207
|
+
year = {2025},
|
|
208
|
+
}
|
|
209
|
+
```
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
google_meridian-1.3.
|
|
1
|
+
google_meridian-1.3.1.dist-info/licenses/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
|
|
2
2
|
meridian/__init__.py,sha256=0fOT5oNZF7-pbiWWGUefV-ysafttieG079m1ijMFQO8,861
|
|
3
3
|
meridian/constants.py,sha256=ZmMIoJDFQvKIOVG9oPOQ7Cj16wt4HDS5fCPBrz_KiLE,20308
|
|
4
|
-
meridian/version.py,sha256=
|
|
4
|
+
meridian/version.py,sha256=mmCkrGRWB8mI33apVnverT2ysfJofSyhlXSCVNotj9U,644
|
|
5
5
|
meridian/analysis/__init__.py,sha256=NWDhtRkKs-n66C6746rXT7Wk8tLdkrT9NDvuPG_B_5c,874
|
|
6
6
|
meridian/analysis/analyzer.py,sha256=EUKoN69EEz4vt4BpyRvrwEoHt7eLZ8UM6emNe14Tbj0,219975
|
|
7
7
|
meridian/analysis/formatter.py,sha256=AN2M4jdUV88XjNdi5-dK_mITES1J1Dk3Zs7DYo3OTKg,7290
|
|
@@ -25,9 +25,9 @@ meridian/analysis/templates/style.css,sha256=RODTWc2pXcG9zW3q9SEJpVXgeD-WwQgzLpm
|
|
|
25
25
|
meridian/analysis/templates/style.scss,sha256=nSrZOpcIrVyiL4eC9jLUlxIZtAKZ0Rt8pwfk4H1nMrs,5076
|
|
26
26
|
meridian/analysis/templates/summary.html.jinja,sha256=LuENVDHYIpNo4pzloYaCR2K9XN1Ow6_9oQOcOwD9nGg,1707
|
|
27
27
|
meridian/analysis/templates/table.html.jinja,sha256=mvLMZx92RcD2JAS2w2eZtfYG-6WdfwYVo7pM8TbHp4g,1176
|
|
28
|
-
meridian/backend/__init__.py,sha256=
|
|
28
|
+
meridian/backend/__init__.py,sha256=ftXcdb3tIky_m8exhD9RRhaEqTEvMcqZ8lGkkR1PjsE,39495
|
|
29
29
|
meridian/backend/config.py,sha256=B9VQnhBfg9RW04GNbt7F5uCugByenoJzt-keFLLYEp8,3561
|
|
30
|
-
meridian/backend/test_utils.py,sha256=
|
|
30
|
+
meridian/backend/test_utils.py,sha256=DYU5IpWiyM26aTY9Q84mOGRw0dQ9XmsZRsAJNFuZDp0,11667
|
|
31
31
|
meridian/data/__init__.py,sha256=StIe-wfYnnbfUbKtZHwnAQcRQUS8XCZk_PCaEzw90Ww,929
|
|
32
32
|
meridian/data/arg_builder.py,sha256=Kqlt88bOqFj6D3xNwvWo4MBwNwcDFHzd-wMfEOmLoPU,3741
|
|
33
33
|
meridian/data/data_frame_input_data_builder.py,sha256=_hexZMFAuAowgo6FaOGElHSFHqhGnHQwEEBcwnT3zUE,27295
|
|
@@ -44,7 +44,7 @@ meridian/model/adstock_hill.py,sha256=HoRKjyL03pCTBz6Utof9wEvlQCFM43BvrEW_oupj7N
|
|
|
44
44
|
meridian/model/knots.py,sha256=87kw5oa3T1k9GgT_aWXTqQx5XCxLsS2w1hnzc581XL0,26677
|
|
45
45
|
meridian/model/media.py,sha256=skjy4Vd8LfDQWlqR_2lJ1qbG9UcS1dow5W45BAu4qk8,14599
|
|
46
46
|
meridian/model/model.py,sha256=jMtfl7woWtJ8M8AX42QeZ5hUS8hlhPdZ-9OU8KahjKA,68984
|
|
47
|
-
meridian/model/model_test_data.py,sha256=
|
|
47
|
+
meridian/model/model_test_data.py,sha256=XGBz8RGdCsjAUOmgxX3CfWSj-_hdq2Lc8saFCqmImwM,23901
|
|
48
48
|
meridian/model/posterior_sampler.py,sha256=f3MayglIgBeBjWeXJU_RgT9cCugcjJ3aEjHqaWPsTbg,26806
|
|
49
49
|
meridian/model/prior_distribution.py,sha256=ZArW4uXIPPQL6hRWiGZUzcHktbkjE_vOklvlbp9LR64,57662
|
|
50
50
|
meridian/model/prior_sampler.py,sha256=iLvCefhA4WY0ENcnLK9471WUZPPyzQ1je58MRjxKv74,25460
|
|
@@ -52,11 +52,25 @@ meridian/model/spec.py,sha256=VlK6WJiPo2lzOF0O2judtJ6O3uEw7wYL5AT8bioq4gE,19188
|
|
|
52
52
|
meridian/model/transformers.py,sha256=HxlVJitxP-wu-NOHU0tArFUZ4NAO3c7adAYj4Zvqnvo,8363
|
|
53
53
|
meridian/model/eda/__init__.py,sha256=w3p7ZUZLq5TOEHm8n2P1CWjGrzuNrkqSSnVFdlw17Dk,812
|
|
54
54
|
meridian/model/eda/constants.py,sha256=V9aOHQDvB3WAEyT0NE4gE8rqbStaGVh3XDlBPOKpuLc,739
|
|
55
|
-
meridian/model/eda/eda_engine.py,sha256=
|
|
55
|
+
meridian/model/eda/eda_engine.py,sha256=5Ikgiz-6d3uTBan71WwsgPHOqk0fAy390ZLCq9L6HoY,64846
|
|
56
56
|
meridian/model/eda/eda_outcome.py,sha256=P-0kNIbNXcyqMaNvFxiL3x6fhtYOL2trw-zPPZGXh5w,5670
|
|
57
57
|
meridian/model/eda/eda_spec.py,sha256=diieYyZH0ee3ZLy0rGFMcWrrgiUrz2HctMwOrmtJR6w,2871
|
|
58
58
|
meridian/model/eda/meridian_eda.py,sha256=GTdBaAtfsHS5s6P5ZESaeh1ElKV_o7dSqQosiMnFBKg,7537
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
59
|
+
schema/__init__.py,sha256=Df2XKjMKa0ry7CTlDuJnyxuTdSgAd5za922wMlE98cg,681
|
|
60
|
+
schema/serde/__init__.py,sha256=xyydIcWB5IUpcn3wu1m9HL1fK4gMWURbwTyRsQtolF0,975
|
|
61
|
+
schema/serde/constants.py,sha256=aYtD_RuA0GCkpC4TIQq3VjMqEc837Wn-TlJNm-yn_4Y,1842
|
|
62
|
+
schema/serde/distribution.py,sha256=jy3h6JD1TSs4gwociMis814sz_Fm2kFQ2UbkgjYJW9k,19347
|
|
63
|
+
schema/serde/eda_spec.py,sha256=uOqBeZpUU3Dzzc19rU1LjHWmUhRmVcx8oIZvZfVJHT8,7180
|
|
64
|
+
schema/serde/function_registry.py,sha256=GbgC5_9NDcA9Y7nqmdJ-4-LK5JPhhfI50Lmfy5ZBJOg,4858
|
|
65
|
+
schema/serde/hyperparameters.py,sha256=Igm-PZmIozrsKZH6c-XkrU_Nlf8OAuxpnJJfv7W1SfQ,13524
|
|
66
|
+
schema/serde/inference_data.py,sha256=DrwE9hU8LMrl0z8W_sUSIaPrRdym_lu0iOqpT4KZxsA,3623
|
|
67
|
+
schema/serde/marketing_data.py,sha256=yb-fRTe84Sjg7-v3wsvYRRXvrxLSFWSenO0_ikMvUpk,44845
|
|
68
|
+
schema/serde/meridian_serde.py,sha256=ZG05JaBG4LW8mhl-Cunje9Q6xyR4tyNTtLYedzMBYjA,15985
|
|
69
|
+
schema/serde/serde.py,sha256=8vUqhJxvZgX9UY3rXTyWJznRgapwDzzaHXDHwV_kKTA,1612
|
|
70
|
+
schema/serde/test_data.py,sha256=7hfEWyvZ9WcAkVAOXt6elX8stJlsfhfd-ASlHo9SRb8,107342
|
|
71
|
+
schema/utils/__init__.py,sha256=AkC4NMbmXC3PFBY9dFYxlf3qFsxt5OOBVdc9zmFXsC8,675
|
|
72
|
+
schema/utils/time_record.py,sha256=-KzHFjvSBUUXsfESPAfcJP_VFxaFLqj90Ac0kgKWfpI,4624
|
|
73
|
+
google_meridian-1.3.1.dist-info/METADATA,sha256=wu5D6r6v46vd5g4uKLfXEVuhxkO3jwEgS68wn6m0jR4,9547
|
|
74
|
+
google_meridian-1.3.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
75
|
+
google_meridian-1.3.1.dist-info/top_level.txt,sha256=yWkWDLV_UUanhKmk_xNPiKNdPDl1oyU1sBYwEnhaSf4,16
|
|
76
|
+
google_meridian-1.3.1.dist-info/RECORD,,
|
meridian/backend/__init__.py
CHANGED
|
@@ -19,6 +19,7 @@ import functools
|
|
|
19
19
|
import os
|
|
20
20
|
from typing import Any, Optional, Sequence, Tuple, TYPE_CHECKING, Union
|
|
21
21
|
import warnings
|
|
22
|
+
|
|
22
23
|
from meridian.backend import config
|
|
23
24
|
import numpy as np
|
|
24
25
|
from typing_extensions import Literal
|
|
@@ -220,7 +221,7 @@ def _tf_arange(
|
|
|
220
221
|
|
|
221
222
|
def _jax_cast(x: Any, dtype: Any) -> "_jax.Array":
|
|
222
223
|
"""JAX implementation for cast."""
|
|
223
|
-
return
|
|
224
|
+
return jax_ops.asarray(x, dtype=dtype)
|
|
224
225
|
|
|
225
226
|
|
|
226
227
|
def _jax_divide_no_nan(x, y):
|
|
@@ -305,17 +306,132 @@ def _jax_numpy_function(*args, **kwargs): # pylint: disable=unused-argument
|
|
|
305
306
|
)
|
|
306
307
|
|
|
307
308
|
|
|
308
|
-
def _jax_make_tensor_proto(
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
309
|
+
def _jax_make_tensor_proto(values, dtype=None, shape=None): # pylint: disable=unused-argument
|
|
310
|
+
"""JAX implementation for make_tensor_proto."""
|
|
311
|
+
# pylint: disable=g-direct-tensorflow-import
|
|
312
|
+
from tensorflow.core.framework import tensor_pb2
|
|
313
|
+
from tensorflow.core.framework import tensor_shape_pb2
|
|
314
|
+
from tensorflow.core.framework import types_pb2
|
|
315
|
+
# pylint: enable=g-direct-tensorflow-import
|
|
312
316
|
|
|
317
|
+
if not isinstance(values, np.ndarray):
|
|
318
|
+
values = np.array(values)
|
|
313
319
|
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
320
|
+
if dtype:
|
|
321
|
+
numpy_dtype = np.dtype(dtype)
|
|
322
|
+
values = values.astype(numpy_dtype)
|
|
323
|
+
else:
|
|
324
|
+
numpy_dtype = values.dtype
|
|
325
|
+
|
|
326
|
+
dtype_map = {
|
|
327
|
+
np.dtype(np.float16): types_pb2.DT_HALF,
|
|
328
|
+
np.dtype(np.float32): types_pb2.DT_FLOAT,
|
|
329
|
+
np.dtype(np.float64): types_pb2.DT_DOUBLE,
|
|
330
|
+
np.dtype(np.int32): types_pb2.DT_INT32,
|
|
331
|
+
np.dtype(np.uint8): types_pb2.DT_UINT8,
|
|
332
|
+
np.dtype(np.uint16): types_pb2.DT_UINT16,
|
|
333
|
+
np.dtype(np.uint32): types_pb2.DT_UINT32,
|
|
334
|
+
np.dtype(np.uint64): types_pb2.DT_UINT64,
|
|
335
|
+
np.dtype(np.int16): types_pb2.DT_INT16,
|
|
336
|
+
np.dtype(np.int8): types_pb2.DT_INT8,
|
|
337
|
+
np.dtype(np.int64): types_pb2.DT_INT64,
|
|
338
|
+
np.dtype(np.complex64): types_pb2.DT_COMPLEX64,
|
|
339
|
+
np.dtype(np.complex128): types_pb2.DT_COMPLEX128,
|
|
340
|
+
np.dtype(np.bool_): types_pb2.DT_BOOL,
|
|
341
|
+
# Note: String types are handled outside the map.
|
|
342
|
+
}
|
|
343
|
+
proto_dtype = dtype_map.get(numpy_dtype)
|
|
344
|
+
if proto_dtype is None and numpy_dtype.kind in ("S", "U"):
|
|
345
|
+
proto_dtype = types_pb2.DT_STRING
|
|
346
|
+
|
|
347
|
+
if proto_dtype is None:
|
|
348
|
+
raise TypeError(
|
|
349
|
+
f"Unsupported dtype for TensorProto conversion: {numpy_dtype}"
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
proto = tensor_pb2.TensorProto(
|
|
353
|
+
dtype=proto_dtype,
|
|
354
|
+
tensor_shape=tensor_shape_pb2.TensorShapeProto(
|
|
355
|
+
dim=[
|
|
356
|
+
tensor_shape_pb2.TensorShapeProto.Dim(size=d)
|
|
357
|
+
for d in values.shape
|
|
358
|
+
]
|
|
359
|
+
),
|
|
317
360
|
)
|
|
318
361
|
|
|
362
|
+
proto.tensor_content = values.tobytes()
|
|
363
|
+
return proto
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
def _jax_make_ndarray(proto):
|
|
367
|
+
"""JAX implementation for make_ndarray."""
|
|
368
|
+
# pylint: disable=g-direct-tensorflow-import
|
|
369
|
+
from tensorflow.core.framework import types_pb2
|
|
370
|
+
# pylint: enable=g-direct-tensorflow-import
|
|
371
|
+
|
|
372
|
+
dtype_map = {
|
|
373
|
+
types_pb2.DT_HALF: np.float16,
|
|
374
|
+
types_pb2.DT_FLOAT: np.float32,
|
|
375
|
+
types_pb2.DT_DOUBLE: np.float64,
|
|
376
|
+
types_pb2.DT_INT32: np.int32,
|
|
377
|
+
types_pb2.DT_UINT8: np.uint8,
|
|
378
|
+
types_pb2.DT_UINT16: np.uint16,
|
|
379
|
+
types_pb2.DT_UINT32: np.uint32,
|
|
380
|
+
types_pb2.DT_UINT64: np.uint64,
|
|
381
|
+
types_pb2.DT_INT16: np.int16,
|
|
382
|
+
types_pb2.DT_INT8: np.int8,
|
|
383
|
+
types_pb2.DT_INT64: np.int64,
|
|
384
|
+
types_pb2.DT_COMPLEX64: np.complex64,
|
|
385
|
+
types_pb2.DT_COMPLEX128: np.complex128,
|
|
386
|
+
types_pb2.DT_BOOL: np.bool_,
|
|
387
|
+
types_pb2.DT_STRING: np.bytes_,
|
|
388
|
+
}
|
|
389
|
+
if proto.dtype not in dtype_map:
|
|
390
|
+
raise TypeError(f"Unsupported TensorProto dtype: {proto.dtype}")
|
|
391
|
+
|
|
392
|
+
shape = [d.size for d in proto.tensor_shape.dim]
|
|
393
|
+
dtype = dtype_map[proto.dtype]
|
|
394
|
+
|
|
395
|
+
if proto.tensor_content:
|
|
396
|
+
num_elements = np.prod(shape).item() if shape else 0
|
|
397
|
+
# When deserializing a string from tensor_content, the itemsize is not
|
|
398
|
+
# explicitly stored. We must infer it from the content length and shape.
|
|
399
|
+
if dtype == np.bytes_ and num_elements > 0:
|
|
400
|
+
content_len = len(proto.tensor_content)
|
|
401
|
+
itemsize = content_len // num_elements
|
|
402
|
+
if itemsize * num_elements != content_len:
|
|
403
|
+
raise ValueError(
|
|
404
|
+
"Tensor content size is not a multiple of the number of elements"
|
|
405
|
+
" for string dtype."
|
|
406
|
+
)
|
|
407
|
+
dtype = np.dtype(f"S{itemsize}")
|
|
408
|
+
|
|
409
|
+
return (
|
|
410
|
+
np.frombuffer(proto.tensor_content, dtype=dtype).copy().reshape(shape)
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
# Fallback for protos that store data in val fields instead of tensor_content.
|
|
414
|
+
if dtype == np.float32:
|
|
415
|
+
val_field = proto.float_val
|
|
416
|
+
elif dtype == np.float64:
|
|
417
|
+
val_field = proto.double_val
|
|
418
|
+
elif dtype == np.int32:
|
|
419
|
+
val_field = proto.int_val
|
|
420
|
+
elif dtype == np.int64:
|
|
421
|
+
val_field = proto.int64_val
|
|
422
|
+
elif dtype == np.bool_:
|
|
423
|
+
val_field = proto.bool_val
|
|
424
|
+
else:
|
|
425
|
+
if proto.string_val:
|
|
426
|
+
return np.array(proto.string_val, dtype=np.bytes_).reshape(shape)
|
|
427
|
+
if not any(shape):
|
|
428
|
+
return np.array([], dtype=dtype).reshape(shape)
|
|
429
|
+
raise TypeError(
|
|
430
|
+
f"Unsupported dtype for TensorProto value field fallback: {dtype}"
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
return np.array(val_field, dtype=dtype).reshape(shape)
|
|
434
|
+
|
|
319
435
|
|
|
320
436
|
def _jax_get_indices_where(condition):
|
|
321
437
|
"""JAX implementation for get_indices_where."""
|
|
@@ -497,12 +613,23 @@ def _jax_convert_to_tensor(data, dtype=None):
|
|
|
497
613
|
# JAX does not natively support string tensors in the same way TF does.
|
|
498
614
|
# If a string dtype is requested, or if the data is inherently strings,
|
|
499
615
|
# we fall back to a standard NumPy array.
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
616
|
+
is_string_target = False
|
|
617
|
+
if dtype is not None:
|
|
618
|
+
try:
|
|
619
|
+
if np.dtype(dtype).kind in ("S", "U"):
|
|
620
|
+
is_string_target = True
|
|
621
|
+
except TypeError:
|
|
622
|
+
# This can happen if dtype is not a valid dtype specifier,
|
|
623
|
+
# let jax.asarray handle it.
|
|
624
|
+
pass
|
|
625
|
+
|
|
626
|
+
is_string_data = isinstance(data, (list, np.ndarray)) and np.array(
|
|
627
|
+
data
|
|
628
|
+
).dtype.kind in ("S", "U")
|
|
629
|
+
|
|
630
|
+
if is_string_target or (dtype is None and is_string_data):
|
|
631
|
+
return np.array(data, dtype=dtype)
|
|
632
|
+
|
|
506
633
|
return jax_ops.asarray(data, dtype=dtype)
|
|
507
634
|
|
|
508
635
|
|
|
@@ -535,18 +662,48 @@ def _tf_nanvar(a, axis=None, keepdims=False):
|
|
|
535
662
|
return tf.convert_to_tensor(var)
|
|
536
663
|
|
|
537
664
|
|
|
538
|
-
def _jax_one_hot(
|
|
665
|
+
def _jax_one_hot(
|
|
666
|
+
indices, depth, on_value=None, off_value=None, axis=None, dtype=None
|
|
667
|
+
):
|
|
539
668
|
"""JAX implementation for one_hot."""
|
|
540
|
-
|
|
541
|
-
|
|
669
|
+
import jax.numpy as jnp
|
|
670
|
+
|
|
671
|
+
resolved_dtype = _resolve_dtype(dtype, on_value, off_value, 1, 0)
|
|
672
|
+
jax_axis = -1 if axis is None else axis
|
|
673
|
+
|
|
674
|
+
one_hot_result = jax.nn.one_hot(
|
|
675
|
+
indices, num_classes=depth, dtype=jnp.dtype(resolved_dtype), axis=jax_axis
|
|
542
676
|
)
|
|
543
677
|
|
|
678
|
+
on_val = 1 if on_value is None else on_value
|
|
679
|
+
off_val = 0 if off_value is None else off_value
|
|
680
|
+
|
|
681
|
+
if on_val == 1 and off_val == 0:
|
|
682
|
+
return one_hot_result
|
|
683
|
+
|
|
684
|
+
on_tensor = jnp.array(on_val, dtype=jnp.dtype(resolved_dtype))
|
|
685
|
+
off_tensor = jnp.array(off_val, dtype=jnp.dtype(resolved_dtype))
|
|
686
|
+
|
|
687
|
+
return jnp.where(one_hot_result == 1, on_tensor, off_tensor)
|
|
688
|
+
|
|
544
689
|
|
|
545
|
-
def _jax_roll(
|
|
690
|
+
def _jax_roll(a, shift, axis=None):
|
|
546
691
|
"""JAX implementation for roll."""
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
)
|
|
692
|
+
import jax.numpy as jnp
|
|
693
|
+
|
|
694
|
+
return jnp.roll(a, shift, axis=axis)
|
|
695
|
+
|
|
696
|
+
|
|
697
|
+
def _tf_roll(a, shift: Sequence[int], axis=None):
|
|
698
|
+
"""TensorFlow implementation for roll that handles axis=None."""
|
|
699
|
+
import tensorflow as tf
|
|
700
|
+
|
|
701
|
+
if axis is None:
|
|
702
|
+
original_shape = tf.shape(a)
|
|
703
|
+
flat_tensor = tf.reshape(a, [-1])
|
|
704
|
+
rolled_flat = tf.roll(flat_tensor, shift=shift, axis=0)
|
|
705
|
+
return tf.reshape(rolled_flat, original_shape)
|
|
706
|
+
return tf.roll(a, shift, axis=axis)
|
|
550
707
|
|
|
551
708
|
|
|
552
709
|
def _jax_enable_op_determinism():
|
|
@@ -772,7 +929,7 @@ if _BACKEND == config.Backend.JAX:
|
|
|
772
929
|
newaxis = _ops.newaxis
|
|
773
930
|
TensorShape = _jax_tensor_shape
|
|
774
931
|
int32 = _ops.int32
|
|
775
|
-
string = np.
|
|
932
|
+
string = np.bytes_
|
|
776
933
|
|
|
777
934
|
stabilize_rf_roi_grid = _jax_stabilize_rf_roi_grid
|
|
778
935
|
|
|
@@ -904,7 +1061,7 @@ elif _BACKEND == config.Backend.TENSORFLOW:
|
|
|
904
1061
|
reduce_sum = _ops.reduce_sum
|
|
905
1062
|
repeat = _ops.repeat
|
|
906
1063
|
reshape = _ops.reshape
|
|
907
|
-
roll =
|
|
1064
|
+
roll = _tf_roll
|
|
908
1065
|
set_random_seed = tf_backend.keras.utils.set_random_seed
|
|
909
1066
|
split = _ops.split
|
|
910
1067
|
stack = _ops.stack
|
meridian/backend/test_utils.py
CHANGED
|
@@ -15,11 +15,21 @@
|
|
|
15
15
|
"""Common testing utilities for Meridian, designed to be backend-agnostic."""
|
|
16
16
|
|
|
17
17
|
from typing import Any, Optional
|
|
18
|
+
|
|
18
19
|
from absl.testing import parameterized
|
|
20
|
+
from google.protobuf import descriptor
|
|
21
|
+
from google.protobuf import message
|
|
19
22
|
from meridian import backend
|
|
20
23
|
from meridian.backend import config
|
|
21
24
|
import numpy as np
|
|
22
25
|
|
|
26
|
+
from tensorflow.python.util.protobuf import compare
|
|
27
|
+
# pylint: disable=g-direct-tensorflow-import
|
|
28
|
+
from tensorflow.core.framework import tensor_pb2
|
|
29
|
+
# pylint: enable=g-direct-tensorflow-import
|
|
30
|
+
|
|
31
|
+
FieldDescriptor = descriptor.FieldDescriptor
|
|
32
|
+
|
|
23
33
|
# A type alias for backend-agnostic array-like objects.
|
|
24
34
|
# We use `Any` here to avoid circular dependencies with the backend module
|
|
25
35
|
# while still allowing the function to accept backend-specific tensor types.
|
|
@@ -131,6 +141,118 @@ def assert_all_non_negative(a: ArrayLike, err_msg: str = ""):
|
|
|
131
141
|
raise AssertionError(err_msg or "Array contains negative values.")
|
|
132
142
|
|
|
133
143
|
|
|
144
|
+
# --- Proto Utilities ---
|
|
145
|
+
def normalize_tensor_protos(proto: message.Message):
|
|
146
|
+
"""Recursively normalizes TensorProto messages within a proto (In-place).
|
|
147
|
+
|
|
148
|
+
This ensures a consistent serialization format across different backends
|
|
149
|
+
(e.g., JAX vs TF) by repacking TensorProtos using the current backend's
|
|
150
|
+
canonical method (backend.make_tensor_proto). This handles differences
|
|
151
|
+
like using `bool_val` versus `tensor_content` for boolean tensors.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
proto: The protobuf message object to normalize. This object is modified in
|
|
155
|
+
place.
|
|
156
|
+
"""
|
|
157
|
+
if not isinstance(proto, message.Message):
|
|
158
|
+
return
|
|
159
|
+
|
|
160
|
+
for desc, value in proto.ListFields():
|
|
161
|
+
if desc.type != FieldDescriptor.TYPE_MESSAGE:
|
|
162
|
+
continue
|
|
163
|
+
|
|
164
|
+
# A map is defined as a repeated field whose message type has the
|
|
165
|
+
# map_entry option set.
|
|
166
|
+
is_map = (
|
|
167
|
+
desc.label == FieldDescriptor.LABEL_REPEATED
|
|
168
|
+
and desc.message_type.has_options
|
|
169
|
+
and desc.message_type.GetOptions().map_entry
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
if is_map:
|
|
173
|
+
for item in value.values():
|
|
174
|
+
# Helper checks if values are scalars or messages.
|
|
175
|
+
_process_message_for_normalization(item)
|
|
176
|
+
|
|
177
|
+
elif desc.label == FieldDescriptor.LABEL_REPEATED:
|
|
178
|
+
# Handle standard repeated message fields.
|
|
179
|
+
for item in value:
|
|
180
|
+
_process_message_for_normalization(item)
|
|
181
|
+
else:
|
|
182
|
+
# Handle singular message fields.
|
|
183
|
+
_process_message_for_normalization(value)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _process_message_for_normalization(msg: Any):
|
|
187
|
+
"""Helper to process a potential message during normalization traversal."""
|
|
188
|
+
# Ensure we only process message objects.
|
|
189
|
+
# If msg is a scalar (e.g., string from map<string, string>), stop recursion.
|
|
190
|
+
if not isinstance(msg, message.Message):
|
|
191
|
+
return
|
|
192
|
+
|
|
193
|
+
if isinstance(msg, tensor_pb2.TensorProto):
|
|
194
|
+
_repack_tensor_proto(msg)
|
|
195
|
+
else:
|
|
196
|
+
# If it's another message type, recurse into its fields.
|
|
197
|
+
normalize_tensor_protos(msg)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def _repack_tensor_proto(tensor_proto: "tensor_pb2.TensorProto"):
|
|
201
|
+
"""Repacks a TensorProto in place to use a consistent serialization format."""
|
|
202
|
+
if not tensor_proto.ByteSize():
|
|
203
|
+
return
|
|
204
|
+
|
|
205
|
+
try:
|
|
206
|
+
data_array = backend.make_ndarray(tensor_proto)
|
|
207
|
+
except Exception as e:
|
|
208
|
+
raise ValueError(
|
|
209
|
+
"Failed to deserialize TensorProto during normalization:"
|
|
210
|
+
f" {e}\nProto content:\n{tensor_proto}"
|
|
211
|
+
) from e
|
|
212
|
+
|
|
213
|
+
new_tensor_proto = backend.make_tensor_proto(data_array)
|
|
214
|
+
|
|
215
|
+
tensor_proto.Clear()
|
|
216
|
+
tensor_proto.CopyFrom(new_tensor_proto)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def assert_normalized_proto_equal(
|
|
220
|
+
test_case: parameterized.TestCase,
|
|
221
|
+
expected: message.Message,
|
|
222
|
+
actual: message.Message,
|
|
223
|
+
msg: Optional[str] = None,
|
|
224
|
+
**kwargs: Any,
|
|
225
|
+
):
|
|
226
|
+
"""Compares two protos after normalizing TensorProto fields.
|
|
227
|
+
|
|
228
|
+
Use this instead of compare.assertProtoEqual when protos contain tensors
|
|
229
|
+
to ensure backend-agnostic comparison.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
test_case: The TestCase instance (self).
|
|
233
|
+
expected: The expected protobuf message.
|
|
234
|
+
actual: The actual protobuf message.
|
|
235
|
+
msg: An optional message to display on failure.
|
|
236
|
+
**kwargs: Additional keyword arguments passed to assertProto2Equal (e.g.,
|
|
237
|
+
precision).
|
|
238
|
+
"""
|
|
239
|
+
# Work on copies to avoid mutating the original objects
|
|
240
|
+
expected_copy = expected.__class__()
|
|
241
|
+
expected_copy.CopyFrom(expected)
|
|
242
|
+
actual_copy = actual.__class__()
|
|
243
|
+
actual_copy.CopyFrom(actual)
|
|
244
|
+
|
|
245
|
+
try:
|
|
246
|
+
normalize_tensor_protos(expected_copy)
|
|
247
|
+
normalize_tensor_protos(actual_copy)
|
|
248
|
+
except ValueError as e:
|
|
249
|
+
test_case.fail(f"Proto normalization failed: {e}. {msg}")
|
|
250
|
+
|
|
251
|
+
compare.assertProtoEqual(
|
|
252
|
+
test_case, expected_copy, actual_copy, msg=msg, **kwargs
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
|
|
134
256
|
class MeridianTestCase(parameterized.TestCase):
|
|
135
257
|
"""Base test class for Meridian providing backend-aware utilities.
|
|
136
258
|
|