google-meridian 1.1.5__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,9 +18,8 @@ from collections.abc import Mapping
18
18
  from typing import TYPE_CHECKING
19
19
 
20
20
  import arviz as az
21
+ from meridian import backend
21
22
  from meridian import constants
22
- import tensorflow as tf
23
- import tensorflow_probability as tfp
24
23
 
25
24
  if TYPE_CHECKING:
26
25
  from meridian.model import model # pylint: disable=g-bad-import-order,g-import-not-at-top
@@ -32,8 +31,8 @@ __all__ = [
32
31
 
33
32
 
34
33
  def _get_tau_g(
35
- tau_g_excl_baseline: tf.Tensor, baseline_geo_idx: int
36
- ) -> tfp.distributions.Distribution:
34
+ tau_g_excl_baseline: backend.Tensor, baseline_geo_idx: int
35
+ ) -> backend.tfd.Distribution:
37
36
  """Computes `tau_g` from `tau_g_excl_baseline`.
38
37
 
39
38
  This function computes `tau_g` by inserting a column of zeros at the
@@ -51,15 +50,15 @@ def _get_tau_g(
51
50
  """
52
51
  rank = len(tau_g_excl_baseline.shape)
53
52
  shape = tau_g_excl_baseline.shape[:-1] + [1] if rank != 1 else 1
54
- tau_g = tf.concat(
53
+ tau_g = backend.concatenate(
55
54
  [
56
55
  tau_g_excl_baseline[..., :baseline_geo_idx],
57
- tf.zeros(shape, dtype=tau_g_excl_baseline.dtype),
56
+ backend.zeros(shape, dtype=tau_g_excl_baseline.dtype),
58
57
  tau_g_excl_baseline[..., baseline_geo_idx:],
59
58
  ],
60
59
  axis=rank - 1,
61
60
  )
62
- return tfp.distributions.Deterministic(tau_g, name="tau_g")
61
+ return backend.tfd.Deterministic(tau_g, name="tau_g")
63
62
 
64
63
 
65
64
  class PriorDistributionSampler:
@@ -72,7 +71,7 @@ class PriorDistributionSampler:
72
71
  self,
73
72
  n_draws: int,
74
73
  seed: int | None = None,
75
- ) -> Mapping[str, tf.Tensor]:
74
+ ) -> Mapping[str, backend.Tensor]:
76
75
  """Draws samples from the prior distributions of the media variables.
77
76
 
78
77
  Args:
@@ -97,8 +96,8 @@ class PriorDistributionSampler:
97
96
  constants.ETA_M: prior.eta_m.sample(**sample_kwargs),
98
97
  constants.SLOPE_M: prior.slope_m.sample(**sample_kwargs),
99
98
  }
100
- beta_gm_dev = tfp.distributions.Sample(
101
- tfp.distributions.Normal(0, 1),
99
+ beta_gm_dev = backend.tfd.Sample(
100
+ backend.tfd.Normal(0, 1),
102
101
  [mmm.n_geos, mmm.n_media_channels],
103
102
  name=constants.BETA_GM_DEV,
104
103
  ).sample(**sample_kwargs)
@@ -126,6 +125,7 @@ class PriorDistributionSampler:
126
125
  alpha=media_vars[constants.ALPHA_M],
127
126
  ec=media_vars[constants.EC_M],
128
127
  slope=media_vars[constants.SLOPE_M],
128
+ decay_functions=mmm.adstock_decay_spec.media
129
129
  )
130
130
  linear_predictor_counterfactual_difference = (
131
131
  mmm.linear_predictor_counterfactual_difference_media(
@@ -142,20 +142,20 @@ class PriorDistributionSampler:
142
142
  eta_x=media_vars[constants.ETA_M],
143
143
  beta_gx_dev=beta_gm_dev,
144
144
  )
145
- media_vars[constants.BETA_M] = tfp.distributions.Deterministic(
145
+ media_vars[constants.BETA_M] = backend.tfd.Deterministic(
146
146
  beta_m_value, name=constants.BETA_M
147
147
  ).sample()
148
148
 
149
149
  beta_eta_combined = (
150
- media_vars[constants.BETA_M][..., tf.newaxis, :]
151
- + media_vars[constants.ETA_M][..., tf.newaxis, :] * beta_gm_dev
150
+ media_vars[constants.BETA_M][..., backend.newaxis, :]
151
+ + media_vars[constants.ETA_M][..., backend.newaxis, :] * beta_gm_dev
152
152
  )
153
153
  beta_gm_value = (
154
154
  beta_eta_combined
155
155
  if mmm.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
156
- else tf.math.exp(beta_eta_combined)
156
+ else backend.exp(beta_eta_combined)
157
157
  )
158
- media_vars[constants.BETA_GM] = tfp.distributions.Deterministic(
158
+ media_vars[constants.BETA_GM] = backend.tfd.Deterministic(
159
159
  beta_gm_value, name=constants.BETA_GM
160
160
  ).sample()
161
161
 
@@ -165,7 +165,7 @@ class PriorDistributionSampler:
165
165
  self,
166
166
  n_draws: int,
167
167
  seed: int | None = None,
168
- ) -> Mapping[str, tf.Tensor]:
168
+ ) -> Mapping[str, backend.Tensor]:
169
169
  """Draws samples from the prior distributions of the RF variables.
170
170
 
171
171
  Args:
@@ -190,8 +190,8 @@ class PriorDistributionSampler:
190
190
  constants.ETA_RF: prior.eta_rf.sample(**sample_kwargs),
191
191
  constants.SLOPE_RF: prior.slope_rf.sample(**sample_kwargs),
192
192
  }
193
- beta_grf_dev = tfp.distributions.Sample(
194
- tfp.distributions.Normal(0, 1),
193
+ beta_grf_dev = backend.tfd.Sample(
194
+ backend.tfd.Normal(0, 1),
195
195
  [mmm.n_geos, mmm.n_rf_channels],
196
196
  name=constants.BETA_GRF_DEV,
197
197
  ).sample(**sample_kwargs)
@@ -220,6 +220,7 @@ class PriorDistributionSampler:
220
220
  alpha=rf_vars[constants.ALPHA_RF],
221
221
  ec=rf_vars[constants.EC_RF],
222
222
  slope=rf_vars[constants.SLOPE_RF],
223
+ decay_functions=mmm.adstock_decay_spec.rf,
223
224
  )
224
225
  linear_predictor_counterfactual_difference = (
225
226
  mmm.linear_predictor_counterfactual_difference_rf(
@@ -236,21 +237,21 @@ class PriorDistributionSampler:
236
237
  eta_x=rf_vars[constants.ETA_RF],
237
238
  beta_gx_dev=beta_grf_dev,
238
239
  )
239
- rf_vars[constants.BETA_RF] = tfp.distributions.Deterministic(
240
+ rf_vars[constants.BETA_RF] = backend.tfd.Deterministic(
240
241
  beta_rf_value,
241
242
  name=constants.BETA_RF,
242
243
  ).sample()
243
244
 
244
245
  beta_eta_combined = (
245
- rf_vars[constants.BETA_RF][..., tf.newaxis, :]
246
- + rf_vars[constants.ETA_RF][..., tf.newaxis, :] * beta_grf_dev
246
+ rf_vars[constants.BETA_RF][..., backend.newaxis, :]
247
+ + rf_vars[constants.ETA_RF][..., backend.newaxis, :] * beta_grf_dev
247
248
  )
248
249
  beta_grf_value = (
249
250
  beta_eta_combined
250
251
  if mmm.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
251
- else tf.math.exp(beta_eta_combined)
252
+ else backend.exp(beta_eta_combined)
252
253
  )
253
- rf_vars[constants.BETA_GRF] = tfp.distributions.Deterministic(
254
+ rf_vars[constants.BETA_GRF] = backend.tfd.Deterministic(
254
255
  beta_grf_value, name=constants.BETA_GRF
255
256
  ).sample()
256
257
 
@@ -260,7 +261,7 @@ class PriorDistributionSampler:
260
261
  self,
261
262
  n_draws: int,
262
263
  seed: int | None = None,
263
- ) -> Mapping[str, tf.Tensor]:
264
+ ) -> Mapping[str, backend.Tensor]:
264
265
  """Draws samples from the prior distributions of organic media variables.
265
266
 
266
267
  Args:
@@ -285,8 +286,8 @@ class PriorDistributionSampler:
285
286
  constants.ETA_OM: prior.eta_om.sample(**sample_kwargs),
286
287
  constants.SLOPE_OM: prior.slope_om.sample(**sample_kwargs),
287
288
  }
288
- beta_gom_dev = tfp.distributions.Sample(
289
- tfp.distributions.Normal(0, 1),
289
+ beta_gom_dev = backend.tfd.Sample(
290
+ backend.tfd.Normal(0, 1),
290
291
  [mmm.n_geos, mmm.n_organic_media_channels],
291
292
  name=constants.BETA_GOM_DEV,
292
293
  ).sample(**sample_kwargs)
@@ -308,6 +309,7 @@ class PriorDistributionSampler:
308
309
  alpha=organic_media_vars[constants.ALPHA_OM],
309
310
  ec=organic_media_vars[constants.EC_OM],
310
311
  slope=organic_media_vars[constants.SLOPE_OM],
312
+ decay_functions=mmm.adstock_decay_spec.organic_media,
311
313
  )
312
314
  beta_om_value = mmm.calculate_beta_x(
313
315
  is_non_media=False,
@@ -316,7 +318,7 @@ class PriorDistributionSampler:
316
318
  eta_x=organic_media_vars[constants.ETA_OM],
317
319
  beta_gx_dev=beta_gom_dev,
318
320
  )
319
- organic_media_vars[constants.BETA_OM] = tfp.distributions.Deterministic(
321
+ organic_media_vars[constants.BETA_OM] = backend.tfd.Deterministic(
320
322
  beta_om_value,
321
323
  name=constants.BETA_OM,
322
324
  ).sample()
@@ -324,16 +326,16 @@ class PriorDistributionSampler:
324
326
  raise ValueError(f"Unsupported prior type: {prior_type}")
325
327
 
326
328
  beta_eta_combined = (
327
- organic_media_vars[constants.BETA_OM][..., tf.newaxis, :]
328
- + organic_media_vars[constants.ETA_OM][..., tf.newaxis, :]
329
+ organic_media_vars[constants.BETA_OM][..., backend.newaxis, :]
330
+ + organic_media_vars[constants.ETA_OM][..., backend.newaxis, :]
329
331
  * beta_gom_dev
330
332
  )
331
333
  beta_gom_value = (
332
334
  beta_eta_combined
333
335
  if mmm.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
334
- else tf.math.exp(beta_eta_combined)
336
+ else backend.exp(beta_eta_combined)
335
337
  )
336
- organic_media_vars[constants.BETA_GOM] = tfp.distributions.Deterministic(
338
+ organic_media_vars[constants.BETA_GOM] = backend.tfd.Deterministic(
337
339
  beta_gom_value, name=constants.BETA_GOM
338
340
  ).sample()
339
341
 
@@ -343,7 +345,7 @@ class PriorDistributionSampler:
343
345
  self,
344
346
  n_draws: int,
345
347
  seed: int | None = None,
346
- ) -> Mapping[str, tf.Tensor]:
348
+ ) -> Mapping[str, backend.Tensor]:
347
349
  """Draws samples from the prior distributions of the organic RF variables.
348
350
 
349
351
  Args:
@@ -368,8 +370,8 @@ class PriorDistributionSampler:
368
370
  constants.ETA_ORF: prior.eta_orf.sample(**sample_kwargs),
369
371
  constants.SLOPE_ORF: prior.slope_orf.sample(**sample_kwargs),
370
372
  }
371
- beta_gorf_dev = tfp.distributions.Sample(
372
- tfp.distributions.Normal(0, 1),
373
+ beta_gorf_dev = backend.tfd.Sample(
374
+ backend.tfd.Normal(0, 1),
373
375
  [mmm.n_geos, mmm.n_organic_rf_channels],
374
376
  name=constants.BETA_GORF_DEV,
375
377
  ).sample(**sample_kwargs)
@@ -392,6 +394,7 @@ class PriorDistributionSampler:
392
394
  alpha=organic_rf_vars[constants.ALPHA_ORF],
393
395
  ec=organic_rf_vars[constants.EC_ORF],
394
396
  slope=organic_rf_vars[constants.SLOPE_ORF],
397
+ decay_functions=mmm.adstock_decay_spec.organic_rf,
395
398
  )
396
399
  beta_orf_value = mmm.calculate_beta_x(
397
400
  is_non_media=False,
@@ -400,7 +403,7 @@ class PriorDistributionSampler:
400
403
  eta_x=organic_rf_vars[constants.ETA_ORF],
401
404
  beta_gx_dev=beta_gorf_dev,
402
405
  )
403
- organic_rf_vars[constants.BETA_ORF] = tfp.distributions.Deterministic(
406
+ organic_rf_vars[constants.BETA_ORF] = backend.tfd.Deterministic(
404
407
  beta_orf_value,
405
408
  name=constants.BETA_ORF,
406
409
  ).sample()
@@ -408,15 +411,16 @@ class PriorDistributionSampler:
408
411
  raise ValueError(f"Unsupported prior type: {prior_type}")
409
412
 
410
413
  beta_eta_combined = (
411
- organic_rf_vars[constants.BETA_ORF][..., tf.newaxis, :]
412
- + organic_rf_vars[constants.ETA_ORF][..., tf.newaxis, :] * beta_gorf_dev
414
+ organic_rf_vars[constants.BETA_ORF][..., backend.newaxis, :]
415
+ + organic_rf_vars[constants.ETA_ORF][..., backend.newaxis, :]
416
+ * beta_gorf_dev
413
417
  )
414
418
  beta_gorf_value = (
415
419
  beta_eta_combined
416
420
  if mmm.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
417
- else tf.math.exp(beta_eta_combined)
421
+ else backend.exp(beta_eta_combined)
418
422
  )
419
- organic_rf_vars[constants.BETA_GORF] = tfp.distributions.Deterministic(
423
+ organic_rf_vars[constants.BETA_GORF] = backend.tfd.Deterministic(
420
424
  beta_gorf_value, name=constants.BETA_GORF
421
425
  ).sample()
422
426
 
@@ -426,7 +430,7 @@ class PriorDistributionSampler:
426
430
  self,
427
431
  n_draws: int,
428
432
  seed: int | None = None,
429
- ) -> Mapping[str, tf.Tensor]:
433
+ ) -> Mapping[str, backend.Tensor]:
430
434
  """Draws from the prior distributions of the non-media treatment variables.
431
435
 
432
436
  Args:
@@ -448,8 +452,8 @@ class PriorDistributionSampler:
448
452
  non_media_treatments_vars = {
449
453
  constants.XI_N: prior.xi_n.sample(**sample_kwargs),
450
454
  }
451
- gamma_gn_dev = tfp.distributions.Sample(
452
- tfp.distributions.Normal(0, 1),
455
+ gamma_gn_dev = backend.tfd.Sample(
456
+ backend.tfd.Normal(0, 1),
453
457
  [mmm.n_geos, mmm.n_non_media_channels],
454
458
  name=constants.GAMMA_GN_DEV,
455
459
  ).sample(**sample_kwargs)
@@ -479,35 +483,31 @@ class PriorDistributionSampler:
479
483
  eta_x=non_media_treatments_vars[constants.XI_N],
480
484
  beta_gx_dev=gamma_gn_dev,
481
485
  )
482
- non_media_treatments_vars[constants.GAMMA_N] = (
483
- tfp.distributions.Deterministic(
484
- gamma_n_value, name=constants.GAMMA_N
485
- ).sample()
486
- )
486
+ non_media_treatments_vars[constants.GAMMA_N] = backend.tfd.Deterministic(
487
+ gamma_n_value, name=constants.GAMMA_N
488
+ ).sample()
487
489
  else:
488
490
  raise ValueError(f"Unsupported prior type: {prior_type}")
489
- non_media_treatments_vars[constants.GAMMA_GN] = (
490
- tfp.distributions.Deterministic(
491
- non_media_treatments_vars[constants.GAMMA_N][..., tf.newaxis, :]
492
- + non_media_treatments_vars[constants.XI_N][..., tf.newaxis, :]
493
- * gamma_gn_dev,
494
- name=constants.GAMMA_GN,
495
- ).sample()
496
- )
491
+ non_media_treatments_vars[constants.GAMMA_GN] = backend.tfd.Deterministic(
492
+ non_media_treatments_vars[constants.GAMMA_N][..., backend.newaxis, :]
493
+ + non_media_treatments_vars[constants.XI_N][..., backend.newaxis, :]
494
+ * gamma_gn_dev,
495
+ name=constants.GAMMA_GN,
496
+ ).sample()
497
497
  return non_media_treatments_vars
498
498
 
499
499
  def _sample_prior(
500
500
  self,
501
501
  n_draws: int,
502
502
  seed: int | None = None,
503
- ) -> Mapping[str, tf.Tensor]:
503
+ ) -> Mapping[str, backend.Tensor]:
504
504
  """Returns a mapping of prior parameters to tensors of the samples."""
505
505
  mmm = self._meridian
506
506
 
507
507
  # For stateful sampling, the random seed must be set to ensure that any
508
508
  # random numbers that are generated are deterministic.
509
509
  if seed is not None:
510
- tf.keras.utils.set_random_seed(1)
510
+ backend.set_random_seed(seed)
511
511
 
512
512
  prior = mmm.prior_broadcast
513
513
  # `sample_shape` is prepended to the shape of each BatchBroadcast in `prior`
@@ -527,11 +527,11 @@ class PriorDistributionSampler:
527
527
  ),
528
528
  }
529
529
 
530
- base_vars[constants.MU_T] = tfp.distributions.Deterministic(
531
- tf.einsum(
530
+ base_vars[constants.MU_T] = backend.tfd.Deterministic(
531
+ backend.einsum(
532
532
  "...k,kt->...t",
533
533
  base_vars[constants.KNOT_VALUES],
534
- tf.convert_to_tensor(mmm.knot_info.weights),
534
+ backend.to_tensor(mmm.knot_info.weights),
535
535
  ),
536
536
  name=constants.MU_T,
537
537
  ).sample()
@@ -544,14 +544,14 @@ class PriorDistributionSampler:
544
544
  constants.XI_C: prior.xi_c.sample(**sample_kwargs),
545
545
  }
546
546
 
547
- gamma_gc_dev = tfp.distributions.Sample(
548
- tfp.distributions.Normal(0, 1),
547
+ gamma_gc_dev = backend.tfd.Sample(
548
+ backend.tfd.Normal(0, 1),
549
549
  [mmm.n_geos, mmm.n_controls],
550
550
  name=constants.GAMMA_GC_DEV,
551
551
  ).sample(**sample_kwargs)
552
- base_vars[constants.GAMMA_GC] = tfp.distributions.Deterministic(
553
- base_vars[constants.GAMMA_C][..., tf.newaxis, :]
554
- + base_vars[constants.XI_C][..., tf.newaxis, :] * gamma_gc_dev,
552
+ base_vars[constants.GAMMA_GC] = backend.tfd.Deterministic(
553
+ base_vars[constants.GAMMA_C][..., backend.newaxis, :]
554
+ + base_vars[constants.XI_C][..., backend.newaxis, :] * gamma_gc_dev,
555
555
  name=constants.GAMMA_GC,
556
556
  ).sample()
557
557
 
meridian/model/spec.py CHANGED
@@ -14,6 +14,7 @@
14
14
 
15
15
  """Defines model specification parameters for Meridian."""
16
16
 
17
+ from collections.abc import Mapping
17
18
  import dataclasses
18
19
  from typing import Sequence
19
20
  import warnings
@@ -72,8 +73,7 @@ class ModelSpec:
72
73
  before Hill. This argument does not apply to RF channels. Default:
73
74
  `False`.
74
75
  max_lag: An integer indicating the maximum number of lag periods (≥ `0`) to
75
- include in the Adstock calculation. Can also be set to `None`, which is
76
- equivalent to infinite max lag. Default: `8`.
76
+ include in the Adstock calculation. Default: `8`.
77
77
  unique_sigma_for_each_geo: A boolean indicating whether to use a unique
78
78
  residual variance for each geo. If `False`, then a single residual
79
79
  variance is used for all geos. Default: `False`.
@@ -202,6 +202,20 @@ class ModelSpec:
202
202
  `(n_non_media_channels,)` indicating the non-media variables for which the
203
203
  non-media value will be scaled by population. If `None`, then no non-media
204
204
  variables are scaled by population. Default: `None`.
205
+ adstock_decay_spec: A string or mapping specifying the adstock decay
206
+ function for each media, RF, organic media and organic RF channel.
207
+ * If a string, must be either `'geometric'` or `'binomial'`, specifying
208
+ that decay function for all channels.
209
+ * If a mapping, keys should be channel names and values should be
210
+ `'geometric'` or `'binomial'`, with each key-value pair denoting the
211
+ adstock decay function to use for that channel. Channels that are not
212
+ specified in the mapping default to using 'geometric'.
213
+ Default: `'geometric'`.
214
+ enable_aks: A boolean indicating whether to use the Automatic Knot Selection
215
+ algorithm to select an optimal number of knots for running the model
216
+ instead of the default 1 for national models and n_times for geo models.
217
+ If this is set to `True` and the `knots` arg is provided, then an error
218
+ will be raised. Default: `False`.
205
219
  """
206
220
 
207
221
  prior: prior_distribution.PriorDistribution = dataclasses.field(
@@ -209,7 +223,7 @@ class ModelSpec:
209
223
  )
210
224
  media_effects_dist: str = constants.MEDIA_EFFECTS_LOG_NORMAL
211
225
  hill_before_adstock: bool = False
212
- max_lag: int | None = 8
226
+ max_lag: int = 8
213
227
  unique_sigma_for_each_geo: bool = False
214
228
  media_prior_type: str | None = None
215
229
  rf_prior_type: str | None = None
@@ -227,6 +241,8 @@ class ModelSpec:
227
241
  holdout_id: np.ndarray | None = None
228
242
  control_population_scaling_id: np.ndarray | None = None
229
243
  non_media_population_scaling_id: np.ndarray | None = None
244
+ adstock_decay_spec: str | Mapping[str, str] = constants.GEOMETRIC_DECAY
245
+ enable_aks: bool = False
230
246
 
231
247
  def __post_init__(self):
232
248
  # Validate media_effects_dist.
@@ -311,6 +327,10 @@ class ModelSpec:
311
327
  raise ValueError("The `knots` parameter cannot be an empty list.")
312
328
  if isinstance(self.knots, int) and self.knots == 0:
313
329
  raise ValueError("The `knots` parameter cannot be zero.")
330
+ if self.knots is not None and self.enable_aks:
331
+ raise ValueError(
332
+ "The `knots` parameter cannot be set when `enable_aks` is True."
333
+ )
314
334
 
315
335
  @property
316
336
  def effective_media_prior_type(self) -> str:
@@ -15,8 +15,9 @@
15
15
  """Contains data transformers for various inputs of the Meridian model."""
16
16
 
17
17
  import abc
18
+
19
+ from meridian import backend
18
20
  import numpy as np
19
- import tensorflow as tf
20
21
 
21
22
 
22
23
  __all__ = [
@@ -31,14 +32,14 @@ class TensorTransformer(abc.ABC):
31
32
  """Abstract class for data transformers."""
32
33
 
33
34
  @abc.abstractmethod
34
- @tf.function(jit_compile=True)
35
- def forward(self, tensor: tf.Tensor) -> tf.Tensor:
35
+ @backend.function(jit_compile=True)
36
+ def forward(self, tensor: backend.Tensor) -> backend.Tensor:
36
37
  """Transforms a given tensor."""
37
38
  raise NotImplementedError("`forward` must be implemented.")
38
39
 
39
40
  @abc.abstractmethod
40
- @tf.function(jit_compile=True)
41
- def inverse(self, tensor: tf.Tensor) -> tf.Tensor:
41
+ @backend.function(jit_compile=True)
42
+ def inverse(self, tensor: backend.Tensor) -> backend.Tensor:
42
43
  """Transforms back a given tensor."""
43
44
  raise NotImplementedError("`inverse` must be implemented.")
44
45
 
@@ -52,8 +53,8 @@ class MediaTransformer(TensorTransformer):
52
53
 
53
54
  def __init__(
54
55
  self,
55
- media: tf.Tensor,
56
- population: tf.Tensor,
56
+ media: backend.Tensor,
57
+ population: backend.Tensor,
57
58
  ):
58
59
  """`MediaTransformer` constructor.
59
60
 
@@ -63,22 +64,27 @@ class MediaTransformer(TensorTransformer):
63
64
  population: A tensor of dimension `(n_geos,)` containing the population of
64
65
  each geo, used to compute the scale factors.
65
66
  """
66
- population_scaled_media = tf.math.divide_no_nan(
67
- media, population[:, tf.newaxis, tf.newaxis]
67
+ population_scaled_media = backend.divide_no_nan(
68
+ media, population[:, backend.newaxis, backend.newaxis]
68
69
  )
69
70
  # Replace zeros with NaNs
70
- population_scaled_media_nan = tf.where(
71
+ population_scaled_media_nan = backend.where(
71
72
  population_scaled_media == 0, np.nan, population_scaled_media
72
73
  )
73
74
  # Tensor of medians of the positive portion of `media`. Used as a component
74
75
  # for scaling.
75
- self._population_scaled_median_m = tf.numpy_function(
76
+ self._population_scaled_median_m = backend.numpy_function(
76
77
  func=lambda x: np.nanmedian(x, axis=[0, 1]),
77
78
  inp=[population_scaled_media_nan],
78
- Tout=tf.float32,
79
+ Tout=backend.float32,
79
80
  )
81
+ if backend.reduce_any(backend.is_nan(self._population_scaled_median_m)):
82
+ raise ValueError(
83
+ "MediaTransformer has a NaN population-scaled non-zero median due to"
84
+ " a media channel with either all zeroes or all NaNs."
85
+ )
80
86
  # Tensor of dimensions (`n_geos` x 1) of weights for scaling `metric`.
81
- self._scale_factors_gm = tf.einsum(
87
+ self._scale_factors_gm = backend.einsum(
82
88
  "g,m->gm", population, self._population_scaled_median_m
83
89
  )
84
90
 
@@ -86,15 +92,15 @@ class MediaTransformer(TensorTransformer):
86
92
  def population_scaled_median_m(self):
87
93
  return self._population_scaled_median_m
88
94
 
89
- @tf.function(jit_compile=True)
90
- def forward(self, tensor: tf.Tensor) -> tf.Tensor:
95
+ @backend.function(jit_compile=True)
96
+ def forward(self, tensor: backend.Tensor) -> backend.Tensor:
91
97
  """Scales a given tensor using the stored scale factors."""
92
- return tensor / self._scale_factors_gm[:, tf.newaxis, :]
98
+ return tensor / self._scale_factors_gm[:, backend.newaxis, :]
93
99
 
94
- @tf.function(jit_compile=True)
95
- def inverse(self, tensor: tf.Tensor) -> tf.Tensor:
100
+ @backend.function(jit_compile=True)
101
+ def inverse(self, tensor: backend.Tensor) -> backend.Tensor:
96
102
  """Scales a given tensor using the inversed stored scale factors."""
97
- return tensor * self._scale_factors_gm[:, tf.newaxis, :]
103
+ return tensor * self._scale_factors_gm[:, backend.newaxis, :]
98
104
 
99
105
 
100
106
  class CenteringAndScalingTransformer(TensorTransformer):
@@ -108,9 +114,9 @@ class CenteringAndScalingTransformer(TensorTransformer):
108
114
 
109
115
  def __init__(
110
116
  self,
111
- tensor: tf.Tensor,
112
- population: tf.Tensor,
113
- population_scaling_id: tf.Tensor | None = None,
117
+ tensor: backend.Tensor,
118
+ population: backend.Tensor,
119
+ population_scaling_id: backend.Tensor | None = None,
114
120
  ):
115
121
  """`CenteringAndScalingTransformer` constructor.
116
122
 
@@ -124,25 +130,25 @@ class CenteringAndScalingTransformer(TensorTransformer):
124
130
  scaled by population.
125
131
  """
126
132
  if population_scaling_id is not None:
127
- self._population_scaling_factors = tf.where(
133
+ self._population_scaling_factors = backend.where(
128
134
  population_scaling_id,
129
135
  population[:, None],
130
- tf.ones_like(population)[:, None],
136
+ backend.ones_like(population)[:, None],
131
137
  )
132
138
  population_scaled_tensor = (
133
139
  tensor / self._population_scaling_factors[:, None, :]
134
140
  )
135
- self._means = tf.reduce_mean(population_scaled_tensor, axis=(0, 1))
136
- self._stdevs = tf.math.reduce_std(population_scaled_tensor, axis=(0, 1))
141
+ self._means = backend.reduce_mean(population_scaled_tensor, axis=(0, 1))
142
+ self._stdevs = backend.reduce_std(population_scaled_tensor, axis=(0, 1))
137
143
  else:
138
144
  self._population_scaling_factors = None
139
- self._means = tf.reduce_mean(tensor, axis=(0, 1))
140
- self._stdevs = tf.math.reduce_std(tensor, axis=(0, 1))
145
+ self._means = backend.reduce_mean(tensor, axis=(0, 1))
146
+ self._stdevs = backend.reduce_std(tensor, axis=(0, 1))
141
147
 
142
- @tf.function(jit_compile=True)
148
+ @backend.function(jit_compile=True)
143
149
  def forward(
144
- self, tensor: tf.Tensor, apply_population_scaling: bool = True
145
- ) -> tf.Tensor:
150
+ self, tensor: backend.Tensor, apply_population_scaling: bool = True
151
+ ) -> backend.Tensor:
146
152
  """Scales a given tensor using the stored coefficients.
147
153
 
148
154
  Args:
@@ -156,10 +162,10 @@ class CenteringAndScalingTransformer(TensorTransformer):
156
162
  and self._population_scaling_factors is not None
157
163
  ):
158
164
  tensor /= self._population_scaling_factors[:, None, :]
159
- return tf.math.divide_no_nan(tensor - self._means, self._stdevs)
165
+ return backend.divide_no_nan(tensor - self._means, self._stdevs)
160
166
 
161
- @tf.function(jit_compile=True)
162
- def inverse(self, tensor: tf.Tensor) -> tf.Tensor:
167
+ @backend.function(jit_compile=True)
168
+ def inverse(self, tensor: backend.Tensor) -> backend.Tensor:
163
169
  """Scales back a given tensor using the stored coefficients."""
164
170
  scaled_tensor = tensor * self._stdevs + self._means
165
171
  return (
@@ -178,8 +184,8 @@ class KpiTransformer(TensorTransformer):
178
184
 
179
185
  def __init__(
180
186
  self,
181
- kpi: tf.Tensor,
182
- population: tf.Tensor,
187
+ kpi: backend.Tensor,
188
+ population: backend.Tensor,
183
189
  ):
184
190
  """`KpiTransformer` constructor.
185
191
 
@@ -190,11 +196,11 @@ class KpiTransformer(TensorTransformer):
190
196
  each geo, used to to compute the population scale factors.
191
197
  """
192
198
  self._population = population
193
- population_scaled_kpi = tf.math.divide_no_nan(
194
- kpi, self._population[:, tf.newaxis]
199
+ population_scaled_kpi = backend.divide_no_nan(
200
+ kpi, self._population[:, backend.newaxis]
195
201
  )
196
- self._population_scaled_mean = tf.reduce_mean(population_scaled_kpi)
197
- self._population_scaled_stdev = tf.math.reduce_std(population_scaled_kpi)
202
+ self._population_scaled_mean = backend.reduce_mean(population_scaled_kpi)
203
+ self._population_scaled_stdev = backend.reduce_std(population_scaled_kpi)
198
204
 
199
205
  @property
200
206
  def population_scaled_mean(self):
@@ -204,18 +210,18 @@ class KpiTransformer(TensorTransformer):
204
210
  def population_scaled_stdev(self):
205
211
  return self._population_scaled_stdev
206
212
 
207
- @tf.function(jit_compile=True)
208
- def forward(self, tensor: tf.Tensor) -> tf.Tensor:
213
+ @backend.function(jit_compile=True)
214
+ def forward(self, tensor: backend.Tensor) -> backend.Tensor:
209
215
  """Scales a given tensor using the stored coefficients."""
210
- return tf.math.divide_no_nan(
211
- tf.math.divide_no_nan(tensor, self._population[:, tf.newaxis])
216
+ return backend.divide_no_nan(
217
+ backend.divide_no_nan(tensor, self._population[:, backend.newaxis])
212
218
  - self._population_scaled_mean,
213
219
  self._population_scaled_stdev,
214
220
  )
215
221
 
216
- @tf.function(jit_compile=True)
217
- def inverse(self, tensor: tf.Tensor) -> tf.Tensor:
222
+ @backend.function(jit_compile=True)
223
+ def inverse(self, tensor: backend.Tensor) -> backend.Tensor:
218
224
  """Scales back a given tensor using the stored coefficients."""
219
225
  return (
220
226
  tensor * self._population_scaled_stdev + self._population_scaled_mean
221
- ) * self._population[:, tf.newaxis]
227
+ ) * self._population[:, backend.newaxis]
meridian/version.py CHANGED
@@ -14,4 +14,4 @@
14
14
 
15
15
  """Module for Meridian version."""
16
16
 
17
- __version__ = "1.1.5"
17
+ __version__ = "1.2.0"