google-meridian 1.2.1__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. google_meridian-1.3.1.dist-info/METADATA +209 -0
  2. google_meridian-1.3.1.dist-info/RECORD +76 -0
  3. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/top_level.txt +1 -0
  4. meridian/analysis/__init__.py +2 -0
  5. meridian/analysis/analyzer.py +179 -105
  6. meridian/analysis/formatter.py +2 -2
  7. meridian/analysis/optimizer.py +227 -87
  8. meridian/analysis/review/__init__.py +20 -0
  9. meridian/analysis/review/checks.py +721 -0
  10. meridian/analysis/review/configs.py +110 -0
  11. meridian/analysis/review/constants.py +40 -0
  12. meridian/analysis/review/results.py +544 -0
  13. meridian/analysis/review/reviewer.py +186 -0
  14. meridian/analysis/summarizer.py +21 -34
  15. meridian/analysis/templates/chips.html.jinja +12 -0
  16. meridian/analysis/test_utils.py +27 -5
  17. meridian/analysis/visualizer.py +41 -57
  18. meridian/backend/__init__.py +457 -118
  19. meridian/backend/test_utils.py +162 -0
  20. meridian/constants.py +39 -3
  21. meridian/model/__init__.py +1 -0
  22. meridian/model/eda/__init__.py +3 -0
  23. meridian/model/eda/constants.py +21 -0
  24. meridian/model/eda/eda_engine.py +1309 -196
  25. meridian/model/eda/eda_outcome.py +200 -0
  26. meridian/model/eda/eda_spec.py +84 -0
  27. meridian/model/eda/meridian_eda.py +220 -0
  28. meridian/model/knots.py +55 -49
  29. meridian/model/media.py +10 -8
  30. meridian/model/model.py +79 -16
  31. meridian/model/model_test_data.py +53 -0
  32. meridian/model/posterior_sampler.py +39 -32
  33. meridian/model/prior_distribution.py +12 -2
  34. meridian/model/prior_sampler.py +146 -90
  35. meridian/model/spec.py +7 -8
  36. meridian/model/transformers.py +11 -3
  37. meridian/version.py +1 -1
  38. schema/__init__.py +18 -0
  39. schema/serde/__init__.py +26 -0
  40. schema/serde/constants.py +48 -0
  41. schema/serde/distribution.py +515 -0
  42. schema/serde/eda_spec.py +192 -0
  43. schema/serde/function_registry.py +143 -0
  44. schema/serde/hyperparameters.py +363 -0
  45. schema/serde/inference_data.py +105 -0
  46. schema/serde/marketing_data.py +1321 -0
  47. schema/serde/meridian_serde.py +413 -0
  48. schema/serde/serde.py +47 -0
  49. schema/serde/test_data.py +4608 -0
  50. schema/utils/__init__.py +17 -0
  51. schema/utils/time_record.py +156 -0
  52. google_meridian-1.2.1.dist-info/METADATA +0 -409
  53. google_meridian-1.2.1.dist-info/RECORD +0 -52
  54. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/WHEEL +0 -0
  55. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,48 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Constants shared across the Meridian serde library."""
16
+
17
+ # Constants for hyperparameters protobuf structure
18
+ BASELINE_GEO_ONEOF = 'baseline_geo_oneof'
19
+ BASELINE_GEO_INT = 'baseline_geo_int'
20
+ BASELINE_GEO_STRING = 'baseline_geo_string'
21
+ CONTROL_POPULATION_SCALING_ID = 'control_population_scaling_id'
22
+ HOLDOUT_ID = 'holdout_id'
23
+ NON_MEDIA_POPULATION_SCALING_ID = 'non_media_population_scaling_id'
24
+ ADSTOCK_DECAY_SPEC = 'adstock_decay_spec'
25
+ GLOBAL_ADSTOCK_DECAY = 'global_adstock_decay'
26
+ ADSTOCK_DECAY_BY_CHANNEL = 'adstock_decay_by_channel'
27
+ DEFAULT_DECAY = 'geometric'
28
+
29
+ # Constants for marketing data protobuf structure
30
+ GEO_INFO = 'geo_info'
31
+ METADATA = 'metadata'
32
+ REACH_FREQUENCY = 'reach_frequency'
33
+
34
+ # Constants for distribution protobuf structure
35
+ DISTRIBUTION_TYPE = 'distribution_type'
36
+ BATCH_BROADCAST_DISTRIBUTION = 'batch_broadcast'
37
+ DETERMINISTIC_DISTRIBUTION = 'deterministic'
38
+ HALF_NORMAL_DISTRIBUTION = 'half_normal'
39
+ LOG_NORMAL_DISTRIBUTION = 'log_normal'
40
+ NORMAL_DISTRIBUTION = 'normal'
41
+ TRANSFORMED_DISTRIBUTION = 'transformed'
42
+ TRUNCATED_NORMAL_DISTRIBUTION = 'truncated_normal'
43
+ UNIFORM_DISTRIBUTION = 'uniform'
44
+ BETA_DISTRIBUTION = 'beta'
45
+ BIJECTOR_TYPE = 'bijector_type'
46
+ SHIFT_BIJECTOR = 'shift'
47
+ SCALE_BIJECTOR = 'scale'
48
+ RECIPROCAL_BIJECTOR = 'reciprocal'
@@ -0,0 +1,515 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Serialization and deserialization of `Distribution` objects for priors."""
16
+
17
+ from __future__ import annotations
18
+
19
+ import inspect
20
+ import types
21
+ from typing import Any, Sequence, TypeVar
22
+ import warnings
23
+
24
+ from meridian import backend
25
+ from meridian import constants
26
+ from meridian.model import prior_distribution as pd
27
+ from mmm.v1.model.meridian import meridian_model_pb2 as meridian_pb
28
+ from schema.serde import constants as sc
29
+ from schema.serde import function_registry as function_registry_utils
30
+ from schema.serde import serde
31
+
32
+ from tensorflow.core.framework import tensor_shape_pb2 # pylint: disable=g-direct-tensorflow-import
33
+
34
+ FunctionRegistry = function_registry_utils.FunctionRegistry
35
+
36
+
37
+ MeridianPriorDistributions = (
38
+ meridian_pb.PriorTfpDistributions
39
+ )
40
+
41
+ _FUNCTION_REGISTRY_NAME = "function_registry"
42
+
43
+
44
+ # TODO: b/436637084 - Delete enumerated schema.
45
+ class DistributionSerde(
46
+ serde.Serde[MeridianPriorDistributions, pd.PriorDistribution]
47
+ ):
48
+ """Serializes and deserializes a Meridian prior distributions container into a `Distribution` proto."""
49
+
50
+ def __init__(self, function_registry: FunctionRegistry):
51
+ """Initializes a `DistributionSerde` instance.
52
+
53
+ Args:
54
+ function_registry: A lookup table containing custom functions used by
55
+ various `backend.tfd` classes. It's recommended to explicitly define the
56
+ custom functions instead of using lambdas, as lambda functions may not
57
+ be registered successfully.
58
+ """
59
+ self._function_registry = function_registry
60
+
61
+ @property
62
+ def function_registry(self) -> FunctionRegistry:
63
+ return self._function_registry
64
+
65
+ def serialize(
66
+ self, obj: pd.PriorDistribution
67
+ ) -> meridian_pb.PriorTfpDistributions:
68
+ """Serializes the given Meridian priors container into a `MeridianPriorDistributions` proto."""
69
+ proto = meridian_pb.PriorTfpDistributions()
70
+ for param in constants.ALL_PRIOR_DISTRIBUTION_PARAMETERS:
71
+ if not hasattr(obj, param):
72
+ continue
73
+ getattr(proto, param).CopyFrom(
74
+ self._to_distribution_proto(getattr(obj, param))
75
+ )
76
+ proto.function_registry.update(self.function_registry.hashed_registry)
77
+ return proto
78
+
79
+ def deserialize(
80
+ self,
81
+ serialized: MeridianPriorDistributions,
82
+ serialized_version: str = "",
83
+ force_deserialization: bool = False,
84
+ ) -> pd.PriorDistribution:
85
+ """Deserializes the `PriorTfpDistributions` proto.
86
+
87
+ WARNING: If any custom functions in the function registry are modified after
88
+ serialization, the deserialized model can differ from the original model, as
89
+ the original function's behavior is no longer guaranteed. This will result
90
+ in an error during deserialization.
91
+
92
+ For users who are intentionally changing functions and are confident that
93
+ the changes will not affect the deserialized model, you can bypass safety
94
+ mechanisms to force deserialization. See example:
95
+
96
+ Args:
97
+ serialized: A serialized `PriorDistributions` object.
98
+ serialized_version: The version of the serialized Meridian model. This is
99
+ used to handle changes in deserialization logic across different
100
+ versions.
101
+ force_deserialization: If True, bypasses the safety check that validates
102
+ whether functions within `function_registry` have changed after
103
+ serialization. Use with caution. This should only be used if you have
104
+ intentionally modified a custom function and are confident that the
105
+ changes will not affect the deserialized model. A safer alternative is
106
+ to first deserialize the model with the original functions and then
107
+ serialize it with the new ones.
108
+
109
+ Returns:
110
+ A deserialized `PriorDistribution` object.
111
+ """
112
+ kwargs = {}
113
+ for param in constants.ALL_PRIOR_DISTRIBUTION_PARAMETERS:
114
+ if not hasattr(serialized, param):
115
+ continue
116
+ # A parameter may be unspecified in a serialized proto message because:
117
+ # (1) It is left unset for Meridian to set its default value.
118
+ # (2) The message was created from a previous Meridian version after
119
+ # introducing a new parameter.
120
+ if not serialized.HasField(param):
121
+ continue
122
+ param_name = getattr(serialized, param)
123
+ if isinstance(serialized, meridian_pb.PriorTfpDistributions):
124
+ if force_deserialization:
125
+ warnings.warn(
126
+ "You're attempting to deserialize a model while ignoring changes"
127
+ " to custom functions. This is a risky operation that can"
128
+ " potentially lead to a deserialized model that behaves"
129
+ " differently from the original, resulting in unexpected behavior"
130
+ " or model failure. We strongly recommend a safer two-step"
131
+ " process: deserialize the model using the original function"
132
+ " registry and reserialize the model using the updated registry."
133
+ " Please proceed with caution."
134
+ )
135
+ else:
136
+ stored_hashed_function_registry = getattr(
137
+ serialized, _FUNCTION_REGISTRY_NAME
138
+ )
139
+ try:
140
+ self.function_registry.validate(stored_hashed_function_registry)
141
+ except ValueError as e:
142
+ raise ValueError(
143
+ f"An issue found during deserializing Distribution: {e}"
144
+ ) from e
145
+
146
+ kwargs[param] = self._from_distribution_proto(param_name)
147
+ # copybara: strip_begin(legacy proto)
148
+ elif isinstance(serialized, meridian_pb.PriorDistributions):
149
+ kwargs[param] = _from_legacy_distribution_proto(param_name)
150
+ # copybara: strip_end
151
+ return pd.PriorDistribution(**kwargs)
152
+
153
+ def _to_distribution_proto(
154
+ self,
155
+ dist: backend.tfd.Distribution,
156
+ ) -> meridian_pb.TfpDistribution:
157
+ """Converts a TensorFlow `Distribution` object to a `TfpDistribution` proto."""
158
+ dist_name = type(dist).__name__
159
+ dist_class = getattr(backend.tfd, dist_name)
160
+ return meridian_pb.TfpDistribution(
161
+ distribution_type=dist_name,
162
+ parameters={
163
+ name: self._to_parameter_value_proto(name, value, dist_class)
164
+ for name, value in dist.parameters.items()
165
+ },
166
+ )
167
+
168
+ def _to_bijector_proto(
169
+ self,
170
+ bijector: backend.bijectors.Bijector,
171
+ ) -> meridian_pb.TfpBijector:
172
+ """Converts a TensorFlow `Bijector` object to a `TfpBijector` proto."""
173
+ bij_name = type(bijector).__name__
174
+ bij_class = getattr(backend.bijectors, bij_name)
175
+ return meridian_pb.TfpBijector(
176
+ bijector_type=bij_name,
177
+ parameters={
178
+ name: self._to_parameter_value_proto(name, value, bij_class)
179
+ for name, value in bijector.parameters.items()
180
+ },
181
+ )
182
+
183
+ def _to_parameter_value_proto(
184
+ self,
185
+ param_name: str,
186
+ value: Any,
187
+ dist: backend.tfd.Distribution | backend.bijectors.Bijector,
188
+ ) -> meridian_pb.TfpParameterValue:
189
+ """Converts a TensorFlow `Distribution` parameter value to a `TfpParameterValue` proto."""
190
+ # Handle built-in types.
191
+ match value:
192
+ case float():
193
+ return meridian_pb.TfpParameterValue(scalar_value=value)
194
+ case int():
195
+ return meridian_pb.TfpParameterValue(int_value=value)
196
+ case bool():
197
+ return meridian_pb.TfpParameterValue(bool_value=value)
198
+ case str():
199
+ return meridian_pb.TfpParameterValue(string_value=value)
200
+ case None:
201
+ return meridian_pb.TfpParameterValue(none_value=True)
202
+ case list():
203
+ value_generator = (
204
+ self._to_parameter_value_proto(param_name, v, dist) for v in value
205
+ )
206
+ return meridian_pb.TfpParameterValue(
207
+ list_value=meridian_pb.TfpParameterValue.List(
208
+ values=value_generator
209
+ )
210
+ )
211
+ case dict():
212
+ dict_value = {
213
+ k: self._to_parameter_value_proto(param_name, v, dist)
214
+ for k, v in value.items()
215
+ }
216
+ return meridian_pb.TfpParameterValue(
217
+ dict_value=meridian_pb.TfpParameterValue.Dict(value_map=dict_value)
218
+ )
219
+ case backend.Tensor():
220
+ return meridian_pb.TfpParameterValue(
221
+ tensor_value=backend.make_tensor_proto(value)
222
+ )
223
+ case backend.tfd.Distribution():
224
+ return meridian_pb.TfpParameterValue(
225
+ distribution_value=self._to_distribution_proto(value)
226
+ )
227
+ case backend.bijectors.Bijector():
228
+ return meridian_pb.TfpParameterValue(
229
+ bijector_value=self._to_bijector_proto(value)
230
+ )
231
+ case backend.tfd.ReparameterizationType():
232
+ fully_reparameterized = value == backend.tfd.FULLY_REPARAMETERIZED
233
+ return meridian_pb.TfpParameterValue(
234
+ fully_reparameterized=fully_reparameterized
235
+ )
236
+ case types.FunctionType():
237
+ # Check for default value
238
+ signature = inspect.signature(dist.__init__)
239
+ param = signature.parameters[param_name]
240
+ if param.default and param.default is value:
241
+ return meridian_pb.TfpParameterValue(
242
+ function_param=meridian_pb.TfpParameterValue.FunctionParam(
243
+ uses_default=True
244
+ )
245
+ )
246
+
247
+ # Check against registry.
248
+ function_key = self.function_registry.get_function_key(value)
249
+ if function_key is not None:
250
+ return meridian_pb.TfpParameterValue(
251
+ function_param=meridian_pb.TfpParameterValue.FunctionParam(
252
+ function_key=function_key
253
+ )
254
+ )
255
+ raise ValueError(
256
+ f"Custom function `{param_name}` detected for"
257
+ f" {type(dist).__name__}, but not found in registry. Please"
258
+ " add custom functions to registry when saving models."
259
+ )
260
+
261
+ # Handle unsupported types.
262
+ raise TypeError(f"Unsupported type: {type(value)}, {value}")
263
+
264
+ def _from_distribution_proto(
265
+ self,
266
+ dist_proto: meridian_pb.TfpDistribution,
267
+ ) -> backend.tfd.Distribution:
268
+ """Converts a `Distribution` proto to a TensorFlow `Distribution` object."""
269
+ dist_class_name = dist_proto.distribution_type
270
+ dist_class = getattr(backend.tfd, dist_class_name)
271
+ dist_parameters = dist_proto.parameters
272
+ input_parameters = {
273
+ k: self._unpack_tfp_parameters(k, v, dist_class)
274
+ for k, v in dist_parameters.items()
275
+ }
276
+ return dist_class(**input_parameters)
277
+
278
+ def _from_bijector_proto(
279
+ self,
280
+ dist_proto: meridian_pb.TfpBijector,
281
+ ) -> backend.bijectors.Bijector:
282
+ """Converts a `Bijector` proto to a TensorFlow `Bijector` object."""
283
+ dist_class_name = dist_proto.bijector_type
284
+ dist_class = getattr(backend.bijectors, dist_class_name)
285
+ dist_parameters = dist_proto.parameters
286
+ input_parameters = {
287
+ name: self._unpack_tfp_parameters(name, value, dist_class)
288
+ for name, value in dist_parameters.items()
289
+ }
290
+
291
+ return dist_class(**input_parameters)
292
+
293
+ def _unpack_tfp_parameters(
294
+ self,
295
+ param_name: str,
296
+ param_value: meridian_pb.TfpParameterValue,
297
+ dist_class: backend.tfd.Distribution,
298
+ ) -> Any:
299
+ """Unpacks a `TfpParameterValue` proto into a Python value."""
300
+ match param_value.WhichOneof("value_type"):
301
+ # Handle built-in types.
302
+ case "scalar_value":
303
+ return param_value.scalar_value
304
+ case "int_value":
305
+ return param_value.int_value
306
+ case "bool_value":
307
+ return param_value.bool_value
308
+ case "string_value":
309
+ return param_value.string_value
310
+ case "none_value":
311
+ return None
312
+ case "list_value":
313
+ return [
314
+ self._unpack_tfp_parameters(param_name, v, dist_class)
315
+ for v in param_value.list_value.values
316
+ ]
317
+ case "dict_value":
318
+ items = param_value.dict_value.value_map.items()
319
+ return {
320
+ key: self._unpack_tfp_parameters(key, value, dist_class)
321
+ for key, value in items
322
+ }
323
+
324
+ # Handle custom types.
325
+ case "tensor_value":
326
+ return backend.to_tensor(backend.make_ndarray(param_value.tensor_value))
327
+ case "distribution_value":
328
+ return self._from_distribution_proto(param_value.distribution_value)
329
+ case "bijector_value":
330
+ return self._from_bijector_proto(param_value.bijector_value)
331
+ case "fully_reparameterized":
332
+ if param_value.fully_reparameterized:
333
+ return backend.tfd.FULLY_REPARAMETERIZED
334
+ else:
335
+ return backend.tfd.NOT_FULLY_REPARAMETERIZED
336
+
337
+ # Handle functions.
338
+ case "function_param":
339
+ function_param = param_value.function_param
340
+ # Check against registry.
341
+ if function_param.HasField("function_key"):
342
+ registry = self.function_registry
343
+ if function_param.function_key in registry:
344
+ return registry.get(function_param.function_key)
345
+ # Check for default value.
346
+ if (
347
+ function_param.HasField("uses_default")
348
+ and function_param.uses_default
349
+ ):
350
+ signature = inspect.signature(dist_class.__init__)
351
+ return signature.parameters[param_name].default
352
+ raise ValueError(f"No function found for {param_name}")
353
+
354
+ # Handle unsupported types.
355
+ case _:
356
+ raise ValueError(
357
+ f"Unsupported TFP distribution parameter type: {type(param_value)}"
358
+ )
359
+
360
+ # copybara: strip_begin
361
+
362
+
363
+ def _from_legacy_bijector_proto(
364
+ bijector_proto: meridian_pb.Distribution.Bijector,
365
+ ) -> backend.bijectors.Bijector:
366
+ """Converts a `Bijector` proto to a `Bijector` object."""
367
+ bijector_type_field = bijector_proto.WhichOneof(sc.BIJECTOR_TYPE)
368
+ match bijector_type_field:
369
+ case sc.SHIFT_BIJECTOR:
370
+ return backend.bijectors.Shift(
371
+ shift=_deserialize_sequence(bijector_proto.shift.shifts)
372
+ )
373
+ case sc.SCALE_BIJECTOR:
374
+ return backend.bijectors.Scale(
375
+ scale=_deserialize_sequence(bijector_proto.scale.scales),
376
+ log_scale=_deserialize_sequence(bijector_proto.scale.log_scales),
377
+ )
378
+ case sc.RECIPROCAL_BIJECTOR:
379
+ return backend.bijectors.Reciprocal()
380
+ case _:
381
+ raise ValueError(
382
+ f"Unsupported Bijector proto type: {bijector_type_field};"
383
+ f" Bijector proto:\n{bijector_proto}"
384
+ )
385
+
386
+
387
+ def _from_legacy_distribution_proto(
388
+ dist_proto: meridian_pb.Distribution,
389
+ ) -> backend.tfd.Distribution:
390
+ """Converts a `Distribution` proto to a `Distribution` object."""
391
+ dist_type_field = dist_proto.WhichOneof(sc.DISTRIBUTION_TYPE)
392
+ match dist_type_field:
393
+ case sc.BATCH_BROADCAST_DISTRIBUTION:
394
+ return backend.tfd.BatchBroadcast(
395
+ name=dist_proto.name,
396
+ distribution=_from_legacy_distribution_proto(
397
+ dist_proto.batch_broadcast.distribution
398
+ ),
399
+ with_shape=_from_shape_proto(dist_proto.batch_broadcast.batch_shape),
400
+ )
401
+ case sc.TRANSFORMED_DISTRIBUTION:
402
+ return backend.tfd.TransformedDistribution(
403
+ name=dist_proto.name,
404
+ distribution=_from_legacy_distribution_proto(
405
+ dist_proto.transformed.distribution
406
+ ),
407
+ bijector=_from_legacy_bijector_proto(dist_proto.transformed.bijector),
408
+ )
409
+ case sc.DETERMINISTIC_DISTRIBUTION:
410
+ return backend.tfd.Deterministic(
411
+ name=dist_proto.name,
412
+ loc=_deserialize_sequence(dist_proto.deterministic.locs),
413
+ )
414
+ case sc.HALF_NORMAL_DISTRIBUTION:
415
+ return backend.tfd.HalfNormal(
416
+ name=dist_proto.name,
417
+ scale=_deserialize_sequence(dist_proto.half_normal.scales),
418
+ )
419
+ case sc.LOG_NORMAL_DISTRIBUTION:
420
+ return backend.tfd.LogNormal(
421
+ name=dist_proto.name,
422
+ loc=_deserialize_sequence(dist_proto.log_normal.locs),
423
+ scale=_deserialize_sequence(dist_proto.log_normal.scales),
424
+ )
425
+ case sc.NORMAL_DISTRIBUTION:
426
+ return backend.tfd.Normal(
427
+ name=dist_proto.name,
428
+ loc=_deserialize_sequence(dist_proto.normal.locs),
429
+ scale=_deserialize_sequence(dist_proto.normal.scales),
430
+ )
431
+ case sc.TRUNCATED_NORMAL_DISTRIBUTION:
432
+ if (
433
+ hasattr(dist_proto.truncated_normal, "lows")
434
+ and dist_proto.truncated_normal.lows
435
+ ):
436
+ if dist_proto.truncated_normal.low:
437
+ _show_warning("low", "TruncatedNormal")
438
+ low = _deserialize_sequence(dist_proto.truncated_normal.lows)
439
+ else:
440
+ low = dist_proto.truncated_normal.low
441
+
442
+ if (
443
+ hasattr(dist_proto.truncated_normal, "highs")
444
+ and dist_proto.truncated_normal.highs
445
+ ):
446
+ if dist_proto.truncated_normal.high:
447
+ _show_warning("high", "TruncatedNormal")
448
+ high = _deserialize_sequence(dist_proto.truncated_normal.highs)
449
+ else:
450
+ high = dist_proto.truncated_normal.high
451
+ return backend.tfd.TruncatedNormal(
452
+ name=dist_proto.name,
453
+ loc=_deserialize_sequence(dist_proto.truncated_normal.locs),
454
+ scale=_deserialize_sequence(dist_proto.truncated_normal.scales),
455
+ low=low,
456
+ high=high,
457
+ )
458
+ case sc.UNIFORM_DISTRIBUTION:
459
+ if hasattr(dist_proto.uniform, "lows") and dist_proto.uniform.lows:
460
+ if dist_proto.uniform.low:
461
+ _show_warning("low", "Uniform")
462
+ low = _deserialize_sequence(dist_proto.uniform.lows)
463
+ else:
464
+ low = dist_proto.uniform.low
465
+
466
+ if hasattr(dist_proto.uniform, "highs") and dist_proto.uniform.highs:
467
+ if dist_proto.uniform.high:
468
+ _show_warning("high", "Uniform")
469
+ high = _deserialize_sequence(dist_proto.uniform.highs)
470
+ else:
471
+ high = dist_proto.uniform.high
472
+
473
+ return backend.tfd.Uniform(
474
+ name=dist_proto.name,
475
+ low=low,
476
+ high=high,
477
+ )
478
+ case sc.BETA_DISTRIBUTION:
479
+ return backend.tfd.Beta(
480
+ name=dist_proto.name,
481
+ concentration1=_deserialize_sequence(dist_proto.beta.alpha),
482
+ concentration0=_deserialize_sequence(dist_proto.beta.beta),
483
+ )
484
+ case _:
485
+ raise ValueError(
486
+ f"Unsupported Distribution proto type: {dist_type_field};"
487
+ f" Distribution proto:\n{dist_proto}"
488
+ )
489
+
490
+
491
+ def _show_warning(field_name: str, dist_name: str) -> None:
492
+ warnings.warn(
493
+ f"Both `{field_name}s` and `{field_name}` are specified in"
494
+ f" {dist_name} distribution proto. Prioritizing `{field_name}s` since"
495
+ f" `{field_name}` is deprecated.",
496
+ DeprecationWarning,
497
+ )
498
+
499
+
500
+ def _from_shape_proto(
501
+ shape_proto: tensor_shape_pb2.TensorShapeProto,
502
+ ) -> backend.TensorShapeInstance:
503
+ """Converts a `TensorShapeProto` to a `TensorShape`."""
504
+ return backend.TensorShape([dim.size for dim in shape_proto.dim])
505
+
506
+
507
+ T = TypeVar("T")
508
+
509
+
510
+ def _deserialize_sequence(args: Sequence[T]) -> T | Sequence[T] | None:
511
+ if not args:
512
+ return None
513
+ return args[0] if len(args) == 1 else list(args)
514
+
515
+ # copybara: strip_end