google-meridian 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. {google_meridian-1.4.0.dist-info → google_meridian-1.5.0.dist-info}/METADATA +14 -11
  2. {google_meridian-1.4.0.dist-info → google_meridian-1.5.0.dist-info}/RECORD +47 -43
  3. {google_meridian-1.4.0.dist-info → google_meridian-1.5.0.dist-info}/WHEEL +1 -1
  4. meridian/analysis/analyzer.py +558 -398
  5. meridian/analysis/optimizer.py +90 -68
  6. meridian/analysis/review/reviewer.py +4 -1
  7. meridian/analysis/summarizer.py +6 -1
  8. meridian/analysis/test_utils.py +2898 -2538
  9. meridian/analysis/visualizer.py +28 -9
  10. meridian/backend/__init__.py +106 -0
  11. meridian/constants.py +1 -0
  12. meridian/data/input_data.py +30 -52
  13. meridian/data/input_data_builder.py +2 -9
  14. meridian/data/test_utils.py +25 -41
  15. meridian/data/validator.py +48 -0
  16. meridian/mlflow/autolog.py +19 -9
  17. meridian/model/adstock_hill.py +3 -5
  18. meridian/model/context.py +134 -0
  19. meridian/model/eda/constants.py +334 -4
  20. meridian/model/eda/eda_engine.py +723 -312
  21. meridian/model/eda/eda_outcome.py +177 -33
  22. meridian/model/model.py +159 -110
  23. meridian/model/model_test_data.py +38 -0
  24. meridian/model/posterior_sampler.py +103 -62
  25. meridian/model/prior_sampler.py +114 -94
  26. meridian/model/spec.py +23 -14
  27. meridian/templates/card.html.jinja +9 -7
  28. meridian/templates/chart.html.jinja +1 -6
  29. meridian/templates/finding.html.jinja +19 -0
  30. meridian/templates/findings.html.jinja +33 -0
  31. meridian/templates/formatter.py +41 -5
  32. meridian/templates/formatter_test.py +127 -0
  33. meridian/templates/style.css +66 -9
  34. meridian/templates/style.scss +85 -4
  35. meridian/templates/table.html.jinja +1 -0
  36. meridian/version.py +1 -1
  37. scenarioplanner/linkingapi/constants.py +1 -1
  38. scenarioplanner/mmm_ui_proto_generator.py +1 -0
  39. schema/processors/marketing_processor.py +11 -10
  40. schema/processors/model_processor.py +4 -1
  41. schema/serde/distribution.py +12 -7
  42. schema/serde/hyperparameters.py +54 -107
  43. schema/serde/meridian_serde.py +6 -1
  44. schema/utils/__init__.py +1 -0
  45. schema/utils/proto_enum_converter.py +127 -0
  46. {google_meridian-1.4.0.dist-info → google_meridian-1.5.0.dist-info}/licenses/LICENSE +0 -0
  47. {google_meridian-1.4.0.dist-info → google_meridian-1.5.0.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,7 @@ from collections.abc import Sequence
18
18
  import dataclasses
19
19
  import math
20
20
  import os
21
+ import re
21
22
 
22
23
  import altair as alt
23
24
  import immutabledict
@@ -46,6 +47,9 @@ class ChartSpec:
46
47
  id: str
47
48
  chart_json: str
48
49
  description: str | None = None
50
+ errors: Sequence[str] | None = None
51
+ warnings: Sequence[str] | None = None
52
+ infos: Sequence[str] | None = None
49
53
 
50
54
 
51
55
  @dataclasses.dataclass(frozen=True)
@@ -55,6 +59,9 @@ class TableSpec:
55
59
  column_headers: Sequence[str]
56
60
  row_values: Sequence[Sequence[str]]
57
61
  description: str | None = None
62
+ errors: Sequence[str] | None = None
63
+ warnings: Sequence[str] | None = None
64
+ infos: Sequence[str] | None = None
58
65
 
59
66
 
60
67
  @dataclasses.dataclass(frozen=True)
@@ -198,6 +205,25 @@ def format_monetary_num(num: float, currency: str) -> str:
198
205
  return compact_number(num, precision=precision, currency=currency)
199
206
 
200
207
 
208
+ def format_col_names(headers: Sequence[str]) -> Sequence[str]:
209
+ """Turns underscores to spaces and capitalizes words.
210
+
211
+ Ex. ['col_name', ...] to ['Col Name', ...])
212
+
213
+ Args:
214
+ headers: The list of column names to format.
215
+
216
+ Returns:
217
+ Human readable list of column names.
218
+ """
219
+ # \b matches the start of a word
220
+ # [a-z] matches only if the first letter is lowercase
221
+ return [
222
+ re.sub(r'\b[a-z]', lambda m: m.group().upper(), header.replace('_', ' '))
223
+ for header in headers
224
+ ]
225
+
226
+
201
227
  def create_template_env() -> jinja2.Environment:
202
228
  """Creates a Jinja2 template environment."""
203
229
  return jinja2.Environment(
@@ -220,19 +246,20 @@ def create_summary_html(
220
246
  def create_card_html(
221
247
  template_env: jinja2.Environment,
222
248
  card_spec: CardSpec,
223
- insights: str,
249
+ insights: str | None = None,
224
250
  chart_specs: Sequence[ChartSpec | TableSpec] | None = None,
225
251
  stats_specs: Sequence[StatsSpec] | None = None,
226
252
  ) -> str:
227
253
  """Creates a card's HTML snippet that includes given card and chart specs."""
228
- insights_html = template_env.get_template('insights.html.jinja').render(
229
- text_html=insights
230
- )
231
254
  card_params = dataclasses.asdict(card_spec)
232
255
  card_params[c.CARD_CHARTS] = (
233
256
  _create_charts_htmls(template_env, chart_specs) if chart_specs else None
234
257
  )
235
- card_params[c.CARD_INSIGHTS] = insights_html
258
+ if insights:
259
+ insights_html = template_env.get_template('insights.html.jinja').render(
260
+ text_html=insights
261
+ )
262
+ card_params[c.CARD_INSIGHTS] = insights_html
236
263
  card_params[c.CARD_STATS] = (
237
264
  _create_stats_htmls(template_env, stats_specs) if stats_specs else None
238
265
  )
@@ -267,3 +294,12 @@ def _create_charts_htmls(
267
294
  else:
268
295
  htmls.append(table_template.render(dataclasses.asdict(spec)))
269
296
  return htmls
297
+
298
+
299
+ def create_finding_html(
300
+ template_env: jinja2.Environment, text: str, finding_type: str
301
+ ) -> str:
302
+ """Generates an HTML tag for the table finding."""
303
+ return template_env.get_template('finding.html.jinja').render(
304
+ finding_class=finding_type, text=text
305
+ )
@@ -92,6 +92,36 @@ class FormatterTest(parameterized.TestCase):
92
92
  formatted_number = formatter.compact_number(num, precision, currency)
93
93
  self.assertEqual(formatted_number, expected)
94
94
 
95
+ @parameterized.named_parameters(
96
+ dict(
97
+ testcase_name='basic_snake_case',
98
+ input_headers=['finding_cause'],
99
+ expected=['Finding Cause'],
100
+ ),
101
+ dict(
102
+ testcase_name='multiple_columns',
103
+ input_headers=['geo', 'time_index', 'channel_name'],
104
+ expected=['Geo', 'Time Index', 'Channel Name'],
105
+ ),
106
+ dict(
107
+ testcase_name='preserves_acronyms',
108
+ input_headers=['VIF_score', 'national_KPI'],
109
+ expected=['VIF Score', 'National KPI'],
110
+ ),
111
+ dict(
112
+ testcase_name='handles_tuples_input',
113
+ input_headers=('row_id', 'value'),
114
+ expected=['Row Id', 'Value'],
115
+ ),
116
+ dict(
117
+ testcase_name='empty_input',
118
+ input_headers=[],
119
+ expected=[],
120
+ ),
121
+ )
122
+ def test_format_col_names(self, input_headers, expected):
123
+ self.assertEqual(formatter.format_col_names(input_headers), expected)
124
+
95
125
  def test_create_summary_html(self):
96
126
  template_env = formatter.create_template_env()
97
127
  title = 'Integration Test Report'
@@ -211,6 +241,103 @@ class FormatterTest(parameterized.TestCase):
211
241
  self.assertEqual(stats_html[0][2].tag, 'delta')
212
242
  self.assertContainsSubset('+0.3', stats_html[0][2].text)
213
243
 
244
+ def test_create_card_html_no_insights(self):
245
+ template_env = formatter.create_template_env()
246
+ card_spec = formatter.CardSpec(id='test_id', title='test_title')
247
+ stats_spec = formatter.StatsSpec(title='stats_title', stat='test_stat')
248
+
249
+ card_html = ET.fromstring(
250
+ formatter.create_card_html(
251
+ template_env, card_spec, insights=None, stats_specs=[stats_spec]
252
+ )
253
+ )
254
+
255
+ self.assertEqual(card_html.tag, 'card')
256
+ self.assertIsNone(card_html.find('card-insights'))
257
+ self.assertIsNotNone(card_html.find('stats-section'))
258
+
259
+ def test_create_card_html_chart_findings(self):
260
+ """Tests that errors, warnings, and infos render inside a chart."""
261
+ template_env = formatter.create_template_env()
262
+ card_spec = formatter.CardSpec(id='test_id', title='test_title')
263
+ chart_spec = formatter.ChartSpec(
264
+ id='id',
265
+ chart_json='{}',
266
+ errors=['Chart Error'],
267
+ warnings=['Chart Warning'],
268
+ infos=['Chart Info'],
269
+ )
270
+ card_html = ET.fromstring(
271
+ formatter.create_card_html(
272
+ template_env, card_spec, insights=None, chart_specs=[chart_spec]
273
+ )
274
+ )
275
+
276
+ charts_elem = card_html.find('charts')
277
+ self.assertIsNotNone(charts_elem)
278
+ chart_elem = charts_elem.find('chart')
279
+ self.assertIsNotNone(chart_elem)
280
+
281
+ error_elem = chart_elem.find('errors')
282
+ self.assertIsNotNone(error_elem)
283
+ error_p = error_elem.find('p')
284
+ self.assertIsNotNone(error_p)
285
+ self.assertIn('Chart Error', error_p.text)
286
+
287
+ warning_elem = chart_elem.find('warnings')
288
+ self.assertIsNotNone(warning_elem)
289
+ warning_p = warning_elem.find('p')
290
+ self.assertIsNotNone(warning_p)
291
+ self.assertIn('Chart Warning', warning_p.text)
292
+
293
+ info_elem = chart_elem.find('infos')
294
+ self.assertIsNotNone(info_elem)
295
+ info_p = info_elem.find('p')
296
+ self.assertIsNotNone(info_p)
297
+ self.assertIn('Chart Info', info_p.text)
298
+
299
+ def test_create_card_html_table_findings(self):
300
+ """Tests that errors, warnings, and infos render inside a table."""
301
+ template_env = formatter.create_template_env()
302
+ card_spec = formatter.CardSpec(id='test_id', title='test_title')
303
+ table_spec = formatter.TableSpec(
304
+ id='table_id',
305
+ title='Table Title',
306
+ column_headers=['Col1'],
307
+ row_values=[['Val1']],
308
+ errors=['Table Error'],
309
+ warnings=['Table Warning'],
310
+ infos=['Table Info'],
311
+ )
312
+ card_html = ET.fromstring(
313
+ formatter.create_card_html(
314
+ template_env, card_spec, insights=None, chart_specs=[table_spec]
315
+ )
316
+ )
317
+
318
+ charts_elem = card_html.find('charts')
319
+ self.assertIsNotNone(charts_elem)
320
+ table_elem = charts_elem.find('chart-table')
321
+ self.assertIsNotNone(table_elem)
322
+
323
+ error_elem = table_elem.find('errors')
324
+ self.assertIsNotNone(error_elem)
325
+ error_p = error_elem.find('p')
326
+ self.assertIsNotNone(error_p)
327
+ self.assertIn('Table Error', error_p.text)
328
+
329
+ warning_elem = table_elem.find('warnings')
330
+ self.assertIsNotNone(warning_elem)
331
+ warning_p = warning_elem.find('p')
332
+ self.assertIsNotNone(warning_p)
333
+ self.assertIn('Table Warning', warning_p.text)
334
+
335
+ info_elem = table_elem.find('infos')
336
+ self.assertIsNotNone(info_elem)
337
+ info_p = info_elem.find('p')
338
+ self.assertIsNotNone(info_p)
339
+ self.assertIn('Table Info', info_p.text)
340
+
214
341
 
215
342
  if __name__ == '__main__':
216
343
  absltest.main()
@@ -86,11 +86,30 @@ card {
86
86
  font-weight: 400;
87
87
  line-height: 50px; }
88
88
  card > card-insights {
89
+ background: #eee; }
90
+ card errors,
91
+ card warnings,
92
+ card infos {
93
+ margin-top: -39px;
94
+ margin-left: -69px;
95
+ margin-right: -69px; }
96
+ card errors {
97
+ background: #f8d7da; }
98
+ card warnings {
99
+ background: #fff3cd; }
100
+ card infos {
101
+ background: #dae8fc; }
102
+ card > card-insights,
103
+ card errors,
104
+ card warnings,
105
+ card infos {
89
106
  display: flex;
90
107
  flex-direction: row;
91
- padding: 22px 20px;
92
- background: #eee; }
93
- card > card-insights p.insights-text {
108
+ padding: 22px 20px; }
109
+ card > card-insights p,
110
+ card errors p,
111
+ card warnings p,
112
+ card infos p {
94
113
  margin: 0px 25px;
95
114
  color: var(--Grey-800, #3c4043);
96
115
  font-family: Roboto;
@@ -110,7 +129,10 @@ charts {
110
129
  charts chart {
111
130
  display: flex;
112
131
  flex-flow: column nowrap;
113
- gap: 12px; }
132
+ gap: 12px;
133
+ flex: 1 1 auto;
134
+ min-width: 0;
135
+ max-width: 100%; }
114
136
  charts chart > chart-description {
115
137
  color: var(--Grey-800, #3c4043);
116
138
  font-family: "Google Sans Display", "Google Sans", sans-serif;
@@ -119,10 +141,18 @@ charts {
119
141
  font-weight: 400;
120
142
  line-height: 16px;
121
143
  max-width: 450px; }
144
+ charts chart-embed {
145
+ display: block;
146
+ width: 100%;
147
+ overflow-x: auto;
148
+ padding: 10px 5px 20px 2px; }
122
149
  charts chart-table {
123
150
  display: flex;
124
151
  flex-flow: column nowrap;
125
- gap: 12px; }
152
+ gap: 12px;
153
+ flex: 1 1 auto;
154
+ min-width: 0;
155
+ max-width: 100%; }
126
156
  charts chart-table .chart-table-title {
127
157
  color: var(--Grey-800, #3c4043);
128
158
  font-family: "Google Sans Display", "Google Sans", sans-serif;
@@ -138,7 +168,10 @@ charts {
138
168
  font-style: normal;
139
169
  font-weight: 400;
140
170
  line-height: 20px;
141
- letter-spacing: 0.2px; }
171
+ letter-spacing: 0.2px;
172
+ width: 100%;
173
+ overflow-x: auto;
174
+ padding-bottom: 20px; }
142
175
  charts chart-table .chart-table-content table {
143
176
  border-radius: 4px;
144
177
  border: 1px solid var(--Grey-300, #dadce0);
@@ -160,15 +193,39 @@ charts {
160
193
  font-weight: 400;
161
194
  line-height: 16px;
162
195
  max-width: 450px; }
196
+ charts chart-table finding {
197
+ display: flex;
198
+ width: fit-content;
199
+ align-items: center;
200
+ justify-content: center;
201
+ white-space: nowrap;
202
+ padding: 4px 12px;
203
+ border-radius: 16px;
204
+ margin: 2px 4px 2px 0;
205
+ font-family: "Google Sans Display", "Google Sans", sans-serif;
206
+ font-weight: 500;
207
+ font-size: 13px;
208
+ line-height: 20px; }
209
+ charts chart-table finding.error {
210
+ background-color: #f8d7da;
211
+ color: var(--Red-600, #d93025); }
212
+ charts chart-table finding.attention {
213
+ background-color: #fff3cd;
214
+ color: #664d03; }
215
+ charts chart-table finding.info {
216
+ background-color: #dae8fc;
217
+ color: var(--blue-700, #1967d2); }
163
218
 
164
219
  stats-section {
165
220
  display: flex;
166
- flex-direction: row;
167
- justify-content: space-around;
221
+ flex-flow: row wrap;
222
+ justify-content: flex-start;
223
+ gap: 40px;
168
224
  padding: 32px; }
169
225
  stats-section stats {
170
226
  display: flex;
171
- flex-direction: column; }
227
+ flex-direction: column;
228
+ flex: 0 0 auto; }
172
229
  stats-section stats > stats-title {
173
230
  color: var(--grey-900, #202124);
174
231
  font-family: Roboto;
@@ -20,6 +20,10 @@ $insights_bg_grey: #eee;
20
20
  $text_grey: var(--Grey-800, #3c4043);
21
21
  $text_green: var(--Green-500, #34a853);
22
22
  $text_red: var(--Red-600, #d93025);
23
+ $error_red: #f8d7da;
24
+ $text_yellow: #664d03;
25
+ $warning_yellow: #fff3cd;
26
+ $info_blue: #dae8fc;
23
27
 
24
28
  $google_sans: 'Google Sans Display', 'Google Sans', sans-serif;
25
29
  $roboto: Roboto;
@@ -112,12 +116,38 @@ card {
112
116
  }
113
117
 
114
118
  > card-insights {
119
+ background: $insights_bg_grey;
120
+ }
121
+
122
+ errors,
123
+ warnings,
124
+ infos {
125
+ margin-top: -39px;
126
+ margin-left: -69px;
127
+ margin-right: -69px;
128
+ }
129
+
130
+ errors {
131
+ background: $error_red;
132
+ }
133
+
134
+ warnings {
135
+ background: $warning_yellow;
136
+ }
137
+
138
+ infos {
139
+ background: $info_blue;
140
+ }
141
+
142
+ > card-insights,
143
+ errors,
144
+ warnings,
145
+ infos {
115
146
  display: flex;
116
147
  flex-direction: row;
117
148
  padding: 22px 20px;
118
- background: $insights_bg_grey;
119
149
 
120
- p.insights-text {
150
+ p {
121
151
  margin: 0px 25px;
122
152
  color: $text_grey;
123
153
 
@@ -149,6 +179,9 @@ charts {
149
179
 
150
180
  chart {
151
181
  @include chart-style;
182
+ flex: 1 1 auto;
183
+ min-width: 0;
184
+ max-width: 100%;
152
185
 
153
186
  > chart-description {
154
187
  color: $text_grey;
@@ -162,8 +195,19 @@ charts {
162
195
  }
163
196
  }
164
197
 
198
+ chart-embed {
199
+ display: block;
200
+ width: 100%;
201
+
202
+ overflow-x: auto;
203
+ padding: 10px 5px 20px 2px;
204
+ }
205
+
165
206
  chart-table {
166
207
  @include chart-style;
208
+ flex: 1 1 auto;
209
+ min-width: 0;
210
+ max-width: 100%;
167
211
 
168
212
  .chart-table-title {
169
213
  color: $text_grey;
@@ -185,6 +229,9 @@ charts {
185
229
  line-height: 20px;
186
230
  letter-spacing: 0.2px;
187
231
 
232
+ width: 100%;
233
+ overflow-x: auto;
234
+ padding-bottom: 20px;
188
235
  @mixin border-style {
189
236
  border-radius: 4px;
190
237
  border: 1px solid var(--Grey-300, #dadce0);
@@ -216,18 +263,52 @@ charts {
216
263
  line-height: 16px;
217
264
  max-width: 450px;
218
265
  }
266
+
267
+ finding {
268
+ display: flex;
269
+ width: fit-content;
270
+ align-items: center;
271
+ justify-content: center;
272
+ white-space: nowrap;
273
+
274
+ padding: 4px 12px;
275
+ border-radius: 16px;
276
+ margin: 2px 4px 2px 0;
277
+
278
+ font-family: $google_sans;
279
+ font-weight: 500;
280
+ font-size: 13px;
281
+ line-height: 20px;
282
+
283
+ &.error {
284
+ background-color: $error_red;
285
+ color: $text_red;
286
+ }
287
+
288
+ &.attention {
289
+ background-color: $warning_yellow;
290
+ color: $text_yellow;
291
+ }
292
+
293
+ &.info {
294
+ background-color: $info_blue;
295
+ color: $chip_blue;
296
+ }
297
+ }
219
298
  }
220
299
  }
221
300
 
222
301
  stats-section {
223
302
  display: flex;
224
- flex-direction: row;
225
- justify-content: space-around;
303
+ flex-flow: row wrap;
304
+ justify-content: flex-start;
305
+ gap: 40px;
226
306
  padding: 32px;
227
307
 
228
308
  stats {
229
309
  display: flex;
230
310
  flex-direction: column;
311
+ flex: 0 0 auto;
231
312
 
232
313
  > stats-title {
233
314
  color: $title_dark_grey;
@@ -15,6 +15,7 @@ limitations under the License.
15
15
  #}
16
16
 
17
17
  <chart-table id="{{ id }}">
18
+ {% include "findings.html.jinja" %}
18
19
  <div class="chart-table-title">{{ title }}</div>
19
20
  <div class="chart-table-content">
20
21
  <table>
meridian/version.py CHANGED
@@ -14,4 +14,4 @@
14
14
 
15
15
  """Module for Meridian version."""
16
16
 
17
- __version__ = "1.4.0"
17
+ __version__ = "1.5.0"
@@ -21,7 +21,7 @@ parameter names.
21
21
  REPORT_TEMPLATE_ID = 'fbd3aeff-fc00-45fd-83f7-1ec5f21c9f56'
22
22
  COMMUNITY_CONNECTOR_NAME = 'community'
23
23
  COMMUNITY_CONNECTOR_ID = (
24
- 'AKfycbz-xdEN-GbTuQ9MjEddS-64wLgXwMMTp9a4zFE4PO_kwT6wDgZPsN4Y19oKmLLHD6xk'
24
+ 'AKfycbx-HSSApXV7VTvYCgzvifyQ-bRlB3Oo_uAAvRKBkwUG1nOVIO_dTlpTmVASYM7oeN0D'
25
25
  )
26
26
  SHEETS_CONNECTOR_NAME = 'googleSheets'
27
27
  GA4_MEASUREMENT_ID = 'G-R6C81BNHJ4'
@@ -191,6 +191,7 @@ class MmmUiProtoGenerator:
191
191
  grid_name=rf_opt_grid_name,
192
192
  group_id=budget_opt_spec.group_id,
193
193
  confidence_level=budget_opt_spec.confidence_level,
194
+ max_frequency=budget_opt_spec.max_frequency,
194
195
  )
195
196
 
196
197
  def _enumerate_dates_open_end(
@@ -141,7 +141,7 @@ __all__ = [
141
141
  ]
142
142
 
143
143
 
144
- @dataclasses.dataclass(frozen=True)
144
+ @dataclasses.dataclass(frozen=True, kw_only=True)
145
145
  class MediaSummarySpec(model_processor.Spec):
146
146
  """Stores parameters needed for creating media summary metrics.
147
147
 
@@ -172,7 +172,6 @@ class MediaSummarySpec(model_processor.Spec):
172
172
  aggregate_times: bool = True
173
173
  marginal_roi_by_reach: bool = True
174
174
  include_non_paid_channels: bool = False
175
- # b/384034128 Use new args in `summary_metrics`.
176
175
  new_data: analyzer.DataTensors | None = None
177
176
  media_selected_times: Sequence[bool] | None = None
178
177
 
@@ -368,15 +367,16 @@ class MarketingProcessor(
368
367
  marketing_analysis_list: list[marketing_analysis_pb2.MarketingAnalysis] = []
369
368
 
370
369
  for spec in marketing_analysis_specs:
371
- if (
372
- spec.incremental_outcome_spec is not None
373
- and spec.incremental_outcome_spec.new_data is not None
374
- and spec.incremental_outcome_spec.new_data.time is not None
375
- ):
370
+ if spec.incremental_outcome_spec is not None:
371
+ new_data = spec.incremental_outcome_spec.new_data
372
+ elif spec.media_summary_spec is not None:
373
+ new_data = spec.media_summary_spec.new_data
374
+ else:
375
+ new_data = None
376
+
377
+ if new_data is not None and new_data.time is not None:
376
378
  new_time_coords = time_coordinates.TimeCoordinates.from_dates(
377
- np.asarray(spec.incremental_outcome_spec.new_data.time)
378
- .astype(str)
379
- .tolist()
379
+ np.asarray(new_data.time).astype(str).tolist()
380
380
  )
381
381
  resolver = spec.resolver(new_time_coords)
382
382
  else:
@@ -507,6 +507,7 @@ class MarketingProcessor(
507
507
  selected_times=selected_times,
508
508
  aggregate_geos=True,
509
509
  aggregate_times=media_summary_spec.aggregate_times,
510
+ new_data=media_summary_spec.new_data,
510
511
  confidence_level=confidence_level,
511
512
  )
512
513
 
@@ -73,7 +73,10 @@ class TrainedModel(abc.ABC):
73
73
  @functools.cached_property
74
74
  def internal_analyzer(self) -> analyzer.Analyzer:
75
75
  """Returns an internal `Analyzer` bound to this trained model."""
76
- return analyzer.Analyzer(self.mmm)
76
+ return analyzer.Analyzer(
77
+ model_context=self.mmm.model_context,
78
+ inference_data=self.mmm.inference_data,
79
+ )
77
80
 
78
81
  @functools.cached_property
79
82
  def internal_optimizer(self) -> optimizer.BudgetOptimizer:
@@ -193,6 +193,8 @@ class DistributionSerde(
193
193
  return meridian_pb.TfpParameterValue(scalar_value=value)
194
194
  case int():
195
195
  return meridian_pb.TfpParameterValue(int_value=value)
196
+ # TODO: b/470407198 - case bool() has to be before int() because bool is a
197
+ # subtype of int.
196
198
  case bool():
197
199
  return meridian_pb.TfpParameterValue(bool_value=value)
198
200
  case str():
@@ -216,10 +218,6 @@ class DistributionSerde(
216
218
  return meridian_pb.TfpParameterValue(
217
219
  dict_value=meridian_pb.TfpParameterValue.Dict(value_map=dict_value)
218
220
  )
219
- case backend.Tensor():
220
- return meridian_pb.TfpParameterValue(
221
- tensor_value=backend.make_tensor_proto(value)
222
- )
223
221
  case backend.tfd.Distribution():
224
222
  return meridian_pb.TfpParameterValue(
225
223
  distribution_value=self._to_distribution_proto(value)
@@ -257,9 +255,16 @@ class DistributionSerde(
257
255
  f" {type(dist).__name__}, but not found in registry. Please"
258
256
  " add custom functions to registry when saving models."
259
257
  )
260
-
261
- # Handle unsupported types.
262
- raise TypeError(f"Unsupported type: {type(value)}, {value}")
258
+ case _:
259
+ # Handle unsupported types by attempting to convert to a tensor proto.
260
+ # This allows for more flexibility in handling types that are not
261
+ # explicitly handled above, such as numpy arrays or backend tensors.
262
+ try:
263
+ return meridian_pb.TfpParameterValue(
264
+ tensor_value=backend.make_tensor_proto(value)
265
+ )
266
+ except TypeError as e:
267
+ raise TypeError(f"Unsupported type: {type(value)}, {value!r}") from e
263
268
 
264
269
  def _from_distribution_proto(
265
270
  self,