google-meridian 1.1.6__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,306 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Meridian EDA Engine."""
16
+
17
+ import functools
18
+ from typing import Callable, Dict, Optional, TypeAlias
19
+ from meridian import constants
20
+ from meridian.model import model
21
+ from meridian.model import transformers
22
+ import numpy as np
23
+ import tensorflow as tf
24
+ import xarray as xr
25
+
26
+
27
+ _DEFAULT_DA_VAR_AGG_FUNCTION = np.sum
28
+ AggregationMap: TypeAlias = Dict[str, Callable[[xr.DataArray], np.ndarray]]
29
+
30
+
31
+ class EDAEngine:
32
+ """Meridian EDA Engine."""
33
+
34
+ def __init__(self, meridian: model.Meridian):
35
+ self._meridian = meridian
36
+
37
+ @functools.cached_property
38
+ def controls_scaled_da(self) -> xr.DataArray | None:
39
+ if self._meridian.input_data.controls is None:
40
+ return None
41
+ controls_scaled_da = _data_array_like(
42
+ da=self._meridian.input_data.controls,
43
+ values=self._meridian.controls_scaled,
44
+ )
45
+ return controls_scaled_da
46
+
47
+ @functools.cached_property
48
+ def media_raw_da(self) -> xr.DataArray | None:
49
+ if self._meridian.input_data.media is None:
50
+ return None
51
+ return self._truncate_media_time(self._meridian.input_data.media)
52
+
53
+ @functools.cached_property
54
+ def media_scaled_da(self) -> xr.DataArray | None:
55
+ if self._meridian.input_data.media is None:
56
+ return None
57
+ media_scaled_da = _data_array_like(
58
+ da=self._meridian.input_data.media,
59
+ values=self._meridian.media_tensors.media_scaled,
60
+ )
61
+ return self._truncate_media_time(media_scaled_da)
62
+
63
+ @functools.cached_property
64
+ def media_spend_da(self) -> xr.DataArray | None:
65
+ if self._meridian.input_data.media_spend is None:
66
+ return None
67
+ media_spend_da = _data_array_like(
68
+ da=self._meridian.input_data.media_spend,
69
+ values=self._meridian.media_tensors.media_spend,
70
+ )
71
+ # No need to truncate the media time for media spend.
72
+ return media_spend_da
73
+
74
+ @functools.cached_property
75
+ def media_raw_da_national(self) -> xr.DataArray | None:
76
+ if self.media_raw_da is None:
77
+ return None
78
+ if self._meridian.is_national:
79
+ return self.media_raw_da
80
+ else:
81
+ return self._aggregate_and_scale_geo_da(
82
+ self.media_raw_da,
83
+ None,
84
+ )
85
+
86
+ @functools.cached_property
87
+ def media_scaled_da_national(self) -> xr.DataArray | None:
88
+ if self.media_scaled_da is None:
89
+ return None
90
+ if self._meridian.is_national:
91
+ return self.media_scaled_da
92
+ else:
93
+ return self._aggregate_and_scale_geo_da(
94
+ self.media_raw_da,
95
+ transformers.MediaTransformer,
96
+ )
97
+
98
+ @functools.cached_property
99
+ def organic_media_raw_da(self) -> xr.DataArray | None:
100
+ if self._meridian.input_data.organic_media is None:
101
+ return None
102
+ return self._truncate_media_time(self._meridian.input_data.organic_media)
103
+
104
+ @functools.cached_property
105
+ def organic_media_scaled_da(self) -> xr.DataArray | None:
106
+ if self._meridian.input_data.organic_media is None:
107
+ return None
108
+ organic_media_scaled_da = _data_array_like(
109
+ da=self._meridian.input_data.organic_media,
110
+ values=self._meridian.organic_media_tensors.organic_media_scaled,
111
+ )
112
+ return self._truncate_media_time(organic_media_scaled_da)
113
+
114
+ @functools.cached_property
115
+ def organic_media_raw_da_national(self) -> xr.DataArray | None:
116
+ if self.organic_media_raw_da is None:
117
+ return None
118
+ if self._meridian.is_national:
119
+ return self.organic_media_raw_da
120
+ else:
121
+ return self._aggregate_and_scale_geo_da(self.organic_media_raw_da, None)
122
+
123
+ @functools.cached_property
124
+ def organic_media_scaled_da_national(self) -> xr.DataArray | None:
125
+ if self.organic_media_scaled_da is None:
126
+ return None
127
+ if self._meridian.is_national:
128
+ return self.organic_media_scaled_da
129
+ else:
130
+ return self._aggregate_and_scale_geo_da(
131
+ self.organic_media_raw_da,
132
+ transformers.MediaTransformer,
133
+ )
134
+
135
+ @functools.cached_property
136
+ def non_media_scaled_da(self) -> xr.DataArray | None:
137
+ if self._meridian.input_data.non_media_treatments is None:
138
+ return None
139
+ non_media_scaled_da = _data_array_like(
140
+ da=self._meridian.input_data.non_media_treatments,
141
+ values=self._meridian.non_media_treatments_normalized,
142
+ )
143
+ return non_media_scaled_da
144
+
145
+ @functools.cached_property
146
+ def rf_spend_da(self) -> xr.DataArray | None:
147
+ if self._meridian.input_data.rf_spend is None:
148
+ return None
149
+ rf_spend_da = _data_array_like(
150
+ da=self._meridian.input_data.rf_spend,
151
+ values=self._meridian.rf_tensors.rf_spend,
152
+ )
153
+ return rf_spend_da
154
+
155
+ @functools.cached_property
156
+ def rf_spend_da_national(self) -> xr.DataArray | None:
157
+ if self._meridian.input_data.rf_spend is None:
158
+ return None
159
+ if self._meridian.is_national:
160
+ return self.rf_spend_da
161
+ else:
162
+ return self._aggregate_and_scale_geo_da(
163
+ self._meridian.input_data.rf_spend, None
164
+ )
165
+
166
+ def _truncate_media_time(self, da: xr.DataArray) -> xr.DataArray:
167
+ """Truncates the first `start` elements of the media time of a variable."""
168
+ # This should not happen. If it does, it means this function is mis-used.
169
+ if constants.MEDIA_TIME not in da.coords:
170
+ raise ValueError(
171
+ f'Variable does not have a media time coordinate: {da.name}.'
172
+ )
173
+
174
+ start = self._meridian.n_media_times - self._meridian.n_times
175
+ da = da.copy().isel({constants.MEDIA_TIME: slice(start, None)})
176
+ da = da.rename({constants.MEDIA_TIME: constants.TIME})
177
+ return da
178
+
179
+ def _scale_xarray(
180
+ self,
181
+ xarray: xr.DataArray,
182
+ transformer_class: Optional[type[transformers.TensorTransformer]],
183
+ population: tf.Tensor = tf.constant([1.0], dtype=tf.float32),
184
+ ):
185
+ """Scales xarray values with a TensorTransformer."""
186
+ if transformer_class is None:
187
+ return xarray
188
+ elif transformer_class is transformers.CenteringAndScalingTransformer:
189
+ xarray_transformer = transformers.CenteringAndScalingTransformer(
190
+ tensor=xarray.values, population=population
191
+ )
192
+ elif transformer_class is transformers.MediaTransformer:
193
+ xarray_transformer = transformers.MediaTransformer(
194
+ media=xarray.values, population=population
195
+ )
196
+ else:
197
+ raise ValueError(
198
+ 'Unknown transformer class: '
199
+ + str(transformer_class)
200
+ + '.\nMust be one of: CenteringAndScalingTransformer or'
201
+ ' MediaTransformer.'
202
+ )
203
+ xarray.values = xarray_transformer.forward(xarray.values)
204
+ return xarray
205
+
206
+ def _aggregate_variables(
207
+ self,
208
+ da_geo: xr.DataArray,
209
+ channel_dim: str,
210
+ da_var_agg_map: AggregationMap,
211
+ keepdims: bool = True,
212
+ ) -> xr.DataArray:
213
+ """Aggregates variables within a DataArray based on user-defined functions.
214
+
215
+ Args:
216
+ da_geo: The geo-level DataArray containing multiple variables along
217
+ channel_dim.
218
+ channel_dim: The name of the dimension coordinate to aggregate over (e.g.,
219
+ constants.CONTROL_VARIABLE).
220
+ da_var_agg_map: A dictionary mapping dataArray variable names to
221
+ aggregation functions.
222
+ keepdims: Whether to keep the dimensions of the aggregated DataArray.
223
+
224
+ Returns:
225
+ An xr.DataArray aggregated to the national level, with each variable
226
+ aggregated according to the da_var_agg_map.
227
+ """
228
+ agg_results = []
229
+ for var_name in da_geo[channel_dim].values:
230
+ var_data = da_geo.sel({channel_dim: var_name})
231
+ agg_func = da_var_agg_map.get(var_name, _DEFAULT_DA_VAR_AGG_FUNCTION)
232
+ # Apply the aggregation function over the GEO dimension
233
+ aggregated_data = var_data.reduce(
234
+ agg_func, dim=constants.GEO, keepdims=keepdims
235
+ )
236
+ agg_results.append(aggregated_data)
237
+
238
+ # Combine the aggregated variables back into a single DataArray
239
+ return xr.concat(agg_results, dim=channel_dim).transpose(..., channel_dim)
240
+
241
+ def _aggregate_and_scale_geo_da(
242
+ self,
243
+ da_geo: xr.DataArray,
244
+ transformer_class: Optional[type[transformers.TensorTransformer]],
245
+ channel_dim: Optional[str] = None,
246
+ da_var_agg_map: Optional[AggregationMap] = None,
247
+ ) -> xr.DataArray:
248
+ """Aggregate geo-level xr.DataArray to national level and then scale values.
249
+
250
+ Args:
251
+ da_geo: The geo-level DataArray to convert.
252
+ transformer_class: The TensorTransformer class to apply after summing to
253
+ national level. Must be None, CenteringAndScalingTransformer, or
254
+ MediaTransformer.
255
+ channel_dim: The name of the dimension coordinate to aggregate over (e.g.,
256
+ constants.CONTROL_VARIABLE). If None, standard sum aggregation is used.
257
+ da_var_agg_map: A dictionary mapping dataArray variable names to
258
+ aggregation functions. Used only if channel_dim is not None.
259
+
260
+ Returns:
261
+ An xr.DataArray representing the aggregated and scaled national-level
262
+ data.
263
+ """
264
+ temp_geo_dim = constants.NATIONAL_MODEL_DEFAULT_GEO_NAME
265
+
266
+ if da_var_agg_map is None:
267
+ da_var_agg_map = {}
268
+
269
+ if channel_dim is not None:
270
+ da_national = self._aggregate_variables(
271
+ da_geo, channel_dim, da_var_agg_map
272
+ )
273
+ else:
274
+ # Default to sum aggregation if no channel dimension is provided
275
+ da_national = da_geo.sum(
276
+ dim=constants.GEO, keepdims=True, skipna=False, keep_attrs=True
277
+ )
278
+
279
+ da_national = da_national.assign_coords({constants.GEO: [temp_geo_dim]})
280
+ da_national.values = tf.cast(da_national.values, tf.float32)
281
+ da_national = self._scale_xarray(da_national, transformer_class)
282
+
283
+ return da_national.sel({constants.GEO: temp_geo_dim}, drop=True)
284
+
285
+
286
+ def _data_array_like(
287
+ *, da: xr.DataArray, values: np.ndarray | tf.Tensor
288
+ ) -> xr.DataArray:
289
+ """Returns a DataArray from `values` with the same structure as `da`.
290
+
291
+ Args:
292
+ da: The DataArray whose structure (dimensions, coordinates, name, and attrs)
293
+ will be used for the new DataArray.
294
+ values: The numpy array or tensorflow tensor to use as the values for the
295
+ new DataArray.
296
+
297
+ Returns:
298
+ A new DataArray with the provided `values` and the same structure as `da`.
299
+ """
300
+ return xr.DataArray(
301
+ values,
302
+ coords=da.coords,
303
+ dims=da.dims,
304
+ name=da.name,
305
+ attrs=da.attrs,
306
+ )