ptlasso 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,26 @@
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *.egg-info/
5
+ dist/
6
+ build/
7
+ .eggs/
8
+
9
+ # Virtual environments
10
+ .venv/
11
+ venv/
12
+ env/
13
+
14
+ # Testing
15
+ .pytest_cache/
16
+ .coverage
17
+ htmlcov/
18
+
19
+ # Tools
20
+ .ruff_cache/
21
+
22
+ # Jupyter
23
+ .ipynb_checkpoints/
24
+
25
+ # macOS
26
+ .DS_Store
ptlasso-0.1.0/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 ptlasso authors
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
ptlasso-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,328 @@
1
+ Metadata-Version: 2.4
2
+ Name: ptlasso
3
+ Version: 0.1.0
4
+ Summary: Pretrained Lasso: a two-step procedure for sparse linear models with grouped samples
5
+ Project-URL: Homepage, https://github.com/tlemenestrel/ptlasso
6
+ Project-URL: Repository, https://github.com/tlemenestrel/ptlasso
7
+ Project-URL: Bug Tracker, https://github.com/tlemenestrel/ptlasso/issues
8
+ Author: Erin Craig, Thomas Le Menestrel, Robert Tibshirani
9
+ License: MIT License
10
+
11
+ Copyright (c) 2026 ptlasso authors
12
+
13
+ Permission is hereby granted, free of charge, to any person obtaining a copy
14
+ of this software and associated documentation files (the "Software"), to deal
15
+ in the Software without restriction, including without limitation the rights
16
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
17
+ copies of the Software, and to permit persons to whom the Software is
18
+ furnished to do so, subject to the following conditions:
19
+
20
+ The above copyright notice and this permission notice shall be included in all
21
+ copies or substantial portions of the Software.
22
+
23
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
24
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
25
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
26
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
27
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
28
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
29
+ SOFTWARE.
30
+ License-File: LICENSE
31
+ Keywords: grouped data,lasso,machine learning,pretraining,regularization,sparse,statistics
32
+ Classifier: Development Status :: 3 - Alpha
33
+ Classifier: Intended Audience :: Science/Research
34
+ Classifier: License :: OSI Approved :: MIT License
35
+ Classifier: Programming Language :: Python :: 3
36
+ Classifier: Programming Language :: Python :: 3.9
37
+ Classifier: Programming Language :: Python :: 3.10
38
+ Classifier: Programming Language :: Python :: 3.11
39
+ Classifier: Programming Language :: Python :: 3.12
40
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
41
+ Classifier: Topic :: Scientific/Engineering :: Mathematics
42
+ Requires-Python: >=3.9
43
+ Requires-Dist: adelie>=1.0
44
+ Requires-Dist: matplotlib>=3.5
45
+ Requires-Dist: numpy>=1.22
46
+ Requires-Dist: scikit-learn>=1.1
47
+ Provides-Extra: dev
48
+ Requires-Dist: pytest; extra == 'dev'
49
+ Requires-Dist: pytest-cov; extra == 'dev'
50
+ Requires-Dist: ruff; extra == 'dev'
51
+ Description-Content-Type: text/markdown
52
+
53
+ # ptlasso
54
+
55
+ Python implementation of the **Pretrained Lasso** — a two-step procedure for fitting sparse linear models when samples belong to distinct groups, leveraging shared structure across groups via pretraining.
56
+
57
+ Based on:
58
+ > Craig, E., Pilanci, M., Le Menestrel, T., Narasimhan, B., Rivas, M. A., Gullaksen, S. E., ... & Tibshirani, R. (2025). Pretraining and the lasso. *Journal of the Royal Statistical Society Series B: Statistical Methodology*, qkaf050.
59
+
60
+ ---
61
+
62
+ ## The idea
63
+
64
+ Standard group-specific Lasso models are fit independently per group, ignoring shared signal. The Pretrained Lasso fits in two steps:
65
+
66
+ **Step 1 — Overall model.** Fit a Lasso on all samples to capture shared structure:
67
+
68
+ $$\hat{\beta}^{\text{overall}} = \arg\min_\beta \frac{1}{2n}\|y - X\beta\|^2 + \lambda\|\beta\|_1$$
69
+
70
+ **Step 2 — Group models.** For each group $k$, fit a Lasso with an offset equal to $\alpha$ times the overall model's linear predictor:
71
+
72
+ $$\hat{\beta}^{(k)} = \arg\min_\beta \frac{1}{2n_k}\|y^{(k)} - \underbrace{\alpha \cdot X^{(k)}\hat{\beta}^{\text{overall}}}_{\text{offset}} - X^{(k)}\beta\|^2 + \lambda_k\|\beta\|_1$$
73
+
74
+ The parameter $\alpha \in [0, 1]$ controls the pretraining strength:
75
+ - $\alpha = 0$: pure group-specific models, no pretraining
76
+ - $\alpha = 1$: group models explain residuals from the overall model
77
+ - $\alpha \in (0, 1)$: group models are anchored to the overall fit
78
+
79
+ Final prediction for group $k$: $\hat{y}^{(k)} = \alpha \cdot X^{(k)}\hat{\beta}^{\text{overall}} + X^{(k)}\hat{\beta}^{(k)}$
80
+
81
+ Supports **gaussian**, **binomial**, and **multinomial** families.
82
+
83
+ ---
84
+
85
+ ## Installation
86
+
87
+ ```bash
88
+ pip install ptlasso
89
+ ```
90
+
91
+ Requires Python ≥ 3.9 and [adelie](https://github.com/JamesYang007/adelie) for the underlying Lasso solver, which supports fitting with offsets (unlike scikit-learn).
92
+
93
+ ---
94
+
95
+ ## Quick start
96
+
97
+ ```python
98
+ import numpy as np
99
+ from ptlasso import PretrainedLasso, PretrainedLassoCV
100
+
101
+ rng = np.random.default_rng(42)
102
+ n, p, k = 300, 100, 3
103
+
104
+ X = rng.standard_normal((n, p))
105
+ groups = rng.integers(0, k, size=n)
106
+ beta = np.zeros(p)
107
+ beta[:5] = [2, -1.5, 1, -0.8, 0.5]
108
+ y = X @ beta + 0.5 * rng.standard_normal(n)
109
+
110
+ # Fixed alpha
111
+ model = PretrainedLasso(alpha=0.5)
112
+ model.fit(X, y, groups)
113
+ print(model)
114
+ # PretrainedLasso(alpha=0.5, family='gaussian', overall_lambda='lambda.1se', ...)
115
+ # family : gaussian
116
+ # n_features : 100
117
+ # n_groups : 3
118
+ # overall |Ŝ| : |Ŝ| = 5 / 100 [0, 1, 2, 3, 4]
119
+ # pretrain |Ŝ| : 0: |Ŝ|=5, 1: |Ŝ|=4, 2: |Ŝ|=6
120
+
121
+ y_pred = model.predict(X, groups)
122
+ print("R²:", model.score(X, y, groups))
123
+
124
+ # Cross-validate over alpha
125
+ cv = PretrainedLassoCV(alphas=[0.0, 0.25, 0.5, 0.75, 1.0])
126
+ cv.fit(X, y, groups)
127
+ print("Best alpha:", cv.alpha_)
128
+ ```
129
+
130
+ ---
131
+
132
+ ## Families
133
+
134
+ ```python
135
+ # Binary classification
136
+ model = PretrainedLasso(alpha=0.5, family="binomial")
137
+ model.fit(X, y_binary, groups)
138
+ probs = model.predict(X, groups) # shape (n,), P(y=1)
139
+
140
+ # Multi-class classification (integer labels 0..K-1)
141
+ model = PretrainedLasso(alpha=0.5, family="multinomial")
142
+ model.fit(X, y_multiclass, groups)
143
+ probs = model.predict(X, groups) # shape (n, K)
144
+ ```
145
+
146
+ ---
147
+
148
+ ## Feature names and group labels
149
+
150
+ Both `fit()` methods accept human-readable names. pandas DataFrames are supported natively — column names are picked up automatically.
151
+
152
+ ```python
153
+ import pandas as pd
154
+
155
+ X_df = pd.DataFrame(X, columns=[f"gene_{i}" for i in range(p)])
156
+ group_labels = {0: "control", 1: "treated_A", 2: "treated_B"}
157
+
158
+ model = PretrainedLasso(alpha=0.5)
159
+ model.fit(X_df, y, groups, group_labels=group_labels)
160
+ # overall |Ŝ| : |Ŝ| = 5 / 100 [gene_0, gene_1, gene_2, gene_3, gene_4]
161
+ # pretrain |Ŝ| : control: |Ŝ|=5, treated_A: |Ŝ|=4, treated_B: |Ŝ|=6
162
+ ```
163
+
164
+ ---
165
+
166
+ ## Inspecting the support
167
+
168
+ ```python
169
+ from ptlasso import (
170
+ get_overall_support,
171
+ get_pretrain_support,
172
+ get_pretrain_support_split,
173
+ get_individual_support,
174
+ )
175
+
176
+ get_overall_support(model) # features from the overall model
177
+ get_pretrain_support(model) # union across pretrained group models
178
+ get_pretrain_support(model, common_only=True) # features selected by >50% of groups
179
+ get_pretrain_support(model, groups=[0, 1]) # restrict to specific groups
180
+ get_individual_support(model) # features from no-pretraining baselines
181
+
182
+ common, indiv = get_pretrain_support_split(model)
183
+ # common : features from the overall model (stage 1)
184
+ # indiv : additional features picked up by group models (stage 2)
185
+ ```
186
+
187
+ ---
188
+
189
+ ## Evaluating all sub-models at once
190
+
191
+ ```python
192
+ result = model.evaluate(X_test, y_test, groups_test)
193
+ # {"pretrain": {"predictions": ..., "score": ...},
194
+ # "individual": {"predictions": ..., "score": ...},
195
+ # "overall": {"predictions": ..., "score": ...}}
196
+ ```
197
+
198
+ ---
199
+
200
+ ## Retrieving coefficients
201
+
202
+ ```python
203
+ coefs = model.get_coef() # all sub-models
204
+ coefs["overall"] # {"coef": ndarray, "intercept": ndarray}
205
+ coefs["pretrain"]["control"] # {"coef": ndarray, "intercept": ndarray}
206
+ coefs["individual"]["treated_A"]
207
+
208
+ model.get_coef(model="pretrain") # just pretrain sub-dict
209
+ ```
210
+
211
+ ---
212
+
213
+ ## CV details
214
+
215
+ ```python
216
+ cv = PretrainedLassoCV(
217
+ alphas=[0.0, 0.25, 0.5, 0.75, 1.0],
218
+ cv=5,
219
+ alphahat_choice="overall", # or "mean" (unweighted mean of per-group CV errors)
220
+ family="gaussian",
221
+ overall_lambda="lambda.1se", # or "lambda.min"
222
+ foldid=my_foldid, # optional: custom integer fold assignments
223
+ )
224
+ cv.fit(X, y, groups)
225
+
226
+ cv.alpha_ # globally best alpha
227
+ cv.varying_alphahat_ # {group: best_alpha} per group
228
+ cv.cv_results_ # {alpha: mean CV loss}
229
+ cv.cv_results_se_ # {alpha: SE of CV loss}
230
+ cv.cv_results_per_group_ # {alpha: {group: mean CV loss}}
231
+ cv.cv_results_mean_ # {alpha: unweighted mean of per-group losses}
232
+ cv.cv_results_wtd_mean_ # {alpha: size-weighted mean of per-group losses}
233
+ cv.cv_results_individual_ # CV loss for individual (no-pretraining) baseline
234
+ cv.cv_results_overall_ # CV loss for overall model baseline
235
+ cv.best_estimator_ # PretrainedLasso fitted with alpha_
236
+ cv.all_estimators_ # {alpha: PretrainedLasso} for varying-alpha prediction
237
+
238
+ # Predict using each group's own best alpha
239
+ cv.predict(X, groups, alphatype="varying")
240
+ cv.evaluate(X, y, groups, alphatype="varying")
241
+ ```
242
+
243
+ ---
244
+
245
+ ## Plotting
246
+
247
+ ```python
248
+ from ptlasso import plot_cv, plot_paths
249
+
250
+ plot_cv(cv) # CV loss curve over alpha with ±1 SE band
251
+ plot_paths(model) # regularisation paths for all sub-models
252
+ ```
253
+
254
+ ---
255
+
256
+ ## API reference
257
+
258
+ ### `PretrainedLasso`
259
+
260
+ | Parameter | Default | Description |
261
+ |-----------|---------|-------------|
262
+ | `alpha` | `0.5` | Pretraining strength $\in [0, 1]$ |
263
+ | `family` | `"gaussian"` | `"gaussian"`, `"binomial"`, or `"multinomial"` |
264
+ | `overall_lambda` | `"lambda.1se"` | Lambda rule for stage-1 offset: `"lambda.1se"` or `"lambda.min"` |
265
+ | `fit_intercept` | `True` | Fit an intercept in all sub-models |
266
+ | `lmda_path_size` | `100` | Number of $\lambda$ values in the regularisation path |
267
+ | `min_ratio` | `0.01` | Ratio of smallest to largest $\lambda$ |
268
+ | `verbose` | `False` | Show adelie progress bar |
269
+
270
+ **Methods:**
271
+ - `fit(X, y, groups, group_labels=None, feature_names=None)`
272
+ - `predict(X, groups, model="pretrain", lmda_idx=None)` — `model` ∈ `{"pretrain", "individual", "overall"}`
273
+ - `score(X, y, groups)` — R² or accuracy
274
+ - `evaluate(X, y, groups)` — predict + score for all three sub-models
275
+ - `get_coef(model="all", lmda_idx=None)`
276
+
277
+ ### `PretrainedLassoCV`
278
+
279
+ | Parameter | Default | Description |
280
+ |-----------|---------|-------------|
281
+ | `alphas` | `[0, 0.25, 0.5, 0.75, 1.0]` | Candidate $\alpha$ values |
282
+ | `cv` | `5` | Number of CV folds |
283
+ | `alphahat_choice` | `"overall"` | `"overall"` or `"mean"` (unweighted per-group mean) |
284
+ | `family` | `"gaussian"` | Same as `PretrainedLasso` |
285
+ | `overall_lambda` | `"lambda.1se"` | Same as `PretrainedLasso` |
286
+ | `fit_intercept` | `True` | |
287
+ | `lmda_path_size` | `100` | |
288
+ | `min_ratio` | `0.01` | |
289
+ | `verbose` | `False` | |
290
+ | `foldid` | `None` | Integer array of fold assignments (overrides `cv`) |
291
+
292
+ Same `fit` / `predict` / `score` / `evaluate` / `get_coef` interface as `PretrainedLasso`, plus:
293
+
294
+ | Fitted attribute | Description |
295
+ |-----------------|-------------|
296
+ | `alpha_` | Best $\alpha$ selected by CV |
297
+ | `varying_alphahat_` | `{group: alpha}` — per-group best $\alpha$ |
298
+ | `cv_results_` | `{alpha: mean CV loss}` |
299
+ | `cv_results_se_` | `{alpha: SE of CV loss}` |
300
+ | `cv_results_per_group_` | `{alpha: {group: mean CV loss}}` |
301
+ | `cv_results_mean_` | `{alpha: unweighted mean of per-group losses}` |
302
+ | `cv_results_wtd_mean_` | `{alpha: size-weighted mean of per-group losses}` |
303
+ | `cv_results_individual_` | CV loss for individual baseline |
304
+ | `cv_results_overall_` | CV loss for overall baseline |
305
+ | `best_estimator_` | `PretrainedLasso` fitted with `alpha_` |
306
+ | `all_estimators_` | `{alpha: PretrainedLasso}` for each unique varying alpha |
307
+
308
+ `predict` also accepts `alphatype="varying"` to route each group through its own best alpha.
309
+
310
+ ---
311
+
312
+ ## Citation
313
+
314
+ ```bibtex
315
+ @article{craig2025pretraining,
316
+ title = {Pretraining and the lasso},
317
+ author = {Craig, Erin and Pilanci, Mert and Le Menestrel, Thomas and Narasimhan, Balasubramanian and Rivas, Manuel A. and Gullaksen, Stein-Erik and Tibshirani, Robert},
318
+ journal = {Journal of the Royal Statistical Society Series B: Statistical Methodology},
319
+ pages = {qkaf050},
320
+ year = {2025}
321
+ }
322
+ ```
323
+
324
+ ---
325
+
326
+ ## License
327
+
328
+ MIT
@@ -0,0 +1,276 @@
1
+ # ptlasso
2
+
3
+ Python implementation of the **Pretrained Lasso** — a two-step procedure for fitting sparse linear models when samples belong to distinct groups, leveraging shared structure across groups via pretraining.
4
+
5
+ Based on:
6
+ > Craig, E., Pilanci, M., Le Menestrel, T., Narasimhan, B., Rivas, M. A., Gullaksen, S. E., ... & Tibshirani, R. (2025). Pretraining and the lasso. *Journal of the Royal Statistical Society Series B: Statistical Methodology*, qkaf050.
7
+
8
+ ---
9
+
10
+ ## The idea
11
+
12
+ Standard group-specific Lasso models are fit independently per group, ignoring shared signal. The Pretrained Lasso fits in two steps:
13
+
14
+ **Step 1 — Overall model.** Fit a Lasso on all samples to capture shared structure:
15
+
16
+ $$\hat{\beta}^{\text{overall}} = \arg\min_\beta \frac{1}{2n}\|y - X\beta\|^2 + \lambda\|\beta\|_1$$
17
+
18
+ **Step 2 — Group models.** For each group $k$, fit a Lasso with an offset equal to $\alpha$ times the overall model's linear predictor:
19
+
20
+ $$\hat{\beta}^{(k)} = \arg\min_\beta \frac{1}{2n_k}\|y^{(k)} - \underbrace{\alpha \cdot X^{(k)}\hat{\beta}^{\text{overall}}}_{\text{offset}} - X^{(k)}\beta\|^2 + \lambda_k\|\beta\|_1$$
21
+
22
+ The parameter $\alpha \in [0, 1]$ controls the pretraining strength:
23
+ - $\alpha = 0$: pure group-specific models, no pretraining
24
+ - $\alpha = 1$: group models explain residuals from the overall model
25
+ - $\alpha \in (0, 1)$: group models are anchored to the overall fit
26
+
27
+ Final prediction for group $k$: $\hat{y}^{(k)} = \alpha \cdot X^{(k)}\hat{\beta}^{\text{overall}} + X^{(k)}\hat{\beta}^{(k)}$
28
+
29
+ Supports **gaussian**, **binomial**, and **multinomial** families.
30
+
31
+ ---
32
+
33
+ ## Installation
34
+
35
+ ```bash
36
+ pip install ptlasso
37
+ ```
38
+
39
+ Requires Python ≥ 3.9 and [adelie](https://github.com/JamesYang007/adelie) for the underlying Lasso solver, which supports fitting with offsets (unlike scikit-learn).
40
+
41
+ ---
42
+
43
+ ## Quick start
44
+
45
+ ```python
46
+ import numpy as np
47
+ from ptlasso import PretrainedLasso, PretrainedLassoCV
48
+
49
+ rng = np.random.default_rng(42)
50
+ n, p, k = 300, 100, 3
51
+
52
+ X = rng.standard_normal((n, p))
53
+ groups = rng.integers(0, k, size=n)
54
+ beta = np.zeros(p)
55
+ beta[:5] = [2, -1.5, 1, -0.8, 0.5]
56
+ y = X @ beta + 0.5 * rng.standard_normal(n)
57
+
58
+ # Fixed alpha
59
+ model = PretrainedLasso(alpha=0.5)
60
+ model.fit(X, y, groups)
61
+ print(model)
62
+ # PretrainedLasso(alpha=0.5, family='gaussian', overall_lambda='lambda.1se', ...)
63
+ # family : gaussian
64
+ # n_features : 100
65
+ # n_groups : 3
66
+ # overall |Ŝ| : |Ŝ| = 5 / 100 [0, 1, 2, 3, 4]
67
+ # pretrain |Ŝ| : 0: |Ŝ|=5, 1: |Ŝ|=4, 2: |Ŝ|=6
68
+
69
+ y_pred = model.predict(X, groups)
70
+ print("R²:", model.score(X, y, groups))
71
+
72
+ # Cross-validate over alpha
73
+ cv = PretrainedLassoCV(alphas=[0.0, 0.25, 0.5, 0.75, 1.0])
74
+ cv.fit(X, y, groups)
75
+ print("Best alpha:", cv.alpha_)
76
+ ```
77
+
78
+ ---
79
+
80
+ ## Families
81
+
82
+ ```python
83
+ # Binary classification
84
+ model = PretrainedLasso(alpha=0.5, family="binomial")
85
+ model.fit(X, y_binary, groups)
86
+ probs = model.predict(X, groups) # shape (n,), P(y=1)
87
+
88
+ # Multi-class classification (integer labels 0..K-1)
89
+ model = PretrainedLasso(alpha=0.5, family="multinomial")
90
+ model.fit(X, y_multiclass, groups)
91
+ probs = model.predict(X, groups) # shape (n, K)
92
+ ```
93
+
94
+ ---
95
+
96
+ ## Feature names and group labels
97
+
98
+ Both `fit()` methods accept human-readable names. pandas DataFrames are supported natively — column names are picked up automatically.
99
+
100
+ ```python
101
+ import pandas as pd
102
+
103
+ X_df = pd.DataFrame(X, columns=[f"gene_{i}" for i in range(p)])
104
+ group_labels = {0: "control", 1: "treated_A", 2: "treated_B"}
105
+
106
+ model = PretrainedLasso(alpha=0.5)
107
+ model.fit(X_df, y, groups, group_labels=group_labels)
108
+ # overall |Ŝ| : |Ŝ| = 5 / 100 [gene_0, gene_1, gene_2, gene_3, gene_4]
109
+ # pretrain |Ŝ| : control: |Ŝ|=5, treated_A: |Ŝ|=4, treated_B: |Ŝ|=6
110
+ ```
111
+
112
+ ---
113
+
114
+ ## Inspecting the support
115
+
116
+ ```python
117
+ from ptlasso import (
118
+ get_overall_support,
119
+ get_pretrain_support,
120
+ get_pretrain_support_split,
121
+ get_individual_support,
122
+ )
123
+
124
+ get_overall_support(model) # features from the overall model
125
+ get_pretrain_support(model) # union across pretrained group models
126
+ get_pretrain_support(model, common_only=True) # features selected by >50% of groups
127
+ get_pretrain_support(model, groups=[0, 1]) # restrict to specific groups
128
+ get_individual_support(model) # features from no-pretraining baselines
129
+
130
+ common, indiv = get_pretrain_support_split(model)
131
+ # common : features from the overall model (stage 1)
132
+ # indiv : additional features picked up by group models (stage 2)
133
+ ```
134
+
135
+ ---
136
+
137
+ ## Evaluating all sub-models at once
138
+
139
+ ```python
140
+ result = model.evaluate(X_test, y_test, groups_test)
141
+ # {"pretrain": {"predictions": ..., "score": ...},
142
+ # "individual": {"predictions": ..., "score": ...},
143
+ # "overall": {"predictions": ..., "score": ...}}
144
+ ```
145
+
146
+ ---
147
+
148
+ ## Retrieving coefficients
149
+
150
+ ```python
151
+ coefs = model.get_coef() # all sub-models
152
+ coefs["overall"] # {"coef": ndarray, "intercept": ndarray}
153
+ coefs["pretrain"]["control"] # {"coef": ndarray, "intercept": ndarray}
154
+ coefs["individual"]["treated_A"]
155
+
156
+ model.get_coef(model="pretrain") # just pretrain sub-dict
157
+ ```
158
+
159
+ ---
160
+
161
+ ## CV details
162
+
163
+ ```python
164
+ cv = PretrainedLassoCV(
165
+ alphas=[0.0, 0.25, 0.5, 0.75, 1.0],
166
+ cv=5,
167
+ alphahat_choice="overall", # or "mean" (unweighted mean of per-group CV errors)
168
+ family="gaussian",
169
+ overall_lambda="lambda.1se", # or "lambda.min"
170
+ foldid=my_foldid, # optional: custom integer fold assignments
171
+ )
172
+ cv.fit(X, y, groups)
173
+
174
+ cv.alpha_ # globally best alpha
175
+ cv.varying_alphahat_ # {group: best_alpha} per group
176
+ cv.cv_results_ # {alpha: mean CV loss}
177
+ cv.cv_results_se_ # {alpha: SE of CV loss}
178
+ cv.cv_results_per_group_ # {alpha: {group: mean CV loss}}
179
+ cv.cv_results_mean_ # {alpha: unweighted mean of per-group losses}
180
+ cv.cv_results_wtd_mean_ # {alpha: size-weighted mean of per-group losses}
181
+ cv.cv_results_individual_ # CV loss for individual (no-pretraining) baseline
182
+ cv.cv_results_overall_ # CV loss for overall model baseline
183
+ cv.best_estimator_ # PretrainedLasso fitted with alpha_
184
+ cv.all_estimators_ # {alpha: PretrainedLasso} for varying-alpha prediction
185
+
186
+ # Predict using each group's own best alpha
187
+ cv.predict(X, groups, alphatype="varying")
188
+ cv.evaluate(X, y, groups, alphatype="varying")
189
+ ```
190
+
191
+ ---
192
+
193
+ ## Plotting
194
+
195
+ ```python
196
+ from ptlasso import plot_cv, plot_paths
197
+
198
+ plot_cv(cv) # CV loss curve over alpha with ±1 SE band
199
+ plot_paths(model) # regularisation paths for all sub-models
200
+ ```
201
+
202
+ ---
203
+
204
+ ## API reference
205
+
206
+ ### `PretrainedLasso`
207
+
208
+ | Parameter | Default | Description |
209
+ |-----------|---------|-------------|
210
+ | `alpha` | `0.5` | Pretraining strength $\in [0, 1]$ |
211
+ | `family` | `"gaussian"` | `"gaussian"`, `"binomial"`, or `"multinomial"` |
212
+ | `overall_lambda` | `"lambda.1se"` | Lambda rule for stage-1 offset: `"lambda.1se"` or `"lambda.min"` |
213
+ | `fit_intercept` | `True` | Fit an intercept in all sub-models |
214
+ | `lmda_path_size` | `100` | Number of $\lambda$ values in the regularisation path |
215
+ | `min_ratio` | `0.01` | Ratio of smallest to largest $\lambda$ |
216
+ | `verbose` | `False` | Show adelie progress bar |
217
+
218
+ **Methods:**
219
+ - `fit(X, y, groups, group_labels=None, feature_names=None)`
220
+ - `predict(X, groups, model="pretrain", lmda_idx=None)` — `model` ∈ `{"pretrain", "individual", "overall"}`
221
+ - `score(X, y, groups)` — R² or accuracy
222
+ - `evaluate(X, y, groups)` — predict + score for all three sub-models
223
+ - `get_coef(model="all", lmda_idx=None)`
224
+
225
+ ### `PretrainedLassoCV`
226
+
227
+ | Parameter | Default | Description |
228
+ |-----------|---------|-------------|
229
+ | `alphas` | `[0, 0.25, 0.5, 0.75, 1.0]` | Candidate $\alpha$ values |
230
+ | `cv` | `5` | Number of CV folds |
231
+ | `alphahat_choice` | `"overall"` | `"overall"` or `"mean"` (unweighted per-group mean) |
232
+ | `family` | `"gaussian"` | Same as `PretrainedLasso` |
233
+ | `overall_lambda` | `"lambda.1se"` | Same as `PretrainedLasso` |
234
+ | `fit_intercept` | `True` | |
235
+ | `lmda_path_size` | `100` | |
236
+ | `min_ratio` | `0.01` | |
237
+ | `verbose` | `False` | |
238
+ | `foldid` | `None` | Integer array of fold assignments (overrides `cv`) |
239
+
240
+ Same `fit` / `predict` / `score` / `evaluate` / `get_coef` interface as `PretrainedLasso`, plus:
241
+
242
+ | Fitted attribute | Description |
243
+ |-----------------|-------------|
244
+ | `alpha_` | Best $\alpha$ selected by CV |
245
+ | `varying_alphahat_` | `{group: alpha}` — per-group best $\alpha$ |
246
+ | `cv_results_` | `{alpha: mean CV loss}` |
247
+ | `cv_results_se_` | `{alpha: SE of CV loss}` |
248
+ | `cv_results_per_group_` | `{alpha: {group: mean CV loss}}` |
249
+ | `cv_results_mean_` | `{alpha: unweighted mean of per-group losses}` |
250
+ | `cv_results_wtd_mean_` | `{alpha: size-weighted mean of per-group losses}` |
251
+ | `cv_results_individual_` | CV loss for individual baseline |
252
+ | `cv_results_overall_` | CV loss for overall baseline |
253
+ | `best_estimator_` | `PretrainedLasso` fitted with `alpha_` |
254
+ | `all_estimators_` | `{alpha: PretrainedLasso}` for each unique varying alpha |
255
+
256
+ `predict` also accepts `alphatype="varying"` to route each group through its own best alpha.
257
+
258
+ ---
259
+
260
+ ## Citation
261
+
262
+ ```bibtex
263
+ @article{craig2025pretraining,
264
+ title = {Pretraining and the lasso},
265
+ author = {Craig, Erin and Pilanci, Mert and Le Menestrel, Thomas and Narasimhan, Balasubramanian and Rivas, Manuel A. and Gullaksen, Stein-Erik and Tibshirani, Robert},
266
+ journal = {Journal of the Royal Statistical Society Series B: Statistical Methodology},
267
+ pages = {qkaf050},
268
+ year = {2025}
269
+ }
270
+ ```
271
+
272
+ ---
273
+
274
+ ## License
275
+
276
+ MIT