google-meridian 1.3.0__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,209 @@
1
+ Metadata-Version: 2.4
2
+ Name: google-meridian
3
+ Version: 1.3.1
4
+ Summary: Google's open source mixed marketing model library, helps you understand your return on investment and direct your ad spend with confidence.
5
+ Author-email: The Meridian Authors <no-reply@google.com>
6
+ Project-URL: homepage, https://github.com/google/meridian
7
+ Project-URL: repository, https://github.com/google/meridian
8
+ Project-URL: changelog, https://github.com/google/meridian/blob/main/CHANGELOG.md
9
+ Project-URL: documentation, https://developers.google.com/meridian
10
+ Keywords: mmm
11
+ Classifier: Programming Language :: Python :: 3
12
+ Classifier: Programming Language :: Python :: 3 :: Only
13
+ Classifier: Topic :: Other/Nonlisted Topic
14
+ Classifier: Topic :: Scientific/Engineering :: Information Analysis
15
+ Requires-Python: >=3.10
16
+ Description-Content-Type: text/markdown
17
+ License-File: LICENSE
18
+ Requires-Dist: arviz
19
+ Requires-Dist: altair>=5
20
+ Requires-Dist: immutabledict
21
+ Requires-Dist: joblib
22
+ Requires-Dist: natsort<8,>=7.1.1
23
+ Requires-Dist: numpy<3,>=2.0.2
24
+ Requires-Dist: pandas<3,>=2.2.2
25
+ Requires-Dist: patsy<1,>=0.5.3
26
+ Requires-Dist: scipy<2,>=1.13.1
27
+ Requires-Dist: statsmodels>=0.14.5
28
+ Requires-Dist: tensorflow<2.19,>=2.18
29
+ Requires-Dist: tensorflow-probability<0.26,>=0.25
30
+ Requires-Dist: tf-keras<2.19,>=2.18
31
+ Requires-Dist: xarray
32
+ Provides-Extra: dev
33
+ Requires-Dist: pytest>=8.0.0; extra == "dev"
34
+ Requires-Dist: pytest-xdist; extra == "dev"
35
+ Requires-Dist: pylint>=2.6.0; extra == "dev"
36
+ Requires-Dist: pyink; extra == "dev"
37
+ Provides-Extra: colab
38
+ Requires-Dist: psutil; extra == "colab"
39
+ Requires-Dist: python-calamine; extra == "colab"
40
+ Provides-Extra: and-cuda
41
+ Requires-Dist: tensorflow[and-cuda]<2.19,>=2.18; extra == "and-cuda"
42
+ Provides-Extra: mlflow
43
+ Requires-Dist: mlflow; extra == "mlflow"
44
+ Provides-Extra: jax
45
+ Requires-Dist: jax==0.4.26; extra == "jax"
46
+ Requires-Dist: jaxlib==0.4.26; extra == "jax"
47
+ Requires-Dist: tensorflow-probability[substrates-jax]==0.25.0; extra == "jax"
48
+ Provides-Extra: schema
49
+ Requires-Dist: mmm-proto-schema; extra == "schema"
50
+ Requires-Dist: semver; extra == "schema"
51
+ Dynamic: license-file
52
+
53
+ # About Meridian
54
+
55
+ Marketing mix modeling (MMM) is a statistical analysis technique that measures
56
+ the impact of marketing campaigns and activities to guide budget planning
57
+ decisions and improve overall media effectiveness. MMM uses aggregated data to
58
+ measure impact across marketing channels and account for non-marketing factors
59
+ that impact sales and other key performance indicators (KPIs). MMM is
60
+ privacy-safe and does not use any cookie or user-level information.
61
+
62
+ Meridian is an MMM framework that enables advertisers to set up and run their
63
+ own in-house models. Meridian helps you answer key questions such as:
64
+
65
+ * How did the marketing channels drive my revenue or other KPI?
66
+ * What was my marketing return on investment (ROI)?
67
+ * How do I optimize my marketing budget allocation for the future?
68
+
69
+ Meridian is a highly customizable modeling framework that is based on
70
+ [Bayesian causal inference](https://developers.google.com/meridian/docs/causal-inference/bayesian-inference).
71
+ It is capable of handling large scale geo-level data, which is encouraged if
72
+ available, but it can also be used for national-level modeling. Meridian
73
+ provides clear insights and visualizations to inform business decisions around
74
+ marketing budget and planning. Additionally, Meridian provides methodologies to
75
+ support calibration of MMM with experiments and other prior information, and to
76
+ optimize target ad frequency by utilizing reach and frequency data.
77
+
78
+ If you are using LightweightMMM, see the
79
+ [migration guide](https://developers.google.com/meridian/docs/migrate) to help
80
+ you understand the differences between these MMM projects.
81
+
82
+ ## Install Meridian
83
+
84
+ Python 3.11 or 3.12 is required to use Meridian. We also recommend using a
85
+ minimum of 1 GPU.
86
+
87
+ Note: This project has been tested on T4 GPU using 16 GB of RAM.
88
+
89
+ To install Meridian, run the following command to automatically install the
90
+ latest release from PyPI.
91
+
92
+ * For Linux-GPU users:
93
+
94
+ Note: CUDA toolchain and a compatible GPU device is necessary for
95
+ `[and-cuda]` extra to activate.
96
+
97
+ ```sh
98
+ $ pip install --upgrade google-meridian[and-cuda]
99
+ ```
100
+
101
+ * For macOS and general CPU users:
102
+
103
+ Note: There is no official GPU support for macOS.
104
+
105
+ ```sh
106
+ $ pip install --upgrade google-meridian
107
+ ```
108
+
109
+ Alternatively, run the following command to install the most recent, unreleased
110
+ version from GitHub.
111
+
112
+ * For GPU users:
113
+
114
+ ```sh
115
+ $ pip install --upgrade "google-meridian[and-cuda] @ git+https://github.com/google/meridian.git"
116
+ ```
117
+
118
+ * For CPU users:
119
+
120
+ ```sh
121
+ $ pip install --upgrade git+https://github.com/google/meridian.git
122
+ ```
123
+
124
+ We recommend to install Meridian in a fresh
125
+ [virtual environment](https://virtualenv.pypa.io/en/latest/user_guide.html#quick-start)
126
+ to make sure that correct versions of all the dependencies are installed, as
127
+ defined in [pyproject.toml](https://github.com/google/meridian/blob/main/pyproject.toml).
128
+
129
+ ## How to use the Meridian library
130
+
131
+ To get started with Meridian, you can run the code programmatically using sample
132
+ data with the [Getting Started Colab][3].
133
+
134
+ The Meridian model uses a holistic MCMC sampling approach called
135
+ [No U Turn Sampler (NUTS)](https://www.tensorflow.org/probability/api_docs/python/tfp/experimental/mcmc/NoUTurnSampler)
136
+ which can be compute intensive. To help with this, GPU support has been
137
+ developed across the library (out-of-the-box) using tensors. We recommend
138
+ running your Meridian model on GPUs to get real time optimization results and
139
+ significantly reduce training time.
140
+
141
+ ## Meridian Documentation & Tutorials
142
+
143
+ The following documentation, colab, and video resources will help you get
144
+ started quickly with using Meridian:
145
+
146
+ | Resource | Description |
147
+ | --------------------------- | ---------------------------------------------- |
148
+ | [Meridian documentation][1] | Main landing page for Meridian documentation. |
149
+ | [Meridian basics][2] | Learn about Meridian features, methodologies, and the model math. |
150
+ | [Getting started colab][3] | Install and quickly learn how to use Meridian with this colab tutorial using sample data. |
151
+ | [User guide][4] | A detailed walk-through of how to use Meridian and generating visualizations using your own data. |
152
+ | [Pre-modeling][5] | Prepare and analyze your data before modeling. |
153
+ | [Modeling][6] | Modeling guidance for model refinement and edge cases. |
154
+ | [Post-modeling][7] | Post-modeling guidance for model fit, visualizations, optimizations, refreshing the model, and debugging. |
155
+ | [Migrate from LMMM][8] | Learn about the differences between Meridian and LightweightMMM as you consider migrating. |
156
+ | [API Reference][9] | API reference documentation for the Meridian package. |
157
+ | [Reference list][10] | White papers and other referenced material. |
158
+
159
+ [1]: https://developers.google.com/meridian
160
+ [2]: https://developers.google.com/meridian/docs/basics/meridian-introduction
161
+ [3]: https://developers.google.com/meridian/notebook/meridian-getting-started
162
+ [4]: https://developers.google.com/meridian/docs/user-guide/installing
163
+ [5]: https://developers.google.com/meridian/docs/pre-modeling/collect-data
164
+ [6]: https://developers.google.com/meridian/docs/advanced-modeling/control-variables
165
+ [7]: https://developers.google.com/meridian/docs/post-modeling/model-fit
166
+ [8]: https://developers.google.com/meridian/docs/migrate
167
+ [9]: https://developers.google.com/meridian/reference/api/meridian
168
+ [10]: https://developers.google.com/meridian/docs/reference-list
169
+
170
+ ## Support
171
+
172
+ **Questions about methodology**: Please see the [Modeling](https://developers.google.com/meridian/docs/basics/meridian-introduction) tab in the technical documentation.
173
+
174
+ **Issues installing or using Meridian**: Feel free to post questions in the
175
+ [Discussions](https://github.com/google/meridian/discussions) or [Issues](https://github.com/google/meridian/issues) tabs of the Meridian GitHub repository. The Meridian team responds to
176
+ these questions weekly in batches, so please be patient and don't reach out
177
+ directly to your Google Account teams.
178
+
179
+ **Bug reports**: Please post bug reports to the [Issues](https://github.com/google/meridian/issues)
180
+ tab of the Meridian GitHub repository. We also encourage the community to share
181
+ tips and advice with each other on the [Issues](https://github.com/google/meridian/issues)
182
+ tab. When our team addresses or resolves a new bug, we will notify you through
183
+ the comments on the issue.
184
+
185
+ **Feature requests**: Please post these to the [Discussions](https://github.com/google/meridian/discussions)
186
+ tab of the Meridian GitHub repository. We have an internal roadmap for Meridian
187
+ development, but would love your inputs for new feature requests so that we can
188
+ prioritize them based on the roadmap.
189
+
190
+ **Pull requests**: These are appreciated but are very difficult for us to merge
191
+ because the code in this repository is linked to Google internal systems and has
192
+ to pass internal review. If you submit a pull request and we believe that we can
193
+ incorporate a change in the base code, we will reach out to you directly about
194
+ this.
195
+
196
+ ## Citing Meridian
197
+
198
+ To cite this repository:
199
+
200
+ <!-- mdlint off(SNIPPET_INVALID_LANGUAGE) -->
201
+ ```BibTeX
202
+ @software{meridian_github,
203
+ author = {Google Meridian Marketing Mix Modeling Team},
204
+ title = {Meridian: Marketing Mix Modeling},
205
+ url = {https://github.com/google/meridian},
206
+ version = {1.3.1},
207
+ year = {2025},
208
+ }
209
+ ```
@@ -1,7 +1,7 @@
1
- google_meridian-1.3.0.dist-info/licenses/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
1
+ google_meridian-1.3.1.dist-info/licenses/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
2
2
  meridian/__init__.py,sha256=0fOT5oNZF7-pbiWWGUefV-ysafttieG079m1ijMFQO8,861
3
3
  meridian/constants.py,sha256=ZmMIoJDFQvKIOVG9oPOQ7Cj16wt4HDS5fCPBrz_KiLE,20308
4
- meridian/version.py,sha256=PmfWvor_VonhwnCn2anG7FaY2z7fcdLbE_3OeBHekCg,644
4
+ meridian/version.py,sha256=mmCkrGRWB8mI33apVnverT2ysfJofSyhlXSCVNotj9U,644
5
5
  meridian/analysis/__init__.py,sha256=NWDhtRkKs-n66C6746rXT7Wk8tLdkrT9NDvuPG_B_5c,874
6
6
  meridian/analysis/analyzer.py,sha256=EUKoN69EEz4vt4BpyRvrwEoHt7eLZ8UM6emNe14Tbj0,219975
7
7
  meridian/analysis/formatter.py,sha256=AN2M4jdUV88XjNdi5-dK_mITES1J1Dk3Zs7DYo3OTKg,7290
@@ -25,9 +25,9 @@ meridian/analysis/templates/style.css,sha256=RODTWc2pXcG9zW3q9SEJpVXgeD-WwQgzLpm
25
25
  meridian/analysis/templates/style.scss,sha256=nSrZOpcIrVyiL4eC9jLUlxIZtAKZ0Rt8pwfk4H1nMrs,5076
26
26
  meridian/analysis/templates/summary.html.jinja,sha256=LuENVDHYIpNo4pzloYaCR2K9XN1Ow6_9oQOcOwD9nGg,1707
27
27
  meridian/analysis/templates/table.html.jinja,sha256=mvLMZx92RcD2JAS2w2eZtfYG-6WdfwYVo7pM8TbHp4g,1176
28
- meridian/backend/__init__.py,sha256=EzsRE_E3tmxNkRPuNannz6Ro0IddvEXaNZ365GG5X8o,34341
28
+ meridian/backend/__init__.py,sha256=ftXcdb3tIky_m8exhD9RRhaEqTEvMcqZ8lGkkR1PjsE,39495
29
29
  meridian/backend/config.py,sha256=B9VQnhBfg9RW04GNbt7F5uCugByenoJzt-keFLLYEp8,3561
30
- meridian/backend/test_utils.py,sha256=XyYZ61o7fRCQq6rX75eON9AudvPHHdWzxHYK8_dJJvE,7753
30
+ meridian/backend/test_utils.py,sha256=DYU5IpWiyM26aTY9Q84mOGRw0dQ9XmsZRsAJNFuZDp0,11667
31
31
  meridian/data/__init__.py,sha256=StIe-wfYnnbfUbKtZHwnAQcRQUS8XCZk_PCaEzw90Ww,929
32
32
  meridian/data/arg_builder.py,sha256=Kqlt88bOqFj6D3xNwvWo4MBwNwcDFHzd-wMfEOmLoPU,3741
33
33
  meridian/data/data_frame_input_data_builder.py,sha256=_hexZMFAuAowgo6FaOGElHSFHqhGnHQwEEBcwnT3zUE,27295
@@ -44,7 +44,7 @@ meridian/model/adstock_hill.py,sha256=HoRKjyL03pCTBz6Utof9wEvlQCFM43BvrEW_oupj7N
44
44
  meridian/model/knots.py,sha256=87kw5oa3T1k9GgT_aWXTqQx5XCxLsS2w1hnzc581XL0,26677
45
45
  meridian/model/media.py,sha256=skjy4Vd8LfDQWlqR_2lJ1qbG9UcS1dow5W45BAu4qk8,14599
46
46
  meridian/model/model.py,sha256=jMtfl7woWtJ8M8AX42QeZ5hUS8hlhPdZ-9OU8KahjKA,68984
47
- meridian/model/model_test_data.py,sha256=s8G1NjAYcF0-nCOPBZp6rxnKMLGdGo97Bg14mKumStE,23359
47
+ meridian/model/model_test_data.py,sha256=XGBz8RGdCsjAUOmgxX3CfWSj-_hdq2Lc8saFCqmImwM,23901
48
48
  meridian/model/posterior_sampler.py,sha256=f3MayglIgBeBjWeXJU_RgT9cCugcjJ3aEjHqaWPsTbg,26806
49
49
  meridian/model/prior_distribution.py,sha256=ZArW4uXIPPQL6hRWiGZUzcHktbkjE_vOklvlbp9LR64,57662
50
50
  meridian/model/prior_sampler.py,sha256=iLvCefhA4WY0ENcnLK9471WUZPPyzQ1je58MRjxKv74,25460
@@ -52,11 +52,25 @@ meridian/model/spec.py,sha256=VlK6WJiPo2lzOF0O2judtJ6O3uEw7wYL5AT8bioq4gE,19188
52
52
  meridian/model/transformers.py,sha256=HxlVJitxP-wu-NOHU0tArFUZ4NAO3c7adAYj4Zvqnvo,8363
53
53
  meridian/model/eda/__init__.py,sha256=w3p7ZUZLq5TOEHm8n2P1CWjGrzuNrkqSSnVFdlw17Dk,812
54
54
  meridian/model/eda/constants.py,sha256=V9aOHQDvB3WAEyT0NE4gE8rqbStaGVh3XDlBPOKpuLc,739
55
- meridian/model/eda/eda_engine.py,sha256=5LCOtRSCLZf2ch70SN177UZl00VUWoswS_9Px3xUD9E,63672
55
+ meridian/model/eda/eda_engine.py,sha256=5Ikgiz-6d3uTBan71WwsgPHOqk0fAy390ZLCq9L6HoY,64846
56
56
  meridian/model/eda/eda_outcome.py,sha256=P-0kNIbNXcyqMaNvFxiL3x6fhtYOL2trw-zPPZGXh5w,5670
57
57
  meridian/model/eda/eda_spec.py,sha256=diieYyZH0ee3ZLy0rGFMcWrrgiUrz2HctMwOrmtJR6w,2871
58
58
  meridian/model/eda/meridian_eda.py,sha256=GTdBaAtfsHS5s6P5ZESaeh1ElKV_o7dSqQosiMnFBKg,7537
59
- google_meridian-1.3.0.dist-info/METADATA,sha256=FHLqp_ZLBXWJHpZWW2gdWg8Fy37cUC7qAlsv_-cIrc8,22470
60
- google_meridian-1.3.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
61
- google_meridian-1.3.0.dist-info/top_level.txt,sha256=nwaCebZvvU34EopTKZsjK0OMTFjVnkf4FfnBN_TAc0g,9
62
- google_meridian-1.3.0.dist-info/RECORD,,
59
+ schema/__init__.py,sha256=Df2XKjMKa0ry7CTlDuJnyxuTdSgAd5za922wMlE98cg,681
60
+ schema/serde/__init__.py,sha256=xyydIcWB5IUpcn3wu1m9HL1fK4gMWURbwTyRsQtolF0,975
61
+ schema/serde/constants.py,sha256=aYtD_RuA0GCkpC4TIQq3VjMqEc837Wn-TlJNm-yn_4Y,1842
62
+ schema/serde/distribution.py,sha256=jy3h6JD1TSs4gwociMis814sz_Fm2kFQ2UbkgjYJW9k,19347
63
+ schema/serde/eda_spec.py,sha256=uOqBeZpUU3Dzzc19rU1LjHWmUhRmVcx8oIZvZfVJHT8,7180
64
+ schema/serde/function_registry.py,sha256=GbgC5_9NDcA9Y7nqmdJ-4-LK5JPhhfI50Lmfy5ZBJOg,4858
65
+ schema/serde/hyperparameters.py,sha256=Igm-PZmIozrsKZH6c-XkrU_Nlf8OAuxpnJJfv7W1SfQ,13524
66
+ schema/serde/inference_data.py,sha256=DrwE9hU8LMrl0z8W_sUSIaPrRdym_lu0iOqpT4KZxsA,3623
67
+ schema/serde/marketing_data.py,sha256=yb-fRTe84Sjg7-v3wsvYRRXvrxLSFWSenO0_ikMvUpk,44845
68
+ schema/serde/meridian_serde.py,sha256=ZG05JaBG4LW8mhl-Cunje9Q6xyR4tyNTtLYedzMBYjA,15985
69
+ schema/serde/serde.py,sha256=8vUqhJxvZgX9UY3rXTyWJznRgapwDzzaHXDHwV_kKTA,1612
70
+ schema/serde/test_data.py,sha256=7hfEWyvZ9WcAkVAOXt6elX8stJlsfhfd-ASlHo9SRb8,107342
71
+ schema/utils/__init__.py,sha256=AkC4NMbmXC3PFBY9dFYxlf3qFsxt5OOBVdc9zmFXsC8,675
72
+ schema/utils/time_record.py,sha256=-KzHFjvSBUUXsfESPAfcJP_VFxaFLqj90Ac0kgKWfpI,4624
73
+ google_meridian-1.3.1.dist-info/METADATA,sha256=wu5D6r6v46vd5g4uKLfXEVuhxkO3jwEgS68wn6m0jR4,9547
74
+ google_meridian-1.3.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
75
+ google_meridian-1.3.1.dist-info/top_level.txt,sha256=yWkWDLV_UUanhKmk_xNPiKNdPDl1oyU1sBYwEnhaSf4,16
76
+ google_meridian-1.3.1.dist-info/RECORD,,
@@ -19,6 +19,7 @@ import functools
19
19
  import os
20
20
  from typing import Any, Optional, Sequence, Tuple, TYPE_CHECKING, Union
21
21
  import warnings
22
+
22
23
  from meridian.backend import config
23
24
  import numpy as np
24
25
  from typing_extensions import Literal
@@ -220,7 +221,7 @@ def _tf_arange(
220
221
 
221
222
  def _jax_cast(x: Any, dtype: Any) -> "_jax.Array":
222
223
  """JAX implementation for cast."""
223
- return x.astype(dtype)
224
+ return jax_ops.asarray(x, dtype=dtype)
224
225
 
225
226
 
226
227
  def _jax_divide_no_nan(x, y):
@@ -305,17 +306,132 @@ def _jax_numpy_function(*args, **kwargs): # pylint: disable=unused-argument
305
306
  )
306
307
 
307
308
 
308
- def _jax_make_tensor_proto(*args, **kwargs): # pylint: disable=unused-argument
309
- raise NotImplementedError(
310
- "backend.make_tensor_proto is not implemented for the JAX backend."
311
- )
309
+ def _jax_make_tensor_proto(values, dtype=None, shape=None): # pylint: disable=unused-argument
310
+ """JAX implementation for make_tensor_proto."""
311
+ # pylint: disable=g-direct-tensorflow-import
312
+ from tensorflow.core.framework import tensor_pb2
313
+ from tensorflow.core.framework import tensor_shape_pb2
314
+ from tensorflow.core.framework import types_pb2
315
+ # pylint: enable=g-direct-tensorflow-import
312
316
 
317
+ if not isinstance(values, np.ndarray):
318
+ values = np.array(values)
313
319
 
314
- def _jax_make_ndarray(*args, **kwargs): # pylint: disable=unused-argument
315
- raise NotImplementedError(
316
- "backend.make_ndarray is not implemented for the JAX backend."
320
+ if dtype:
321
+ numpy_dtype = np.dtype(dtype)
322
+ values = values.astype(numpy_dtype)
323
+ else:
324
+ numpy_dtype = values.dtype
325
+
326
+ dtype_map = {
327
+ np.dtype(np.float16): types_pb2.DT_HALF,
328
+ np.dtype(np.float32): types_pb2.DT_FLOAT,
329
+ np.dtype(np.float64): types_pb2.DT_DOUBLE,
330
+ np.dtype(np.int32): types_pb2.DT_INT32,
331
+ np.dtype(np.uint8): types_pb2.DT_UINT8,
332
+ np.dtype(np.uint16): types_pb2.DT_UINT16,
333
+ np.dtype(np.uint32): types_pb2.DT_UINT32,
334
+ np.dtype(np.uint64): types_pb2.DT_UINT64,
335
+ np.dtype(np.int16): types_pb2.DT_INT16,
336
+ np.dtype(np.int8): types_pb2.DT_INT8,
337
+ np.dtype(np.int64): types_pb2.DT_INT64,
338
+ np.dtype(np.complex64): types_pb2.DT_COMPLEX64,
339
+ np.dtype(np.complex128): types_pb2.DT_COMPLEX128,
340
+ np.dtype(np.bool_): types_pb2.DT_BOOL,
341
+ # Note: String types are handled outside the map.
342
+ }
343
+ proto_dtype = dtype_map.get(numpy_dtype)
344
+ if proto_dtype is None and numpy_dtype.kind in ("S", "U"):
345
+ proto_dtype = types_pb2.DT_STRING
346
+
347
+ if proto_dtype is None:
348
+ raise TypeError(
349
+ f"Unsupported dtype for TensorProto conversion: {numpy_dtype}"
350
+ )
351
+
352
+ proto = tensor_pb2.TensorProto(
353
+ dtype=proto_dtype,
354
+ tensor_shape=tensor_shape_pb2.TensorShapeProto(
355
+ dim=[
356
+ tensor_shape_pb2.TensorShapeProto.Dim(size=d)
357
+ for d in values.shape
358
+ ]
359
+ ),
317
360
  )
318
361
 
362
+ proto.tensor_content = values.tobytes()
363
+ return proto
364
+
365
+
366
+ def _jax_make_ndarray(proto):
367
+ """JAX implementation for make_ndarray."""
368
+ # pylint: disable=g-direct-tensorflow-import
369
+ from tensorflow.core.framework import types_pb2
370
+ # pylint: enable=g-direct-tensorflow-import
371
+
372
+ dtype_map = {
373
+ types_pb2.DT_HALF: np.float16,
374
+ types_pb2.DT_FLOAT: np.float32,
375
+ types_pb2.DT_DOUBLE: np.float64,
376
+ types_pb2.DT_INT32: np.int32,
377
+ types_pb2.DT_UINT8: np.uint8,
378
+ types_pb2.DT_UINT16: np.uint16,
379
+ types_pb2.DT_UINT32: np.uint32,
380
+ types_pb2.DT_UINT64: np.uint64,
381
+ types_pb2.DT_INT16: np.int16,
382
+ types_pb2.DT_INT8: np.int8,
383
+ types_pb2.DT_INT64: np.int64,
384
+ types_pb2.DT_COMPLEX64: np.complex64,
385
+ types_pb2.DT_COMPLEX128: np.complex128,
386
+ types_pb2.DT_BOOL: np.bool_,
387
+ types_pb2.DT_STRING: np.bytes_,
388
+ }
389
+ if proto.dtype not in dtype_map:
390
+ raise TypeError(f"Unsupported TensorProto dtype: {proto.dtype}")
391
+
392
+ shape = [d.size for d in proto.tensor_shape.dim]
393
+ dtype = dtype_map[proto.dtype]
394
+
395
+ if proto.tensor_content:
396
+ num_elements = np.prod(shape).item() if shape else 0
397
+ # When deserializing a string from tensor_content, the itemsize is not
398
+ # explicitly stored. We must infer it from the content length and shape.
399
+ if dtype == np.bytes_ and num_elements > 0:
400
+ content_len = len(proto.tensor_content)
401
+ itemsize = content_len // num_elements
402
+ if itemsize * num_elements != content_len:
403
+ raise ValueError(
404
+ "Tensor content size is not a multiple of the number of elements"
405
+ " for string dtype."
406
+ )
407
+ dtype = np.dtype(f"S{itemsize}")
408
+
409
+ return (
410
+ np.frombuffer(proto.tensor_content, dtype=dtype).copy().reshape(shape)
411
+ )
412
+
413
+ # Fallback for protos that store data in val fields instead of tensor_content.
414
+ if dtype == np.float32:
415
+ val_field = proto.float_val
416
+ elif dtype == np.float64:
417
+ val_field = proto.double_val
418
+ elif dtype == np.int32:
419
+ val_field = proto.int_val
420
+ elif dtype == np.int64:
421
+ val_field = proto.int64_val
422
+ elif dtype == np.bool_:
423
+ val_field = proto.bool_val
424
+ else:
425
+ if proto.string_val:
426
+ return np.array(proto.string_val, dtype=np.bytes_).reshape(shape)
427
+ if not any(shape):
428
+ return np.array([], dtype=dtype).reshape(shape)
429
+ raise TypeError(
430
+ f"Unsupported dtype for TensorProto value field fallback: {dtype}"
431
+ )
432
+
433
+ return np.array(val_field, dtype=dtype).reshape(shape)
434
+
319
435
 
320
436
  def _jax_get_indices_where(condition):
321
437
  """JAX implementation for get_indices_where."""
@@ -497,12 +613,23 @@ def _jax_convert_to_tensor(data, dtype=None):
497
613
  # JAX does not natively support string tensors in the same way TF does.
498
614
  # If a string dtype is requested, or if the data is inherently strings,
499
615
  # we fall back to a standard NumPy array.
500
- if dtype == np.str_ or (
501
- dtype is None
502
- and isinstance(data, (list, np.ndarray))
503
- and np.array(data).dtype.kind in ("S", "U")
504
- ):
505
- return np.array(data, dtype=np.str_)
616
+ is_string_target = False
617
+ if dtype is not None:
618
+ try:
619
+ if np.dtype(dtype).kind in ("S", "U"):
620
+ is_string_target = True
621
+ except TypeError:
622
+ # This can happen if dtype is not a valid dtype specifier,
623
+ # let jax.asarray handle it.
624
+ pass
625
+
626
+ is_string_data = isinstance(data, (list, np.ndarray)) and np.array(
627
+ data
628
+ ).dtype.kind in ("S", "U")
629
+
630
+ if is_string_target or (dtype is None and is_string_data):
631
+ return np.array(data, dtype=dtype)
632
+
506
633
  return jax_ops.asarray(data, dtype=dtype)
507
634
 
508
635
 
@@ -535,18 +662,48 @@ def _tf_nanvar(a, axis=None, keepdims=False):
535
662
  return tf.convert_to_tensor(var)
536
663
 
537
664
 
538
- def _jax_one_hot(*args, **kwargs): # pylint: disable=unused-argument
665
+ def _jax_one_hot(
666
+ indices, depth, on_value=None, off_value=None, axis=None, dtype=None
667
+ ):
539
668
  """JAX implementation for one_hot."""
540
- raise NotImplementedError(
541
- "backend.one_hot is not implemented for the JAX backend."
669
+ import jax.numpy as jnp
670
+
671
+ resolved_dtype = _resolve_dtype(dtype, on_value, off_value, 1, 0)
672
+ jax_axis = -1 if axis is None else axis
673
+
674
+ one_hot_result = jax.nn.one_hot(
675
+ indices, num_classes=depth, dtype=jnp.dtype(resolved_dtype), axis=jax_axis
542
676
  )
543
677
 
678
+ on_val = 1 if on_value is None else on_value
679
+ off_val = 0 if off_value is None else off_value
680
+
681
+ if on_val == 1 and off_val == 0:
682
+ return one_hot_result
683
+
684
+ on_tensor = jnp.array(on_val, dtype=jnp.dtype(resolved_dtype))
685
+ off_tensor = jnp.array(off_val, dtype=jnp.dtype(resolved_dtype))
686
+
687
+ return jnp.where(one_hot_result == 1, on_tensor, off_tensor)
688
+
544
689
 
545
- def _jax_roll(*args, **kwargs): # pylint: disable=unused-argument
690
+ def _jax_roll(a, shift, axis=None):
546
691
  """JAX implementation for roll."""
547
- raise NotImplementedError(
548
- "backend.roll is not implemented for the JAX backend."
549
- )
692
+ import jax.numpy as jnp
693
+
694
+ return jnp.roll(a, shift, axis=axis)
695
+
696
+
697
+ def _tf_roll(a, shift: Sequence[int], axis=None):
698
+ """TensorFlow implementation for roll that handles axis=None."""
699
+ import tensorflow as tf
700
+
701
+ if axis is None:
702
+ original_shape = tf.shape(a)
703
+ flat_tensor = tf.reshape(a, [-1])
704
+ rolled_flat = tf.roll(flat_tensor, shift=shift, axis=0)
705
+ return tf.reshape(rolled_flat, original_shape)
706
+ return tf.roll(a, shift, axis=axis)
550
707
 
551
708
 
552
709
  def _jax_enable_op_determinism():
@@ -772,7 +929,7 @@ if _BACKEND == config.Backend.JAX:
772
929
  newaxis = _ops.newaxis
773
930
  TensorShape = _jax_tensor_shape
774
931
  int32 = _ops.int32
775
- string = np.str_
932
+ string = np.bytes_
776
933
 
777
934
  stabilize_rf_roi_grid = _jax_stabilize_rf_roi_grid
778
935
 
@@ -904,7 +1061,7 @@ elif _BACKEND == config.Backend.TENSORFLOW:
904
1061
  reduce_sum = _ops.reduce_sum
905
1062
  repeat = _ops.repeat
906
1063
  reshape = _ops.reshape
907
- roll = _ops.roll
1064
+ roll = _tf_roll
908
1065
  set_random_seed = tf_backend.keras.utils.set_random_seed
909
1066
  split = _ops.split
910
1067
  stack = _ops.stack
@@ -15,11 +15,21 @@
15
15
  """Common testing utilities for Meridian, designed to be backend-agnostic."""
16
16
 
17
17
  from typing import Any, Optional
18
+
18
19
  from absl.testing import parameterized
20
+ from google.protobuf import descriptor
21
+ from google.protobuf import message
19
22
  from meridian import backend
20
23
  from meridian.backend import config
21
24
  import numpy as np
22
25
 
26
+ from tensorflow.python.util.protobuf import compare
27
+ # pylint: disable=g-direct-tensorflow-import
28
+ from tensorflow.core.framework import tensor_pb2
29
+ # pylint: enable=g-direct-tensorflow-import
30
+
31
+ FieldDescriptor = descriptor.FieldDescriptor
32
+
23
33
  # A type alias for backend-agnostic array-like objects.
24
34
  # We use `Any` here to avoid circular dependencies with the backend module
25
35
  # while still allowing the function to accept backend-specific tensor types.
@@ -131,6 +141,118 @@ def assert_all_non_negative(a: ArrayLike, err_msg: str = ""):
131
141
  raise AssertionError(err_msg or "Array contains negative values.")
132
142
 
133
143
 
144
+ # --- Proto Utilities ---
145
+ def normalize_tensor_protos(proto: message.Message):
146
+ """Recursively normalizes TensorProto messages within a proto (In-place).
147
+
148
+ This ensures a consistent serialization format across different backends
149
+ (e.g., JAX vs TF) by repacking TensorProtos using the current backend's
150
+ canonical method (backend.make_tensor_proto). This handles differences
151
+ like using `bool_val` versus `tensor_content` for boolean tensors.
152
+
153
+ Args:
154
+ proto: The protobuf message object to normalize. This object is modified in
155
+ place.
156
+ """
157
+ if not isinstance(proto, message.Message):
158
+ return
159
+
160
+ for desc, value in proto.ListFields():
161
+ if desc.type != FieldDescriptor.TYPE_MESSAGE:
162
+ continue
163
+
164
+ # A map is defined as a repeated field whose message type has the
165
+ # map_entry option set.
166
+ is_map = (
167
+ desc.label == FieldDescriptor.LABEL_REPEATED
168
+ and desc.message_type.has_options
169
+ and desc.message_type.GetOptions().map_entry
170
+ )
171
+
172
+ if is_map:
173
+ for item in value.values():
174
+ # Helper checks if values are scalars or messages.
175
+ _process_message_for_normalization(item)
176
+
177
+ elif desc.label == FieldDescriptor.LABEL_REPEATED:
178
+ # Handle standard repeated message fields.
179
+ for item in value:
180
+ _process_message_for_normalization(item)
181
+ else:
182
+ # Handle singular message fields.
183
+ _process_message_for_normalization(value)
184
+
185
+
186
+ def _process_message_for_normalization(msg: Any):
187
+ """Helper to process a potential message during normalization traversal."""
188
+ # Ensure we only process message objects.
189
+ # If msg is a scalar (e.g., string from map<string, string>), stop recursion.
190
+ if not isinstance(msg, message.Message):
191
+ return
192
+
193
+ if isinstance(msg, tensor_pb2.TensorProto):
194
+ _repack_tensor_proto(msg)
195
+ else:
196
+ # If it's another message type, recurse into its fields.
197
+ normalize_tensor_protos(msg)
198
+
199
+
200
+ def _repack_tensor_proto(tensor_proto: "tensor_pb2.TensorProto"):
201
+ """Repacks a TensorProto in place to use a consistent serialization format."""
202
+ if not tensor_proto.ByteSize():
203
+ return
204
+
205
+ try:
206
+ data_array = backend.make_ndarray(tensor_proto)
207
+ except Exception as e:
208
+ raise ValueError(
209
+ "Failed to deserialize TensorProto during normalization:"
210
+ f" {e}\nProto content:\n{tensor_proto}"
211
+ ) from e
212
+
213
+ new_tensor_proto = backend.make_tensor_proto(data_array)
214
+
215
+ tensor_proto.Clear()
216
+ tensor_proto.CopyFrom(new_tensor_proto)
217
+
218
+
219
+ def assert_normalized_proto_equal(
220
+ test_case: parameterized.TestCase,
221
+ expected: message.Message,
222
+ actual: message.Message,
223
+ msg: Optional[str] = None,
224
+ **kwargs: Any,
225
+ ):
226
+ """Compares two protos after normalizing TensorProto fields.
227
+
228
+ Use this instead of compare.assertProtoEqual when protos contain tensors
229
+ to ensure backend-agnostic comparison.
230
+
231
+ Args:
232
+ test_case: The TestCase instance (self).
233
+ expected: The expected protobuf message.
234
+ actual: The actual protobuf message.
235
+ msg: An optional message to display on failure.
236
+ **kwargs: Additional keyword arguments passed to assertProto2Equal (e.g.,
237
+ precision).
238
+ """
239
+ # Work on copies to avoid mutating the original objects
240
+ expected_copy = expected.__class__()
241
+ expected_copy.CopyFrom(expected)
242
+ actual_copy = actual.__class__()
243
+ actual_copy.CopyFrom(actual)
244
+
245
+ try:
246
+ normalize_tensor_protos(expected_copy)
247
+ normalize_tensor_protos(actual_copy)
248
+ except ValueError as e:
249
+ test_case.fail(f"Proto normalization failed: {e}. {msg}")
250
+
251
+ compare.assertProtoEqual(
252
+ test_case, expected_copy, actual_copy, msg=msg, **kwargs
253
+ )
254
+
255
+
134
256
  class MeridianTestCase(parameterized.TestCase):
135
257
  """Base test class for Meridian providing backend-aware utilities.
136
258