pcntoolkit 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pcntoolkit/__init__.py +4 -0
- pcntoolkit/configs.py +9 -0
- pcntoolkit/dataio/__init__.py +1 -0
- pcntoolkit/dataio/fileio.py +608 -0
- pcntoolkit/model/KnuOp.py +48 -0
- pcntoolkit/model/NP.py +88 -0
- pcntoolkit/model/NPR.py +86 -0
- pcntoolkit/model/SHASH.py +509 -0
- pcntoolkit/model/__init__.py +6 -0
- pcntoolkit/model/architecture.py +219 -0
- pcntoolkit/model/bayesreg.py +585 -0
- pcntoolkit/model/core.21290 +0 -0
- pcntoolkit/model/gp.py +489 -0
- pcntoolkit/model/hbr.py +1584 -0
- pcntoolkit/model/rfa.py +245 -0
- pcntoolkit/normative.py +1647 -0
- pcntoolkit/normative_NP.py +336 -0
- pcntoolkit/normative_model/__init__.py +6 -0
- pcntoolkit/normative_model/norm_base.py +62 -0
- pcntoolkit/normative_model/norm_blr.py +303 -0
- pcntoolkit/normative_model/norm_gpr.py +112 -0
- pcntoolkit/normative_model/norm_hbr.py +752 -0
- pcntoolkit/normative_model/norm_np.py +333 -0
- pcntoolkit/normative_model/norm_rfa.py +109 -0
- pcntoolkit/normative_model/norm_utils.py +29 -0
- pcntoolkit/normative_parallel.py +1420 -0
- pcntoolkit/regression_model/blr/warp.py +1 -0
- pcntoolkit/trendsurf.py +315 -0
- pcntoolkit/util/__init__.py +1 -0
- pcntoolkit/util/bspline.py +149 -0
- pcntoolkit/util/hbr_utils.py +242 -0
- pcntoolkit/util/utils.py +1698 -0
- pcntoolkit-0.32.0.dist-info/LICENSE +674 -0
- pcntoolkit-0.32.0.dist-info/METADATA +134 -0
- pcntoolkit-0.32.0.dist-info/RECORD +37 -0
- pcntoolkit-0.32.0.dist-info/WHEEL +4 -0
- pcntoolkit-0.32.0.dist-info/entry_points.txt +5 -0
|
@@ -0,0 +1,336 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
# -*- coding: utf-8 -*-
|
|
4
|
+
"""
|
|
5
|
+
Created on Tue Jun 18 09:47:01 2019
|
|
6
|
+
|
|
7
|
+
@author: seykia
|
|
8
|
+
"""
|
|
9
|
+
# ------------------------------------------------------------------------------
|
|
10
|
+
# Usage:
|
|
11
|
+
# python normative_NP.py -r /home/preclineu/andmar/data/seykia/ds000030_R1.0.5/responses.nii.gz
|
|
12
|
+
# -c /home/preclineu/andmar/data/seykia/ds000030_R1.0.5/covariates.pickle
|
|
13
|
+
# --tr /home/preclineu/andmar/data/seykia/ds000030_R1.0.5/test_responses.nii.gz
|
|
14
|
+
# --tc /home/preclineu/andmar/data/seykia/ds000030_R1.0.5/test_covariates.pickle
|
|
15
|
+
# -o /home/preclineu/andmar/data/seykia/ds000030_R1.0.5/Results
|
|
16
|
+
#
|
|
17
|
+
#
|
|
18
|
+
# Written by S. M. Kia
|
|
19
|
+
# ------------------------------------------------------------------------------
|
|
20
|
+
|
|
21
|
+
from __future__ import print_function
|
|
22
|
+
from __future__ import division
|
|
23
|
+
|
|
24
|
+
import sys
|
|
25
|
+
import argparse
|
|
26
|
+
import torch
|
|
27
|
+
from torch import optim
|
|
28
|
+
import numpy as np
|
|
29
|
+
import pickle
|
|
30
|
+
from pcntoolkit.model.NP import NP, apply_dropout_test, np_loss
|
|
31
|
+
from sklearn.preprocessing import MinMaxScaler, StandardScaler
|
|
32
|
+
from sklearn.linear_model import LinearRegression, MultiTaskLasso
|
|
33
|
+
from pcntoolkit.model.architecture import Encoder, Decoder
|
|
34
|
+
from pcntoolkit.util.utils import compute_pearsonr, explained_var, compute_MSLL
|
|
35
|
+
from pcntoolkit.util.utils import extreme_value_prob, extreme_value_prob_fit, ravel_2D, unravel_2D
|
|
36
|
+
from pcntoolkit.dataio import fileio
|
|
37
|
+
import os
|
|
38
|
+
|
|
39
|
+
try: # run as a package if installed
|
|
40
|
+
from pcntoolkit import configs
|
|
41
|
+
except ImportError:
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
path = os.path.abspath(os.path.dirname(__file__))
|
|
45
|
+
if path not in sys.path:
|
|
46
|
+
sys.path.append(path)
|
|
47
|
+
del path
|
|
48
|
+
import configs
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def get_args(*args):
|
|
52
|
+
"""
|
|
53
|
+
Parses command-line arguments for the Neural Processes (NP) for Deep Normative Modeling script.
|
|
54
|
+
|
|
55
|
+
Parameters:
|
|
56
|
+
*args: Variable length argument list.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
argparse.Namespace: An object that holds the command-line arguments as attributes. The arguments include:
|
|
60
|
+
- respfile: Training response nifti file address.
|
|
61
|
+
- covfile: Training covariates pickle file address.
|
|
62
|
+
- testcovfile: Test covariates pickle file address.
|
|
63
|
+
- testrespfile: Test response nifti file address.
|
|
64
|
+
- mask: Mask nifti file address.
|
|
65
|
+
- outdir: Output directory address.
|
|
66
|
+
- m: Number of fixed-effect estimations.
|
|
67
|
+
- batchnum: Input batch size for training.
|
|
68
|
+
- epochs: Number of epochs to train.
|
|
69
|
+
- device: Either cpu or cuda.
|
|
70
|
+
- estimator: Fixed-effect estimator type.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
############################ Parsing inputs ###############################
|
|
74
|
+
|
|
75
|
+
parser = argparse.ArgumentParser(
|
|
76
|
+
description='Neural Processes (NP) for Deep Normative Modeling')
|
|
77
|
+
parser.add_argument("-r", help="Training response nifti file address",
|
|
78
|
+
required=True, dest="respfile", default=None)
|
|
79
|
+
parser.add_argument("-c", help="Training covariates pickle file address",
|
|
80
|
+
required=True, dest="covfile", default=None)
|
|
81
|
+
parser.add_argument("--tc", help="Test covariates pickle file address",
|
|
82
|
+
required=True, dest="testcovfile", default=None)
|
|
83
|
+
parser.add_argument("--tr", help="Test response nifti file address",
|
|
84
|
+
dest="testrespfile", default=None)
|
|
85
|
+
parser.add_argument("--mask", help="Mask nifti file address",
|
|
86
|
+
dest="mask", default=None)
|
|
87
|
+
parser.add_argument("-o", help="Output directory address",
|
|
88
|
+
dest="outdir", default=None)
|
|
89
|
+
parser.add_argument('-m', type=int, default=10, dest='m',
|
|
90
|
+
help='number of fixed-effect estimations')
|
|
91
|
+
parser.add_argument('--batchnum', type=int, default=10, dest='batchnum',
|
|
92
|
+
help='input batch size for training')
|
|
93
|
+
parser.add_argument('--epochs', type=int, default=100, dest='epochs',
|
|
94
|
+
help='number of epochs to train')
|
|
95
|
+
parser.add_argument('--device', type=str, default='cuda', dest='device',
|
|
96
|
+
help='Either cpu or cuda')
|
|
97
|
+
parser.add_argument('--fxestimator', type=str, default='ST', dest='estimator',
|
|
98
|
+
help='Fixed-effect estimator type.')
|
|
99
|
+
|
|
100
|
+
args = parser.parse_args()
|
|
101
|
+
|
|
102
|
+
if (args.respfile == None or args.covfile == None or args.testcovfile == None):
|
|
103
|
+
raise ValueError("Training response nifti file, Training covariates pickle file, and \
|
|
104
|
+
Test covariates pickle file must be specified.")
|
|
105
|
+
if (args.outdir == None):
|
|
106
|
+
args.outdir = os.getcwd()
|
|
107
|
+
|
|
108
|
+
cuda = args.device == 'cuda' and torch.cuda.is_available()
|
|
109
|
+
args.device = torch.device("cuda" if cuda else "cpu")
|
|
110
|
+
args.kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}
|
|
111
|
+
args.type = 'MT'
|
|
112
|
+
|
|
113
|
+
return args
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def estimate(args):
|
|
117
|
+
"""
|
|
118
|
+
Estimates the fixed-effects for the Neural Processes (NP) for Deep Normative Modeling script.
|
|
119
|
+
|
|
120
|
+
Parameters:
|
|
121
|
+
args (argparse.Namespace): An object that holds the command-line arguments as attributes. The arguments include:
|
|
122
|
+
- respfile: Training response nifti file address.
|
|
123
|
+
- covfile: Training covariates pickle file address.
|
|
124
|
+
- testcovfile: Test covariates pickle file address.
|
|
125
|
+
- mask: Mask nifti file address.
|
|
126
|
+
- outdir: Output directory address.
|
|
127
|
+
- m: Number of fixed-effect estimations.
|
|
128
|
+
- device: Either cpu or cuda.
|
|
129
|
+
- estimator: Fixed-effect estimator type.
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
None
|
|
133
|
+
|
|
134
|
+
This function loads the input data, normalizes it, and estimates the fixed-effects using either single-task (ST)
|
|
135
|
+
or multi-task (MT) regression. The results are stored in the `y_context` and `y_context_test` variables.
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
torch.set_default_dtype(torch.float32)
|
|
139
|
+
args.type = 'MT'
|
|
140
|
+
print('Loading the input Data ...')
|
|
141
|
+
responses = fileio.load_nifti(
|
|
142
|
+
args.respfile, vol=True).transpose([3, 0, 1, 2])
|
|
143
|
+
response_shape = responses.shape
|
|
144
|
+
with open(args.covfile, 'rb') as handle:
|
|
145
|
+
covariates = pickle.load(handle)['covariates']
|
|
146
|
+
with open(args.testcovfile, 'rb') as handle:
|
|
147
|
+
test_covariates = pickle.load(handle)['test_covariates']
|
|
148
|
+
if args.mask is not None:
|
|
149
|
+
mask = fileio.load_nifti(args.mask, vol=True)
|
|
150
|
+
mask = fileio.create_mask(mask, mask=None)
|
|
151
|
+
else:
|
|
152
|
+
mask = fileio.create_mask(responses[0, :, :, :], mask=None)
|
|
153
|
+
if args.testrespfile is not None:
|
|
154
|
+
test_responses = fileio.load_nifti(
|
|
155
|
+
args.testrespfile, vol=True).transpose([3, 0, 1, 2])
|
|
156
|
+
test_responses_shape = test_responses.shape
|
|
157
|
+
|
|
158
|
+
print('Normalizing the input Data ...')
|
|
159
|
+
covariates_scaler = StandardScaler()
|
|
160
|
+
covariates = covariates_scaler.fit_transform(covariates)
|
|
161
|
+
test_covariates = covariates_scaler.transform(test_covariates)
|
|
162
|
+
response_scaler = MinMaxScaler()
|
|
163
|
+
responses = unravel_2D(response_scaler.fit_transform(
|
|
164
|
+
ravel_2D(responses)), response_shape)
|
|
165
|
+
if args.testrespfile is not None:
|
|
166
|
+
test_responses = unravel_2D(response_scaler.transform(
|
|
167
|
+
ravel_2D(test_responses)), test_responses_shape)
|
|
168
|
+
test_responses = np.expand_dims(test_responses, axis=1)
|
|
169
|
+
|
|
170
|
+
factor = args.m
|
|
171
|
+
|
|
172
|
+
x_context = np.zeros([covariates.shape[0], factor,
|
|
173
|
+
covariates.shape[1]], dtype=np.float32)
|
|
174
|
+
y_context = np.zeros([responses.shape[0], factor, responses.shape[1],
|
|
175
|
+
responses.shape[2], responses.shape[3]], dtype=np.float32)
|
|
176
|
+
x_all = np.zeros([covariates.shape[0], factor,
|
|
177
|
+
covariates.shape[1]], dtype=np.float32)
|
|
178
|
+
x_context_test = np.zeros(
|
|
179
|
+
[test_covariates.shape[0], factor, test_covariates.shape[1]], dtype=np.float32)
|
|
180
|
+
y_context_test = np.zeros([test_covariates.shape[0], factor, responses.shape[1],
|
|
181
|
+
responses.shape[2], responses.shape[3]], dtype=np.float32)
|
|
182
|
+
|
|
183
|
+
print('Estimating the fixed-effects ...')
|
|
184
|
+
for i in range(factor):
|
|
185
|
+
x_context[:, i, :] = covariates[:, :]
|
|
186
|
+
x_context_test[:, i, :] = test_covariates[:, :]
|
|
187
|
+
idx = np.random.randint(0, covariates.shape[0], covariates.shape[0])
|
|
188
|
+
if args.estimator == 'ST':
|
|
189
|
+
for j in range(responses.shape[1]):
|
|
190
|
+
for k in range(responses.shape[2]):
|
|
191
|
+
for l in range(responses.shape[3]):
|
|
192
|
+
reg = LinearRegression()
|
|
193
|
+
reg.fit(x_context[idx, i, :], responses[idx, j, k, l])
|
|
194
|
+
y_context[:, i, j, k, l] = reg.predict(
|
|
195
|
+
x_context[:, i, :])
|
|
196
|
+
y_context_test[:, i, j, k, l] = reg.predict(
|
|
197
|
+
x_context_test[:, i, :])
|
|
198
|
+
elif args.estimator == 'MT':
|
|
199
|
+
reg = MultiTaskLasso(alpha=0.1)
|
|
200
|
+
reg.fit(x_context[idx, i, :], np.reshape(responses[idx, :, :, :], [
|
|
201
|
+
covariates.shape[0], np.prod(responses.shape[1:])]))
|
|
202
|
+
y_context[:, i, :, :, :] = np.reshape(reg.predict(x_context[:, i, :]),
|
|
203
|
+
[x_context.shape[0], responses.shape[1], responses.shape[2], responses.shape[3]])
|
|
204
|
+
y_context_test[:, i, :, :, :] = np.reshape(reg.predict(x_context_test[:, i, :]),
|
|
205
|
+
[x_context_test.shape[0], responses.shape[1], responses.shape[2], responses.shape[3]])
|
|
206
|
+
print('Fixed-effect %d of %d is computed!' % (i+1, factor))
|
|
207
|
+
|
|
208
|
+
x_all = x_context
|
|
209
|
+
responses = np.expand_dims(responses, axis=1).repeat(factor, axis=1)
|
|
210
|
+
|
|
211
|
+
################################## TRAINING #################################
|
|
212
|
+
|
|
213
|
+
encoder = Encoder(x_context, y_context, args).to(args.device)
|
|
214
|
+
args.cnn_feature_num = encoder.cnn_feature_num
|
|
215
|
+
decoder = Decoder(x_context, y_context, args).to(args.device)
|
|
216
|
+
model = NP(encoder, decoder, args).to(args.device)
|
|
217
|
+
|
|
218
|
+
print('Estimating the Random-effect ...')
|
|
219
|
+
k = 1
|
|
220
|
+
epochs = [int(args.epochs/4), int(args.epochs/2), int(args.epochs/5),
|
|
221
|
+
int(args.epochs-args.epochs/4-args.epochs/2-args.epochs/5)]
|
|
222
|
+
mini_batch_num = args.batchnum
|
|
223
|
+
batch_size = int(x_context.shape[0]/mini_batch_num)
|
|
224
|
+
model.train()
|
|
225
|
+
for e in range(len(epochs)):
|
|
226
|
+
optimizer = optim.Adam(model.parameters(), lr=10**(-e-2))
|
|
227
|
+
for j in range(epochs[e]):
|
|
228
|
+
train_loss = 0
|
|
229
|
+
rand_idx = np.random.permutation(x_context.shape[0])
|
|
230
|
+
for i in range(mini_batch_num):
|
|
231
|
+
optimizer.zero_grad()
|
|
232
|
+
idx = rand_idx[i*batch_size:(i+1)*batch_size]
|
|
233
|
+
y_hat, z_all, z_context, dummy = model(torch.tensor(x_context[idx, :, :], device=args.device),
|
|
234
|
+
torch.tensor(
|
|
235
|
+
y_context[idx, :, :, :, :], device=args.device),
|
|
236
|
+
torch.tensor(
|
|
237
|
+
x_all[idx, :, :], device=args.device),
|
|
238
|
+
torch.tensor(responses[idx, :, :, :, :], device=args.device))
|
|
239
|
+
loss = np_loss(y_hat, torch.tensor(
|
|
240
|
+
responses[idx, :, :, :, :], device=args.device), z_all, z_context)
|
|
241
|
+
loss.backward()
|
|
242
|
+
train_loss += loss.item()
|
|
243
|
+
optimizer.step()
|
|
244
|
+
print('Epoch: %d, Loss:%f, Average Loss:%f' %
|
|
245
|
+
(k, train_loss, train_loss/responses.shape[0]))
|
|
246
|
+
k += 1
|
|
247
|
+
|
|
248
|
+
################################## Evaluation #################################
|
|
249
|
+
|
|
250
|
+
print('Predicting on Test Data ...')
|
|
251
|
+
model.eval()
|
|
252
|
+
model.apply(apply_dropout_test)
|
|
253
|
+
with torch.no_grad():
|
|
254
|
+
y_hat, z_all, z_context, y_sigma = model(torch.tensor(x_context_test, device=args.device),
|
|
255
|
+
torch.tensor(y_context_test, device=args.device), n=15)
|
|
256
|
+
if args.testrespfile is not None:
|
|
257
|
+
test_loss = np_loss(y_hat[0:test_responses_shape[0], :],
|
|
258
|
+
torch.tensor(test_responses, device=args.device),
|
|
259
|
+
z_all, z_context).item()
|
|
260
|
+
print('Average Test Loss:%f' % (test_loss/test_responses_shape[0]))
|
|
261
|
+
|
|
262
|
+
RMSE = np.sqrt(np.mean(
|
|
263
|
+
(test_responses - y_hat[0:test_responses_shape[0], :].cpu().numpy())**2, axis=0)).squeeze() * mask
|
|
264
|
+
SMSE = RMSE ** 2 / np.var(test_responses, axis=0).squeeze()
|
|
265
|
+
Rho, pRho = compute_pearsonr(test_responses.squeeze(
|
|
266
|
+
), y_hat[0:test_responses_shape[0], :].cpu().numpy().squeeze())
|
|
267
|
+
EXPV = explained_var(test_responses.squeeze(
|
|
268
|
+
), y_hat[0:test_responses_shape[0], :].cpu().numpy().squeeze()) * mask
|
|
269
|
+
MSLL = compute_MSLL(test_responses.squeeze(), y_hat[0:test_responses_shape[0], :].cpu().numpy().squeeze(),
|
|
270
|
+
y_sigma[0:test_responses_shape[0], :].cpu().numpy().squeeze()**2, train_mean=test_responses.mean(0),
|
|
271
|
+
train_var=test_responses.var(0)).squeeze() * mask
|
|
272
|
+
|
|
273
|
+
NPMs = (test_responses - y_hat[0:test_responses_shape[0], :].cpu().numpy()) / (
|
|
274
|
+
y_sigma[0:test_responses_shape[0], :].cpu().numpy())
|
|
275
|
+
NPMs = NPMs.squeeze()
|
|
276
|
+
NPMs = NPMs * mask
|
|
277
|
+
NPMs = np.nan_to_num(NPMs)
|
|
278
|
+
|
|
279
|
+
temp = NPMs.reshape([NPMs.shape[0], NPMs.shape[1]
|
|
280
|
+
* NPMs.shape[2]*NPMs.shape[3]])
|
|
281
|
+
EVD_params = extreme_value_prob_fit(temp, 0.01)
|
|
282
|
+
abnormal_probs = extreme_value_prob(EVD_params, temp, 0.01)
|
|
283
|
+
|
|
284
|
+
############################## SAVING RESULTS #################################
|
|
285
|
+
|
|
286
|
+
print('Saving Results to: %s' % (args.outdir))
|
|
287
|
+
exfile = args.respfile
|
|
288
|
+
y_hat = y_hat.squeeze().cpu().numpy()
|
|
289
|
+
y_hat = response_scaler.inverse_transform(ravel_2D(y_hat))
|
|
290
|
+
y_hat = y_hat[:, mask.flatten()]
|
|
291
|
+
fileio.save(y_hat.T, args.outdir +
|
|
292
|
+
'/yhat.nii.gz', example=exfile, mask=mask)
|
|
293
|
+
ys2 = y_sigma.squeeze().cpu().numpy()
|
|
294
|
+
ys2 = ravel_2D(ys2) * (response_scaler.data_max_ -
|
|
295
|
+
response_scaler.data_min_)
|
|
296
|
+
ys2 = ys2**2
|
|
297
|
+
ys2 = ys2[:, mask.flatten()]
|
|
298
|
+
fileio.save(ys2.T, args.outdir +
|
|
299
|
+
'/ys2.nii.gz', example=exfile, mask=mask)
|
|
300
|
+
if args.testrespfile is not None:
|
|
301
|
+
NPMs = ravel_2D(NPMs)[:, mask.flatten()]
|
|
302
|
+
fileio.save(NPMs.T, args.outdir +
|
|
303
|
+
'/Z.nii.gz', example=exfile, mask=mask)
|
|
304
|
+
fileio.save(Rho.flatten()[mask.flatten()], args.outdir +
|
|
305
|
+
'/Rho.nii.gz', example=exfile, mask=mask)
|
|
306
|
+
fileio.save(pRho.flatten()[mask.flatten()], args.outdir +
|
|
307
|
+
'/pRho.nii.gz', example=exfile, mask=mask)
|
|
308
|
+
fileio.save(RMSE.flatten()[mask.flatten()], args.outdir +
|
|
309
|
+
'/rmse.nii.gz', example=exfile, mask=mask)
|
|
310
|
+
fileio.save(SMSE.flatten()[mask.flatten()], args.outdir +
|
|
311
|
+
'/smse.nii.gz', example=exfile, mask=mask)
|
|
312
|
+
fileio.save(EXPV.flatten()[mask.flatten()], args.outdir +
|
|
313
|
+
'/expv.nii.gz', example=exfile, mask=mask)
|
|
314
|
+
fileio.save(MSLL.flatten()[mask.flatten()], args.outdir +
|
|
315
|
+
'/msll.nii.gz', example=exfile, mask=mask)
|
|
316
|
+
|
|
317
|
+
with open(args.outdir + 'model.pkl', 'wb') as handle:
|
|
318
|
+
pickle.dump({'model': model, 'covariates_scaler': covariates_scaler,
|
|
319
|
+
'response_scaler': response_scaler, 'EVD_params': EVD_params,
|
|
320
|
+
'abnormal_probs': abnormal_probs}, handle, protocol=configs.PICKLE_PROTOCOL)
|
|
321
|
+
|
|
322
|
+
###############################################################################
|
|
323
|
+
print('DONE!')
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def main(*args):
|
|
327
|
+
""" Parse arguments and estimate model
|
|
328
|
+
"""
|
|
329
|
+
|
|
330
|
+
np.seterr(invalid='ignore')
|
|
331
|
+
args = get_args(args)
|
|
332
|
+
estimate(args)
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
if __name__ == "__main__":
|
|
336
|
+
main(sys.argv[1:])
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import pickle
|
|
3
|
+
import sys
|
|
4
|
+
from abc import ABCMeta, abstractmethod
|
|
5
|
+
|
|
6
|
+
import pandas as pd
|
|
7
|
+
from six import with_metaclass
|
|
8
|
+
|
|
9
|
+
try: # run as a package if installed
|
|
10
|
+
from pcntoolkit import configs
|
|
11
|
+
except ImportError:
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
path = os.path.abspath(os.path.dirname(__file__))
|
|
15
|
+
if path not in sys.path:
|
|
16
|
+
sys.path.append(path)
|
|
17
|
+
del path
|
|
18
|
+
import configs
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class NormBase(with_metaclass(ABCMeta)):
|
|
22
|
+
""" Base class for normative model back-end.
|
|
23
|
+
|
|
24
|
+
All normative modelling approaches must define the following methods::
|
|
25
|
+
|
|
26
|
+
NormativeModel.estimate()
|
|
27
|
+
NormativeModel.predict()
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(self, x=None):
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
@abstractmethod
|
|
34
|
+
def estimate(self, X, y):
|
|
35
|
+
""" Estimate the normative model """
|
|
36
|
+
|
|
37
|
+
@abstractmethod
|
|
38
|
+
def predict(self, Xs, X, y):
|
|
39
|
+
""" Make predictions for new data """
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
@abstractmethod
|
|
43
|
+
def n_params(self):
|
|
44
|
+
""" Report the number of parameters required by the model """
|
|
45
|
+
|
|
46
|
+
def save(self, save_path):
|
|
47
|
+
try:
|
|
48
|
+
with open(save_path, 'wb') as handle:
|
|
49
|
+
pickle.dump(self, handle, protocol=configs.PICKLE_PROTOCOL)
|
|
50
|
+
return True
|
|
51
|
+
except Exception as err:
|
|
52
|
+
print('Error:', err)
|
|
53
|
+
raise
|
|
54
|
+
|
|
55
|
+
def load(self, load_path):
|
|
56
|
+
try:
|
|
57
|
+
with open(load_path, 'rb') as handle:
|
|
58
|
+
nm = pd.read_pickle(handle)
|
|
59
|
+
return nm
|
|
60
|
+
except Exception as err:
|
|
61
|
+
print('Error:', err)
|
|
62
|
+
raise
|