google-meridian 1.1.6__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,735 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Meridian EDA Engine."""
16
+
17
+ import dataclasses
18
+ import functools
19
+ from typing import Callable, Dict, Optional, TypeAlias
20
+ from meridian import constants
21
+ from meridian.model import model
22
+ from meridian.model import transformers
23
+ import numpy as np
24
+ import tensorflow as tf
25
+ import xarray as xr
26
+
27
+
28
+ _DEFAULT_DA_VAR_AGG_FUNCTION = np.sum
29
+ AggregationMap: TypeAlias = Dict[str, Callable[[xr.DataArray], np.ndarray]]
30
+
31
+
32
+ @dataclasses.dataclass(frozen=True, kw_only=True)
33
+ class ReachFrequencyData:
34
+ """Holds reach and frequency data arrays.
35
+
36
+ Attributes:
37
+ reach_raw_da: Raw reach data.
38
+ reach_scaled_da: Scaled reach data.
39
+ reach_raw_da_national: National raw reach data.
40
+ reach_scaled_da_national: National scaled reach data.
41
+ frequency_da: Frequency data.
42
+ frequency_da_national: National frequency data.
43
+ rf_impressions_scaled_da: Scaled reach * frequency impressions data.
44
+ rf_impressions_scaled_da_national: National scaled reach * frequency
45
+ impressions data.
46
+ rf_impressions_raw_da: Raw reach * frequency impressions data.
47
+ rf_impressions_raw_da_national: National raw reach * frequency impressions
48
+ data.
49
+ """
50
+
51
+ reach_raw_da: xr.DataArray
52
+ reach_scaled_da: xr.DataArray
53
+ reach_raw_da_national: xr.DataArray
54
+ reach_scaled_da_national: xr.DataArray
55
+ frequency_da: xr.DataArray
56
+ frequency_da_national: xr.DataArray
57
+ rf_impressions_scaled_da: xr.DataArray
58
+ rf_impressions_scaled_da_national: xr.DataArray
59
+ rf_impressions_raw_da: xr.DataArray
60
+ rf_impressions_raw_da_national: xr.DataArray
61
+
62
+
63
+ @dataclasses.dataclass(frozen=True, kw_only=True)
64
+ class AggregationConfig:
65
+ """Configuration for custom aggregation functions.
66
+
67
+ Attributes:
68
+ control_variables: A dictionary mapping control variable names to
69
+ aggregation functions. Defaults to `np.sum` if a variable is not
70
+ specified.
71
+ non_media_treatments: A dictionary mapping non-media variable names to
72
+ aggregation functions. Defaults to `np.sum` if a variable is not
73
+ specified.
74
+ """
75
+
76
+ control_variables: AggregationMap = dataclasses.field(default_factory=dict)
77
+ non_media_treatments: AggregationMap = dataclasses.field(default_factory=dict)
78
+
79
+
80
+ class EDAEngine:
81
+ """Meridian EDA Engine."""
82
+
83
+ def __init__(
84
+ self,
85
+ meridian: model.Meridian,
86
+ agg_config: AggregationConfig = AggregationConfig(),
87
+ ):
88
+ self._meridian = meridian
89
+ self._agg_config = agg_config
90
+
91
+ @functools.cached_property
92
+ def controls_scaled_da(self) -> xr.DataArray | None:
93
+ if self._meridian.input_data.controls is None:
94
+ return None
95
+ controls_scaled_da = _data_array_like(
96
+ da=self._meridian.input_data.controls,
97
+ values=self._meridian.controls_scaled,
98
+ )
99
+ return controls_scaled_da
100
+
101
+ @functools.cached_property
102
+ def controls_scaled_da_national(self) -> xr.DataArray | None:
103
+ """Returns the national controls data array."""
104
+ if self._meridian.input_data.controls is None:
105
+ return None
106
+ if self._meridian.is_national:
107
+ if self.controls_scaled_da is None:
108
+ # This case should be impossible given the check above.
109
+ raise RuntimeError(
110
+ 'controls_scaled_da is None when controls is not None.'
111
+ )
112
+ return self.controls_scaled_da.squeeze(constants.GEO)
113
+ else:
114
+ return self._aggregate_and_scale_geo_da(
115
+ self._meridian.input_data.controls,
116
+ transformers.CenteringAndScalingTransformer,
117
+ constants.CONTROL_VARIABLE,
118
+ self._agg_config.control_variables,
119
+ )
120
+
121
+ @functools.cached_property
122
+ def media_raw_da(self) -> xr.DataArray | None:
123
+ if self._meridian.input_data.media is None:
124
+ return None
125
+ return self._truncate_media_time(self._meridian.input_data.media)
126
+
127
+ @functools.cached_property
128
+ def media_scaled_da(self) -> xr.DataArray | None:
129
+ if self._meridian.input_data.media is None:
130
+ return None
131
+ media_scaled_da = _data_array_like(
132
+ da=self._meridian.input_data.media,
133
+ values=self._meridian.media_tensors.media_scaled,
134
+ )
135
+ return self._truncate_media_time(media_scaled_da)
136
+
137
+ @functools.cached_property
138
+ def media_spend_da(self) -> xr.DataArray | None:
139
+ if self._meridian.input_data.media_spend is None:
140
+ return None
141
+ media_spend_da = _data_array_like(
142
+ da=self._meridian.input_data.media_spend,
143
+ values=self._meridian.media_tensors.media_spend,
144
+ )
145
+ # No need to truncate the media time for media spend.
146
+ return media_spend_da
147
+
148
+ @functools.cached_property
149
+ def media_spend_da_national(self) -> xr.DataArray | None:
150
+ """Returns the national media spend data array."""
151
+ if self._meridian.input_data.media_spend is None:
152
+ return None
153
+ if self._meridian.is_national:
154
+ if self.media_spend_da is None:
155
+ # This case should be impossible given the check above.
156
+ raise RuntimeError(
157
+ 'media_spend_da is None when media_spend is not None.'
158
+ )
159
+ return self.media_spend_da.squeeze(constants.GEO)
160
+ else:
161
+ return self._aggregate_and_scale_geo_da(
162
+ self._meridian.input_data.media_spend,
163
+ None,
164
+ )
165
+
166
+ @functools.cached_property
167
+ def media_raw_da_national(self) -> xr.DataArray | None:
168
+ if self.media_raw_da is None:
169
+ return None
170
+ if self._meridian.is_national:
171
+ return self.media_raw_da.squeeze(constants.GEO)
172
+ else:
173
+ # Note that media is summable by assumption.
174
+ return self._aggregate_and_scale_geo_da(
175
+ self.media_raw_da,
176
+ None,
177
+ )
178
+
179
+ @functools.cached_property
180
+ def media_scaled_da_national(self) -> xr.DataArray | None:
181
+ if self.media_scaled_da is None:
182
+ return None
183
+ if self._meridian.is_national:
184
+ return self.media_scaled_da.squeeze(constants.GEO)
185
+ else:
186
+ # Note that media is summable by assumption.
187
+ return self._aggregate_and_scale_geo_da(
188
+ self.media_raw_da,
189
+ transformers.MediaTransformer,
190
+ )
191
+
192
+ @functools.cached_property
193
+ def organic_media_raw_da(self) -> xr.DataArray | None:
194
+ if self._meridian.input_data.organic_media is None:
195
+ return None
196
+ return self._truncate_media_time(self._meridian.input_data.organic_media)
197
+
198
+ @functools.cached_property
199
+ def organic_media_scaled_da(self) -> xr.DataArray | None:
200
+ if self._meridian.input_data.organic_media is None:
201
+ return None
202
+ organic_media_scaled_da = _data_array_like(
203
+ da=self._meridian.input_data.organic_media,
204
+ values=self._meridian.organic_media_tensors.organic_media_scaled,
205
+ )
206
+ return self._truncate_media_time(organic_media_scaled_da)
207
+
208
+ @functools.cached_property
209
+ def organic_media_raw_da_national(self) -> xr.DataArray | None:
210
+ if self.organic_media_raw_da is None:
211
+ return None
212
+ if self._meridian.is_national:
213
+ return self.organic_media_raw_da.squeeze(constants.GEO)
214
+ else:
215
+ # Note that organic media is summable by assumption.
216
+ return self._aggregate_and_scale_geo_da(self.organic_media_raw_da, None)
217
+
218
+ @functools.cached_property
219
+ def organic_media_scaled_da_national(self) -> xr.DataArray | None:
220
+ if self.organic_media_scaled_da is None:
221
+ return None
222
+ if self._meridian.is_national:
223
+ return self.organic_media_scaled_da.squeeze(constants.GEO)
224
+ else:
225
+ # Note that organic media is summable by assumption.
226
+ return self._aggregate_and_scale_geo_da(
227
+ self.organic_media_raw_da,
228
+ transformers.MediaTransformer,
229
+ )
230
+
231
+ @functools.cached_property
232
+ def non_media_scaled_da(self) -> xr.DataArray | None:
233
+ if self._meridian.input_data.non_media_treatments is None:
234
+ return None
235
+ non_media_scaled_da = _data_array_like(
236
+ da=self._meridian.input_data.non_media_treatments,
237
+ values=self._meridian.non_media_treatments_normalized,
238
+ )
239
+ return non_media_scaled_da
240
+
241
+ @functools.cached_property
242
+ def non_media_scaled_da_national(self) -> xr.DataArray | None:
243
+ """Returns the national non-media treatment data array."""
244
+ if self._meridian.input_data.non_media_treatments is None:
245
+ return None
246
+ if self._meridian.is_national:
247
+ if self.non_media_scaled_da is None:
248
+ # This case should be impossible given the check above.
249
+ raise RuntimeError(
250
+ 'non_media_scaled_da is None when non_media_treatments is not None.'
251
+ )
252
+ return self.non_media_scaled_da.squeeze(constants.GEO)
253
+ else:
254
+ return self._aggregate_and_scale_geo_da(
255
+ self._meridian.input_data.non_media_treatments,
256
+ transformers.CenteringAndScalingTransformer,
257
+ constants.NON_MEDIA_CHANNEL,
258
+ self._agg_config.non_media_treatments,
259
+ )
260
+
261
+ @functools.cached_property
262
+ def rf_spend_da(self) -> xr.DataArray | None:
263
+ if self._meridian.input_data.rf_spend is None:
264
+ return None
265
+ rf_spend_da = _data_array_like(
266
+ da=self._meridian.input_data.rf_spend,
267
+ values=self._meridian.rf_tensors.rf_spend,
268
+ )
269
+ return rf_spend_da
270
+
271
+ @functools.cached_property
272
+ def rf_spend_da_national(self) -> xr.DataArray | None:
273
+ if self._meridian.input_data.rf_spend is None:
274
+ return None
275
+ if self._meridian.is_national:
276
+ if self.rf_spend_da is None:
277
+ # This case should be impossible given the check above.
278
+ raise RuntimeError('rf_spend_da is None when rf_spend is not None.')
279
+ return self.rf_spend_da.squeeze(constants.GEO)
280
+ else:
281
+ return self._aggregate_and_scale_geo_da(
282
+ self._meridian.input_data.rf_spend, None
283
+ )
284
+
285
+ @functools.cached_property
286
+ def _rf_data(self) -> ReachFrequencyData | None:
287
+ if self._meridian.input_data.reach is None:
288
+ return None
289
+ return self._get_rf_data(
290
+ self._meridian.input_data.reach,
291
+ self._meridian.input_data.frequency,
292
+ is_organic=False,
293
+ )
294
+
295
+ @property
296
+ def reach_raw_da(self) -> xr.DataArray | None:
297
+ if self._rf_data is None:
298
+ return None
299
+ return self._rf_data.reach_raw_da
300
+
301
+ @property
302
+ def reach_scaled_da(self) -> xr.DataArray | None:
303
+ if self._rf_data is None:
304
+ return None
305
+ return self._rf_data.reach_scaled_da
306
+
307
+ @property
308
+ def reach_raw_da_national(self) -> xr.DataArray | None:
309
+ if self._rf_data is None:
310
+ return None
311
+ return self._rf_data.reach_raw_da_national
312
+
313
+ @property
314
+ def reach_scaled_da_national(self) -> xr.DataArray | None:
315
+ if self._rf_data is None:
316
+ return None
317
+ return self._rf_data.reach_scaled_da_national
318
+
319
+ @property
320
+ def frequency_da(self) -> xr.DataArray | None:
321
+ if self._rf_data is None:
322
+ return None
323
+ return self._rf_data.frequency_da
324
+
325
+ @property
326
+ def frequency_da_national(self) -> xr.DataArray | None:
327
+ if self._rf_data is None:
328
+ return None
329
+ return self._rf_data.frequency_da_national
330
+
331
+ @property
332
+ def rf_impressions_raw_da(self) -> xr.DataArray | None:
333
+ if self._rf_data is None:
334
+ return None
335
+ return self._rf_data.rf_impressions_raw_da
336
+
337
+ @property
338
+ def rf_impressions_raw_da_national(self) -> xr.DataArray | None:
339
+ if self._rf_data is None:
340
+ return None
341
+ return self._rf_data.rf_impressions_raw_da_national
342
+
343
+ @property
344
+ def rf_impressions_scaled_da(self) -> xr.DataArray | None:
345
+ if self._rf_data is None:
346
+ return None
347
+ return self._rf_data.rf_impressions_scaled_da
348
+
349
+ @property
350
+ def rf_impressions_scaled_da_national(self) -> xr.DataArray | None:
351
+ if self._rf_data is None:
352
+ return None
353
+ return self._rf_data.rf_impressions_scaled_da_national
354
+
355
+ @functools.cached_property
356
+ def _organic_rf_data(self) -> ReachFrequencyData | None:
357
+ if self._meridian.input_data.organic_reach is None:
358
+ return None
359
+ return self._get_rf_data(
360
+ self._meridian.input_data.organic_reach,
361
+ self._meridian.input_data.organic_frequency,
362
+ is_organic=True,
363
+ )
364
+
365
+ @property
366
+ def organic_reach_raw_da(self) -> xr.DataArray | None:
367
+ if self._organic_rf_data is None:
368
+ return None
369
+ return self._organic_rf_data.reach_raw_da
370
+
371
+ @property
372
+ def organic_reach_scaled_da(self) -> xr.DataArray | None:
373
+ if self._organic_rf_data is None:
374
+ return None
375
+ return self._organic_rf_data.reach_scaled_da
376
+
377
+ @property
378
+ def organic_reach_raw_da_national(self) -> xr.DataArray | None:
379
+ if self._organic_rf_data is None:
380
+ return None
381
+ return self._organic_rf_data.reach_raw_da_national
382
+
383
+ @property
384
+ def organic_reach_scaled_da_national(self) -> xr.DataArray | None:
385
+ if self._organic_rf_data is None:
386
+ return None
387
+ return self._organic_rf_data.reach_scaled_da_national
388
+
389
+ @property
390
+ def organic_rf_impressions_scaled_da(self) -> xr.DataArray | None:
391
+ if self._organic_rf_data is None:
392
+ return None
393
+ return self._organic_rf_data.rf_impressions_scaled_da
394
+
395
+ @property
396
+ def organic_rf_impressions_scaled_da_national(self) -> xr.DataArray | None:
397
+ if self._organic_rf_data is None:
398
+ return None
399
+ return self._organic_rf_data.rf_impressions_scaled_da_national
400
+
401
+ @property
402
+ def organic_frequency_da(self) -> xr.DataArray | None:
403
+ if self._organic_rf_data is None:
404
+ return None
405
+ return self._organic_rf_data.frequency_da
406
+
407
+ @property
408
+ def organic_frequency_da_national(self) -> xr.DataArray | None:
409
+ if self._organic_rf_data is None:
410
+ return None
411
+ return self._organic_rf_data.frequency_da_national
412
+
413
+ @property
414
+ def organic_rf_impressions_raw_da(self) -> xr.DataArray | None:
415
+ if self._organic_rf_data is None:
416
+ return None
417
+ return self._organic_rf_data.rf_impressions_raw_da
418
+
419
+ @property
420
+ def organic_rf_impressions_raw_da_national(self) -> xr.DataArray | None:
421
+ if self._organic_rf_data is None:
422
+ return None
423
+ return self._organic_rf_data.rf_impressions_raw_da_national
424
+
425
+ @functools.cached_property
426
+ def geo_population_da(self) -> xr.DataArray | None:
427
+ if self._meridian.is_national:
428
+ return None
429
+ return xr.DataArray(
430
+ self._meridian.population,
431
+ coords={constants.GEO: self._meridian.input_data.geo.values},
432
+ dims=[constants.GEO],
433
+ name=constants.POPULATION,
434
+ )
435
+
436
+ @functools.cached_property
437
+ def kpi_scaled_da(self) -> xr.DataArray:
438
+ return _data_array_like(
439
+ da=self._meridian.input_data.kpi,
440
+ values=self._meridian.kpi_scaled,
441
+ )
442
+
443
+ @functools.cached_property
444
+ def kpi_scaled_da_national(self) -> xr.DataArray:
445
+ if self._meridian.is_national:
446
+ return self.kpi_scaled_da.squeeze(constants.GEO)
447
+ else:
448
+ # Note that kpi is summable by assumption.
449
+ return self._aggregate_and_scale_geo_da(
450
+ self._meridian.input_data.kpi,
451
+ transformers.CenteringAndScalingTransformer,
452
+ )
453
+
454
+ @functools.cached_property
455
+ def treatment_control_scaled_ds(self) -> xr.Dataset:
456
+ """Returns a Dataset containing all scaled treatments and controls.
457
+
458
+ This includes media, RF impressions, organic media, organic RF impressions,
459
+ non-media treatments, and control variables, all at the geo level.
460
+ """
461
+ to_merge = [
462
+ da
463
+ for da in [
464
+ self.media_scaled_da,
465
+ self.rf_impressions_scaled_da,
466
+ self.organic_media_scaled_da,
467
+ self.organic_rf_impressions_scaled_da,
468
+ self.controls_scaled_da,
469
+ self.non_media_scaled_da,
470
+ ]
471
+ if da is not None
472
+ ]
473
+ return xr.merge(to_merge, join='inner')
474
+
475
+ @functools.cached_property
476
+ def treatment_control_scaled_ds_national(self) -> xr.Dataset:
477
+ """Returns a Dataset containing all scaled treatments and controls.
478
+
479
+ This includes media, RF impressions, organic media, organic RF impressions,
480
+ non-media treatments, and control variables, all at the national level.
481
+ """
482
+ to_merge_national = [
483
+ da
484
+ for da in [
485
+ self.media_scaled_da_national,
486
+ self.rf_impressions_scaled_da_national,
487
+ self.organic_media_scaled_da_national,
488
+ self.organic_rf_impressions_scaled_da_national,
489
+ self.controls_scaled_da_national,
490
+ self.non_media_scaled_da_national,
491
+ ]
492
+ if da is not None
493
+ ]
494
+ return xr.merge(to_merge_national, join='inner')
495
+
496
+ def _truncate_media_time(self, da: xr.DataArray) -> xr.DataArray:
497
+ """Truncates the first `start` elements of the media time of a variable."""
498
+ # This should not happen. If it does, it means this function is mis-used.
499
+ if constants.MEDIA_TIME not in da.coords:
500
+ raise ValueError(
501
+ f'Variable does not have a media time coordinate: {da.name}.'
502
+ )
503
+
504
+ start = self._meridian.n_media_times - self._meridian.n_times
505
+ da = da.copy().isel({constants.MEDIA_TIME: slice(start, None)})
506
+ da = da.rename({constants.MEDIA_TIME: constants.TIME})
507
+ return da
508
+
509
+ def _scale_xarray(
510
+ self,
511
+ xarray: xr.DataArray,
512
+ transformer_class: Optional[type[transformers.TensorTransformer]],
513
+ population: tf.Tensor = tf.constant([1.0], dtype=tf.float32),
514
+ ):
515
+ """Scales xarray values with a TensorTransformer."""
516
+ da = xarray.copy()
517
+
518
+ if transformer_class is None:
519
+ return da
520
+ elif transformer_class is transformers.CenteringAndScalingTransformer:
521
+ xarray_transformer = transformers.CenteringAndScalingTransformer(
522
+ tensor=da.values, population=population
523
+ )
524
+ elif transformer_class is transformers.MediaTransformer:
525
+ xarray_transformer = transformers.MediaTransformer(
526
+ media=da.values, population=population
527
+ )
528
+ else:
529
+ raise ValueError(
530
+ 'Unknown transformer class: '
531
+ + str(transformer_class)
532
+ + '.\nMust be one of: CenteringAndScalingTransformer or'
533
+ ' MediaTransformer.'
534
+ )
535
+ da.values = xarray_transformer.forward(da.values)
536
+ return da
537
+
538
+ def _aggregate_variables(
539
+ self,
540
+ da_geo: xr.DataArray,
541
+ channel_dim: str,
542
+ da_var_agg_map: AggregationMap,
543
+ keepdims: bool = True,
544
+ ) -> xr.DataArray:
545
+ """Aggregates variables within a DataArray based on user-defined functions.
546
+
547
+ Args:
548
+ da_geo: The geo-level DataArray containing multiple variables along
549
+ channel_dim.
550
+ channel_dim: The name of the dimension coordinate to aggregate over (e.g.,
551
+ constants.CONTROL_VARIABLE).
552
+ da_var_agg_map: A dictionary mapping dataArray variable names to
553
+ aggregation functions.
554
+ keepdims: Whether to keep the dimensions of the aggregated DataArray.
555
+
556
+ Returns:
557
+ An xr.DataArray aggregated to the national level, with each variable
558
+ aggregated according to the da_var_agg_map.
559
+ """
560
+ agg_results = []
561
+ for var_name in da_geo[channel_dim].values:
562
+ var_data = da_geo.sel({channel_dim: var_name})
563
+ agg_func = da_var_agg_map.get(var_name, _DEFAULT_DA_VAR_AGG_FUNCTION)
564
+ # Apply the aggregation function over the GEO dimension
565
+ aggregated_data = var_data.reduce(
566
+ agg_func, dim=constants.GEO, keepdims=keepdims
567
+ )
568
+ agg_results.append(aggregated_data)
569
+
570
+ # Combine the aggregated variables back into a single DataArray
571
+ return xr.concat(agg_results, dim=channel_dim).transpose(..., channel_dim)
572
+
573
+ def _aggregate_and_scale_geo_da(
574
+ self,
575
+ da_geo: xr.DataArray,
576
+ transformer_class: Optional[type[transformers.TensorTransformer]],
577
+ channel_dim: Optional[str] = None,
578
+ da_var_agg_map: Optional[AggregationMap] = None,
579
+ ) -> xr.DataArray:
580
+ """Aggregate geo-level xr.DataArray to national level and then scale values.
581
+
582
+ Args:
583
+ da_geo: The geo-level DataArray to convert.
584
+ transformer_class: The TensorTransformer class to apply after summing to
585
+ national level. Must be None, CenteringAndScalingTransformer, or
586
+ MediaTransformer.
587
+ channel_dim: The name of the dimension coordinate to aggregate over (e.g.,
588
+ constants.CONTROL_VARIABLE). If None, standard sum aggregation is used.
589
+ da_var_agg_map: A dictionary mapping dataArray variable names to
590
+ aggregation functions. Used only if channel_dim is not None.
591
+
592
+ Returns:
593
+ An xr.DataArray representing the aggregated and scaled national-level
594
+ data.
595
+ """
596
+ temp_geo_dim = constants.NATIONAL_MODEL_DEFAULT_GEO_NAME
597
+
598
+ if da_var_agg_map is None:
599
+ da_var_agg_map = {}
600
+
601
+ if channel_dim is not None:
602
+ da_national = self._aggregate_variables(
603
+ da_geo, channel_dim, da_var_agg_map
604
+ )
605
+ else:
606
+ # Default to sum aggregation if no channel dimension is provided
607
+ da_national = da_geo.sum(
608
+ dim=constants.GEO, keepdims=True, skipna=False, keep_attrs=True
609
+ )
610
+
611
+ da_national = da_national.assign_coords({constants.GEO: [temp_geo_dim]})
612
+ da_national.values = tf.cast(da_national.values, tf.float32)
613
+ da_national = self._scale_xarray(da_national, transformer_class)
614
+
615
+ return da_national.sel({constants.GEO: temp_geo_dim}, drop=True)
616
+
617
+ def _get_rf_data(
618
+ self,
619
+ reach_raw_da: xr.DataArray,
620
+ freq_raw_da: xr.DataArray,
621
+ is_organic: bool,
622
+ ) -> ReachFrequencyData:
623
+ """Get impressions and frequencies data arrays for RF channels."""
624
+ if is_organic:
625
+ scaled_reach_values = (
626
+ self._meridian.organic_rf_tensors.organic_reach_scaled
627
+ )
628
+ else:
629
+ scaled_reach_values = self._meridian.rf_tensors.reach_scaled
630
+ reach_scaled_da = _data_array_like(
631
+ da=reach_raw_da, values=scaled_reach_values
632
+ )
633
+ # Truncate the media time for reach and scaled reach.
634
+ reach_raw_da = self._truncate_media_time(reach_raw_da)
635
+ reach_scaled_da = self._truncate_media_time(reach_scaled_da)
636
+
637
+ # The geo level frequency
638
+ frequency_da = self._truncate_media_time(freq_raw_da)
639
+
640
+ # The raw geo level impression
641
+ # It's equal to reach * frequency.
642
+ impressions_raw_da = reach_raw_da * frequency_da
643
+ impressions_raw_da.name = (
644
+ constants.ORGANIC_RF_IMPRESSIONS
645
+ if is_organic
646
+ else constants.RF_IMPRESSIONS
647
+ )
648
+ impressions_raw_da.values = tf.cast(impressions_raw_da.values, tf.float32)
649
+
650
+ if self._meridian.is_national:
651
+ reach_raw_da_national = reach_raw_da.squeeze(constants.GEO)
652
+ reach_scaled_da_national = reach_scaled_da.squeeze(constants.GEO)
653
+ impressions_raw_da_national = impressions_raw_da.squeeze(constants.GEO)
654
+ frequency_da_national = frequency_da.squeeze(constants.GEO)
655
+
656
+ # Scaled impressions
657
+ impressions_scaled_da = self._scale_xarray(
658
+ impressions_raw_da, transformers.MediaTransformer
659
+ )
660
+ impressions_scaled_da_national = impressions_scaled_da.squeeze(
661
+ constants.GEO
662
+ )
663
+ else:
664
+ reach_raw_da_national = self._aggregate_and_scale_geo_da(
665
+ reach_raw_da, None
666
+ )
667
+ reach_scaled_da_national = self._aggregate_and_scale_geo_da(
668
+ reach_raw_da, transformers.MediaTransformer
669
+ )
670
+ impressions_raw_da_national = self._aggregate_and_scale_geo_da(
671
+ impressions_raw_da, None
672
+ )
673
+
674
+ # National frequency is a weighted average of geo frequencies,
675
+ # weighted by reach.
676
+ frequency_da_national = xr.where(
677
+ reach_raw_da_national == 0.0,
678
+ 0.0,
679
+ impressions_raw_da_national / reach_raw_da_national,
680
+ )
681
+ frequency_da_national.name = (
682
+ constants.ORGANIC_PREFIX if is_organic else ''
683
+ ) + constants.FREQUENCY
684
+ frequency_da_national.values = tf.cast(
685
+ frequency_da_national.values, tf.float32
686
+ )
687
+
688
+ # Scale the impressions by population
689
+ impressions_scaled_da = self._scale_xarray(
690
+ impressions_raw_da,
691
+ transformers.MediaTransformer,
692
+ population=self._meridian.population,
693
+ )
694
+
695
+ # Scale the national impressions
696
+ impressions_scaled_da_national = self._aggregate_and_scale_geo_da(
697
+ impressions_raw_da,
698
+ transformers.MediaTransformer,
699
+ )
700
+
701
+ return ReachFrequencyData(
702
+ reach_raw_da=reach_raw_da,
703
+ reach_scaled_da=reach_scaled_da,
704
+ reach_raw_da_national=reach_raw_da_national,
705
+ reach_scaled_da_national=reach_scaled_da_national,
706
+ frequency_da=frequency_da,
707
+ frequency_da_national=frequency_da_national,
708
+ rf_impressions_scaled_da=impressions_scaled_da,
709
+ rf_impressions_scaled_da_national=impressions_scaled_da_national,
710
+ rf_impressions_raw_da=impressions_raw_da,
711
+ rf_impressions_raw_da_national=impressions_raw_da_national,
712
+ )
713
+
714
+
715
+ def _data_array_like(
716
+ *, da: xr.DataArray, values: np.ndarray | tf.Tensor
717
+ ) -> xr.DataArray:
718
+ """Returns a DataArray from `values` with the same structure as `da`.
719
+
720
+ Args:
721
+ da: The DataArray whose structure (dimensions, coordinates, name, and attrs)
722
+ will be used for the new DataArray.
723
+ values: The numpy array or tensorflow tensor to use as the values for the
724
+ new DataArray.
725
+
726
+ Returns:
727
+ A new DataArray with the provided `values` and the same structure as `da`.
728
+ """
729
+ return xr.DataArray(
730
+ values,
731
+ coords=da.coords,
732
+ dims=da.dims,
733
+ name=da.name,
734
+ attrs=da.attrs,
735
+ )