google-meridian 1.2.1__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. google_meridian-1.3.1.dist-info/METADATA +209 -0
  2. google_meridian-1.3.1.dist-info/RECORD +76 -0
  3. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/top_level.txt +1 -0
  4. meridian/analysis/__init__.py +2 -0
  5. meridian/analysis/analyzer.py +179 -105
  6. meridian/analysis/formatter.py +2 -2
  7. meridian/analysis/optimizer.py +227 -87
  8. meridian/analysis/review/__init__.py +20 -0
  9. meridian/analysis/review/checks.py +721 -0
  10. meridian/analysis/review/configs.py +110 -0
  11. meridian/analysis/review/constants.py +40 -0
  12. meridian/analysis/review/results.py +544 -0
  13. meridian/analysis/review/reviewer.py +186 -0
  14. meridian/analysis/summarizer.py +21 -34
  15. meridian/analysis/templates/chips.html.jinja +12 -0
  16. meridian/analysis/test_utils.py +27 -5
  17. meridian/analysis/visualizer.py +41 -57
  18. meridian/backend/__init__.py +457 -118
  19. meridian/backend/test_utils.py +162 -0
  20. meridian/constants.py +39 -3
  21. meridian/model/__init__.py +1 -0
  22. meridian/model/eda/__init__.py +3 -0
  23. meridian/model/eda/constants.py +21 -0
  24. meridian/model/eda/eda_engine.py +1309 -196
  25. meridian/model/eda/eda_outcome.py +200 -0
  26. meridian/model/eda/eda_spec.py +84 -0
  27. meridian/model/eda/meridian_eda.py +220 -0
  28. meridian/model/knots.py +55 -49
  29. meridian/model/media.py +10 -8
  30. meridian/model/model.py +79 -16
  31. meridian/model/model_test_data.py +53 -0
  32. meridian/model/posterior_sampler.py +39 -32
  33. meridian/model/prior_distribution.py +12 -2
  34. meridian/model/prior_sampler.py +146 -90
  35. meridian/model/spec.py +7 -8
  36. meridian/model/transformers.py +11 -3
  37. meridian/version.py +1 -1
  38. schema/__init__.py +18 -0
  39. schema/serde/__init__.py +26 -0
  40. schema/serde/constants.py +48 -0
  41. schema/serde/distribution.py +515 -0
  42. schema/serde/eda_spec.py +192 -0
  43. schema/serde/function_registry.py +143 -0
  44. schema/serde/hyperparameters.py +363 -0
  45. schema/serde/inference_data.py +105 -0
  46. schema/serde/marketing_data.py +1321 -0
  47. schema/serde/meridian_serde.py +413 -0
  48. schema/serde/serde.py +47 -0
  49. schema/serde/test_data.py +4608 -0
  50. schema/utils/__init__.py +17 -0
  51. schema/utils/time_record.py +156 -0
  52. google_meridian-1.2.1.dist-info/METADATA +0 -409
  53. google_meridian-1.2.1.dist-info/RECORD +0 -52
  54. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/WHEEL +0 -0
  55. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/licenses/LICENSE +0 -0
meridian/model/model.py CHANGED
@@ -14,6 +14,7 @@
14
14
 
15
15
  """Meridian module for the geo-level Bayesian hierarchical media mix model."""
16
16
 
17
+ import collections
17
18
  from collections.abc import Mapping, Sequence
18
19
  import functools
19
20
  import numbers
@@ -34,19 +35,26 @@ from meridian.model import prior_distribution
34
35
  from meridian.model import prior_sampler
35
36
  from meridian.model import spec
36
37
  from meridian.model import transformers
38
+ from meridian.model.eda import eda_engine
39
+ from meridian.model.eda import eda_outcome
40
+ from meridian.model.eda import eda_spec as eda_spec_module
37
41
  import numpy as np
38
42
 
39
-
40
43
  __all__ = [
41
44
  "MCMCSamplingError",
42
45
  "MCMCOOMError",
43
46
  "Meridian",
47
+ "ModelFittingError",
44
48
  "NotFittedModelError",
45
49
  "save_mmm",
46
50
  "load_mmm",
47
51
  ]
48
52
 
49
53
 
54
+ class ModelFittingError(Exception):
55
+ """Model has critical issues preventing fitting."""
56
+
57
+
50
58
  class NotFittedModelError(Exception):
51
59
  """Model has not been fitted."""
52
60
 
@@ -91,6 +99,10 @@ class Meridian:
91
99
  model_spec: A `ModelSpec` object containing the model specification.
92
100
  inference_data: A _mutable_ `arviz.InferenceData` object containing the
93
101
  resulting data from fitting the model.
102
+ eda_engine: An `EDAEngine` object containing the EDA engine.
103
+ eda_spec: An `EDASpec` object containing the EDA specification.
104
+ eda_outcomes: A list of `EDAOutcome` objects containing the outcomes from
105
+ running critical EDA checks.
94
106
  n_geos: Number of geos in the data.
95
107
  n_media_channels: Number of media channels in the data.
96
108
  n_rf_channels: Number of reach and frequency (RF) channels in the data.
@@ -126,11 +138,17 @@ class Meridian:
126
138
  treatmenttensors using the model's non-media treatment data.
127
139
  kpi_transformer: A `KpiTransformer` to scale KPI tensors using the model's
128
140
  KPI data.
129
- controls_scaled: The controls tensor normalized by population and by the
130
- median value.
131
- non_media_treatments_scaled: The non-media treatment tensor normalized by
132
- population and by the median value.
133
- kpi_scaled: The KPI tensor normalized by population and by the median value.
141
+ controls_scaled: The controls tensor after pre-modeling transformations
142
+ including population scaling (for variables with
143
+ `ModelSpec.control_population_scaling_id` set to `True`), centering by the
144
+ mean, and scaling by the standard deviation.
145
+ non_media_treatments_scaled: The non-media treatment tensor after
146
+ pre-modeling transformations including population scaling (for variables
147
+ with `ModelSpec.non_media_population_scaling_id` set to `True`), centering
148
+ by the mean, and scaling by the standard deviation.
149
+ kpi_scaled: The KPI tensor after pre-modeling transformations including
150
+ population scaling, centering by the mean, and scaling by the standard
151
+ deviation.
134
152
  media_effects_dist: A string to specify the distribution of media random
135
153
  effects across geos.
136
154
  unique_sigma_for_each_geo: A boolean indicating whether to use a unique
@@ -148,6 +166,7 @@ class Meridian:
148
166
  inference_data: (
149
167
  az.InferenceData | None
150
168
  ) = None, # for deserializer use only
169
+ eda_spec: eda_spec_module.EDASpec = eda_spec_module.EDASpec(),
151
170
  ):
152
171
  self._input_data = input_data
153
172
  self._model_spec = model_spec if model_spec else spec.ModelSpec()
@@ -155,6 +174,8 @@ class Meridian:
155
174
  inference_data if inference_data else az.InferenceData()
156
175
  )
157
176
 
177
+ self._eda_spec = eda_spec
178
+
158
179
  self._validate_data_dependent_model_spec()
159
180
  self._validate_injected_inference_data()
160
181
 
@@ -184,6 +205,18 @@ class Meridian:
184
205
  def inference_data(self) -> az.InferenceData:
185
206
  return self._inference_data
186
207
 
208
+ @functools.cached_property
209
+ def eda_engine(self) -> eda_engine.EDAEngine:
210
+ return eda_engine.EDAEngine(self, spec=self._eda_spec)
211
+
212
+ @property
213
+ def eda_spec(self) -> eda_spec_module.EDASpec:
214
+ return self._eda_spec
215
+
216
+ @property
217
+ def eda_outcomes(self) -> Sequence[eda_outcome.EDAOutcome]:
218
+ return self.eda_engine.run_all_critical_checks()
219
+
187
220
  @functools.cached_property
188
221
  def media_tensors(self) -> media.MediaTensors:
189
222
  return media.build_media_tensors(self.input_data, self.model_spec)
@@ -444,7 +477,8 @@ class Meridian:
444
477
  f" {tuple(self.model_spec.adstock_decay_spec.keys())}. Keys should"
445
478
  " either contain only channel_names"
446
479
  f" {tuple(self.input_data.get_all_adstock_hill_channels().tolist())} or"
447
- " be one or more of {'media', 'rf', 'organic_media', 'organic_rf'}."
480
+ " be one or more of {'media', 'rf', 'organic_media',"
481
+ " 'organic_rf'}."
448
482
  ) from e
449
483
 
450
484
  @functools.cached_property
@@ -561,7 +595,9 @@ class Meridian:
561
595
  non_media_treatments_population_scaled[..., channel], axis=[0, 1]
562
596
  )
563
597
  elif isinstance(baseline_value, numbers.Number):
564
- baseline_for_channel = backend.cast(baseline_value, backend.float32)
598
+ baseline_for_channel = backend.to_tensor(
599
+ baseline_value, dtype=backend.float32
600
+ )
565
601
  else:
566
602
  raise ValueError(
567
603
  f"Invalid non_media_baseline_values value: '{baseline_value}'. Only"
@@ -1135,16 +1171,11 @@ class Meridian:
1135
1171
  " time."
1136
1172
  )
1137
1173
 
1138
- def _kpi_has_variability(self):
1139
- """Returns True if the KPI has variability across geos and times."""
1140
- return self.kpi_transformer.population_scaled_stdev != 0
1141
-
1142
1174
  def _validate_kpi_transformer(self):
1143
1175
  """Validates the KPI transformer."""
1144
- if self._kpi_has_variability():
1176
+ if self.eda_engine.kpi_has_variability:
1145
1177
  return
1146
-
1147
- kpi = "kpi" if self.is_national else "population_scaled_kpi"
1178
+ kpi = self.eda_engine.kpi_scaled_da.name
1148
1179
 
1149
1180
  if (
1150
1181
  self.n_media_channels > 0
@@ -1569,6 +1600,36 @@ class Meridian:
1569
1600
  """
1570
1601
  self.prior_sampler_callable(n_draws=n_draws, seed=seed)
1571
1602
 
1603
+ def _run_model_fitting_guardrail(self):
1604
+ """Raises an error if the model has critical EDA issues."""
1605
+ error_findings_by_type: dict[eda_outcome.EDACheckType, list[str]] = (
1606
+ collections.defaultdict(list)
1607
+ )
1608
+ for outcome in self.eda_outcomes:
1609
+ error_findings = [
1610
+ finding
1611
+ for finding in outcome.findings
1612
+ if finding.severity == eda_outcome.EDASeverity.ERROR
1613
+ ]
1614
+ if error_findings:
1615
+ error_findings_by_type[outcome.check_type].extend(
1616
+ [finding.explanation for finding in error_findings]
1617
+ )
1618
+
1619
+ if error_findings_by_type:
1620
+ error_message_lines = [
1621
+ "Model has critical EDA issues. Please fix before running"
1622
+ " `sample_posterior`.\n"
1623
+ ]
1624
+ for check_type, explanations in error_findings_by_type.items():
1625
+ error_message_lines.append(f"Check type: {check_type.name}")
1626
+ for explanation in explanations:
1627
+ error_message_lines.append(f"- {explanation}")
1628
+ error_message_lines.append(
1629
+ "For further details, please refer to `Meridian.eda_outcomes`."
1630
+ )
1631
+ raise ModelFittingError("\n".join(error_message_lines))
1632
+
1572
1633
  def sample_posterior(
1573
1634
  self,
1574
1635
  n_chains: Sequence[int] | int,
@@ -1644,8 +1705,10 @@ class Meridian:
1644
1705
  a list of integers as `n_chains` to sample chains serially. For more
1645
1706
  information, see
1646
1707
  [ResourceExhaustedError when running Meridian.sample_posterior]
1647
- (https://developers.google.com/meridian/docs/advanced-modeling/model-debugging#gpu-oom-error).
1708
+ (https://developers.google.com/meridian/docs/post-modeling/model-debugging#gpu-oom-error).
1648
1709
  """
1710
+ self._run_model_fitting_guardrail()
1711
+
1649
1712
  self.posterior_sampler_callable(
1650
1713
  n_chains=n_chains,
1651
1714
  n_adapt=n_adapt,
@@ -143,6 +143,7 @@ class WithInputDataSamples:
143
143
  _short_input_data_with_rf_only: input_data.InputData
144
144
  _short_input_data_with_media_and_rf: input_data.InputData
145
145
  _national_input_data_media_only: input_data.InputData
146
+ _national_input_data_rf_only: input_data.InputData
146
147
  _national_input_data_media_and_rf: input_data.InputData
147
148
  _test_dist_media_and_rf: collections.OrderedDict[str, backend.Tensor]
148
149
  _test_dist_media_only: collections.OrderedDict[str, backend.Tensor]
@@ -156,6 +157,8 @@ class WithInputDataSamples:
156
157
  _short_input_data_non_media_and_organic: input_data.InputData
157
158
  _short_input_data_non_media: input_data.InputData
158
159
  _input_data_non_media_and_organic_same_time_dims: input_data.InputData
160
+ _input_data_organic_only: input_data.InputData
161
+ _national_input_data_organic_only: input_data.InputData
159
162
 
160
163
  # The following NamedTuples and their attributes are immutable, so they can
161
164
  # be accessed directly.
@@ -280,6 +283,16 @@ class WithInputDataSamples:
280
283
  seed=0,
281
284
  )
282
285
  )
286
+ cls._national_input_data_rf_only = (
287
+ test_utils.sample_input_data_non_revenue_revenue_per_kpi(
288
+ n_geos=cls._N_GEOS_NATIONAL,
289
+ n_times=cls._N_TIMES,
290
+ n_media_times=cls._N_MEDIA_TIMES,
291
+ n_controls=cls._N_CONTROLS,
292
+ n_rf_channels=cls._N_RF_CHANNELS,
293
+ seed=0,
294
+ )
295
+ )
283
296
  cls._national_input_data_media_only = (
284
297
  test_utils.sample_input_data_non_revenue_revenue_per_kpi(
285
298
  n_geos=cls._N_GEOS_NATIONAL,
@@ -496,6 +509,34 @@ class WithInputDataSamples:
496
509
  seed=0,
497
510
  )
498
511
  )
512
+ cls._input_data_organic_only = (
513
+ test_utils.sample_input_data_non_revenue_revenue_per_kpi(
514
+ n_geos=cls._N_GEOS,
515
+ n_times=cls._N_TIMES,
516
+ n_media_times=cls._N_MEDIA_TIMES,
517
+ n_controls=cls._N_CONTROLS,
518
+ n_non_media_channels=0,
519
+ n_media_channels=cls._N_MEDIA_CHANNELS,
520
+ n_rf_channels=0,
521
+ n_organic_media_channels=cls._N_ORGANIC_MEDIA_CHANNELS,
522
+ n_organic_rf_channels=cls._N_ORGANIC_RF_CHANNELS,
523
+ seed=0,
524
+ )
525
+ )
526
+ cls._national_input_data_organic_only = (
527
+ test_utils.sample_input_data_non_revenue_revenue_per_kpi(
528
+ n_geos=cls._N_GEOS_NATIONAL,
529
+ n_times=cls._N_TIMES,
530
+ n_media_times=cls._N_MEDIA_TIMES,
531
+ n_controls=cls._N_CONTROLS,
532
+ n_non_media_channels=0,
533
+ n_media_channels=cls._N_MEDIA_CHANNELS,
534
+ n_rf_channels=0,
535
+ n_organic_media_channels=cls._N_ORGANIC_MEDIA_CHANNELS,
536
+ n_organic_rf_channels=cls._N_ORGANIC_RF_CHANNELS,
537
+ seed=0,
538
+ )
539
+ )
499
540
 
500
541
  @property
501
542
  def input_data_non_revenue_no_revenue_per_kpi(self) -> input_data.InputData:
@@ -551,6 +592,10 @@ class WithInputDataSamples:
551
592
  def national_input_data_media_only(self) -> input_data.InputData:
552
593
  return self._national_input_data_media_only.copy(deep=True)
553
594
 
595
+ @property
596
+ def national_input_data_rf_only(self) -> input_data.InputData:
597
+ return self._national_input_data_rf_only.copy(deep=True)
598
+
554
599
  @property
555
600
  def national_input_data_media_and_rf(self) -> input_data.InputData:
556
601
  return self._national_input_data_media_and_rf.copy(deep=True)
@@ -606,3 +651,11 @@ class WithInputDataSamples:
606
651
  self,
607
652
  ) -> input_data.InputData:
608
653
  return self._input_data_non_media_and_organic_same_time_dims.copy(deep=True)
654
+
655
+ @property
656
+ def input_data_organic_only(self) -> input_data.InputData:
657
+ return self._input_data_organic_only.copy(deep=True)
658
+
659
+ @property
660
+ def national_input_data_organic_only(self) -> input_data.InputData:
661
+ return self._national_input_data_organic_only.copy(deep=True)
@@ -72,12 +72,6 @@ def _get_tau_g(
72
72
  return backend.tfd.Deterministic(tau_g, name="tau_g")
73
73
 
74
74
 
75
- @backend.function(autograph=False, jit_compile=True)
76
- def _xla_windowed_adaptive_nuts(**kwargs):
77
- """XLA wrapper for windowed_adaptive_nuts."""
78
- return backend.experimental.mcmc.windowed_adaptive_nuts(**kwargs)
79
-
80
-
81
75
  def _joint_dist_unpinned(mmm: "model.Meridian"):
82
76
  """Returns unpinned joint distribution."""
83
77
 
@@ -447,26 +441,44 @@ class PosteriorMCMCSampler:
447
441
 
448
442
  def __init__(self, meridian: "model.Meridian"):
449
443
  self._meridian = meridian
444
+ self._joint_dist = None
445
+
446
+ def __getstate__(self):
447
+ state = self.__dict__.copy()
448
+ # Exclude unpickleable objects.
449
+ if "_joint_dist" in state:
450
+ del state["_joint_dist"]
451
+ return state
452
+
453
+ def __setstate__(self, state):
454
+ self.__dict__.update(state)
455
+ self._joint_dist = None
450
456
 
451
457
  @property
452
458
  def model(self) -> "model.Meridian":
453
459
  return self._meridian
454
460
 
461
+ def _joint_dist_unpinned_fn(self):
462
+ return _joint_dist_unpinned(self.model)
463
+
455
464
  def _get_joint_dist_unpinned(self) -> backend.tfd.Distribution:
456
- """Returns a `JointDistributionCoroutineAutoBatched` function for MCMC."""
465
+ """Builds a `JointDistributionCoroutineAutoBatched` function for MCMC."""
457
466
  mmm = self.model
458
467
  mmm.populate_cached_properties()
459
- fn = lambda: _joint_dist_unpinned(mmm)
460
- return backend.tfd.JointDistributionCoroutineAutoBatched(fn)
468
+ return backend.tfd.JointDistributionCoroutineAutoBatched(
469
+ self._joint_dist_unpinned_fn
470
+ )
461
471
 
462
472
  def _get_joint_dist(self) -> backend.tfd.Distribution:
463
- mmm = self.model
464
- y = (
465
- backend.where(mmm.holdout_id, 0.0, mmm.kpi_scaled)
466
- if mmm.holdout_id is not None
467
- else mmm.kpi_scaled
468
- )
469
- return self._get_joint_dist_unpinned().experimental_pin(y=y)
473
+ if self._joint_dist is None:
474
+ mmm = self.model
475
+ y = (
476
+ backend.where(mmm.holdout_id, 0.0, mmm.kpi_scaled)
477
+ if mmm.holdout_id is not None
478
+ else mmm.kpi_scaled
479
+ )
480
+ self._joint_dist = self._get_joint_dist_unpinned().experimental_pin(y=y)
481
+ return self._joint_dist
470
482
 
471
483
  def __call__(
472
484
  self,
@@ -541,26 +553,22 @@ class PosteriorMCMCSampler:
541
553
  a list of integers as `n_chains` to sample chains serially. For more
542
554
  information, see
543
555
  [ResourceExhaustedError when running Meridian.sample_posterior]
544
- (https://developers.google.com/meridian/docs/advanced-modeling/model-debugging#gpu-oom-error).
556
+ (https://developers.google.com/meridian/docs/post-modeling/model-debugging#gpu-oom-error).
545
557
  """
546
- if seed is not None and isinstance(seed, Sequence) and len(seed) != 2:
547
- raise ValueError(
548
- "Invalid seed: Must be either a single integer (stateful seed) or a"
549
- " pair of two integers (stateless seed). See"
550
- " [tfp.random.sanitize_seed](https://www.tensorflow.org/probability/api_docs/python/tfp/random/sanitize_seed)"
551
- " for details."
552
- )
553
- if seed is not None and isinstance(seed, int):
554
- seed = (seed, seed)
555
- seed = backend.random.sanitize_seed(seed) if seed is not None else None
558
+ rng_handler = backend.RNGHandler(seed)
556
559
  n_chains_list = [n_chains] if isinstance(n_chains, int) else n_chains
557
560
  total_chains = np.sum(n_chains_list)
558
561
 
562
+ # Clear joint distribution cache prior to sampling.
563
+ self._joint_dist = None
564
+
559
565
  states = []
560
566
  traces = []
561
567
  for n_chains_batch in n_chains_list:
568
+ kernel_seed = rng_handler.get_kernel_seed()
569
+
562
570
  try:
563
- mcmc = _xla_windowed_adaptive_nuts(
571
+ mcmc = backend.xla_windowed_adaptive_nuts(
564
572
  n_draws=n_burnin + n_keep,
565
573
  joint_dist=self._get_joint_dist(),
566
574
  n_chains=n_chains_batch,
@@ -572,17 +580,16 @@ class PosteriorMCMCSampler:
572
580
  max_energy_diff=max_energy_diff,
573
581
  unrolled_leapfrog_steps=unrolled_leapfrog_steps,
574
582
  parallel_iterations=parallel_iterations,
575
- seed=seed,
583
+ seed=kernel_seed,
576
584
  **pins,
577
585
  )
578
586
  except backend.errors.ResourceExhaustedError as error:
579
587
  raise MCMCOOMError(
580
588
  "ERROR: Out of memory. Try reducing `n_keep` or pass a list of"
581
589
  " integers as `n_chains` to sample chains serially (see"
582
- " https://developers.google.com/meridian/docs/advanced-modeling/model-debugging#gpu-oom-error)"
590
+ " https://developers.google.com/meridian/docs/post-modeling/model-debugging#gpu-oom-error)"
583
591
  ) from error
584
- if seed is not None:
585
- seed += 1
592
+ rng_handler = rng_handler.advance_handler()
586
593
  states.append(mcmc.all_states._asdict())
587
594
  traces.append(mcmc.trace)
588
595
 
@@ -35,6 +35,7 @@ __all__ = [
35
35
  'PriorDistribution',
36
36
  'distributions_are_equal',
37
37
  'lognormal_dist_from_mean_std',
38
+ 'lognormal_dist_from_range',
38
39
  ]
39
40
 
40
41
 
@@ -1195,7 +1196,7 @@ def lognormal_dist_from_range(
1195
1196
  """Define a LogNormal distribution from a specified range.
1196
1197
 
1197
1198
  This function parameterizes lognormal distributions by the bounds of a range,
1198
- so that the specificed probability mass falls within the bounds defined by
1199
+ so that the specified probability mass falls within the bounds defined by
1199
1200
  `low` and `high`. The probability mass is symmetric about the median. For
1200
1201
  example, to define a lognormal distribution with a 95% probability mass of
1201
1202
  (1, 10), use:
@@ -1210,7 +1211,7 @@ def lognormal_dist_from_range(
1210
1211
  high: Float or array-like denoting the upper bound of range. Values must be
1211
1212
  non-negative.
1212
1213
  mass_percent: Float or array-like denoting the probability mass. Values must
1213
- be between 0 and 1 (exlusive). Default: 0.95.
1214
+ be between 0 and 1 (exclusive). Default: 0.95.
1214
1215
 
1215
1216
  Returns:
1216
1217
  A `backend.tfd.LogNormal` object with the input percentage mass falling
@@ -1341,6 +1342,15 @@ def _validate_support(
1341
1342
  f'{parameter_name} was assigned a point mass (deterministic) prior'
1342
1343
  f' at {bounds[i]}, which is not allowed.'
1343
1344
  )
1345
+ elif isinstance(tfp_dist, backend.tfd.TruncatedNormal):
1346
+ # TruncatedNormal quantile method is not reliable, particularly when the
1347
+ # `low` or `high` value falls into extreme percentile of the untruncated
1348
+ # distribution. Note that
1349
+ # `TruncatedNormal.experimental_default_event_space_bijector()([-inf, inf])`
1350
+ # returns the correct support range, so this method could be used if the
1351
+ # `quantile` method is found to be unreliable for other distributions.
1352
+ support_min_vals = tfp_dist.low
1353
+ support_max_vals = tfp_dist.high
1344
1354
  else:
1345
1355
  try:
1346
1356
  support_min_vals = tfp_dist.quantile(0)