pcntoolkit 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pcntoolkit/__init__.py +4 -0
- pcntoolkit/configs.py +9 -0
- pcntoolkit/dataio/__init__.py +1 -0
- pcntoolkit/dataio/fileio.py +608 -0
- pcntoolkit/model/KnuOp.py +48 -0
- pcntoolkit/model/NP.py +88 -0
- pcntoolkit/model/NPR.py +86 -0
- pcntoolkit/model/SHASH.py +509 -0
- pcntoolkit/model/__init__.py +6 -0
- pcntoolkit/model/architecture.py +219 -0
- pcntoolkit/model/bayesreg.py +585 -0
- pcntoolkit/model/core.21290 +0 -0
- pcntoolkit/model/gp.py +489 -0
- pcntoolkit/model/hbr.py +1584 -0
- pcntoolkit/model/rfa.py +245 -0
- pcntoolkit/normative.py +1647 -0
- pcntoolkit/normative_NP.py +336 -0
- pcntoolkit/normative_model/__init__.py +6 -0
- pcntoolkit/normative_model/norm_base.py +62 -0
- pcntoolkit/normative_model/norm_blr.py +303 -0
- pcntoolkit/normative_model/norm_gpr.py +112 -0
- pcntoolkit/normative_model/norm_hbr.py +752 -0
- pcntoolkit/normative_model/norm_np.py +333 -0
- pcntoolkit/normative_model/norm_rfa.py +109 -0
- pcntoolkit/normative_model/norm_utils.py +29 -0
- pcntoolkit/normative_parallel.py +1420 -0
- pcntoolkit/regression_model/blr/warp.py +1 -0
- pcntoolkit/trendsurf.py +315 -0
- pcntoolkit/util/__init__.py +1 -0
- pcntoolkit/util/bspline.py +149 -0
- pcntoolkit/util/hbr_utils.py +242 -0
- pcntoolkit/util/utils.py +1698 -0
- pcntoolkit-0.32.0.dist-info/LICENSE +674 -0
- pcntoolkit-0.32.0.dist-info/METADATA +134 -0
- pcntoolkit-0.32.0.dist-info/RECORD +37 -0
- pcntoolkit-0.32.0.dist-info/WHEEL +4 -0
- pcntoolkit-0.32.0.dist-info/entry_points.txt +5 -0
|
@@ -0,0 +1,303 @@
|
|
|
1
|
+
from __future__ import print_function
|
|
2
|
+
from __future__ import division
|
|
3
|
+
|
|
4
|
+
import os
|
|
5
|
+
import sys
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pandas as pd
|
|
8
|
+
from ast import literal_eval
|
|
9
|
+
|
|
10
|
+
try: # run as a package if installed
|
|
11
|
+
from pcntoolkit.model.bayesreg import BLR
|
|
12
|
+
from pcntoolkit.normative_model.norm_base import NormBase
|
|
13
|
+
from pcntoolkit.dataio import fileio
|
|
14
|
+
from pcntoolkit.util.utils import create_poly_basis, WarpBoxCox, \
|
|
15
|
+
WarpAffine, WarpCompose, WarpSinArcsinh
|
|
16
|
+
except ImportError:
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
path = os.path.abspath(os.path.dirname(__file__))
|
|
20
|
+
if path not in sys.path:
|
|
21
|
+
sys.path.append(path)
|
|
22
|
+
del path
|
|
23
|
+
|
|
24
|
+
from model.bayesreg import BLR
|
|
25
|
+
from norm_base import NormBase
|
|
26
|
+
from dataio import fileio
|
|
27
|
+
from util.utils import create_poly_basis, WarpBoxCox, \
|
|
28
|
+
WarpAffine, WarpCompose, WarpSinArcsinh
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class NormBLR(NormBase):
|
|
32
|
+
""" Normative modelling based on Bayesian Linear Regression
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(self, **kwargs):
|
|
36
|
+
"""
|
|
37
|
+
Initialize the NormBLR object.
|
|
38
|
+
|
|
39
|
+
This function initializes the NormBLR object with the given arguments. It requires a data matrix 'X' and optionally takes a target 'y' and parameters 'theta'.
|
|
40
|
+
It also configures the model order and heteroskedastic noise if specified in the arguments.
|
|
41
|
+
|
|
42
|
+
:param kwargs: Keyword arguments which should include:
|
|
43
|
+
- 'X': Data matrix. Must be specified.
|
|
44
|
+
- 'y': Target values. Optional.
|
|
45
|
+
- 'theta': Parameters for the model. Optional.
|
|
46
|
+
- 'optimizer': The optimization algorithm to use. Default is 'powell'.
|
|
47
|
+
- 'configparam' or 'model_order': The order of the model. Default is 1.
|
|
48
|
+
- 'varcovfile': File containing the variance-covariance matrix for heteroskedastic noise. Optional.
|
|
49
|
+
:raises ValueError: If 'X' is not specified in kwargs.
|
|
50
|
+
"""
|
|
51
|
+
X = kwargs.pop('X', None)
|
|
52
|
+
y = kwargs.pop('y', None)
|
|
53
|
+
theta = kwargs.pop('theta', None)
|
|
54
|
+
if isinstance(theta, str):
|
|
55
|
+
theta = np.array(literal_eval(theta))
|
|
56
|
+
self.optim_alg = kwargs.get('optimizer', 'powell')
|
|
57
|
+
|
|
58
|
+
if X is None:
|
|
59
|
+
raise ValueError("Data matrix must be specified")
|
|
60
|
+
|
|
61
|
+
if len(X.shape) == 1:
|
|
62
|
+
self.D = 1
|
|
63
|
+
else:
|
|
64
|
+
self.D = X.shape[1]
|
|
65
|
+
|
|
66
|
+
# Parse model order
|
|
67
|
+
if kwargs is None:
|
|
68
|
+
model_order = 1
|
|
69
|
+
elif 'configparam' in kwargs: # deprecated syntax
|
|
70
|
+
model_order = kwargs.pop('configparam')
|
|
71
|
+
elif 'model_order' in kwargs:
|
|
72
|
+
model_order = kwargs.pop('model_order')
|
|
73
|
+
else:
|
|
74
|
+
model_order = 1
|
|
75
|
+
|
|
76
|
+
# Force a default model order and check datatype
|
|
77
|
+
if model_order is None:
|
|
78
|
+
model_order = 1
|
|
79
|
+
if type(model_order) is not int:
|
|
80
|
+
model_order = int(model_order)
|
|
81
|
+
|
|
82
|
+
# configure heteroskedastic noise
|
|
83
|
+
if 'varcovfile' in kwargs:
|
|
84
|
+
var_cov_file = kwargs.get('varcovfile')
|
|
85
|
+
if var_cov_file.endswith('.pkl'):
|
|
86
|
+
self.var_covariates = pd.read_pickle(var_cov_file)
|
|
87
|
+
else:
|
|
88
|
+
self.var_covariates = np.loadtxt(var_cov_file)
|
|
89
|
+
if len(self.var_covariates.shape) == 1:
|
|
90
|
+
self.var_covariates = self.var_covariates[:, np.newaxis]
|
|
91
|
+
n_beta = self.var_covariates.shape[1]
|
|
92
|
+
self.var_groups = None
|
|
93
|
+
elif 'vargroupfile' in kwargs:
|
|
94
|
+
# configure variance groups (e.g. site specific variance)
|
|
95
|
+
var_groups_file = kwargs.pop('vargroupfile')
|
|
96
|
+
if var_groups_file.endswith('.pkl'):
|
|
97
|
+
self.var_groups = pd.read_pickle(var_groups_file)
|
|
98
|
+
else:
|
|
99
|
+
self.var_groups = np.loadtxt(var_groups_file)
|
|
100
|
+
var_ids = set(self.var_groups)
|
|
101
|
+
var_ids = sorted(list(var_ids))
|
|
102
|
+
n_beta = len(var_ids)
|
|
103
|
+
else:
|
|
104
|
+
self.var_groups = None
|
|
105
|
+
self.var_covariates = None
|
|
106
|
+
n_beta = 1
|
|
107
|
+
|
|
108
|
+
# are we using ARD?
|
|
109
|
+
if 'use_ard' in kwargs:
|
|
110
|
+
self.use_ard = kwargs.pop('use_ard')
|
|
111
|
+
else:
|
|
112
|
+
self.use_ard = False
|
|
113
|
+
if self.use_ard:
|
|
114
|
+
n_alpha = self.D * model_order
|
|
115
|
+
else:
|
|
116
|
+
n_alpha = 1
|
|
117
|
+
|
|
118
|
+
# Configure warped likelihood
|
|
119
|
+
if 'warp' in kwargs:
|
|
120
|
+
warp_str = kwargs.pop('warp')
|
|
121
|
+
if warp_str is None:
|
|
122
|
+
self.warp = None
|
|
123
|
+
n_gamma = 0
|
|
124
|
+
else:
|
|
125
|
+
# set up warp
|
|
126
|
+
exec('self.warp =' + warp_str + '()')
|
|
127
|
+
n_gamma = self.warp.get_n_params()
|
|
128
|
+
else:
|
|
129
|
+
self.warp = None
|
|
130
|
+
n_gamma = 0
|
|
131
|
+
|
|
132
|
+
self._n_params = n_alpha + n_beta + n_gamma
|
|
133
|
+
self._model_order = model_order
|
|
134
|
+
|
|
135
|
+
print("configuring BLR ( order", model_order, ")")
|
|
136
|
+
if (theta is None) or (len(theta) != self._n_params):
|
|
137
|
+
print("Using default hyperparameters")
|
|
138
|
+
self.theta0 = np.zeros(self._n_params)
|
|
139
|
+
else:
|
|
140
|
+
self.theta0 = theta
|
|
141
|
+
self.theta = self.theta0
|
|
142
|
+
|
|
143
|
+
# initialise the BLR object if the required parameters are present
|
|
144
|
+
if (theta is not None) and (y is not None):
|
|
145
|
+
Phi = create_poly_basis(X, self._model_order)
|
|
146
|
+
self.blr = BLR(theta=theta, X=Phi, y=y,
|
|
147
|
+
warp=self.warp, **kwargs)
|
|
148
|
+
else:
|
|
149
|
+
self.blr = BLR(**kwargs)
|
|
150
|
+
|
|
151
|
+
@property
|
|
152
|
+
def n_params(self):
|
|
153
|
+
return self._n_params
|
|
154
|
+
|
|
155
|
+
@property
|
|
156
|
+
def neg_log_lik(self):
|
|
157
|
+
return self.blr.nlZ
|
|
158
|
+
|
|
159
|
+
def estimate(self, X, y, **kwargs):
|
|
160
|
+
"""
|
|
161
|
+
Estimate the parameters of the model.
|
|
162
|
+
|
|
163
|
+
This function estimates the parameters of the model given the data matrix 'X' and target 'y'.
|
|
164
|
+
If 'theta' is provided in kwargs, it is used as the initial guess for the parameters.
|
|
165
|
+
Otherwise, the initial guess is set to the current value of 'self.theta'.
|
|
166
|
+
|
|
167
|
+
:param X: Data matrix.
|
|
168
|
+
:param y: Target values.
|
|
169
|
+
:param kwargs: Keyword arguments which may include:
|
|
170
|
+
- 'theta': Initial guess for the parameters. Optional.
|
|
171
|
+
- 'warp': String representing the warp function. It is removed from kwargs before passing to the BLR object.
|
|
172
|
+
:return: The instance of the NormBLR object.
|
|
173
|
+
"""
|
|
174
|
+
theta = kwargs.pop('theta', None)
|
|
175
|
+
if isinstance(theta, str):
|
|
176
|
+
theta = np.array(literal_eval(theta))
|
|
177
|
+
|
|
178
|
+
# remove warp string to prevent it being passed to the blr object
|
|
179
|
+
kwargs.pop('warp', None)
|
|
180
|
+
|
|
181
|
+
Phi = create_poly_basis(X, self._model_order)
|
|
182
|
+
if len(y.shape) > 1:
|
|
183
|
+
y = y.ravel()
|
|
184
|
+
|
|
185
|
+
if theta is None:
|
|
186
|
+
theta = self.theta0
|
|
187
|
+
|
|
188
|
+
# (re-)initialize BLR object because parameters were not specified
|
|
189
|
+
self.blr = BLR(theta=theta, X=Phi, y=y,
|
|
190
|
+
var_groups=self.var_groups,
|
|
191
|
+
warp=self.warp, **kwargs)
|
|
192
|
+
|
|
193
|
+
self.theta = self.blr.estimate(theta, Phi, y,
|
|
194
|
+
var_covariates=self.var_covariates, **kwargs)
|
|
195
|
+
|
|
196
|
+
return self
|
|
197
|
+
|
|
198
|
+
def predict(self, Xs, X=None, y=None, **kwargs):
|
|
199
|
+
"""
|
|
200
|
+
Predict the target values for the given test data.
|
|
201
|
+
|
|
202
|
+
This function predicts the target values for the given test data 'Xs' using the estimated parameters of the model.
|
|
203
|
+
If 'X' and 'y' are provided, they are used to update the model before prediction.
|
|
204
|
+
|
|
205
|
+
:param Xs: Test data matrix.
|
|
206
|
+
:param X: Training data matrix. Optional.
|
|
207
|
+
:param y: Training target values. Optional.
|
|
208
|
+
:param kwargs: Keyword arguments which may include:
|
|
209
|
+
- 'testvargroup': Variance groups for the test data. Optional.
|
|
210
|
+
- 'testvargroupfile': File containing the variance groups for the test data. Optional.
|
|
211
|
+
- 'testvarcov': Variance covariates for the test data. Optional.
|
|
212
|
+
- 'testvarcovfile': File containing the variance covariates for the test data. Optional.
|
|
213
|
+
- 'adaptresp': Responses to adapt to. Optional.
|
|
214
|
+
- 'adaptrespfile': File containing the responses to adapt to. Optional.
|
|
215
|
+
- 'adaptcov': Covariates to adapt to. Optional.
|
|
216
|
+
- 'adaptcovfile': File containing the covariates to adapt to. Optional.
|
|
217
|
+
- 'adaptvargroup': Variance groups to adapt to. Optional.
|
|
218
|
+
- 'adaptvargroupfile': File containing the variance groups to adapt to. Optional.
|
|
219
|
+
:return: The predicted target values for the test data.
|
|
220
|
+
"""
|
|
221
|
+
|
|
222
|
+
theta = self.theta # always use the estimated coefficients
|
|
223
|
+
# remove from kwargs to avoid downstream problems
|
|
224
|
+
kwargs.pop('theta', None)
|
|
225
|
+
|
|
226
|
+
Phis = create_poly_basis(Xs, self._model_order)
|
|
227
|
+
|
|
228
|
+
if X is None:
|
|
229
|
+
Phi = None
|
|
230
|
+
else:
|
|
231
|
+
Phi = create_poly_basis(X, self._model_order)
|
|
232
|
+
|
|
233
|
+
# process variance groups for the test data
|
|
234
|
+
if 'testvargroup' in kwargs:
|
|
235
|
+
var_groups_te = kwargs.pop('testvargroup')
|
|
236
|
+
else:
|
|
237
|
+
if 'testvargroupfile' in kwargs:
|
|
238
|
+
var_groups_test_file = kwargs.pop('testvargroupfile')
|
|
239
|
+
if var_groups_test_file.endswith('.pkl'):
|
|
240
|
+
var_groups_te = pd.read_pickle(var_groups_test_file)
|
|
241
|
+
else:
|
|
242
|
+
var_groups_te = np.loadtxt(var_groups_test_file)
|
|
243
|
+
else:
|
|
244
|
+
var_groups_te = None
|
|
245
|
+
|
|
246
|
+
# process test variance covariates
|
|
247
|
+
if 'testvarcov' in kwargs:
|
|
248
|
+
var_cov_te = kwargs.pop('testvarcov')
|
|
249
|
+
else:
|
|
250
|
+
if 'testvarcovfile' in kwargs:
|
|
251
|
+
var_cov_test_file = kwargs.get('testvarcovfile')
|
|
252
|
+
if var_cov_test_file.endswith('.pkl'):
|
|
253
|
+
var_cov_te = pd.read_pickle(var_cov_test_file)
|
|
254
|
+
else:
|
|
255
|
+
var_cov_te = np.loadtxt(var_cov_test_file)
|
|
256
|
+
else:
|
|
257
|
+
var_cov_te = None
|
|
258
|
+
|
|
259
|
+
# do we want to adjust the responses?
|
|
260
|
+
if 'adaptresp' in kwargs:
|
|
261
|
+
y_adapt = kwargs.pop('adaptresp')
|
|
262
|
+
else:
|
|
263
|
+
if 'adaptrespfile' in kwargs:
|
|
264
|
+
y_adapt = fileio.load(kwargs.pop('adaptrespfile'))
|
|
265
|
+
if len(y_adapt.shape) == 1:
|
|
266
|
+
y_adapt = y_adapt[:, np.newaxis]
|
|
267
|
+
else:
|
|
268
|
+
y_adapt = None
|
|
269
|
+
|
|
270
|
+
if 'adaptcov' in kwargs:
|
|
271
|
+
X_adapt = kwargs.pop('adaptcov')
|
|
272
|
+
Phi_adapt = create_poly_basis(X_adapt, self._model_order)
|
|
273
|
+
else:
|
|
274
|
+
if 'adaptcovfile' in kwargs:
|
|
275
|
+
X_adapt = fileio.load(kwargs.pop('adaptcovfile'))
|
|
276
|
+
Phi_adapt = create_poly_basis(X_adapt, self._model_order)
|
|
277
|
+
else:
|
|
278
|
+
Phi_adapt = None
|
|
279
|
+
|
|
280
|
+
if 'adaptvargroup' in kwargs:
|
|
281
|
+
var_groups_ad = kwargs.pop('adaptvargroup')
|
|
282
|
+
else:
|
|
283
|
+
if 'adaptvargroupfile' in kwargs:
|
|
284
|
+
var_groups_adapt_file = kwargs.pop('adaptvargroupfile')
|
|
285
|
+
if var_groups_adapt_file.endswith('.pkl'):
|
|
286
|
+
var_groups_ad = pd.read_pickle(var_groups_adapt_file)
|
|
287
|
+
else:
|
|
288
|
+
var_groups_ad = np.loadtxt(var_groups_adapt_file)
|
|
289
|
+
else:
|
|
290
|
+
var_groups_ad = None
|
|
291
|
+
|
|
292
|
+
if y_adapt is None:
|
|
293
|
+
yhat, s2 = self.blr.predict(theta, Phi, y, Phis,
|
|
294
|
+
var_groups_test=var_groups_te,
|
|
295
|
+
var_covariates_test=var_cov_te,
|
|
296
|
+
**kwargs)
|
|
297
|
+
else:
|
|
298
|
+
yhat, s2 = self.blr.predict_and_adjust(theta, Phi_adapt, y_adapt, Phis,
|
|
299
|
+
var_groups_test=var_groups_te,
|
|
300
|
+
var_groups_adapt=var_groups_ad,
|
|
301
|
+
**kwargs)
|
|
302
|
+
|
|
303
|
+
return yhat, s2
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
from __future__ import print_function
|
|
2
|
+
from __future__ import division
|
|
3
|
+
|
|
4
|
+
import os
|
|
5
|
+
import sys
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
try: # run as a package if installed
|
|
9
|
+
from pcntoolkit.model.gp import GPR, CovSum
|
|
10
|
+
from pcntoolkit.normative_model.norm_base import NormBase
|
|
11
|
+
except ImportError:
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
path = os.path.abspath(os.path.dirname(__file__))
|
|
15
|
+
if path not in sys.path:
|
|
16
|
+
sys.path.append(path)
|
|
17
|
+
del path
|
|
18
|
+
|
|
19
|
+
from model.gp import GPR, CovSum
|
|
20
|
+
from norm_base import NormBase
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class NormGPR(NormBase):
|
|
24
|
+
""" Classical GPR-based normative modelling approach
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, **kwargs): # X=None, y=None, theta=None,
|
|
28
|
+
"""
|
|
29
|
+
Initialize the NormGPR object.
|
|
30
|
+
|
|
31
|
+
This function initializes the NormGPR object with the given arguments. It requires a data matrix 'X' and optionally takes a target 'y' and parameters 'theta'.
|
|
32
|
+
It also initializes the covariance function and the Gaussian Process Regression (GPR) model.
|
|
33
|
+
|
|
34
|
+
:param kwargs: Keyword arguments which should include:
|
|
35
|
+
- 'X': Data matrix. Must be specified.
|
|
36
|
+
- 'y': Target values. Optional.
|
|
37
|
+
- 'theta': Parameters for the model. Optional.
|
|
38
|
+
"""
|
|
39
|
+
X = kwargs.pop('X', None)
|
|
40
|
+
y = kwargs.pop('y', None)
|
|
41
|
+
theta = kwargs.pop('theta', None)
|
|
42
|
+
|
|
43
|
+
self.covfunc = CovSum(X, ('CovLin', 'CovSqExpARD'))
|
|
44
|
+
self.theta0 = np.zeros(self.covfunc.get_n_params() + 1)
|
|
45
|
+
self.theta = self.theta0
|
|
46
|
+
|
|
47
|
+
if (theta is not None) and (X is not None) and (y is not None):
|
|
48
|
+
self.gpr = GPR(theta, self.covfunc, X, y)
|
|
49
|
+
self._n_params = self.covfunc.get_n_params() + 1
|
|
50
|
+
else:
|
|
51
|
+
self.gpr = GPR()
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def n_params(self):
|
|
55
|
+
if not hasattr(self, '_n_params'):
|
|
56
|
+
self._n_params = self.covfunc.get_n_params() + 1
|
|
57
|
+
|
|
58
|
+
return self._n_params
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def neg_log_lik(self):
|
|
62
|
+
return self.gpr.nlZ
|
|
63
|
+
|
|
64
|
+
def estimate(self, X, y, **kwargs):
|
|
65
|
+
"""
|
|
66
|
+
Estimate the parameters of the Gaussian Process Regression model.
|
|
67
|
+
|
|
68
|
+
This function estimates the parameters of the Gaussian Process Regression (GPR) model given the data matrix 'X' and target 'y'.
|
|
69
|
+
If 'theta' is provided in kwargs, it is used as the initial guess for the parameters.
|
|
70
|
+
Otherwise, the initial guess is set to the current value of 'self.theta0'.
|
|
71
|
+
|
|
72
|
+
:param X: Data matrix.
|
|
73
|
+
:param y: Target values.
|
|
74
|
+
:param kwargs: Keyword arguments which may include:
|
|
75
|
+
- 'theta': Initial guess for the parameters. Optional.
|
|
76
|
+
:return: The instance of the NormGPR object.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
theta = kwargs.pop('theta', None)
|
|
80
|
+
if theta is None:
|
|
81
|
+
theta = self.theta0
|
|
82
|
+
self.gpr = GPR(theta, self.covfunc, X, y)
|
|
83
|
+
self.theta = self.gpr.estimate(theta, self.covfunc, X, y)
|
|
84
|
+
|
|
85
|
+
return self
|
|
86
|
+
|
|
87
|
+
def predict(self, Xs, X, y, **kwargs):
|
|
88
|
+
"""
|
|
89
|
+
Predict the target values for the given test data.
|
|
90
|
+
|
|
91
|
+
This function predicts the target values for the given test data 'Xs' using the estimated parameters of the Gaussian Process Regression (GPR) model.
|
|
92
|
+
If 'X' and 'y' are provided, they are used to update the model before prediction.
|
|
93
|
+
If 'theta' is provided in kwargs, it is used as the parameters for prediction.
|
|
94
|
+
Otherwise, the current value of 'self.theta' is used.
|
|
95
|
+
|
|
96
|
+
:param Xs: Test data matrix.
|
|
97
|
+
:param X: Training data matrix. Optional.
|
|
98
|
+
:param y: Training target values. Optional.
|
|
99
|
+
:param kwargs: Keyword arguments which may include:
|
|
100
|
+
- 'theta': Parameters for prediction. Optional.
|
|
101
|
+
:return: A tuple containing the predicted target values and the marginal variances for the test data.
|
|
102
|
+
"""
|
|
103
|
+
theta = kwargs.pop('theta', None)
|
|
104
|
+
if theta is None:
|
|
105
|
+
theta = self.theta
|
|
106
|
+
yhat, s2 = self.gpr.predict(theta, X, y, Xs)
|
|
107
|
+
|
|
108
|
+
# only return the marginal variances
|
|
109
|
+
if len(s2.shape) == 2:
|
|
110
|
+
s2 = np.diag(s2)
|
|
111
|
+
|
|
112
|
+
return yhat, s2
|