pcntoolkit 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pcntoolkit/__init__.py +4 -0
- pcntoolkit/configs.py +9 -0
- pcntoolkit/dataio/__init__.py +1 -0
- pcntoolkit/dataio/fileio.py +608 -0
- pcntoolkit/model/KnuOp.py +48 -0
- pcntoolkit/model/NP.py +88 -0
- pcntoolkit/model/NPR.py +86 -0
- pcntoolkit/model/SHASH.py +509 -0
- pcntoolkit/model/__init__.py +6 -0
- pcntoolkit/model/architecture.py +219 -0
- pcntoolkit/model/bayesreg.py +585 -0
- pcntoolkit/model/core.21290 +0 -0
- pcntoolkit/model/gp.py +489 -0
- pcntoolkit/model/hbr.py +1584 -0
- pcntoolkit/model/rfa.py +245 -0
- pcntoolkit/normative.py +1647 -0
- pcntoolkit/normative_NP.py +336 -0
- pcntoolkit/normative_model/__init__.py +6 -0
- pcntoolkit/normative_model/norm_base.py +62 -0
- pcntoolkit/normative_model/norm_blr.py +303 -0
- pcntoolkit/normative_model/norm_gpr.py +112 -0
- pcntoolkit/normative_model/norm_hbr.py +752 -0
- pcntoolkit/normative_model/norm_np.py +333 -0
- pcntoolkit/normative_model/norm_rfa.py +109 -0
- pcntoolkit/normative_model/norm_utils.py +29 -0
- pcntoolkit/normative_parallel.py +1420 -0
- pcntoolkit/regression_model/blr/warp.py +1 -0
- pcntoolkit/trendsurf.py +315 -0
- pcntoolkit/util/__init__.py +1 -0
- pcntoolkit/util/bspline.py +149 -0
- pcntoolkit/util/hbr_utils.py +242 -0
- pcntoolkit/util/utils.py +1698 -0
- pcntoolkit-0.32.0.dist-info/LICENSE +674 -0
- pcntoolkit-0.32.0.dist-info/METADATA +134 -0
- pcntoolkit-0.32.0.dist-info/RECORD +37 -0
- pcntoolkit-0.32.0.dist-info/WHEEL +4 -0
- pcntoolkit-0.32.0.dist-info/entry_points.txt +5 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
pcntoolkit/trendsurf.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
1
|
+
#!/opt/conda/bin/python
|
|
2
|
+
|
|
3
|
+
# ------------------------------------------------------------------------------
|
|
4
|
+
# Usage:
|
|
5
|
+
# python trendsurf.py -m [maskfile] -b [basis] -c [covariates] <infile>
|
|
6
|
+
#
|
|
7
|
+
# Written by A. Marquand
|
|
8
|
+
# ------------------------------------------------------------------------------
|
|
9
|
+
|
|
10
|
+
from __future__ import division, print_function
|
|
11
|
+
|
|
12
|
+
import argparse
|
|
13
|
+
import os
|
|
14
|
+
import sys
|
|
15
|
+
|
|
16
|
+
import nibabel as nib
|
|
17
|
+
import numpy as np
|
|
18
|
+
|
|
19
|
+
try: # Run as a package if installed
|
|
20
|
+
from pcntoolkit.dataio import fileio
|
|
21
|
+
from pcntoolkit.model.bayesreg import BLR
|
|
22
|
+
except ImportError:
|
|
23
|
+
pass
|
|
24
|
+
path = os.path.abspath(os.path.dirname(__file__))
|
|
25
|
+
if path not in sys.path:
|
|
26
|
+
sys.path.append(path)
|
|
27
|
+
del path
|
|
28
|
+
|
|
29
|
+
from dataio import fileio
|
|
30
|
+
from model.bayesreg import BLR
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def load_data(datafile, maskfile=None):
|
|
34
|
+
"""
|
|
35
|
+
Load data from disk
|
|
36
|
+
|
|
37
|
+
This will load data from disk, either in nifti or ascii format. If the
|
|
38
|
+
data are in ascii format, they should be in tab or space delimited format
|
|
39
|
+
with the number of voxels in rows and the number of subjects in columns.
|
|
40
|
+
Neuroimaging data will be reshaped into the appropriate format
|
|
41
|
+
|
|
42
|
+
:param datafile: 4-d nifti file containing the images to be estimated
|
|
43
|
+
:param maskfile: nifti mask used to apply to the data
|
|
44
|
+
:returns: * dat - data in vectorised form
|
|
45
|
+
* world - voxel coordinates
|
|
46
|
+
* mask - mask used to apply to the data
|
|
47
|
+
"""
|
|
48
|
+
if datafile.endswith("nii.gz") or datafile.endswith("nii"):
|
|
49
|
+
# we load the data this way rather than fileio.load() because we need
|
|
50
|
+
# access to the volumetric representation (to know the # coordinates)
|
|
51
|
+
dat = fileio.load_nifti(datafile, vol=True)
|
|
52
|
+
dim = dat.shape
|
|
53
|
+
if len(dim) <= 3:
|
|
54
|
+
dim = dim + (1,)
|
|
55
|
+
else:
|
|
56
|
+
raise ValueError("No routine to handle non-nifti data")
|
|
57
|
+
|
|
58
|
+
mask = fileio.create_mask(dat, mask=maskfile)
|
|
59
|
+
|
|
60
|
+
dat = fileio.vol2vec(dat, mask)
|
|
61
|
+
maskid = np.where(mask.ravel())[0]
|
|
62
|
+
|
|
63
|
+
# generate voxel coordinates
|
|
64
|
+
i, j, k = np.meshgrid(np.linspace(0, dim[0]-1, dim[0]),
|
|
65
|
+
np.linspace(0, dim[1]-1, dim[1]),
|
|
66
|
+
np.linspace(0, dim[2]-1, dim[2]), indexing='ij')
|
|
67
|
+
|
|
68
|
+
# voxel-to-world mapping
|
|
69
|
+
img = nib.load(datafile)
|
|
70
|
+
world = np.vstack((i.ravel(), j.ravel(), k.ravel(),
|
|
71
|
+
np.ones(np.prod(i.shape), float))).T
|
|
72
|
+
world = np.dot(world, img.affine.T)[maskid, 0:3]
|
|
73
|
+
|
|
74
|
+
return dat, world, mask
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def create_basis(X, basis, mask):
|
|
78
|
+
"""
|
|
79
|
+
Create a basis set
|
|
80
|
+
|
|
81
|
+
This will create a basis set for the trend surface model. This is
|
|
82
|
+
currently fit using a polynomial model of a specified degree. The models
|
|
83
|
+
are estimated on the basis of data stored on disk in ascii or
|
|
84
|
+
neuroimaging data formats (currently nifti only). Ascii data should be in
|
|
85
|
+
tab or space delimited format with the number of voxels in rows and the
|
|
86
|
+
number of subjects in columns. Neuroimaging data will be reshaped
|
|
87
|
+
into the appropriate format
|
|
88
|
+
|
|
89
|
+
:param X: covariates
|
|
90
|
+
:param basis: model order for the interpolating polynomial
|
|
91
|
+
:param mask: mask used to apply to the data
|
|
92
|
+
:returns: * Phi - basis set
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
# check whether we are using a polynomial basis set
|
|
96
|
+
if type(basis) is int or (type(basis) is str and len(basis) == 1):
|
|
97
|
+
dimpoly = int(basis)
|
|
98
|
+
dimx = X.shape[1]
|
|
99
|
+
print('Generating polynomial basis set of degree', dimpoly, '...')
|
|
100
|
+
Phi = np.zeros((X.shape[0], X.shape[1]*dimpoly))
|
|
101
|
+
colid = np.arange(0, dimx)
|
|
102
|
+
for d in range(1, dimpoly+1):
|
|
103
|
+
Phi[:, colid] = X ** d
|
|
104
|
+
colid += dimx
|
|
105
|
+
else: # custom basis set
|
|
106
|
+
if type(basis) is str:
|
|
107
|
+
print('Loading custom basis set from', basis)
|
|
108
|
+
|
|
109
|
+
# Phi_vol = fileio.load_data(basis)
|
|
110
|
+
# we load the data this way instead so we can apply the same mask
|
|
111
|
+
Phi_vol = fileio.load_nifti(basis, vol=True)
|
|
112
|
+
Phi = fileio.vol2vec(Phi_vol, mask)
|
|
113
|
+
print('Basis set consists of', Phi.shape[1], 'basis functions.')
|
|
114
|
+
# maskid = np.where(mask.ravel())[0]
|
|
115
|
+
else:
|
|
116
|
+
raise ValueError("I don't know what to do with basis:", basis)
|
|
117
|
+
|
|
118
|
+
return Phi
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def write_nii(data, filename, examplenii, mask):
|
|
122
|
+
"""
|
|
123
|
+
Write data to nifti file
|
|
124
|
+
|
|
125
|
+
This will write data to a nifti file, using the header information from
|
|
126
|
+
an example nifti file.
|
|
127
|
+
|
|
128
|
+
:param data: data to be written
|
|
129
|
+
:param filename: name of file to be written
|
|
130
|
+
:param examplenii: example nifti file
|
|
131
|
+
:param mask: mask used to apply to the data
|
|
132
|
+
:returns: * Phi - basis set
|
|
133
|
+
"""
|
|
134
|
+
# load example image
|
|
135
|
+
ex_img = nib.load(examplenii)
|
|
136
|
+
dim = ex_img.shape[0:3]
|
|
137
|
+
nvol = int(data.shape[1])
|
|
138
|
+
|
|
139
|
+
# write data
|
|
140
|
+
array_data = np.zeros((np.prod(dim), nvol))
|
|
141
|
+
array_data[mask.flatten(), :] = data
|
|
142
|
+
array_data = np.reshape(array_data, dim+(nvol,))
|
|
143
|
+
array_img = nib.Nifti1Image(array_data,
|
|
144
|
+
ex_img.get_affine(),
|
|
145
|
+
ex_img.get_header())
|
|
146
|
+
nib.save(array_img, filename)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def get_args(*args):
|
|
150
|
+
"""
|
|
151
|
+
Parse command line arguments
|
|
152
|
+
|
|
153
|
+
This will parse the command line arguments for the trend surface model.
|
|
154
|
+
The arguments are:
|
|
155
|
+
|
|
156
|
+
:param filename: 4-d nifti file containing the images to be estimated
|
|
157
|
+
:param maskfile: nifti mask used to apply to the data
|
|
158
|
+
:param basis: model order for the interpolating polynomial
|
|
159
|
+
:param covfile: file containing covariates
|
|
160
|
+
:param ard: use ARD
|
|
161
|
+
:param outputall: output all measures
|
|
162
|
+
:returns: * filename - 4-d nifti file containing the images to be estimated
|
|
163
|
+
* maskfile - nifti mask used to apply to the data
|
|
164
|
+
* basis - model order for the interpolating polynomial
|
|
165
|
+
* covfile - file containing covariates
|
|
166
|
+
* ard - use ARD
|
|
167
|
+
* outputall - output all measures
|
|
168
|
+
"""
|
|
169
|
+
parser = argparse.ArgumentParser(description="Trend surface model")
|
|
170
|
+
parser.add_argument("filename")
|
|
171
|
+
parser.add_argument("-b", help="basis set", dest="basis", default=3)
|
|
172
|
+
parser.add_argument("-m", help="mask file", dest="maskfile", default=None)
|
|
173
|
+
parser.add_argument("-c", help="covariates file", dest="covfile",
|
|
174
|
+
default=None)
|
|
175
|
+
parser.add_argument("-a", help="use ARD", action='store_true')
|
|
176
|
+
parser.add_argument("-o", help="output all measures", action='store_true')
|
|
177
|
+
args = parser.parse_args()
|
|
178
|
+
wdir = os.path.realpath(os.path.curdir)
|
|
179
|
+
filename = os.path.join(wdir, args.filename)
|
|
180
|
+
if args.maskfile is None:
|
|
181
|
+
maskfile = None
|
|
182
|
+
else:
|
|
183
|
+
maskfile = os.path.join(wdir, args.maskfile)
|
|
184
|
+
basis = args.basis
|
|
185
|
+
if args.covfile is not None:
|
|
186
|
+
raise NotImplementedError("Covariates not implemented yet.")
|
|
187
|
+
|
|
188
|
+
return filename, maskfile, basis, args.a, args.o
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def estimate(filename, maskfile, basis, ard=False, outputall=False,
|
|
192
|
+
saveoutput=True, **kwargs):
|
|
193
|
+
""" Estimate a trend surface model
|
|
194
|
+
|
|
195
|
+
This will estimate a trend surface model, independently for each subject.
|
|
196
|
+
This is currently fit using a polynomial model of a specified degree.
|
|
197
|
+
The models are estimated on the basis of data stored on disk in ascii or
|
|
198
|
+
neuroimaging data formats (currently nifti only). Ascii data should be in
|
|
199
|
+
tab or space delimited format with the number of voxels in rows and the
|
|
200
|
+
number of subjects in columns. Neuroimaging data will be reshaped
|
|
201
|
+
into the appropriate format
|
|
202
|
+
|
|
203
|
+
Basic usage::
|
|
204
|
+
|
|
205
|
+
estimate(filename, maskfile, basis)
|
|
206
|
+
|
|
207
|
+
where the variables are defined below. Note that either the cfolds
|
|
208
|
+
parameter or (testcov, testresp) should be specified, but not both.
|
|
209
|
+
|
|
210
|
+
:param filename: 4-d nifti file containing the images to be estimated
|
|
211
|
+
:param maskfile: nifti mask used to apply to the data
|
|
212
|
+
:param basis: model order for the interpolating polynomial
|
|
213
|
+
|
|
214
|
+
All outputs are written to disk in the same format as the input. These are:
|
|
215
|
+
|
|
216
|
+
:outputs: * yhat - predictive mean
|
|
217
|
+
* ys2 - predictive variance
|
|
218
|
+
* trendcoeff - coefficients from the trend surface model
|
|
219
|
+
* negloglik - Negative log marginal likelihood
|
|
220
|
+
* hyp - hyperparameters
|
|
221
|
+
* explainedvar - explained variance
|
|
222
|
+
* rmse - standardised mean squared error
|
|
223
|
+
"""
|
|
224
|
+
|
|
225
|
+
# parse arguments
|
|
226
|
+
optim = kwargs.get('optimizer', 'powell')
|
|
227
|
+
|
|
228
|
+
# load data
|
|
229
|
+
print("Processing data in", filename)
|
|
230
|
+
Y, X, mask = load_data(filename, maskfile)
|
|
231
|
+
Y = np.round(10000*Y)/10000 # truncate precision to avoid numerical probs
|
|
232
|
+
if len(Y.shape) == 1:
|
|
233
|
+
Y = Y[:, np.newaxis]
|
|
234
|
+
N = Y.shape[1]
|
|
235
|
+
|
|
236
|
+
# standardize responses and covariates
|
|
237
|
+
mY = np.mean(Y, axis=0)
|
|
238
|
+
sY = np.std(Y, axis=0)
|
|
239
|
+
Yz = (Y - mY) / sY
|
|
240
|
+
mX = np.mean(X, axis=0)
|
|
241
|
+
sX = np.std(X, axis=0)
|
|
242
|
+
Xz = (X - mX) / sX
|
|
243
|
+
|
|
244
|
+
# create basis set and set starting hyperparamters
|
|
245
|
+
Phi = create_basis(Xz, basis, mask)
|
|
246
|
+
if ard is True:
|
|
247
|
+
hyp0 = np.zeros(Phi.shape[1]+1)
|
|
248
|
+
else:
|
|
249
|
+
hyp0 = np.zeros(2)
|
|
250
|
+
|
|
251
|
+
# estimate the models for all subjects
|
|
252
|
+
if ard:
|
|
253
|
+
print('ARD is enabled')
|
|
254
|
+
yhat = np.zeros_like(Yz)
|
|
255
|
+
ys2 = np.zeros_like(Yz)
|
|
256
|
+
nlZ = np.zeros(N)
|
|
257
|
+
hyp = np.zeros((N, len(hyp0)))
|
|
258
|
+
rmse = np.zeros(N)
|
|
259
|
+
ev = np.zeros(N)
|
|
260
|
+
m = np.zeros((N, Phi.shape[1]))
|
|
261
|
+
bs2 = np.zeros((N, Phi.shape[1]))
|
|
262
|
+
for i in range(0, N):
|
|
263
|
+
print("Estimating model ", i+1, "of", N)
|
|
264
|
+
breg = BLR()
|
|
265
|
+
hyp[i, :] = breg.estimate(hyp0, Phi, Yz[:, i], optimizer=optim)
|
|
266
|
+
m[i, :] = breg.m
|
|
267
|
+
nlZ[i] = breg.nlZ
|
|
268
|
+
|
|
269
|
+
# compute extra measures (e.g. marginal variances)?
|
|
270
|
+
if outputall:
|
|
271
|
+
bs2[i] = np.sqrt(np.diag(np.linalg.inv(breg.A)))
|
|
272
|
+
|
|
273
|
+
# compute predictions and errors
|
|
274
|
+
yhat[:, i], ys2[:, i] = breg.predict(hyp[i, :], Phi, Yz[:, i], Phi)
|
|
275
|
+
yhat[:, i] = yhat[:, i]*sY[i] + mY[i]
|
|
276
|
+
rmse[i] = np.sqrt(np.mean((Y[:, i] - yhat[:, i]) ** 2))
|
|
277
|
+
ev[i] = 100*(1 - (np.var(Y[:, i] - yhat[:, i]) / np.var(Y[:, i])))
|
|
278
|
+
|
|
279
|
+
print("Variance explained =", ev[i], "% RMSE =", rmse[i])
|
|
280
|
+
|
|
281
|
+
print("Mean (std) variance explained =", ev.mean(), "(", ev.std(), ")")
|
|
282
|
+
print("Mean (std) RMSE =", rmse.mean(), "(", rmse.std(), ")")
|
|
283
|
+
|
|
284
|
+
# Write output
|
|
285
|
+
if saveoutput:
|
|
286
|
+
print("Writing output ...")
|
|
287
|
+
np.savetxt("trendcoeff.txt", m, delimiter='\t', fmt='%5.8f')
|
|
288
|
+
np.savetxt("negloglik.txt", nlZ, delimiter='\t', fmt='%5.8f')
|
|
289
|
+
np.savetxt("hyp.txt", hyp, delimiter='\t', fmt='%5.8f')
|
|
290
|
+
np.savetxt("explainedvar.txt", ev, delimiter='\t', fmt='%5.8f')
|
|
291
|
+
np.savetxt("rmse.txt", rmse, delimiter='\t', fmt='%5.8f')
|
|
292
|
+
fileio.save_nifti(yhat, 'yhat.nii.gz', filename, mask)
|
|
293
|
+
fileio.save_nifti(ys2, 'ys2.nii.gz', filename, mask)
|
|
294
|
+
|
|
295
|
+
if outputall:
|
|
296
|
+
np.savetxt("trendcoeffvar.txt", bs2, delimiter='\t', fmt='%5.8f')
|
|
297
|
+
else:
|
|
298
|
+
out = [yhat, ys2, nlZ, hyp, rmse, ev, m]
|
|
299
|
+
if outputall:
|
|
300
|
+
out.append(bs2)
|
|
301
|
+
return out
|
|
302
|
+
|
|
303
|
+
def entrypoint(*args):
|
|
304
|
+
main(*args)
|
|
305
|
+
|
|
306
|
+
def main(*args):
|
|
307
|
+
np.seterr(invalid='ignore')
|
|
308
|
+
|
|
309
|
+
filename, maskfile, basis, ard, outputall = get_args(args)
|
|
310
|
+
estimate(filename, maskfile, basis, ard, outputall)
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
# For running from the command line:
|
|
314
|
+
if __name__ == "__main__":
|
|
315
|
+
main(sys.argv[1:])
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from . import utils
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from scipy.interpolate import BSpline
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class BSplineBasis:
|
|
6
|
+
def __init__(
|
|
7
|
+
self, order, nknots, knot_method="uniform", left_expand=0.05, right_expand=0.05
|
|
8
|
+
):
|
|
9
|
+
"""
|
|
10
|
+
Initialize the BSplineBasis object.
|
|
11
|
+
:param order: Degree of the B-spline
|
|
12
|
+
:param nknots: Number of interior knots. Mind that this is the number of interior
|
|
13
|
+
knots. The final number of knots will be nknotes+2 as two knots will be added at boundries.
|
|
14
|
+
:param knot_method: 'uniform' or 'percentile' for knot placement
|
|
15
|
+
:param left_expand: Fraction to expand the range on the left (default 0)
|
|
16
|
+
:param right_expand: Fraction to expand the range on the right (default 0)
|
|
17
|
+
"""
|
|
18
|
+
if nknots + 2 < order + 1:
|
|
19
|
+
raise ValueError("Number of knots+2 must be at least degree + 1.")
|
|
20
|
+
if knot_method not in ["uniform", "percentile"]:
|
|
21
|
+
raise ValueError("knot_method must be 'uniform' or 'percentile'.")
|
|
22
|
+
if not (0 <= left_expand <= 1 and 0 <= right_expand <= 1):
|
|
23
|
+
raise ValueError("left_expand and right_expand must be between 0 and 1.")
|
|
24
|
+
|
|
25
|
+
self.degree = order
|
|
26
|
+
self.nknots = nknots
|
|
27
|
+
self.knot_method = knot_method
|
|
28
|
+
self.left_expand = left_expand
|
|
29
|
+
self.right_expand = right_expand
|
|
30
|
+
self.knots = None
|
|
31
|
+
self.feature_min = None
|
|
32
|
+
self.feature_max = None
|
|
33
|
+
|
|
34
|
+
def fit(self, X, feature_min=None, feature_max=None):
|
|
35
|
+
"""
|
|
36
|
+
Fit B-spline basis functions to the dataset.
|
|
37
|
+
:param X: [N×P] array of covariates
|
|
38
|
+
:param feature_min: Minimum values for features (optional)
|
|
39
|
+
:param feature_max: Maximum values for features (optional)
|
|
40
|
+
"""
|
|
41
|
+
if not isinstance(X, np.ndarray):
|
|
42
|
+
raise ValueError("Input X must be a NumPy array.")
|
|
43
|
+
if X.ndim != 2:
|
|
44
|
+
raise ValueError("Input X must be a 2D array.")
|
|
45
|
+
|
|
46
|
+
self.feature_min = (
|
|
47
|
+
np.min(X, axis=0) if feature_min is None else np.array(feature_min)
|
|
48
|
+
)
|
|
49
|
+
self.feature_max = (
|
|
50
|
+
np.max(X, axis=0) if feature_max is None else np.array(feature_max)
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
feature_num = X.shape[1]
|
|
54
|
+
self.knots = []
|
|
55
|
+
|
|
56
|
+
for i in range(feature_num):
|
|
57
|
+
# Determine range of bspline basis
|
|
58
|
+
minx = self.feature_min[i]
|
|
59
|
+
maxx = self.feature_max[i]
|
|
60
|
+
delta = maxx - minx
|
|
61
|
+
t_min = minx - self.left_expand * delta
|
|
62
|
+
t_max = maxx + self.right_expand * delta
|
|
63
|
+
|
|
64
|
+
# Determine knot locations
|
|
65
|
+
if self.knot_method == "uniform":
|
|
66
|
+
interior_knots = np.linspace(t_min, t_max, self.nknots)
|
|
67
|
+
elif self.knot_method == "percentile":
|
|
68
|
+
interior_knots = np.percentile(
|
|
69
|
+
X[:, i], np.linspace(0, 100, self.nknots)
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# Add boundary knots
|
|
73
|
+
t = np.concatenate(
|
|
74
|
+
([t_min] * self.degree, interior_knots, [t_max] * self.degree)
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
self.knots.append(t)
|
|
78
|
+
|
|
79
|
+
def transform(self, X):
|
|
80
|
+
"""
|
|
81
|
+
Transform the dataset using the fitted B-spline basis functions.
|
|
82
|
+
:param X: [N×P] array of clinical covariates
|
|
83
|
+
:return: [N×(P×n_basis)] array of transformed data
|
|
84
|
+
"""
|
|
85
|
+
if self.knots is None:
|
|
86
|
+
raise ValueError(
|
|
87
|
+
"B-spline basis functions have not been fitted. Call 'fit' first."
|
|
88
|
+
)
|
|
89
|
+
if not isinstance(X, np.ndarray):
|
|
90
|
+
raise ValueError("Input X must be a NumPy array.")
|
|
91
|
+
if X.ndim != 2:
|
|
92
|
+
raise ValueError("Input X must be a 2D array.")
|
|
93
|
+
if len(self.knots) != X.shape[1]:
|
|
94
|
+
raise ValueError(
|
|
95
|
+
"Number of B-spline basis functions must match the number of features in X."
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
transformed_features = []
|
|
99
|
+
for f in range(len(self.knots)):
|
|
100
|
+
phi = BSpline.design_matrix(
|
|
101
|
+
x=X[:, f], t=self.knots[f], k=self.degree, extrapolate=True
|
|
102
|
+
).toarray()
|
|
103
|
+
transformed_features.append(phi)
|
|
104
|
+
return np.concatenate(transformed_features, axis=1)
|
|
105
|
+
|
|
106
|
+
def adapt(self, target_X):
|
|
107
|
+
"""
|
|
108
|
+
Adapt the fitted B-spline basis functions to a target dataset.
|
|
109
|
+
:param target_X: [N×P] array of target clinical covariates
|
|
110
|
+
"""
|
|
111
|
+
if self.knots is None:
|
|
112
|
+
raise ValueError(
|
|
113
|
+
"B-spline basis functions have not been fitted. Call 'fit' first."
|
|
114
|
+
)
|
|
115
|
+
if not isinstance(target_X, np.ndarray):
|
|
116
|
+
raise ValueError("Input target_X must be a NumPy array.")
|
|
117
|
+
if target_X.ndim != 2:
|
|
118
|
+
raise ValueError("Input target_X must be a 2D array.")
|
|
119
|
+
if len(self.knots) != target_X.shape[1]:
|
|
120
|
+
raise ValueError(
|
|
121
|
+
"Number of B-spline basis functions must match the number of features in target_X."
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
# Updating feature_min and feature_max using combined datsets
|
|
125
|
+
combined_min = np.minimum(self.feature_min, np.min(target_X, axis=0))
|
|
126
|
+
combined_max = np.maximum(self.feature_max, np.max(target_X, axis=0))
|
|
127
|
+
self.feature_min = combined_min
|
|
128
|
+
self.feature_max = combined_max
|
|
129
|
+
|
|
130
|
+
feature_num = target_X.shape[1]
|
|
131
|
+
|
|
132
|
+
new_knots = []
|
|
133
|
+
|
|
134
|
+
for i in range(feature_num):
|
|
135
|
+
minx = self.feature_min[i]
|
|
136
|
+
maxx = self.feature_max[i]
|
|
137
|
+
delta = maxx - minx
|
|
138
|
+
t_min = minx - self.left_expand * delta
|
|
139
|
+
t_max = maxx + self.right_expand * delta
|
|
140
|
+
|
|
141
|
+
# Adapt knots
|
|
142
|
+
source_knots = self.knots[i]
|
|
143
|
+
target_knots = t_min + (source_knots - source_knots[0]) * (
|
|
144
|
+
t_max - t_min
|
|
145
|
+
) / (source_knots[-1] - source_knots[0])
|
|
146
|
+
|
|
147
|
+
new_knots.append(target_knots)
|
|
148
|
+
|
|
149
|
+
self.knots = new_knots
|