google-meridian 1.3.2__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/METADATA +18 -11
  2. google_meridian-1.5.0.dist-info/RECORD +112 -0
  3. {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/WHEEL +1 -1
  4. {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/top_level.txt +1 -0
  5. meridian/analysis/analyzer.py +558 -398
  6. meridian/analysis/optimizer.py +90 -68
  7. meridian/analysis/review/reviewer.py +4 -1
  8. meridian/analysis/summarizer.py +13 -3
  9. meridian/analysis/test_utils.py +2911 -2102
  10. meridian/analysis/visualizer.py +37 -14
  11. meridian/backend/__init__.py +106 -0
  12. meridian/constants.py +2 -0
  13. meridian/data/input_data.py +30 -52
  14. meridian/data/input_data_builder.py +2 -9
  15. meridian/data/test_utils.py +107 -51
  16. meridian/data/validator.py +48 -0
  17. meridian/mlflow/autolog.py +19 -9
  18. meridian/model/__init__.py +2 -0
  19. meridian/model/adstock_hill.py +3 -5
  20. meridian/model/context.py +1059 -0
  21. meridian/model/eda/constants.py +335 -4
  22. meridian/model/eda/eda_engine.py +723 -312
  23. meridian/model/eda/eda_outcome.py +177 -33
  24. meridian/model/equations.py +418 -0
  25. meridian/model/knots.py +58 -47
  26. meridian/model/model.py +228 -878
  27. meridian/model/model_test_data.py +38 -0
  28. meridian/model/posterior_sampler.py +103 -62
  29. meridian/model/prior_sampler.py +114 -94
  30. meridian/model/spec.py +23 -14
  31. meridian/templates/card.html.jinja +9 -7
  32. meridian/templates/chart.html.jinja +1 -6
  33. meridian/templates/finding.html.jinja +19 -0
  34. meridian/templates/findings.html.jinja +33 -0
  35. meridian/templates/formatter.py +41 -5
  36. meridian/templates/formatter_test.py +127 -0
  37. meridian/templates/style.css +66 -9
  38. meridian/templates/style.scss +85 -4
  39. meridian/templates/table.html.jinja +1 -0
  40. meridian/version.py +1 -1
  41. scenarioplanner/__init__.py +42 -0
  42. scenarioplanner/converters/__init__.py +25 -0
  43. scenarioplanner/converters/dataframe/__init__.py +28 -0
  44. scenarioplanner/converters/dataframe/budget_opt_converters.py +383 -0
  45. scenarioplanner/converters/dataframe/common.py +71 -0
  46. scenarioplanner/converters/dataframe/constants.py +137 -0
  47. scenarioplanner/converters/dataframe/converter.py +42 -0
  48. scenarioplanner/converters/dataframe/dataframe_model_converter.py +70 -0
  49. scenarioplanner/converters/dataframe/marketing_analyses_converters.py +543 -0
  50. scenarioplanner/converters/dataframe/rf_opt_converters.py +314 -0
  51. scenarioplanner/converters/mmm.py +743 -0
  52. scenarioplanner/converters/mmm_converter.py +58 -0
  53. scenarioplanner/converters/sheets.py +156 -0
  54. scenarioplanner/converters/test_data.py +714 -0
  55. scenarioplanner/linkingapi/__init__.py +47 -0
  56. scenarioplanner/linkingapi/constants.py +27 -0
  57. scenarioplanner/linkingapi/url_generator.py +131 -0
  58. scenarioplanner/mmm_ui_proto_generator.py +355 -0
  59. schema/__init__.py +5 -2
  60. schema/mmm_proto_generator.py +71 -0
  61. schema/model_consumer.py +133 -0
  62. schema/processors/__init__.py +77 -0
  63. schema/processors/budget_optimization_processor.py +832 -0
  64. schema/processors/common.py +64 -0
  65. schema/processors/marketing_processor.py +1137 -0
  66. schema/processors/model_fit_processor.py +367 -0
  67. schema/processors/model_kernel_processor.py +117 -0
  68. schema/processors/model_processor.py +415 -0
  69. schema/processors/reach_frequency_optimization_processor.py +584 -0
  70. schema/serde/distribution.py +12 -7
  71. schema/serde/hyperparameters.py +54 -107
  72. schema/serde/meridian_serde.py +6 -1
  73. schema/test_data.py +380 -0
  74. schema/utils/__init__.py +2 -0
  75. schema/utils/date_range_bucketing.py +117 -0
  76. schema/utils/proto_enum_converter.py +127 -0
  77. google_meridian-1.3.2.dist-info/RECORD +0 -76
  78. {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,584 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Defines a processor for reach and frequency optimization inference on a Meridian model.
16
+
17
+ This module provides the `ReachFrequencyOptimizationProcessor`, which optimizes
18
+ the average frequency for reach and frequency (R&F) media channels in a trained
19
+ Meridian model to maximize ROI.
20
+
21
+ The processor takes a trained model and a `ReachFrequencyOptimizationSpec`
22
+ object. The spec defines the constraints for the optimization, such as the
23
+ minimum and maximum average frequency to consider for each channel.
24
+
25
+ Key Features:
26
+
27
+ - Optimizes average frequency for all R&F channels simultaneously.
28
+ - Allows setting minimum and maximum frequency constraints.
29
+ - Generates detailed results, including the optimal average frequency for
30
+ each channel, the expected outcomes at this optimal frequency, and
31
+ response curves showing KPI/Revenue as a function of spend.
32
+ - Outputs results in a structured protobuf format
33
+ (`ReachFrequencyOptimization`).
34
+
35
+ Key Classes:
36
+
37
+ - `ReachFrequencyOptimizationSpec`: Dataclass to specify optimization
38
+ parameters and constraints.
39
+ - `ReachFrequencyOptimizationProcessor`: The main processor class to execute
40
+ the R&F optimization.
41
+
42
+ Example Usage:
43
+
44
+ ```python
45
+ from schema.processors import reach_frequency_optimization_processor
46
+ from schema.processors import common
47
+ from schema.processors import model_processor
48
+ import datetime
49
+
50
+ # Assuming 'mmm' is a trained Meridian model object with R&F channels
51
+ trained_model = model_processor.TrainedModel(mmm)
52
+
53
+ spec = reach_frequency_optimization_processor.ReachFrequencyOptimizationSpec(
54
+ optimization_name="rf_optimize_q1",
55
+ start_date=datetime.date(2023, 1, 1),
56
+ end_date=datetime.date(2023, 4, 1),
57
+ min_frequency=1.0,
58
+ max_frequency=10.0, # Optional, defaults to model's max frequency
59
+ kpi_type=common.KpiType.REVENUE,
60
+ )
61
+
62
+ processor = (
63
+ reach_frequency_optimization_processor.ReachFrequencyOptimizationProcessor(
64
+ trained_model
65
+ )
66
+ )
67
+ # result is a rf_pb.ReachFrequencyOptimization proto
68
+ result = processor.execute([spec])
69
+
70
+ print(f"R&F Optimization results for {spec.optimization_name}:")
71
+ # Access results from the proto, e.g.:
72
+ # result.results[0].optimized_channel_frequencies
73
+ # result.results[0].optimized_marketing_analysis
74
+ # result.results[0].frequency_outcome_grid
75
+ ```
76
+
77
+ Note: You can provide the processor with multiple specs. This would result in
78
+ a `ReachFrequencyOptimization` output with multiple results therein.
79
+ """
80
+
81
+ from collections.abc import Sequence
82
+ import dataclasses
83
+
84
+ from meridian import backend
85
+ from meridian import constants
86
+ from mmm.v1 import mmm_pb2 as pb
87
+ from mmm.v1.common import kpi_type_pb2 as kpi_type_pb
88
+ from mmm.v1.marketing.analysis import marketing_analysis_pb2 as analysis_pb
89
+ from mmm.v1.marketing.analysis import media_analysis_pb2 as media_analysis_pb
90
+ from mmm.v1.marketing.analysis import outcome_pb2 as outcome_pb
91
+ from mmm.v1.marketing.analysis import response_curve_pb2
92
+ from mmm.v1.marketing.optimization import constraints_pb2 as constraints_pb
93
+ from mmm.v1.marketing.optimization import reach_frequency_optimization_pb2 as rf_pb
94
+ from schema.processors import common
95
+ from schema.processors import model_processor
96
+ from schema.utils import time_record
97
+ import numpy as np
98
+ import xarray as xr
99
+
100
+
101
+ __all__ = [
102
+ "ReachFrequencyOptimizationSpec",
103
+ "ReachFrequencyOptimizationProcessor",
104
+ ]
105
+
106
+
107
+ _STEP_SIZE_DECIMAL_PRECISION = 1
108
+ _STEP_SIZE = _STEP_SIZE_DECIMAL_PRECISION / 10
109
+ _TOL = 1e-6
110
+
111
+
112
+ @dataclasses.dataclass(frozen=True, kw_only=True)
113
+ class ReachFrequencyOptimizationSpec(model_processor.OptimizationSpec):
114
+ """Spec dataclass for marketing reach and frequency optimization processor.
115
+
116
+ A frequency grid is generated using the range `[rounded_min_frequency,
117
+ rounded_max_frequency]` and a step size of `STEP_SIZE=0.1`.
118
+ `rounded_min_frequency` and `rounded_max_frequency` are rounded to the
119
+ nearest multiple of `STEP_SIZE`.
120
+
121
+ This spec is used both as user input to inform the R&F optimization processor
122
+ of its constraints and parameters, as well as an output structure that is
123
+ serializable to a `ReachFrequencyOptimizationSpec` proto. The latter serves as
124
+ a metadata embedded in a `ReachFrequencyOptimizationResult`. The output spec
125
+ in the proto reflects the actual numbers used to generate the reach and
126
+ frequency optimization result.
127
+
128
+ Attributes:
129
+ min_frequency: The minimum frequency constraint for each channel. Must be
130
+ greater than or equal to `1.0`. Defaults to `1.0`.
131
+ max_frequency: The maximum frequency constraint for each channel. Must be
132
+ greater than min_frequency. Defaults to None. If this value is set to
133
+ None, the model's max frequency will be used.
134
+ rf_channels: The R&F media channels in the model. When resolved with a
135
+ model, the model's R&F channels will be present here. Ignored when used as
136
+ input.
137
+ kpi_type: A `common.KpiType` enum denoting whether the optimized KPI is of a
138
+ `'revenue'` or `'non-revenue'` type.
139
+ """
140
+
141
+ min_frequency: float = 1.0
142
+ max_frequency: float | None = None
143
+ rf_channels: Sequence[str] = dataclasses.field(default_factory=list)
144
+ kpi_type: common.KpiType = common.KpiType.REVENUE
145
+
146
+ @property
147
+ def selected_times(self) -> tuple[str | None, str | None] | None:
148
+ """The start and end dates, as a tuple of date strings."""
149
+ start, end = (None, None)
150
+ if self.start_date is not None:
151
+ start = self.start_date.strftime(constants.DATE_FORMAT)
152
+ if self.end_date is not None:
153
+ end = self.end_date.strftime(constants.DATE_FORMAT)
154
+
155
+ if start or end:
156
+ return (start, end)
157
+ return None
158
+
159
+ @property
160
+ def objective(self) -> common.TargetMetric:
161
+ """A Meridian budget optimization objective is always ROI."""
162
+ return common.TargetMetric.ROI
163
+
164
+ def validate(self):
165
+ super().validate()
166
+ if self.min_frequency < 0:
167
+ raise ValueError("Min frequency must be non-negative.")
168
+ if (
169
+ self.max_frequency is not None
170
+ and self.max_frequency < self.min_frequency
171
+ ):
172
+ raise ValueError("Max frequency must be greater than min frequency.")
173
+
174
+ def to_proto(self) -> rf_pb.ReachFrequencyOptimizationSpec:
175
+ # When invoked as an output proto, the spec should have been fully resolved.
176
+ if self.start_date is None or self.end_date is None:
177
+ raise ValueError(
178
+ "Start and end dates must be resolved before this spec can be"
179
+ " serialized."
180
+ )
181
+
182
+ return rf_pb.ReachFrequencyOptimizationSpec(
183
+ date_interval=time_record.create_date_interval_pb(
184
+ self.start_date, self.end_date, tag=self.date_interval_tag
185
+ ),
186
+ rf_channel_constraints=[
187
+ rf_pb.RfChannelConstraint(
188
+ channel_name=channel,
189
+ frequency_constraint=constraints_pb.FrequencyConstraint(
190
+ min_frequency=self.min_frequency,
191
+ max_frequency=self.max_frequency,
192
+ ),
193
+ )
194
+ for channel in self.rf_channels
195
+ ],
196
+ objective=self.objective.value,
197
+ kpi_type=(
198
+ kpi_type_pb.KpiType.REVENUE
199
+ if self.kpi_type == common.KpiType.REVENUE
200
+ else kpi_type_pb.KpiType.NON_REVENUE
201
+ ),
202
+ )
203
+
204
+
205
+ class ReachFrequencyOptimizationProcessor(
206
+ model_processor.ModelProcessor[
207
+ ReachFrequencyOptimizationSpec, rf_pb.ReachFrequencyOptimization
208
+ ],
209
+ ):
210
+ """A Processor for marketing reach and frequency optimization."""
211
+
212
+ def __init__(
213
+ self,
214
+ trained_model: model_processor.ModelType,
215
+ ):
216
+ trained_model = model_processor.ensure_trained_model(trained_model)
217
+ self._internal_analyzer = trained_model.internal_analyzer
218
+ self._meridian = trained_model.mmm
219
+
220
+ if trained_model.mmm.input_data.rf_channel is None:
221
+ raise ValueError("RF channels must be set in the model.")
222
+
223
+ self._all_rf_channels = trained_model.mmm.input_data.rf_channel.data
224
+
225
+ @classmethod
226
+ def spec_type(cls) -> type[ReachFrequencyOptimizationSpec]:
227
+ return ReachFrequencyOptimizationSpec
228
+
229
+ @classmethod
230
+ def output_type(cls) -> type[rf_pb.ReachFrequencyOptimization]:
231
+ return rf_pb.ReachFrequencyOptimization
232
+
233
+ def _to_target_precision(self, value: float) -> float:
234
+ return round(value, _STEP_SIZE_DECIMAL_PRECISION)
235
+
236
+ def _set_output(
237
+ self, output: pb.Mmm, result: rf_pb.ReachFrequencyOptimization
238
+ ):
239
+ output.marketing_optimization.reach_frequency_optimization.CopyFrom(result)
240
+
241
+ def execute(
242
+ self, specs: Sequence[ReachFrequencyOptimizationSpec]
243
+ ) -> rf_pb.ReachFrequencyOptimization:
244
+ output = rf_pb.ReachFrequencyOptimization()
245
+
246
+ group_ids = [spec.group_id for spec in specs if spec.group_id]
247
+ if len(set(group_ids)) != len(group_ids):
248
+ raise ValueError(
249
+ "Specified group_id must be unique among the given group of specs."
250
+ )
251
+
252
+ for spec in specs:
253
+ selected_times = spec.resolver(
254
+ self._meridian.input_data.time_coordinates
255
+ ).resolve_to_enumerated_selected_times()
256
+
257
+ grid_min_freq = self._to_target_precision(spec.min_frequency)
258
+ # If the max frequency is not set, use the model's max frequency.
259
+ grid_max_freq = self._to_target_precision(
260
+ spec.max_frequency or np.max(self._meridian.rf_tensors.frequency)
261
+ )
262
+ grid = [
263
+ self._to_target_precision(f)
264
+ for f in np.arange(grid_min_freq, grid_max_freq + _TOL, _STEP_SIZE)
265
+ ]
266
+
267
+ # Note that the internal analyzer, like the budget optimizer, maximizes
268
+ # non-revenue KPI if input data is of non-revenue and the user selects
269
+ # `use_kpi=True`. Otherwise, it maximizes revenue KPI.
270
+ optimal_frequency = self._internal_analyzer.optimal_freq(
271
+ selected_times=selected_times,
272
+ confidence_level=spec.confidence_level,
273
+ freq_grid=grid,
274
+ use_kpi=(spec.kpi_type == common.KpiType.NON_REVENUE),
275
+ max_frequency=spec.max_frequency,
276
+ )
277
+ response_curve = self._internal_analyzer.response_curves(
278
+ confidence_level=spec.confidence_level,
279
+ selected_times=selected_times,
280
+ by_reach=False,
281
+ use_kpi=(spec.kpi_type == common.KpiType.NON_REVENUE),
282
+ use_optimal_frequency=True,
283
+ )
284
+
285
+ spend_data = self._compute_spend_data(selected_times=selected_times)
286
+
287
+ # Obtain the output spec.
288
+ start, end = spec.resolver(
289
+ self._meridian.input_data.time_coordinates
290
+ ).resolve_to_date_interval_open_end()
291
+
292
+ # Copy the current spec, and resolve its date interval as well as model-
293
+ # dependent parameters.
294
+ output_spec = dataclasses.replace(
295
+ spec,
296
+ rf_channels=self._all_rf_channels,
297
+ min_frequency=grid_min_freq,
298
+ max_frequency=grid_max_freq,
299
+ start_date=start,
300
+ end_date=end,
301
+ )
302
+
303
+ output.results.append(
304
+ self._to_reach_frequency_optimization_result(
305
+ output_spec,
306
+ optimal_frequency,
307
+ response_curve,
308
+ spend_data,
309
+ )
310
+ )
311
+ return output
312
+
313
+ def _compute_spend_data(
314
+ self, selected_times: list[str] | None = None
315
+ ) -> xr.Dataset:
316
+ aggregated_spends = self._internal_analyzer.get_historical_spend(
317
+ selected_times
318
+ )
319
+ aggregated_rf_spend = aggregated_spends.sel(
320
+ {constants.CHANNEL: self._all_rf_channels}
321
+ ).data
322
+ total_spend = np.sum(aggregated_spends.data)
323
+ pct_of_spend = 100.0 * aggregated_rf_spend / total_spend
324
+
325
+ xr_dims = (constants.RF_CHANNEL,)
326
+ xr_coords = {
327
+ constants.RF_CHANNEL: (
328
+ [constants.RF_CHANNEL],
329
+ list(self._all_rf_channels),
330
+ ),
331
+ }
332
+ xr_data_vars = {
333
+ constants.SPEND: (xr_dims, aggregated_rf_spend),
334
+ constants.PCT_OF_SPEND: (xr_dims, pct_of_spend),
335
+ }
336
+
337
+ return xr.Dataset(
338
+ data_vars=xr_data_vars,
339
+ coords=xr_coords,
340
+ )
341
+
342
+ def _to_reach_frequency_optimization_result(
343
+ self,
344
+ spec: ReachFrequencyOptimizationSpec,
345
+ optimal_frequency: xr.Dataset,
346
+ response_curve: xr.Dataset,
347
+ spend_data: xr.Dataset,
348
+ ) -> rf_pb.ReachFrequencyOptimizationResult:
349
+ """Converts given optimal frequency dataset to protobuf form."""
350
+ result = rf_pb.ReachFrequencyOptimizationResult(
351
+ name=spec.optimization_name,
352
+ spec=spec.to_proto(),
353
+ optimized_channel_frequencies=_create_optimized_channel_frequencies(
354
+ optimal_frequency
355
+ ),
356
+ optimized_marketing_analysis=self._to_marketing_analysis(
357
+ spec,
358
+ optimal_frequency,
359
+ response_curve,
360
+ spend_data,
361
+ ),
362
+ frequency_outcome_grid=self._create_frequency_outcome_grid(
363
+ optimal_frequency,
364
+ spec,
365
+ ),
366
+ )
367
+ if spec.group_id:
368
+ result.group_id = spec.group_id
369
+ return result
370
+
371
+ def _to_marketing_analysis(
372
+ self,
373
+ spec: ReachFrequencyOptimizationSpec,
374
+ optimal_frequency: xr.Dataset,
375
+ response_curve: xr.Dataset,
376
+ spend_data: xr.Dataset,
377
+ ) -> analysis_pb.MarketingAnalysis:
378
+ """Converts an optimal frequency dataset to a `MarketingAnalysis` proto."""
379
+ # `spec` should have been resolved with concrete date interval parameters.
380
+ assert spec.start_date is not None and spec.end_date is not None
381
+
382
+ optimized_marketing_analysis = analysis_pb.MarketingAnalysis(
383
+ date_interval=time_record.create_date_interval_pb(
384
+ start_date=spec.start_date,
385
+ end_date=spec.end_date,
386
+ ),
387
+ )
388
+
389
+ # Create a per-channel MediaAnalysis.
390
+ channels = optimal_frequency.coords[constants.RF_CHANNEL].data
391
+ for channel in channels:
392
+ channel_optimal_frequency = optimal_frequency.sel(rf_channel=channel)
393
+ channel_spend_data = spend_data.sel(rf_channel=channel)
394
+
395
+ # TODO Add non-media analyses.
396
+ channel_media_analysis = media_analysis_pb.MediaAnalysis(
397
+ channel_name=channel,
398
+ response_curve=_compute_response_curve(
399
+ response_curve,
400
+ channel,
401
+ ),
402
+ spend_info=media_analysis_pb.SpendInfo(
403
+ spend=channel_spend_data[constants.SPEND].data.item(),
404
+ spend_share=(
405
+ channel_spend_data[constants.PCT_OF_SPEND].data.item()
406
+ ),
407
+ ),
408
+ )
409
+
410
+ # Output one outcome per channel: either revenue or non-revenue.
411
+ channel_media_analysis.media_outcomes.append(
412
+ _to_outcome(
413
+ channel_optimal_frequency,
414
+ is_revenue_kpi=optimal_frequency.attrs[constants.IS_REVENUE_KPI],
415
+ )
416
+ )
417
+
418
+ optimized_marketing_analysis.media_analyses.append(channel_media_analysis)
419
+
420
+ return optimized_marketing_analysis
421
+
422
+ def _create_frequency_outcome_grid(
423
+ self,
424
+ optimal_frequency_dataset: xr.Dataset,
425
+ spec: ReachFrequencyOptimizationSpec,
426
+ ) -> rf_pb.FrequencyOutcomeGrid:
427
+ """Creates a FrequencyOutcomeGrid proto."""
428
+ channel_cells = []
429
+ frequencies = optimal_frequency_dataset.coords[constants.FREQUENCY].data
430
+ channels = optimal_frequency_dataset.coords[constants.RF_CHANNEL].data
431
+ input_tensor_dims = "gtc"
432
+ output_tensor_dims = "c"
433
+
434
+ for channel in channels:
435
+ cells = []
436
+ for frequency in frequencies:
437
+ new_frequency = (
438
+ backend.ones_like(self._meridian.rf_tensors.frequency) * frequency
439
+ )
440
+ new_reach = (
441
+ self._meridian.rf_tensors.frequency
442
+ * self._meridian.rf_tensors.reach
443
+ / new_frequency
444
+ )
445
+ channel_mask = [c == channel for c in channels]
446
+ filtered_reach = backend.boolean_mask(new_reach, channel_mask, axis=2)
447
+ aggregated_reach = backend.einsum(
448
+ f"{input_tensor_dims}->...{output_tensor_dims}", filtered_reach
449
+ )
450
+ reach = aggregated_reach.numpy()[-1]
451
+
452
+ metric_data_array = optimal_frequency_dataset[constants.ROI].sel(
453
+ frequency=frequency, rf_channel=channel
454
+ )
455
+ outcome = common.to_estimate(metric_data_array, spec.confidence_level)
456
+
457
+ cell = rf_pb.FrequencyOutcomeGrid.Cell(
458
+ outcome=outcome,
459
+ reach_frequency=rf_pb.ReachFrequency(
460
+ reach=int(reach),
461
+ average_frequency=frequency,
462
+ ),
463
+ )
464
+ cells.append(cell)
465
+
466
+ channel_cell = rf_pb.FrequencyOutcomeGrid.ChannelCells(
467
+ channel_name=channel,
468
+ cells=cells,
469
+ )
470
+ channel_cells.append(channel_cell)
471
+
472
+ return rf_pb.FrequencyOutcomeGrid(
473
+ name=spec.grid_name,
474
+ frequency_step_size=_STEP_SIZE,
475
+ channel_cells=channel_cells,
476
+ )
477
+
478
+
479
+ def _create_optimized_channel_frequencies(
480
+ optimal_frequency_dataset: xr.Dataset,
481
+ ) -> list[rf_pb.OptimizedChannelFrequency]:
482
+ """Creates an OptimizedChannelFrequency proto for each channel in the dataset."""
483
+ optimal_frequency_protos = []
484
+ optimal_frequency = optimal_frequency_dataset[constants.OPTIMAL_FREQUENCY]
485
+ channels = optimal_frequency.coords[constants.RF_CHANNEL].data
486
+
487
+ for channel in channels:
488
+ optimal_frequency_protos.append(
489
+ rf_pb.OptimizedChannelFrequency(
490
+ channel_name=channel,
491
+ optimal_average_frequency=optimal_frequency.sel(
492
+ rf_channel=channel
493
+ ).item(),
494
+ )
495
+ )
496
+ return optimal_frequency_protos
497
+
498
+
499
+ def _to_outcome(
500
+ channel_optimal_frequency: xr.Dataset,
501
+ is_revenue_kpi: bool,
502
+ ) -> outcome_pb.Outcome:
503
+ """Returns an `Outcome` value for a given channel's optimized media analysis.
504
+
505
+ Args:
506
+ channel_optimal_frequency: A channel-selected dataset from
507
+ `Analyzer.optimal_freq()`.
508
+ is_revenue_kpi: Whether the KPI is revenue-based.
509
+ """
510
+ confidence_level = channel_optimal_frequency.attrs[constants.CONFIDENCE_LEVEL]
511
+ return outcome_pb.Outcome(
512
+ kpi_type=(
513
+ kpi_type_pb.REVENUE if is_revenue_kpi else kpi_type_pb.NON_REVENUE
514
+ ),
515
+ roi=common.to_estimate(
516
+ channel_optimal_frequency.optimized_roi, confidence_level
517
+ ),
518
+ marginal_roi=common.to_estimate(
519
+ channel_optimal_frequency.optimized_mroi_by_frequency,
520
+ confidence_level,
521
+ ),
522
+ cost_per_contribution=common.to_estimate(
523
+ channel_optimal_frequency.optimized_cpik,
524
+ confidence_level=confidence_level,
525
+ ),
526
+ contribution=outcome_pb.Contribution(
527
+ value=common.to_estimate(
528
+ channel_optimal_frequency.optimized_incremental_outcome,
529
+ confidence_level,
530
+ ),
531
+ ),
532
+ effectiveness=outcome_pb.Effectiveness(
533
+ media_unit=constants.IMPRESSIONS,
534
+ value=common.to_estimate(
535
+ channel_optimal_frequency.optimized_effectiveness,
536
+ confidence_level,
537
+ ),
538
+ ),
539
+ )
540
+
541
+
542
+ def _compute_response_curve(
543
+ response_curve_dataset: xr.Dataset,
544
+ channel_name: str,
545
+ ) -> response_curve_pb2.ResponseCurve:
546
+ """Returns a ResponseCurve proto for the given channel.
547
+
548
+ Args:
549
+ response_curve_dataset: A dataset containing the data needed to generate a
550
+ response curve.
551
+ channel_name: The name of the channel to analyze.
552
+ """
553
+
554
+ spend_multiplier_list = response_curve_dataset.coords[
555
+ constants.SPEND_MULTIPLIER
556
+ ].data
557
+ response_points: list[response_curve_pb2.ResponsePoint] = []
558
+
559
+ for spend_multiplier in spend_multiplier_list:
560
+ spend = (
561
+ response_curve_dataset[constants.SPEND]
562
+ .sel(spend_multiplier=spend_multiplier, channel=channel_name)
563
+ .data.item()
564
+ )
565
+ incremental_outcome = (
566
+ response_curve_dataset[constants.INCREMENTAL_OUTCOME]
567
+ .sel(
568
+ spend_multiplier=spend_multiplier,
569
+ channel=channel_name,
570
+ metric=constants.MEAN,
571
+ )
572
+ .data.item()
573
+ )
574
+
575
+ response_point = response_curve_pb2.ResponsePoint(
576
+ input_value=spend,
577
+ incremental_kpi=incremental_outcome,
578
+ )
579
+ response_points.append(response_point)
580
+
581
+ return response_curve_pb2.ResponseCurve(
582
+ input_name=constants.SPEND,
583
+ response_points=response_points,
584
+ )
@@ -193,6 +193,8 @@ class DistributionSerde(
193
193
  return meridian_pb.TfpParameterValue(scalar_value=value)
194
194
  case int():
195
195
  return meridian_pb.TfpParameterValue(int_value=value)
196
+ # TODO: b/470407198 - case bool() has to be before int() because bool is a
197
+ # subtype of int.
196
198
  case bool():
197
199
  return meridian_pb.TfpParameterValue(bool_value=value)
198
200
  case str():
@@ -216,10 +218,6 @@ class DistributionSerde(
216
218
  return meridian_pb.TfpParameterValue(
217
219
  dict_value=meridian_pb.TfpParameterValue.Dict(value_map=dict_value)
218
220
  )
219
- case backend.Tensor():
220
- return meridian_pb.TfpParameterValue(
221
- tensor_value=backend.make_tensor_proto(value)
222
- )
223
221
  case backend.tfd.Distribution():
224
222
  return meridian_pb.TfpParameterValue(
225
223
  distribution_value=self._to_distribution_proto(value)
@@ -257,9 +255,16 @@ class DistributionSerde(
257
255
  f" {type(dist).__name__}, but not found in registry. Please"
258
256
  " add custom functions to registry when saving models."
259
257
  )
260
-
261
- # Handle unsupported types.
262
- raise TypeError(f"Unsupported type: {type(value)}, {value}")
258
+ case _:
259
+ # Handle unsupported types by attempting to convert to a tensor proto.
260
+ # This allows for more flexibility in handling types that are not
261
+ # explicitly handled above, such as numpy arrays or backend tensors.
262
+ try:
263
+ return meridian_pb.TfpParameterValue(
264
+ tensor_value=backend.make_tensor_proto(value)
265
+ )
266
+ except TypeError as e:
267
+ raise TypeError(f"Unsupported type: {type(value)}, {value!r}") from e
263
268
 
264
269
  def _from_distribution_proto(
265
270
  self,