google-meridian 1.3.1__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.3.1.dist-info → google_meridian-1.4.0.dist-info}/METADATA +13 -9
- google_meridian-1.4.0.dist-info/RECORD +108 -0
- {google_meridian-1.3.1.dist-info → google_meridian-1.4.0.dist-info}/top_level.txt +1 -0
- meridian/analysis/__init__.py +1 -2
- meridian/analysis/analyzer.py +0 -1
- meridian/analysis/optimizer.py +5 -3
- meridian/analysis/review/checks.py +81 -30
- meridian/analysis/review/constants.py +4 -0
- meridian/analysis/review/results.py +40 -9
- meridian/analysis/summarizer.py +8 -3
- meridian/analysis/test_utils.py +934 -485
- meridian/analysis/visualizer.py +11 -7
- meridian/backend/__init__.py +53 -5
- meridian/backend/test_utils.py +72 -0
- meridian/constants.py +2 -0
- meridian/data/load.py +2 -0
- meridian/data/test_utils.py +82 -10
- meridian/model/__init__.py +2 -0
- meridian/model/context.py +925 -0
- meridian/model/eda/__init__.py +0 -1
- meridian/model/eda/constants.py +13 -2
- meridian/model/eda/eda_engine.py +299 -37
- meridian/model/eda/eda_outcome.py +21 -1
- meridian/model/equations.py +418 -0
- meridian/model/knots.py +75 -47
- meridian/model/model.py +93 -792
- meridian/{analysis/templates → templates}/card.html.jinja +1 -1
- meridian/{analysis/templates → templates}/chart.html.jinja +1 -1
- meridian/{analysis/templates → templates}/chips.html.jinja +1 -1
- meridian/{analysis → templates}/formatter.py +12 -1
- meridian/templates/formatter_test.py +216 -0
- meridian/{analysis/templates → templates}/insights.html.jinja +1 -1
- meridian/{analysis/templates → templates}/stats.html.jinja +1 -1
- meridian/{analysis/templates → templates}/style.css +1 -1
- meridian/{analysis/templates → templates}/style.scss +1 -1
- meridian/{analysis/templates → templates}/summary.html.jinja +4 -2
- meridian/{analysis/templates → templates}/table.html.jinja +1 -1
- meridian/version.py +1 -1
- scenarioplanner/__init__.py +42 -0
- scenarioplanner/converters/__init__.py +25 -0
- scenarioplanner/converters/dataframe/__init__.py +28 -0
- scenarioplanner/converters/dataframe/budget_opt_converters.py +383 -0
- scenarioplanner/converters/dataframe/common.py +71 -0
- scenarioplanner/converters/dataframe/constants.py +137 -0
- scenarioplanner/converters/dataframe/converter.py +42 -0
- scenarioplanner/converters/dataframe/dataframe_model_converter.py +70 -0
- scenarioplanner/converters/dataframe/marketing_analyses_converters.py +543 -0
- scenarioplanner/converters/dataframe/rf_opt_converters.py +314 -0
- scenarioplanner/converters/mmm.py +743 -0
- scenarioplanner/converters/mmm_converter.py +58 -0
- scenarioplanner/converters/sheets.py +156 -0
- scenarioplanner/converters/test_data.py +714 -0
- scenarioplanner/linkingapi/__init__.py +47 -0
- scenarioplanner/linkingapi/constants.py +27 -0
- scenarioplanner/linkingapi/url_generator.py +131 -0
- scenarioplanner/mmm_ui_proto_generator.py +354 -0
- schema/__init__.py +15 -0
- schema/mmm_proto_generator.py +71 -0
- schema/model_consumer.py +133 -0
- schema/processors/__init__.py +77 -0
- schema/processors/budget_optimization_processor.py +832 -0
- schema/processors/common.py +64 -0
- schema/processors/marketing_processor.py +1136 -0
- schema/processors/model_fit_processor.py +367 -0
- schema/processors/model_kernel_processor.py +117 -0
- schema/processors/model_processor.py +412 -0
- schema/processors/reach_frequency_optimization_processor.py +584 -0
- schema/test_data.py +380 -0
- schema/utils/__init__.py +1 -0
- schema/utils/date_range_bucketing.py +117 -0
- google_meridian-1.3.1.dist-info/RECORD +0 -76
- meridian/model/eda/meridian_eda.py +0 -220
- {google_meridian-1.3.1.dist-info → google_meridian-1.4.0.dist-info}/WHEEL +0 -0
- {google_meridian-1.3.1.dist-info → google_meridian-1.4.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,412 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Defines common and base classes for processing trained Meridian model to an MMM schema."""
|
|
16
|
+
|
|
17
|
+
import abc
|
|
18
|
+
from collections.abc import Sequence
|
|
19
|
+
import dataclasses
|
|
20
|
+
import datetime
|
|
21
|
+
import functools
|
|
22
|
+
from typing import Generic, TypeVar
|
|
23
|
+
|
|
24
|
+
from google.protobuf import message
|
|
25
|
+
from meridian import constants as c
|
|
26
|
+
from meridian.analysis import analyzer
|
|
27
|
+
from meridian.analysis import optimizer
|
|
28
|
+
from meridian.analysis import visualizer
|
|
29
|
+
from meridian.data import time_coordinates as tc
|
|
30
|
+
from meridian.model import model
|
|
31
|
+
from mmm.v1 import mmm_pb2 as pb
|
|
32
|
+
from mmm.v1.common import date_interval_pb2
|
|
33
|
+
from schema.utils import time_record
|
|
34
|
+
from typing_extensions import override
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
__all__ = [
|
|
38
|
+
'ModelProcessor',
|
|
39
|
+
'TrainedModel',
|
|
40
|
+
'DatedSpec',
|
|
41
|
+
'DatedSpecResolver',
|
|
42
|
+
'OptimizationSpec',
|
|
43
|
+
'ensure_trained_model',
|
|
44
|
+
]
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class TrainedModel(abc.ABC):
|
|
48
|
+
"""Encapsulates a trained MMM model."""
|
|
49
|
+
|
|
50
|
+
def __init__(self, mmm: model.Meridian):
|
|
51
|
+
"""Initializes the TrainedModel with a Meridian model.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
mmm: A Meridian model that has been fitted (posterior samples drawn).
|
|
55
|
+
|
|
56
|
+
Raises:
|
|
57
|
+
ValueError: If the model has not been fitted (posterior samples drawn).
|
|
58
|
+
"""
|
|
59
|
+
# Ideally, this could be encoded in the model type itself, and we won't need
|
|
60
|
+
# this extra runtime check.
|
|
61
|
+
if mmm.inference_data.prior is None or mmm.inference_data.posterior is None: # pytype: disable=attribute-error
|
|
62
|
+
raise ValueError('MMM model has not been fitted.')
|
|
63
|
+
self._mmm = mmm
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def mmm(self) -> model.Meridian:
|
|
67
|
+
return self._mmm
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def time_coordinates(self) -> tc.TimeCoordinates:
|
|
71
|
+
return self._mmm.input_data.time_coordinates
|
|
72
|
+
|
|
73
|
+
@functools.cached_property
|
|
74
|
+
def internal_analyzer(self) -> analyzer.Analyzer:
|
|
75
|
+
"""Returns an internal `Analyzer` bound to this trained model."""
|
|
76
|
+
return analyzer.Analyzer(self.mmm)
|
|
77
|
+
|
|
78
|
+
@functools.cached_property
|
|
79
|
+
def internal_optimizer(self) -> optimizer.BudgetOptimizer:
|
|
80
|
+
"""Returns an internal `BudgetOptimizer` bound to this trained model."""
|
|
81
|
+
return optimizer.BudgetOptimizer(self.mmm)
|
|
82
|
+
|
|
83
|
+
@functools.cached_property
|
|
84
|
+
def internal_model_diagnostics(self) -> visualizer.ModelDiagnostics:
|
|
85
|
+
"""Returns an internal `ModelDiagnostics` bound to this trained model."""
|
|
86
|
+
return visualizer.ModelDiagnostics(self.mmm)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
ModelType = model.Meridian | TrainedModel
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def ensure_trained_model(model_input: ModelType) -> TrainedModel:
|
|
93
|
+
"""Ensure the given model is a trained model, and wrap it in a TrainedModel."""
|
|
94
|
+
if isinstance(model_input, TrainedModel):
|
|
95
|
+
return model_input
|
|
96
|
+
return TrainedModel(model_input)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class Spec(abc.ABC):
|
|
100
|
+
"""Contains parameters needed for model-based analysis/optimization."""
|
|
101
|
+
|
|
102
|
+
@abc.abstractmethod
|
|
103
|
+
def validate(self):
|
|
104
|
+
"""Checks whether each parameter in the Spec has a valid value.
|
|
105
|
+
|
|
106
|
+
Raises:
|
|
107
|
+
ValueError: If any parameter in the Spec has an invalid value.
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
def __post_init__(self):
|
|
111
|
+
self.validate()
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@dataclasses.dataclass(frozen=True)
|
|
115
|
+
class DatedSpec(Spec):
|
|
116
|
+
"""A spec with a `[start_date, end_date)` closed-open date range semantic.
|
|
117
|
+
|
|
118
|
+
Attrs:
|
|
119
|
+
start_date: The start date of the analysis/optimization. If left as `None`,
|
|
120
|
+
then this will eventually resolve to a model's first time coordinate.
|
|
121
|
+
end_date: The end date of the analysis/optimization. If left as `None`, then
|
|
122
|
+
this will eventually resolve to a model's last time coordinate. When
|
|
123
|
+
specified, this end date is exclusive.
|
|
124
|
+
date_interval_tag: An optional tag that identifies the date interval.
|
|
125
|
+
"""
|
|
126
|
+
|
|
127
|
+
start_date: datetime.date | None = None
|
|
128
|
+
end_date: datetime.date | None = None
|
|
129
|
+
date_interval_tag: str = ''
|
|
130
|
+
|
|
131
|
+
@override
|
|
132
|
+
def validate(self):
|
|
133
|
+
"""Overrides the Spec.validate() method to check that dates are valid."""
|
|
134
|
+
if (
|
|
135
|
+
self.start_date is not None
|
|
136
|
+
and self.end_date is not None
|
|
137
|
+
and self.start_date > self.end_date
|
|
138
|
+
):
|
|
139
|
+
raise ValueError('Start date must be before end date.')
|
|
140
|
+
|
|
141
|
+
def resolver(
|
|
142
|
+
self, time_coordinates: tc.TimeCoordinates
|
|
143
|
+
) -> 'DatedSpecResolver':
|
|
144
|
+
"""Returns a date resolver for this spec, with the given Meridian model."""
|
|
145
|
+
return DatedSpecResolver(self, time_coordinates)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class DatedSpecResolver:
|
|
149
|
+
"""Resolves date parameters in specs based on a model's time coordinates."""
|
|
150
|
+
|
|
151
|
+
def __init__(self, spec: DatedSpec, time_coordinates: tc.TimeCoordinates):
|
|
152
|
+
self._spec = spec
|
|
153
|
+
self._time_coordinates = time_coordinates
|
|
154
|
+
|
|
155
|
+
@property
|
|
156
|
+
def _interval_days(self) -> int:
|
|
157
|
+
return self._time_coordinates.interval_days
|
|
158
|
+
|
|
159
|
+
@property
|
|
160
|
+
def time_coordinates(self) -> tc.TimeCoordinates:
|
|
161
|
+
return self._time_coordinates
|
|
162
|
+
|
|
163
|
+
def to_closed_date_interval_tuple(
|
|
164
|
+
self,
|
|
165
|
+
) -> tuple[str | None, str | None]:
|
|
166
|
+
"""Transforms given spec into a closed `[start, end]` date interval tuple.
|
|
167
|
+
|
|
168
|
+
For each of the bookends in the tuple, `None` value indicates a time
|
|
169
|
+
coordinate default (first or last time coordinate, respectively).
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
A **closed** `[start, end]` date interval tuple.
|
|
173
|
+
"""
|
|
174
|
+
start, end = (None, None)
|
|
175
|
+
|
|
176
|
+
if self._spec.start_date is not None:
|
|
177
|
+
start = self._spec.start_date.strftime(c.DATE_FORMAT)
|
|
178
|
+
if self._spec.end_date is not None:
|
|
179
|
+
inclusive_end_date = self._spec.end_date - datetime.timedelta(
|
|
180
|
+
days=self._interval_days
|
|
181
|
+
)
|
|
182
|
+
end = inclusive_end_date.strftime(c.DATE_FORMAT)
|
|
183
|
+
|
|
184
|
+
return (start, end)
|
|
185
|
+
|
|
186
|
+
def resolve_to_enumerated_selected_times(self) -> list[str] | None:
|
|
187
|
+
"""Resolves the given spec into an enumerated list of time coordinates.
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
An enumerated list of time coordinates, or None (semantic "All") if the
|
|
191
|
+
bound spec is also None.
|
|
192
|
+
"""
|
|
193
|
+
start, end = self.to_closed_date_interval_tuple()
|
|
194
|
+
expanded = self._time_coordinates.expand_selected_time_dims(
|
|
195
|
+
start_date=start, end_date=end
|
|
196
|
+
)
|
|
197
|
+
if expanded is None:
|
|
198
|
+
return None
|
|
199
|
+
return [date.strftime(c.DATE_FORMAT) for date in expanded]
|
|
200
|
+
|
|
201
|
+
def resolve_to_bool_selected_times(self) -> list[bool] | None:
|
|
202
|
+
"""Resolves the given spec into a list of booleans indicating selected times.
|
|
203
|
+
|
|
204
|
+
Returns:
|
|
205
|
+
A list of booleans indicating selected times, or None (semantic "All") if
|
|
206
|
+
the bound spec is also None.
|
|
207
|
+
"""
|
|
208
|
+
selected_times = self.resolve_to_enumerated_selected_times()
|
|
209
|
+
if selected_times is None:
|
|
210
|
+
return None
|
|
211
|
+
return [
|
|
212
|
+
time in selected_times for time in self._time_coordinates.all_dates_str
|
|
213
|
+
]
|
|
214
|
+
|
|
215
|
+
def collapse_to_date_interval_proto(self) -> date_interval_pb2.DateInterval:
|
|
216
|
+
"""Collapses the given spec into a `DateInterval` proto.
|
|
217
|
+
|
|
218
|
+
If the spec's date range is unbounded, then the DateInterval proto will have
|
|
219
|
+
the semantic "All", and we resolve it by consulting the time coordinates of
|
|
220
|
+
the model bound to this resolver.
|
|
221
|
+
|
|
222
|
+
Note that the exclusive end date semantic will be preserved in the returned
|
|
223
|
+
proto.
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
A `DateInterval` proto the represents the date interval specified by the
|
|
227
|
+
spec.
|
|
228
|
+
"""
|
|
229
|
+
selected_times = self.resolve_to_enumerated_selected_times()
|
|
230
|
+
if selected_times is None:
|
|
231
|
+
start_date = self._time_coordinates.all_dates[0]
|
|
232
|
+
end_date = self._time_coordinates.all_dates[-1]
|
|
233
|
+
else:
|
|
234
|
+
normalized_selected_times = [
|
|
235
|
+
tc.normalize_date(date) for date in selected_times
|
|
236
|
+
]
|
|
237
|
+
start_date = normalized_selected_times[0]
|
|
238
|
+
end_date = normalized_selected_times[-1]
|
|
239
|
+
|
|
240
|
+
# Adjust end_date to make it exclusive.
|
|
241
|
+
end_date += datetime.timedelta(days=self._interval_days)
|
|
242
|
+
|
|
243
|
+
return time_record.create_date_interval_pb(
|
|
244
|
+
start_date, end_date, tag=self._spec.date_interval_tag
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
def transform_to_date_interval_protos(
|
|
248
|
+
self,
|
|
249
|
+
) -> list[date_interval_pb2.DateInterval]:
|
|
250
|
+
"""Transforms the given spec into `DateInterval` protos.
|
|
251
|
+
|
|
252
|
+
If the spec's date range is unbounded, then the DateInterval proto will have
|
|
253
|
+
the semantic "All", and we resolve it by consulting the time coordinates of
|
|
254
|
+
the model bound to this resolver.
|
|
255
|
+
|
|
256
|
+
Note that the exclusive end date semantic will be preserved in the returned
|
|
257
|
+
proto.
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
A list of `DateInterval` protos the represents the date intervals
|
|
261
|
+
specified by the spec.
|
|
262
|
+
"""
|
|
263
|
+
selected_times = self.resolve_to_enumerated_selected_times()
|
|
264
|
+
if selected_times is None:
|
|
265
|
+
times_list = self._time_coordinates.all_dates
|
|
266
|
+
else:
|
|
267
|
+
times_list = [tc.normalize_date(date) for date in selected_times]
|
|
268
|
+
|
|
269
|
+
date_intervals = []
|
|
270
|
+
for start_date in times_list:
|
|
271
|
+
date_interval = time_record.create_date_interval_pb(
|
|
272
|
+
start_date=start_date,
|
|
273
|
+
end_date=start_date + datetime.timedelta(days=self._interval_days),
|
|
274
|
+
tag=self._spec.date_interval_tag,
|
|
275
|
+
)
|
|
276
|
+
date_intervals.append(date_interval)
|
|
277
|
+
|
|
278
|
+
return date_intervals
|
|
279
|
+
|
|
280
|
+
def resolve_to_date_interval_open_end(
|
|
281
|
+
self,
|
|
282
|
+
) -> tuple[datetime.date, datetime.date]:
|
|
283
|
+
"""Resolves given spec into an open-ended `[start, end)` date interval."""
|
|
284
|
+
start = self._spec.start_date or self._time_coordinates.all_dates[0]
|
|
285
|
+
end = self._spec.end_date
|
|
286
|
+
if end is None:
|
|
287
|
+
end = self._time_coordinates.all_dates[-1]
|
|
288
|
+
# Adjust `end` to make it exclusive, but only if we pulled it from the
|
|
289
|
+
# time coordinates.
|
|
290
|
+
end += datetime.timedelta(days=self._interval_days)
|
|
291
|
+
return (start, end)
|
|
292
|
+
|
|
293
|
+
def resolve_to_date_interval_proto(self) -> date_interval_pb2.DateInterval:
|
|
294
|
+
"""Resolves the given spec into a fully specified `DateInterval` proto.
|
|
295
|
+
|
|
296
|
+
If either `start_date` or `end_date` is None in the bound spec, then we
|
|
297
|
+
resolve it by consulting the time coordinates of the model bound to this
|
|
298
|
+
resolver. They are resolved to the first and last time coordinates (plus
|
|
299
|
+
interval length), respectively.
|
|
300
|
+
|
|
301
|
+
Note that the exclusive end date semantic will be preserved in the returned
|
|
302
|
+
proto.
|
|
303
|
+
|
|
304
|
+
Returns:
|
|
305
|
+
A resolved `DateInterval` proto the represents the date interval specified
|
|
306
|
+
by the bound spec.
|
|
307
|
+
"""
|
|
308
|
+
start_date, end_date = self.resolve_to_date_interval_open_end()
|
|
309
|
+
return time_record.create_date_interval_pb(
|
|
310
|
+
start_date, end_date, tag=self._spec.date_interval_tag
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
315
|
+
class OptimizationSpec(DatedSpec):
|
|
316
|
+
"""A dated spec for optimization.
|
|
317
|
+
|
|
318
|
+
Attrs:
|
|
319
|
+
optimization_name: The name of the optimization in this spec.
|
|
320
|
+
grid_name: The name of the optimization grid.
|
|
321
|
+
group_id: An optional group ID for linking related optimizations.
|
|
322
|
+
confidence_level: The threshold for computing confidence intervals. Defaults
|
|
323
|
+
to 0.9. Must be a number between 0 and 1.
|
|
324
|
+
"""
|
|
325
|
+
|
|
326
|
+
optimization_name: str
|
|
327
|
+
grid_name: str
|
|
328
|
+
group_id: str | None = None
|
|
329
|
+
confidence_level: float = c.DEFAULT_CONFIDENCE_LEVEL
|
|
330
|
+
|
|
331
|
+
@override
|
|
332
|
+
def validate(self):
|
|
333
|
+
"""Check optimization parameters are valid."""
|
|
334
|
+
super().validate()
|
|
335
|
+
|
|
336
|
+
if not self.optimization_name or self.optimization_name.isspace():
|
|
337
|
+
raise ValueError('Optimization name must not be empty or blank.')
|
|
338
|
+
|
|
339
|
+
if not self.grid_name or self.grid_name.isspace():
|
|
340
|
+
raise ValueError('Grid name must not be empty or blank.')
|
|
341
|
+
|
|
342
|
+
if self.confidence_level < 0 or self.confidence_level > 1:
|
|
343
|
+
raise ValueError('Confidence level must be between 0 and 1.')
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
S = TypeVar('S', bound=Spec)
|
|
347
|
+
M = TypeVar('M', bound=message.Message)
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
class ModelProcessor(abc.ABC, Generic[S, M]):
|
|
351
|
+
"""Performs model-based analysis or optimization."""
|
|
352
|
+
|
|
353
|
+
@classmethod
|
|
354
|
+
@abc.abstractmethod
|
|
355
|
+
def spec_type(cls) -> type[S]:
|
|
356
|
+
"""Returns the concrete Spec type that this ModelProcessor operates on."""
|
|
357
|
+
raise NotImplementedError()
|
|
358
|
+
|
|
359
|
+
@classmethod
|
|
360
|
+
@abc.abstractmethod
|
|
361
|
+
def output_type(cls) -> type[M]:
|
|
362
|
+
"""Returns the concrete output type that this ModelProcessor produces."""
|
|
363
|
+
raise NotImplementedError()
|
|
364
|
+
|
|
365
|
+
@abc.abstractmethod
|
|
366
|
+
def execute(self, specs: Sequence[S]) -> M:
|
|
367
|
+
"""Runs an analysis/optimization on the model using the given specs.
|
|
368
|
+
|
|
369
|
+
Args:
|
|
370
|
+
specs: Sequence of Specs containing parameters needed for the
|
|
371
|
+
analysis/optimization. The specs must all be of the same type as
|
|
372
|
+
`self.spec_type()` for this processor
|
|
373
|
+
|
|
374
|
+
Returns:
|
|
375
|
+
A proto containing the results of the analysis/optimization.
|
|
376
|
+
"""
|
|
377
|
+
raise NotImplementedError()
|
|
378
|
+
|
|
379
|
+
@abc.abstractmethod
|
|
380
|
+
def _set_output(self, output: pb.Mmm, result: M):
|
|
381
|
+
"""Sets the output field in the given `MmmOutput` proto.
|
|
382
|
+
|
|
383
|
+
A model consumer that orchestrated this processor will indirectly call this
|
|
384
|
+
method (via `__call__`) to attach the output of `execute()` (a
|
|
385
|
+
processor-defined message `M`) into a partially built `MmmOutput` proto that
|
|
386
|
+
the model consumer manages.
|
|
387
|
+
|
|
388
|
+
Args:
|
|
389
|
+
output: The container output proto to which the given result message
|
|
390
|
+
should be attached.
|
|
391
|
+
result: An output of `execute()`.
|
|
392
|
+
"""
|
|
393
|
+
raise NotImplementedError()
|
|
394
|
+
|
|
395
|
+
def __call__(self, specs: Sequence[S], output: pb.Mmm):
|
|
396
|
+
"""Runs an analysis/optimization on the model using the given specs.
|
|
397
|
+
|
|
398
|
+
This also sets the appropriate output field in the given MmmOutput proto.
|
|
399
|
+
|
|
400
|
+
Args:
|
|
401
|
+
specs: Sequence of Specs containing parameters needed for the
|
|
402
|
+
analysis/optimization. The specs must all be of the same type as
|
|
403
|
+
`self.spec_type()` for this processor
|
|
404
|
+
output: The output proto to which the results of the analysis/optimization
|
|
405
|
+
should be attached.
|
|
406
|
+
|
|
407
|
+
Raises:
|
|
408
|
+
ValueError: If any spec is not of the same type as `self.spec_type()`.
|
|
409
|
+
"""
|
|
410
|
+
if not all([isinstance(spec, self.spec_type()) for spec in specs]):
|
|
411
|
+
raise ValueError('Not all specs are of type %s' % self.spec_type())
|
|
412
|
+
self._set_output(output, self.execute(specs))
|