sgptools 1.1.4__py3-none-any.whl → 1.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sgptools/__init__.py +1 -1
- sgptools/models/core/osgpr.py +26 -13
- sgptools/utils/data.py +3 -0
- sgptools/utils/gpflow.py +11 -2
- sgptools/utils/misc.py +22 -0
- {sgptools-1.1.4.dist-info → sgptools-1.1.7.dist-info}/METADATA +1 -1
- {sgptools-1.1.4.dist-info → sgptools-1.1.7.dist-info}/RECORD +10 -10
- {sgptools-1.1.4.dist-info → sgptools-1.1.7.dist-info}/LICENSE.txt +0 -0
- {sgptools-1.1.4.dist-info → sgptools-1.1.7.dist-info}/WHEEL +0 -0
- {sgptools-1.1.4.dist-info → sgptools-1.1.7.dist-info}/top_level.txt +0 -0
sgptools/__init__.py
CHANGED
sgptools/models/core/osgpr.py
CHANGED
@@ -57,28 +57,39 @@ class OSGPR_VFE(GPModel, InternalDataTrainingLossMixin):
|
|
57
57
|
self.Kaa_old = tf.Variable(Kaa_old, shape=tf.TensorShape(None), trainable=False)
|
58
58
|
self.Z_old = tf.Variable(Z_old, shape=tf.TensorShape(None), trainable=False)
|
59
59
|
|
60
|
-
def
|
60
|
+
def init_Z(self):
|
61
|
+
M = self.inducing_variable.Z.shape[0]
|
62
|
+
M_old = int(0.7 * M)
|
63
|
+
M_new = M - M_old
|
64
|
+
old_Z = self.Z_old.numpy()[np.random.permutation(M)[0:M_old], :]
|
65
|
+
new_Z = self.X.numpy()[np.random.permutation(self.X.shape[0])[0:M_new], :]
|
66
|
+
Z = np.vstack((old_Z, new_Z))
|
67
|
+
return Z
|
68
|
+
|
69
|
+
def update(self, data, inducing_variable=None):
|
61
70
|
"""Configure the OSGPR to adapt to a new batch of data.
|
62
71
|
Note: The OSGPR needs to be trained using gradient-based approaches after update.
|
63
72
|
|
64
73
|
Args:
|
65
|
-
data (tuple): (X, y) ndarrays with new batch of inputs (n, d) and labels (n,
|
74
|
+
data (tuple): (X, y) ndarrays with new batch of inputs (n, d) and labels (n, ndim)
|
66
75
|
"""
|
67
76
|
self.X, self.Y = self.data = gpflow.models.util.data_input_to_tensor(data)
|
68
77
|
self.num_data = self.X.shape[0]
|
69
78
|
|
70
|
-
|
71
|
-
|
72
|
-
|
79
|
+
# Update the inducing points
|
80
|
+
self.Z_old.assign(self.inducing_variable.Z.numpy())
|
81
|
+
if inducing_variable is None:
|
82
|
+
inducing_variable = self.init_Z()
|
83
|
+
self.inducing_variable.Z.assign(inducing_variable)
|
73
84
|
|
74
85
|
# Get posterior mean and covariance for the old inducing points
|
75
86
|
mu_old, Su_old = self.predict_f(self.Z_old, full_cov=True)
|
76
|
-
self.mu_old
|
77
|
-
self.Su_old
|
78
|
-
|
87
|
+
self.mu_old.assign(mu_old.numpy())
|
88
|
+
self.Su_old.assign(Su_old.numpy())
|
89
|
+
|
79
90
|
# Get the prior covariance matrix for the old inducing points
|
80
91
|
Kaa_old = self.kernel(self.Z_old)
|
81
|
-
self.Kaa_old
|
92
|
+
self.Kaa_old.assign(Kaa_old.numpy())
|
82
93
|
|
83
94
|
def _common_terms(self):
|
84
95
|
Mb = self.inducing_variable.num_inducing
|
@@ -228,7 +239,8 @@ def init_osgpr(X_train,
|
|
228
239
|
lengthscales=1.0,
|
229
240
|
variance=1.0,
|
230
241
|
noise_variance=0.001,
|
231
|
-
kernel=None
|
242
|
+
kernel=None,
|
243
|
+
ndim=1):
|
232
244
|
"""Initialize a VFE OSGPR model with an RBF kernel with
|
233
245
|
unit variance and lengthcales, and 0.001 noise variance.
|
234
246
|
Used in the Online Continuous SGP approach.
|
@@ -243,6 +255,7 @@ def init_osgpr(X_train,
|
|
243
255
|
variance (float): Kernel variance
|
244
256
|
noise_variance (float): Data noise variance
|
245
257
|
kernel (gpflow.kernels.Kernel): gpflow kernel function
|
258
|
+
ndim (int): Number of output dimensions
|
246
259
|
|
247
260
|
Returns:
|
248
261
|
online_param (OSGPR_VFE): Initialized online sparse Gaussian process model
|
@@ -252,7 +265,7 @@ def init_osgpr(X_train,
|
|
252
265
|
kernel = gpflow.kernels.SquaredExponential(lengthscales=lengthscales,
|
253
266
|
variance=variance)
|
254
267
|
|
255
|
-
y_train = np.zeros((len(X_train),
|
268
|
+
y_train = np.zeros((len(X_train), ndim), dtype=X_train.dtype)
|
256
269
|
Z_init = get_inducing_pts(X_train, num_inducing)
|
257
270
|
init_param = gpflow.models.SGPR((X_train, y_train),
|
258
271
|
kernel,
|
@@ -261,8 +274,8 @@ def init_osgpr(X_train,
|
|
261
274
|
|
262
275
|
# Initialize the OSGPR model using the parameters from the SGPR model
|
263
276
|
# The X_train and y_train here will be overwritten in the online phase
|
264
|
-
X_train = np.
|
265
|
-
y_train =
|
277
|
+
X_train = np.zeros([2, X_train.shape[-1]], dtype=X_train.dtype)
|
278
|
+
y_train = np.zeros([2, ndim], dtype=X_train.dtype)
|
266
279
|
Zopt = init_param.inducing_variable.Z.numpy()
|
267
280
|
mu, Su = init_param.predict_f(Zopt, full_cov=True)
|
268
281
|
Kaa = init_param.kernel(Zopt)
|
sgptools/utils/data.py
CHANGED
@@ -120,6 +120,7 @@ def prep_synthetic_dataset(shape=(50, 50),
|
|
120
120
|
min_height=0.0,
|
121
121
|
max_height=30.0,
|
122
122
|
roughness=0.5,
|
123
|
+
random_seed=None,
|
123
124
|
**kwargs):
|
124
125
|
'''Generates a 50x50 grid of synthetic elevation data using the diamond square algorithm.
|
125
126
|
|
@@ -131,6 +132,7 @@ def prep_synthetic_dataset(shape=(50, 50),
|
|
131
132
|
min_height (float): Minimum allowed height in the sampled data
|
132
133
|
max_height (float): Maximum allowed height in the sampled data
|
133
134
|
roughness (float): Roughness of the sampled data
|
135
|
+
random_seed (int): Random seed for reproducibility
|
134
136
|
|
135
137
|
Returns:
|
136
138
|
X (ndarray): (n, d); Dataset input features
|
@@ -140,6 +142,7 @@ def prep_synthetic_dataset(shape=(50, 50),
|
|
140
142
|
min_height=min_height,
|
141
143
|
max_height=max_height,
|
142
144
|
roughness=roughness,
|
145
|
+
random_seed=random_seed,
|
143
146
|
**kwargs)
|
144
147
|
|
145
148
|
# create x and y coordinates from the extent
|
sgptools/utils/gpflow.py
CHANGED
@@ -53,6 +53,8 @@ def get_model_params(X_train, y_train,
|
|
53
53
|
noise_variance=0.1,
|
54
54
|
kernel=None,
|
55
55
|
return_gp=False,
|
56
|
+
train_inducing_pts=False,
|
57
|
+
num_inducing_pts=500,
|
56
58
|
**kwargs):
|
57
59
|
"""Train a GP on the given training set.
|
58
60
|
Trains a sparse GP if the training set is larger than 1000 samples.
|
@@ -69,6 +71,10 @@ def get_model_params(X_train, y_train,
|
|
69
71
|
noise_variance (float): Data noise variance
|
70
72
|
kernel (gpflow.kernels.Kernel): gpflow kernel function
|
71
73
|
return_gp (bool): If True, returns the trained GP model
|
74
|
+
train_inducing_pts (bool): If True, trains the inducing points when
|
75
|
+
using a sparse GP model
|
76
|
+
num_inducing_pts (int): Number of inducing points to use when training
|
77
|
+
a sparse GP model
|
72
78
|
|
73
79
|
Returns:
|
74
80
|
loss (list): Loss values obtained during training
|
@@ -88,12 +94,15 @@ def get_model_params(X_train, y_train,
|
|
88
94
|
noise_variance=noise_variance)
|
89
95
|
trainable_variables=gpr.trainable_variables
|
90
96
|
else:
|
91
|
-
inducing_pts = get_inducing_pts(X_train,
|
97
|
+
inducing_pts = get_inducing_pts(X_train, num_inducing_pts)
|
92
98
|
gpr = gpflow.models.SGPR(data=(X_train, y_train),
|
93
99
|
kernel=kernel,
|
94
100
|
inducing_variable=inducing_pts,
|
95
101
|
noise_variance=noise_variance)
|
96
|
-
|
102
|
+
if train_inducing_pts:
|
103
|
+
trainable_variables=gpr.trainable_variables
|
104
|
+
else:
|
105
|
+
trainable_variables=gpr.trainable_variables[1:]
|
97
106
|
|
98
107
|
if max_steps > 0:
|
99
108
|
loss = optimize_model(gpr, max_steps=max_steps, lr=lr,
|
sgptools/utils/misc.py
CHANGED
@@ -3,6 +3,8 @@ from .metrics import get_distance
|
|
3
3
|
from scipy.optimize import linear_sum_assignment
|
4
4
|
from sklearn.metrics import pairwise_distances
|
5
5
|
from scipy.cluster.vq import kmeans2
|
6
|
+
from shapely import geometry
|
7
|
+
import geopandas as gpd
|
6
8
|
|
7
9
|
import matplotlib.pyplot as plt
|
8
10
|
import numpy as np
|
@@ -138,3 +140,23 @@ def project_waypoints(waypoints, candidates):
|
|
138
140
|
waypoints_disc = cont2disc(waypoints, candidates)
|
139
141
|
waypoints_valid = _reoder_path(waypoints, waypoints_disc)
|
140
142
|
return waypoints_valid
|
143
|
+
|
144
|
+
def ploygon2candidats(vertices,
|
145
|
+
num_samples=5000,
|
146
|
+
random_seed=2024):
|
147
|
+
"""Sample unlabeled candidates within a polygon
|
148
|
+
|
149
|
+
Args:
|
150
|
+
vertices (ndarray): (v, 2) of vertices that define the polygon
|
151
|
+
num_samples (int): Number of samples to generate
|
152
|
+
random_seed (int): Random seed for reproducibility
|
153
|
+
|
154
|
+
Returns:
|
155
|
+
candidates (ndarray): (n, 2); Candidate sensor placement locations
|
156
|
+
"""
|
157
|
+
poly = geometry.Polygon(vertices)
|
158
|
+
sampler = gpd.GeoSeries([poly])
|
159
|
+
candidates = sampler.sample_points(size=num_samples,
|
160
|
+
rng=random_seed)
|
161
|
+
candidates = candidates.get_coordinates().to_numpy()
|
162
|
+
return candidates
|
@@ -1,4 +1,4 @@
|
|
1
|
-
sgptools/__init__.py,sha256=
|
1
|
+
sgptools/__init__.py,sha256=dgJa5rLAj4qD6De6iK-MHAvh7CS7qJkPKrdj5zIXdgY,449
|
2
2
|
sgptools/kernels/__init__.py,sha256=zRf4y-wJwjXKt1uOnmI5MbzCA6pRlyA7C-eagLfb3d0,190
|
3
3
|
sgptools/kernels/neural_kernel.py,sha256=9XEjcwwi1Gwj4D5cAZwq5QdWqMaI-Vu2DKgYO58DmPg,6709
|
4
4
|
sgptools/models/__init__.py,sha256=X2lIg9kf1-2MHUswk-VW2dHHcbSLxf6_IuV7lc_kvDc,682
|
@@ -10,16 +10,16 @@ sgptools/models/greedy_sgp.py,sha256=giddMbU3ohePTdLTcH4fDx-bS9upq1T_K8KUW_Ag6HI
|
|
10
10
|
sgptools/models/core/__init__.py,sha256=TlUdvrM0A7vSzc5IM8C2Y2kliB1ip7YLEcHHzvuw-C4,482
|
11
11
|
sgptools/models/core/augmented_gpr.py,sha256=NuYwlggz7ho7pvW4-so3ghos5vZ8oK7nRZqvHpAt0Zk,3497
|
12
12
|
sgptools/models/core/augmented_sgpr.py,sha256=qMP9J4AnOUx9AEZfaPhoyb3RP_2AOhOUCUY4eh7uOi0,7185
|
13
|
-
sgptools/models/core/osgpr.py,sha256=
|
13
|
+
sgptools/models/core/osgpr.py,sha256=fyIRtNGWZeRRuojQJQAxDhMCUTKlmh5mXK3iddrPC8A,12199
|
14
14
|
sgptools/models/core/transformations.py,sha256=X7WEKo_lFAYB5HKnFvxFsxfz6CB-jzPfVWcx1sWe2lI,18313
|
15
15
|
sgptools/utils/__init__.py,sha256=jgWqzSDgUbqOTFo8mkqZaTlyz44l3v2XYPJfcHYHjqM,376
|
16
|
-
sgptools/utils/data.py,sha256=
|
17
|
-
sgptools/utils/gpflow.py,sha256=
|
16
|
+
sgptools/utils/data.py,sha256=ojDq6KzBXbAl5CdpA6A6me0sg5Sah9ZTl2TpFCqgR4c,7464
|
17
|
+
sgptools/utils/gpflow.py,sha256=46-_Tl-suxvuX3Y9KI_uiixfyCWQ2T-7BUn-7hesdVM,10047
|
18
18
|
sgptools/utils/metrics.py,sha256=tu8H129n8GuxV5fQIKLcfzPUxd7sp8zEF9qZBOZjNKo,5834
|
19
|
-
sgptools/utils/misc.py,sha256=
|
19
|
+
sgptools/utils/misc.py,sha256=11nsDEU3imnrvH7ywGMiwtNBcBnJfHX3KaGxFS3eq6w,6223
|
20
20
|
sgptools/utils/tsp.py,sha256=b1Lx1Pj-sv7siX-f0S6d25C3RtvszCl3IP4QbvBckqY,8151
|
21
|
-
sgptools-1.1.
|
22
|
-
sgptools-1.1.
|
23
|
-
sgptools-1.1.
|
24
|
-
sgptools-1.1.
|
25
|
-
sgptools-1.1.
|
21
|
+
sgptools-1.1.7.dist-info/LICENSE.txt,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
22
|
+
sgptools-1.1.7.dist-info/METADATA,sha256=SE-g1VWus5SyoA0PUmEpgHcQ-UX1EUpVXHPSMW1yACU,944
|
23
|
+
sgptools-1.1.7.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
24
|
+
sgptools-1.1.7.dist-info/top_level.txt,sha256=2NWH6uQLAOuLB9fG7o1pqf6Jvpe1_hEcuqfSqtUw3gw,9
|
25
|
+
sgptools-1.1.7.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|