google-meridian 1.1.5__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.1.5.dist-info → google_meridian-1.2.0.dist-info}/METADATA +8 -2
- google_meridian-1.2.0.dist-info/RECORD +52 -0
- meridian/__init__.py +1 -0
- meridian/analysis/analyzer.py +526 -362
- meridian/analysis/optimizer.py +275 -267
- meridian/analysis/test_utils.py +96 -94
- meridian/analysis/visualizer.py +37 -49
- meridian/backend/__init__.py +514 -0
- meridian/backend/config.py +59 -0
- meridian/backend/test_utils.py +95 -0
- meridian/constants.py +59 -3
- meridian/data/input_data.py +94 -0
- meridian/data/test_utils.py +144 -12
- meridian/model/adstock_hill.py +279 -33
- meridian/model/eda/__init__.py +17 -0
- meridian/model/eda/eda_engine.py +306 -0
- meridian/model/knots.py +525 -2
- meridian/model/media.py +62 -54
- meridian/model/model.py +224 -97
- meridian/model/model_test_data.py +323 -157
- meridian/model/posterior_sampler.py +84 -77
- meridian/model/prior_distribution.py +538 -168
- meridian/model/prior_sampler.py +65 -65
- meridian/model/spec.py +23 -3
- meridian/model/transformers.py +53 -47
- meridian/version.py +1 -1
- google_meridian-1.1.5.dist-info/RECORD +0 -47
- {google_meridian-1.1.5.dist-info → google_meridian-1.2.0.dist-info}/WHEEL +0 -0
- {google_meridian-1.1.5.dist-info → google_meridian-1.2.0.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.1.5.dist-info → google_meridian-1.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,306 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Meridian EDA Engine."""
|
|
16
|
+
|
|
17
|
+
import functools
|
|
18
|
+
from typing import Callable, Dict, Optional, TypeAlias
|
|
19
|
+
from meridian import constants
|
|
20
|
+
from meridian.model import model
|
|
21
|
+
from meridian.model import transformers
|
|
22
|
+
import numpy as np
|
|
23
|
+
import tensorflow as tf
|
|
24
|
+
import xarray as xr
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
_DEFAULT_DA_VAR_AGG_FUNCTION = np.sum
|
|
28
|
+
AggregationMap: TypeAlias = Dict[str, Callable[[xr.DataArray], np.ndarray]]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class EDAEngine:
|
|
32
|
+
"""Meridian EDA Engine."""
|
|
33
|
+
|
|
34
|
+
def __init__(self, meridian: model.Meridian):
|
|
35
|
+
self._meridian = meridian
|
|
36
|
+
|
|
37
|
+
@functools.cached_property
|
|
38
|
+
def controls_scaled_da(self) -> xr.DataArray | None:
|
|
39
|
+
if self._meridian.input_data.controls is None:
|
|
40
|
+
return None
|
|
41
|
+
controls_scaled_da = _data_array_like(
|
|
42
|
+
da=self._meridian.input_data.controls,
|
|
43
|
+
values=self._meridian.controls_scaled,
|
|
44
|
+
)
|
|
45
|
+
return controls_scaled_da
|
|
46
|
+
|
|
47
|
+
@functools.cached_property
|
|
48
|
+
def media_raw_da(self) -> xr.DataArray | None:
|
|
49
|
+
if self._meridian.input_data.media is None:
|
|
50
|
+
return None
|
|
51
|
+
return self._truncate_media_time(self._meridian.input_data.media)
|
|
52
|
+
|
|
53
|
+
@functools.cached_property
|
|
54
|
+
def media_scaled_da(self) -> xr.DataArray | None:
|
|
55
|
+
if self._meridian.input_data.media is None:
|
|
56
|
+
return None
|
|
57
|
+
media_scaled_da = _data_array_like(
|
|
58
|
+
da=self._meridian.input_data.media,
|
|
59
|
+
values=self._meridian.media_tensors.media_scaled,
|
|
60
|
+
)
|
|
61
|
+
return self._truncate_media_time(media_scaled_da)
|
|
62
|
+
|
|
63
|
+
@functools.cached_property
|
|
64
|
+
def media_spend_da(self) -> xr.DataArray | None:
|
|
65
|
+
if self._meridian.input_data.media_spend is None:
|
|
66
|
+
return None
|
|
67
|
+
media_spend_da = _data_array_like(
|
|
68
|
+
da=self._meridian.input_data.media_spend,
|
|
69
|
+
values=self._meridian.media_tensors.media_spend,
|
|
70
|
+
)
|
|
71
|
+
# No need to truncate the media time for media spend.
|
|
72
|
+
return media_spend_da
|
|
73
|
+
|
|
74
|
+
@functools.cached_property
|
|
75
|
+
def media_raw_da_national(self) -> xr.DataArray | None:
|
|
76
|
+
if self.media_raw_da is None:
|
|
77
|
+
return None
|
|
78
|
+
if self._meridian.is_national:
|
|
79
|
+
return self.media_raw_da
|
|
80
|
+
else:
|
|
81
|
+
return self._aggregate_and_scale_geo_da(
|
|
82
|
+
self.media_raw_da,
|
|
83
|
+
None,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
@functools.cached_property
|
|
87
|
+
def media_scaled_da_national(self) -> xr.DataArray | None:
|
|
88
|
+
if self.media_scaled_da is None:
|
|
89
|
+
return None
|
|
90
|
+
if self._meridian.is_national:
|
|
91
|
+
return self.media_scaled_da
|
|
92
|
+
else:
|
|
93
|
+
return self._aggregate_and_scale_geo_da(
|
|
94
|
+
self.media_raw_da,
|
|
95
|
+
transformers.MediaTransformer,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
@functools.cached_property
|
|
99
|
+
def organic_media_raw_da(self) -> xr.DataArray | None:
|
|
100
|
+
if self._meridian.input_data.organic_media is None:
|
|
101
|
+
return None
|
|
102
|
+
return self._truncate_media_time(self._meridian.input_data.organic_media)
|
|
103
|
+
|
|
104
|
+
@functools.cached_property
|
|
105
|
+
def organic_media_scaled_da(self) -> xr.DataArray | None:
|
|
106
|
+
if self._meridian.input_data.organic_media is None:
|
|
107
|
+
return None
|
|
108
|
+
organic_media_scaled_da = _data_array_like(
|
|
109
|
+
da=self._meridian.input_data.organic_media,
|
|
110
|
+
values=self._meridian.organic_media_tensors.organic_media_scaled,
|
|
111
|
+
)
|
|
112
|
+
return self._truncate_media_time(organic_media_scaled_da)
|
|
113
|
+
|
|
114
|
+
@functools.cached_property
|
|
115
|
+
def organic_media_raw_da_national(self) -> xr.DataArray | None:
|
|
116
|
+
if self.organic_media_raw_da is None:
|
|
117
|
+
return None
|
|
118
|
+
if self._meridian.is_national:
|
|
119
|
+
return self.organic_media_raw_da
|
|
120
|
+
else:
|
|
121
|
+
return self._aggregate_and_scale_geo_da(self.organic_media_raw_da, None)
|
|
122
|
+
|
|
123
|
+
@functools.cached_property
|
|
124
|
+
def organic_media_scaled_da_national(self) -> xr.DataArray | None:
|
|
125
|
+
if self.organic_media_scaled_da is None:
|
|
126
|
+
return None
|
|
127
|
+
if self._meridian.is_national:
|
|
128
|
+
return self.organic_media_scaled_da
|
|
129
|
+
else:
|
|
130
|
+
return self._aggregate_and_scale_geo_da(
|
|
131
|
+
self.organic_media_raw_da,
|
|
132
|
+
transformers.MediaTransformer,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
@functools.cached_property
|
|
136
|
+
def non_media_scaled_da(self) -> xr.DataArray | None:
|
|
137
|
+
if self._meridian.input_data.non_media_treatments is None:
|
|
138
|
+
return None
|
|
139
|
+
non_media_scaled_da = _data_array_like(
|
|
140
|
+
da=self._meridian.input_data.non_media_treatments,
|
|
141
|
+
values=self._meridian.non_media_treatments_normalized,
|
|
142
|
+
)
|
|
143
|
+
return non_media_scaled_da
|
|
144
|
+
|
|
145
|
+
@functools.cached_property
|
|
146
|
+
def rf_spend_da(self) -> xr.DataArray | None:
|
|
147
|
+
if self._meridian.input_data.rf_spend is None:
|
|
148
|
+
return None
|
|
149
|
+
rf_spend_da = _data_array_like(
|
|
150
|
+
da=self._meridian.input_data.rf_spend,
|
|
151
|
+
values=self._meridian.rf_tensors.rf_spend,
|
|
152
|
+
)
|
|
153
|
+
return rf_spend_da
|
|
154
|
+
|
|
155
|
+
@functools.cached_property
|
|
156
|
+
def rf_spend_da_national(self) -> xr.DataArray | None:
|
|
157
|
+
if self._meridian.input_data.rf_spend is None:
|
|
158
|
+
return None
|
|
159
|
+
if self._meridian.is_national:
|
|
160
|
+
return self.rf_spend_da
|
|
161
|
+
else:
|
|
162
|
+
return self._aggregate_and_scale_geo_da(
|
|
163
|
+
self._meridian.input_data.rf_spend, None
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
def _truncate_media_time(self, da: xr.DataArray) -> xr.DataArray:
|
|
167
|
+
"""Truncates the first `start` elements of the media time of a variable."""
|
|
168
|
+
# This should not happen. If it does, it means this function is mis-used.
|
|
169
|
+
if constants.MEDIA_TIME not in da.coords:
|
|
170
|
+
raise ValueError(
|
|
171
|
+
f'Variable does not have a media time coordinate: {da.name}.'
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
start = self._meridian.n_media_times - self._meridian.n_times
|
|
175
|
+
da = da.copy().isel({constants.MEDIA_TIME: slice(start, None)})
|
|
176
|
+
da = da.rename({constants.MEDIA_TIME: constants.TIME})
|
|
177
|
+
return da
|
|
178
|
+
|
|
179
|
+
def _scale_xarray(
|
|
180
|
+
self,
|
|
181
|
+
xarray: xr.DataArray,
|
|
182
|
+
transformer_class: Optional[type[transformers.TensorTransformer]],
|
|
183
|
+
population: tf.Tensor = tf.constant([1.0], dtype=tf.float32),
|
|
184
|
+
):
|
|
185
|
+
"""Scales xarray values with a TensorTransformer."""
|
|
186
|
+
if transformer_class is None:
|
|
187
|
+
return xarray
|
|
188
|
+
elif transformer_class is transformers.CenteringAndScalingTransformer:
|
|
189
|
+
xarray_transformer = transformers.CenteringAndScalingTransformer(
|
|
190
|
+
tensor=xarray.values, population=population
|
|
191
|
+
)
|
|
192
|
+
elif transformer_class is transformers.MediaTransformer:
|
|
193
|
+
xarray_transformer = transformers.MediaTransformer(
|
|
194
|
+
media=xarray.values, population=population
|
|
195
|
+
)
|
|
196
|
+
else:
|
|
197
|
+
raise ValueError(
|
|
198
|
+
'Unknown transformer class: '
|
|
199
|
+
+ str(transformer_class)
|
|
200
|
+
+ '.\nMust be one of: CenteringAndScalingTransformer or'
|
|
201
|
+
' MediaTransformer.'
|
|
202
|
+
)
|
|
203
|
+
xarray.values = xarray_transformer.forward(xarray.values)
|
|
204
|
+
return xarray
|
|
205
|
+
|
|
206
|
+
def _aggregate_variables(
|
|
207
|
+
self,
|
|
208
|
+
da_geo: xr.DataArray,
|
|
209
|
+
channel_dim: str,
|
|
210
|
+
da_var_agg_map: AggregationMap,
|
|
211
|
+
keepdims: bool = True,
|
|
212
|
+
) -> xr.DataArray:
|
|
213
|
+
"""Aggregates variables within a DataArray based on user-defined functions.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
da_geo: The geo-level DataArray containing multiple variables along
|
|
217
|
+
channel_dim.
|
|
218
|
+
channel_dim: The name of the dimension coordinate to aggregate over (e.g.,
|
|
219
|
+
constants.CONTROL_VARIABLE).
|
|
220
|
+
da_var_agg_map: A dictionary mapping dataArray variable names to
|
|
221
|
+
aggregation functions.
|
|
222
|
+
keepdims: Whether to keep the dimensions of the aggregated DataArray.
|
|
223
|
+
|
|
224
|
+
Returns:
|
|
225
|
+
An xr.DataArray aggregated to the national level, with each variable
|
|
226
|
+
aggregated according to the da_var_agg_map.
|
|
227
|
+
"""
|
|
228
|
+
agg_results = []
|
|
229
|
+
for var_name in da_geo[channel_dim].values:
|
|
230
|
+
var_data = da_geo.sel({channel_dim: var_name})
|
|
231
|
+
agg_func = da_var_agg_map.get(var_name, _DEFAULT_DA_VAR_AGG_FUNCTION)
|
|
232
|
+
# Apply the aggregation function over the GEO dimension
|
|
233
|
+
aggregated_data = var_data.reduce(
|
|
234
|
+
agg_func, dim=constants.GEO, keepdims=keepdims
|
|
235
|
+
)
|
|
236
|
+
agg_results.append(aggregated_data)
|
|
237
|
+
|
|
238
|
+
# Combine the aggregated variables back into a single DataArray
|
|
239
|
+
return xr.concat(agg_results, dim=channel_dim).transpose(..., channel_dim)
|
|
240
|
+
|
|
241
|
+
def _aggregate_and_scale_geo_da(
|
|
242
|
+
self,
|
|
243
|
+
da_geo: xr.DataArray,
|
|
244
|
+
transformer_class: Optional[type[transformers.TensorTransformer]],
|
|
245
|
+
channel_dim: Optional[str] = None,
|
|
246
|
+
da_var_agg_map: Optional[AggregationMap] = None,
|
|
247
|
+
) -> xr.DataArray:
|
|
248
|
+
"""Aggregate geo-level xr.DataArray to national level and then scale values.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
da_geo: The geo-level DataArray to convert.
|
|
252
|
+
transformer_class: The TensorTransformer class to apply after summing to
|
|
253
|
+
national level. Must be None, CenteringAndScalingTransformer, or
|
|
254
|
+
MediaTransformer.
|
|
255
|
+
channel_dim: The name of the dimension coordinate to aggregate over (e.g.,
|
|
256
|
+
constants.CONTROL_VARIABLE). If None, standard sum aggregation is used.
|
|
257
|
+
da_var_agg_map: A dictionary mapping dataArray variable names to
|
|
258
|
+
aggregation functions. Used only if channel_dim is not None.
|
|
259
|
+
|
|
260
|
+
Returns:
|
|
261
|
+
An xr.DataArray representing the aggregated and scaled national-level
|
|
262
|
+
data.
|
|
263
|
+
"""
|
|
264
|
+
temp_geo_dim = constants.NATIONAL_MODEL_DEFAULT_GEO_NAME
|
|
265
|
+
|
|
266
|
+
if da_var_agg_map is None:
|
|
267
|
+
da_var_agg_map = {}
|
|
268
|
+
|
|
269
|
+
if channel_dim is not None:
|
|
270
|
+
da_national = self._aggregate_variables(
|
|
271
|
+
da_geo, channel_dim, da_var_agg_map
|
|
272
|
+
)
|
|
273
|
+
else:
|
|
274
|
+
# Default to sum aggregation if no channel dimension is provided
|
|
275
|
+
da_national = da_geo.sum(
|
|
276
|
+
dim=constants.GEO, keepdims=True, skipna=False, keep_attrs=True
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
da_national = da_national.assign_coords({constants.GEO: [temp_geo_dim]})
|
|
280
|
+
da_national.values = tf.cast(da_national.values, tf.float32)
|
|
281
|
+
da_national = self._scale_xarray(da_national, transformer_class)
|
|
282
|
+
|
|
283
|
+
return da_national.sel({constants.GEO: temp_geo_dim}, drop=True)
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def _data_array_like(
|
|
287
|
+
*, da: xr.DataArray, values: np.ndarray | tf.Tensor
|
|
288
|
+
) -> xr.DataArray:
|
|
289
|
+
"""Returns a DataArray from `values` with the same structure as `da`.
|
|
290
|
+
|
|
291
|
+
Args:
|
|
292
|
+
da: The DataArray whose structure (dimensions, coordinates, name, and attrs)
|
|
293
|
+
will be used for the new DataArray.
|
|
294
|
+
values: The numpy array or tensorflow tensor to use as the values for the
|
|
295
|
+
new DataArray.
|
|
296
|
+
|
|
297
|
+
Returns:
|
|
298
|
+
A new DataArray with the provided `values` and the same structure as `da`.
|
|
299
|
+
"""
|
|
300
|
+
return xr.DataArray(
|
|
301
|
+
values,
|
|
302
|
+
coords=da.coords,
|
|
303
|
+
dims=da.dims,
|
|
304
|
+
name=da.name,
|
|
305
|
+
attrs=da.attrs,
|
|
306
|
+
)
|