sgptools 1.1.7__py3-none-any.whl → 1.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sgptools/__init__.py +1 -1
- sgptools/models/core/osgpr.py +6 -3
- {sgptools-1.1.7.dist-info → sgptools-1.1.8.dist-info}/METADATA +1 -1
- {sgptools-1.1.7.dist-info → sgptools-1.1.8.dist-info}/RECORD +7 -7
- {sgptools-1.1.7.dist-info → sgptools-1.1.8.dist-info}/LICENSE.txt +0 -0
- {sgptools-1.1.7.dist-info → sgptools-1.1.8.dist-info}/WHEEL +0 -0
- {sgptools-1.1.7.dist-info → sgptools-1.1.8.dist-info}/top_level.txt +0 -0
sgptools/__init__.py
CHANGED
sgptools/models/core/osgpr.py
CHANGED
@@ -66,21 +66,24 @@ class OSGPR_VFE(GPModel, InternalDataTrainingLossMixin):
|
|
66
66
|
Z = np.vstack((old_Z, new_Z))
|
67
67
|
return Z
|
68
68
|
|
69
|
-
def update(self, data, inducing_variable=None):
|
69
|
+
def update(self, data, inducing_variable=None, update_inducing=True):
|
70
70
|
"""Configure the OSGPR to adapt to a new batch of data.
|
71
71
|
Note: The OSGPR needs to be trained using gradient-based approaches after update.
|
72
72
|
|
73
73
|
Args:
|
74
74
|
data (tuple): (X, y) ndarrays with new batch of inputs (n, d) and labels (n, ndim)
|
75
|
+
inducing_variable (ndarray): (m_new, d): New initial inducing points
|
76
|
+
update_inducing (bool): Whether to update the inducing points
|
75
77
|
"""
|
76
78
|
self.X, self.Y = self.data = gpflow.models.util.data_input_to_tensor(data)
|
77
79
|
self.num_data = self.X.shape[0]
|
78
80
|
|
79
81
|
# Update the inducing points
|
80
82
|
self.Z_old.assign(self.inducing_variable.Z.numpy())
|
81
|
-
if inducing_variable is None:
|
83
|
+
if inducing_variable is None and update_inducing:
|
82
84
|
inducing_variable = self.init_Z()
|
83
|
-
|
85
|
+
if inducing_variable is not None:
|
86
|
+
self.inducing_variable.Z.assign(inducing_variable)
|
84
87
|
|
85
88
|
# Get posterior mean and covariance for the old inducing points
|
86
89
|
mu_old, Su_old = self.predict_f(self.Z_old, full_cov=True)
|
@@ -1,4 +1,4 @@
|
|
1
|
-
sgptools/__init__.py,sha256=
|
1
|
+
sgptools/__init__.py,sha256=fITCCOmfRYm3Ow6I6EW7ASAf_anvnmJJBA7SVc5K7as,449
|
2
2
|
sgptools/kernels/__init__.py,sha256=zRf4y-wJwjXKt1uOnmI5MbzCA6pRlyA7C-eagLfb3d0,190
|
3
3
|
sgptools/kernels/neural_kernel.py,sha256=9XEjcwwi1Gwj4D5cAZwq5QdWqMaI-Vu2DKgYO58DmPg,6709
|
4
4
|
sgptools/models/__init__.py,sha256=X2lIg9kf1-2MHUswk-VW2dHHcbSLxf6_IuV7lc_kvDc,682
|
@@ -10,7 +10,7 @@ sgptools/models/greedy_sgp.py,sha256=giddMbU3ohePTdLTcH4fDx-bS9upq1T_K8KUW_Ag6HI
|
|
10
10
|
sgptools/models/core/__init__.py,sha256=TlUdvrM0A7vSzc5IM8C2Y2kliB1ip7YLEcHHzvuw-C4,482
|
11
11
|
sgptools/models/core/augmented_gpr.py,sha256=NuYwlggz7ho7pvW4-so3ghos5vZ8oK7nRZqvHpAt0Zk,3497
|
12
12
|
sgptools/models/core/augmented_sgpr.py,sha256=qMP9J4AnOUx9AEZfaPhoyb3RP_2AOhOUCUY4eh7uOi0,7185
|
13
|
-
sgptools/models/core/osgpr.py,sha256=
|
13
|
+
sgptools/models/core/osgpr.py,sha256=trUwUOLX82BRY2KyMWFygBQkc2PItxGlWPXaAgOhpE4,12442
|
14
14
|
sgptools/models/core/transformations.py,sha256=X7WEKo_lFAYB5HKnFvxFsxfz6CB-jzPfVWcx1sWe2lI,18313
|
15
15
|
sgptools/utils/__init__.py,sha256=jgWqzSDgUbqOTFo8mkqZaTlyz44l3v2XYPJfcHYHjqM,376
|
16
16
|
sgptools/utils/data.py,sha256=ojDq6KzBXbAl5CdpA6A6me0sg5Sah9ZTl2TpFCqgR4c,7464
|
@@ -18,8 +18,8 @@ sgptools/utils/gpflow.py,sha256=46-_Tl-suxvuX3Y9KI_uiixfyCWQ2T-7BUn-7hesdVM,1004
|
|
18
18
|
sgptools/utils/metrics.py,sha256=tu8H129n8GuxV5fQIKLcfzPUxd7sp8zEF9qZBOZjNKo,5834
|
19
19
|
sgptools/utils/misc.py,sha256=11nsDEU3imnrvH7ywGMiwtNBcBnJfHX3KaGxFS3eq6w,6223
|
20
20
|
sgptools/utils/tsp.py,sha256=b1Lx1Pj-sv7siX-f0S6d25C3RtvszCl3IP4QbvBckqY,8151
|
21
|
-
sgptools-1.1.
|
22
|
-
sgptools-1.1.
|
23
|
-
sgptools-1.1.
|
24
|
-
sgptools-1.1.
|
25
|
-
sgptools-1.1.
|
21
|
+
sgptools-1.1.8.dist-info/LICENSE.txt,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
22
|
+
sgptools-1.1.8.dist-info/METADATA,sha256=oDf2T6ldEux4_WQxEPVi1DD6a-_fS9jS1wG7rMhl_8g,944
|
23
|
+
sgptools-1.1.8.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
24
|
+
sgptools-1.1.8.dist-info/top_level.txt,sha256=2NWH6uQLAOuLB9fG7o1pqf6Jvpe1_hEcuqfSqtUw3gw,9
|
25
|
+
sgptools-1.1.8.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|