google-meridian 1.3.0__py3-none-any.whl → 1.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google_meridian-1.3.1.dist-info/METADATA +209 -0
- {google_meridian-1.3.0.dist-info → google_meridian-1.3.1.dist-info}/RECORD +24 -10
- {google_meridian-1.3.0.dist-info → google_meridian-1.3.1.dist-info}/top_level.txt +1 -0
- meridian/backend/__init__.py +180 -23
- meridian/backend/test_utils.py +122 -0
- meridian/model/eda/eda_engine.py +54 -8
- meridian/model/model_test_data.py +15 -0
- meridian/version.py +1 -1
- schema/__init__.py +18 -0
- schema/serde/__init__.py +26 -0
- schema/serde/constants.py +48 -0
- schema/serde/distribution.py +515 -0
- schema/serde/eda_spec.py +192 -0
- schema/serde/function_registry.py +143 -0
- schema/serde/hyperparameters.py +363 -0
- schema/serde/inference_data.py +105 -0
- schema/serde/marketing_data.py +1321 -0
- schema/serde/meridian_serde.py +413 -0
- schema/serde/serde.py +47 -0
- schema/serde/test_data.py +4608 -0
- schema/utils/__init__.py +17 -0
- schema/utils/time_record.py +156 -0
- google_meridian-1.3.0.dist-info/METADATA +0 -409
- {google_meridian-1.3.0.dist-info → google_meridian-1.3.1.dist-info}/WHEEL +0 -0
- {google_meridian-1.3.0.dist-info → google_meridian-1.3.1.dist-info}/licenses/LICENSE +0 -0
meridian/model/eda/eda_engine.py
CHANGED
|
@@ -175,10 +175,22 @@ def _data_array_like(
|
|
|
175
175
|
)
|
|
176
176
|
|
|
177
177
|
|
|
178
|
-
def
|
|
178
|
+
def stack_variables(
|
|
179
179
|
ds: xr.Dataset, coord_name: str = _STACK_VAR_COORD_NAME
|
|
180
180
|
) -> xr.DataArray:
|
|
181
|
-
"""Stacks data variables
|
|
181
|
+
"""Stacks data variables of a Dataset into a single DataArray.
|
|
182
|
+
|
|
183
|
+
This function is designed to work with Datasets that have 'time' or 'geo'
|
|
184
|
+
dimensions, which are preserved. Other dimensions are stacked into a new
|
|
185
|
+
dimension.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
ds: The input xarray.Dataset to stack.
|
|
189
|
+
coord_name: The name of the new coordinate for the stacked dimension.
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
An xarray.DataArray with the specified dimensions stacked.
|
|
193
|
+
"""
|
|
182
194
|
dims = []
|
|
183
195
|
coords = []
|
|
184
196
|
sample_dims = []
|
|
@@ -423,10 +435,11 @@ class EDAEngine:
|
|
|
423
435
|
@functools.cached_property
|
|
424
436
|
def national_media_spend_da(self) -> xr.DataArray | None:
|
|
425
437
|
"""Returns the national media spend data array."""
|
|
426
|
-
|
|
438
|
+
media_spend = self.media_spend_da
|
|
439
|
+
if media_spend is None:
|
|
427
440
|
return None
|
|
428
441
|
if self._is_national_data:
|
|
429
|
-
national_da =
|
|
442
|
+
national_da = media_spend.squeeze(constants.GEO, drop=True)
|
|
430
443
|
national_da.name = constants.NATIONAL_MEDIA_SPEND
|
|
431
444
|
else:
|
|
432
445
|
national_da = self._aggregate_and_scale_geo_da(
|
|
@@ -576,10 +589,11 @@ class EDAEngine:
|
|
|
576
589
|
@functools.cached_property
|
|
577
590
|
def national_rf_spend_da(self) -> xr.DataArray | None:
|
|
578
591
|
"""Returns the national RF spend data array."""
|
|
579
|
-
|
|
592
|
+
rf_spend = self.rf_spend_da
|
|
593
|
+
if rf_spend is None:
|
|
580
594
|
return None
|
|
581
595
|
if self._is_national_data:
|
|
582
|
-
national_da =
|
|
596
|
+
national_da = rf_spend.squeeze(constants.GEO, drop=True)
|
|
583
597
|
national_da.name = constants.NATIONAL_RF_SPEND
|
|
584
598
|
else:
|
|
585
599
|
national_da = self._aggregate_and_scale_geo_da(
|
|
@@ -806,10 +820,42 @@ class EDAEngine:
|
|
|
806
820
|
]
|
|
807
821
|
return xr.merge(to_merge, join='inner')
|
|
808
822
|
|
|
823
|
+
@functools.cached_property
|
|
824
|
+
def all_spend_ds(self) -> xr.Dataset:
|
|
825
|
+
"""Returns a Dataset containing all spend data.
|
|
826
|
+
|
|
827
|
+
This includes media spend and rf spend.
|
|
828
|
+
"""
|
|
829
|
+
to_merge = [
|
|
830
|
+
da
|
|
831
|
+
for da in [
|
|
832
|
+
self.media_spend_da,
|
|
833
|
+
self.rf_spend_da,
|
|
834
|
+
]
|
|
835
|
+
if da is not None
|
|
836
|
+
]
|
|
837
|
+
return xr.merge(to_merge, join='inner')
|
|
838
|
+
|
|
839
|
+
@functools.cached_property
|
|
840
|
+
def national_all_spend_ds(self) -> xr.Dataset:
|
|
841
|
+
"""Returns a Dataset containing all national spend data.
|
|
842
|
+
|
|
843
|
+
This includes media spend and rf spend.
|
|
844
|
+
"""
|
|
845
|
+
to_merge = [
|
|
846
|
+
da
|
|
847
|
+
for da in [
|
|
848
|
+
self.national_media_spend_da,
|
|
849
|
+
self.national_rf_spend_da,
|
|
850
|
+
]
|
|
851
|
+
if da is not None
|
|
852
|
+
]
|
|
853
|
+
return xr.merge(to_merge, join='inner')
|
|
854
|
+
|
|
809
855
|
@functools.cached_property
|
|
810
856
|
def _stacked_treatment_control_scaled_da(self) -> xr.DataArray:
|
|
811
857
|
"""Returns a stacked DataArray of treatment_control_scaled_ds."""
|
|
812
|
-
da =
|
|
858
|
+
da = stack_variables(self.treatment_control_scaled_ds)
|
|
813
859
|
da.name = constants.TREATMENT_CONTROL_SCALED
|
|
814
860
|
return da
|
|
815
861
|
|
|
@@ -837,7 +883,7 @@ class EDAEngine:
|
|
|
837
883
|
@functools.cached_property
|
|
838
884
|
def _stacked_national_treatment_control_scaled_da(self) -> xr.DataArray:
|
|
839
885
|
"""Returns a stacked DataArray of national_treatment_control_scaled_ds."""
|
|
840
|
-
da =
|
|
886
|
+
da = stack_variables(self.national_treatment_control_scaled_ds)
|
|
841
887
|
da.name = constants.NATIONAL_TREATMENT_CONTROL_SCALED
|
|
842
888
|
return da
|
|
843
889
|
|
|
@@ -143,6 +143,7 @@ class WithInputDataSamples:
|
|
|
143
143
|
_short_input_data_with_rf_only: input_data.InputData
|
|
144
144
|
_short_input_data_with_media_and_rf: input_data.InputData
|
|
145
145
|
_national_input_data_media_only: input_data.InputData
|
|
146
|
+
_national_input_data_rf_only: input_data.InputData
|
|
146
147
|
_national_input_data_media_and_rf: input_data.InputData
|
|
147
148
|
_test_dist_media_and_rf: collections.OrderedDict[str, backend.Tensor]
|
|
148
149
|
_test_dist_media_only: collections.OrderedDict[str, backend.Tensor]
|
|
@@ -282,6 +283,16 @@ class WithInputDataSamples:
|
|
|
282
283
|
seed=0,
|
|
283
284
|
)
|
|
284
285
|
)
|
|
286
|
+
cls._national_input_data_rf_only = (
|
|
287
|
+
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
288
|
+
n_geos=cls._N_GEOS_NATIONAL,
|
|
289
|
+
n_times=cls._N_TIMES,
|
|
290
|
+
n_media_times=cls._N_MEDIA_TIMES,
|
|
291
|
+
n_controls=cls._N_CONTROLS,
|
|
292
|
+
n_rf_channels=cls._N_RF_CHANNELS,
|
|
293
|
+
seed=0,
|
|
294
|
+
)
|
|
295
|
+
)
|
|
285
296
|
cls._national_input_data_media_only = (
|
|
286
297
|
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
287
298
|
n_geos=cls._N_GEOS_NATIONAL,
|
|
@@ -581,6 +592,10 @@ class WithInputDataSamples:
|
|
|
581
592
|
def national_input_data_media_only(self) -> input_data.InputData:
|
|
582
593
|
return self._national_input_data_media_only.copy(deep=True)
|
|
583
594
|
|
|
595
|
+
@property
|
|
596
|
+
def national_input_data_rf_only(self) -> input_data.InputData:
|
|
597
|
+
return self._national_input_data_rf_only.copy(deep=True)
|
|
598
|
+
|
|
584
599
|
@property
|
|
585
600
|
def national_input_data_media_and_rf(self) -> input_data.InputData:
|
|
586
601
|
return self._national_input_data_media_and_rf.copy(deep=True)
|
meridian/version.py
CHANGED
schema/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Module containing MMM schema library."""
|
|
16
|
+
|
|
17
|
+
from schema import serde
|
|
18
|
+
from schema import utils
|
schema/serde/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""A serialization and deserialization library for Meridian models.
|
|
16
|
+
|
|
17
|
+
For entry points API, see `meridian_serde` module docs.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from schema.serde import constants
|
|
21
|
+
from schema.serde import distribution
|
|
22
|
+
from schema.serde import eda_spec
|
|
23
|
+
from schema.serde import hyperparameters
|
|
24
|
+
from schema.serde import inference_data
|
|
25
|
+
from schema.serde import meridian_serde
|
|
26
|
+
from schema.serde import serde
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Constants shared across the Meridian serde library."""
|
|
16
|
+
|
|
17
|
+
# Constants for hyperparameters protobuf structure
|
|
18
|
+
BASELINE_GEO_ONEOF = 'baseline_geo_oneof'
|
|
19
|
+
BASELINE_GEO_INT = 'baseline_geo_int'
|
|
20
|
+
BASELINE_GEO_STRING = 'baseline_geo_string'
|
|
21
|
+
CONTROL_POPULATION_SCALING_ID = 'control_population_scaling_id'
|
|
22
|
+
HOLDOUT_ID = 'holdout_id'
|
|
23
|
+
NON_MEDIA_POPULATION_SCALING_ID = 'non_media_population_scaling_id'
|
|
24
|
+
ADSTOCK_DECAY_SPEC = 'adstock_decay_spec'
|
|
25
|
+
GLOBAL_ADSTOCK_DECAY = 'global_adstock_decay'
|
|
26
|
+
ADSTOCK_DECAY_BY_CHANNEL = 'adstock_decay_by_channel'
|
|
27
|
+
DEFAULT_DECAY = 'geometric'
|
|
28
|
+
|
|
29
|
+
# Constants for marketing data protobuf structure
|
|
30
|
+
GEO_INFO = 'geo_info'
|
|
31
|
+
METADATA = 'metadata'
|
|
32
|
+
REACH_FREQUENCY = 'reach_frequency'
|
|
33
|
+
|
|
34
|
+
# Constants for distribution protobuf structure
|
|
35
|
+
DISTRIBUTION_TYPE = 'distribution_type'
|
|
36
|
+
BATCH_BROADCAST_DISTRIBUTION = 'batch_broadcast'
|
|
37
|
+
DETERMINISTIC_DISTRIBUTION = 'deterministic'
|
|
38
|
+
HALF_NORMAL_DISTRIBUTION = 'half_normal'
|
|
39
|
+
LOG_NORMAL_DISTRIBUTION = 'log_normal'
|
|
40
|
+
NORMAL_DISTRIBUTION = 'normal'
|
|
41
|
+
TRANSFORMED_DISTRIBUTION = 'transformed'
|
|
42
|
+
TRUNCATED_NORMAL_DISTRIBUTION = 'truncated_normal'
|
|
43
|
+
UNIFORM_DISTRIBUTION = 'uniform'
|
|
44
|
+
BETA_DISTRIBUTION = 'beta'
|
|
45
|
+
BIJECTOR_TYPE = 'bijector_type'
|
|
46
|
+
SHIFT_BIJECTOR = 'shift'
|
|
47
|
+
SCALE_BIJECTOR = 'scale'
|
|
48
|
+
RECIPROCAL_BIJECTOR = 'reciprocal'
|