sgptools 1.1.6__py3-none-any.whl → 1.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sgptools/__init__.py CHANGED
@@ -12,7 +12,7 @@ The library includes python code for the following:
12
12
 
13
13
  """
14
14
 
15
- __version__ = "1.1.6"
15
+ __version__ = "1.1.8"
16
16
  __author__ = 'Kalvik'
17
17
 
18
18
  from .models.core import *
@@ -57,28 +57,42 @@ class OSGPR_VFE(GPModel, InternalDataTrainingLossMixin):
57
57
  self.Kaa_old = tf.Variable(Kaa_old, shape=tf.TensorShape(None), trainable=False)
58
58
  self.Z_old = tf.Variable(Z_old, shape=tf.TensorShape(None), trainable=False)
59
59
 
60
- def update(self, data):
60
+ def init_Z(self):
61
+ M = self.inducing_variable.Z.shape[0]
62
+ M_old = int(0.7 * M)
63
+ M_new = M - M_old
64
+ old_Z = self.Z_old.numpy()[np.random.permutation(M)[0:M_old], :]
65
+ new_Z = self.X.numpy()[np.random.permutation(self.X.shape[0])[0:M_new], :]
66
+ Z = np.vstack((old_Z, new_Z))
67
+ return Z
68
+
69
+ def update(self, data, inducing_variable=None, update_inducing=True):
61
70
  """Configure the OSGPR to adapt to a new batch of data.
62
71
  Note: The OSGPR needs to be trained using gradient-based approaches after update.
63
72
 
64
73
  Args:
65
- data (tuple): (X, y) ndarrays with new batch of inputs (n, d) and labels (n, 1)
74
+ data (tuple): (X, y) ndarrays with new batch of inputs (n, d) and labels (n, ndim)
75
+ inducing_variable (ndarray): (m_new, d): New initial inducing points
76
+ update_inducing (bool): Whether to update the inducing points
66
77
  """
67
78
  self.X, self.Y = self.data = gpflow.models.util.data_input_to_tensor(data)
68
79
  self.num_data = self.X.shape[0]
69
80
 
70
- self.Z_old = tf.Variable(self.inducing_variable.Z.numpy(),
71
- shape=tf.TensorShape(None),
72
- trainable=False)
81
+ # Update the inducing points
82
+ self.Z_old.assign(self.inducing_variable.Z.numpy())
83
+ if inducing_variable is None and update_inducing:
84
+ inducing_variable = self.init_Z()
85
+ if inducing_variable is not None:
86
+ self.inducing_variable.Z.assign(inducing_variable)
73
87
 
74
88
  # Get posterior mean and covariance for the old inducing points
75
89
  mu_old, Su_old = self.predict_f(self.Z_old, full_cov=True)
76
- self.mu_old = tf.Variable(mu_old, shape=tf.TensorShape(None), trainable=False)
77
- self.Su_old = tf.Variable(Su_old, shape=tf.TensorShape(None), trainable=False)
78
-
90
+ self.mu_old.assign(mu_old.numpy())
91
+ self.Su_old.assign(Su_old.numpy())
92
+
79
93
  # Get the prior covariance matrix for the old inducing points
80
94
  Kaa_old = self.kernel(self.Z_old)
81
- self.Kaa_old = tf.Variable(Kaa_old, shape=tf.TensorShape(None), trainable=False)
95
+ self.Kaa_old.assign(Kaa_old.numpy())
82
96
 
83
97
  def _common_terms(self):
84
98
  Mb = self.inducing_variable.num_inducing
@@ -228,7 +242,8 @@ def init_osgpr(X_train,
228
242
  lengthscales=1.0,
229
243
  variance=1.0,
230
244
  noise_variance=0.001,
231
- kernel=None):
245
+ kernel=None,
246
+ ndim=1):
232
247
  """Initialize a VFE OSGPR model with an RBF kernel with
233
248
  unit variance and lengthcales, and 0.001 noise variance.
234
249
  Used in the Online Continuous SGP approach.
@@ -243,6 +258,7 @@ def init_osgpr(X_train,
243
258
  variance (float): Kernel variance
244
259
  noise_variance (float): Data noise variance
245
260
  kernel (gpflow.kernels.Kernel): gpflow kernel function
261
+ ndim (int): Number of output dimensions
246
262
 
247
263
  Returns:
248
264
  online_param (OSGPR_VFE): Initialized online sparse Gaussian process model
@@ -252,7 +268,7 @@ def init_osgpr(X_train,
252
268
  kernel = gpflow.kernels.SquaredExponential(lengthscales=lengthscales,
253
269
  variance=variance)
254
270
 
255
- y_train = np.zeros((len(X_train), 1), dtype=X_train.dtype)
271
+ y_train = np.zeros((len(X_train), ndim), dtype=X_train.dtype)
256
272
  Z_init = get_inducing_pts(X_train, num_inducing)
257
273
  init_param = gpflow.models.SGPR((X_train, y_train),
258
274
  kernel,
@@ -261,8 +277,8 @@ def init_osgpr(X_train,
261
277
 
262
278
  # Initialize the OSGPR model using the parameters from the SGPR model
263
279
  # The X_train and y_train here will be overwritten in the online phase
264
- X_train = np.array([[0, 0], [0, 0]])
265
- y_train = np.array([0, 0]).reshape(-1, 1)
280
+ X_train = np.zeros([2, X_train.shape[-1]], dtype=X_train.dtype)
281
+ y_train = np.zeros([2, ndim], dtype=X_train.dtype)
266
282
  Zopt = init_param.inducing_variable.Z.numpy()
267
283
  mu, Su = init_param.predict_f(Zopt, full_cov=True)
268
284
  Kaa = init_param.kernel(Zopt)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: sgptools
3
- Version: 1.1.6
3
+ Version: 1.1.8
4
4
  Summary: Software Suite for Sensor Placement and Informative Path Planning
5
5
  Home-page: https://www.itskalvik.com/sgp-tools
6
6
  Author: Kalvik
@@ -1,4 +1,4 @@
1
- sgptools/__init__.py,sha256=ETsbfpEnORTg0xjJyWQSUNI7cSLKYCAr6QRluBeGzRs,449
1
+ sgptools/__init__.py,sha256=fITCCOmfRYm3Ow6I6EW7ASAf_anvnmJJBA7SVc5K7as,449
2
2
  sgptools/kernels/__init__.py,sha256=zRf4y-wJwjXKt1uOnmI5MbzCA6pRlyA7C-eagLfb3d0,190
3
3
  sgptools/kernels/neural_kernel.py,sha256=9XEjcwwi1Gwj4D5cAZwq5QdWqMaI-Vu2DKgYO58DmPg,6709
4
4
  sgptools/models/__init__.py,sha256=X2lIg9kf1-2MHUswk-VW2dHHcbSLxf6_IuV7lc_kvDc,682
@@ -10,7 +10,7 @@ sgptools/models/greedy_sgp.py,sha256=giddMbU3ohePTdLTcH4fDx-bS9upq1T_K8KUW_Ag6HI
10
10
  sgptools/models/core/__init__.py,sha256=TlUdvrM0A7vSzc5IM8C2Y2kliB1ip7YLEcHHzvuw-C4,482
11
11
  sgptools/models/core/augmented_gpr.py,sha256=NuYwlggz7ho7pvW4-so3ghos5vZ8oK7nRZqvHpAt0Zk,3497
12
12
  sgptools/models/core/augmented_sgpr.py,sha256=qMP9J4AnOUx9AEZfaPhoyb3RP_2AOhOUCUY4eh7uOi0,7185
13
- sgptools/models/core/osgpr.py,sha256=gqliUdXdnt3fea206LP0rqGIggmIdKh8WP2DtFWzdBw,11798
13
+ sgptools/models/core/osgpr.py,sha256=trUwUOLX82BRY2KyMWFygBQkc2PItxGlWPXaAgOhpE4,12442
14
14
  sgptools/models/core/transformations.py,sha256=X7WEKo_lFAYB5HKnFvxFsxfz6CB-jzPfVWcx1sWe2lI,18313
15
15
  sgptools/utils/__init__.py,sha256=jgWqzSDgUbqOTFo8mkqZaTlyz44l3v2XYPJfcHYHjqM,376
16
16
  sgptools/utils/data.py,sha256=ojDq6KzBXbAl5CdpA6A6me0sg5Sah9ZTl2TpFCqgR4c,7464
@@ -18,8 +18,8 @@ sgptools/utils/gpflow.py,sha256=46-_Tl-suxvuX3Y9KI_uiixfyCWQ2T-7BUn-7hesdVM,1004
18
18
  sgptools/utils/metrics.py,sha256=tu8H129n8GuxV5fQIKLcfzPUxd7sp8zEF9qZBOZjNKo,5834
19
19
  sgptools/utils/misc.py,sha256=11nsDEU3imnrvH7ywGMiwtNBcBnJfHX3KaGxFS3eq6w,6223
20
20
  sgptools/utils/tsp.py,sha256=b1Lx1Pj-sv7siX-f0S6d25C3RtvszCl3IP4QbvBckqY,8151
21
- sgptools-1.1.6.dist-info/LICENSE.txt,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
22
- sgptools-1.1.6.dist-info/METADATA,sha256=SXtfy54sXEKKFd7uPKkpDA7CCQUbAunhjl9YHE_hTRs,944
23
- sgptools-1.1.6.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
24
- sgptools-1.1.6.dist-info/top_level.txt,sha256=2NWH6uQLAOuLB9fG7o1pqf6Jvpe1_hEcuqfSqtUw3gw,9
25
- sgptools-1.1.6.dist-info/RECORD,,
21
+ sgptools-1.1.8.dist-info/LICENSE.txt,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
22
+ sgptools-1.1.8.dist-info/METADATA,sha256=oDf2T6ldEux4_WQxEPVi1DD6a-_fS9jS1wG7rMhl_8g,944
23
+ sgptools-1.1.8.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
24
+ sgptools-1.1.8.dist-info/top_level.txt,sha256=2NWH6uQLAOuLB9fG7o1pqf6Jvpe1_hEcuqfSqtUw3gw,9
25
+ sgptools-1.1.8.dist-info/RECORD,,