google-meridian 1.2.1__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. google_meridian-1.3.1.dist-info/METADATA +209 -0
  2. google_meridian-1.3.1.dist-info/RECORD +76 -0
  3. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/top_level.txt +1 -0
  4. meridian/analysis/__init__.py +2 -0
  5. meridian/analysis/analyzer.py +179 -105
  6. meridian/analysis/formatter.py +2 -2
  7. meridian/analysis/optimizer.py +227 -87
  8. meridian/analysis/review/__init__.py +20 -0
  9. meridian/analysis/review/checks.py +721 -0
  10. meridian/analysis/review/configs.py +110 -0
  11. meridian/analysis/review/constants.py +40 -0
  12. meridian/analysis/review/results.py +544 -0
  13. meridian/analysis/review/reviewer.py +186 -0
  14. meridian/analysis/summarizer.py +21 -34
  15. meridian/analysis/templates/chips.html.jinja +12 -0
  16. meridian/analysis/test_utils.py +27 -5
  17. meridian/analysis/visualizer.py +41 -57
  18. meridian/backend/__init__.py +457 -118
  19. meridian/backend/test_utils.py +162 -0
  20. meridian/constants.py +39 -3
  21. meridian/model/__init__.py +1 -0
  22. meridian/model/eda/__init__.py +3 -0
  23. meridian/model/eda/constants.py +21 -0
  24. meridian/model/eda/eda_engine.py +1309 -196
  25. meridian/model/eda/eda_outcome.py +200 -0
  26. meridian/model/eda/eda_spec.py +84 -0
  27. meridian/model/eda/meridian_eda.py +220 -0
  28. meridian/model/knots.py +55 -49
  29. meridian/model/media.py +10 -8
  30. meridian/model/model.py +79 -16
  31. meridian/model/model_test_data.py +53 -0
  32. meridian/model/posterior_sampler.py +39 -32
  33. meridian/model/prior_distribution.py +12 -2
  34. meridian/model/prior_sampler.py +146 -90
  35. meridian/model/spec.py +7 -8
  36. meridian/model/transformers.py +11 -3
  37. meridian/version.py +1 -1
  38. schema/__init__.py +18 -0
  39. schema/serde/__init__.py +26 -0
  40. schema/serde/constants.py +48 -0
  41. schema/serde/distribution.py +515 -0
  42. schema/serde/eda_spec.py +192 -0
  43. schema/serde/function_registry.py +143 -0
  44. schema/serde/hyperparameters.py +363 -0
  45. schema/serde/inference_data.py +105 -0
  46. schema/serde/marketing_data.py +1321 -0
  47. schema/serde/meridian_serde.py +413 -0
  48. schema/serde/serde.py +47 -0
  49. schema/serde/test_data.py +4608 -0
  50. schema/utils/__init__.py +17 -0
  51. schema/utils/time_record.py +156 -0
  52. google_meridian-1.2.1.dist-info/METADATA +0 -409
  53. google_meridian-1.2.1.dist-info/RECORD +0 -52
  54. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/WHEEL +0 -0
  55. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,200 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Meridian EDA Outcome."""
16
+
17
+ import dataclasses
18
+ import enum
19
+ import typing
20
+ import pandas as pd
21
+ import xarray as xr
22
+
23
+ __all__ = [
24
+ "EDASeverity",
25
+ "EDAFinding",
26
+ "AnalysisLevel",
27
+ "AnalysisArtifact",
28
+ "PairwiseCorrArtifact",
29
+ "StandardDeviationArtifact",
30
+ "VIFArtifact",
31
+ "KpiInvariabilityArtifact",
32
+ "EDACheckType",
33
+ "ArtifactType",
34
+ "EDAOutcome",
35
+ ]
36
+
37
+
38
+ @enum.unique
39
+ class EDASeverity(enum.Enum):
40
+ """Enumeration for the severity of an EDA check's finding."""
41
+
42
+ # For the non-critical findings.
43
+ INFO = enum.auto()
44
+ # For the non-critical findings that require user attention.
45
+ ATTENTION = enum.auto()
46
+ # For unacceptable, model-blocking data errors.
47
+ ERROR = enum.auto()
48
+
49
+
50
+ @dataclasses.dataclass(frozen=True)
51
+ class EDAFinding:
52
+ """Encapsulates a single, specific finding from an EDA check.
53
+
54
+ Attributes:
55
+ severity: The severity level of the finding.
56
+ explanation: A human-readable description about the EDA check and a
57
+ potential actionable guidance on how to address or interpret this
58
+ specific finding.
59
+ """
60
+
61
+ severity: EDASeverity
62
+ explanation: str
63
+
64
+
65
+ @enum.unique
66
+ class AnalysisLevel(enum.Enum):
67
+ """Enumeration for the level of an analysis.
68
+
69
+ Attributes:
70
+ OVERALL: Computed across all geos and time. When the analysis is performed
71
+ on national data, this level is equivalent to the NATIONAL level.
72
+ NATIONAL: Computed across time for data aggregated to the national level.
73
+ When the analysis is performed on national data, this level is equivalent
74
+ to the OVERALL level.
75
+ GEO: Computed across time, for each geo.
76
+ """
77
+
78
+ OVERALL = enum.auto()
79
+ NATIONAL = enum.auto()
80
+ GEO = enum.auto()
81
+
82
+
83
+ @dataclasses.dataclass(frozen=True)
84
+ class AnalysisArtifact:
85
+ """Base dataclass for analysis artifacts.
86
+
87
+ Specific EDA artifacts should inherit from this class to store check-specific
88
+ data for downstream processing (e.g., plotting).
89
+
90
+ Attributes:
91
+ level: The level of the analysis.
92
+ """
93
+
94
+ level: AnalysisLevel
95
+
96
+
97
+ @dataclasses.dataclass(frozen=True)
98
+ class PairwiseCorrArtifact(AnalysisArtifact):
99
+ """Encapsulates artifacts from a single pairwise correlation analysis.
100
+
101
+ Attributes:
102
+ corr_matrix: Pairwise correlation matrix.
103
+ extreme_corr_var_pairs: DataFrame of variable pairs exceeding the
104
+ correlation threshold.
105
+ extreme_corr_threshold: The threshold used to identify extreme correlation
106
+ pairs.
107
+ """
108
+
109
+ corr_matrix: xr.DataArray
110
+ extreme_corr_var_pairs: pd.DataFrame
111
+ extreme_corr_threshold: float
112
+
113
+
114
+ @dataclasses.dataclass(frozen=True)
115
+ class StandardDeviationArtifact(AnalysisArtifact):
116
+ """Encapsulates artifacts from a standard deviation analysis.
117
+
118
+ Attributes:
119
+ variable: The variable for which standard deviation is calculated.
120
+ std_ds: Dataset with stdev_with_outliers and stdev_without_outliers.
121
+ outlier_df: DataFrame with outliers.
122
+ """
123
+
124
+ variable: str
125
+ std_ds: xr.Dataset
126
+ outlier_df: pd.DataFrame
127
+
128
+
129
+ @dataclasses.dataclass(frozen=True)
130
+ class VIFArtifact(AnalysisArtifact):
131
+ """Encapsulates artifacts from a single VIF analysis.
132
+
133
+ Attributes:
134
+ vif_da: DataArray with VIF values.
135
+ outlier_df: DataFrame with extreme VIF values.
136
+ """
137
+
138
+ vif_da: xr.DataArray
139
+ outlier_df: pd.DataFrame
140
+
141
+
142
+ @dataclasses.dataclass(frozen=True)
143
+ class KpiInvariabilityArtifact(AnalysisArtifact):
144
+ """Encapsulates artifacts from a KPI invariability analysis.
145
+
146
+ Attributes:
147
+ kpi_da: DataArray of the KPI that is examined for variability.
148
+ kpi_stdev: The standard deviation of the KPI, which is used to test the KPI
149
+ invariability.
150
+ """
151
+
152
+ kpi_da: xr.DataArray
153
+ kpi_stdev: xr.DataArray
154
+
155
+
156
+ @enum.unique
157
+ class EDACheckType(enum.Enum):
158
+ """Enumeration for the type of an EDA check."""
159
+
160
+ PAIRWISE_CORRELATION = enum.auto()
161
+ STANDARD_DEVIATION = enum.auto()
162
+ MULTICOLLINEARITY = enum.auto()
163
+ KPI_INVARIABILITY = enum.auto()
164
+
165
+
166
+ ArtifactType = typing.TypeVar("ArtifactType", bound="AnalysisArtifact")
167
+
168
+
169
+ @dataclasses.dataclass(frozen=True)
170
+ class EDAOutcome(typing.Generic[ArtifactType]):
171
+ """A dataclass for the outcomes of a single EDA check function.
172
+
173
+ An EDA check function can discover multiple issues. This object groups all of
174
+ those individual issues, reported as a list of `EDAFinding` objects.
175
+
176
+ Attributes:
177
+ check_type: The type of the EDA check that is being performed.
178
+ findings: A list of all individual issues discovered by the check.
179
+ analysis_artifacts: A list of analysis artifacts from the EDA check.
180
+ """
181
+
182
+ check_type: EDACheckType
183
+ findings: list[EDAFinding]
184
+ analysis_artifacts: list[ArtifactType]
185
+
186
+ @property
187
+ def get_geo_artifact(self) -> ArtifactType | None:
188
+ """Returns the geo-level analysis artifact."""
189
+ for artifact in self.analysis_artifacts:
190
+ if artifact.level == AnalysisLevel.GEO:
191
+ return artifact
192
+ return None
193
+
194
+ @property
195
+ def get_national_artifact(self) -> ArtifactType | None:
196
+ """Returns the national-level analysis artifact."""
197
+ for artifact in self.analysis_artifacts:
198
+ if artifact.level == AnalysisLevel.NATIONAL:
199
+ return artifact
200
+ return None
@@ -0,0 +1,84 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Meridian EDA Spec."""
16
+
17
+ import dataclasses
18
+ from typing import Any, Callable, Dict, TypeAlias
19
+
20
+ __all__ = [
21
+ "AggregationConfig",
22
+ "VIFSpec",
23
+ "EDASpec",
24
+ ]
25
+
26
+ AggregationFn: TypeAlias = Callable[..., Any]
27
+ AggregationMap: TypeAlias = Dict[str, AggregationFn]
28
+ _DEFAULT_VIF_THRESHOLD = 1000
29
+
30
+
31
+ @dataclasses.dataclass(frozen=True, kw_only=True)
32
+ class AggregationConfig:
33
+ """A configuration for customizing variable aggregation functions.
34
+
35
+ The aggregation function can be called in the form `f(x, axis=axis, **kwargs)`
36
+ to return the result of reducing an `np.ndarray` over an integer valued axis.
37
+ It's recommended to explicitly define the aggregation functions instead of
38
+ using lambdas.
39
+
40
+ Attributes:
41
+ control_variables: A dictionary mapping control variable names to
42
+ aggregation functions. Defaults to `np.sum` if a variable is not
43
+ specified.
44
+ non_media_treatments: A dictionary mapping non-media variable names to
45
+ aggregation functions. Defaults to `np.sum` if a variable is not
46
+ specified.
47
+ """
48
+
49
+ control_variables: AggregationMap = dataclasses.field(default_factory=dict)
50
+ non_media_treatments: AggregationMap = dataclasses.field(default_factory=dict)
51
+
52
+
53
+ @dataclasses.dataclass(frozen=True, kw_only=True)
54
+ class VIFSpec:
55
+ """A spec for the EDA VIF check.
56
+
57
+ Attributes:
58
+ geo_threshold: The threshold for geo-level VIF.
59
+ overall_threshold: The threshold for overall VIF.
60
+ national_threshold: The threshold for national VIF.
61
+ """
62
+
63
+ geo_threshold: float = _DEFAULT_VIF_THRESHOLD
64
+ overall_threshold: float = _DEFAULT_VIF_THRESHOLD
65
+ national_threshold: float = _DEFAULT_VIF_THRESHOLD
66
+
67
+
68
+ @dataclasses.dataclass(frozen=True, kw_only=True)
69
+ class EDASpec:
70
+ """A container for all user-configurable EDA check specs.
71
+
72
+ This object allows users to customize the behavior of the EDA checks
73
+ by passing a single configuration object into the EDAEngine constructor,
74
+ avoiding a large number of arguments.
75
+
76
+ Attributes:
77
+ aggregation_config: A configuration object for custom aggregation functions.
78
+ vif_spec: A configuration object for the EDA VIF check.
79
+ """
80
+
81
+ aggregation_config: AggregationConfig = dataclasses.field(
82
+ default_factory=AggregationConfig
83
+ )
84
+ vif_spec: VIFSpec = dataclasses.field(default_factory=VIFSpec)
@@ -0,0 +1,220 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Module containing Meridian related exploratory data analysis (EDA) functionalities."""
16
+ from __future__ import annotations
17
+
18
+ from typing import Literal, TYPE_CHECKING, Union
19
+
20
+ import altair as alt
21
+ from meridian import constants
22
+ from meridian.model.eda import constants as eda_constants
23
+ import pandas as pd
24
+
25
+ if TYPE_CHECKING:
26
+ from meridian.model import model # pylint: disable=g-bad-import-order,g-import-not-at-top
27
+
28
+ __all__ = [
29
+ 'MeridianEDA',
30
+ ]
31
+
32
+
33
+ class MeridianEDA:
34
+ """Class for running pre-modeling exploratory data analysis for Meridian InputData."""
35
+
36
+ _PAIRWISE_CORR_COLOR_SCALE = alt.Scale(
37
+ domain=[-1.0, 0.0, 1.0],
38
+ range=['#1f78b4', '#f7f7f7', '#e34a33'], # Blue-light grey-Orange
39
+ type='linear',
40
+ )
41
+
42
+ def __init__(
43
+ self,
44
+ meridian: model.Meridian,
45
+ ):
46
+ self._meridian = meridian
47
+
48
+ def generate_and_save_report(self, filename: str, filepath: str):
49
+ """Generates and saves the 2 page HTML report containing findings in EDA about given InputData.
50
+
51
+ Args:
52
+ filename: The filename for the generated HTML output.
53
+ filepath: The path to the directory where the file will be saved.
54
+ """
55
+ # TODO: Implement.
56
+ raise NotImplementedError()
57
+
58
+ def plot_pairwise_correlation(
59
+ self, geos: Union[int, list[str], Literal['nationalize']] = 1
60
+ ) -> alt.Chart:
61
+ """Plots the Pairwise Correlation data.
62
+
63
+ Args:
64
+ geos: Defines which geos to plot. - int: The number of top geos to plot,
65
+ ranked by population. - list[str]: A specific list of geo names to plot.
66
+ - 'nationalize': Aggregates all geos into a single national view.
67
+ Defaults to 1 (plotting the top geo). If the data is already at a
68
+ national level, this parameter is ignored and a national plot is
69
+ generated.
70
+
71
+ Returns:
72
+ Altair chart(s) of the Pairwise Correlation data.
73
+ """
74
+ geos_to_plot = self._validate_and_get_geos_to_plot(geos)
75
+ is_national = self._meridian.is_national
76
+ nationalize_geos = geos == 'nationalize'
77
+
78
+ if is_national or nationalize_geos:
79
+ pairwise_corr_artifact = (
80
+ self._meridian.eda_engine.check_national_pairwise_corr().get_national_artifact
81
+ )
82
+ if pairwise_corr_artifact is None:
83
+ raise ValueError('EDAOutcome does not have national artifact.')
84
+ else:
85
+ pairwise_corr_artifact = (
86
+ self._meridian.eda_engine.check_geo_pairwise_corr().get_geo_artifact
87
+ )
88
+ if pairwise_corr_artifact is None:
89
+ raise ValueError('EDAOutcome does not have geo artifact.')
90
+ pairwise_corr_data = pairwise_corr_artifact.corr_matrix.to_dataframe()
91
+
92
+ charts = []
93
+ for geo_to_plot in geos_to_plot:
94
+ title = (
95
+ 'Pairwise correlations among all treatments and controls for'
96
+ f' {geo_to_plot}'
97
+ )
98
+
99
+ if not (is_national or nationalize_geos):
100
+ plot_data = (
101
+ pairwise_corr_data.xs(geo_to_plot, level=constants.GEO)
102
+ .rename_axis(
103
+ index=[eda_constants.VARIABLE_1, eda_constants.VARIABLE_2]
104
+ )
105
+ .reset_index()
106
+ )
107
+ else:
108
+ plot_data = pairwise_corr_data.rename_axis(
109
+ index=[eda_constants.VARIABLE_1, eda_constants.VARIABLE_2]
110
+ ).reset_index()
111
+ plot_data.columns = [
112
+ eda_constants.VARIABLE_1,
113
+ eda_constants.VARIABLE_2,
114
+ eda_constants.CORRELATION,
115
+ ]
116
+ unique_variables = plot_data[eda_constants.VARIABLE_1].unique()
117
+ variable_to_index = {name: i for i, name in enumerate(unique_variables)}
118
+
119
+ plot_data['idx1'] = plot_data[eda_constants.VARIABLE_1].map(
120
+ variable_to_index
121
+ )
122
+ plot_data['idx2'] = plot_data[eda_constants.VARIABLE_2].map(
123
+ variable_to_index
124
+ )
125
+ lower_triangle_data = plot_data[plot_data['idx2'] > plot_data['idx1']]
126
+
127
+ charts.append(
128
+ self._plot_2d_heatmap(lower_triangle_data, title, unique_variables)
129
+ )
130
+ final_chart = (
131
+ alt.vconcat(*charts)
132
+ .resolve_legend(color='independent')
133
+ .configure_axis(labelAngle=315)
134
+ .configure_title(anchor='start')
135
+ .configure_view(stroke=None)
136
+ )
137
+ return final_chart
138
+
139
+ def _plot_2d_heatmap(
140
+ self, data: pd.DataFrame, title: str, unique_variables: list[str]
141
+ ) -> alt.Chart:
142
+ """Plots a 2D heatmap."""
143
+ # Base chart with position encodings
144
+ base = (
145
+ alt.Chart(data)
146
+ .encode(
147
+ x=alt.X(
148
+ f'{eda_constants.VARIABLE_1}:N',
149
+ title=None,
150
+ sort=unique_variables,
151
+ scale=alt.Scale(domain=unique_variables),
152
+ ),
153
+ y=alt.Y(
154
+ f'{eda_constants.VARIABLE_2}:N',
155
+ title=None,
156
+ sort=unique_variables,
157
+ scale=alt.Scale(domain=unique_variables),
158
+ ),
159
+ )
160
+ .properties(title=title)
161
+ )
162
+
163
+ # Heatmap layer (rectangles)
164
+ heatmap = base.mark_rect().encode(
165
+ color=alt.Color(
166
+ f'{eda_constants.CORRELATION}:Q',
167
+ scale=self._PAIRWISE_CORR_COLOR_SCALE,
168
+ legend=alt.Legend(title=eda_constants.CORRELATION),
169
+ ),
170
+ tooltip=[
171
+ eda_constants.VARIABLE_1,
172
+ eda_constants.VARIABLE_2,
173
+ alt.Tooltip(f'{eda_constants.CORRELATION}:Q', format='.3f'),
174
+ ],
175
+ )
176
+
177
+ # Text annotation layer (values)
178
+ text = base.mark_text().encode(
179
+ text=alt.Text(f'{eda_constants.CORRELATION}:Q', format='.3f'),
180
+ color=alt.value('black'),
181
+ )
182
+
183
+ # Combine layers and apply final configurations
184
+ chart = (heatmap + text).properties(width=350, height=350)
185
+
186
+ return chart
187
+
188
+ def _generate_pairwise_correlation_report(self) -> str:
189
+ """Creates the HTML snippet for Pairwise Correlation report section."""
190
+ # TODO: Implement.
191
+ raise NotImplementedError()
192
+
193
+ def _validate_and_get_geos_to_plot(
194
+ self, geos: Union[int, list[str], Literal['nationalize']]
195
+ ) -> list[str]:
196
+ """Validates and returns the geos to plot."""
197
+ ## Validate
198
+ is_national = self._meridian.is_national
199
+ if is_national or geos == 'nationalize':
200
+ geos_to_plot = [constants.NATIONAL_MODEL_DEFAULT_GEO_NAME]
201
+ elif isinstance(geos, int):
202
+ if geos > len(self._meridian.input_data.geo) or geos <= 0:
203
+ raise ValueError(
204
+ 'geos must be a positive integer less than or equal to the number'
205
+ ' of geos in the data.'
206
+ )
207
+ geos_to_plot = self._meridian.input_data.get_n_top_largest_geos(geos)
208
+ else:
209
+ geos_to_plot = geos
210
+
211
+ if (
212
+ not is_national and geos != 'nationalize'
213
+ ): # if national then geos_to_plot will be ignored
214
+ for geo in geos_to_plot:
215
+ if geo not in self._meridian.input_data.geo:
216
+ raise ValueError(f'Geo {geo} does not exist in the data.')
217
+ if len(geos_to_plot) != len(set(geos_to_plot)):
218
+ raise ValueError('geos must not contain duplicate values.')
219
+
220
+ return geos_to_plot
meridian/model/knots.py CHANGED
@@ -250,12 +250,28 @@ class AKS:
250
250
  def __init__(self, data: input_data.InputData):
251
251
  self._data = data
252
252
 
253
- def automatic_knot_selection(self) -> AKSResult:
253
+ def automatic_knot_selection(
254
+ self,
255
+ base_penalty: np.ndarray | None = None,
256
+ min_internal_knots: int = 1,
257
+ max_internal_knots: int | None = None,
258
+ ) -> AKSResult:
254
259
  """Calculates the optimal number of knots for Meridian model using Automatic knot selection with A-spline.
255
260
 
261
+ Args:
262
+ base_penalty: A vector of positive penalty values. The adaptive spline
263
+ regression is performed for every value of penalty.
264
+ min_internal_knots: The minimum number of internal knots. Defaults to 1.
265
+ max_internal_knots: The maximum number of internal knots. If None, this
266
+ value is calculated as the number of initial knots minus the total count
267
+ of all treatment and control variables. Otherwise, the user-provided
268
+ value will be used.
269
+
256
270
  Returns:
257
271
  Selected knots and the corresponding B-spline model.
258
272
  """
273
+ if base_penalty is None:
274
+ base_penalty = self._BASE_PENALTY
259
275
  n_times = len(self._data.time)
260
276
  n_geos = len(self._data.geo)
261
277
 
@@ -265,11 +281,12 @@ class AKS:
265
281
  np.repeat([range(n_times)], n_geos, axis=0), (n_geos * n_times,)
266
282
  )
267
283
 
268
- knots, min_internal_knots, max_internal_knots = (
269
- self._calculate_initial_knots(x)
284
+ knots = self._calculate_initial_knots(x)
285
+ max_internal_knots = self._calculate_and_validate_max_internal_knots(
286
+ knots, min_internal_knots, max_internal_knots
270
287
  )
271
288
  geo_scaling_factor = 1 / np.sqrt(len(self._data.geo))
272
- penalty = geo_scaling_factor * self._BASE_PENALTY
289
+ penalty = geo_scaling_factor * base_penalty
273
290
 
274
291
  aspline = self.aspline(x=x, y=y, knots=knots, penalty=penalty)
275
292
  n_knots = np.array([len(x) for x in aspline[constants.KNOTS_SELECTED]])
@@ -288,7 +305,7 @@ class AKS:
288
305
  def _calculate_initial_knots(
289
306
  self,
290
307
  x: np.ndarray,
291
- ) -> tuple[np.ndarray, int, int]:
308
+ ) -> np.ndarray:
292
309
  """Calculates initial knots based on unique x values.
293
310
 
294
311
  Args:
@@ -298,38 +315,7 @@ class AKS:
298
315
  Returns:
299
316
  A tuple containing:
300
317
  - The calculated knots.
301
- - The minimum number of internal knots.
302
- - The maximum number of internal knots.
303
318
  """
304
- n_media = (
305
- len(self._data.media_channel)
306
- if self._data.media_channel is not None
307
- else 0
308
- )
309
- n_rf = (
310
- len(self._data.rf_channel) if self._data.rf_channel is not None else 0
311
- )
312
- n_organic_media = (
313
- len(self._data.organic_media_channel)
314
- if self._data.organic_media_channel is not None
315
- else 0
316
- )
317
- n_organic_rf = (
318
- len(self._data.organic_rf_channel)
319
- if self._data.organic_rf_channel is not None
320
- else 0
321
- )
322
- n_non_media = (
323
- len(self._data.non_media_channel)
324
- if self._data.non_media_channel is not None
325
- else 0
326
- )
327
- n_controls = (
328
- len(self._data.control_variable)
329
- if self._data.control_variable is not None
330
- else 0
331
- )
332
-
333
319
  x_vals_unique = np.unique(x)
334
320
  min_x_data, max_x_data = x_vals_unique.min(), x_vals_unique.max()
335
321
  knots = x_vals_unique[
@@ -340,18 +326,39 @@ class AKS:
340
326
  # fewer degree of freedom than the total number of knots to function.
341
327
  # Dropping the final knot is a natural and practical choice because it
342
328
  # often has minimal impact on the overall model fit.
343
- knots = knots[:-1]
344
- min_internal_knots = 1
345
-
346
- max_internal_knots = (
347
- len(knots)
348
- - n_media
349
- - n_rf
350
- - n_organic_media
351
- - n_organic_rf
352
- - n_non_media
353
- - n_controls
329
+ return knots[:-1]
330
+
331
+ def _calculate_and_validate_max_internal_knots(
332
+ self,
333
+ knots: np.ndarray,
334
+ min_internal_knots: int = 1,
335
+ max_internal_knots: int | None = None,
336
+ ) -> int:
337
+ """Calculates the max internal knots, and validates the range.
338
+
339
+ Args:
340
+ knots: Initial knots to calculate max internal knots from.
341
+ min_internal_knots: The minimum number of internal knots.
342
+ max_internal_knots: The maximum number of internal knots. If None, this
343
+ value will be calculated, otherwise will use user provided value.
344
+
345
+ Returns:
346
+ The maximum number of internal knots.
347
+ """
348
+ n_features = sum(
349
+ len(feature) if feature is not None else 0
350
+ for feature in [
351
+ self._data.media_channel,
352
+ self._data.rf_channel,
353
+ self._data.organic_media_channel,
354
+ self._data.organic_rf_channel,
355
+ self._data.non_media_channel,
356
+ self._data.control_variable,
357
+ ]
354
358
  )
359
+
360
+ if max_internal_knots is None:
361
+ max_internal_knots = len(knots) - n_features
355
362
  if min_internal_knots > len(knots):
356
363
  raise ValueError(
357
364
  'The minimum number of internal knots cannot be greater than the'
@@ -362,8 +369,7 @@ class AKS:
362
369
  'The maximum number of internal knots cannot be less than the minimum'
363
370
  ' number of internal knots.'
364
371
  )
365
-
366
- return knots, min_internal_knots, max_internal_knots
372
+ return max_internal_knots
367
373
 
368
374
  def aspline(
369
375
  self,
meridian/model/media.py CHANGED
@@ -63,8 +63,8 @@ class MediaTensors:
63
63
  media_spend: A tensor constructed from `InputData.media_spend`.
64
64
  media_transformer: A `MediaTransformer` to scale media tensors using the
65
65
  model's media data.
66
- media_scaled: The media tensor normalized by population and by the median
67
- value.
66
+ media_scaled: The media tensor after pre-modeling transformations including
67
+ population scaling and scaling by the median non-zero value.
68
68
  prior_media_scaled_counterfactual: A tensor containing `media_scaled` values
69
69
  corresponding to the counterfactual scenario required for the prior
70
70
  calculation. For ROI priors, the counterfactual scenario is where media is
@@ -169,8 +169,9 @@ class OrganicMediaTensors:
169
169
  organic_media: A tensor constructed from `InputData.organic_media`.
170
170
  organic_media_transformer: A `MediaTransformer` to scale media tensors using
171
171
  the model's organic media data.
172
- organic_media_scaled: The organic media tensor normalized by population and
173
- by the median value.
172
+ organic_media_scaled: The organic media tensor after pre-modeling
173
+ transformations including population scaling and scaling by the media
174
+ non-zero value.
174
175
  """
175
176
 
176
177
  organic_media: backend.Tensor | None = None
@@ -214,8 +215,8 @@ class RfTensors:
214
215
  rf_spend: A tensor constructed from `InputData.rf_spend`.
215
216
  reach_transformer: A `MediaTransformer` to scale RF tensors using the
216
217
  model's RF data.
217
- reach_scaled: A reach tensor normalized by population and by the median
218
- value.
218
+ reach_scaled: A reach tensor after pre-modeling transformations including
219
+ population scaling and scaling by the median non-zero value.
219
220
  prior_reach_scaled_counterfactual: A tensor containing `reach_scaled` values
220
221
  corresponding to the counterfactual scenario required for the prior
221
222
  calculation. For ROI priors, the counterfactual scenario is where reach is
@@ -324,8 +325,9 @@ class OrganicRfTensors:
324
325
  organic_frequency: A tensor constructed from `InputData.organic_frequency`.
325
326
  organic_reach_transformer: A `MediaTransformer` to scale organic RF tensors
326
327
  using the model's organic RF data.
327
- organic_reach_scaled: An organic reach tensor normalized by population and
328
- by the median value.
328
+ organic_reach_scaled: An organic reach tensor after pre-modeling
329
+ transformations including population scaling and scaling by the median
330
+ non-zero value.
329
331
  """
330
332
 
331
333
  organic_reach: backend.Tensor | None = None