google-meridian 1.1.6__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.1.dist-info}/METADATA +8 -2
- google_meridian-1.2.1.dist-info/RECORD +52 -0
- meridian/__init__.py +1 -0
- meridian/analysis/analyzer.py +621 -393
- meridian/analysis/optimizer.py +403 -351
- meridian/analysis/summarizer.py +31 -16
- meridian/analysis/test_utils.py +96 -94
- meridian/analysis/visualizer.py +53 -54
- meridian/backend/__init__.py +975 -0
- meridian/backend/config.py +118 -0
- meridian/backend/test_utils.py +181 -0
- meridian/constants.py +71 -10
- meridian/data/input_data.py +99 -0
- meridian/data/test_utils.py +146 -12
- meridian/mlflow/autolog.py +2 -2
- meridian/model/adstock_hill.py +280 -33
- meridian/model/eda/__init__.py +17 -0
- meridian/model/eda/eda_engine.py +735 -0
- meridian/model/knots.py +525 -2
- meridian/model/media.py +62 -54
- meridian/model/model.py +224 -97
- meridian/model/model_test_data.py +331 -159
- meridian/model/posterior_sampler.py +388 -383
- meridian/model/prior_distribution.py +612 -177
- meridian/model/prior_sampler.py +65 -65
- meridian/model/spec.py +23 -3
- meridian/model/transformers.py +55 -49
- meridian/version.py +1 -1
- google_meridian-1.1.6.dist-info/RECORD +0 -47
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.1.dist-info}/WHEEL +0 -0
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.1.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.1.dist-info}/top_level.txt +0 -0
meridian/model/knots.py
CHANGED
|
@@ -16,8 +16,16 @@
|
|
|
16
16
|
|
|
17
17
|
import bisect
|
|
18
18
|
from collections.abc import Collection, Sequence
|
|
19
|
+
import copy
|
|
19
20
|
import dataclasses
|
|
21
|
+
import math
|
|
22
|
+
from typing import Any
|
|
23
|
+
from meridian import constants
|
|
24
|
+
from meridian.data import input_data
|
|
20
25
|
import numpy as np
|
|
26
|
+
# TODO: b/437393442 - migrate patsy
|
|
27
|
+
from patsy import highlevel
|
|
28
|
+
from statsmodels.regression import linear_model
|
|
21
29
|
|
|
22
30
|
|
|
23
31
|
__all__ = [
|
|
@@ -144,6 +152,8 @@ class KnotInfo:
|
|
|
144
152
|
def get_knot_info(
|
|
145
153
|
n_times: int,
|
|
146
154
|
knots: int | Collection[int] | None,
|
|
155
|
+
enable_aks: bool = False,
|
|
156
|
+
data: input_data.InputData | None = None,
|
|
147
157
|
is_national: bool = False,
|
|
148
158
|
) -> KnotInfo:
|
|
149
159
|
"""Returns the number of knots, knot locations, and weights.
|
|
@@ -161,6 +171,12 @@ def get_knot_info(
|
|
|
161
171
|
coefficient used for all time periods. If `knots` is `None`, then the
|
|
162
172
|
numbers of knots used is equal to the number of time periods. This is
|
|
163
173
|
equivalent to each time period having its own regression coefficient.
|
|
174
|
+
enable_aks: A boolean indicating whether to use the Automatic Knot Selection
|
|
175
|
+
algorithm to select optimal number of knots for running the model instead
|
|
176
|
+
of the default 1 for national and n_times for non-national models.
|
|
177
|
+
data: An Optional InputData object used by the Automatic Knot Selection
|
|
178
|
+
algorithm to calculate optimal number of knots from the provided Input
|
|
179
|
+
Data.
|
|
164
180
|
is_national: A boolean indicator whether to adapt the knot information for a
|
|
165
181
|
national model.
|
|
166
182
|
|
|
@@ -169,8 +185,17 @@ def get_knot_info(
|
|
|
169
185
|
weights used to multiply with the knot values to get time-varying
|
|
170
186
|
coefficients.
|
|
171
187
|
"""
|
|
172
|
-
|
|
173
|
-
|
|
188
|
+
if enable_aks:
|
|
189
|
+
if data is None:
|
|
190
|
+
raise ValueError(
|
|
191
|
+
'If enable_aks is true then input data must be provided.'
|
|
192
|
+
)
|
|
193
|
+
else:
|
|
194
|
+
aks = AKS(data)
|
|
195
|
+
knots = aks.automatic_knot_selection().knots
|
|
196
|
+
n_knots = len(knots)
|
|
197
|
+
knot_locations = knots
|
|
198
|
+
elif isinstance(knots, int):
|
|
174
199
|
if knots < 1:
|
|
175
200
|
raise ValueError('If knots is an integer, it must be at least 1.')
|
|
176
201
|
elif knots > n_times:
|
|
@@ -208,3 +233,501 @@ def get_knot_info(
|
|
|
208
233
|
weights = l1_distance_weights(n_times, knot_locations)
|
|
209
234
|
|
|
210
235
|
return KnotInfo(n_knots, knot_locations, weights)
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
@dataclasses.dataclass(frozen=True)
|
|
239
|
+
class AKSResult:
|
|
240
|
+
knots: np.ndarray[int, np.dtype[int]]
|
|
241
|
+
model: linear_model.OLS
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
class AKS:
|
|
245
|
+
"""Class for automatically selecting knots in Meridian Core Library."""
|
|
246
|
+
|
|
247
|
+
_BASE_PENALTY = np.logspace(-1, 2, 100)
|
|
248
|
+
_DEGREE = 1
|
|
249
|
+
|
|
250
|
+
def __init__(self, data: input_data.InputData):
|
|
251
|
+
self._data = data
|
|
252
|
+
|
|
253
|
+
def automatic_knot_selection(self) -> AKSResult:
|
|
254
|
+
"""Calculates the optimal number of knots for Meridian model using Automatic knot selection with A-spline.
|
|
255
|
+
|
|
256
|
+
Returns:
|
|
257
|
+
Selected knots and the corresponding B-spline model.
|
|
258
|
+
"""
|
|
259
|
+
n_times = len(self._data.time)
|
|
260
|
+
n_geos = len(self._data.geo)
|
|
261
|
+
|
|
262
|
+
y_tensor = self._data.scaled_centered_kpi
|
|
263
|
+
y = np.reshape(y_tensor, (n_geos * n_times,))
|
|
264
|
+
x = np.reshape(
|
|
265
|
+
np.repeat([range(n_times)], n_geos, axis=0), (n_geos * n_times,)
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
knots, min_internal_knots, max_internal_knots = (
|
|
269
|
+
self._calculate_initial_knots(x)
|
|
270
|
+
)
|
|
271
|
+
geo_scaling_factor = 1 / np.sqrt(len(self._data.geo))
|
|
272
|
+
penalty = geo_scaling_factor * self._BASE_PENALTY
|
|
273
|
+
|
|
274
|
+
aspline = self.aspline(x=x, y=y, knots=knots, penalty=penalty)
|
|
275
|
+
n_knots = np.array([len(x) for x in aspline[constants.KNOTS_SELECTED]])
|
|
276
|
+
feasible_idx = np.where(
|
|
277
|
+
(n_knots >= min_internal_knots) & (n_knots <= max_internal_knots)
|
|
278
|
+
)[0]
|
|
279
|
+
information_criterion = aspline[constants.AIC][feasible_idx]
|
|
280
|
+
knots_sel = [aspline[constants.KNOTS_SELECTED][i] for i in feasible_idx]
|
|
281
|
+
model = [aspline[constants.MODEL][i] for i in feasible_idx]
|
|
282
|
+
opt_idx = max(
|
|
283
|
+
np.where(information_criterion == min(information_criterion))[0]
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
return AKSResult(knots_sel[opt_idx], model[opt_idx])
|
|
287
|
+
|
|
288
|
+
def _calculate_initial_knots(
|
|
289
|
+
self,
|
|
290
|
+
x: np.ndarray,
|
|
291
|
+
) -> tuple[np.ndarray, int, int]:
|
|
292
|
+
"""Calculates initial knots based on unique x values.
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
x: A flattened array of indexed time coordinates, repeated n_geos times.
|
|
296
|
+
e.g. [0, 1, 2, 3, ..., 0, 1, 2, 3, ...].
|
|
297
|
+
|
|
298
|
+
Returns:
|
|
299
|
+
A tuple containing:
|
|
300
|
+
- The calculated knots.
|
|
301
|
+
- The minimum number of internal knots.
|
|
302
|
+
- The maximum number of internal knots.
|
|
303
|
+
"""
|
|
304
|
+
n_media = (
|
|
305
|
+
len(self._data.media_channel)
|
|
306
|
+
if self._data.media_channel is not None
|
|
307
|
+
else 0
|
|
308
|
+
)
|
|
309
|
+
n_rf = (
|
|
310
|
+
len(self._data.rf_channel) if self._data.rf_channel is not None else 0
|
|
311
|
+
)
|
|
312
|
+
n_organic_media = (
|
|
313
|
+
len(self._data.organic_media_channel)
|
|
314
|
+
if self._data.organic_media_channel is not None
|
|
315
|
+
else 0
|
|
316
|
+
)
|
|
317
|
+
n_organic_rf = (
|
|
318
|
+
len(self._data.organic_rf_channel)
|
|
319
|
+
if self._data.organic_rf_channel is not None
|
|
320
|
+
else 0
|
|
321
|
+
)
|
|
322
|
+
n_non_media = (
|
|
323
|
+
len(self._data.non_media_channel)
|
|
324
|
+
if self._data.non_media_channel is not None
|
|
325
|
+
else 0
|
|
326
|
+
)
|
|
327
|
+
n_controls = (
|
|
328
|
+
len(self._data.control_variable)
|
|
329
|
+
if self._data.control_variable is not None
|
|
330
|
+
else 0
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
x_vals_unique = np.unique(x)
|
|
334
|
+
min_x_data, max_x_data = x_vals_unique.min(), x_vals_unique.max()
|
|
335
|
+
knots = x_vals_unique[
|
|
336
|
+
(x_vals_unique > min_x_data) & (x_vals_unique < max_x_data)
|
|
337
|
+
]
|
|
338
|
+
knots = np.sort(np.unique(knots))
|
|
339
|
+
# Drop one knot from the set of all knots because the algorithm requires one
|
|
340
|
+
# fewer degree of freedom than the total number of knots to function.
|
|
341
|
+
# Dropping the final knot is a natural and practical choice because it
|
|
342
|
+
# often has minimal impact on the overall model fit.
|
|
343
|
+
knots = knots[:-1]
|
|
344
|
+
min_internal_knots = 1
|
|
345
|
+
|
|
346
|
+
max_internal_knots = (
|
|
347
|
+
len(knots)
|
|
348
|
+
- n_media
|
|
349
|
+
- n_rf
|
|
350
|
+
- n_organic_media
|
|
351
|
+
- n_organic_rf
|
|
352
|
+
- n_non_media
|
|
353
|
+
- n_controls
|
|
354
|
+
)
|
|
355
|
+
if min_internal_knots > len(knots):
|
|
356
|
+
raise ValueError(
|
|
357
|
+
'The minimum number of internal knots cannot be greater than the'
|
|
358
|
+
' total number of initial knots.'
|
|
359
|
+
)
|
|
360
|
+
if max_internal_knots < min_internal_knots:
|
|
361
|
+
raise ValueError(
|
|
362
|
+
'The maximum number of internal knots cannot be less than the minimum'
|
|
363
|
+
' number of internal knots.'
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
return knots, min_internal_knots, max_internal_knots
|
|
367
|
+
|
|
368
|
+
def aspline(
|
|
369
|
+
self,
|
|
370
|
+
x: np.ndarray,
|
|
371
|
+
y: np.ndarray,
|
|
372
|
+
knots: np.ndarray,
|
|
373
|
+
penalty: np.ndarray,
|
|
374
|
+
max_iterations: int = 1000,
|
|
375
|
+
epsilon: float = 1e-5,
|
|
376
|
+
tol: float = 1e-6,
|
|
377
|
+
) -> dict[str, Any]:
|
|
378
|
+
"""Fits B-splines with automatic knot selection.
|
|
379
|
+
|
|
380
|
+
Args:
|
|
381
|
+
x: A flattened array of indexed time coordinates, repeated n_geos times.
|
|
382
|
+
e.g. [0, 1, 2, 3, ..., 0, 1, 2, 3, ...].
|
|
383
|
+
y: The flattened array of KPI values that have been population-scaled and
|
|
384
|
+
mean-centered by geo.
|
|
385
|
+
knots: Internal knots used for spline regression.
|
|
386
|
+
penalty: A vector of positive penalty values. The adaptive spline
|
|
387
|
+
regression is performed for every value of penalty.
|
|
388
|
+
max_iterations: Maximum number of iterations in the main loop.
|
|
389
|
+
epsilon: Value of the constant in the adaptive ridge procedure (see
|
|
390
|
+
Frommlet, F., Nuel, G. (2016) An Adaptive Ridge Procedure for L0
|
|
391
|
+
Regularization.)
|
|
392
|
+
tol: The tolerance chosen to diagnose convergence of the adaptive ridge
|
|
393
|
+
procedure.
|
|
394
|
+
|
|
395
|
+
Returns:
|
|
396
|
+
A dictionary of the following items:
|
|
397
|
+
selection_coefs: A list of selection coefficients for every value of
|
|
398
|
+
penalty.
|
|
399
|
+
knots_selected: A list of selected knots for every value of penalty.
|
|
400
|
+
model: A list of fitted models for every value of penalty.
|
|
401
|
+
regression_coefs: A list of estimated regression coefficients for every
|
|
402
|
+
value of penalty.
|
|
403
|
+
selected_matrix: A matrix of selected knots for every value of penalty.
|
|
404
|
+
aic: A list of AIC values for every value of penalty.
|
|
405
|
+
bic: A list of BIC values for every value of penalty.
|
|
406
|
+
ebic: A list of EBIC values for every value of penalty.
|
|
407
|
+
"""
|
|
408
|
+
if x.ndim != 1 or y.ndim != 1:
|
|
409
|
+
raise ValueError(
|
|
410
|
+
'Provided x and y args for aspline must both be 1 dimensional!'
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
bs_cmd = (
|
|
414
|
+
'bs(x,knots=['
|
|
415
|
+
+ ','.join(map(str, knots))
|
|
416
|
+
+ '],degree='
|
|
417
|
+
+ str(self._DEGREE)
|
|
418
|
+
+ ',include_intercept=True)-1'
|
|
419
|
+
)
|
|
420
|
+
xmat = highlevel.dmatrix(bs_cmd, {'x': x})
|
|
421
|
+
nrow = xmat.shape[0]
|
|
422
|
+
ncol = xmat.shape[1]
|
|
423
|
+
|
|
424
|
+
xx = xmat.T.dot(xmat)
|
|
425
|
+
xy = xmat.T.dot(y)
|
|
426
|
+
xx_rot = np.concat(
|
|
427
|
+
[
|
|
428
|
+
self._mat2rot(xx + (1e-20 * np.identity(ncol))),
|
|
429
|
+
np.zeros(ncol)[:, np.newaxis],
|
|
430
|
+
],
|
|
431
|
+
axis=1,
|
|
432
|
+
)
|
|
433
|
+
sigma0sq = linear_model.OLS(y, xmat).fit().mse_resid ** 2
|
|
434
|
+
model, x_sel, knots_sel, sel_ls, par_ls, aic, bic, ebic, dim, loglik = (
|
|
435
|
+
[None] * len(penalty) for _ in range(10)
|
|
436
|
+
)
|
|
437
|
+
old_sel, w = [np.ones(ncol - self._DEGREE - 1) for _ in range(2)]
|
|
438
|
+
par = np.ones(ncol)
|
|
439
|
+
index_penalty = 0
|
|
440
|
+
for _ in range(max_iterations):
|
|
441
|
+
par = self._wridge_solver(
|
|
442
|
+
xx_rot, xy, self._DEGREE, penalty[index_penalty], w, old_par=par
|
|
443
|
+
)
|
|
444
|
+
par_diff = np.diff(par, n=self._DEGREE + 1)
|
|
445
|
+
|
|
446
|
+
w = 1 / (par_diff**2 + epsilon**2)
|
|
447
|
+
sel = w * par_diff**2
|
|
448
|
+
converge = max(abs(old_sel - sel)) < tol
|
|
449
|
+
if converge:
|
|
450
|
+
sel_ls[index_penalty] = sel
|
|
451
|
+
knots_sel[index_penalty] = knots[sel > 0.99]
|
|
452
|
+
bs_cmd_iter = (
|
|
453
|
+
f"bs(x,knots=[{','.join(map(str, knots_sel[index_penalty]))}],degree={self._DEGREE},include_intercept=True)-1"
|
|
454
|
+
)
|
|
455
|
+
design_mat = highlevel.dmatrix(bs_cmd_iter, {'x': x})
|
|
456
|
+
x_sel[index_penalty] = design_mat
|
|
457
|
+
bs_model = linear_model.OLS(y, x_sel[index_penalty]).fit()
|
|
458
|
+
model[index_penalty] = bs_model
|
|
459
|
+
coefs = np.zeros(ncol, dtype=np.float32)
|
|
460
|
+
idx = np.concat([sel > 0.99, np.repeat(True, self._DEGREE + 1)])
|
|
461
|
+
coefs[idx] = bs_model.params
|
|
462
|
+
par_ls[index_penalty] = coefs
|
|
463
|
+
|
|
464
|
+
loglik[index_penalty] = sum(bs_model.resid**2 / sigma0sq) / 2
|
|
465
|
+
dim[index_penalty] = len(knots_sel[index_penalty]) + self._DEGREE + 1
|
|
466
|
+
aic[index_penalty] = 2 * dim[index_penalty] + 2 * loglik[index_penalty]
|
|
467
|
+
bic[index_penalty] = (
|
|
468
|
+
np.log(nrow) * dim[index_penalty] + 2 * loglik[index_penalty]
|
|
469
|
+
)
|
|
470
|
+
ebic[index_penalty] = bic[index_penalty] + 2 * np.log(
|
|
471
|
+
np.float32(math.comb(ncol, design_mat.shape[1]))
|
|
472
|
+
)
|
|
473
|
+
index_penalty = index_penalty + 1
|
|
474
|
+
if index_penalty > len(penalty) - 1:
|
|
475
|
+
break
|
|
476
|
+
old_sel = sel
|
|
477
|
+
|
|
478
|
+
sel_mat = np.round(np.stack(sel_ls, axis=-1), 1)
|
|
479
|
+
return {
|
|
480
|
+
constants.SELECTION_COEFS: sel_ls,
|
|
481
|
+
constants.KNOTS_SELECTED: knots_sel,
|
|
482
|
+
constants.MODEL: model,
|
|
483
|
+
constants.REGRESSION_COEFS: par_ls,
|
|
484
|
+
constants.SELECTED_MATRIX: sel_mat,
|
|
485
|
+
constants.AIC: np.array(aic),
|
|
486
|
+
constants.BIC: np.array(bic),
|
|
487
|
+
constants.EBIC: np.array(ebic),
|
|
488
|
+
}
|
|
489
|
+
|
|
490
|
+
def _mat2rot(self, band_mat: np.ndarray) -> np.ndarray:
|
|
491
|
+
"""Rotates a symmetric band matrix to get the rotated matrix associated.
|
|
492
|
+
|
|
493
|
+
Each column of the rotated matrix corresponds to a diagonal. The first
|
|
494
|
+
column is the main diagonal, the second one is the upper-diagonal and so on.
|
|
495
|
+
Artificial 0s are placed at the end of each column if necessary.
|
|
496
|
+
|
|
497
|
+
Args:
|
|
498
|
+
band_mat: The band square matrix to be rotated.
|
|
499
|
+
|
|
500
|
+
Returns:
|
|
501
|
+
The rotated matrix of band_mat.
|
|
502
|
+
"""
|
|
503
|
+
p = band_mat.shape[1]
|
|
504
|
+
l = 0
|
|
505
|
+
for i in range(p):
|
|
506
|
+
lprime = np.where(band_mat[i, :] != 0)[0]
|
|
507
|
+
l = np.maximum(l, lprime[len(lprime) - 1] - i)
|
|
508
|
+
|
|
509
|
+
rot_mat = np.zeros([p, l + 1])
|
|
510
|
+
rot_mat[:, 0] = np.diag(band_mat)
|
|
511
|
+
if l > 0:
|
|
512
|
+
for j in range(l):
|
|
513
|
+
rot_mat[:, j + 1] = np.concat([
|
|
514
|
+
np.diag(band_mat[range(p - j - 1), :][:, range(j + 1, p)]),
|
|
515
|
+
np.zeros(j + 1),
|
|
516
|
+
])
|
|
517
|
+
return rot_mat
|
|
518
|
+
|
|
519
|
+
def _band_weight(self, w: np.ndarray, diff: int) -> np.ndarray:
|
|
520
|
+
"""Creates the penalty matrix for A-Spline.
|
|
521
|
+
|
|
522
|
+
Args:
|
|
523
|
+
w: Vector of weights.
|
|
524
|
+
diff: Order of the differences to be applied to the parameters. Must be a
|
|
525
|
+
strictly positive integer.
|
|
526
|
+
|
|
527
|
+
Returns:
|
|
528
|
+
Weighted penalty matrix D'diag(w)D, where
|
|
529
|
+
D = diff(diag(len(w) + diff), differences = diff)}. Only the non-null
|
|
530
|
+
superdiagonals of the weight matrix are returned, each column
|
|
531
|
+
corresponding to a diagonal.
|
|
532
|
+
"""
|
|
533
|
+
ws = len(w)
|
|
534
|
+
rows = ws + diff
|
|
535
|
+
cols = diff + 1
|
|
536
|
+
|
|
537
|
+
# Compute the entries of the difference matrix
|
|
538
|
+
binom = np.zeros(cols, dtype=np.int32)
|
|
539
|
+
for i in range(cols):
|
|
540
|
+
binom[i] = math.comb(diff, i) * (-1) ** i
|
|
541
|
+
|
|
542
|
+
# Compute the limit indices
|
|
543
|
+
ind_mat = np.zeros([rows, 2], dtype=np.int32)
|
|
544
|
+
for ind in range(rows):
|
|
545
|
+
ind_mat[ind, 0] = 0 if ind - diff < 0 else ind - diff
|
|
546
|
+
ind_mat[ind, 1] = ind if ind < ws - 1 else ws - 1
|
|
547
|
+
|
|
548
|
+
# Main loop
|
|
549
|
+
result = np.zeros([rows, cols])
|
|
550
|
+
for j in range(cols):
|
|
551
|
+
for i in range(rows - j):
|
|
552
|
+
temp = 0.0
|
|
553
|
+
for k in range(ind_mat[i + j, 0], ind_mat[i, 1] + 1):
|
|
554
|
+
temp += binom[i - k] * binom[i + j - k] * w[k]
|
|
555
|
+
result[i, j] = temp
|
|
556
|
+
|
|
557
|
+
return result
|
|
558
|
+
|
|
559
|
+
def _ldl(self, rot_mat: np.ndarray) -> np.ndarray:
|
|
560
|
+
"""Solves the Fast LDL decomposition of symmetric band matrix of length k.
|
|
561
|
+
|
|
562
|
+
Args:
|
|
563
|
+
rot_mat: Rotated row-wised matrix of dimensions n*k, with first column
|
|
564
|
+
corresponding to the diagonal, the second to the first super-diagonal
|
|
565
|
+
and so on.
|
|
566
|
+
|
|
567
|
+
Returns:
|
|
568
|
+
Solution of the LDL decomposition.
|
|
569
|
+
"""
|
|
570
|
+
n = rot_mat.shape[0]
|
|
571
|
+
m = rot_mat.shape[1] - 1
|
|
572
|
+
rot_mat_new = copy.deepcopy(rot_mat)
|
|
573
|
+
for i in range(1, n + 1):
|
|
574
|
+
j0 = np.maximum(1, i - m)
|
|
575
|
+
for j in range(j0, i + 1):
|
|
576
|
+
for k in range(j0, j):
|
|
577
|
+
rot_mat_new[j - 1, i - j] -= (
|
|
578
|
+
rot_mat_new[k - 1, i - k]
|
|
579
|
+
* rot_mat_new[k - 1, j - k]
|
|
580
|
+
* rot_mat_new[k - 1, 0]
|
|
581
|
+
)
|
|
582
|
+
if i > j:
|
|
583
|
+
rot_mat_new[j - 1, i - j] /= rot_mat_new[j - 1, 0]
|
|
584
|
+
|
|
585
|
+
return rot_mat_new
|
|
586
|
+
|
|
587
|
+
def _bandsolve_kernel(
|
|
588
|
+
self, rot_mat: np.ndarray, rhs_mat: np.ndarray
|
|
589
|
+
) -> np.ndarray:
|
|
590
|
+
"""Solves the symmetric bandlinear system Ax = b.
|
|
591
|
+
|
|
592
|
+
This is the kernel function that solves the system, where A is the rotated
|
|
593
|
+
form of the band matrix and b is the right hand side.
|
|
594
|
+
|
|
595
|
+
Args:
|
|
596
|
+
rot_mat: Band square matrix in the rotated form. It's the visual rotation
|
|
597
|
+
by 90 degrees of the matrix, where subdiagonal are discarded.
|
|
598
|
+
rhs_mat: right hand side of the equation. Can be either a vector or a
|
|
599
|
+
matrix. If not supplied, the function return the inverse of rot_mat.
|
|
600
|
+
|
|
601
|
+
Returns:
|
|
602
|
+
Solution of the linear problem.
|
|
603
|
+
"""
|
|
604
|
+
rot_mat_ldl = self._ldl(rot_mat)
|
|
605
|
+
x = copy.deepcopy(rhs_mat)
|
|
606
|
+
n = rot_mat.shape[0]
|
|
607
|
+
k = rot_mat_ldl.shape[1] - 1
|
|
608
|
+
l = rhs_mat.shape[1]
|
|
609
|
+
|
|
610
|
+
for l in range(l):
|
|
611
|
+
# solve b=inv(L)b
|
|
612
|
+
for i in range(2, n + 1):
|
|
613
|
+
jmax = np.minimum(i - 1, k)
|
|
614
|
+
for j in range(1, jmax + 1):
|
|
615
|
+
x[i - 1, l] -= rot_mat_ldl[i - j - 1, j] * x[i - j - 1, l]
|
|
616
|
+
|
|
617
|
+
# solve b=b/D
|
|
618
|
+
for i in range(n):
|
|
619
|
+
x[i, l] /= rot_mat_ldl[i, 0]
|
|
620
|
+
|
|
621
|
+
# solve b=inv(t(L))b=inv(L*D*t(L))b
|
|
622
|
+
for i in range(n - 1, 0, -1):
|
|
623
|
+
jmax = np.minimum(n - i, k)
|
|
624
|
+
for j in range(1, jmax + 1):
|
|
625
|
+
x[i - 1, l] -= rot_mat_ldl[i - 1, j] * x[i + j - 1, l]
|
|
626
|
+
|
|
627
|
+
return x
|
|
628
|
+
|
|
629
|
+
def _bandsolve(self, rot_mat: np.ndarray, rhs_mat: np.ndarray) -> np.ndarray:
|
|
630
|
+
"""Solves the symmetric bandlinear system Ax = b.
|
|
631
|
+
|
|
632
|
+
Here A is the rotated form of the band matrix and b is the right hand side.
|
|
633
|
+
|
|
634
|
+
Args:
|
|
635
|
+
rot_mat: Band square matrix in the rotated form. It's the visual rotation
|
|
636
|
+
by 90 degrees of the matrix, where subdiagonal are discarded.
|
|
637
|
+
rhs_mat: right hand side of the equation. Can be either a vector or a
|
|
638
|
+
matrix. If not supplied, the function return the inverse of rot_mat.
|
|
639
|
+
|
|
640
|
+
Returns:
|
|
641
|
+
Solution of the linear problem.
|
|
642
|
+
"""
|
|
643
|
+
|
|
644
|
+
nrow = rot_mat.shape[0]
|
|
645
|
+
ncol = rot_mat.shape[1]
|
|
646
|
+
if (nrow == ncol) & (rot_mat[nrow - 1, ncol - 1] != 0):
|
|
647
|
+
raise ValueError('rot_mat should be a rotated matrix!')
|
|
648
|
+
if rot_mat[nrow - 1, 1] != 0:
|
|
649
|
+
raise ValueError('rot_mat should be a rotated matrix!')
|
|
650
|
+
if len(rhs_mat) != nrow:
|
|
651
|
+
raise ValueError('Dimension problem!')
|
|
652
|
+
|
|
653
|
+
return self._bandsolve_kernel(rot_mat, rhs_mat[:, np.newaxis])
|
|
654
|
+
|
|
655
|
+
def _wridge_solver(
|
|
656
|
+
self,
|
|
657
|
+
xx_rot: np.ndarray,
|
|
658
|
+
xy: np.ndarray,
|
|
659
|
+
degree: int,
|
|
660
|
+
penalty: float,
|
|
661
|
+
w: np.ndarray,
|
|
662
|
+
old_par: np.ndarray,
|
|
663
|
+
max_iterations: int = 1000,
|
|
664
|
+
tol: float = 1e-8,
|
|
665
|
+
) -> np.ndarray | None:
|
|
666
|
+
"""Fits B-Splines with weighted penalization over differences of parameters.
|
|
667
|
+
|
|
668
|
+
Args:
|
|
669
|
+
xx_rot: The matrix X'X where X is the design matrix. This argument is
|
|
670
|
+
given in the form of a band matrix, i.e., successive columns represent
|
|
671
|
+
superdiagonals.
|
|
672
|
+
xy: The vector of currently estimated points X'y, where y is the
|
|
673
|
+
y-coordinate of the data.
|
|
674
|
+
degree: The degree of the B-splines.
|
|
675
|
+
penalty: Positive penalty constant.
|
|
676
|
+
w: Vector of weights. The case w = np.ones(xx_rot.shape[0] - degree - 1)
|
|
677
|
+
corresponds to fitting P-splines with difference order degree + 1. See
|
|
678
|
+
Eilers, P., Marx, B. (1996) Flexible smoothing with B-splines and
|
|
679
|
+
penalties.
|
|
680
|
+
old_par: The previous parameter vector.
|
|
681
|
+
max_iterations: Maximum number of Newton-Raphson iterations to be
|
|
682
|
+
computed.
|
|
683
|
+
tol: The tolerance chosen to diagnose convergence of the adaptive ridge
|
|
684
|
+
procedure.
|
|
685
|
+
|
|
686
|
+
Returns:
|
|
687
|
+
The estimated parameter of the spline regression.
|
|
688
|
+
"""
|
|
689
|
+
|
|
690
|
+
def _hessian_solver(par, xx_rot, xy, penalty, w, diff):
|
|
691
|
+
"""Inverts the hessian and multiplies it by the score.
|
|
692
|
+
|
|
693
|
+
Args:
|
|
694
|
+
par: The parameter vector.
|
|
695
|
+
xx_rot: The matrix X'X where X is the design matrix. This argument is
|
|
696
|
+
given in the form of a rotated band matrix, i.e., successive columns
|
|
697
|
+
represent superdiagonals.
|
|
698
|
+
xy: The vector of currently estimated points X'y, where y is the
|
|
699
|
+
y-coordinate of the data.
|
|
700
|
+
penalty: Positive penalty constant.
|
|
701
|
+
w: Vector of weights.
|
|
702
|
+
diff: The order of the differences of the parameter. Equals degree + 1
|
|
703
|
+
in adaptive spline regression.
|
|
704
|
+
|
|
705
|
+
Returns:
|
|
706
|
+
The solution of the linear system: (X'X + penalty*D'WD)^{-1} X'y - par
|
|
707
|
+
"""
|
|
708
|
+
if xx_rot.shape[1] != diff + 1:
|
|
709
|
+
raise ValueError('Error: xx_rot must have diff + 1 columns')
|
|
710
|
+
return (
|
|
711
|
+
self._bandsolve(xx_rot + penalty * self._band_weight(w, diff), xy)[
|
|
712
|
+
:, 0
|
|
713
|
+
]
|
|
714
|
+
- par
|
|
715
|
+
)
|
|
716
|
+
|
|
717
|
+
par = None
|
|
718
|
+
for _ in range(max_iterations):
|
|
719
|
+
par = old_par + _hessian_solver(
|
|
720
|
+
par=old_par,
|
|
721
|
+
xx_rot=xx_rot,
|
|
722
|
+
xy=xy,
|
|
723
|
+
penalty=penalty,
|
|
724
|
+
w=w,
|
|
725
|
+
diff=degree + 1,
|
|
726
|
+
)
|
|
727
|
+
index = old_par != 0
|
|
728
|
+
rel_error = max(abs(par - old_par)[index] / abs(old_par)[index])
|
|
729
|
+
if rel_error < tol:
|
|
730
|
+
break
|
|
731
|
+
old_par = par
|
|
732
|
+
|
|
733
|
+
return par
|