google-meridian 1.3.2__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/METADATA +8 -4
  2. {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/RECORD +49 -17
  3. {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/top_level.txt +1 -0
  4. meridian/analysis/summarizer.py +7 -2
  5. meridian/analysis/test_utils.py +934 -485
  6. meridian/analysis/visualizer.py +10 -6
  7. meridian/constants.py +1 -0
  8. meridian/data/test_utils.py +82 -10
  9. meridian/model/__init__.py +2 -0
  10. meridian/model/context.py +925 -0
  11. meridian/model/eda/constants.py +1 -0
  12. meridian/model/equations.py +418 -0
  13. meridian/model/knots.py +58 -47
  14. meridian/model/model.py +93 -792
  15. meridian/version.py +1 -1
  16. scenarioplanner/__init__.py +42 -0
  17. scenarioplanner/converters/__init__.py +25 -0
  18. scenarioplanner/converters/dataframe/__init__.py +28 -0
  19. scenarioplanner/converters/dataframe/budget_opt_converters.py +383 -0
  20. scenarioplanner/converters/dataframe/common.py +71 -0
  21. scenarioplanner/converters/dataframe/constants.py +137 -0
  22. scenarioplanner/converters/dataframe/converter.py +42 -0
  23. scenarioplanner/converters/dataframe/dataframe_model_converter.py +70 -0
  24. scenarioplanner/converters/dataframe/marketing_analyses_converters.py +543 -0
  25. scenarioplanner/converters/dataframe/rf_opt_converters.py +314 -0
  26. scenarioplanner/converters/mmm.py +743 -0
  27. scenarioplanner/converters/mmm_converter.py +58 -0
  28. scenarioplanner/converters/sheets.py +156 -0
  29. scenarioplanner/converters/test_data.py +714 -0
  30. scenarioplanner/linkingapi/__init__.py +47 -0
  31. scenarioplanner/linkingapi/constants.py +27 -0
  32. scenarioplanner/linkingapi/url_generator.py +131 -0
  33. scenarioplanner/mmm_ui_proto_generator.py +354 -0
  34. schema/__init__.py +5 -2
  35. schema/mmm_proto_generator.py +71 -0
  36. schema/model_consumer.py +133 -0
  37. schema/processors/__init__.py +77 -0
  38. schema/processors/budget_optimization_processor.py +832 -0
  39. schema/processors/common.py +64 -0
  40. schema/processors/marketing_processor.py +1136 -0
  41. schema/processors/model_fit_processor.py +367 -0
  42. schema/processors/model_kernel_processor.py +117 -0
  43. schema/processors/model_processor.py +412 -0
  44. schema/processors/reach_frequency_optimization_processor.py +584 -0
  45. schema/test_data.py +380 -0
  46. schema/utils/__init__.py +1 -0
  47. schema/utils/date_range_bucketing.py +117 -0
  48. {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/WHEEL +0 -0
  49. {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/licenses/LICENSE +0 -0
@@ -29,3 +29,4 @@ NATIONALIZE = 'nationalize'
29
29
  MEDIA_IMPRESSIONS_SCALED = 'media_impressions_scaled'
30
30
  IMPRESSION_SHARE_SCALED = 'impression_share_scaled'
31
31
  SPEND_SHARE = 'spend_share'
32
+ LABEL = 'label'
@@ -0,0 +1,418 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Core mathematical equations for the Meridian model.
16
+
17
+ This module defines the `ModelEquations` class, which encapsulates the stateless
18
+ mathematical functions used in the Meridian MMM. This includes the core model
19
+ definitions, such as adstock, hill, and other transformations used
20
+ during model fitting. It requires a `ModelContext` instance for data access.
21
+ """
22
+
23
+ from collections.abc import Sequence
24
+ import numbers
25
+
26
+ from meridian import backend
27
+ from meridian import constants
28
+ from meridian.model import adstock_hill
29
+ from meridian.model import context
30
+
31
+
32
+ __all__ = [
33
+ "ModelEquations",
34
+ ]
35
+
36
+
37
+ class ModelEquations:
38
+ """Provides core, stateless mathematical functions for Meridian MMM."""
39
+
40
+ def __init__(self, model_context: context.ModelContext):
41
+ self._context = model_context
42
+
43
+ def adstock_hill_media(
44
+ self,
45
+ *,
46
+ media: backend.Tensor,
47
+ alpha: backend.Tensor,
48
+ ec: backend.Tensor,
49
+ slope: backend.Tensor,
50
+ decay_functions: str | Sequence[str] = constants.GEOMETRIC_DECAY,
51
+ n_times_output: int | None = None,
52
+ ) -> backend.Tensor:
53
+ """Transforms media or using Adstock and Hill functions in the desired order.
54
+
55
+ Args:
56
+ media: Tensor of dimensions `(n_geos, n_media_times, n_media_channels)`
57
+ containing non-negative media execution values. Typically this is
58
+ impressions, but it can be any metric, such as `media_spend`. Clicks are
59
+ often used for paid search ads.
60
+ alpha: Uniform distribution for Adstock and Hill calculations.
61
+ ec: Shifted half-normal distribution for Adstock and Hill calculations.
62
+ slope: Deterministic distribution for Adstock and Hill calculations.
63
+ decay_functions: String or sequence of strings denoting the adstock decay
64
+ function(s) for each channel. Default: 'geometric'.
65
+ n_times_output: Number of time periods to output. This argument is
66
+ optional when the number of time periods in `media` equals
67
+ `n_media_times`, in which case `n_times_output` defaults to `n_times`.
68
+
69
+ Returns:
70
+ Tensor with dimensions `[..., n_geos, n_times, n_media_channels]`
71
+ representing Adstock and Hill-transformed media.
72
+ """
73
+ if n_times_output is None and (
74
+ media.shape[1] == self._context.n_media_times
75
+ ):
76
+ n_times_output = self._context.n_times
77
+ elif n_times_output is None:
78
+ raise ValueError(
79
+ "n_times_output is required. This argument is only optional when "
80
+ "`media` has a number of time periods equal to `n_media_times`."
81
+ )
82
+
83
+ adstock_transformer = adstock_hill.AdstockTransformer(
84
+ alpha=alpha,
85
+ max_lag=self._context.model_spec.max_lag,
86
+ n_times_output=n_times_output,
87
+ decay_functions=decay_functions,
88
+ )
89
+ hill_transformer = adstock_hill.HillTransformer(
90
+ ec=ec,
91
+ slope=slope,
92
+ )
93
+ transformers_list = (
94
+ [hill_transformer, adstock_transformer]
95
+ if self._context.model_spec.hill_before_adstock
96
+ else [adstock_transformer, hill_transformer]
97
+ )
98
+
99
+ media_out = media
100
+ for transformer in transformers_list:
101
+ media_out = transformer.forward(media_out)
102
+ return media_out
103
+
104
+ def adstock_hill_rf(
105
+ self,
106
+ *,
107
+ reach: backend.Tensor,
108
+ frequency: backend.Tensor,
109
+ alpha: backend.Tensor,
110
+ ec: backend.Tensor,
111
+ slope: backend.Tensor,
112
+ decay_functions: str | Sequence[str] = constants.GEOMETRIC_DECAY,
113
+ n_times_output: int | None = None,
114
+ ) -> backend.Tensor:
115
+ """Transforms reach and frequency (RF) using Hill and Adstock functions.
116
+
117
+ Args:
118
+ reach: Tensor of dimensions `(n_geos, n_media_times, n_rf_channels)`
119
+ containing non-negative media for reach.
120
+ frequency: Tensor of dimensions `(n_geos, n_media_times, n_rf_channels)`
121
+ containing non-negative media for frequency.
122
+ alpha: Uniform distribution for Adstock and Hill calculations.
123
+ ec: Shifted half-normal distribution for Adstock and Hill calculations.
124
+ slope: Deterministic distribution for Adstock and Hill calculations.
125
+ decay_functions: String or sequence of strings denoting the adstock decay
126
+ function(s) for each channel. Default: 'geometric'.
127
+ n_times_output: Number of time periods to output. This argument is
128
+ optional when the number of time periods in `reach` equals
129
+ `n_media_times`, in which case `n_times_output` defaults to `n_times`.
130
+
131
+ Returns:
132
+ Tensor with dimensions `[..., n_geos, n_times, n_rf_channels]`
133
+ representing Hill and Adstock-transformed RF.
134
+ """
135
+ if n_times_output is None and (
136
+ reach.shape[1] == self._context.n_media_times
137
+ ):
138
+ n_times_output = self._context.n_times
139
+ elif n_times_output is None:
140
+ raise ValueError(
141
+ "n_times_output is required. This argument is only optional when "
142
+ "`reach` has a number of time periods equal to `n_media_times`."
143
+ )
144
+
145
+ hill_transformer = adstock_hill.HillTransformer(
146
+ ec=ec,
147
+ slope=slope,
148
+ )
149
+ adstock_transformer = adstock_hill.AdstockTransformer(
150
+ alpha=alpha,
151
+ max_lag=self._context.model_spec.max_lag,
152
+ n_times_output=n_times_output,
153
+ decay_functions=decay_functions,
154
+ )
155
+ adj_frequency = hill_transformer.forward(frequency)
156
+ rf_out = adstock_transformer.forward(reach * adj_frequency)
157
+
158
+ return rf_out
159
+
160
+ def compute_non_media_treatments_baseline(
161
+ self,
162
+ non_media_baseline_values: Sequence[str | float] | None = None,
163
+ ) -> backend.Tensor:
164
+ """Computes the baseline for each non-media treatment channel.
165
+
166
+ Args:
167
+ non_media_baseline_values: Optional list of shape
168
+ `(n_non_media_channels,)`. Each element is either a float (which means
169
+ that the fixed value will be used as baseline for the given channel) or
170
+ one of the strings "min" or "max" (which mean that the global minimum or
171
+ maximum value will be used as baseline for the values of the given
172
+ non_media treatment channel). If float values are provided, it is
173
+ expected that they are scaled by population for the channels where
174
+ `model_spec.non_media_population_scaling_id` is `True`. If `None`, the
175
+ `model_spec.non_media_baseline_values` is used, which defaults to the
176
+ minimum value for each non_media treatment channel.
177
+
178
+ Returns:
179
+ A tensor of shape `(n_non_media_channels,)` containing the
180
+ baseline values for each non-media treatment channel.
181
+ """
182
+ if non_media_baseline_values is None:
183
+ non_media_baseline_values = (
184
+ self._context.model_spec.non_media_baseline_values
185
+ )
186
+
187
+ no_op_scaling_factor = backend.ones_like(self._context.population)[
188
+ :, backend.newaxis, backend.newaxis
189
+ ]
190
+ if self._context.model_spec.non_media_population_scaling_id is not None:
191
+ scaling_factors = backend.where(
192
+ self._context.model_spec.non_media_population_scaling_id,
193
+ self._context.population[:, backend.newaxis, backend.newaxis],
194
+ no_op_scaling_factor,
195
+ )
196
+ else:
197
+ scaling_factors = no_op_scaling_factor
198
+
199
+ non_media_treatments_population_scaled = backend.divide_no_nan(
200
+ self._context.non_media_treatments, scaling_factors
201
+ )
202
+
203
+ if non_media_baseline_values is None:
204
+ # If non_media_baseline_values is not provided, use the minimum
205
+ # value for each non_media treatment channel as the baseline.
206
+ non_media_baseline_values_filled = [
207
+ constants.NON_MEDIA_BASELINE_MIN
208
+ ] * non_media_treatments_population_scaled.shape[-1]
209
+ else:
210
+ non_media_baseline_values_filled = non_media_baseline_values
211
+
212
+ if non_media_treatments_population_scaled.shape[-1] != len(
213
+ non_media_baseline_values_filled
214
+ ):
215
+ raise ValueError(
216
+ "The number of non-media channels"
217
+ f" ({non_media_treatments_population_scaled.shape[-1]}) does not"
218
+ " match the number of baseline values"
219
+ f" ({len(non_media_baseline_values_filled)})."
220
+ )
221
+
222
+ baseline_list = []
223
+ for channel in range(non_media_treatments_population_scaled.shape[-1]):
224
+ baseline_value = non_media_baseline_values_filled[channel]
225
+
226
+ if baseline_value == constants.NON_MEDIA_BASELINE_MIN:
227
+ baseline_for_channel = backend.reduce_min(
228
+ non_media_treatments_population_scaled[..., channel], axis=[0, 1]
229
+ )
230
+ elif baseline_value == constants.NON_MEDIA_BASELINE_MAX:
231
+ baseline_for_channel = backend.reduce_max(
232
+ non_media_treatments_population_scaled[..., channel], axis=[0, 1]
233
+ )
234
+ elif isinstance(baseline_value, numbers.Number):
235
+ baseline_for_channel = backend.to_tensor(
236
+ baseline_value, dtype=backend.float32
237
+ )
238
+ else:
239
+ raise ValueError(
240
+ f"Invalid non_media_baseline_values value: '{baseline_value}'. Only"
241
+ " float numbers and strings 'min' and 'max' are supported."
242
+ )
243
+
244
+ baseline_list.append(baseline_for_channel)
245
+
246
+ return backend.stack(baseline_list, axis=-1)
247
+
248
+ def linear_predictor_counterfactual_difference_media(
249
+ self,
250
+ *,
251
+ media_transformed: backend.Tensor,
252
+ alpha_m: backend.Tensor,
253
+ ec_m: backend.Tensor,
254
+ slope_m: backend.Tensor,
255
+ ) -> backend.Tensor:
256
+ """Calculates linear predictor counterfactual difference for non-RF media.
257
+
258
+ For non-RF media variables (paid or organic), this function calculates the
259
+ linear predictor difference between the treatment variable and its
260
+ counterfactual. "Linear predictor" refers to the output of the hill/adstock
261
+ function, which is multiplied by the geo-level coefficient.
262
+
263
+ This function does the calculation efficiently by only calculating calling
264
+ the hill/adstock function if the prior counterfactual is not all zeros.
265
+
266
+ Args:
267
+ media_transformed: The output of the hill/adstock function for actual
268
+ historical media data.
269
+ alpha_m: The adstock alpha parameter values.
270
+ ec_m: The adstock ec parameter values.
271
+ slope_m: The adstock hill slope parameter values.
272
+
273
+ Returns:
274
+ The linear predictor difference between the treatment variable and its
275
+ counterfactual.
276
+ """
277
+ if self._context.media_tensors.prior_media_scaled_counterfactual is None:
278
+ return media_transformed
279
+ media_transformed_counterfactual = self.adstock_hill_media(
280
+ media=self._context.media_tensors.prior_media_scaled_counterfactual,
281
+ alpha=alpha_m,
282
+ ec=ec_m,
283
+ slope=slope_m,
284
+ decay_functions=self._context.adstock_decay_spec.media,
285
+ )
286
+ # Absolute values is needed because the difference is negative for mROI
287
+ # priors and positive for ROI and contribution priors.
288
+ return backend.absolute(
289
+ media_transformed - media_transformed_counterfactual
290
+ )
291
+
292
+ def linear_predictor_counterfactual_difference_rf(
293
+ self,
294
+ *,
295
+ rf_transformed: backend.Tensor,
296
+ alpha_rf: backend.Tensor,
297
+ ec_rf: backend.Tensor,
298
+ slope_rf: backend.Tensor,
299
+ ) -> backend.Tensor:
300
+ """Calculates linear predictor counterfactual difference for RF media.
301
+
302
+ For RF media variables (paid or organic), this function calculates the
303
+ linear predictor difference between the treatment variable and its
304
+ counterfactual. "Linear predictor" refers to the output of the hill/adstock
305
+ function, which is multiplied by the geo-level coefficient.
306
+
307
+ This function does the calculation efficiently by only calculating calling
308
+ the hill/adstock function if the prior counterfactual is not all zeros.
309
+
310
+ Args:
311
+ rf_transformed: The output of the hill/adstock function for actual
312
+ historical media data.
313
+ alpha_rf: The adstock alpha parameter values.
314
+ ec_rf: The adstock ec parameter values.
315
+ slope_rf: The adstock hill slope parameter values.
316
+
317
+ Returns:
318
+ The linear predictor difference between the treatment variable and its
319
+ counterfactual.
320
+ """
321
+ if self._context.rf_tensors.prior_reach_scaled_counterfactual is None:
322
+ return rf_transformed
323
+ rf_transformed_counterfactual = self.adstock_hill_rf(
324
+ reach=self._context.rf_tensors.prior_reach_scaled_counterfactual,
325
+ frequency=self._context.rf_tensors.frequency,
326
+ alpha=alpha_rf,
327
+ ec=ec_rf,
328
+ slope=slope_rf,
329
+ decay_functions=self._context.adstock_decay_spec.rf,
330
+ )
331
+ # Absolute values is needed because the difference is negative for mROI
332
+ # priors and positive for ROI and contribution priors.
333
+ return backend.absolute(rf_transformed - rf_transformed_counterfactual)
334
+
335
+ def calculate_beta_x(
336
+ self,
337
+ *,
338
+ is_non_media: bool,
339
+ incremental_outcome_x: backend.Tensor,
340
+ linear_predictor_counterfactual_difference: backend.Tensor,
341
+ eta_x: backend.Tensor,
342
+ beta_gx_dev: backend.Tensor,
343
+ ) -> backend.Tensor:
344
+ """Calculates coefficient mean parameter for any treatment variable type.
345
+
346
+ The "beta_x" in the function name refers to the coefficient mean parameter
347
+ of any treatment variable. The "x" can represent "m", "rf", "om", or "orf".
348
+ This function can also be used to calculate "gamma_n" for any non-media
349
+ treatments.
350
+
351
+ Args:
352
+ is_non_media: Boolean indicating whether the treatment variable is a
353
+ non-media treatment. This argument is used to determine whether the
354
+ coefficient random effects are normal or log-normal. If `True`, then
355
+ random effects are assumed to be normal. Otherwise, the distribution is
356
+ inferred from `self._context.media_effects_dist`.
357
+ incremental_outcome_x: The incremental outcome of the treatment variable,
358
+ which depends on the parameter values of a particular prior or posterior
359
+ draw. The "_x" indicates that this is a tensor with length equal to the
360
+ dimension of the treatment variable.
361
+ linear_predictor_counterfactual_difference: The difference between the
362
+ treatment variable and its counterfactual on the linear predictor scale.
363
+ "Linear predictor" refers to the quantity that is multiplied by the
364
+ geo-level coefficient. For media variables, this is the output of the
365
+ hill/adstock transformation function. For non-media treatments, this is
366
+ simply the treatment variable after centering/scaling transformations.
367
+ This tensor has dimensions for geo, time, and channel.
368
+ eta_x: The random effect standard deviation parameter values. For media
369
+ variables, the "x" represents "m", "rf", "om", or "orf". For non-media
370
+ treatments, this argument should be set to `xi_n`, which is analogous to
371
+ "eta".
372
+ beta_gx_dev: The latent standard normal parameter values of the geo-level
373
+ coefficients. For media variables, the "x" represents "m", "rf", "om",
374
+ or "orf". For non-media treatments, this argument should be set to
375
+ `gamma_gn_dev`, which is analogous to "beta_gx_dev".
376
+
377
+ Returns:
378
+ The coefficient mean parameter of the treatment variable, which has
379
+ dimension equal to the number of treatment channels..
380
+ """
381
+ if is_non_media:
382
+ random_effects_normal = True
383
+ else:
384
+ random_effects_normal = (
385
+ self._context.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
386
+ )
387
+ if self._context.revenue_per_kpi is None:
388
+ revenue_per_kpi = backend.ones(
389
+ [self._context.n_geos, self._context.n_times], dtype=backend.float32
390
+ )
391
+ else:
392
+ revenue_per_kpi = self._context.revenue_per_kpi
393
+ incremental_outcome_gx_over_beta_gx = backend.einsum(
394
+ "...gtx,gt,g,->...gx",
395
+ linear_predictor_counterfactual_difference,
396
+ revenue_per_kpi,
397
+ self._context.population,
398
+ self._context.kpi_transformer.population_scaled_stdev,
399
+ )
400
+ if random_effects_normal:
401
+ numerator_term_x = backend.einsum(
402
+ "...gx,...gx,...x->...x",
403
+ incremental_outcome_gx_over_beta_gx,
404
+ beta_gx_dev,
405
+ eta_x,
406
+ )
407
+ denominator_term_x = backend.einsum(
408
+ "...gx->...x", incremental_outcome_gx_over_beta_gx
409
+ )
410
+ return (incremental_outcome_x - numerator_term_x) / denominator_term_x
411
+ # For log-normal random effects, beta_x and eta_x are not mean & std.
412
+ # The parameterization is beta_gx ~ exp(beta_x + eta_x * N(0, 1)).
413
+ denominator_term_x = backend.einsum(
414
+ "...gx,...gx->...x",
415
+ incremental_outcome_gx_over_beta_gx,
416
+ backend.exp(beta_gx_dev * eta_x[..., backend.newaxis, :]),
417
+ )
418
+ return backend.log(incremental_outcome_x) - backend.log(denominator_term_x)
meridian/model/knots.py CHANGED
@@ -14,18 +14,17 @@
14
14
 
15
15
  """Auxiliary functions for knots calculations."""
16
16
 
17
- import bisect
18
17
  from collections.abc import Collection, Sequence
19
18
  import copy
20
19
  import dataclasses
21
20
  import math
22
21
  import pprint
23
22
  from typing import Any
23
+
24
24
  from meridian import constants
25
25
  from meridian.data import input_data
26
26
  import numpy as np
27
- # TODO: b/437393442 - migrate patsy
28
- from patsy import highlevel
27
+ from scipy import interpolate
29
28
  from statsmodels.regression import linear_model
30
29
 
31
30
 
@@ -36,40 +35,35 @@ __all__ = [
36
35
  ]
37
36
 
38
37
 
39
- # TODO: Reimplement with a more readable method.
40
- def _find_neighboring_knots_indices(
38
+ def _find_left_knot_indices(
39
+ *,
41
40
  times: np.ndarray,
42
41
  knot_locations: np.ndarray,
43
- ) -> Sequence[Sequence[int] | None]:
44
- """Return indices of neighboring knot locations.
45
-
46
- Returns indices in `knot_locations` that correspond to the neighboring knot
47
- locations for each time period. If a time point is at or before the first
48
- knot, the first knot is the only neighboring knot. If a time point is after
49
- the last knot, the last knot is the only neighboring knot.
42
+ ) -> Sequence[int]:
43
+ """Return the index of the left neighboring knot for each time point.
50
44
 
51
45
  Args:
52
46
  times: Times `0, 1, 2,..., (n_times-1)`.
53
47
  knot_locations: The location of knots within `0, 1, 2,..., (n_times-1)`.
54
48
 
55
49
  Returns:
56
- List of length `n_times`. Each element is the indices of the neighboring
57
- knot locations for the respective time period. If a time point is at or
58
- before the first knot, the first knot is the only neighboring knot. If a
59
- time point is after the last knot, the last knot is the only neighboring
60
- knot.
50
+ A list of indices of the left neighboring knot for each time point. The
51
+ length of the list is equal to the length of `times`.
52
+ - If a time point is at or before the first knot, the index is 0.
53
+ - If a time point is at or after the last knot, the index is `n_knots - 1`.
54
+ - Otherwise, it's the index of the knot just to the left.
61
55
  """
62
- neighboring_knots_indices = [None] * len(times)
63
- for t in times:
64
- # knot_locations assumed to be sorted.
65
- if t <= knot_locations[0]:
66
- neighboring_knots_indices[t] = [0]
67
- elif t >= knot_locations[-1]:
68
- neighboring_knots_indices[t] = [len(knot_locations) - 1]
69
- else:
70
- bisect_index = bisect.bisect_left(knot_locations, t)
71
- neighboring_knots_indices[t] = [bisect_index - 1, bisect_index]
72
- return neighboring_knots_indices
56
+ n_knots = len(knot_locations)
57
+ # Find indices such that knot_locations[i-1] < times <= knot_locations[i]
58
+ insert_indices = np.searchsorted(knot_locations, times, side='right')
59
+ left_knot_indices = insert_indices - 1
60
+
61
+ # Handle edge cases for times before the first knot
62
+ left_knot_indices[times < knot_locations[0]] = 0
63
+ # Handle edge cases for times at or after the last knot
64
+ left_knot_indices[times >= knot_locations[-1]] = n_knots - 1
65
+
66
+ return left_knot_indices
73
67
 
74
68
 
75
69
  def l1_distance_weights(
@@ -115,16 +109,31 @@ def l1_distance_weights(
115
109
  time_minus_knot = abs(knot_locations[:, np.newaxis] - times[np.newaxis, :])
116
110
 
117
111
  w = np.zeros(time_minus_knot.shape, dtype=np.float32)
118
- neighboring_knots_indices = _find_neighboring_knots_indices(
119
- times, knot_locations
112
+ left_knot_indices = _find_left_knot_indices(
113
+ times=times, knot_locations=knot_locations
120
114
  )
115
+
121
116
  for t in times:
122
- idx = neighboring_knots_indices[t]
123
- if len(idx) == 1:
124
- w[idx, t] = 1
117
+ left_idx = left_knot_indices[t]
118
+ current_time = times[t]
119
+
120
+ if current_time in knot_locations:
121
+ # If time is exactly at a knot, give all weight to that knot.
122
+ knot_idx = np.where(knot_locations == current_time)[0][0]
123
+ w[knot_idx, t] = 1.0
124
+ elif current_time < knot_locations[0] or current_time > knot_locations[-1]:
125
+ # Outside the knot range, assign full weight to the closest endpoint knot.
126
+ w[left_idx, t] = 1.0
125
127
  else:
126
- # Weight is in proportion to how close the two neighboring knots are.
127
- w[idx, t] = 1 - (time_minus_knot[idx, t] / time_minus_knot[idx, t].sum())
128
+ # Time is between left_idx and left_idx + 1.
129
+ left_dist = time_minus_knot[left_idx, t]
130
+ right_dist = time_minus_knot[left_idx + 1, t]
131
+ total_dist = left_dist + right_dist
132
+
133
+ # Assign weight inversely proportional to distance.
134
+ # The closer knot gets more weight.
135
+ w[left_idx, t] = right_dist / total_dist
136
+ w[left_idx + 1, t] = left_dist / total_dist
128
137
 
129
138
  return w
130
139
 
@@ -319,6 +328,17 @@ class AKS:
319
328
 
320
329
  return AKSResult(knots_sel[opt_idx], model[opt_idx])
321
330
 
331
+ def _get_bspline_matrix(self, x, knots):
332
+ """Replaces patsy.highlevel.dmatrix('bs(...)', ...)"""
333
+ # Pad knots at boundaries to match patsy's 'bs()' behavior
334
+ t = np.concatenate((
335
+ [x.min()] * (self._DEGREE + 1),
336
+ np.sort(knots),
337
+ [x.max()] * (self._DEGREE + 1),
338
+ ))
339
+ # Create design matrix (standard numpy array)
340
+ return interpolate.BSpline.design_matrix(x, t, self._DEGREE).toarray()
341
+
322
342
  def _calculate_initial_knots(
323
343
  self,
324
344
  x: np.ndarray,
@@ -433,14 +453,7 @@ class AKS:
433
453
  'Provided x and y args for aspline must both be 1 dimensional!'
434
454
  )
435
455
 
436
- bs_cmd = (
437
- 'bs(x,knots=['
438
- + ','.join(map(str, knots))
439
- + '],degree='
440
- + str(self._DEGREE)
441
- + ',include_intercept=True)-1'
442
- )
443
- xmat = highlevel.dmatrix(bs_cmd, {'x': x})
456
+ xmat = self._get_bspline_matrix(x, knots)
444
457
  nrow = xmat.shape[0]
445
458
  ncol = xmat.shape[1]
446
459
 
@@ -472,10 +485,8 @@ class AKS:
472
485
  if converge:
473
486
  sel_ls[index_penalty] = sel
474
487
  knots_sel[index_penalty] = knots[sel > 0.99]
475
- bs_cmd_iter = (
476
- f"bs(x,knots=[{','.join(map(str, knots_sel[index_penalty]))}],degree={self._DEGREE},include_intercept=True)-1"
477
- )
478
- design_mat = highlevel.dmatrix(bs_cmd_iter, {'x': x})
488
+
489
+ design_mat = self._get_bspline_matrix(x, knots_sel[index_penalty])
479
490
  x_sel[index_penalty] = design_mat
480
491
  bs_model = linear_model.OLS(y, x_sel[index_penalty]).fit()
481
492
  model[index_penalty] = bs_model