sgptools 1.1.6__py3-none-any.whl → 1.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sgptools/__init__.py CHANGED
@@ -12,7 +12,7 @@ The library includes python code for the following:
12
12
 
13
13
  """
14
14
 
15
- __version__ = "1.1.6"
15
+ __version__ = "1.1.7"
16
16
  __author__ = 'Kalvik'
17
17
 
18
18
  from .models.core import *
@@ -57,28 +57,39 @@ class OSGPR_VFE(GPModel, InternalDataTrainingLossMixin):
57
57
  self.Kaa_old = tf.Variable(Kaa_old, shape=tf.TensorShape(None), trainable=False)
58
58
  self.Z_old = tf.Variable(Z_old, shape=tf.TensorShape(None), trainable=False)
59
59
 
60
- def update(self, data):
60
+ def init_Z(self):
61
+ M = self.inducing_variable.Z.shape[0]
62
+ M_old = int(0.7 * M)
63
+ M_new = M - M_old
64
+ old_Z = self.Z_old.numpy()[np.random.permutation(M)[0:M_old], :]
65
+ new_Z = self.X.numpy()[np.random.permutation(self.X.shape[0])[0:M_new], :]
66
+ Z = np.vstack((old_Z, new_Z))
67
+ return Z
68
+
69
+ def update(self, data, inducing_variable=None):
61
70
  """Configure the OSGPR to adapt to a new batch of data.
62
71
  Note: The OSGPR needs to be trained using gradient-based approaches after update.
63
72
 
64
73
  Args:
65
- data (tuple): (X, y) ndarrays with new batch of inputs (n, d) and labels (n, 1)
74
+ data (tuple): (X, y) ndarrays with new batch of inputs (n, d) and labels (n, ndim)
66
75
  """
67
76
  self.X, self.Y = self.data = gpflow.models.util.data_input_to_tensor(data)
68
77
  self.num_data = self.X.shape[0]
69
78
 
70
- self.Z_old = tf.Variable(self.inducing_variable.Z.numpy(),
71
- shape=tf.TensorShape(None),
72
- trainable=False)
79
+ # Update the inducing points
80
+ self.Z_old.assign(self.inducing_variable.Z.numpy())
81
+ if inducing_variable is None:
82
+ inducing_variable = self.init_Z()
83
+ self.inducing_variable.Z.assign(inducing_variable)
73
84
 
74
85
  # Get posterior mean and covariance for the old inducing points
75
86
  mu_old, Su_old = self.predict_f(self.Z_old, full_cov=True)
76
- self.mu_old = tf.Variable(mu_old, shape=tf.TensorShape(None), trainable=False)
77
- self.Su_old = tf.Variable(Su_old, shape=tf.TensorShape(None), trainable=False)
78
-
87
+ self.mu_old.assign(mu_old.numpy())
88
+ self.Su_old.assign(Su_old.numpy())
89
+
79
90
  # Get the prior covariance matrix for the old inducing points
80
91
  Kaa_old = self.kernel(self.Z_old)
81
- self.Kaa_old = tf.Variable(Kaa_old, shape=tf.TensorShape(None), trainable=False)
92
+ self.Kaa_old.assign(Kaa_old.numpy())
82
93
 
83
94
  def _common_terms(self):
84
95
  Mb = self.inducing_variable.num_inducing
@@ -228,7 +239,8 @@ def init_osgpr(X_train,
228
239
  lengthscales=1.0,
229
240
  variance=1.0,
230
241
  noise_variance=0.001,
231
- kernel=None):
242
+ kernel=None,
243
+ ndim=1):
232
244
  """Initialize a VFE OSGPR model with an RBF kernel with
233
245
  unit variance and lengthcales, and 0.001 noise variance.
234
246
  Used in the Online Continuous SGP approach.
@@ -243,6 +255,7 @@ def init_osgpr(X_train,
243
255
  variance (float): Kernel variance
244
256
  noise_variance (float): Data noise variance
245
257
  kernel (gpflow.kernels.Kernel): gpflow kernel function
258
+ ndim (int): Number of output dimensions
246
259
 
247
260
  Returns:
248
261
  online_param (OSGPR_VFE): Initialized online sparse Gaussian process model
@@ -252,7 +265,7 @@ def init_osgpr(X_train,
252
265
  kernel = gpflow.kernels.SquaredExponential(lengthscales=lengthscales,
253
266
  variance=variance)
254
267
 
255
- y_train = np.zeros((len(X_train), 1), dtype=X_train.dtype)
268
+ y_train = np.zeros((len(X_train), ndim), dtype=X_train.dtype)
256
269
  Z_init = get_inducing_pts(X_train, num_inducing)
257
270
  init_param = gpflow.models.SGPR((X_train, y_train),
258
271
  kernel,
@@ -261,8 +274,8 @@ def init_osgpr(X_train,
261
274
 
262
275
  # Initialize the OSGPR model using the parameters from the SGPR model
263
276
  # The X_train and y_train here will be overwritten in the online phase
264
- X_train = np.array([[0, 0], [0, 0]])
265
- y_train = np.array([0, 0]).reshape(-1, 1)
277
+ X_train = np.zeros([2, X_train.shape[-1]], dtype=X_train.dtype)
278
+ y_train = np.zeros([2, ndim], dtype=X_train.dtype)
266
279
  Zopt = init_param.inducing_variable.Z.numpy()
267
280
  mu, Su = init_param.predict_f(Zopt, full_cov=True)
268
281
  Kaa = init_param.kernel(Zopt)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: sgptools
3
- Version: 1.1.6
3
+ Version: 1.1.7
4
4
  Summary: Software Suite for Sensor Placement and Informative Path Planning
5
5
  Home-page: https://www.itskalvik.com/sgp-tools
6
6
  Author: Kalvik
@@ -1,4 +1,4 @@
1
- sgptools/__init__.py,sha256=ETsbfpEnORTg0xjJyWQSUNI7cSLKYCAr6QRluBeGzRs,449
1
+ sgptools/__init__.py,sha256=dgJa5rLAj4qD6De6iK-MHAvh7CS7qJkPKrdj5zIXdgY,449
2
2
  sgptools/kernels/__init__.py,sha256=zRf4y-wJwjXKt1uOnmI5MbzCA6pRlyA7C-eagLfb3d0,190
3
3
  sgptools/kernels/neural_kernel.py,sha256=9XEjcwwi1Gwj4D5cAZwq5QdWqMaI-Vu2DKgYO58DmPg,6709
4
4
  sgptools/models/__init__.py,sha256=X2lIg9kf1-2MHUswk-VW2dHHcbSLxf6_IuV7lc_kvDc,682
@@ -10,7 +10,7 @@ sgptools/models/greedy_sgp.py,sha256=giddMbU3ohePTdLTcH4fDx-bS9upq1T_K8KUW_Ag6HI
10
10
  sgptools/models/core/__init__.py,sha256=TlUdvrM0A7vSzc5IM8C2Y2kliB1ip7YLEcHHzvuw-C4,482
11
11
  sgptools/models/core/augmented_gpr.py,sha256=NuYwlggz7ho7pvW4-so3ghos5vZ8oK7nRZqvHpAt0Zk,3497
12
12
  sgptools/models/core/augmented_sgpr.py,sha256=qMP9J4AnOUx9AEZfaPhoyb3RP_2AOhOUCUY4eh7uOi0,7185
13
- sgptools/models/core/osgpr.py,sha256=gqliUdXdnt3fea206LP0rqGIggmIdKh8WP2DtFWzdBw,11798
13
+ sgptools/models/core/osgpr.py,sha256=fyIRtNGWZeRRuojQJQAxDhMCUTKlmh5mXK3iddrPC8A,12199
14
14
  sgptools/models/core/transformations.py,sha256=X7WEKo_lFAYB5HKnFvxFsxfz6CB-jzPfVWcx1sWe2lI,18313
15
15
  sgptools/utils/__init__.py,sha256=jgWqzSDgUbqOTFo8mkqZaTlyz44l3v2XYPJfcHYHjqM,376
16
16
  sgptools/utils/data.py,sha256=ojDq6KzBXbAl5CdpA6A6me0sg5Sah9ZTl2TpFCqgR4c,7464
@@ -18,8 +18,8 @@ sgptools/utils/gpflow.py,sha256=46-_Tl-suxvuX3Y9KI_uiixfyCWQ2T-7BUn-7hesdVM,1004
18
18
  sgptools/utils/metrics.py,sha256=tu8H129n8GuxV5fQIKLcfzPUxd7sp8zEF9qZBOZjNKo,5834
19
19
  sgptools/utils/misc.py,sha256=11nsDEU3imnrvH7ywGMiwtNBcBnJfHX3KaGxFS3eq6w,6223
20
20
  sgptools/utils/tsp.py,sha256=b1Lx1Pj-sv7siX-f0S6d25C3RtvszCl3IP4QbvBckqY,8151
21
- sgptools-1.1.6.dist-info/LICENSE.txt,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
22
- sgptools-1.1.6.dist-info/METADATA,sha256=SXtfy54sXEKKFd7uPKkpDA7CCQUbAunhjl9YHE_hTRs,944
23
- sgptools-1.1.6.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
24
- sgptools-1.1.6.dist-info/top_level.txt,sha256=2NWH6uQLAOuLB9fG7o1pqf6Jvpe1_hEcuqfSqtUw3gw,9
25
- sgptools-1.1.6.dist-info/RECORD,,
21
+ sgptools-1.1.7.dist-info/LICENSE.txt,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
22
+ sgptools-1.1.7.dist-info/METADATA,sha256=SE-g1VWus5SyoA0PUmEpgHcQ-UX1EUpVXHPSMW1yACU,944
23
+ sgptools-1.1.7.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
24
+ sgptools-1.1.7.dist-info/top_level.txt,sha256=2NWH6uQLAOuLB9fG7o1pqf6Jvpe1_hEcuqfSqtUw3gw,9
25
+ sgptools-1.1.7.dist-info/RECORD,,