sgptools 1.1.6__py3-none-any.whl → 1.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sgptools/__init__.py +1 -1
- sgptools/models/core/osgpr.py +26 -13
- {sgptools-1.1.6.dist-info → sgptools-1.1.7.dist-info}/METADATA +1 -1
- {sgptools-1.1.6.dist-info → sgptools-1.1.7.dist-info}/RECORD +7 -7
- {sgptools-1.1.6.dist-info → sgptools-1.1.7.dist-info}/LICENSE.txt +0 -0
- {sgptools-1.1.6.dist-info → sgptools-1.1.7.dist-info}/WHEEL +0 -0
- {sgptools-1.1.6.dist-info → sgptools-1.1.7.dist-info}/top_level.txt +0 -0
sgptools/__init__.py
CHANGED
sgptools/models/core/osgpr.py
CHANGED
@@ -57,28 +57,39 @@ class OSGPR_VFE(GPModel, InternalDataTrainingLossMixin):
|
|
57
57
|
self.Kaa_old = tf.Variable(Kaa_old, shape=tf.TensorShape(None), trainable=False)
|
58
58
|
self.Z_old = tf.Variable(Z_old, shape=tf.TensorShape(None), trainable=False)
|
59
59
|
|
60
|
-
def
|
60
|
+
def init_Z(self):
|
61
|
+
M = self.inducing_variable.Z.shape[0]
|
62
|
+
M_old = int(0.7 * M)
|
63
|
+
M_new = M - M_old
|
64
|
+
old_Z = self.Z_old.numpy()[np.random.permutation(M)[0:M_old], :]
|
65
|
+
new_Z = self.X.numpy()[np.random.permutation(self.X.shape[0])[0:M_new], :]
|
66
|
+
Z = np.vstack((old_Z, new_Z))
|
67
|
+
return Z
|
68
|
+
|
69
|
+
def update(self, data, inducing_variable=None):
|
61
70
|
"""Configure the OSGPR to adapt to a new batch of data.
|
62
71
|
Note: The OSGPR needs to be trained using gradient-based approaches after update.
|
63
72
|
|
64
73
|
Args:
|
65
|
-
data (tuple): (X, y) ndarrays with new batch of inputs (n, d) and labels (n,
|
74
|
+
data (tuple): (X, y) ndarrays with new batch of inputs (n, d) and labels (n, ndim)
|
66
75
|
"""
|
67
76
|
self.X, self.Y = self.data = gpflow.models.util.data_input_to_tensor(data)
|
68
77
|
self.num_data = self.X.shape[0]
|
69
78
|
|
70
|
-
|
71
|
-
|
72
|
-
|
79
|
+
# Update the inducing points
|
80
|
+
self.Z_old.assign(self.inducing_variable.Z.numpy())
|
81
|
+
if inducing_variable is None:
|
82
|
+
inducing_variable = self.init_Z()
|
83
|
+
self.inducing_variable.Z.assign(inducing_variable)
|
73
84
|
|
74
85
|
# Get posterior mean and covariance for the old inducing points
|
75
86
|
mu_old, Su_old = self.predict_f(self.Z_old, full_cov=True)
|
76
|
-
self.mu_old
|
77
|
-
self.Su_old
|
78
|
-
|
87
|
+
self.mu_old.assign(mu_old.numpy())
|
88
|
+
self.Su_old.assign(Su_old.numpy())
|
89
|
+
|
79
90
|
# Get the prior covariance matrix for the old inducing points
|
80
91
|
Kaa_old = self.kernel(self.Z_old)
|
81
|
-
self.Kaa_old
|
92
|
+
self.Kaa_old.assign(Kaa_old.numpy())
|
82
93
|
|
83
94
|
def _common_terms(self):
|
84
95
|
Mb = self.inducing_variable.num_inducing
|
@@ -228,7 +239,8 @@ def init_osgpr(X_train,
|
|
228
239
|
lengthscales=1.0,
|
229
240
|
variance=1.0,
|
230
241
|
noise_variance=0.001,
|
231
|
-
kernel=None
|
242
|
+
kernel=None,
|
243
|
+
ndim=1):
|
232
244
|
"""Initialize a VFE OSGPR model with an RBF kernel with
|
233
245
|
unit variance and lengthcales, and 0.001 noise variance.
|
234
246
|
Used in the Online Continuous SGP approach.
|
@@ -243,6 +255,7 @@ def init_osgpr(X_train,
|
|
243
255
|
variance (float): Kernel variance
|
244
256
|
noise_variance (float): Data noise variance
|
245
257
|
kernel (gpflow.kernels.Kernel): gpflow kernel function
|
258
|
+
ndim (int): Number of output dimensions
|
246
259
|
|
247
260
|
Returns:
|
248
261
|
online_param (OSGPR_VFE): Initialized online sparse Gaussian process model
|
@@ -252,7 +265,7 @@ def init_osgpr(X_train,
|
|
252
265
|
kernel = gpflow.kernels.SquaredExponential(lengthscales=lengthscales,
|
253
266
|
variance=variance)
|
254
267
|
|
255
|
-
y_train = np.zeros((len(X_train),
|
268
|
+
y_train = np.zeros((len(X_train), ndim), dtype=X_train.dtype)
|
256
269
|
Z_init = get_inducing_pts(X_train, num_inducing)
|
257
270
|
init_param = gpflow.models.SGPR((X_train, y_train),
|
258
271
|
kernel,
|
@@ -261,8 +274,8 @@ def init_osgpr(X_train,
|
|
261
274
|
|
262
275
|
# Initialize the OSGPR model using the parameters from the SGPR model
|
263
276
|
# The X_train and y_train here will be overwritten in the online phase
|
264
|
-
X_train = np.
|
265
|
-
y_train =
|
277
|
+
X_train = np.zeros([2, X_train.shape[-1]], dtype=X_train.dtype)
|
278
|
+
y_train = np.zeros([2, ndim], dtype=X_train.dtype)
|
266
279
|
Zopt = init_param.inducing_variable.Z.numpy()
|
267
280
|
mu, Su = init_param.predict_f(Zopt, full_cov=True)
|
268
281
|
Kaa = init_param.kernel(Zopt)
|
@@ -1,4 +1,4 @@
|
|
1
|
-
sgptools/__init__.py,sha256=
|
1
|
+
sgptools/__init__.py,sha256=dgJa5rLAj4qD6De6iK-MHAvh7CS7qJkPKrdj5zIXdgY,449
|
2
2
|
sgptools/kernels/__init__.py,sha256=zRf4y-wJwjXKt1uOnmI5MbzCA6pRlyA7C-eagLfb3d0,190
|
3
3
|
sgptools/kernels/neural_kernel.py,sha256=9XEjcwwi1Gwj4D5cAZwq5QdWqMaI-Vu2DKgYO58DmPg,6709
|
4
4
|
sgptools/models/__init__.py,sha256=X2lIg9kf1-2MHUswk-VW2dHHcbSLxf6_IuV7lc_kvDc,682
|
@@ -10,7 +10,7 @@ sgptools/models/greedy_sgp.py,sha256=giddMbU3ohePTdLTcH4fDx-bS9upq1T_K8KUW_Ag6HI
|
|
10
10
|
sgptools/models/core/__init__.py,sha256=TlUdvrM0A7vSzc5IM8C2Y2kliB1ip7YLEcHHzvuw-C4,482
|
11
11
|
sgptools/models/core/augmented_gpr.py,sha256=NuYwlggz7ho7pvW4-so3ghos5vZ8oK7nRZqvHpAt0Zk,3497
|
12
12
|
sgptools/models/core/augmented_sgpr.py,sha256=qMP9J4AnOUx9AEZfaPhoyb3RP_2AOhOUCUY4eh7uOi0,7185
|
13
|
-
sgptools/models/core/osgpr.py,sha256=
|
13
|
+
sgptools/models/core/osgpr.py,sha256=fyIRtNGWZeRRuojQJQAxDhMCUTKlmh5mXK3iddrPC8A,12199
|
14
14
|
sgptools/models/core/transformations.py,sha256=X7WEKo_lFAYB5HKnFvxFsxfz6CB-jzPfVWcx1sWe2lI,18313
|
15
15
|
sgptools/utils/__init__.py,sha256=jgWqzSDgUbqOTFo8mkqZaTlyz44l3v2XYPJfcHYHjqM,376
|
16
16
|
sgptools/utils/data.py,sha256=ojDq6KzBXbAl5CdpA6A6me0sg5Sah9ZTl2TpFCqgR4c,7464
|
@@ -18,8 +18,8 @@ sgptools/utils/gpflow.py,sha256=46-_Tl-suxvuX3Y9KI_uiixfyCWQ2T-7BUn-7hesdVM,1004
|
|
18
18
|
sgptools/utils/metrics.py,sha256=tu8H129n8GuxV5fQIKLcfzPUxd7sp8zEF9qZBOZjNKo,5834
|
19
19
|
sgptools/utils/misc.py,sha256=11nsDEU3imnrvH7ywGMiwtNBcBnJfHX3KaGxFS3eq6w,6223
|
20
20
|
sgptools/utils/tsp.py,sha256=b1Lx1Pj-sv7siX-f0S6d25C3RtvszCl3IP4QbvBckqY,8151
|
21
|
-
sgptools-1.1.
|
22
|
-
sgptools-1.1.
|
23
|
-
sgptools-1.1.
|
24
|
-
sgptools-1.1.
|
25
|
-
sgptools-1.1.
|
21
|
+
sgptools-1.1.7.dist-info/LICENSE.txt,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
22
|
+
sgptools-1.1.7.dist-info/METADATA,sha256=SE-g1VWus5SyoA0PUmEpgHcQ-UX1EUpVXHPSMW1yACU,944
|
23
|
+
sgptools-1.1.7.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
24
|
+
sgptools-1.1.7.dist-info/top_level.txt,sha256=2NWH6uQLAOuLB9fG7o1pqf6Jvpe1_hEcuqfSqtUw3gw,9
|
25
|
+
sgptools-1.1.7.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|