google-meridian 1.1.0__py3-none-any.whl → 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {google_meridian-1.1.0.dist-info → google_meridian-1.1.2.dist-info}/METADATA +6 -2
  2. google_meridian-1.1.2.dist-info/RECORD +46 -0
  3. {google_meridian-1.1.0.dist-info → google_meridian-1.1.2.dist-info}/WHEEL +1 -1
  4. meridian/__init__.py +2 -2
  5. meridian/analysis/__init__.py +1 -1
  6. meridian/analysis/analyzer.py +29 -22
  7. meridian/analysis/formatter.py +1 -1
  8. meridian/analysis/optimizer.py +70 -44
  9. meridian/analysis/summarizer.py +1 -1
  10. meridian/analysis/summary_text.py +1 -1
  11. meridian/analysis/test_utils.py +1 -1
  12. meridian/analysis/visualizer.py +17 -8
  13. meridian/constants.py +3 -3
  14. meridian/data/__init__.py +4 -1
  15. meridian/data/arg_builder.py +1 -1
  16. meridian/data/data_frame_input_data_builder.py +614 -0
  17. meridian/data/input_data.py +12 -8
  18. meridian/data/input_data_builder.py +817 -0
  19. meridian/data/load.py +121 -428
  20. meridian/data/nd_array_input_data_builder.py +509 -0
  21. meridian/data/test_utils.py +60 -43
  22. meridian/data/time_coordinates.py +1 -1
  23. meridian/mlflow/__init__.py +17 -0
  24. meridian/mlflow/autolog.py +54 -0
  25. meridian/model/__init__.py +1 -1
  26. meridian/model/adstock_hill.py +1 -1
  27. meridian/model/knots.py +1 -1
  28. meridian/model/media.py +1 -1
  29. meridian/model/model.py +65 -37
  30. meridian/model/model_test_data.py +75 -1
  31. meridian/model/posterior_sampler.py +19 -15
  32. meridian/model/prior_distribution.py +1 -1
  33. meridian/model/prior_sampler.py +32 -26
  34. meridian/model/spec.py +18 -8
  35. meridian/model/transformers.py +1 -1
  36. google_meridian-1.1.0.dist-info/RECORD +0 -41
  37. {google_meridian-1.1.0.dist-info → google_meridian-1.1.2.dist-info}/licenses/LICENSE +0 -0
  38. {google_meridian-1.1.0.dist-info → google_meridian-1.1.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,509 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """An implementation of `InputDataBuilder` with n-dimensional array primitives."""
16
+
17
+ import logging
18
+ import warnings
19
+ from meridian import constants
20
+ from meridian.data import input_data_builder
21
+ import numpy as np
22
+ import xarray as xr
23
+
24
+
25
+ __all__ = [
26
+ 'NDArrayInputDataBuilder',
27
+ ]
28
+
29
+
30
+ class NDArrayInputDataBuilder(input_data_builder.InputDataBuilder):
31
+ """Builds `InputData` from n-dimensional arrays."""
32
+
33
+ # Unlike `DataFrameInputDataBuilder`, each piecemeal data has no coordinate
34
+ # information; they're purely data values. It's up to the user to provide
35
+ # coordinates with setter methods from the abstract base class above.
36
+ # Validation is done on each piece w.r.t. dimensional consistency by
37
+ # shape alone.
38
+
39
+ def with_kpi(self, nd: np.ndarray) -> 'NDArrayInputDataBuilder':
40
+ """Reads KPI data from a ndarray.
41
+
42
+ `nd` must be given with the shape:
43
+ - `(n_geos, n_time)`
44
+ - `(n_time,)` or `(1, n_time)` for national model.
45
+
46
+ If called without a call to .geos() first, the data will be
47
+ assumed to be national-level.
48
+
49
+ Args:
50
+ nd: The ndarray to read the KPI data from.
51
+
52
+ Returns:
53
+ The `NDArrayInputDataBuilder` with the added KPI data.
54
+ """
55
+ ### Validate ###
56
+ self._validate_coords()
57
+ self._validate_shape(nd)
58
+
59
+ ### Transform ###
60
+ self.kpi = xr.DataArray(
61
+ nd,
62
+ dims=[constants.GEO, constants.TIME],
63
+ coords={
64
+ constants.GEO: self.geos,
65
+ constants.TIME: self.time_coords,
66
+ },
67
+ name=constants.KPI,
68
+ )
69
+ return self
70
+
71
+ def with_controls(
72
+ self, nd: np.ndarray, control_names: list[str]
73
+ ) -> 'NDArrayInputDataBuilder':
74
+ """Reads controls data from a ndarray.
75
+
76
+ `nd` must be given with the shape:
77
+ - `(n_geos, n_time, n_controls)`
78
+ - `(n_time, n_controls)` or `(1, n_time, n_controls)` for national model.
79
+
80
+ If called without a call to .geos() first, the data will be
81
+ assumed to be national-level.
82
+
83
+ Args:
84
+ nd: The ndarray to read the controls data from.
85
+ control_names: The names of the control variables.
86
+
87
+ Returns:
88
+ The `NDArrayInputDataBuilder` with the added controls data.
89
+ """
90
+ ### Validate ###
91
+ self._validate_coords()
92
+ self._validate_shape(nd, control_names)
93
+
94
+ ### Transform ###
95
+ self.controls = xr.DataArray(
96
+ nd,
97
+ dims=[constants.GEO, constants.TIME, constants.CONTROL_VARIABLE],
98
+ coords={
99
+ constants.GEO: self.geos,
100
+ constants.TIME: self.time_coords,
101
+ constants.CONTROL_VARIABLE: control_names,
102
+ },
103
+ name=constants.CONTROLS,
104
+ )
105
+ return self
106
+
107
+ def with_population(self, nd: np.ndarray) -> 'NDArrayInputDataBuilder':
108
+ """Reads population data from a ndarray.
109
+
110
+ `nd` must be given with the shape:
111
+ - `(n_geos,)`
112
+
113
+ If called without a call to .geos() first, the data will be
114
+ assumed to be national-level.
115
+
116
+ Args:
117
+ nd: The ndarray to read the population data from.
118
+
119
+ Returns:
120
+ The `NDArrayInputDataBuilder` with the added population data.
121
+ """
122
+ ### Validate ###
123
+ self._validate_coords(is_population=True)
124
+ self._validate_shape(nd, is_population=True)
125
+ ### Transform ###
126
+ self.population = xr.DataArray(
127
+ nd,
128
+ dims=[constants.GEO],
129
+ coords={constants.GEO: self.geos},
130
+ name=constants.POPULATION,
131
+ )
132
+
133
+ return self
134
+
135
+ def with_revenue_per_kpi(self, nd: np.ndarray) -> 'NDArrayInputDataBuilder':
136
+ """Reads Revenue per KPI data from a ndarray.
137
+
138
+ `nd` must be given with the shape:
139
+ - `(n_geos, n_time)`
140
+ - `(n_time,)` or `(1, n_time)` for national model.
141
+
142
+ If called without a call to .geos() first, the data will be
143
+ assumed to be national-level.
144
+
145
+ Args:
146
+ nd: The ndarray to read the Reevenue per KPI data from.
147
+
148
+ Returns:
149
+ The `NDArrayInputDataBuilder` with the added Revenue per KPI data.
150
+ """
151
+ ### Validate ###
152
+ self._validate_coords()
153
+ self._validate_shape(nd)
154
+ revenue_per_kpi_nd = self._check_revenue_per_kpi_defaults(nd)
155
+
156
+ ### Transform ###
157
+ self.revenue_per_kpi = xr.DataArray(
158
+ revenue_per_kpi_nd,
159
+ dims=[constants.GEO, constants.TIME],
160
+ coords={
161
+ constants.GEO: self.geos,
162
+ constants.TIME: self.time_coords,
163
+ },
164
+ name=constants.REVENUE_PER_KPI,
165
+ )
166
+ return self
167
+
168
+ def with_media(
169
+ self, m_nd: np.ndarray, ms_nd: np.ndarray, media_channels: list[str]
170
+ ) -> 'NDArrayInputDataBuilder':
171
+ """Reads media and media spend data from the ndarrays.
172
+
173
+ `m_nd` must be given with the shape:
174
+ - `(n_geos, n_media_times, n_media_channels)`
175
+ - `(n_media_times, n_media_channels)` or `(1, n_media_times,
176
+ n_media_channels)` for national model.
177
+
178
+ `ms_nd` must be given with the shape:
179
+ - `(n_geos, n_times, n_media_channels)`
180
+ - `(n_times, n_media_channels)` or `(1, n_times,
181
+ n_media_channels)` for national model.
182
+
183
+ If called without a call to .geos() first, the data will be
184
+ assumed to be national-level.
185
+
186
+ Args:
187
+ m_nd: The ndarray that contains dimensional media data.
188
+ ms_nd: The ndarray that contains dimensional media spend data.
189
+ media_channels: The names of the media channels.
190
+
191
+ Returns:
192
+ The `NDArrayInputDataBuilder` with the added media and media spend data.
193
+ """
194
+ ### Validate ###
195
+ self._validate_coords(is_media_time=True)
196
+ self._validate_coords(is_media_time=False)
197
+ self._validate_shape(nd=m_nd, dims=media_channels, is_media_time=True)
198
+ self._validate_shape(nd=ms_nd, dims=media_channels, is_media_time=False)
199
+
200
+ ### Transform ###
201
+ self.media = xr.DataArray(
202
+ m_nd,
203
+ dims=[
204
+ constants.GEO,
205
+ constants.MEDIA_TIME,
206
+ constants.MEDIA_CHANNEL,
207
+ ],
208
+ coords={
209
+ constants.GEO: self.geos,
210
+ constants.MEDIA_TIME: self.media_time_coords,
211
+ constants.MEDIA_CHANNEL: media_channels,
212
+ },
213
+ name=constants.MEDIA,
214
+ )
215
+ self.media_spend = xr.DataArray(
216
+ ms_nd,
217
+ dims=[
218
+ constants.GEO,
219
+ constants.TIME,
220
+ constants.MEDIA_CHANNEL,
221
+ ],
222
+ coords={
223
+ constants.GEO: self.geos,
224
+ constants.TIME: self.time_coords,
225
+ constants.MEDIA_CHANNEL: media_channels,
226
+ },
227
+ name=constants.MEDIA_SPEND,
228
+ )
229
+ return self
230
+
231
+ def with_reach(
232
+ self,
233
+ r_nd: np.ndarray,
234
+ f_nd: np.ndarray,
235
+ rfs_nd: np.ndarray,
236
+ rf_channels: list[str],
237
+ ) -> 'NDArrayInputDataBuilder':
238
+ """Reads reach, frequency, and rf_spend data from the ndarrays.
239
+
240
+ `r_nd` and `f_nd` must be given with the shape:
241
+ - `(n_geos, n_media_times, n_rf_channels)`
242
+ - `(n_media_times, n_rf_channels)` or `(1, n_media_times,
243
+ n_rf_channels)` for national model.
244
+
245
+ `rfs_nd` must be given with the shape:
246
+ - `(n_geos, n_times, n_rf_channels)`
247
+ - `(n_times, n_rf_channels)` or `(1, n_times,
248
+ n_rf_channels)` for national model.
249
+
250
+ If called without a call to .geos() first, the data will be
251
+ assumed to be national-level.
252
+
253
+ Args:
254
+ r_nd: The ndarray that contains dimensional reach data.
255
+ f_nd: The ndarray that contains dimensional frequency data.
256
+ rfs_nd: The ndarray that contains dimensional rf_spend data.
257
+ rf_channels: The names of the rf channels.
258
+
259
+ Returns:
260
+ The `NDArrayInputDataBuilder` with the added reach, frequency, and
261
+ rf_spend data.
262
+ """
263
+ ### Validate ###
264
+ self._validate_coords(is_media_time=True)
265
+ self._validate_coords(is_media_time=False)
266
+ self._validate_shape(nd=r_nd, dims=rf_channels, is_media_time=True)
267
+ self._validate_shape(nd=f_nd, dims=rf_channels, is_media_time=True)
268
+ self._validate_shape(nd=rfs_nd, dims=rf_channels, is_media_time=False)
269
+
270
+ ### Transform ###
271
+ self.reach = xr.DataArray(
272
+ r_nd,
273
+ dims=[
274
+ constants.GEO,
275
+ constants.MEDIA_TIME,
276
+ constants.RF_CHANNEL,
277
+ ],
278
+ coords={
279
+ constants.GEO: self.geos,
280
+ constants.MEDIA_TIME: self.media_time_coords,
281
+ constants.RF_CHANNEL: rf_channels,
282
+ },
283
+ name=constants.REACH,
284
+ )
285
+ self.frequency = xr.DataArray(
286
+ f_nd,
287
+ dims=[
288
+ constants.GEO,
289
+ constants.MEDIA_TIME,
290
+ constants.RF_CHANNEL,
291
+ ],
292
+ coords={
293
+ constants.GEO: self.geos,
294
+ constants.MEDIA_TIME: self.media_time_coords,
295
+ constants.RF_CHANNEL: rf_channels,
296
+ },
297
+ name=constants.FREQUENCY,
298
+ )
299
+ self.rf_spend = xr.DataArray(
300
+ rfs_nd,
301
+ dims=[
302
+ constants.GEO,
303
+ constants.TIME,
304
+ constants.RF_CHANNEL,
305
+ ],
306
+ coords={
307
+ constants.GEO: self.geos,
308
+ constants.TIME: self.time_coords,
309
+ constants.RF_CHANNEL: rf_channels,
310
+ },
311
+ name=constants.RF_SPEND,
312
+ )
313
+ return self
314
+
315
+ def with_organic_media(
316
+ self, nd: np.ndarray, organic_media_channels: list[str]
317
+ ) -> 'NDArrayInputDataBuilder':
318
+ """Reads organic media data from a ndarray.
319
+
320
+ `nd` must be given with the shape:
321
+ - `(n_geos, n_media_times, n_organic_media_channels)`
322
+ - `(n_media_times, n_organic_media_channels)` or `(1, n_media_times,
323
+ n_organic_media_channels)` for national model.
324
+
325
+ If called without a call to .geos() first, the data will be
326
+ assumed to be national-level.
327
+
328
+ Args:
329
+ nd: The ndarray to read the organic media data from.
330
+ organic_media_channels: The names of the organic media channels.
331
+
332
+ Returns:
333
+ The `NDArrayInputDataBuilder` with the added organic media data.
334
+ """
335
+ ### Validate ###
336
+ self._validate_coords(is_media_time=True)
337
+ self._validate_shape(nd=nd, dims=organic_media_channels, is_media_time=True)
338
+
339
+ ### Transform ###
340
+ self.organic_media = xr.DataArray(
341
+ nd,
342
+ dims=[
343
+ constants.GEO,
344
+ constants.MEDIA_TIME,
345
+ constants.ORGANIC_MEDIA_CHANNEL,
346
+ ],
347
+ coords={
348
+ constants.GEO: self.geos,
349
+ constants.MEDIA_TIME: self.media_time_coords,
350
+ constants.ORGANIC_MEDIA_CHANNEL: organic_media_channels,
351
+ },
352
+ name=constants.ORGANIC_MEDIA,
353
+ )
354
+ return self
355
+
356
+ def with_organic_reach(
357
+ self, or_nd: np.ndarray, of_nd: np.ndarray, organic_rf_channels: list[str]
358
+ ) -> 'NDArrayInputDataBuilder':
359
+ """Reads organic reach and organic frequency data from the ndarrays.
360
+
361
+ `or_nd` and `of_nd` must be given with the shape:
362
+ - `(n_geos, n_media_times, n_organic_rf_channels)`
363
+ - `(n_media_times, n_organic_rf_channels)` or `(1, n_media_times,
364
+ n_organic_rf_channels)` for national model.
365
+
366
+ If called without a call to .geos() first, the data will be
367
+ assumed to be national-level.
368
+
369
+ Args:
370
+ or_nd: The ndarray that contains dimensional reach data.
371
+ of_nd: The ndarray that contains dimensional frequency data.
372
+ organic_rf_channels: The names of the organic rf channels.
373
+
374
+ Returns:
375
+ The `NDArrayInputDataBuilder` with the added organic reach and organic
376
+ frequency data.
377
+ """
378
+ ### Validate ###
379
+ self._validate_coords(is_media_time=True)
380
+ self._validate_shape(nd=or_nd, dims=organic_rf_channels, is_media_time=True)
381
+ self._validate_shape(nd=of_nd, dims=organic_rf_channels, is_media_time=True)
382
+
383
+ ### Transform ###
384
+ self.organic_reach = xr.DataArray(
385
+ or_nd,
386
+ dims=[
387
+ constants.GEO,
388
+ constants.MEDIA_TIME,
389
+ constants.ORGANIC_RF_CHANNEL,
390
+ ],
391
+ coords={
392
+ constants.GEO: self.geos,
393
+ constants.MEDIA_TIME: self.media_time_coords,
394
+ constants.ORGANIC_RF_CHANNEL: organic_rf_channels,
395
+ },
396
+ name=constants.ORGANIC_REACH,
397
+ )
398
+ self.organic_frequency = xr.DataArray(
399
+ of_nd,
400
+ dims=[
401
+ constants.GEO,
402
+ constants.MEDIA_TIME,
403
+ constants.ORGANIC_RF_CHANNEL,
404
+ ],
405
+ coords={
406
+ constants.GEO: self.geos,
407
+ constants.MEDIA_TIME: self.media_time_coords,
408
+ constants.ORGANIC_RF_CHANNEL: organic_rf_channels,
409
+ },
410
+ name=constants.ORGANIC_REACH,
411
+ )
412
+ return self
413
+
414
+ def with_non_media_treatments(
415
+ self, nd: np.ndarray, non_media_channel_names: list[str]
416
+ ) -> 'NDArrayInputDataBuilder':
417
+ """Reads non-media treatments data from a ndarray.
418
+
419
+ `nd` must be given with the shape:
420
+ - `(n_geos, n_time, n_media_channels)`
421
+ - `(n_time, n_media_channels)` or `(1, n_time, n_media_channels)` for
422
+ national model.
423
+
424
+ If called without a call to .geos() first, the data will be
425
+ assumed to be national-level.
426
+
427
+ Args:
428
+ nd: The ndarray to read the non-media treatments data from.
429
+ non_media_channel_names: The names of the non-media channels.
430
+
431
+ Returns:
432
+ The `NDArrayInputDataBuilder` with the added non-media treatments data.
433
+ """
434
+ ### Validate ###
435
+ self._validate_coords()
436
+ self._validate_shape(nd, non_media_channel_names)
437
+
438
+ ### Transform ###
439
+ self.non_media_treatments = xr.DataArray(
440
+ nd,
441
+ dims=[constants.GEO, constants.TIME, constants.NON_MEDIA_CHANNEL],
442
+ coords={
443
+ constants.GEO: self.geos,
444
+ constants.TIME: self.time_coords,
445
+ constants.NON_MEDIA_CHANNEL: non_media_channel_names,
446
+ },
447
+ name=constants.NON_MEDIA_TREATMENTS,
448
+ )
449
+ return self
450
+
451
+ def _validate_coords(
452
+ self, is_population: bool = False, is_media_time: bool = False
453
+ ):
454
+ """Validates that the data has the expected coordinates."""
455
+ if not is_population:
456
+ if is_media_time and self._media_time_coords is None:
457
+ raise ValueError(
458
+ 'Media times are required first. Set using .media_time_coords()'
459
+ )
460
+ if not is_media_time and self.time_coords is None:
461
+ raise ValueError(
462
+ 'Time coordinates are required first. Set using .time_coords()'
463
+ )
464
+ if self.geos is None:
465
+ logging.warning(
466
+ 'No geo coordinates set. Assuming NATIONAL model and geos will be set'
467
+ ' to the default value.'
468
+ )
469
+ self.geos = [constants.NATIONAL_MODEL_DEFAULT_GEO_NAME]
470
+
471
+ def _validate_shape(
472
+ self,
473
+ nd: np.ndarray,
474
+ dims: list[str] | None = None,
475
+ is_population: bool = False,
476
+ is_media_time: bool = False,
477
+ ):
478
+ """Validates that the data has the expected shape."""
479
+ # Since all data has a geo dimension (even for national data),
480
+ # Expect the first axis to have the shape of the geo dimension.
481
+ expected_shape = (len(self.geos),)
482
+ detailed_info = f'Expected: {len(self.geos)} geos'
483
+ if not is_population:
484
+ if is_media_time:
485
+ expected_shape += (len(self.media_time_coords),)
486
+ detailed_info += f' x {len(self.media_time_coords)} media times'
487
+ else:
488
+ expected_shape += (len(self.time_coords),)
489
+ detailed_info += f' x {len(self.time_coords)} times'
490
+
491
+ if dims is not None:
492
+ if len(dims) != len(set(dims)):
493
+ raise ValueError('given dimensions must be unique.')
494
+ expected_shape += (len(dims),)
495
+ detailed_info += f' x {len(dims)} dims'
496
+
497
+ if expected_shape != nd.shape:
498
+ raise ValueError(f'{detailed_info}. Got: {nd.shape}.')
499
+
500
+ def _check_revenue_per_kpi_defaults(self, nd: np.ndarray):
501
+ """Sets revenue_per_kpi to default if kpi type is revenue and with_revenue_per_kpi is called."""
502
+ if self._kpi_type == constants.REVENUE:
503
+ warnings.warn(
504
+ 'with_revenue_per_kpi was called but kpi_type was set to revenue.'
505
+ ' Assuming revenue per kpi with values [1].'
506
+ )
507
+ return np.ones(nd.shape)
508
+ else:
509
+ return nd