rustystats 0.1.5__cp313-cp313-manylinux_2_34_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rustystats/__init__.py +151 -0
- rustystats/_rustystats.cpython-313-x86_64-linux-gnu.so +0 -0
- rustystats/diagnostics.py +2471 -0
- rustystats/families.py +423 -0
- rustystats/formula.py +1074 -0
- rustystats/glm.py +249 -0
- rustystats/interactions.py +1246 -0
- rustystats/links.py +221 -0
- rustystats/splines.py +367 -0
- rustystats/target_encoding.py +375 -0
- rustystats-0.1.5.dist-info/METADATA +476 -0
- rustystats-0.1.5.dist-info/RECORD +14 -0
- rustystats-0.1.5.dist-info/WHEEL +4 -0
- rustystats-0.1.5.dist-info/licenses/LICENSE +21 -0
rustystats/links.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Link Functions for GLMs
|
|
3
|
+
=======================
|
|
4
|
+
|
|
5
|
+
Link functions connect the linear predictor (η = Xβ) to the mean (μ).
|
|
6
|
+
They're written as:
|
|
7
|
+
|
|
8
|
+
η = g(μ) or equivalently μ = g⁻¹(η)
|
|
9
|
+
|
|
10
|
+
Why Do We Need Link Functions?
|
|
11
|
+
------------------------------
|
|
12
|
+
|
|
13
|
+
Different types of responses need different transformations:
|
|
14
|
+
|
|
15
|
+
1. **Continuous data**: Use identity link (η = μ)
|
|
16
|
+
- No transformation needed
|
|
17
|
+
- Predictions can be any real value
|
|
18
|
+
|
|
19
|
+
2. **Count data**: Use log link (η = log(μ))
|
|
20
|
+
- Ensures predictions are always positive (μ = exp(η) > 0)
|
|
21
|
+
- Gives multiplicative interpretation to coefficients
|
|
22
|
+
|
|
23
|
+
3. **Binary data**: Use logit link (η = log(μ/(1-μ)))
|
|
24
|
+
- Ensures predictions are probabilities (0 < μ < 1)
|
|
25
|
+
- Coefficients are log-odds ratios
|
|
26
|
+
|
|
27
|
+
Choosing a Link Function
|
|
28
|
+
------------------------
|
|
29
|
+
|
|
30
|
+
+------------------+-------------------+-------------------+--------------------+
|
|
31
|
+
| Family | Canonical Link | Common Alternative| Interpretation |
|
|
32
|
+
+==================+===================+===================+====================+
|
|
33
|
+
| Gaussian | Identity | Log | Additive effects |
|
|
34
|
+
| Poisson | Log | - | Multiplicative |
|
|
35
|
+
| Binomial | Logit | Probit, Cloglog | Odds ratios |
|
|
36
|
+
| Gamma | Inverse (1/μ) | Log | Multiplicative |
|
|
37
|
+
+------------------+-------------------+-------------------+--------------------+
|
|
38
|
+
|
|
39
|
+
In actuarial practice, the **log link** is extremely common because:
|
|
40
|
+
- It ensures positive predictions (important for counts and amounts)
|
|
41
|
+
- Coefficients have multiplicative interpretation (rate relativities!)
|
|
42
|
+
- It's consistent across frequency and severity models
|
|
43
|
+
|
|
44
|
+
Examples
|
|
45
|
+
--------
|
|
46
|
+
>>> import rustystats as rs
|
|
47
|
+
>>> import numpy as np
|
|
48
|
+
>>>
|
|
49
|
+
>>> # Log link example
|
|
50
|
+
>>> log_link = rs.links.Log()
|
|
51
|
+
>>> eta = np.array([0.0, 0.5, 1.0]) # Linear predictor values
|
|
52
|
+
>>> mu = log_link.inverse(eta)
|
|
53
|
+
>>> print(mu) # [1.0, 1.649, 2.718] - always positive!
|
|
54
|
+
>>>
|
|
55
|
+
>>> # Logit link example
|
|
56
|
+
>>> logit_link = rs.links.Logit()
|
|
57
|
+
>>> eta = np.array([-2.0, 0.0, 2.0])
|
|
58
|
+
>>> mu = logit_link.inverse(eta)
|
|
59
|
+
>>> print(mu) # [0.119, 0.5, 0.881] - always between 0 and 1!
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
# Import the Rust implementations
|
|
63
|
+
from rustystats._rustystats import (
|
|
64
|
+
IdentityLink as _IdentityLink,
|
|
65
|
+
LogLink as _LogLink,
|
|
66
|
+
LogitLink as _LogitLink,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def Identity():
|
|
71
|
+
"""
|
|
72
|
+
Identity link function: η = μ
|
|
73
|
+
|
|
74
|
+
The simplest link - no transformation at all.
|
|
75
|
+
|
|
76
|
+
Properties
|
|
77
|
+
----------
|
|
78
|
+
- Link: η = μ
|
|
79
|
+
- Inverse: μ = η
|
|
80
|
+
- Derivative: dη/dμ = 1
|
|
81
|
+
|
|
82
|
+
When to Use
|
|
83
|
+
-----------
|
|
84
|
+
- Gaussian family (standard linear regression)
|
|
85
|
+
- When you want to model the mean directly
|
|
86
|
+
- When predictions can be any real value
|
|
87
|
+
|
|
88
|
+
Interpretation
|
|
89
|
+
--------------
|
|
90
|
+
Coefficients have an additive interpretation:
|
|
91
|
+
|
|
92
|
+
- If β = 10 for variable X
|
|
93
|
+
- Then a 1-unit increase in X increases the predicted mean by 10
|
|
94
|
+
|
|
95
|
+
Example
|
|
96
|
+
-------
|
|
97
|
+
>>> link = rs.links.Identity()
|
|
98
|
+
>>> mu = np.array([1.0, 2.0, 3.0])
|
|
99
|
+
>>> eta = link.link(mu) # [1.0, 2.0, 3.0] - unchanged
|
|
100
|
+
"""
|
|
101
|
+
return _IdentityLink()
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def Log():
|
|
105
|
+
"""
|
|
106
|
+
Log link function: η = log(μ)
|
|
107
|
+
|
|
108
|
+
Ensures predictions are always positive. The workhorse of actuarial GLMs.
|
|
109
|
+
|
|
110
|
+
Properties
|
|
111
|
+
----------
|
|
112
|
+
- Link: η = log(μ)
|
|
113
|
+
- Inverse: μ = exp(η)
|
|
114
|
+
- Derivative: dη/dμ = 1/μ
|
|
115
|
+
|
|
116
|
+
When to Use
|
|
117
|
+
-----------
|
|
118
|
+
- Poisson family (claim frequency)
|
|
119
|
+
- Gamma family (claim severity)
|
|
120
|
+
- Whenever the response must be positive
|
|
121
|
+
|
|
122
|
+
Multiplicative Interpretation (Important!)
|
|
123
|
+
------------------------------------------
|
|
124
|
+
Coefficients represent MULTIPLICATIVE effects (rate relativities):
|
|
125
|
+
|
|
126
|
+
- If β = 0.2 for "young driver" indicator
|
|
127
|
+
- Then exp(0.2) ≈ 1.22
|
|
128
|
+
- Young drivers have 1.22× the expected count/amount (22% higher)
|
|
129
|
+
|
|
130
|
+
This is why log link is standard in insurance pricing:
|
|
131
|
+
- Base rate × relativity_1 × relativity_2 × ...
|
|
132
|
+
- On log scale: log(base) + β₁ + β₂ + ...
|
|
133
|
+
|
|
134
|
+
Combining Frequency and Severity
|
|
135
|
+
--------------------------------
|
|
136
|
+
If both models use log link:
|
|
137
|
+
- Frequency: log(μ_freq) = X β_freq
|
|
138
|
+
- Severity: log(μ_sev) = X β_sev
|
|
139
|
+
- Pure Premium: log(μ_freq × μ_sev) = X (β_freq + β_sev)
|
|
140
|
+
|
|
141
|
+
The pure premium coefficients are just the SUM!
|
|
142
|
+
|
|
143
|
+
Example
|
|
144
|
+
-------
|
|
145
|
+
>>> link = rs.links.Log()
|
|
146
|
+
>>>
|
|
147
|
+
>>> # If linear predictor η = 0, predicted count/amount = exp(0) = 1
|
|
148
|
+
>>> # If η increases by 0.1, prediction multiplied by exp(0.1) ≈ 1.105
|
|
149
|
+
>>>
|
|
150
|
+
>>> eta = np.array([0.0, 0.1, 0.2])
|
|
151
|
+
>>> mu = link.inverse(eta)
|
|
152
|
+
>>> print(mu) # [1.0, 1.105, 1.221]
|
|
153
|
+
"""
|
|
154
|
+
return _LogLink()
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def Logit():
|
|
158
|
+
"""
|
|
159
|
+
Logit link function: η = log(μ/(1-μ))
|
|
160
|
+
|
|
161
|
+
Transforms probabilities to the log-odds scale.
|
|
162
|
+
The foundation of logistic regression.
|
|
163
|
+
|
|
164
|
+
Properties
|
|
165
|
+
----------
|
|
166
|
+
- Link: η = log(μ/(1-μ)) [the "log-odds"]
|
|
167
|
+
- Inverse: μ = 1/(1+exp(-η)) [the "sigmoid" or "logistic" function]
|
|
168
|
+
- Derivative: dη/dμ = 1/(μ(1-μ))
|
|
169
|
+
|
|
170
|
+
When to Use
|
|
171
|
+
-----------
|
|
172
|
+
- Binomial family (binary outcomes)
|
|
173
|
+
- Modeling probabilities
|
|
174
|
+
- Yes/no, claim/no-claim type questions
|
|
175
|
+
|
|
176
|
+
Understanding Log-Odds
|
|
177
|
+
----------------------
|
|
178
|
+
If μ = 0.8 (80% probability):
|
|
179
|
+
- Odds = μ/(1-μ) = 0.8/0.2 = 4 ("4-to-1 odds")
|
|
180
|
+
- Log-odds = log(4) ≈ 1.39
|
|
181
|
+
|
|
182
|
+
The logit function maps:
|
|
183
|
+
- μ = 0.5 → η = 0 (even odds)
|
|
184
|
+
- μ → 0 maps to η → -∞
|
|
185
|
+
- μ → 1 maps to η → +∞
|
|
186
|
+
|
|
187
|
+
Odds Ratio Interpretation
|
|
188
|
+
-------------------------
|
|
189
|
+
Coefficients represent LOG odds ratios:
|
|
190
|
+
|
|
191
|
+
- If β = 0.5 for "previous claims" indicator
|
|
192
|
+
- Then exp(0.5) ≈ 1.65
|
|
193
|
+
- People with previous claims have 1.65× the ODDS of claiming
|
|
194
|
+
- This is NOT the same as 1.65× the probability!
|
|
195
|
+
|
|
196
|
+
Converting to Probability Change
|
|
197
|
+
--------------------------------
|
|
198
|
+
The effect on probability depends on the baseline probability:
|
|
199
|
+
|
|
200
|
+
- At baseline μ=0.1: a β=0.5 coefficient changes probability to ~0.15
|
|
201
|
+
- At baseline μ=0.5: the same β changes probability to ~0.62
|
|
202
|
+
|
|
203
|
+
This is why we report odds ratios, not probability ratios.
|
|
204
|
+
|
|
205
|
+
Example
|
|
206
|
+
-------
|
|
207
|
+
>>> link = rs.links.Logit()
|
|
208
|
+
>>>
|
|
209
|
+
>>> # η = 0 means 50% probability
|
|
210
|
+
>>> # η = 2 means high probability (about 88%)
|
|
211
|
+
>>> # η = -2 means low probability (about 12%)
|
|
212
|
+
>>>
|
|
213
|
+
>>> eta = np.array([-2.0, 0.0, 2.0])
|
|
214
|
+
>>> mu = link.inverse(eta)
|
|
215
|
+
>>> print(mu) # [0.119, 0.5, 0.881]
|
|
216
|
+
"""
|
|
217
|
+
return _LogitLink()
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
# For backwards compatibility and convenience
|
|
221
|
+
__all__ = ["Identity", "Log", "Logit"]
|
rustystats/splines.py
ADDED
|
@@ -0,0 +1,367 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Spline basis functions for non-linear continuous effects in GLMs.
|
|
3
|
+
|
|
4
|
+
This module provides B-splines and natural splines, which are essential
|
|
5
|
+
for modeling non-linear relationships between continuous predictors
|
|
6
|
+
and the response variable.
|
|
7
|
+
|
|
8
|
+
Key Functions
|
|
9
|
+
-------------
|
|
10
|
+
- `bs()` - B-spline basis (flexible piecewise polynomials)
|
|
11
|
+
- `ns()` - Natural spline basis (linear extrapolation at boundaries)
|
|
12
|
+
|
|
13
|
+
Example
|
|
14
|
+
-------
|
|
15
|
+
>>> import rustystats as rs
|
|
16
|
+
>>> import numpy as np
|
|
17
|
+
>>>
|
|
18
|
+
>>> # Create spline basis for age with 5 degrees of freedom
|
|
19
|
+
>>> age = np.array([25, 35, 45, 55, 65])
|
|
20
|
+
>>> age_basis = rs.bs(age, df=5)
|
|
21
|
+
>>> print(age_basis.shape)
|
|
22
|
+
(5, 4)
|
|
23
|
+
|
|
24
|
+
>>> # Use in formula API
|
|
25
|
+
>>> result = rs.glm(
|
|
26
|
+
... "y ~ bs(age, df=5) + C(region)",
|
|
27
|
+
... data=data,
|
|
28
|
+
... family="poisson"
|
|
29
|
+
... ).fit()
|
|
30
|
+
|
|
31
|
+
When to Use Each Type
|
|
32
|
+
---------------------
|
|
33
|
+
**B-splines (`bs`):**
|
|
34
|
+
- More flexible at boundaries
|
|
35
|
+
- Good when you don't need to extrapolate
|
|
36
|
+
- Standard choice for most applications
|
|
37
|
+
|
|
38
|
+
**Natural splines (`ns`):**
|
|
39
|
+
- Linear extrapolation beyond boundaries
|
|
40
|
+
- Better for prediction on new data outside training range
|
|
41
|
+
- More stable parameter estimates at boundaries
|
|
42
|
+
- Recommended for actuarial applications
|
|
43
|
+
|
|
44
|
+
Performance Note
|
|
45
|
+
----------------
|
|
46
|
+
Spline basis computation is implemented in Rust with parallel
|
|
47
|
+
evaluation over observations, making it very fast even for
|
|
48
|
+
large datasets.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
from __future__ import annotations
|
|
52
|
+
|
|
53
|
+
from typing import Optional, Union, List, Tuple, TYPE_CHECKING
|
|
54
|
+
import numpy as np
|
|
55
|
+
|
|
56
|
+
if TYPE_CHECKING:
|
|
57
|
+
import polars as pl
|
|
58
|
+
|
|
59
|
+
# Import Rust implementations
|
|
60
|
+
from rustystats._rustystats import (
|
|
61
|
+
bs_py as _bs_rust,
|
|
62
|
+
ns_py as _ns_rust,
|
|
63
|
+
bs_knots_py as _bs_knots_rust,
|
|
64
|
+
bs_names_py as _bs_names_rust,
|
|
65
|
+
ns_names_py as _ns_names_rust,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def bs(
|
|
70
|
+
x: np.ndarray,
|
|
71
|
+
df: int = 5,
|
|
72
|
+
degree: int = 3,
|
|
73
|
+
knots: Optional[List[float]] = None,
|
|
74
|
+
boundary_knots: Optional[Tuple[float, float]] = None,
|
|
75
|
+
include_intercept: bool = False,
|
|
76
|
+
) -> np.ndarray:
|
|
77
|
+
"""
|
|
78
|
+
Compute B-spline basis matrix.
|
|
79
|
+
|
|
80
|
+
B-splines (basis splines) are piecewise polynomial functions that provide
|
|
81
|
+
a flexible way to model non-linear relationships. They are the foundation
|
|
82
|
+
for many modern smoothing techniques.
|
|
83
|
+
|
|
84
|
+
Parameters
|
|
85
|
+
----------
|
|
86
|
+
x : array-like
|
|
87
|
+
Data points to evaluate the basis at. Will be converted to 1D numpy array.
|
|
88
|
+
df : int, default=5
|
|
89
|
+
Degrees of freedom, i.e., number of basis functions to generate.
|
|
90
|
+
Higher df = more flexible fit but risk of overfitting.
|
|
91
|
+
Typical range: 3-10 for most applications.
|
|
92
|
+
degree : int, default=3
|
|
93
|
+
Polynomial degree of the splines:
|
|
94
|
+
- 0: Step functions (not smooth)
|
|
95
|
+
- 1: Linear splines (continuous but not smooth)
|
|
96
|
+
- 2: Quadratic splines (smooth first derivative)
|
|
97
|
+
- 3: Cubic splines (smooth first and second derivatives, most common)
|
|
98
|
+
knots : list, optional
|
|
99
|
+
Interior knot positions. If not provided, knots are placed at
|
|
100
|
+
quantiles of x based on the df parameter.
|
|
101
|
+
boundary_knots : tuple, optional
|
|
102
|
+
(min, max) defining the boundary of the spline basis.
|
|
103
|
+
If not provided, uses the range of x.
|
|
104
|
+
include_intercept : bool, default=False
|
|
105
|
+
Whether to include the intercept (constant) basis function.
|
|
106
|
+
Usually False when used in regression models that already have
|
|
107
|
+
an intercept term.
|
|
108
|
+
|
|
109
|
+
Returns
|
|
110
|
+
-------
|
|
111
|
+
numpy.ndarray
|
|
112
|
+
Basis matrix of shape (n, k) where n is the length of x and
|
|
113
|
+
k is the number of basis functions (df or df-1 depending on
|
|
114
|
+
include_intercept).
|
|
115
|
+
|
|
116
|
+
Notes
|
|
117
|
+
-----
|
|
118
|
+
The number of basis functions is:
|
|
119
|
+
- df if include_intercept=True
|
|
120
|
+
- df-1 if include_intercept=False (default)
|
|
121
|
+
|
|
122
|
+
B-splines have the "partition of unity" property: the basis functions
|
|
123
|
+
sum to 1 at any point. They are also non-negative and have local support
|
|
124
|
+
(each basis function is non-zero only over a limited range).
|
|
125
|
+
|
|
126
|
+
Examples
|
|
127
|
+
--------
|
|
128
|
+
>>> import rustystats as rs
|
|
129
|
+
>>> import numpy as np
|
|
130
|
+
>>>
|
|
131
|
+
>>> # Basic usage
|
|
132
|
+
>>> x = np.linspace(0, 10, 100)
|
|
133
|
+
>>> basis = rs.bs(x, df=5)
|
|
134
|
+
>>> print(basis.shape)
|
|
135
|
+
(100, 4)
|
|
136
|
+
|
|
137
|
+
>>> # With explicit knots
|
|
138
|
+
>>> basis = rs.bs(x, knots=[2.5, 5.0, 7.5], degree=3)
|
|
139
|
+
>>> print(basis.shape)
|
|
140
|
+
(100, 7)
|
|
141
|
+
|
|
142
|
+
>>> # For use in regression with intercept already present
|
|
143
|
+
>>> X = np.column_stack([np.ones(100), rs.bs(x, df=4)])
|
|
144
|
+
|
|
145
|
+
See Also
|
|
146
|
+
--------
|
|
147
|
+
ns : Natural spline basis (linear at boundaries)
|
|
148
|
+
"""
|
|
149
|
+
# Convert to numpy array
|
|
150
|
+
x = np.asarray(x, dtype=np.float64).ravel()
|
|
151
|
+
|
|
152
|
+
if knots is not None:
|
|
153
|
+
# Use explicit knots
|
|
154
|
+
return _bs_knots_rust(x, knots, degree, boundary_knots)
|
|
155
|
+
else:
|
|
156
|
+
# Compute knots automatically based on df
|
|
157
|
+
return _bs_rust(x, df, degree, boundary_knots, include_intercept)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def ns(
|
|
161
|
+
x: np.ndarray,
|
|
162
|
+
df: int = 5,
|
|
163
|
+
knots: Optional[List[float]] = None,
|
|
164
|
+
boundary_knots: Optional[Tuple[float, float]] = None,
|
|
165
|
+
include_intercept: bool = False,
|
|
166
|
+
) -> np.ndarray:
|
|
167
|
+
"""
|
|
168
|
+
Compute natural cubic spline basis matrix.
|
|
169
|
+
|
|
170
|
+
Natural splines are cubic splines with the additional constraint that
|
|
171
|
+
the function is linear beyond the boundary knots. This constraint:
|
|
172
|
+
- Reduces the effective degrees of freedom by 2
|
|
173
|
+
- Provides more sensible extrapolation behavior
|
|
174
|
+
- Often gives more stable parameter estimates
|
|
175
|
+
|
|
176
|
+
Parameters
|
|
177
|
+
----------
|
|
178
|
+
x : array-like
|
|
179
|
+
Data points to evaluate the basis at.
|
|
180
|
+
df : int, default=5
|
|
181
|
+
Degrees of freedom. The number of basis functions generated
|
|
182
|
+
will be df (or df-1 if include_intercept=False).
|
|
183
|
+
knots : list, optional
|
|
184
|
+
Interior knot positions. If not provided, knots are placed at
|
|
185
|
+
quantiles of x.
|
|
186
|
+
boundary_knots : tuple, optional
|
|
187
|
+
(min, max) defining the boundary. Beyond these points, the
|
|
188
|
+
spline is constrained to be linear.
|
|
189
|
+
include_intercept : bool, default=False
|
|
190
|
+
Whether to include an intercept basis function.
|
|
191
|
+
|
|
192
|
+
Returns
|
|
193
|
+
-------
|
|
194
|
+
numpy.ndarray
|
|
195
|
+
Basis matrix of shape (n, k).
|
|
196
|
+
|
|
197
|
+
Notes
|
|
198
|
+
-----
|
|
199
|
+
Natural splines impose the constraint that the second derivative
|
|
200
|
+
is zero at the boundaries. This means:
|
|
201
|
+
|
|
202
|
+
1. The spline is linear (not curved) outside the boundary knots
|
|
203
|
+
2. Extrapolation beyond the data range is more sensible
|
|
204
|
+
3. The fit is often more stable near the boundaries
|
|
205
|
+
|
|
206
|
+
For these reasons, natural splines are often preferred for:
|
|
207
|
+
- Prediction on new data that may be outside the training range
|
|
208
|
+
- Actuarial applications where extrapolation is common
|
|
209
|
+
- When boundary behavior needs to be controlled
|
|
210
|
+
|
|
211
|
+
Examples
|
|
212
|
+
--------
|
|
213
|
+
>>> import rustystats as rs
|
|
214
|
+
>>> import numpy as np
|
|
215
|
+
>>>
|
|
216
|
+
>>> # Basic usage
|
|
217
|
+
>>> age = np.array([20, 30, 40, 50, 60, 70])
|
|
218
|
+
>>> basis = rs.ns(age, df=4)
|
|
219
|
+
>>> print(basis.shape)
|
|
220
|
+
(6, 3)
|
|
221
|
+
|
|
222
|
+
>>> # For an age effect in a GLM
|
|
223
|
+
>>> # The spline will be linear for ages below 20 and above 70
|
|
224
|
+
>>> basis = rs.ns(age, df=4, boundary_knots=(20, 70))
|
|
225
|
+
|
|
226
|
+
See Also
|
|
227
|
+
--------
|
|
228
|
+
bs : B-spline basis (more flexible at boundaries)
|
|
229
|
+
"""
|
|
230
|
+
# Convert to numpy array
|
|
231
|
+
x = np.asarray(x, dtype=np.float64).ravel()
|
|
232
|
+
|
|
233
|
+
# Natural splines don't support explicit interior knots in our implementation
|
|
234
|
+
# (knots are computed from df)
|
|
235
|
+
return _ns_rust(x, df, boundary_knots, include_intercept)
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def bs_names(
|
|
239
|
+
var_name: str,
|
|
240
|
+
df: int,
|
|
241
|
+
include_intercept: bool = False,
|
|
242
|
+
) -> List[str]:
|
|
243
|
+
"""
|
|
244
|
+
Generate column names for B-spline basis functions.
|
|
245
|
+
|
|
246
|
+
Parameters
|
|
247
|
+
----------
|
|
248
|
+
var_name : str
|
|
249
|
+
Name of the original variable (e.g., "age")
|
|
250
|
+
df : int
|
|
251
|
+
Degrees of freedom used
|
|
252
|
+
include_intercept : bool, default=False
|
|
253
|
+
Whether intercept was included
|
|
254
|
+
|
|
255
|
+
Returns
|
|
256
|
+
-------
|
|
257
|
+
list of str
|
|
258
|
+
Names like ['bs(age, 1/5)', 'bs(age, 2/5)', ...]
|
|
259
|
+
|
|
260
|
+
Example
|
|
261
|
+
-------
|
|
262
|
+
>>> rs.bs_names("age", df=5)
|
|
263
|
+
['bs(age, 2/5)', 'bs(age, 3/5)', 'bs(age, 4/5)', 'bs(age, 5/5)']
|
|
264
|
+
"""
|
|
265
|
+
return _bs_names_rust(var_name, df, include_intercept)
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def ns_names(
|
|
269
|
+
var_name: str,
|
|
270
|
+
df: int,
|
|
271
|
+
include_intercept: bool = False,
|
|
272
|
+
) -> List[str]:
|
|
273
|
+
"""
|
|
274
|
+
Generate column names for natural spline basis functions.
|
|
275
|
+
|
|
276
|
+
Parameters
|
|
277
|
+
----------
|
|
278
|
+
var_name : str
|
|
279
|
+
Name of the original variable
|
|
280
|
+
df : int
|
|
281
|
+
Degrees of freedom used
|
|
282
|
+
include_intercept : bool, default=False
|
|
283
|
+
Whether intercept was included
|
|
284
|
+
|
|
285
|
+
Returns
|
|
286
|
+
-------
|
|
287
|
+
list of str
|
|
288
|
+
Names like ['ns(age, 1/5)', 'ns(age, 2/5)', ...]
|
|
289
|
+
"""
|
|
290
|
+
return _ns_names_rust(var_name, df, include_intercept)
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
class SplineTerm:
|
|
294
|
+
"""
|
|
295
|
+
Represents a spline term for use in formula parsing.
|
|
296
|
+
|
|
297
|
+
This class stores the specification for a spline transformation
|
|
298
|
+
and can compute the basis matrix when given data.
|
|
299
|
+
|
|
300
|
+
Attributes
|
|
301
|
+
----------
|
|
302
|
+
var_name : str
|
|
303
|
+
Name of the variable to transform
|
|
304
|
+
spline_type : str
|
|
305
|
+
Either 'bs' or 'ns'
|
|
306
|
+
df : int
|
|
307
|
+
Degrees of freedom
|
|
308
|
+
degree : int
|
|
309
|
+
Polynomial degree (for B-splines)
|
|
310
|
+
boundary_knots : tuple or None
|
|
311
|
+
Boundary knot positions
|
|
312
|
+
"""
|
|
313
|
+
|
|
314
|
+
def __init__(
|
|
315
|
+
self,
|
|
316
|
+
var_name: str,
|
|
317
|
+
spline_type: str = "bs",
|
|
318
|
+
df: int = 5,
|
|
319
|
+
degree: int = 3,
|
|
320
|
+
boundary_knots: Optional[Tuple[float, float]] = None,
|
|
321
|
+
):
|
|
322
|
+
self.var_name = var_name
|
|
323
|
+
self.spline_type = spline_type.lower()
|
|
324
|
+
self.df = df
|
|
325
|
+
self.degree = degree
|
|
326
|
+
self.boundary_knots = boundary_knots
|
|
327
|
+
|
|
328
|
+
if self.spline_type not in ("bs", "ns"):
|
|
329
|
+
raise ValueError(f"spline_type must be 'bs' or 'ns', got '{spline_type}'")
|
|
330
|
+
|
|
331
|
+
def transform(self, x: np.ndarray) -> Tuple[np.ndarray, List[str]]:
|
|
332
|
+
"""
|
|
333
|
+
Compute the spline basis for the given data.
|
|
334
|
+
|
|
335
|
+
Parameters
|
|
336
|
+
----------
|
|
337
|
+
x : np.ndarray
|
|
338
|
+
Data values to transform
|
|
339
|
+
|
|
340
|
+
Returns
|
|
341
|
+
-------
|
|
342
|
+
basis : np.ndarray
|
|
343
|
+
Basis matrix
|
|
344
|
+
names : list of str
|
|
345
|
+
Column names for the basis
|
|
346
|
+
"""
|
|
347
|
+
if self.spline_type == "bs":
|
|
348
|
+
basis = bs(x, df=self.df, degree=self.degree,
|
|
349
|
+
boundary_knots=self.boundary_knots, include_intercept=False)
|
|
350
|
+
names = bs_names(self.var_name, self.df, include_intercept=False)
|
|
351
|
+
else:
|
|
352
|
+
basis = ns(x, df=self.df, boundary_knots=self.boundary_knots,
|
|
353
|
+
include_intercept=False)
|
|
354
|
+
names = ns_names(self.var_name, self.df, include_intercept=False)
|
|
355
|
+
|
|
356
|
+
# Ensure names match columns
|
|
357
|
+
if len(names) != basis.shape[1]:
|
|
358
|
+
names = [f"{self.spline_type}({self.var_name}, {i+1}/{basis.shape[1]})"
|
|
359
|
+
for i in range(basis.shape[1])]
|
|
360
|
+
|
|
361
|
+
return basis, names
|
|
362
|
+
|
|
363
|
+
def __repr__(self) -> str:
|
|
364
|
+
if self.spline_type == "bs":
|
|
365
|
+
return f"bs({self.var_name}, df={self.df}, degree={self.degree})"
|
|
366
|
+
else:
|
|
367
|
+
return f"ns({self.var_name}, df={self.df})"
|