sgptools 1.1.6__py3-none-any.whl → 1.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sgptools/__init__.py +1 -1
- sgptools/models/core/osgpr.py +29 -13
- {sgptools-1.1.6.dist-info → sgptools-1.1.8.dist-info}/METADATA +1 -1
- {sgptools-1.1.6.dist-info → sgptools-1.1.8.dist-info}/RECORD +7 -7
- {sgptools-1.1.6.dist-info → sgptools-1.1.8.dist-info}/LICENSE.txt +0 -0
- {sgptools-1.1.6.dist-info → sgptools-1.1.8.dist-info}/WHEEL +0 -0
- {sgptools-1.1.6.dist-info → sgptools-1.1.8.dist-info}/top_level.txt +0 -0
sgptools/__init__.py
CHANGED
sgptools/models/core/osgpr.py
CHANGED
@@ -57,28 +57,42 @@ class OSGPR_VFE(GPModel, InternalDataTrainingLossMixin):
|
|
57
57
|
self.Kaa_old = tf.Variable(Kaa_old, shape=tf.TensorShape(None), trainable=False)
|
58
58
|
self.Z_old = tf.Variable(Z_old, shape=tf.TensorShape(None), trainable=False)
|
59
59
|
|
60
|
-
def
|
60
|
+
def init_Z(self):
|
61
|
+
M = self.inducing_variable.Z.shape[0]
|
62
|
+
M_old = int(0.7 * M)
|
63
|
+
M_new = M - M_old
|
64
|
+
old_Z = self.Z_old.numpy()[np.random.permutation(M)[0:M_old], :]
|
65
|
+
new_Z = self.X.numpy()[np.random.permutation(self.X.shape[0])[0:M_new], :]
|
66
|
+
Z = np.vstack((old_Z, new_Z))
|
67
|
+
return Z
|
68
|
+
|
69
|
+
def update(self, data, inducing_variable=None, update_inducing=True):
|
61
70
|
"""Configure the OSGPR to adapt to a new batch of data.
|
62
71
|
Note: The OSGPR needs to be trained using gradient-based approaches after update.
|
63
72
|
|
64
73
|
Args:
|
65
|
-
data (tuple): (X, y) ndarrays with new batch of inputs (n, d) and labels (n,
|
74
|
+
data (tuple): (X, y) ndarrays with new batch of inputs (n, d) and labels (n, ndim)
|
75
|
+
inducing_variable (ndarray): (m_new, d): New initial inducing points
|
76
|
+
update_inducing (bool): Whether to update the inducing points
|
66
77
|
"""
|
67
78
|
self.X, self.Y = self.data = gpflow.models.util.data_input_to_tensor(data)
|
68
79
|
self.num_data = self.X.shape[0]
|
69
80
|
|
70
|
-
|
71
|
-
|
72
|
-
|
81
|
+
# Update the inducing points
|
82
|
+
self.Z_old.assign(self.inducing_variable.Z.numpy())
|
83
|
+
if inducing_variable is None and update_inducing:
|
84
|
+
inducing_variable = self.init_Z()
|
85
|
+
if inducing_variable is not None:
|
86
|
+
self.inducing_variable.Z.assign(inducing_variable)
|
73
87
|
|
74
88
|
# Get posterior mean and covariance for the old inducing points
|
75
89
|
mu_old, Su_old = self.predict_f(self.Z_old, full_cov=True)
|
76
|
-
self.mu_old
|
77
|
-
self.Su_old
|
78
|
-
|
90
|
+
self.mu_old.assign(mu_old.numpy())
|
91
|
+
self.Su_old.assign(Su_old.numpy())
|
92
|
+
|
79
93
|
# Get the prior covariance matrix for the old inducing points
|
80
94
|
Kaa_old = self.kernel(self.Z_old)
|
81
|
-
self.Kaa_old
|
95
|
+
self.Kaa_old.assign(Kaa_old.numpy())
|
82
96
|
|
83
97
|
def _common_terms(self):
|
84
98
|
Mb = self.inducing_variable.num_inducing
|
@@ -228,7 +242,8 @@ def init_osgpr(X_train,
|
|
228
242
|
lengthscales=1.0,
|
229
243
|
variance=1.0,
|
230
244
|
noise_variance=0.001,
|
231
|
-
kernel=None
|
245
|
+
kernel=None,
|
246
|
+
ndim=1):
|
232
247
|
"""Initialize a VFE OSGPR model with an RBF kernel with
|
233
248
|
unit variance and lengthcales, and 0.001 noise variance.
|
234
249
|
Used in the Online Continuous SGP approach.
|
@@ -243,6 +258,7 @@ def init_osgpr(X_train,
|
|
243
258
|
variance (float): Kernel variance
|
244
259
|
noise_variance (float): Data noise variance
|
245
260
|
kernel (gpflow.kernels.Kernel): gpflow kernel function
|
261
|
+
ndim (int): Number of output dimensions
|
246
262
|
|
247
263
|
Returns:
|
248
264
|
online_param (OSGPR_VFE): Initialized online sparse Gaussian process model
|
@@ -252,7 +268,7 @@ def init_osgpr(X_train,
|
|
252
268
|
kernel = gpflow.kernels.SquaredExponential(lengthscales=lengthscales,
|
253
269
|
variance=variance)
|
254
270
|
|
255
|
-
y_train = np.zeros((len(X_train),
|
271
|
+
y_train = np.zeros((len(X_train), ndim), dtype=X_train.dtype)
|
256
272
|
Z_init = get_inducing_pts(X_train, num_inducing)
|
257
273
|
init_param = gpflow.models.SGPR((X_train, y_train),
|
258
274
|
kernel,
|
@@ -261,8 +277,8 @@ def init_osgpr(X_train,
|
|
261
277
|
|
262
278
|
# Initialize the OSGPR model using the parameters from the SGPR model
|
263
279
|
# The X_train and y_train here will be overwritten in the online phase
|
264
|
-
X_train = np.
|
265
|
-
y_train =
|
280
|
+
X_train = np.zeros([2, X_train.shape[-1]], dtype=X_train.dtype)
|
281
|
+
y_train = np.zeros([2, ndim], dtype=X_train.dtype)
|
266
282
|
Zopt = init_param.inducing_variable.Z.numpy()
|
267
283
|
mu, Su = init_param.predict_f(Zopt, full_cov=True)
|
268
284
|
Kaa = init_param.kernel(Zopt)
|
@@ -1,4 +1,4 @@
|
|
1
|
-
sgptools/__init__.py,sha256=
|
1
|
+
sgptools/__init__.py,sha256=fITCCOmfRYm3Ow6I6EW7ASAf_anvnmJJBA7SVc5K7as,449
|
2
2
|
sgptools/kernels/__init__.py,sha256=zRf4y-wJwjXKt1uOnmI5MbzCA6pRlyA7C-eagLfb3d0,190
|
3
3
|
sgptools/kernels/neural_kernel.py,sha256=9XEjcwwi1Gwj4D5cAZwq5QdWqMaI-Vu2DKgYO58DmPg,6709
|
4
4
|
sgptools/models/__init__.py,sha256=X2lIg9kf1-2MHUswk-VW2dHHcbSLxf6_IuV7lc_kvDc,682
|
@@ -10,7 +10,7 @@ sgptools/models/greedy_sgp.py,sha256=giddMbU3ohePTdLTcH4fDx-bS9upq1T_K8KUW_Ag6HI
|
|
10
10
|
sgptools/models/core/__init__.py,sha256=TlUdvrM0A7vSzc5IM8C2Y2kliB1ip7YLEcHHzvuw-C4,482
|
11
11
|
sgptools/models/core/augmented_gpr.py,sha256=NuYwlggz7ho7pvW4-so3ghos5vZ8oK7nRZqvHpAt0Zk,3497
|
12
12
|
sgptools/models/core/augmented_sgpr.py,sha256=qMP9J4AnOUx9AEZfaPhoyb3RP_2AOhOUCUY4eh7uOi0,7185
|
13
|
-
sgptools/models/core/osgpr.py,sha256=
|
13
|
+
sgptools/models/core/osgpr.py,sha256=trUwUOLX82BRY2KyMWFygBQkc2PItxGlWPXaAgOhpE4,12442
|
14
14
|
sgptools/models/core/transformations.py,sha256=X7WEKo_lFAYB5HKnFvxFsxfz6CB-jzPfVWcx1sWe2lI,18313
|
15
15
|
sgptools/utils/__init__.py,sha256=jgWqzSDgUbqOTFo8mkqZaTlyz44l3v2XYPJfcHYHjqM,376
|
16
16
|
sgptools/utils/data.py,sha256=ojDq6KzBXbAl5CdpA6A6me0sg5Sah9ZTl2TpFCqgR4c,7464
|
@@ -18,8 +18,8 @@ sgptools/utils/gpflow.py,sha256=46-_Tl-suxvuX3Y9KI_uiixfyCWQ2T-7BUn-7hesdVM,1004
|
|
18
18
|
sgptools/utils/metrics.py,sha256=tu8H129n8GuxV5fQIKLcfzPUxd7sp8zEF9qZBOZjNKo,5834
|
19
19
|
sgptools/utils/misc.py,sha256=11nsDEU3imnrvH7ywGMiwtNBcBnJfHX3KaGxFS3eq6w,6223
|
20
20
|
sgptools/utils/tsp.py,sha256=b1Lx1Pj-sv7siX-f0S6d25C3RtvszCl3IP4QbvBckqY,8151
|
21
|
-
sgptools-1.1.
|
22
|
-
sgptools-1.1.
|
23
|
-
sgptools-1.1.
|
24
|
-
sgptools-1.1.
|
25
|
-
sgptools-1.1.
|
21
|
+
sgptools-1.1.8.dist-info/LICENSE.txt,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
22
|
+
sgptools-1.1.8.dist-info/METADATA,sha256=oDf2T6ldEux4_WQxEPVi1DD6a-_fS9jS1wG7rMhl_8g,944
|
23
|
+
sgptools-1.1.8.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
24
|
+
sgptools-1.1.8.dist-info/top_level.txt,sha256=2NWH6uQLAOuLB9fG7o1pqf6Jvpe1_hEcuqfSqtUw3gw,9
|
25
|
+
sgptools-1.1.8.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|