google-meridian 1.3.2__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/METADATA +8 -4
- {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/RECORD +49 -17
- {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/top_level.txt +1 -0
- meridian/analysis/summarizer.py +7 -2
- meridian/analysis/test_utils.py +934 -485
- meridian/analysis/visualizer.py +10 -6
- meridian/constants.py +1 -0
- meridian/data/test_utils.py +82 -10
- meridian/model/__init__.py +2 -0
- meridian/model/context.py +925 -0
- meridian/model/eda/constants.py +1 -0
- meridian/model/equations.py +418 -0
- meridian/model/knots.py +58 -47
- meridian/model/model.py +93 -792
- meridian/version.py +1 -1
- scenarioplanner/__init__.py +42 -0
- scenarioplanner/converters/__init__.py +25 -0
- scenarioplanner/converters/dataframe/__init__.py +28 -0
- scenarioplanner/converters/dataframe/budget_opt_converters.py +383 -0
- scenarioplanner/converters/dataframe/common.py +71 -0
- scenarioplanner/converters/dataframe/constants.py +137 -0
- scenarioplanner/converters/dataframe/converter.py +42 -0
- scenarioplanner/converters/dataframe/dataframe_model_converter.py +70 -0
- scenarioplanner/converters/dataframe/marketing_analyses_converters.py +543 -0
- scenarioplanner/converters/dataframe/rf_opt_converters.py +314 -0
- scenarioplanner/converters/mmm.py +743 -0
- scenarioplanner/converters/mmm_converter.py +58 -0
- scenarioplanner/converters/sheets.py +156 -0
- scenarioplanner/converters/test_data.py +714 -0
- scenarioplanner/linkingapi/__init__.py +47 -0
- scenarioplanner/linkingapi/constants.py +27 -0
- scenarioplanner/linkingapi/url_generator.py +131 -0
- scenarioplanner/mmm_ui_proto_generator.py +354 -0
- schema/__init__.py +5 -2
- schema/mmm_proto_generator.py +71 -0
- schema/model_consumer.py +133 -0
- schema/processors/__init__.py +77 -0
- schema/processors/budget_optimization_processor.py +832 -0
- schema/processors/common.py +64 -0
- schema/processors/marketing_processor.py +1136 -0
- schema/processors/model_fit_processor.py +367 -0
- schema/processors/model_kernel_processor.py +117 -0
- schema/processors/model_processor.py +412 -0
- schema/processors/reach_frequency_optimization_processor.py +584 -0
- schema/test_data.py +380 -0
- schema/utils/__init__.py +1 -0
- schema/utils/date_range_bucketing.py +117 -0
- {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/WHEEL +0 -0
- {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,832 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Defines a processor for budget optimization inference on a Meridian model.
|
|
16
|
+
|
|
17
|
+
This module provides the `BudgetOptimizationProcessor` class, which is used to
|
|
18
|
+
perform marketing budget optimization based on a trained Meridian model. The
|
|
19
|
+
processor takes a trained model and a `BudgetOptimizationSpec` object,
|
|
20
|
+
which defines the optimization parameters, constraints, and scenarios.
|
|
21
|
+
|
|
22
|
+
The optimization process aims to find the optimal allocation of budget across
|
|
23
|
+
different media channels to maximize a specified objective, such as Key
|
|
24
|
+
Performance Indicator (KPI) or Revenue, subject to various constraints.
|
|
25
|
+
|
|
26
|
+
Key Features:
|
|
27
|
+
|
|
28
|
+
- Supports both fixed and flexible budget scenarios.
|
|
29
|
+
- Allows setting channel-level budget constraints, either as absolute values
|
|
30
|
+
or relative to historical spend.
|
|
31
|
+
- Generates detailed optimization results, including optimal spends, expected
|
|
32
|
+
outcomes, and response curves.
|
|
33
|
+
- Outputs results in a structured protobuf format (`BudgetOptimization`).
|
|
34
|
+
|
|
35
|
+
Key Classes:
|
|
36
|
+
|
|
37
|
+
- `BudgetOptimizationSpec`: Dataclass to specify optimization parameters and
|
|
38
|
+
constraints.
|
|
39
|
+
- `BudgetOptimizationProcessor`: The main processor class to execute budget
|
|
40
|
+
optimization.
|
|
41
|
+
|
|
42
|
+
Example Usage:
|
|
43
|
+
|
|
44
|
+
1. **Fixed Budget Optimization:**
|
|
45
|
+
Optimize budget allocation for a fixed total budget, aiming to maximize KPI.
|
|
46
|
+
|
|
47
|
+
```python
|
|
48
|
+
from schema.processors import budget_optimization_processor
|
|
49
|
+
from meridian.analysis import optimizer
|
|
50
|
+
from schema.processors import common
|
|
51
|
+
|
|
52
|
+
# Assuming 'trained_model' is a loaded Meridian model object
|
|
53
|
+
|
|
54
|
+
spec = budget_optimization_processor.BudgetOptimizationSpec(
|
|
55
|
+
optimization_name="fixed_budget_scenario_1",
|
|
56
|
+
scenario=optimizer.FixedBudgetScenario(total_budget=1000000),
|
|
57
|
+
kpi_type=common.KpiType.REVENUE, # Or common.KpiType.NON_REVENUE
|
|
58
|
+
# Optional: Add channel constraints
|
|
59
|
+
constraints=[
|
|
60
|
+
budget_optimization_processor.ChannelConstraintRel(
|
|
61
|
+
channel_name="channel_a",
|
|
62
|
+
spend_constraint_lower=0.1, # Allow 10% decrease
|
|
63
|
+
spend_constraint_upper=0.5 # Allow 50% increase
|
|
64
|
+
),
|
|
65
|
+
budget_optimization_processor.ChannelConstraintRel(
|
|
66
|
+
channel_name="channel_b",
|
|
67
|
+
spend_constraint_lower=0.0, # No decrease
|
|
68
|
+
spend_constraint_upper=1.0 # Allow 100% increase
|
|
69
|
+
)
|
|
70
|
+
],
|
|
71
|
+
include_response_curves=True,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
processor = budget_optimization_processor.BudgetOptimizationProcessor(
|
|
75
|
+
trained_model
|
|
76
|
+
)
|
|
77
|
+
# result is a `budget_pb.BudgetOptimization` proto
|
|
78
|
+
result = processor.execute([spec])
|
|
79
|
+
```
|
|
80
|
+
|
|
81
|
+
2. **Flexible Budget Optimization:**
|
|
82
|
+
Optimize budget to achieve a target Return on Investment (ROI).
|
|
83
|
+
|
|
84
|
+
```python
|
|
85
|
+
from schema.processors import budget_optimization_processor
|
|
86
|
+
from meridian.analysis import optimizer
|
|
87
|
+
from schema.processors import common
|
|
88
|
+
import meridian.constants as c
|
|
89
|
+
|
|
90
|
+
# Assuming 'trained_model' is a loaded Meridian model object
|
|
91
|
+
|
|
92
|
+
spec = budget_optimization_processor.BudgetOptimizationSpec(
|
|
93
|
+
optimization_name="flexible_roi_target",
|
|
94
|
+
scenario=optimizer.FlexibleBudgetScenario(
|
|
95
|
+
target_metric=c.ROI,
|
|
96
|
+
target_value=3.5 # Target ROI of 3.5
|
|
97
|
+
),
|
|
98
|
+
kpi_type=common.KpiType.REVENUE,
|
|
99
|
+
date_interval_tag="optimization_period",
|
|
100
|
+
# Skip response curves for faster computation.
|
|
101
|
+
include_response_curves=False,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
processor = budget_optimization_processor.BudgetOptimizationProcessor(
|
|
105
|
+
trained_model
|
|
106
|
+
)
|
|
107
|
+
result = processor.execute([spec])
|
|
108
|
+
```
|
|
109
|
+
|
|
110
|
+
Note: You can provide the processor with multiple specs. This would result in
|
|
111
|
+
a `BudgetOptimization` output with multiple results therein.
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
from collections.abc import Mapping, Sequence
|
|
115
|
+
import dataclasses
|
|
116
|
+
from typing import TypeAlias
|
|
117
|
+
import warnings
|
|
118
|
+
|
|
119
|
+
from meridian import constants as c
|
|
120
|
+
from meridian.analysis import analyzer
|
|
121
|
+
from meridian.analysis import optimizer
|
|
122
|
+
from meridian.data import time_coordinates as tc
|
|
123
|
+
from mmm.v1 import mmm_pb2 as pb
|
|
124
|
+
from mmm.v1.common import estimate_pb2 as estimate_pb
|
|
125
|
+
from mmm.v1.common import kpi_type_pb2 as kpi_type_pb
|
|
126
|
+
from mmm.v1.common import target_metric_pb2 as target_pb
|
|
127
|
+
from mmm.v1.marketing.analysis import marketing_analysis_pb2 as analysis_pb
|
|
128
|
+
from mmm.v1.marketing.analysis import media_analysis_pb2 as media_analysis_pb
|
|
129
|
+
from mmm.v1.marketing.analysis import outcome_pb2 as outcome_pb
|
|
130
|
+
from mmm.v1.marketing.analysis import response_curve_pb2 as response_curve_pb
|
|
131
|
+
from mmm.v1.marketing.optimization import budget_optimization_pb2 as budget_pb
|
|
132
|
+
from mmm.v1.marketing.optimization import constraints_pb2 as constraints_pb
|
|
133
|
+
from schema.processors import common
|
|
134
|
+
from schema.processors import model_processor
|
|
135
|
+
from schema.utils import time_record
|
|
136
|
+
import numpy as np
|
|
137
|
+
from typing_extensions import override
|
|
138
|
+
import xarray as xr
|
|
139
|
+
|
|
140
|
+
__all__ = [
|
|
141
|
+
'BudgetOptimizationProcessor',
|
|
142
|
+
'BudgetOptimizationSpec',
|
|
143
|
+
'ChannelConstraintAbs',
|
|
144
|
+
'ChannelConstraintRel',
|
|
145
|
+
]
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
# Default lower and upper bounds (as _relative_ ratios) for channel constraints.
|
|
149
|
+
CHANNEL_CONSTRAINT_LOWERBOUND_DEFAULT_RATIO = 1
|
|
150
|
+
CHANNEL_CONSTRAINT_UPPERBOUND_DEFAULT_RATIO = 2
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
@dataclasses.dataclass(frozen=True)
|
|
154
|
+
class ChannelConstraintAbs:
|
|
155
|
+
"""A budget constraint on a channel.
|
|
156
|
+
|
|
157
|
+
Constraint attributes in this dataclass are absolute values. Useful to
|
|
158
|
+
represent resolved absolute constraint values in an output spec metadata.
|
|
159
|
+
|
|
160
|
+
Attributes:
|
|
161
|
+
channel_name: The name of the channel.
|
|
162
|
+
abs_lowerbound: A simple absolute lower bound value for a channel's spend.
|
|
163
|
+
abs_upperbound: A simple absolute upper bound value for a channel's spend.
|
|
164
|
+
"""
|
|
165
|
+
|
|
166
|
+
channel_name: str
|
|
167
|
+
abs_lowerbound: float
|
|
168
|
+
abs_upperbound: float
|
|
169
|
+
|
|
170
|
+
def to_proto(self) -> budget_pb.ChannelConstraint:
|
|
171
|
+
return budget_pb.ChannelConstraint(
|
|
172
|
+
channel_name=self.channel_name,
|
|
173
|
+
budget_constraint=constraints_pb.BudgetConstraint(
|
|
174
|
+
min_budget=self.abs_lowerbound,
|
|
175
|
+
max_budget=self.abs_upperbound,
|
|
176
|
+
),
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
@dataclasses.dataclass(frozen=True)
|
|
181
|
+
class ChannelConstraintRel:
|
|
182
|
+
"""A budget constraint on a channel.
|
|
183
|
+
|
|
184
|
+
Constraint attributes in this dataclass are relative ratios. Useful for user
|
|
185
|
+
input spec.
|
|
186
|
+
|
|
187
|
+
Attributes:
|
|
188
|
+
channel_name: The name of the channel.
|
|
189
|
+
spend_constraint_lower: The spend constraint lower of a channel is the
|
|
190
|
+
change in ratio w.r.t. the channel's historical spend. The absolute lower
|
|
191
|
+
bound value is equal to `(1 - spend_constraint_lower) *
|
|
192
|
+
hist_channel_spend)`. The value must be between `[0, 1]`.
|
|
193
|
+
spend_constraint_upper: The spend constraint upper of a channel is the
|
|
194
|
+
change in ratio w.r.t. the channel's historical spend. The absolute upper
|
|
195
|
+
bound value is equal to `(1 + spend_constraint_upper) *
|
|
196
|
+
hist_channel_spend)`. The value must be non-negative.
|
|
197
|
+
"""
|
|
198
|
+
|
|
199
|
+
channel_name: str
|
|
200
|
+
spend_constraint_lower: float
|
|
201
|
+
spend_constraint_upper: float
|
|
202
|
+
|
|
203
|
+
def __post_init__(self):
|
|
204
|
+
if self.spend_constraint_lower < 0:
|
|
205
|
+
raise ValueError('Spend constraint lower must be non-negative.')
|
|
206
|
+
if self.spend_constraint_lower > 1:
|
|
207
|
+
raise ValueError('Spend constraint lower must not be greater than 1.')
|
|
208
|
+
if self.spend_constraint_upper < 0:
|
|
209
|
+
raise ValueError('Spend constraint upper must be non-negative.')
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
ChannelConstraint: TypeAlias = ChannelConstraintAbs | ChannelConstraintRel
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
216
|
+
class BudgetOptimizationSpec(model_processor.OptimizationSpec):
|
|
217
|
+
"""Spec dataclass for marketing budget optimization processor.
|
|
218
|
+
|
|
219
|
+
This spec is used both as user input to inform the budget optimization
|
|
220
|
+
processor of its constraints and parameters, as well as an output structure
|
|
221
|
+
that is serializable to a `BudgetOptimizationSpec` proto. The latter serves
|
|
222
|
+
as a metadata embedded in a `BudgetOptimizationResult`.
|
|
223
|
+
|
|
224
|
+
Attributes:
|
|
225
|
+
objective: Always defined as KPI.
|
|
226
|
+
scenario: The optimization scenario (whether fixed or flexible).
|
|
227
|
+
constraints: Per-channel budget constraints. Defaults to relative
|
|
228
|
+
constraints `[1, 2]` for spend_constraint_lower and spend_constraint_upper
|
|
229
|
+
if not specified.
|
|
230
|
+
kpi_type: A `common.KpiType` enum denoting whether the optimized KPI is of a
|
|
231
|
+
`'revenue'` or `'non-revenue'` type.
|
|
232
|
+
grid: The optimization grid to use for the optimization. If None, a new grid
|
|
233
|
+
will be created within the optimizer.
|
|
234
|
+
include_response_curves: Whether to include response curves in the output.
|
|
235
|
+
Setting this to `False` improves performance if only optimization result
|
|
236
|
+
is needed.
|
|
237
|
+
new_data: The new data to use for the optimization. If None, the training
|
|
238
|
+
data will be used.
|
|
239
|
+
use_optimal_frequency: Whether to use the optimal frequency. If set to
|
|
240
|
+
`False`, `max_frequency` is ignored.
|
|
241
|
+
max_frequency: The max frequency to use for the optimal frequency search
|
|
242
|
+
space. If not set when `use_optimal_frequency` is set to `True`, the max
|
|
243
|
+
frequency of the input data is used.
|
|
244
|
+
"""
|
|
245
|
+
|
|
246
|
+
scenario: optimizer.FixedBudgetScenario | optimizer.FlexibleBudgetScenario = (
|
|
247
|
+
dataclasses.field(default_factory=optimizer.FixedBudgetScenario)
|
|
248
|
+
)
|
|
249
|
+
constraints: Sequence[ChannelConstraint] = dataclasses.field(
|
|
250
|
+
default_factory=list
|
|
251
|
+
)
|
|
252
|
+
kpi_type: common.KpiType = common.KpiType.REVENUE
|
|
253
|
+
grid: optimizer.OptimizationGrid | None = None
|
|
254
|
+
include_response_curves: bool = True
|
|
255
|
+
new_data: analyzer.DataTensors | None = None
|
|
256
|
+
use_optimal_frequency: bool = True
|
|
257
|
+
max_frequency: float | None = None
|
|
258
|
+
|
|
259
|
+
@property
|
|
260
|
+
def objective(self) -> common.TargetMetric:
|
|
261
|
+
"""A Meridian budget optimization objective is always KPI."""
|
|
262
|
+
return common.TargetMetric.KPI
|
|
263
|
+
|
|
264
|
+
@override
|
|
265
|
+
def validate(self):
|
|
266
|
+
super().validate()
|
|
267
|
+
if (self.new_data is not None) and (self.new_data.time is None):
|
|
268
|
+
raise ValueError('`time` must be provided in `new_data`.')
|
|
269
|
+
if self.use_optimal_frequency:
|
|
270
|
+
if self.max_frequency is not None and self.max_frequency < 1.0:
|
|
271
|
+
raise ValueError('`max_frequency` must be >= 1.')
|
|
272
|
+
elif self.max_frequency is not None:
|
|
273
|
+
warnings.warn(
|
|
274
|
+
'`max_frequency` is ignored because `use_optimal_frequency` is False.'
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
# TODO: Populate `new_marketing_data`.
|
|
278
|
+
def to_proto(self) -> budget_pb.BudgetOptimizationSpec:
|
|
279
|
+
# When invoked as an output proto, the spec should have been fully resolved.
|
|
280
|
+
if self.start_date is None or self.end_date is None:
|
|
281
|
+
raise ValueError(
|
|
282
|
+
'Start and end dates must be resolved before this spec can be'
|
|
283
|
+
' serialized.'
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
proto = budget_pb.BudgetOptimizationSpec(
|
|
287
|
+
date_interval=time_record.create_date_interval_pb(
|
|
288
|
+
self.start_date, self.end_date, tag=self.date_interval_tag
|
|
289
|
+
),
|
|
290
|
+
objective=self.objective.value,
|
|
291
|
+
kpi_type=(
|
|
292
|
+
kpi_type_pb.KpiType.REVENUE
|
|
293
|
+
if self.kpi_type == common.KpiType.REVENUE
|
|
294
|
+
else kpi_type_pb.KpiType.NON_REVENUE
|
|
295
|
+
),
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
match self.scenario:
|
|
299
|
+
case optimizer.FixedBudgetScenario(total_budget):
|
|
300
|
+
if total_budget is None:
|
|
301
|
+
raise ValueError(
|
|
302
|
+
'Total budget must be resolved before this spec can be serialized'
|
|
303
|
+
)
|
|
304
|
+
proto.fixed_budget_scenario.total_budget = total_budget
|
|
305
|
+
case optimizer.FlexibleBudgetScenario(target_metric, target_value):
|
|
306
|
+
proto.flexible_budget_scenario.target_metric_constraints.append(
|
|
307
|
+
constraints_pb.TargetMetricConstraint(
|
|
308
|
+
target_metric=_target_metric_to_proto(target_metric),
|
|
309
|
+
target_value=target_value,
|
|
310
|
+
)
|
|
311
|
+
)
|
|
312
|
+
case _:
|
|
313
|
+
raise ValueError('Unsupported scenario type.')
|
|
314
|
+
|
|
315
|
+
for channel_constraint in self.constraints:
|
|
316
|
+
# When invoked as an output proto, the spec's constraints must have been
|
|
317
|
+
# resolved to absolute values.
|
|
318
|
+
if not isinstance(channel_constraint, ChannelConstraintAbs):
|
|
319
|
+
raise ValueError(
|
|
320
|
+
'Channel constraints must be resolved to absolute values before'
|
|
321
|
+
' this spec can be serialized.'
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
proto.channel_constraints.append(
|
|
325
|
+
budget_pb.ChannelConstraint(
|
|
326
|
+
channel_name=channel_constraint.channel_name,
|
|
327
|
+
budget_constraint=constraints_pb.BudgetConstraint(
|
|
328
|
+
min_budget=channel_constraint.abs_lowerbound,
|
|
329
|
+
max_budget=channel_constraint.abs_upperbound,
|
|
330
|
+
),
|
|
331
|
+
)
|
|
332
|
+
)
|
|
333
|
+
proto.use_optimal_frequency = self.use_optimal_frequency
|
|
334
|
+
if self.max_frequency is not None:
|
|
335
|
+
proto.max_frequency = self.max_frequency
|
|
336
|
+
return proto
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
class BudgetOptimizationProcessor(
|
|
340
|
+
model_processor.ModelProcessor[
|
|
341
|
+
BudgetOptimizationSpec, budget_pb.BudgetOptimization
|
|
342
|
+
],
|
|
343
|
+
):
|
|
344
|
+
"""A Processor for marketing budget optimization."""
|
|
345
|
+
|
|
346
|
+
def __init__(
|
|
347
|
+
self,
|
|
348
|
+
trained_model: model_processor.ModelType,
|
|
349
|
+
):
|
|
350
|
+
self._trained_model = model_processor.ensure_trained_model(trained_model)
|
|
351
|
+
self._internal_analyzer = self._trained_model.internal_analyzer
|
|
352
|
+
self._internal_optimizer = self._trained_model.internal_optimizer
|
|
353
|
+
|
|
354
|
+
@classmethod
|
|
355
|
+
def spec_type(cls) -> type[BudgetOptimizationSpec]:
|
|
356
|
+
return BudgetOptimizationSpec
|
|
357
|
+
|
|
358
|
+
@classmethod
|
|
359
|
+
def output_type(cls) -> type[budget_pb.BudgetOptimization]:
|
|
360
|
+
return budget_pb.BudgetOptimization
|
|
361
|
+
|
|
362
|
+
def _set_output(self, output: pb.Mmm, result: budget_pb.BudgetOptimization):
|
|
363
|
+
output.marketing_optimization.budget_optimization.CopyFrom(result)
|
|
364
|
+
|
|
365
|
+
def execute(
|
|
366
|
+
self, specs: Sequence[BudgetOptimizationSpec]
|
|
367
|
+
) -> budget_pb.BudgetOptimization:
|
|
368
|
+
output = budget_pb.BudgetOptimization()
|
|
369
|
+
|
|
370
|
+
group_ids = [spec.group_id for spec in specs if spec.group_id]
|
|
371
|
+
if len(set(group_ids)) != len(group_ids):
|
|
372
|
+
raise ValueError(
|
|
373
|
+
'Specified group_id must be unique among the given group of specs.'
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
# For each given spec:
|
|
377
|
+
# 1. Run optimize, which computes channel outcomes and their optimal spends.
|
|
378
|
+
# 2. Run _create_grids, which creates incremental spend outcome grids.
|
|
379
|
+
# 3. Compile the final BudgetOptimization proto.
|
|
380
|
+
for spec in specs:
|
|
381
|
+
kwargs = build_scenario_kwargs(spec.scenario)
|
|
382
|
+
constraints_kwargs = build_constraints_kwargs(
|
|
383
|
+
spec.constraints,
|
|
384
|
+
self._trained_model.mmm.input_data.get_all_paid_channels(),
|
|
385
|
+
)
|
|
386
|
+
kwargs.update(constraints_kwargs)
|
|
387
|
+
if spec.new_data is not None and spec.new_data.time is not None:
|
|
388
|
+
time_coords = tc.TimeCoordinates.from_dates(
|
|
389
|
+
[s.decode() for s in np.asarray(spec.new_data.time)]
|
|
390
|
+
)
|
|
391
|
+
else:
|
|
392
|
+
time_coords = self._trained_model.time_coordinates
|
|
393
|
+
resolver = spec.resolver(time_coords)
|
|
394
|
+
start_date, end_date = resolver.to_closed_date_interval_tuple()
|
|
395
|
+
|
|
396
|
+
# Note that `optimize()` maximises KPI if the input data is non-revenue
|
|
397
|
+
# and the user selected `use_kpi=True`. Otherwise, it maximizes revenue.
|
|
398
|
+
opt_result = self._internal_optimizer.optimize(
|
|
399
|
+
start_date=start_date,
|
|
400
|
+
end_date=end_date,
|
|
401
|
+
fixed_budget=isinstance(spec.scenario, optimizer.FixedBudgetScenario),
|
|
402
|
+
confidence_level=spec.confidence_level,
|
|
403
|
+
use_kpi=(spec.kpi_type == common.KpiType.NON_REVENUE),
|
|
404
|
+
optimization_grid=spec.grid,
|
|
405
|
+
new_data=spec.new_data,
|
|
406
|
+
use_optimal_frequency=spec.use_optimal_frequency,
|
|
407
|
+
max_frequency=spec.max_frequency,
|
|
408
|
+
**kwargs,
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
output.results.append(
|
|
412
|
+
self._to_budget_optimization_result(
|
|
413
|
+
spec, opt_result, resolver, **constraints_kwargs
|
|
414
|
+
)
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
return output
|
|
418
|
+
|
|
419
|
+
def _to_budget_optimization_result(
|
|
420
|
+
self,
|
|
421
|
+
spec: BudgetOptimizationSpec,
|
|
422
|
+
opt_result: optimizer.OptimizationResults,
|
|
423
|
+
resolver: model_processor.DatedSpecResolver,
|
|
424
|
+
spend_constraint_lower: Sequence[float],
|
|
425
|
+
spend_constraint_upper: Sequence[float],
|
|
426
|
+
) -> budget_pb.BudgetOptimizationResult:
|
|
427
|
+
"""Converts an optimizer result to a BudgetOptimizationResult proto.
|
|
428
|
+
|
|
429
|
+
Args:
|
|
430
|
+
spec: The spec used to generate the oiptimization result..
|
|
431
|
+
opt_result: The result of the optimization.
|
|
432
|
+
resolver: A DatedSpecResolver instance.
|
|
433
|
+
spend_constraint_lower: A sequence of lower bound constraints for each
|
|
434
|
+
channel, in relative terms.
|
|
435
|
+
spend_constraint_upper: A sequence of upper bound constraints for each
|
|
436
|
+
channel, in relative terms.
|
|
437
|
+
|
|
438
|
+
Returns:
|
|
439
|
+
A BudgetOptimizationResult proto.
|
|
440
|
+
"""
|
|
441
|
+
# Copy the current spec, and resolve its date interval.
|
|
442
|
+
start, end = resolver.resolve_to_date_interval_open_end()
|
|
443
|
+
|
|
444
|
+
# Resolve the given (input) spec to an (output) spec: the latter features
|
|
445
|
+
# dates and absolute channel constraints resolution.
|
|
446
|
+
spec = dataclasses.replace(
|
|
447
|
+
spec,
|
|
448
|
+
start_date=start,
|
|
449
|
+
end_date=end,
|
|
450
|
+
constraints=_get_channel_constraints_abs(
|
|
451
|
+
opt_result=opt_result,
|
|
452
|
+
constraint_lower=spend_constraint_lower,
|
|
453
|
+
constraint_upper=spend_constraint_upper,
|
|
454
|
+
),
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
# If the spec is a fixed budget scenario, but the total budget is not
|
|
458
|
+
# specified, then set it to the budget amount used in the optimization.
|
|
459
|
+
resolve_historical_budget = (
|
|
460
|
+
isinstance(spec.scenario, optimizer.FixedBudgetScenario)
|
|
461
|
+
and spec.scenario.total_budget is None
|
|
462
|
+
)
|
|
463
|
+
if resolve_historical_budget:
|
|
464
|
+
spec = dataclasses.replace(
|
|
465
|
+
spec,
|
|
466
|
+
scenario=optimizer.FixedBudgetScenario(
|
|
467
|
+
total_budget=opt_result.optimized_data.attrs[c.BUDGET]
|
|
468
|
+
),
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
xr_response_curves = (
|
|
472
|
+
opt_result.get_response_curves()
|
|
473
|
+
if spec.include_response_curves
|
|
474
|
+
else None
|
|
475
|
+
)
|
|
476
|
+
optimized_marketing_analysis = to_marketing_analysis(
|
|
477
|
+
spec=spec,
|
|
478
|
+
xr_data=opt_result.optimized_data,
|
|
479
|
+
xr_response_curves=xr_response_curves,
|
|
480
|
+
)
|
|
481
|
+
nonoptimized_marketing_analysis = to_marketing_analysis(
|
|
482
|
+
spec=spec,
|
|
483
|
+
xr_data=opt_result.nonoptimized_data,
|
|
484
|
+
xr_response_curves=xr_response_curves,
|
|
485
|
+
)
|
|
486
|
+
result = budget_pb.BudgetOptimizationResult(
|
|
487
|
+
name=spec.optimization_name,
|
|
488
|
+
spec=spec.to_proto(),
|
|
489
|
+
optimized_marketing_analysis=optimized_marketing_analysis,
|
|
490
|
+
nonoptimized_marketing_analysis=nonoptimized_marketing_analysis,
|
|
491
|
+
incremental_outcome_grid=_to_incremental_outcome_grid(
|
|
492
|
+
opt_result.optimization_grid.grid_dataset,
|
|
493
|
+
grid_name=spec.grid_name,
|
|
494
|
+
),
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
if spec.group_id:
|
|
498
|
+
result.group_id = spec.group_id
|
|
499
|
+
return result
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
def to_marketing_analysis(
|
|
503
|
+
spec: model_processor.DatedSpec,
|
|
504
|
+
xr_data: xr.Dataset,
|
|
505
|
+
xr_response_curves: xr.Dataset | None,
|
|
506
|
+
) -> analysis_pb.MarketingAnalysis:
|
|
507
|
+
"""Converts OptimizationResults to MarketingAnalysis protos.
|
|
508
|
+
|
|
509
|
+
Args:
|
|
510
|
+
spec: The spec to build MarketingAnalysis protos for.
|
|
511
|
+
xr_data: The xr.Dataset to convert into MarketingAnalysis proto.
|
|
512
|
+
xr_response_curves: The xr.Dataset to convert into response curves.
|
|
513
|
+
|
|
514
|
+
Returns:
|
|
515
|
+
A MarketingAnalysis proto.
|
|
516
|
+
"""
|
|
517
|
+
# `spec` should have been resolved with concrete date interval parameters.
|
|
518
|
+
assert spec.start_date is not None and spec.end_date is not None
|
|
519
|
+
marketing_analysis = analysis_pb.MarketingAnalysis(
|
|
520
|
+
date_interval=time_record.create_date_interval_pb(
|
|
521
|
+
start_date=spec.start_date,
|
|
522
|
+
end_date=spec.end_date,
|
|
523
|
+
tag=spec.date_interval_tag,
|
|
524
|
+
),
|
|
525
|
+
)
|
|
526
|
+
# Include the response curves data for all channels at the optimized freq.
|
|
527
|
+
channel_response_curve_protos = _to_channel_response_curve_protos(
|
|
528
|
+
xr_response_curves
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
# Create a per-channel MediaAnalysis.
|
|
532
|
+
for channel in xr_data.channel.values:
|
|
533
|
+
channel_data = xr_data.sel(channel=channel)
|
|
534
|
+
spend = channel_data.spend.item()
|
|
535
|
+
# TODO: Resolve conflict definition of spend share.
|
|
536
|
+
spend_share = channel_data.pct_of_spend.item()
|
|
537
|
+
channel_media_analysis = media_analysis_pb.MediaAnalysis(
|
|
538
|
+
channel_name=channel,
|
|
539
|
+
spend_info=media_analysis_pb.SpendInfo(
|
|
540
|
+
spend=spend,
|
|
541
|
+
spend_share=spend_share,
|
|
542
|
+
),
|
|
543
|
+
)
|
|
544
|
+
# Output one outcome per channel: either revenue or non-revenue,
|
|
545
|
+
# but not both.
|
|
546
|
+
channel_media_analysis.media_outcomes.append(_to_outcome(channel_data))
|
|
547
|
+
if xr_response_curves is not None:
|
|
548
|
+
channel_media_analysis.response_curve.CopyFrom(
|
|
549
|
+
channel_response_curve_protos[channel]
|
|
550
|
+
)
|
|
551
|
+
marketing_analysis.media_analyses.append(channel_media_analysis)
|
|
552
|
+
|
|
553
|
+
return marketing_analysis
|
|
554
|
+
|
|
555
|
+
|
|
556
|
+
def _get_channel_constraints_abs(
|
|
557
|
+
opt_result: optimizer.OptimizationResults,
|
|
558
|
+
constraint_lower: Sequence[float],
|
|
559
|
+
constraint_upper: Sequence[float],
|
|
560
|
+
) -> list[ChannelConstraintAbs]:
|
|
561
|
+
"""Converts a sequence of channel constraints in relative terms to absolute ones.
|
|
562
|
+
|
|
563
|
+
Args:
|
|
564
|
+
opt_result: The optimization result.
|
|
565
|
+
constraint_lower: A sequence of lower bound constraints for each channel, in
|
|
566
|
+
relative terms.
|
|
567
|
+
constraint_upper: A sequence of upper bound constraints for each channel, in
|
|
568
|
+
relative terms.
|
|
569
|
+
|
|
570
|
+
Returns:
|
|
571
|
+
A list of channel constraints in absolute terms.
|
|
572
|
+
"""
|
|
573
|
+
round_factor = opt_result.optimization_grid.round_factor
|
|
574
|
+
channels = opt_result.optimized_data.channel.values
|
|
575
|
+
(optimization_lower_bound, optimization_upper_bound) = (
|
|
576
|
+
optimizer.get_optimization_bounds(
|
|
577
|
+
n_channels=len(channels),
|
|
578
|
+
spend=opt_result.nonoptimized_data.spend.data,
|
|
579
|
+
round_factor=round_factor,
|
|
580
|
+
spend_constraint_lower=constraint_lower,
|
|
581
|
+
spend_constraint_upper=constraint_upper,
|
|
582
|
+
)
|
|
583
|
+
)
|
|
584
|
+
|
|
585
|
+
abs_constraints: list[ChannelConstraintAbs] = []
|
|
586
|
+
for i, channel in enumerate(channels):
|
|
587
|
+
constraint = ChannelConstraintAbs(
|
|
588
|
+
channel_name=channel,
|
|
589
|
+
abs_lowerbound=optimization_lower_bound[i],
|
|
590
|
+
abs_upperbound=optimization_upper_bound[i],
|
|
591
|
+
)
|
|
592
|
+
abs_constraints.append(constraint)
|
|
593
|
+
return abs_constraints
|
|
594
|
+
|
|
595
|
+
|
|
596
|
+
def build_scenario_kwargs(
|
|
597
|
+
scenario: optimizer.FixedBudgetScenario | optimizer.FlexibleBudgetScenario,
|
|
598
|
+
) -> dict[str, float]:
|
|
599
|
+
"""Returns keyword arguments for an optimizer, given a spec's scenario.
|
|
600
|
+
|
|
601
|
+
The keys in the returned kwargs are a subset of the parameters in
|
|
602
|
+
`optimizer.BudgetOptimizer.optimize()` method.
|
|
603
|
+
|
|
604
|
+
Args:
|
|
605
|
+
scenario: The scenario to build kwargs for.
|
|
606
|
+
|
|
607
|
+
Raises:
|
|
608
|
+
ValueError: If no scenario is specified in the spec, or if for a given
|
|
609
|
+
scenario type, its values are invalid.
|
|
610
|
+
"""
|
|
611
|
+
kwargs = {}
|
|
612
|
+
match scenario:
|
|
613
|
+
case optimizer.FixedBudgetScenario(total_budget):
|
|
614
|
+
if total_budget is not None: # if not specified => historical spend
|
|
615
|
+
kwargs['budget'] = total_budget
|
|
616
|
+
case optimizer.FlexibleBudgetScenario(target_metric, target_value):
|
|
617
|
+
match target_metric:
|
|
618
|
+
case c.ROI:
|
|
619
|
+
key = 'target_roi'
|
|
620
|
+
case c.MROI:
|
|
621
|
+
key = 'target_mroi'
|
|
622
|
+
case _:
|
|
623
|
+
# Technically dead code, since this is already checked in `validate()`
|
|
624
|
+
raise ValueError(
|
|
625
|
+
f'Unsupported target metric: {target_metric} for flexible'
|
|
626
|
+
' budget scenario.'
|
|
627
|
+
)
|
|
628
|
+
kwargs[key] = target_value
|
|
629
|
+
case _:
|
|
630
|
+
# Technically dead code.
|
|
631
|
+
raise ValueError('Unsupported scenario type.')
|
|
632
|
+
return kwargs
|
|
633
|
+
|
|
634
|
+
|
|
635
|
+
def build_constraints_kwargs(
|
|
636
|
+
constraints: Sequence[ChannelConstraint],
|
|
637
|
+
model_channels: Sequence[str],
|
|
638
|
+
) -> dict[str, list[float]]:
|
|
639
|
+
"""Returns `spend_constraint_**` kwargs for given channel constraints.
|
|
640
|
+
|
|
641
|
+
If a media channel is not present in the spec's channel constraints, then
|
|
642
|
+
its spend constraint is implied to be the max budget of the spec's scenario.
|
|
643
|
+
|
|
644
|
+
Args:
|
|
645
|
+
constraints: The channel constraints from the spec.
|
|
646
|
+
model_channels: The list of channels in the model.
|
|
647
|
+
|
|
648
|
+
Raises:
|
|
649
|
+
ValueError: If the channel constraints are invalid (e.g. channel names are
|
|
650
|
+
not matched with the internal model data, etc).
|
|
651
|
+
"""
|
|
652
|
+
# Validate user-configured channel constraints in the spec.
|
|
653
|
+
constraints_by_channel_name = {c.channel_name: c for c in constraints}
|
|
654
|
+
constraint_channel_names = set(constraints_by_channel_name.keys())
|
|
655
|
+
if not (constraint_channel_names <= set(model_channels)):
|
|
656
|
+
raise ValueError(
|
|
657
|
+
'Channel constraints must have channel names that are in the model'
|
|
658
|
+
f' data. Expected {model_channels}, got {constraint_channel_names}.'
|
|
659
|
+
)
|
|
660
|
+
|
|
661
|
+
spend_constraint_lower = []
|
|
662
|
+
spend_constraint_upper = []
|
|
663
|
+
for channel in model_channels:
|
|
664
|
+
if channel in constraints_by_channel_name:
|
|
665
|
+
constraint = constraints_by_channel_name[channel]
|
|
666
|
+
if not isinstance(constraint, ChannelConstraintRel):
|
|
667
|
+
raise ValueError(
|
|
668
|
+
'Channel constraints in user input must be expressed in relative'
|
|
669
|
+
' ratio terms.'
|
|
670
|
+
)
|
|
671
|
+
lowerbound = constraint.spend_constraint_lower
|
|
672
|
+
upperbound = constraint.spend_constraint_upper
|
|
673
|
+
else:
|
|
674
|
+
lowerbound = CHANNEL_CONSTRAINT_LOWERBOUND_DEFAULT_RATIO
|
|
675
|
+
upperbound = CHANNEL_CONSTRAINT_UPPERBOUND_DEFAULT_RATIO
|
|
676
|
+
|
|
677
|
+
spend_constraint_lower.append(lowerbound)
|
|
678
|
+
spend_constraint_upper.append(upperbound)
|
|
679
|
+
|
|
680
|
+
return {
|
|
681
|
+
'spend_constraint_lower': spend_constraint_lower,
|
|
682
|
+
'spend_constraint_upper': spend_constraint_upper,
|
|
683
|
+
}
|
|
684
|
+
|
|
685
|
+
|
|
686
|
+
def _to_channel_response_curve_protos(
|
|
687
|
+
optimized_response_curves: xr.Dataset | None,
|
|
688
|
+
) -> Mapping[str, response_curve_pb.ResponseCurve]:
|
|
689
|
+
"""Converts a response curve dataframe to a map of channel to ResponseCurve.
|
|
690
|
+
|
|
691
|
+
Args:
|
|
692
|
+
optimized_response_curves: A dataframe containing the response curve data.
|
|
693
|
+
This is the output of `OptimizationResults.get_response_curves()`.
|
|
694
|
+
|
|
695
|
+
Returns:
|
|
696
|
+
A map of channel to ResponseCurve proto.
|
|
697
|
+
"""
|
|
698
|
+
if optimized_response_curves is None:
|
|
699
|
+
return {}
|
|
700
|
+
channels = optimized_response_curves.channel.values
|
|
701
|
+
# Flatten the dataset into a tabular dataframe so we can iterate over it.
|
|
702
|
+
df = (
|
|
703
|
+
optimized_response_curves.to_dataframe()
|
|
704
|
+
.reset_index()
|
|
705
|
+
.pivot(
|
|
706
|
+
index=[c.CHANNEL, c.SPEND, c.SPEND_MULTIPLIER],
|
|
707
|
+
columns=c.METRIC,
|
|
708
|
+
values=c.INCREMENTAL_OUTCOME,
|
|
709
|
+
)
|
|
710
|
+
.reset_index()
|
|
711
|
+
).sort_values(by=[c.CHANNEL, c.SPEND])
|
|
712
|
+
|
|
713
|
+
channel_response_curves = {
|
|
714
|
+
channel: response_curve_pb.ResponseCurve(input_name=c.SPEND)
|
|
715
|
+
for channel in channels
|
|
716
|
+
}
|
|
717
|
+
|
|
718
|
+
for _, row in df.iterrows():
|
|
719
|
+
channel = row[c.CHANNEL]
|
|
720
|
+
response_point = response_curve_pb.ResponsePoint(
|
|
721
|
+
input_value=row[c.SPEND],
|
|
722
|
+
incremental_kpi=row[c.MEAN],
|
|
723
|
+
)
|
|
724
|
+
channel_response_curves[channel].response_points.append(response_point)
|
|
725
|
+
|
|
726
|
+
return channel_response_curves
|
|
727
|
+
|
|
728
|
+
|
|
729
|
+
def _to_outcome(channel_data: xr.Dataset) -> outcome_pb.Outcome:
|
|
730
|
+
"""Returns an Outcome value for a given channel's media analysis.
|
|
731
|
+
|
|
732
|
+
Args:
|
|
733
|
+
channel_data: A channel-selected dataset from `OptimizationResults`.
|
|
734
|
+
"""
|
|
735
|
+
confidence_level = channel_data.attrs[c.CONFIDENCE_LEVEL]
|
|
736
|
+
is_revenue_kpi = channel_data.attrs[c.IS_REVENUE_KPI]
|
|
737
|
+
|
|
738
|
+
return outcome_pb.Outcome(
|
|
739
|
+
kpi_type=(
|
|
740
|
+
kpi_type_pb.REVENUE if is_revenue_kpi else kpi_type_pb.NON_REVENUE
|
|
741
|
+
),
|
|
742
|
+
roi=_to_estimate(channel_data.roi, confidence_level),
|
|
743
|
+
marginal_roi=_to_estimate(channel_data.mroi, confidence_level),
|
|
744
|
+
cost_per_contribution=_to_estimate(
|
|
745
|
+
channel_data.cpik,
|
|
746
|
+
confidence_level=confidence_level,
|
|
747
|
+
),
|
|
748
|
+
contribution=outcome_pb.Contribution(
|
|
749
|
+
value=_to_estimate(
|
|
750
|
+
channel_data.incremental_outcome, confidence_level
|
|
751
|
+
),
|
|
752
|
+
),
|
|
753
|
+
effectiveness=outcome_pb.Effectiveness(
|
|
754
|
+
media_unit=c.IMPRESSIONS,
|
|
755
|
+
value=_to_estimate(channel_data.effectiveness, confidence_level),
|
|
756
|
+
),
|
|
757
|
+
)
|
|
758
|
+
|
|
759
|
+
|
|
760
|
+
def _to_incremental_outcome_grid(
|
|
761
|
+
optimization_grid: xr.Dataset,
|
|
762
|
+
grid_name: str | None,
|
|
763
|
+
) -> budget_pb.IncrementalOutcomeGrid:
|
|
764
|
+
"""Converts an optimization grid to an `IncrementalOutcomeGrid` proto.
|
|
765
|
+
|
|
766
|
+
Args:
|
|
767
|
+
optimization_grid: The optimization grid dataset in
|
|
768
|
+
`OptimizationResults.optimization_grid`.
|
|
769
|
+
grid_name: A user-given name for this grid.
|
|
770
|
+
|
|
771
|
+
Returns:
|
|
772
|
+
An `IncrementalOutcomeGrid` proto.
|
|
773
|
+
"""
|
|
774
|
+
grid = budget_pb.IncrementalOutcomeGrid(
|
|
775
|
+
name=(grid_name or ''),
|
|
776
|
+
spend_step_size=optimization_grid.spend_step_size,
|
|
777
|
+
)
|
|
778
|
+
for channel in optimization_grid.channel.values:
|
|
779
|
+
channel_grid = optimization_grid.sel(channel=channel)
|
|
780
|
+
spend_grid = channel_grid.spend_grid.dropna(dim=c.GRID_SPEND_INDEX)
|
|
781
|
+
incremental_outcome_grid = channel_grid.incremental_outcome_grid.dropna(
|
|
782
|
+
dim=c.GRID_SPEND_INDEX
|
|
783
|
+
)
|
|
784
|
+
if len(spend_grid) != len(incremental_outcome_grid):
|
|
785
|
+
raise ValueError(
|
|
786
|
+
f'Spend grid and incremental outcome grid for channel "{channel}" do'
|
|
787
|
+
' not agree.'
|
|
788
|
+
)
|
|
789
|
+
channel_cells = budget_pb.IncrementalOutcomeGrid.ChannelCells(
|
|
790
|
+
channel_name=channel,
|
|
791
|
+
cells=[
|
|
792
|
+
budget_pb.IncrementalOutcomeGrid.Cell(
|
|
793
|
+
spend=spend.item(),
|
|
794
|
+
incremental_outcome=estimate_pb.Estimate(
|
|
795
|
+
value=incr_outcome.item()
|
|
796
|
+
),
|
|
797
|
+
)
|
|
798
|
+
for (spend, incr_outcome) in zip(
|
|
799
|
+
spend_grid, incremental_outcome_grid
|
|
800
|
+
)
|
|
801
|
+
],
|
|
802
|
+
)
|
|
803
|
+
grid.channel_cells.append(channel_cells)
|
|
804
|
+
return grid
|
|
805
|
+
|
|
806
|
+
|
|
807
|
+
def _to_estimate(
|
|
808
|
+
dataarray: xr.DataArray,
|
|
809
|
+
confidence_level: float = c.DEFAULT_CONFIDENCE_LEVEL,
|
|
810
|
+
) -> estimate_pb.Estimate:
|
|
811
|
+
"""Converts a DataArray with (mean, ci_lo, ci_hi) `metric` datavars."""
|
|
812
|
+
estimate = estimate_pb.Estimate(value=dataarray.sel(metric=c.MEAN).item())
|
|
813
|
+
uncertainty = estimate_pb.Estimate.Uncertainty(
|
|
814
|
+
probability=confidence_level,
|
|
815
|
+
lowerbound=dataarray.sel(metric=c.CI_LO).item(),
|
|
816
|
+
upperbound=dataarray.sel(metric=c.CI_HI).item(),
|
|
817
|
+
)
|
|
818
|
+
estimate.uncertainties.append(uncertainty)
|
|
819
|
+
return estimate
|
|
820
|
+
|
|
821
|
+
|
|
822
|
+
def _target_metric_to_proto(
|
|
823
|
+
target_metric: str,
|
|
824
|
+
) -> target_pb.TargetMetric:
|
|
825
|
+
"""Converts a TargetMetric enum to a TargetMetric proto."""
|
|
826
|
+
match target_metric:
|
|
827
|
+
case c.ROI:
|
|
828
|
+
return target_pb.TargetMetric.ROI
|
|
829
|
+
case c.MROI:
|
|
830
|
+
return target_pb.TargetMetric.MARGINAL_ROI
|
|
831
|
+
case _:
|
|
832
|
+
raise ValueError(f'Unsupported target metric: {target_metric}')
|