google-meridian 1.3.2__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/METADATA +8 -4
  2. {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/RECORD +49 -17
  3. {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/top_level.txt +1 -0
  4. meridian/analysis/summarizer.py +7 -2
  5. meridian/analysis/test_utils.py +934 -485
  6. meridian/analysis/visualizer.py +10 -6
  7. meridian/constants.py +1 -0
  8. meridian/data/test_utils.py +82 -10
  9. meridian/model/__init__.py +2 -0
  10. meridian/model/context.py +925 -0
  11. meridian/model/eda/constants.py +1 -0
  12. meridian/model/equations.py +418 -0
  13. meridian/model/knots.py +58 -47
  14. meridian/model/model.py +93 -792
  15. meridian/version.py +1 -1
  16. scenarioplanner/__init__.py +42 -0
  17. scenarioplanner/converters/__init__.py +25 -0
  18. scenarioplanner/converters/dataframe/__init__.py +28 -0
  19. scenarioplanner/converters/dataframe/budget_opt_converters.py +383 -0
  20. scenarioplanner/converters/dataframe/common.py +71 -0
  21. scenarioplanner/converters/dataframe/constants.py +137 -0
  22. scenarioplanner/converters/dataframe/converter.py +42 -0
  23. scenarioplanner/converters/dataframe/dataframe_model_converter.py +70 -0
  24. scenarioplanner/converters/dataframe/marketing_analyses_converters.py +543 -0
  25. scenarioplanner/converters/dataframe/rf_opt_converters.py +314 -0
  26. scenarioplanner/converters/mmm.py +743 -0
  27. scenarioplanner/converters/mmm_converter.py +58 -0
  28. scenarioplanner/converters/sheets.py +156 -0
  29. scenarioplanner/converters/test_data.py +714 -0
  30. scenarioplanner/linkingapi/__init__.py +47 -0
  31. scenarioplanner/linkingapi/constants.py +27 -0
  32. scenarioplanner/linkingapi/url_generator.py +131 -0
  33. scenarioplanner/mmm_ui_proto_generator.py +354 -0
  34. schema/__init__.py +5 -2
  35. schema/mmm_proto_generator.py +71 -0
  36. schema/model_consumer.py +133 -0
  37. schema/processors/__init__.py +77 -0
  38. schema/processors/budget_optimization_processor.py +832 -0
  39. schema/processors/common.py +64 -0
  40. schema/processors/marketing_processor.py +1136 -0
  41. schema/processors/model_fit_processor.py +367 -0
  42. schema/processors/model_kernel_processor.py +117 -0
  43. schema/processors/model_processor.py +412 -0
  44. schema/processors/reach_frequency_optimization_processor.py +584 -0
  45. schema/test_data.py +380 -0
  46. schema/utils/__init__.py +1 -0
  47. schema/utils/date_range_bucketing.py +117 -0
  48. {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/WHEEL +0 -0
  49. {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1136 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Meridian module for analyzing marketing data in a Meridian model.
16
+
17
+ This module provides a `MarketingProcessor`, designed to extract key marketing
18
+ insights from a trained Meridian model. It allows users to understand the impact
19
+ of different marketing channels, calculate return on investment (ROI), and
20
+ generate response curves.
21
+
22
+ The processor uses specifications defined in `MarketingAnalysisSpec` to control
23
+ the analysis. Users can request:
24
+
25
+ 1. **Media Summary Metrics:** Aggregated performance metrics for each media
26
+ channel, including spend, contribution, ROI, and effectiveness.
27
+ 2. **Incremental Outcomes:** The additional KPI or revenue driven by marketing
28
+ activities, calculated by comparing against a baseline scenario (e.g., zero
29
+ spend).
30
+ 3. **Response Curves:** Visualizations of how the predicted KPI or revenue
31
+ changes as spend on a particular channel increases, helping to identify
32
+ diminishing returns.
33
+
34
+ The results are output as a `MarketingAnalysisList` protobuf message, containing
35
+ detailed breakdowns per channel and for the baseline.
36
+
37
+ Key Classes:
38
+
39
+ - `MediaSummarySpec`: Configures the calculation of summary metrics like ROI.
40
+ - `IncrementalOutcomeSpec`: Configures the calculation of incremental impact.
41
+ - `ResponseCurveSpec`: Configures response curve generation.
42
+ - `MarketingAnalysisSpec`: The main specification to combine the above,
43
+ define date ranges, and set confidence levels.
44
+ - `MarketingProcessor`: The processor class that executes the analysis based
45
+ on the provided specs.
46
+
47
+ Example Usage:
48
+
49
+ 1. **Get Media Summary Metrics for a specific period:**
50
+
51
+ ```python
52
+ from schema.processors import marketing_processor
53
+ import datetime
54
+
55
+ # Assuming 'trained_model' is a loaded Meridian model object
56
+
57
+ spec = marketing_processor.MarketingAnalysisSpec(
58
+ analysis_name="q1_summary",
59
+ start_date=datetime.date(2023, 1, 1),
60
+ end_date=datetime.date(2023, 3, 31),
61
+ media_summary_spec=marketing_processor.MediaSummarySpec(
62
+ aggregate_times=True
63
+ ),
64
+ response_curve_spec=marketing_processor.ResponseCurveSpec(),
65
+ confidence_level=0.9,
66
+ )
67
+
68
+ processor = marketing_processor.MarketingProcessor(trained_model)
69
+ # `result` is a `marketing_analysis_pb2.MarketingAnalysisList` proto
70
+ result = processor.execute([spec])
71
+ ```
72
+
73
+ 2. **Calculate Incremental Outcome with new spend data:**
74
+
75
+ ```python
76
+ from schema.processors import marketing_processor
77
+ from meridian.analysis import analyzer
78
+ import datetime
79
+ import numpy as np
80
+
81
+ # Assuming 'trained_model' is a loaded Meridian model object
82
+ # Assuming 'new_media_spend' is a numpy array with shape (time, channels)
83
+
84
+ # Create DataTensors for the new data
85
+ # Example:
86
+ # new_data = analyzer.DataTensors(
87
+ # media=new_media_spend,
88
+ # time=new_time_index,
89
+ # )
90
+
91
+ spec = marketing_processor.MarketingAnalysisSpec(
92
+ analysis_name="what_if_scenario",
93
+ # NOTE: Dates must align with `new_data.time`
94
+ start_date=datetime.date(2023, 1, 1),
95
+ end_date=datetime.date(2023, 1, 31),
96
+ incremental_outcome_spec=marketing_processor.IncrementalOutcomeSpec(
97
+ new_data=new_data,
98
+ aggregate_times=True,
99
+ ),
100
+ )
101
+
102
+ processor = marketing_processor.MarketingProcessor(trained_model)
103
+ result = processor.execute([spec])
104
+
105
+ print(f"Incremental Outcome for {spec.analysis_name}:")
106
+ # Process results from result.marketing_analyses
107
+ ```
108
+
109
+ Note: You can provide the processor with multiple specs. This would result in
110
+ multiple marketing analysis results in the output.
111
+ """
112
+
113
+ from collections.abc import Sequence
114
+ import dataclasses
115
+ import datetime
116
+ import functools
117
+ import warnings
118
+
119
+ from meridian import constants
120
+ from meridian.analysis import analyzer
121
+ from meridian.data import time_coordinates
122
+ from mmm.v1 import mmm_pb2
123
+ from mmm.v1.common import date_interval_pb2
124
+ from mmm.v1.common import kpi_type_pb2
125
+ from mmm.v1.marketing.analysis import marketing_analysis_pb2
126
+ from mmm.v1.marketing.analysis import media_analysis_pb2
127
+ from mmm.v1.marketing.analysis import non_media_analysis_pb2
128
+ from mmm.v1.marketing.analysis import outcome_pb2
129
+ from mmm.v1.marketing.analysis import response_curve_pb2
130
+ from schema.processors import common
131
+ from schema.processors import model_processor
132
+ import numpy as np
133
+ import xarray as xr
134
+
135
+ __all__ = [
136
+ "MediaSummarySpec",
137
+ "IncrementalOutcomeSpec",
138
+ "ResponseCurveSpec",
139
+ "MarketingAnalysisSpec",
140
+ "MarketingProcessor",
141
+ ]
142
+
143
+
144
+ @dataclasses.dataclass(frozen=True)
145
+ class MediaSummarySpec(model_processor.Spec):
146
+ """Stores parameters needed for creating media summary metrics.
147
+
148
+ Attributes:
149
+ aggregate_times: Boolean. If `True`, the media summary metrics are
150
+ aggregated over time. Defaults to `True`.
151
+ marginal_roi_by_reach: Boolean. Marginal ROI (mROI) is defined as the return
152
+ on the next dollar spent. If this argument is `True`, the assumption is
153
+ that the next dollar spent only impacts reach, holding frequency constant.
154
+ If this argument is `False`, the assumption is that the next dollar spent
155
+ only impacts frequency, holding reach constant. Defaults to `True`.
156
+ include_non_paid_channels: Boolean. If `True`, the media summary metrics
157
+ include non-paid channels. Defaults to `False`.
158
+ new_data: Optional `DataTensors` container with optional tensors: `media`,
159
+ `reach`, `frequency`, `organic_media`, `organic_reach`,
160
+ `organic_frequency`, `non_media_treatments` and `revenue_per_kpi`. If
161
+ `None`, the metrics are calculated using the `InputData` provided to the
162
+ Meridian object. If `new_data` is provided, the metrics are calculated
163
+ using the new tensors in `new_data` and the original values of the
164
+ remaining tensors.
165
+ media_selected_times: Optional list containing booleans with length equal to
166
+ the number of time periods in `new_data`, if provided. If `new_data` is
167
+ provided, `media_selected_times` can select any subset of time periods in
168
+ `new_data`. If `new_data` is not provided, `media_selected_times` selects
169
+ from model's original media data.
170
+ """
171
+
172
+ aggregate_times: bool = True
173
+ marginal_roi_by_reach: bool = True
174
+ include_non_paid_channels: bool = False
175
+ # b/384034128 Use new args in `summary_metrics`.
176
+ new_data: analyzer.DataTensors | None = None
177
+ media_selected_times: Sequence[bool] | None = None
178
+
179
+ def validate(self):
180
+ pass
181
+
182
+
183
+ @dataclasses.dataclass(frozen=True, kw_only=True)
184
+ class IncrementalOutcomeSpec(model_processor.Spec):
185
+ """Stores parameters needed for processing a model into `MarketingAnalysis`s.
186
+
187
+ Attributes:
188
+ aggregate_times: Boolean. If `True`, the media summary metrics are
189
+ aggregated over time. Defaults to `True`.
190
+ new_data: Optional `DataTensors` container with optional tensors: `media`,
191
+ `reach`, `frequency`, `organic_media`, `organic_reach`,
192
+ `organic_frequency`, `non_media_treatments` and `revenue_per_kpi`. If
193
+ `None`, the incremental outcome is calculated using the `InputData`
194
+ provided to the Meridian object. If `new_data` is provided, the
195
+ incremental outcome is calculated using the new tensors in `new_data` and
196
+ the original values of the remaining tensors. For example,
197
+ `incremental_outcome(new_data=DataTensors(media=new_media)` computes the
198
+ incremental outcome using `new_media` and the original values of `reach`,
199
+ `frequency`, `organic_media`, `organic_reach`, `organic_frequency`,
200
+ `non_media_treatments` and `revenue_per_kpi`. If any of the tensors in
201
+ `new_data` is provided with a different number of time periods than in
202
+ `InputData`, then all tensors must be provided with the same number of
203
+ time periods.
204
+ media_selected_times: Optional list containing booleans with length equal to
205
+ the number of time periods in `new_data`, if provided. If `new_data` is
206
+ provided, `media_selected_times` can select any subset of time periods in
207
+ `new_data`. If `new_data` is not provided, `media_selected_times` selects
208
+ from model's original media data and its length must be equal to the
209
+ number of time periods in the model's original media data.
210
+ include_non_paid_channels: Boolean. If `True`, the incremental outcome
211
+ includes non-paid channels. Defaults to `False`.
212
+ """
213
+
214
+ aggregate_times: bool = True
215
+ new_data: analyzer.DataTensors | None = None
216
+ media_selected_times: Sequence[bool] | None = None
217
+ include_non_paid_channels: bool = False
218
+
219
+ def validate(self):
220
+ super().validate()
221
+ if (self.new_data is not None) and (self.new_data.time is None):
222
+ raise ValueError("`time` must be provided in `new_data`.")
223
+
224
+
225
+ @dataclasses.dataclass(frozen=True)
226
+ class ResponseCurveSpec(model_processor.Spec):
227
+ """Stores parameters needed for creating response curves.
228
+
229
+ Attributes:
230
+ by_reach: Boolean. For channels with reach and frequency. If `True`, plots
231
+ the response curve by reach. If `False`, plots the response curve by
232
+ frequency.
233
+ """
234
+
235
+ by_reach: bool = True
236
+
237
+ def validate(self):
238
+ pass
239
+
240
+
241
+ @dataclasses.dataclass(frozen=True, kw_only=True)
242
+ class MarketingAnalysisSpec(model_processor.DatedSpec):
243
+ """Stores parameters needed for processing a model into `MarketingAnalysis`s.
244
+
245
+ Either `media_summary_spec` or `incremental_outcome_spec` must be provided,
246
+ but not both.
247
+
248
+ Attributes:
249
+ media_summary_spec: Parameters for creating media summary metrics. Mutually
250
+ exclusive with `incremental_outcome_spec`.
251
+ incremental_outcome_spec: Parameters for creating incremental outcome.
252
+ Mutually exclusive with `media_summary_spec`. If `new_data` is provided,
253
+ then the start and end dates of this `MarketingAnalysisSpec` must be
254
+ within the `new_data.time`.
255
+ response_curve_spec: Parameters for creating response curves. Response
256
+ curves are only computed for specs that aggregate times and have a
257
+ `media_summary_spec` selected.
258
+ confidence_level: Confidence level for credible intervals, represented as a
259
+ value between zero and one. Defaults to 0.9.
260
+ """
261
+
262
+ media_summary_spec: MediaSummarySpec | None = None
263
+ incremental_outcome_spec: IncrementalOutcomeSpec | None = None
264
+ response_curve_spec: ResponseCurveSpec = dataclasses.field(
265
+ default_factory=ResponseCurveSpec
266
+ )
267
+ confidence_level: float = constants.DEFAULT_CONFIDENCE_LEVEL
268
+
269
+ def validate(self):
270
+ super().validate()
271
+ if self.confidence_level <= 0 or self.confidence_level >= 1:
272
+ raise ValueError(
273
+ "Confidence level must be greater than 0 and less than 1."
274
+ )
275
+ if (
276
+ self.media_summary_spec is None
277
+ and self.incremental_outcome_spec is None
278
+ ):
279
+ raise ValueError(
280
+ "At least one of `media_summary_spec` or `incremental_outcome_spec`"
281
+ " must be provided."
282
+ )
283
+ if (
284
+ self.media_summary_spec is not None
285
+ and self.incremental_outcome_spec is not None
286
+ ):
287
+ raise ValueError(
288
+ "Only one of `media_summary_spec` or `incremental_outcome_spec` can"
289
+ " be provided."
290
+ )
291
+
292
+
293
+ class MarketingProcessor(
294
+ model_processor.ModelProcessor[
295
+ MarketingAnalysisSpec, marketing_analysis_pb2.MarketingAnalysisList
296
+ ]
297
+ ):
298
+ """Generates `MarketingAnalysis` protos for a given trained Meridian model.
299
+
300
+ A `MarketingAnalysis` proto is generated for each spec supplied to
301
+ `execute()`. Within each `MarketingAnalysis` proto, a `MediaAnalysis` proto
302
+ is created for each channel in the model. One `NonMediaAnalysis` proto is also
303
+ created for the model's baseline data.
304
+ """
305
+
306
+ def __init__(
307
+ self,
308
+ trained_model: model_processor.ModelType,
309
+ ):
310
+ trained_model = model_processor.ensure_trained_model(trained_model)
311
+ self._analyzer = trained_model.internal_analyzer
312
+ self._meridian = trained_model.mmm
313
+ self._model_time_coordinates = trained_model.time_coordinates
314
+ self._interval_length = self._model_time_coordinates.interval_days
315
+
316
+ # If the input data KPI type is "revenue", then the `revenue_per_kpi` tensor
317
+ # must exist, and general-KPI type outcomes should not be defined.
318
+ self._revenue_kpi_type = (
319
+ trained_model.mmm.input_data.kpi_type == constants.REVENUE
320
+ )
321
+ # `_kpi_only` is TRUE iff the input data KPI type is "non-revenue" AND the
322
+ # `revenue_per_kpi` tensor is None.
323
+ self._kpi_only = trained_model.mmm.input_data.revenue_per_kpi is None
324
+
325
+ @classmethod
326
+ def spec_type(cls) -> type[MarketingAnalysisSpec]:
327
+ return MarketingAnalysisSpec
328
+
329
+ @classmethod
330
+ def output_type(cls) -> type[marketing_analysis_pb2.MarketingAnalysisList]:
331
+ return marketing_analysis_pb2.MarketingAnalysisList
332
+
333
+ def _set_output(
334
+ self,
335
+ output: mmm_pb2.Mmm,
336
+ result: marketing_analysis_pb2.MarketingAnalysisList,
337
+ ):
338
+ output.marketing_analysis_list.CopyFrom(result)
339
+
340
+ def execute(
341
+ self, marketing_analysis_specs: Sequence[MarketingAnalysisSpec]
342
+ ) -> marketing_analysis_pb2.MarketingAnalysisList:
343
+ """Runs a marketing analysis on the model based on the given specs.
344
+
345
+ A `MarketingAnalysis` proto is created for each of the given specs. Each
346
+ `MarketingAnalysis` proto contains a list of `MediaAnalysis` protos and a
347
+ singleton `NonMediaAnalysis` proto for the baseline analysis. The analysis
348
+ covers the time period bounded by the spec's start and end dates.
349
+
350
+ The singleton non-media analysis is performed on the model's baseline data,
351
+ and contains metrics such as incremental outcome and baseline percent of
352
+ contribution across media and non-media.
353
+
354
+ A media analysis is performed for each channel in the model, plus an
355
+ "All Channels" synthetic channel. The media analysis contains metrics such
356
+ as spend, percent of spend, incremental outcome, percent of contribution,
357
+ and effectiveness. Depending on the type of data (revenue-based or
358
+ non-revenue-based) in the model, the analysis also contains CPIK
359
+ (non-revenue-based) or ROI and MROI (revenue-based).
360
+
361
+ Args:
362
+ marketing_analysis_specs: A sequence of MarketingAnalysisSpec objects.
363
+
364
+ Returns:
365
+ A MarketingAnalysisList proto containing the results of the marketing
366
+ analysis for each spec.
367
+ """
368
+ marketing_analysis_list: list[marketing_analysis_pb2.MarketingAnalysis] = []
369
+
370
+ for spec in marketing_analysis_specs:
371
+ if (
372
+ spec.incremental_outcome_spec is not None
373
+ and spec.incremental_outcome_spec.new_data is not None
374
+ and spec.incremental_outcome_spec.new_data.time is not None
375
+ ):
376
+ new_time_coords = time_coordinates.TimeCoordinates.from_dates(
377
+ np.asarray(spec.incremental_outcome_spec.new_data.time)
378
+ .astype(str)
379
+ .tolist()
380
+ )
381
+ resolver = spec.resolver(new_time_coords)
382
+ else:
383
+ resolver = spec.resolver(self._model_time_coordinates)
384
+ media_summary_marketing_analyses = (
385
+ self._generate_marketing_analyses_for_media_summary_spec(
386
+ spec, resolver
387
+ )
388
+ )
389
+ incremental_outcome_marketing_analyses = (
390
+ self._generate_marketing_analyses_for_incremental_outcome_spec(
391
+ spec, resolver
392
+ )
393
+ )
394
+ marketing_analysis_list.extend(
395
+ media_summary_marketing_analyses
396
+ + incremental_outcome_marketing_analyses
397
+ )
398
+
399
+ return marketing_analysis_pb2.MarketingAnalysisList(
400
+ marketing_analyses=marketing_analysis_list
401
+ )
402
+
403
+ def _generate_marketing_analyses_for_media_summary_spec(
404
+ self,
405
+ marketing_analysis_spec: MarketingAnalysisSpec,
406
+ resolver: model_processor.DatedSpecResolver,
407
+ ) -> list[marketing_analysis_pb2.MarketingAnalysis]:
408
+ """Creates a list of MarketingAnalysis protos based on the given spec.
409
+
410
+ If spec's `aggregate_times` is True, then only one MarketingAnalysis proto
411
+ is created. Otherwise, one MarketingAnalysis proto is created for each date
412
+ interval in the spec.
413
+
414
+ Args:
415
+ marketing_analysis_spec: An instance of MarketingAnalysisSpec.
416
+ resolver: A DatedSpecResolver instance.
417
+
418
+ Returns:
419
+ A list of `MarketingAnalysis` protos containing the results of the
420
+ marketing analysis for the given spec.
421
+ """
422
+ media_summary_spec = marketing_analysis_spec.media_summary_spec
423
+ if media_summary_spec is None:
424
+ return []
425
+
426
+ selected_times = resolver.resolve_to_enumerated_selected_times()
427
+ # This contains either a revenue-based KPI or a non-revenue KPI analysis.
428
+ media_summary_metrics, non_media_summary_metrics = (
429
+ self._generate_media_and_non_media_summary_metrics(
430
+ media_summary_spec,
431
+ selected_times,
432
+ marketing_analysis_spec.confidence_level,
433
+ self._kpi_only,
434
+ )
435
+ )
436
+
437
+ secondary_non_revenue_kpi_metrics = None
438
+ secondary_non_revenue_kpi_non_media_metrics = None
439
+ # If the input data KPI type is "non-revenue", and we calculated its
440
+ # revenue-based KPI outcomes above, then we should also compute its
441
+ # non-revenue KPI outcomes.
442
+ if not self._revenue_kpi_type and not self._kpi_only:
443
+ (
444
+ secondary_non_revenue_kpi_metrics,
445
+ secondary_non_revenue_kpi_non_media_metrics,
446
+ ) = self._generate_media_and_non_media_summary_metrics(
447
+ media_summary_spec,
448
+ selected_times,
449
+ marketing_analysis_spec.confidence_level,
450
+ use_kpi=True,
451
+ )
452
+
453
+ # Note: baseline_summary_metrics() prefers computing revenue (scaled from
454
+ # generic KPI with `revenue_per_kpi` when defined) baseline outcome here.
455
+ # TODO: Baseline outcomes for both revenue and non-revenue
456
+ # KPI types should be computed, when possible.
457
+ baseline_outcome = self._analyzer.baseline_summary_metrics(
458
+ confidence_level=marketing_analysis_spec.confidence_level,
459
+ aggregate_times=media_summary_spec.aggregate_times,
460
+ selected_times=selected_times,
461
+ ).sel(distribution=constants.POSTERIOR)
462
+
463
+ # Response curves are only computed for specs that aggregate times.
464
+ if media_summary_spec.aggregate_times:
465
+ response_curve_spec = marketing_analysis_spec.response_curve_spec
466
+ response_curves = self._analyzer.response_curves(
467
+ confidence_level=marketing_analysis_spec.confidence_level,
468
+ use_posterior=True,
469
+ selected_times=selected_times,
470
+ use_kpi=self._kpi_only,
471
+ by_reach=response_curve_spec.by_reach,
472
+ )
473
+ else:
474
+ response_curves = None
475
+ warnings.warn(
476
+ "Response curves are not computed for non-aggregated time periods."
477
+ )
478
+
479
+ date_intervals = self._build_time_intervals(
480
+ aggregate_times=media_summary_spec.aggregate_times,
481
+ resolver=resolver,
482
+ )
483
+
484
+ return self._marketing_metrics_to_protos(
485
+ media_summary_metrics,
486
+ non_media_summary_metrics,
487
+ baseline_outcome,
488
+ secondary_non_revenue_kpi_metrics,
489
+ secondary_non_revenue_kpi_non_media_metrics,
490
+ response_curves,
491
+ marketing_analysis_spec,
492
+ date_intervals,
493
+ )
494
+
495
+ def _generate_media_and_non_media_summary_metrics(
496
+ self,
497
+ media_summary_spec: MediaSummarySpec,
498
+ selected_times: list[str] | None,
499
+ confidence_level: float,
500
+ use_kpi: bool,
501
+ ) -> tuple[xr.Dataset | None, xr.Dataset | None]:
502
+ if media_summary_spec is None:
503
+ return (None, None)
504
+ compute_media_summary_metrics = functools.partial(
505
+ self._analyzer.summary_metrics,
506
+ marginal_roi_by_reach=media_summary_spec.marginal_roi_by_reach,
507
+ selected_times=selected_times,
508
+ aggregate_geos=True,
509
+ aggregate_times=media_summary_spec.aggregate_times,
510
+ confidence_level=confidence_level,
511
+ )
512
+
513
+ media_summary_metrics = compute_media_summary_metrics(
514
+ use_kpi=use_kpi,
515
+ include_non_paid_channels=False,
516
+ ).sel(distribution=constants.POSTERIOR)
517
+ # TODO:Produce one metrics for both paid and non-paid channels.
518
+ non_media_summary_metrics = None
519
+ if media_summary_spec.include_non_paid_channels:
520
+ media_summary_metrics = media_summary_metrics.drop_sel(
521
+ channel=constants.ALL_CHANNELS
522
+ )
523
+ non_media_summary_metrics = (
524
+ compute_media_summary_metrics(
525
+ use_kpi=use_kpi,
526
+ include_non_paid_channels=True,
527
+ )
528
+ .sel(distribution=constants.POSTERIOR)
529
+ .drop_sel(
530
+ channel=media_summary_metrics.coords[constants.CHANNEL].data
531
+ )
532
+ )
533
+ return media_summary_metrics, non_media_summary_metrics
534
+
535
+ def _generate_marketing_analyses_for_incremental_outcome_spec(
536
+ self,
537
+ marketing_analysis_spec: MarketingAnalysisSpec,
538
+ resolver: model_processor.DatedSpecResolver,
539
+ ) -> list[marketing_analysis_pb2.MarketingAnalysis]:
540
+ """Creates a list of `MarketingAnalysis` protos based on the given spec.
541
+
542
+ If the spec's `aggregate_times` is True, then only one `MarketingAnalysis`
543
+ proto is created. Otherwise, one `MarketingAnalysis` proto is created for
544
+ each date interval in the spec.
545
+
546
+ Args:
547
+ marketing_analysis_spec: An instance of MarketingAnalysisSpec.
548
+ resolver: A DatedSpecResolver instance.
549
+
550
+ Returns:
551
+ A list of `MarketingAnalysis` protos containing the results of the
552
+ marketing analysis for the given spec.
553
+ """
554
+ incremental_outcome_spec = marketing_analysis_spec.incremental_outcome_spec
555
+ if incremental_outcome_spec is None:
556
+ return []
557
+
558
+ compute_incremental_outcome = functools.partial(
559
+ self._incremental_outcome_dataset,
560
+ resolver=resolver,
561
+ new_data=incremental_outcome_spec.new_data,
562
+ media_selected_times=incremental_outcome_spec.media_selected_times,
563
+ aggregate_geos=True,
564
+ aggregate_times=incremental_outcome_spec.aggregate_times,
565
+ confidence_level=marketing_analysis_spec.confidence_level,
566
+ include_non_paid_channels=False,
567
+ )
568
+ # This contains either a revenue-based KPI or a non-revenue KPI analysis.
569
+ incremental_outcome = compute_incremental_outcome(use_kpi=self._kpi_only)
570
+
571
+ secondary_non_revenue_kpi_metrics = None
572
+ # If the input data KPI type is "non-revenue", and we calculated its
573
+ # revenue-based KPI outcomes above, then we should also compute its
574
+ # non-revenue KPI outcomes.
575
+ if not self._revenue_kpi_type and not self._kpi_only:
576
+ secondary_non_revenue_kpi_metrics = compute_incremental_outcome(
577
+ use_kpi=True
578
+ )
579
+
580
+ date_intervals = self._build_time_intervals(
581
+ aggregate_times=incremental_outcome_spec.aggregate_times,
582
+ resolver=resolver,
583
+ )
584
+
585
+ return self._marketing_metrics_to_protos(
586
+ metrics=incremental_outcome,
587
+ non_media_metrics=None,
588
+ baseline_outcome=None,
589
+ secondary_non_revenue_kpi_metrics=secondary_non_revenue_kpi_metrics,
590
+ secondary_non_revenue_kpi_non_media_metrics=None,
591
+ response_curves=None,
592
+ marketing_analysis_spec=marketing_analysis_spec,
593
+ date_intervals=date_intervals,
594
+ )
595
+
596
+ def _build_time_intervals(
597
+ self,
598
+ aggregate_times: bool,
599
+ resolver: model_processor.DatedSpecResolver,
600
+ ) -> list[date_interval_pb2.DateInterval]:
601
+ """Creates a list of `DateInterval` protos for the given spec.
602
+
603
+ Args:
604
+ aggregate_times: Whether to aggregate times.
605
+ resolver: A DatedSpecResolver instance.
606
+
607
+ Returns:
608
+ A list of `DateInterval` protos for the given spec.
609
+ """
610
+ if aggregate_times:
611
+ date_interval = resolver.collapse_to_date_interval_proto()
612
+ # This means metrics are aggregated over time, only one date interval is
613
+ # needed.
614
+ return [date_interval]
615
+
616
+ # This list will contain all date intervals for the given spec. All dates
617
+ # in this list will share a common tag.
618
+ return resolver.transform_to_date_interval_protos()
619
+
620
+ def _marketing_metrics_to_protos(
621
+ self,
622
+ metrics: xr.Dataset,
623
+ non_media_metrics: xr.Dataset | None,
624
+ baseline_outcome: xr.Dataset | None,
625
+ secondary_non_revenue_kpi_metrics: xr.Dataset | None,
626
+ secondary_non_revenue_kpi_non_media_metrics: xr.Dataset | None,
627
+ response_curves: xr.Dataset | None,
628
+ marketing_analysis_spec: MarketingAnalysisSpec,
629
+ date_intervals: Sequence[date_interval_pb2.DateInterval],
630
+ ) -> list[marketing_analysis_pb2.MarketingAnalysis]:
631
+ """Creates a list of MarketingAnalysis protos from datasets."""
632
+ if metrics is None:
633
+ raise ValueError("metrics is None")
634
+
635
+ media_channels = list(metrics.coords[constants.CHANNEL].data)
636
+ non_media_channels = (
637
+ list(non_media_metrics.coords[constants.CHANNEL].data)
638
+ if non_media_metrics
639
+ else []
640
+ )
641
+ channels = media_channels + non_media_channels
642
+ channels_with_response_curve = (
643
+ response_curves.coords[constants.CHANNEL].data
644
+ if response_curves
645
+ else []
646
+ )
647
+ marketing_analyses = []
648
+ for date_interval in date_intervals:
649
+ start_date = date_interval.start_date
650
+ start_date_str = datetime.date(
651
+ start_date.year, start_date.month, start_date.day
652
+ ).strftime(constants.DATE_FORMAT)
653
+ media_analyses: list[media_analysis_pb2.MediaAnalysis] = []
654
+ non_media_analyses: list[non_media_analysis_pb2.NonMediaAnalysis] = []
655
+
656
+ # For all channels reported in the media summary metrics
657
+ for channel_name in channels:
658
+ channel_response_curve = None
659
+ if response_curves and (channel_name in channels_with_response_curve):
660
+ channel_response_curve = response_curves.sel(
661
+ {constants.CHANNEL: channel_name}
662
+ )
663
+ is_media_channel = channel_name in media_channels
664
+
665
+ channel_analysis = self._get_channel_metrics(
666
+ marketing_analysis_spec,
667
+ channel_name,
668
+ start_date_str,
669
+ metrics if is_media_channel else non_media_metrics,
670
+ secondary_non_revenue_kpi_metrics
671
+ if is_media_channel
672
+ else secondary_non_revenue_kpi_non_media_metrics,
673
+ channel_response_curve,
674
+ is_media_channel,
675
+ )
676
+ if isinstance(channel_analysis, media_analysis_pb2.MediaAnalysis):
677
+ media_analyses.append(channel_analysis)
678
+
679
+ if isinstance(
680
+ channel_analysis, non_media_analysis_pb2.NonMediaAnalysis
681
+ ):
682
+ non_media_analyses.append(channel_analysis)
683
+
684
+ marketing_analysis = marketing_analysis_pb2.MarketingAnalysis(
685
+ date_interval=date_interval,
686
+ media_analyses=media_analyses,
687
+ non_media_analyses=non_media_analyses,
688
+ )
689
+ if baseline_outcome is not None:
690
+ baseline_analysis = self._get_baseline_metrics(
691
+ marketing_analysis_spec=marketing_analysis_spec,
692
+ baseline_outcome=baseline_outcome,
693
+ start_date=start_date_str,
694
+ )
695
+ marketing_analysis.non_media_analyses.append(baseline_analysis)
696
+
697
+ marketing_analyses.append(marketing_analysis)
698
+
699
+ return marketing_analyses
700
+
701
+ def _get_channel_metrics(
702
+ self,
703
+ marketing_analysis_spec: MarketingAnalysisSpec,
704
+ channel_name: str,
705
+ start_date_str: str,
706
+ metrics: xr.Dataset,
707
+ secondary_metrics: xr.Dataset | None,
708
+ channel_response_curves: xr.Dataset | None,
709
+ is_media_channel: bool,
710
+ ) -> (
711
+ media_analysis_pb2.MediaAnalysis | non_media_analysis_pb2.NonMediaAnalysis
712
+ ):
713
+ """Returns a MediaAnalysis proto for the given channel."""
714
+ if constants.TIME in metrics.coords:
715
+ sel = {
716
+ constants.CHANNEL: channel_name,
717
+ constants.TIME: start_date_str,
718
+ }
719
+ else:
720
+ sel = {constants.CHANNEL: channel_name}
721
+
722
+ channel_metrics = metrics.sel(sel)
723
+ if secondary_metrics is not None:
724
+ channel_secondary_metrics = secondary_metrics.sel(sel)
725
+ else:
726
+ channel_secondary_metrics = None
727
+
728
+ return self._channel_metrics_to_proto(
729
+ channel_metrics,
730
+ channel_secondary_metrics,
731
+ channel_response_curves,
732
+ channel_name,
733
+ is_media_channel,
734
+ marketing_analysis_spec.confidence_level,
735
+ )
736
+
737
+ def _channel_metrics_to_proto(
738
+ self,
739
+ channel_media_summary_metrics: xr.Dataset,
740
+ channel_secondary_non_revenue_metrics: xr.Dataset | None,
741
+ channel_response_curve: xr.Dataset | None,
742
+ channel_name: str,
743
+ is_media_channel: bool,
744
+ confidence_level: float = constants.DEFAULT_CONFIDENCE_LEVEL,
745
+ ) -> (
746
+ media_analysis_pb2.MediaAnalysis | non_media_analysis_pb2.NonMediaAnalysis
747
+ ):
748
+ """Creates a MediaAnalysis proto for the given channel from datasets.
749
+
750
+ Args:
751
+ channel_media_summary_metrics: A dataset containing the model's media
752
+ summary metrics. This dataset is pre-filtered to `channel_name`. This
753
+ dataset contains revenue-based metrics if the model's input data is
754
+ revenue-based, or if `revenue_per_kpi` is defined. Otherwise, it
755
+ contains non-revenue generic KPI metrics.
756
+ channel_secondary_non_revenue_metrics: A dataset containing the model's
757
+ non-revenue-based media summary metrics. This is only defined iff the
758
+ input data is non-revenue type AND `revenue_per_kpi` is available. In
759
+ this case, `channel_media_summary_metrics` contains revenue-based
760
+ metrics computed from `KPI * revenue_per_kpi`, and this dataset contains
761
+ media summary metrics based on the model's generic KPI alone. In all
762
+ other cases, this is `None`.
763
+ channel_response_curve: A dataset containing the data needed to generate a
764
+ response curve. This dataset is pre-filtered to `channel_name`.
765
+ channel_name: The name of the channel to analyze.
766
+ is_media_channel: Whether the channel is a media channel.
767
+ confidence_level: Confidence level for credible intervals, represented as
768
+ a value between zero and one.
769
+
770
+ Returns:
771
+ A proto containing the media analysis results for the given channel.
772
+ """
773
+
774
+ spend_info = _compute_spend(channel_media_summary_metrics)
775
+ is_all_channels = channel_name == constants.ALL_CHANNELS
776
+
777
+ compute_outcome = functools.partial(
778
+ self._compute_outcome,
779
+ is_all_channels=is_all_channels,
780
+ confidence_level=confidence_level,
781
+ )
782
+
783
+ outcomes = [
784
+ compute_outcome(
785
+ channel_media_summary_metrics,
786
+ is_revenue_type=(not self._kpi_only),
787
+ )
788
+ ]
789
+ # If `channel_media_summary_metrics` represented non-revenue data with
790
+ # revenue-type outcome (i.e. `is_revenue_type_kpi` is defined), then we
791
+ # should also have been provided with media summary metrics for their
792
+ # generic KPI counterparts, as well.
793
+ if channel_secondary_non_revenue_metrics is not None:
794
+ outcomes.append(
795
+ compute_outcome(
796
+ channel_secondary_non_revenue_metrics,
797
+ is_revenue_type=False,
798
+ )
799
+ )
800
+
801
+ if not is_media_channel:
802
+ return non_media_analysis_pb2.NonMediaAnalysis(
803
+ non_media_name=channel_name,
804
+ non_media_outcomes=outcomes,
805
+ )
806
+
807
+ media_analysis = media_analysis_pb2.MediaAnalysis(
808
+ channel_name=channel_name,
809
+ media_outcomes=outcomes,
810
+ )
811
+
812
+ if spend_info is not None:
813
+ media_analysis.spend_info.CopyFrom(spend_info)
814
+
815
+ if channel_response_curve is not None:
816
+ media_analysis.response_curve.CopyFrom(
817
+ self._compute_response_curve(
818
+ channel_response_curve,
819
+ )
820
+ )
821
+
822
+ return media_analysis
823
+
824
+ def _get_baseline_metrics(
825
+ self,
826
+ marketing_analysis_spec: MarketingAnalysisSpec,
827
+ baseline_outcome: xr.Dataset,
828
+ start_date: str,
829
+ ) -> non_media_analysis_pb2.NonMediaAnalysis:
830
+ """Analyzes "baseline" pseudo-channel outcomes over the given time points.
831
+
832
+ Args:
833
+ marketing_analysis_spec: A user input parameter specs for this analysis.
834
+ baseline_outcome: A dataset containing the model's baseline summary
835
+ metrics.
836
+ start_date: The date of the analysis.
837
+
838
+ Returns:
839
+ A `NonMediaAnalysis` representing baseline analysis.
840
+ """
841
+ if constants.TIME in baseline_outcome.coords:
842
+ baseline_outcome = baseline_outcome.sel(
843
+ time=start_date,
844
+ )
845
+ incremental_outcome = baseline_outcome[constants.BASELINE_OUTCOME]
846
+ # Convert percentage to decimal.
847
+ contribution_share = baseline_outcome[constants.PCT_OF_CONTRIBUTION] / 100
848
+
849
+ contribution = outcome_pb2.Contribution(
850
+ value=common.to_estimate(
851
+ incremental_outcome, marketing_analysis_spec.confidence_level
852
+ ),
853
+ share=common.to_estimate(
854
+ contribution_share, marketing_analysis_spec.confidence_level
855
+ ),
856
+ )
857
+ baseline_analysis = non_media_analysis_pb2.NonMediaAnalysis(
858
+ non_media_name=constants.BASELINE,
859
+ )
860
+ baseline_outcome = outcome_pb2.Outcome(
861
+ contribution=contribution,
862
+ # Baseline outcome is always revenue-based, unless `revenue_per_kpi`
863
+ # is undefined.
864
+ # TODO: kpi_type here is synced with what is used inside
865
+ # `baseline_summary_metrics()`. Ideally, really, we should inject this
866
+ # value into that function rather than re-deriving it here.
867
+ kpi_type=(
868
+ kpi_type_pb2.KpiType.NON_REVENUE
869
+ if self._kpi_only
870
+ else kpi_type_pb2.KpiType.REVENUE
871
+ ),
872
+ )
873
+ baseline_analysis.non_media_outcomes.append(baseline_outcome)
874
+
875
+ return baseline_analysis
876
+
877
+ def _compute_outcome(
878
+ self,
879
+ media_summary_metrics: xr.Dataset,
880
+ is_revenue_type: bool,
881
+ is_all_channels: bool,
882
+ confidence_level: float = constants.DEFAULT_CONFIDENCE_LEVEL,
883
+ ) -> outcome_pb2.Outcome:
884
+ """Returns an `Outcome` proto for the given channel's media analysis.
885
+
886
+ Args:
887
+ media_summary_metrics: A dataset containing the model's media summary
888
+ metrics.
889
+ is_revenue_type: Whether the media summary metrics above are revenue
890
+ based.
891
+ is_all_channels: If True, the given media summary represents the aggregate
892
+ "All Channels". Omit `effectiveness` and `mroi` in this case.
893
+ confidence_level: Confidence level for credible intervals, represented as
894
+ a value between zero and one.
895
+ """
896
+ data_vars = media_summary_metrics.data_vars
897
+
898
+ effectiveness = roi = mroi = cpik = None
899
+ if not is_all_channels and constants.EFFECTIVENESS in data_vars:
900
+ effectiveness = outcome_pb2.Effectiveness(
901
+ media_unit=constants.IMPRESSIONS,
902
+ value=common.to_estimate(
903
+ media_summary_metrics[constants.EFFECTIVENESS],
904
+ confidence_level,
905
+ ),
906
+ )
907
+ if not is_all_channels and constants.MROI in data_vars:
908
+ mroi = common.to_estimate(
909
+ media_summary_metrics[constants.MROI],
910
+ confidence_level,
911
+ )
912
+
913
+ contribution_value = media_summary_metrics[constants.INCREMENTAL_OUTCOME]
914
+ contribution = outcome_pb2.Contribution(
915
+ value=common.to_estimate(
916
+ contribution_value,
917
+ confidence_level,
918
+ ),
919
+ )
920
+ # Convert percentage to decimal.
921
+ if constants.PCT_OF_CONTRIBUTION in data_vars:
922
+ contribution_share = (
923
+ media_summary_metrics[constants.PCT_OF_CONTRIBUTION] / 100
924
+ )
925
+ contribution.share.CopyFrom(
926
+ common.to_estimate(
927
+ contribution_share,
928
+ confidence_level,
929
+ )
930
+ )
931
+
932
+ if constants.CPIK in data_vars:
933
+ cpik = common.to_estimate(
934
+ media_summary_metrics[constants.CPIK],
935
+ confidence_level,
936
+ metric=constants.MEDIAN,
937
+ )
938
+
939
+ if constants.ROI in data_vars:
940
+ roi = common.to_estimate(
941
+ media_summary_metrics[constants.ROI],
942
+ confidence_level,
943
+ )
944
+
945
+ return outcome_pb2.Outcome(
946
+ kpi_type=(
947
+ kpi_type_pb2.KpiType.REVENUE
948
+ if is_revenue_type
949
+ else kpi_type_pb2.KpiType.NON_REVENUE
950
+ ),
951
+ contribution=contribution,
952
+ effectiveness=effectiveness,
953
+ cost_per_contribution=cpik,
954
+ roi=roi,
955
+ marginal_roi=mroi,
956
+ )
957
+
958
+ def _compute_response_curve(
959
+ self,
960
+ response_curve_dataset: xr.Dataset,
961
+ ) -> response_curve_pb2.ResponseCurve:
962
+ """Returns a `ResponseCurve` proto for the given channel.
963
+
964
+ Args:
965
+ response_curve_dataset: A dataset containing the data needed to generate a
966
+ response curve.
967
+ """
968
+
969
+ spend_multiplier_list = response_curve_dataset.coords[
970
+ constants.SPEND_MULTIPLIER
971
+ ].data
972
+ response_points: list[response_curve_pb2.ResponsePoint] = []
973
+
974
+ for spend_multiplier in spend_multiplier_list:
975
+ spend = (
976
+ response_curve_dataset[constants.SPEND]
977
+ .sel(spend_multiplier=spend_multiplier)
978
+ .data.item()
979
+ )
980
+ incremental_outcome = (
981
+ response_curve_dataset[constants.INCREMENTAL_OUTCOME]
982
+ .sel(
983
+ spend_multiplier=spend_multiplier,
984
+ metric=constants.MEAN,
985
+ )
986
+ .data.item()
987
+ )
988
+
989
+ response_point = response_curve_pb2.ResponsePoint(
990
+ input_value=spend,
991
+ incremental_kpi=incremental_outcome,
992
+ )
993
+ response_points.append(response_point)
994
+
995
+ return response_curve_pb2.ResponseCurve(
996
+ input_name=constants.SPEND,
997
+ response_points=response_points,
998
+ )
999
+
1000
+ # TODO: Create an abstraction/container around these inference
1001
+ # parameters.
1002
+ def _incremental_outcome_dataset(
1003
+ self,
1004
+ resolver: model_processor.DatedSpecResolver,
1005
+ new_data: analyzer.DataTensors | None = None,
1006
+ media_selected_times: Sequence[bool] | None = None,
1007
+ selected_geos: Sequence[str] | None = None,
1008
+ aggregate_geos: bool = True,
1009
+ aggregate_times: bool = True,
1010
+ use_kpi: bool = False,
1011
+ confidence_level: float = constants.DEFAULT_CONFIDENCE_LEVEL,
1012
+ batch_size: int = constants.DEFAULT_BATCH_SIZE,
1013
+ include_non_paid_channels: bool = False,
1014
+ ) -> xr.Dataset:
1015
+ """Returns incremental outcome for each channel with dimensions.
1016
+
1017
+ Args:
1018
+ resolver: A `DatedSpecResolver` instance.
1019
+ new_data: A dataset containing the new data to use in the analysis.
1020
+ media_selected_times: A boolean array of length `n_times` indicating which
1021
+ time periods are media-active.
1022
+ selected_geos: Optional list containing a subset of geos to include. By
1023
+ default, all geos are included.
1024
+ aggregate_geos: Boolean. If `True`, the expected outcome is summed over
1025
+ all of the regions.
1026
+ aggregate_times: Boolean. If `True`, the expected outcome is summed over
1027
+ all of the time periods.
1028
+ use_kpi: Boolean. If `True`, the summary metrics are calculated using KPI.
1029
+ If `False`, the metrics are calculated using revenue.
1030
+ confidence_level: Confidence level for summary metrics credible intervals,
1031
+ represented as a value between zero and one.
1032
+ batch_size: Integer representing the maximum draws per chain in each
1033
+ batch. The calculation is run in batches to avoid memory exhaustion. If
1034
+ a memory error occurs, try reducing `batch_size`. The calculation will
1035
+ generally be faster with larger `batch_size` values.
1036
+ include_non_paid_channels: Boolean. If `True`, non-paid channels (organic
1037
+ media, organic reach and frequency, and non-media treatments) are
1038
+ included in the summary but only the metrics independent of spend are
1039
+ reported. If `False`, only the paid channels (media, reach and
1040
+ frequency) are included but the summary contains also the metrics
1041
+ dependent on spend. Default: `False`.
1042
+
1043
+ Returns:
1044
+ An `xr.Dataset` and containing `incremental_outcome` for each channel. The
1045
+ coordinates are: `channel` and `metric` (`mean`, `median`, `ci_low`,
1046
+ `ci_high`)
1047
+ """
1048
+ # Selected times in boolean form are supported by the analyzer with and
1049
+ # without the new data.
1050
+ selected_times_bool = resolver.resolve_to_bool_selected_times()
1051
+ kwargs = {
1052
+ "selected_geos": selected_geos,
1053
+ "selected_times": selected_times_bool,
1054
+ "aggregate_geos": aggregate_geos,
1055
+ "aggregate_times": aggregate_times,
1056
+ "batch_size": batch_size,
1057
+ }
1058
+ incremental_outcome_posterior = (
1059
+ self._analyzer.compute_incremental_outcome_aggregate(
1060
+ new_data=new_data,
1061
+ media_selected_times=media_selected_times,
1062
+ use_posterior=True,
1063
+ use_kpi=use_kpi,
1064
+ include_non_paid_channels=include_non_paid_channels,
1065
+ **kwargs,
1066
+ )
1067
+ )
1068
+
1069
+ xr_dims = (
1070
+ ((constants.GEO,) if not aggregate_geos else ())
1071
+ + ((constants.TIME,) if not aggregate_times else ())
1072
+ + (constants.CHANNEL, constants.METRIC)
1073
+ )
1074
+ channels = (
1075
+ self._meridian.input_data.get_all_channels()
1076
+ if include_non_paid_channels
1077
+ else self._meridian.input_data.get_all_paid_channels()
1078
+ )
1079
+ xr_coords = {
1080
+ constants.CHANNEL: (
1081
+ [constants.CHANNEL],
1082
+ list(channels) + [constants.ALL_CHANNELS],
1083
+ ),
1084
+ }
1085
+ if not aggregate_geos:
1086
+ geo_dims = (
1087
+ self._meridian.input_data.geo.data
1088
+ if selected_geos is None
1089
+ else selected_geos
1090
+ )
1091
+ xr_coords[constants.GEO] = ([constants.GEO], geo_dims)
1092
+ if not aggregate_times:
1093
+ selected_times_str = resolver.resolve_to_enumerated_selected_times()
1094
+ if selected_times_str is not None:
1095
+ time_dims = selected_times_str
1096
+ else:
1097
+ time_dims = resolver.time_coordinates.all_dates_str
1098
+ xr_coords[constants.TIME] = ([constants.TIME], time_dims)
1099
+ xr_coords_with_ci = {
1100
+ constants.METRIC: (
1101
+ [constants.METRIC],
1102
+ [
1103
+ constants.MEAN,
1104
+ constants.MEDIAN,
1105
+ constants.CI_LO,
1106
+ constants.CI_HI,
1107
+ ],
1108
+ ),
1109
+ **xr_coords,
1110
+ }
1111
+ metrics = analyzer.get_central_tendency_and_ci(
1112
+ incremental_outcome_posterior, confidence_level, include_median=True
1113
+ )
1114
+ xr_data = {constants.INCREMENTAL_OUTCOME: (xr_dims, metrics)}
1115
+ return xr.Dataset(data_vars=xr_data, coords=xr_coords_with_ci)
1116
+
1117
+
1118
+ def _compute_spend(
1119
+ media_summary_metrics: xr.Dataset,
1120
+ ) -> media_analysis_pb2.SpendInfo | None:
1121
+ """Returns a `SpendInfo` proto with spend information for the given channel.
1122
+
1123
+ Args:
1124
+ media_summary_metrics: A dataset containing the model's media summary
1125
+ metrics.
1126
+ """
1127
+ if constants.SPEND not in media_summary_metrics.data_vars:
1128
+ return None
1129
+
1130
+ spend = media_summary_metrics[constants.SPEND].item()
1131
+ spend_share = media_summary_metrics[constants.PCT_OF_SPEND].data.item() / 100
1132
+
1133
+ return media_analysis_pb2.SpendInfo(
1134
+ spend=spend,
1135
+ spend_share=spend_share,
1136
+ )