google-meridian 1.1.4__py3-none-any.whl → 1.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: google-meridian
3
- Version: 1.1.4
3
+ Version: 1.1.6
4
4
  Summary: Google's open source mixed marketing model library, helps you understand your return on investment and direct your ad spend with confidence.
5
5
  Author-email: The Meridian Authors <no-reply@google.com>
6
6
  License:
@@ -397,7 +397,7 @@ To cite this repository:
397
397
  author = {Google Meridian Marketing Mix Modeling Team},
398
398
  title = {Meridian: Marketing Mix Modeling},
399
399
  url = {https://github.com/google/meridian},
400
- version = {1.1.4},
400
+ version = {1.1.6},
401
401
  year = {2025},
402
402
  }
403
403
  ```
@@ -1,7 +1,7 @@
1
- google_meridian-1.1.4.dist-info/licenses/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
1
+ google_meridian-1.1.6.dist-info/licenses/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
2
2
  meridian/__init__.py,sha256=XROKwHNVQvEa371QCXAHik5wN_YKObOdJQX9bJ2c4M4,832
3
- meridian/constants.py,sha256=U-XNPHsEGBOySsJI8EYnSp7J7bQ_uEIrilFwXO1StZs,17223
4
- meridian/version.py,sha256=o6Skp0JcQqMgRgK6jJoh12OYsXy0ncCvRZr1jECnVbE,644
3
+ meridian/constants.py,sha256=YE5h3qKH8e2lI3d9vWxc5TsSHUm5bHcz1Lq-2LurJnw,17204
4
+ meridian/version.py,sha256=Po1LKdWVufT1_bpzSPvWjGDDYioK0kiMOjW4fb3VUyM,644
5
5
  meridian/analysis/__init__.py,sha256=nGBYz7k9FVdadO_WVGMKJcfq7Yy_TuuP8zgee4i9pSA,836
6
6
  meridian/analysis/analyzer.py,sha256=L7XyCTd4e_Bqfi8a0bW1WaXjH2ZvSVTPs0VP12a209c,206559
7
7
  meridian/analysis/formatter.py,sha256=ENIdR1CRiaVqIGEXx1HcnsA4ewgDD_nhsYCweJAThaw,7270
@@ -21,7 +21,7 @@ meridian/analysis/templates/summary.html.jinja,sha256=LuENVDHYIpNo4pzloYaCR2K9XN
21
21
  meridian/analysis/templates/table.html.jinja,sha256=mvLMZx92RcD2JAS2w2eZtfYG-6WdfwYVo7pM8TbHp4g,1176
22
22
  meridian/data/__init__.py,sha256=StIe-wfYnnbfUbKtZHwnAQcRQUS8XCZk_PCaEzw90Ww,929
23
23
  meridian/data/arg_builder.py,sha256=Kqlt88bOqFj6D3xNwvWo4MBwNwcDFHzd-wMfEOmLoPU,3741
24
- meridian/data/data_frame_input_data_builder.py,sha256=1upb0gfEmU-E8GX2C60NABRNEE8_iIPHARwF4OPnbEk,23195
24
+ meridian/data/data_frame_input_data_builder.py,sha256=_hexZMFAuAowgo6FaOGElHSFHqhGnHQwEEBcwnT3zUE,27295
25
25
  meridian/data/input_data.py,sha256=teJPKTBfW-AzBWgf_fEO_S_Z1J_veqQkCvctINaid6I,39749
26
26
  meridian/data/input_data_builder.py,sha256=tbZjVXPDfmtndVyJA0fmzGzZwZb0RCEjXOTXb-ga8Nc,25648
27
27
  meridian/data/load.py,sha256=X2nmYCC-7A0RUgmdolTqCt0TD3NEZabQ5oGv-TugE00,40129
@@ -34,14 +34,14 @@ meridian/model/__init__.py,sha256=9NFfqUE5WgFc-9lQMkbfkwwV-bQIz0tsQ_3Jyq0A4SU,98
34
34
  meridian/model/adstock_hill.py,sha256=20A_6rbDUAADEkkHspB7JpCm5tYfYS1FQ6hJMLu21Pk,9283
35
35
  meridian/model/knots.py,sha256=KPEgnb-UdQQ4QBugOYEke-zBgEghgTmeCMoeiJ30meY,8054
36
36
  meridian/model/media.py,sha256=3BaPX8xYAFMEvf0mz3mBSCIDWViIs7M218nrCklc6Fk,14099
37
- meridian/model/model.py,sha256=BlLPyskHrEx5D71mUZFbNxS2VjkQgaiaE6hLKvQ5D3A,61489
37
+ meridian/model/model.py,sha256=XxVJaJtfUnCWI6gM7hWC6yC64yXECi91r1LHP2B23SQ,61216
38
38
  meridian/model/model_test_data.py,sha256=hDDTEzm72LknW9c5E_dNsy4Mm4Tfs6AirhGf_QxykFs,15552
39
- meridian/model/posterior_sampler.py,sha256=K49zWTTelME2rL1JLeFAdMPzL0OwrBvyAXA3oR-kgSI,27801
40
- meridian/model/prior_distribution.py,sha256=IEDU1rabcmKNY8lxwbbO4OUAlMHPIMa7flM_zsu3DLM,42417
41
- meridian/model/prior_sampler.py,sha256=cmu6jG-bSEkYDkjVUxl3iSxrL7r-LN7a77cb2Vc0LoA,23218
39
+ meridian/model/posterior_sampler.py,sha256=aOYMu4R1ltak3VC0scjrAPig5ExSjkpagk4pjmxKeh4,27884
40
+ meridian/model/prior_distribution.py,sha256=1Qh7jQ2py7tdhLPDyeQzZ0doU6NhQRVaA0lGZNnOVZA,42554
41
+ meridian/model/prior_sampler.py,sha256=by41y2g56jEeJ1cxJi_s45uaUBySgf7wtL5u7-GpVE8,23325
42
42
  meridian/model/spec.py,sha256=0HNiMQUWQpYvWYOZr1_fj2ah8tH-bEyfEjoqgBZ9Lc0,18049
43
43
  meridian/model/transformers.py,sha256=nRjzq1fQG0ypldxboM7Gqok6WSAXAS1witRXoAzeH9Q,7763
44
- google_meridian-1.1.4.dist-info/METADATA,sha256=_VAPyn1fgR57O8dJ8nVznDy1EQXbe5FVLdAbibd1GWU,22201
45
- google_meridian-1.1.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
46
- google_meridian-1.1.4.dist-info/top_level.txt,sha256=nwaCebZvvU34EopTKZsjK0OMTFjVnkf4FfnBN_TAc0g,9
47
- google_meridian-1.1.4.dist-info/RECORD,,
44
+ google_meridian-1.1.6.dist-info/METADATA,sha256=HV2L4mWfmMtz4hTWOEsVJYbTQ-aGa_0DeIr45ScZQEw,22201
45
+ google_meridian-1.1.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
46
+ google_meridian-1.1.6.dist-info/top_level.txt,sha256=nwaCebZvvU34EopTKZsjK0OMTFjVnkf4FfnBN_TAc0g,9
47
+ google_meridian-1.1.6.dist-info/RECORD,,
meridian/constants.py CHANGED
@@ -218,9 +218,10 @@ PAID_MEDIA_ROI_PRIOR_TYPES = frozenset(
218
218
  # Represents a 1% increase in spend.
219
219
  MROI_FACTOR = 1.01
220
220
 
221
- NATIONAL_MODEL_SPEC_ARGS = immutabledict.immutabledict(
222
- {MEDIA_EFFECTS_DIST: MEDIA_EFFECTS_NORMAL, UNIQUE_SIGMA_FOR_EACH_GEO: False}
223
- )
221
+ NATIONAL_MODEL_SPEC_ARGS = immutabledict.immutabledict({
222
+ MEDIA_EFFECTS_DIST: MEDIA_EFFECTS_NORMAL,
223
+ UNIQUE_SIGMA_FOR_EACH_GEO: False,
224
+ })
224
225
 
225
226
  NATIONAL_ANALYZER_PARAMETERS_DEFAULTS = immutabledict.immutabledict(
226
227
  {'aggregate_geos': True, 'geos_to_include': None}
@@ -231,7 +232,6 @@ NATIONAL_ANALYZER_PARAMETERS_DEFAULTS = immutabledict.immutabledict(
231
232
  CHAIN = 'chain'
232
233
  DRAW = 'draw'
233
234
  KNOTS = 'knots'
234
- SIGMA_DIM = 'sigma_dim'
235
235
 
236
236
 
237
237
  # Model parameters.
@@ -30,25 +30,112 @@ __all__ = [
30
30
  class DataFrameInputDataBuilder(input_data_builder.InputDataBuilder):
31
31
  """Builds `InputData` from DataFrames."""
32
32
 
33
+ def __init__(
34
+ self,
35
+ kpi_type: str,
36
+ default_geo_column: str = constants.GEO,
37
+ default_time_column: str = constants.TIME,
38
+ default_media_time_column: str = constants.TIME,
39
+ default_population_column: str = constants.POPULATION,
40
+ default_kpi_column: str = constants.KPI,
41
+ default_revenue_per_kpi_column: str = constants.REVENUE_PER_KPI,
42
+ ):
43
+ super().__init__(kpi_type)
44
+
45
+ self._default_geo_column = default_geo_column
46
+ self._default_time_column = default_time_column
47
+ self._default_media_time_column = default_media_time_column
48
+ self._default_population_column = default_population_column
49
+ self._default_kpi_column = default_kpi_column
50
+ self._default_revenue_per_kpi_column = default_revenue_per_kpi_column
51
+
52
+ @property
53
+ def default_geo_column(self) -> str:
54
+ """The default geo column name for this builder to use.
55
+
56
+ This column name is used when `geo_col` is not explicitly provided to a data
57
+ setter method.
58
+
59
+ By default, this is `"geo"`.
60
+ """
61
+ return self._default_geo_column
62
+
63
+ @property
64
+ def default_time_column(self) -> str:
65
+ """The default time column name for this builder to use.
66
+
67
+ This column name is used when `time_col` is not explicitly provided to a
68
+ data setter method.
69
+
70
+ By default, this is `"time"`.
71
+ """
72
+ return self._default_time_column
73
+
74
+ @property
75
+ def default_media_time_column(self) -> str:
76
+ """The default *media* time column name for this builder to use.
77
+
78
+ This column name is used when `media_time_col` is not explicitly provided to
79
+ a data setter method.
80
+
81
+ By default, this is also `"time"`, since most input dataframes are likely
82
+ to use the same time column for both their media execution and media spend
83
+ data.
84
+ """
85
+ return self._default_media_time_column
86
+
87
+ @property
88
+ def default_population_column(self) -> str:
89
+ """The default population column name for this builder to use.
90
+
91
+ This column name is used when `population_col` is not explicitly provided to
92
+ a data setter method.
93
+
94
+ By default, this is `"population"`.
95
+ """
96
+ return self._default_population_column
97
+
98
+ @property
99
+ def default_kpi_column(self) -> str:
100
+ """The default kpi column name for this builder to use.
101
+
102
+ This column name is used when `kpi_col` is not explicitly provided to a data
103
+ setter method.
104
+
105
+ By default, this is `"kpi"`.
106
+ """
107
+ return self._default_kpi_column
108
+
109
+ @property
110
+ def default_revenue_per_kpi_column(self) -> str:
111
+ """The default revenue per kpi column name for this builder to use.
112
+
113
+ This column name is used when `revenue_per_kpi_col` is not explicitly
114
+ provided to a data setter method.
115
+
116
+ By default, this is `"revenue_per_kpi"`.
117
+ """
118
+ return self._default_revenue_per_kpi_column
119
+
33
120
  def with_kpi(
34
121
  self,
35
122
  df: pd.DataFrame,
36
- kpi_col: str = constants.KPI,
37
- time_col: str = constants.TIME,
38
- geo_col: str = constants.GEO,
123
+ kpi_col: str | None = None,
124
+ time_col: str | None = None,
125
+ geo_col: str | None = None,
39
126
  ) -> 'DataFrameInputDataBuilder':
40
127
  """Reads KPI data from a DataFrame.
41
128
 
42
129
  Args:
43
130
  df: The DataFrame to read the KPI data from.
44
131
  kpi_col: The name of the column containing the KPI values. If not
45
- provided, the default name is `kpi`.
132
+ provided, `self.default_kpi_column` is used.
46
133
  time_col: The name of the column containing the time coordinates. If not
47
- provided, the default name is `time`.
134
+ provided, `self.default_time_column` is used.
48
135
  geo_col: (Optional) The name of the column containing the geo coordinates.
49
- If not provided, the default name is `geo`. If the DataFrame provided
50
- has no geo column, a national model data is assumed and a geo dimension
51
- will be created internally with a single coordinate value
136
+ If not provided, `self.default_geo_column` is used. If the DataFrame
137
+ provided has no geo column, a national model data is assumed and a geo
138
+ dimension will be created internally with a single coordinate value
52
139
  `national_geo`.
53
140
 
54
141
  Returns:
@@ -56,6 +143,10 @@ class DataFrameInputDataBuilder(input_data_builder.InputDataBuilder):
56
143
  """
57
144
  kpi_df = df.copy()
58
145
 
146
+ kpi_col = kpi_col or self.default_kpi_column
147
+ time_col = time_col or self.default_time_column
148
+ geo_col = geo_col or self.default_geo_column
149
+
59
150
  ### Validate ###
60
151
  self._validate_cols(kpi_df, [kpi_col, time_col], [geo_col])
61
152
  self._validate_coords(kpi_df, geo_col, time_col)
@@ -73,8 +164,8 @@ class DataFrameInputDataBuilder(input_data_builder.InputDataBuilder):
73
164
  self,
74
165
  df: pd.DataFrame,
75
166
  control_cols: list[str],
76
- time_col: str = constants.TIME,
77
- geo_col: str = constants.GEO,
167
+ time_col: str | None = None,
168
+ geo_col: str | None = None,
78
169
  ) -> 'DataFrameInputDataBuilder':
79
170
  """Reads controls data from a DataFrame.
80
171
 
@@ -82,11 +173,11 @@ class DataFrameInputDataBuilder(input_data_builder.InputDataBuilder):
82
173
  df: The DataFrame to read the controls data from.
83
174
  control_cols: The names of the columns containing the controls values.
84
175
  time_col: The name of the column containing the time coordinates. If not
85
- provided, the default name is `time`.
176
+ provided, `self.default_time_column` is used.
86
177
  geo_col: (Optional) The name of the column containing the geo coordinates.
87
- If not provided, the default name is `geo`. If the DataFrame provided
88
- has no geo column, a national model data is assumed and a geo dimension
89
- will be created internally with a single coordinate value
178
+ If not provided, `self.default_geo_column` is used. If the DataFrame
179
+ provided has no geo column, a national model data is assumed and a geo
180
+ dimension will be created internally with a single coordinate value
90
181
  `national_geo`.
91
182
 
92
183
  Returns:
@@ -98,6 +189,9 @@ class DataFrameInputDataBuilder(input_data_builder.InputDataBuilder):
98
189
 
99
190
  controls_df = df.copy()
100
191
 
192
+ time_col = time_col or self.default_time_column
193
+ geo_col = geo_col or self.default_geo_column
194
+
101
195
  ### Validate ###
102
196
  self._validate_cols(
103
197
  controls_df,
@@ -120,19 +214,19 @@ class DataFrameInputDataBuilder(input_data_builder.InputDataBuilder):
120
214
  def with_population(
121
215
  self,
122
216
  df: pd.DataFrame,
123
- population_col: str = constants.POPULATION,
124
- geo_col: str = constants.GEO,
217
+ population_col: str | None = None,
218
+ geo_col: str | None = None,
125
219
  ) -> 'DataFrameInputDataBuilder':
126
220
  """Reads population data from a DataFrame.
127
221
 
128
222
  Args:
129
223
  df: The DataFrame to read the population data from.
130
224
  population_col: The name of the column containing the population values.
131
- If not provided, the default name is `population`.
225
+ If not provided, `self.default_population_column` is used.
132
226
  geo_col: (Optional) The name of the column containing the geo coordinates.
133
- If not provided, the default name is `geo`. If the DataFrame provided
134
- has no geo column, a national model data is assumed and a geo dimension
135
- will be created internally with a single coordinate value
227
+ If not provided, `self.default_geo_column` is used. If the DataFrame
228
+ provided has no geo column, a national model data is assumed and a geo
229
+ dimension will be created internally with a single coordinate value
136
230
  `national_geo`.
137
231
 
138
232
  Returns:
@@ -140,6 +234,9 @@ class DataFrameInputDataBuilder(input_data_builder.InputDataBuilder):
140
234
  """
141
235
  population_df = df.copy()
142
236
 
237
+ population_col = population_col or self.default_population_column
238
+ geo_col = geo_col or self.default_geo_column
239
+
143
240
  ### Validate ###
144
241
  self._validate_cols(population_df, [population_col], [geo_col])
145
242
  self._validate_coords(population_df, geo_col)
@@ -161,22 +258,22 @@ class DataFrameInputDataBuilder(input_data_builder.InputDataBuilder):
161
258
  def with_revenue_per_kpi(
162
259
  self,
163
260
  df: pd.DataFrame,
164
- revenue_per_kpi_col: str = constants.REVENUE_PER_KPI,
165
- time_col: str = constants.TIME,
166
- geo_col: str = constants.GEO,
261
+ revenue_per_kpi_col: str | None = None,
262
+ time_col: str | None = None,
263
+ geo_col: str | None = None,
167
264
  ) -> 'DataFrameInputDataBuilder':
168
265
  """Reads revenue per KPI data from a DataFrame.
169
266
 
170
267
  Args:
171
268
  df: The DataFrame to read the revenue per KPI data from.
172
269
  revenue_per_kpi_col: The name of the column containing the revenue per KPI
173
- values. If not provided, the default name is `revenue_per_kpi`.
270
+ values. If not provided, `self.default_revenue_per_kpi_column` is used.
174
271
  time_col: The name of the column containing the time coordinates. If not
175
- provided, the default name is `time`.
272
+ provided, `self.default_time_column` is used.
176
273
  geo_col: (Optional) The name of the column containing the geo coordinates.
177
- If not provided, the default name is `geo`. If the DataFrame provided
178
- has no geo column, a national model data is assumed and a geo dimension
179
- will be created internally with a single coordinate value
274
+ If not provided, `self.default_geo_column` is used. If the DataFrame
275
+ provided has no geo column, a national model data is assumed and a geo
276
+ dimension will be created internally with a single coordinate value
180
277
  `national_geo`.
181
278
 
182
279
  Returns:
@@ -184,6 +281,12 @@ class DataFrameInputDataBuilder(input_data_builder.InputDataBuilder):
184
281
  """
185
282
  revenue_per_kpi_df = df.copy()
186
283
 
284
+ revenue_per_kpi_col = (
285
+ revenue_per_kpi_col or self.default_revenue_per_kpi_column
286
+ )
287
+ time_col = time_col or self.default_time_column
288
+ geo_col = geo_col or self.default_geo_column
289
+
187
290
  ### Validate ###
188
291
  self._validate_cols(
189
292
  revenue_per_kpi_df,
@@ -213,8 +316,8 @@ class DataFrameInputDataBuilder(input_data_builder.InputDataBuilder):
213
316
  media_cols: list[str],
214
317
  media_spend_cols: list[str],
215
318
  media_channels: list[str],
216
- time_col: str = constants.TIME,
217
- geo_col: str = constants.GEO,
319
+ time_col: str | None = None,
320
+ geo_col: str | None = None,
218
321
  ) -> 'DataFrameInputDataBuilder':
219
322
  """Reads media and media spend data from a DataFrame.
220
323
 
@@ -227,14 +330,15 @@ class DataFrameInputDataBuilder(input_data_builder.InputDataBuilder):
227
330
  `media_cols` and `media_spend_cols` in length. These are also index
228
331
  mapped.
229
332
  time_col: The name of the column containing the time coordinates for media
230
- spend and media time coordinates for media. If not provided, the default
231
- name is `time`. Media time coordinates will be shorter than time
333
+ spend and media time coordinates for media. If not provided,
334
+ `self.default_time_column` is used. Media time coordinates are inferred
335
+ from the same `time_col` and are potentially shorter than time
232
336
  coordinates if media spend values are missing (NaN) for some t in
233
337
  `time`. Media time must be equal or a subset of time.
234
338
  geo_col: (Optional) The name of the column containing the geo coordinates.
235
- If not provided, the default name is `geo`. If the DataFrame provided
236
- has no geo column, a national model data is assumed and a geo dimension
237
- will be created internally with a single coordinate value
339
+ If not provided, `self.default_geo_column` is used. If the DataFrame
340
+ provided has no geo column, a national model data is assumed and a geo
341
+ dimension will be created internally with a single coordinate value
238
342
  `national_geo`.
239
343
 
240
344
  Returns:
@@ -248,6 +352,9 @@ class DataFrameInputDataBuilder(input_data_builder.InputDataBuilder):
248
352
 
249
353
  media_df = df.copy()
250
354
 
355
+ time_col = time_col or self.default_time_column
356
+ geo_col = geo_col or self.default_geo_column
357
+
251
358
  ### Validate ###
252
359
  # For a media dataframe, media and media_spend columns may be the same
253
360
  # (e.g. if using media spend as media execution value), so here we validate
@@ -290,8 +397,8 @@ class DataFrameInputDataBuilder(input_data_builder.InputDataBuilder):
290
397
  frequency_cols: list[str],
291
398
  rf_spend_cols: list[str],
292
399
  rf_channels: list[str],
293
- time_col: str = constants.TIME,
294
- geo_col: str = constants.GEO,
400
+ time_col: str | None = None,
401
+ geo_col: str | None = None,
295
402
  ) -> 'DataFrameInputDataBuilder':
296
403
  """Reads reach, frequency, and rf spend data from a DataFrame.
297
404
 
@@ -305,13 +412,14 @@ class DataFrameInputDataBuilder(input_data_builder.InputDataBuilder):
305
412
  also index mapped.
306
413
  time_col: The name of the column containing the time coordinates for rf
307
414
  spend and media time coordinates for reach and frequency. If not
308
- provided, the default name is `time`. Media time coordinates will be
309
- shorter than time coordinates if media spend values are missing (NaN)
310
- for some t in `time`. Media time must be equal or a subset of time.
415
+ provided, `self.default_time_column` is used. Media time coordinates are
416
+ inferred from the same `time_col` and are potentially shorter than time
417
+ coordinates if media spend values are missing (NaN) for some t in
418
+ `time`. Media time must be equal or a subset of time.
311
419
  geo_col: (Optional) The name of the column containing the geo coordinates.
312
- If not provided, the default name is `geo`. If the DataFrame provided
313
- has no geo column, a national model data is assumed and a geo dimension
314
- will be created internally with a single coordinate value
420
+ If not provided, `self.default_geo_column` is used. If the DataFrame
421
+ provided has no geo column, a national model data is assumed and a geo
422
+ dimension will be created internally with a single coordinate value
315
423
  `national_geo`.
316
424
 
317
425
  Returns:
@@ -331,6 +439,9 @@ class DataFrameInputDataBuilder(input_data_builder.InputDataBuilder):
331
439
 
332
440
  reach_df = df.copy()
333
441
 
442
+ time_col = time_col or self.default_time_column
443
+ geo_col = geo_col or self.default_geo_column
444
+
334
445
  ### Validate ###
335
446
  self._validate_cols(
336
447
  reach_df,
@@ -389,8 +500,8 @@ class DataFrameInputDataBuilder(input_data_builder.InputDataBuilder):
389
500
  df: pd.DataFrame,
390
501
  organic_media_cols: list[str],
391
502
  organic_media_channels: list[str] | None = None,
392
- media_time_col: str = constants.MEDIA_TIME,
393
- geo_col: str = constants.GEO,
503
+ media_time_col: str | None = None,
504
+ geo_col: str | None = None,
394
505
  ) -> 'DataFrameInputDataBuilder':
395
506
  """Reads organic media data from a DataFrame.
396
507
 
@@ -403,11 +514,11 @@ class DataFrameInputDataBuilder(input_data_builder.InputDataBuilder):
403
514
  provided, must match `organic_media_cols` in length. This is index
404
515
  mapped.
405
516
  media_time_col: The name of the column containing the media time
406
- coordinates. If not provided, the default name is `media_time`.
517
+ coordinates. If not provided, `self.default_media_time_column` is used.
407
518
  geo_col: (Optional) The name of the column containing the geo coordinates.
408
- If not provided, the default name is `geo`. If the DataFrame provided
409
- has no geo column, a national model data is assumed and a geo dimension
410
- will be created internally with a single coordinate value
519
+ If not provided, `self.default_geo_column` is used. If the DataFrame
520
+ provided has no geo column, a national model data is assumed and a geo
521
+ dimension will be created internally with a single coordinate value
411
522
  `national_geo`.
412
523
 
413
524
  Returns:
@@ -418,6 +529,9 @@ class DataFrameInputDataBuilder(input_data_builder.InputDataBuilder):
418
529
 
419
530
  organic_media_df = df.copy()
420
531
 
532
+ media_time_col = media_time_col or self.default_media_time_column
533
+ geo_col = geo_col or self.default_geo_column
534
+
421
535
  ### Validate ###
422
536
  if not organic_media_channels:
423
537
  organic_media_channels = organic_media_cols
@@ -456,8 +570,8 @@ class DataFrameInputDataBuilder(input_data_builder.InputDataBuilder):
456
570
  organic_reach_cols: list[str],
457
571
  organic_frequency_cols: list[str],
458
572
  organic_rf_channels: list[str],
459
- media_time_col: str = constants.MEDIA_TIME,
460
- geo_col: str = constants.GEO,
573
+ media_time_col: str | None = None,
574
+ geo_col: str | None = None,
461
575
  ) -> 'DataFrameInputDataBuilder':
462
576
  """Reads organic reach and organic frequency data from a DataFrame.
463
577
 
@@ -471,11 +585,11 @@ class DataFrameInputDataBuilder(input_data_builder.InputDataBuilder):
471
585
  match `organic_reach_cols` and `organic_frequency_cols` in length. These
472
586
  are also index mapped.
473
587
  media_time_col: The name of the column containing the media time
474
- coordinates. If not provided, the default name is `media_time`.
588
+ coordinates. If not provided, `self.default_media_time_column` is used.
475
589
  geo_col: (Optional) The name of the column containing the geo coordinates.
476
- If not provided, the default name is `geo`. If the DataFrame provided
477
- has no geo column, a national model data is assumed and a geo dimension
478
- will be created internally with a single coordinate value
590
+ If not provided, `self.default_geo_column` is used. If the DataFrame
591
+ provided has no geo column, a national model data is assumed and a geo
592
+ dimension will be created internally with a single coordinate value
479
593
  `national_geo`.
480
594
 
481
595
  Returns:
@@ -494,6 +608,9 @@ class DataFrameInputDataBuilder(input_data_builder.InputDataBuilder):
494
608
 
495
609
  organic_reach_frequency_df = df.copy()
496
610
 
611
+ media_time_col = media_time_col or self.default_media_time_column
612
+ geo_col = geo_col or self.default_geo_column
613
+
497
614
  ### Validate ###
498
615
  self._validate_cols(
499
616
  organic_reach_frequency_df,
@@ -540,8 +657,8 @@ class DataFrameInputDataBuilder(input_data_builder.InputDataBuilder):
540
657
  self,
541
658
  df: pd.DataFrame,
542
659
  non_media_treatment_cols: list[str],
543
- time_col: str = constants.TIME,
544
- geo_col: str = constants.GEO,
660
+ time_col: str | None = None,
661
+ geo_col: str | None = None,
545
662
  ) -> 'DataFrameInputDataBuilder':
546
663
  """Reads non-media treatments data from a DataFrame.
547
664
 
@@ -550,11 +667,11 @@ class DataFrameInputDataBuilder(input_data_builder.InputDataBuilder):
550
667
  non_media_treatment_cols: The names of the columns containing the
551
668
  non-media treatments values.
552
669
  time_col: The name of the column containing the time coordinates. If not
553
- provided, the default name is `time`.
670
+ provided, `self.default_time_column` is used.
554
671
  geo_col: (Optional) The name of the column containing the geo coordinates.
555
- If not provided, the default name is `geo`. If the DataFrame provided
556
- has no geo column, a national model data is assumed and a geo dimension
557
- will be created internally with a single coordinate value
672
+ If not provided, `self.default_geo_column` is used. If the DataFrame
673
+ provided has no geo column, a national model data is assumed and a geo
674
+ dimension will be created internally with a single coordinate value
558
675
  `national_geo`.
559
676
 
560
677
  Returns:
@@ -569,6 +686,9 @@ class DataFrameInputDataBuilder(input_data_builder.InputDataBuilder):
569
686
 
570
687
  non_media_treatments_df = df.copy()
571
688
 
689
+ time_col = time_col or self.default_time_column
690
+ geo_col = geo_col or self.default_geo_column
691
+
572
692
  ### Validate ###
573
693
  self._validate_cols(
574
694
  non_media_treatments_df,
meridian/model/model.py CHANGED
@@ -295,10 +295,6 @@ class Meridian:
295
295
  def is_national(self) -> bool:
296
296
  return self.n_geos == 1
297
297
 
298
- @property
299
- def _sigma_shape(self) -> int:
300
- return len(self.input_data.geo) if self.unique_sigma_for_each_geo else 1
301
-
302
298
  @functools.cached_property
303
299
  def knot_info(self) -> knots.KnotInfo:
304
300
  return knots.get_knot_info(
@@ -389,6 +385,7 @@ class Meridian:
389
385
  @functools.cached_property
390
386
  def unique_sigma_for_each_geo(self) -> bool:
391
387
  if self.is_national:
388
+ # Should evaluate to False.
392
389
  return constants.NATIONAL_MODEL_SPEC_ARGS[
393
390
  constants.UNIQUE_SIGMA_FOR_EACH_GEO
394
391
  ]
@@ -449,7 +446,7 @@ class Meridian:
449
446
  n_organic_rf_channels=self.n_organic_rf_channels,
450
447
  n_controls=self.n_controls,
451
448
  n_non_media_channels=self.n_non_media_channels,
452
- sigma_shape=self._sigma_shape,
449
+ unique_sigma_for_each_geo=self.unique_sigma_for_each_geo,
453
450
  n_knots=self.knot_info.n_knots,
454
451
  is_national=self.is_national,
455
452
  set_total_media_contribution_prior=self._set_total_media_contribution_prior,
@@ -663,10 +660,6 @@ class Meridian:
663
660
  self._validate_injected_inference_data_group_coord(
664
661
  inference_data, group, constants.TIME, self.n_times
665
662
  )
666
- if not self.model_spec.unique_sigma_for_each_geo:
667
- self._validate_injected_inference_data_group_coord(
668
- inference_data, group, constants.SIGMA_DIM, self._sigma_shape
669
- )
670
663
  self._validate_injected_inference_data_group_coord(
671
664
  inference_data,
672
665
  group,
@@ -1429,7 +1422,7 @@ class Meridian:
1429
1422
  if self.unique_sigma_for_each_geo:
1430
1423
  inference_dims[constants.SIGMA] = [constants.GEO]
1431
1424
  else:
1432
- inference_dims[constants.SIGMA] = [constants.SIGMA_DIM]
1425
+ inference_dims[constants.SIGMA] = []
1433
1426
 
1434
1427
  return {
1435
1428
  param: [constants.CHAIN, constants.DRAW] + list(dims)
@@ -528,7 +528,7 @@ class PosteriorMCMCSampler:
528
528
  be a positive integer. For more information, see `tf.while_loop`.
529
529
  seed: An `int32[2]` Tensor or a Python list or tuple of 2 `int`s, which
530
530
  will be treated as stateless seeds; or a Python `int` or `None`, which
531
- will be treated as stateful seeds. See [tfp.random.sanitize_seed]
531
+ will be converted into a stateless seed. See [tfp.random.sanitize_seed]
532
532
  (https://www.tensorflow.org/probability/api_docs/python/tfp/random/sanitize_seed).
533
533
  **pins: These are used to condition the provided joint distribution, and
534
534
  are passed directly to `joint_dist.experimental_pin(**pins)`.
@@ -547,6 +547,8 @@ class PosteriorMCMCSampler:
547
547
  " [tfp.random.sanitize_seed](https://www.tensorflow.org/probability/api_docs/python/tfp/random/sanitize_seed)"
548
548
  " for details."
549
549
  )
550
+ if seed is not None and isinstance(seed, int):
551
+ seed = (seed, seed)
550
552
  seed = tfp.random.sanitize_seed(seed) if seed is not None else None
551
553
  n_chains_list = [n_chains] if isinstance(n_chains, int) else n_chains
552
554
  total_chains = np.sum(n_chains_list)
@@ -509,7 +509,7 @@ class PriorDistribution:
509
509
  n_organic_rf_channels: int,
510
510
  n_controls: int,
511
511
  n_non_media_channels: int,
512
- sigma_shape: int,
512
+ unique_sigma_for_each_geo: bool,
513
513
  n_knots: int,
514
514
  is_national: bool,
515
515
  set_total_media_contribution_prior: bool,
@@ -527,9 +527,9 @@ class PriorDistribution:
527
527
  used.
528
528
  n_controls: Number of controls used.
529
529
  n_non_media_channels: Number of non-media channels used.
530
- sigma_shape: A number describing the shape of the sigma parameter. It's
531
- either `1` (if `sigma_for_each_geo=False`) or `n_geos` (if
532
- `sigma_for_each_geo=True`). For more information, see `ModelSpec`.
530
+ unique_sigma_for_each_geo: A boolean indicator whether to use the same
531
+ sigma parameter for all geos. Only used if `n_geos > 1`. For more
532
+ information, see `ModelSpec`.
533
533
  n_knots: Number of knots used.
534
534
  is_national: A boolean indicator whether the prior distribution will be
535
535
  adapted for a national model.
@@ -801,6 +801,9 @@ class PriorDistribution:
801
801
  slope_orf = tfp.distributions.BatchBroadcast(
802
802
  self.slope_orf, n_organic_rf_channels, name=constants.SLOPE_ORF
803
803
  )
804
+
805
+ # If `unique_sigma_for_each_geo == False`, then make a scalar batch.
806
+ sigma_shape = n_geos if (n_geos > 1 and unique_sigma_for_each_geo) else []
804
807
  sigma = tfp.distributions.BatchBroadcast(
805
808
  self.sigma, sigma_shape, name=constants.SIGMA
806
809
  )
@@ -510,6 +510,8 @@ class PriorDistributionSampler:
510
510
  tf.keras.utils.set_random_seed(1)
511
511
 
512
512
  prior = mmm.prior_broadcast
513
+ # `sample_shape` is prepended to the shape of each BatchBroadcast in `prior`
514
+ # when it is sampled.
513
515
  sample_shape = [1, n_draws]
514
516
  sample_kwargs = {constants.SAMPLE_SHAPE: sample_shape, constants.SEED: seed}
515
517
 
meridian/version.py CHANGED
@@ -14,4 +14,4 @@
14
14
 
15
15
  """Module for Meridian version."""
16
16
 
17
- __version__ = "1.1.4"
17
+ __version__ = "1.1.6"