sgptools 1.1.4__py3-none-any.whl → 1.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sgptools/__init__.py +1 -1
- sgptools/utils/data.py +3 -0
- sgptools/utils/gpflow.py +11 -2
- sgptools/utils/misc.py +22 -0
- {sgptools-1.1.4.dist-info → sgptools-1.1.6.dist-info}/METADATA +1 -1
- {sgptools-1.1.4.dist-info → sgptools-1.1.6.dist-info}/RECORD +9 -9
- {sgptools-1.1.4.dist-info → sgptools-1.1.6.dist-info}/LICENSE.txt +0 -0
- {sgptools-1.1.4.dist-info → sgptools-1.1.6.dist-info}/WHEEL +0 -0
- {sgptools-1.1.4.dist-info → sgptools-1.1.6.dist-info}/top_level.txt +0 -0
sgptools/__init__.py
CHANGED
sgptools/utils/data.py
CHANGED
@@ -120,6 +120,7 @@ def prep_synthetic_dataset(shape=(50, 50),
|
|
120
120
|
min_height=0.0,
|
121
121
|
max_height=30.0,
|
122
122
|
roughness=0.5,
|
123
|
+
random_seed=None,
|
123
124
|
**kwargs):
|
124
125
|
'''Generates a 50x50 grid of synthetic elevation data using the diamond square algorithm.
|
125
126
|
|
@@ -131,6 +132,7 @@ def prep_synthetic_dataset(shape=(50, 50),
|
|
131
132
|
min_height (float): Minimum allowed height in the sampled data
|
132
133
|
max_height (float): Maximum allowed height in the sampled data
|
133
134
|
roughness (float): Roughness of the sampled data
|
135
|
+
random_seed (int): Random seed for reproducibility
|
134
136
|
|
135
137
|
Returns:
|
136
138
|
X (ndarray): (n, d); Dataset input features
|
@@ -140,6 +142,7 @@ def prep_synthetic_dataset(shape=(50, 50),
|
|
140
142
|
min_height=min_height,
|
141
143
|
max_height=max_height,
|
142
144
|
roughness=roughness,
|
145
|
+
random_seed=random_seed,
|
143
146
|
**kwargs)
|
144
147
|
|
145
148
|
# create x and y coordinates from the extent
|
sgptools/utils/gpflow.py
CHANGED
@@ -53,6 +53,8 @@ def get_model_params(X_train, y_train,
|
|
53
53
|
noise_variance=0.1,
|
54
54
|
kernel=None,
|
55
55
|
return_gp=False,
|
56
|
+
train_inducing_pts=False,
|
57
|
+
num_inducing_pts=500,
|
56
58
|
**kwargs):
|
57
59
|
"""Train a GP on the given training set.
|
58
60
|
Trains a sparse GP if the training set is larger than 1000 samples.
|
@@ -69,6 +71,10 @@ def get_model_params(X_train, y_train,
|
|
69
71
|
noise_variance (float): Data noise variance
|
70
72
|
kernel (gpflow.kernels.Kernel): gpflow kernel function
|
71
73
|
return_gp (bool): If True, returns the trained GP model
|
74
|
+
train_inducing_pts (bool): If True, trains the inducing points when
|
75
|
+
using a sparse GP model
|
76
|
+
num_inducing_pts (int): Number of inducing points to use when training
|
77
|
+
a sparse GP model
|
72
78
|
|
73
79
|
Returns:
|
74
80
|
loss (list): Loss values obtained during training
|
@@ -88,12 +94,15 @@ def get_model_params(X_train, y_train,
|
|
88
94
|
noise_variance=noise_variance)
|
89
95
|
trainable_variables=gpr.trainable_variables
|
90
96
|
else:
|
91
|
-
inducing_pts = get_inducing_pts(X_train,
|
97
|
+
inducing_pts = get_inducing_pts(X_train, num_inducing_pts)
|
92
98
|
gpr = gpflow.models.SGPR(data=(X_train, y_train),
|
93
99
|
kernel=kernel,
|
94
100
|
inducing_variable=inducing_pts,
|
95
101
|
noise_variance=noise_variance)
|
96
|
-
|
102
|
+
if train_inducing_pts:
|
103
|
+
trainable_variables=gpr.trainable_variables
|
104
|
+
else:
|
105
|
+
trainable_variables=gpr.trainable_variables[1:]
|
97
106
|
|
98
107
|
if max_steps > 0:
|
99
108
|
loss = optimize_model(gpr, max_steps=max_steps, lr=lr,
|
sgptools/utils/misc.py
CHANGED
@@ -3,6 +3,8 @@ from .metrics import get_distance
|
|
3
3
|
from scipy.optimize import linear_sum_assignment
|
4
4
|
from sklearn.metrics import pairwise_distances
|
5
5
|
from scipy.cluster.vq import kmeans2
|
6
|
+
from shapely import geometry
|
7
|
+
import geopandas as gpd
|
6
8
|
|
7
9
|
import matplotlib.pyplot as plt
|
8
10
|
import numpy as np
|
@@ -138,3 +140,23 @@ def project_waypoints(waypoints, candidates):
|
|
138
140
|
waypoints_disc = cont2disc(waypoints, candidates)
|
139
141
|
waypoints_valid = _reoder_path(waypoints, waypoints_disc)
|
140
142
|
return waypoints_valid
|
143
|
+
|
144
|
+
def ploygon2candidats(vertices,
|
145
|
+
num_samples=5000,
|
146
|
+
random_seed=2024):
|
147
|
+
"""Sample unlabeled candidates within a polygon
|
148
|
+
|
149
|
+
Args:
|
150
|
+
vertices (ndarray): (v, 2) of vertices that define the polygon
|
151
|
+
num_samples (int): Number of samples to generate
|
152
|
+
random_seed (int): Random seed for reproducibility
|
153
|
+
|
154
|
+
Returns:
|
155
|
+
candidates (ndarray): (n, 2); Candidate sensor placement locations
|
156
|
+
"""
|
157
|
+
poly = geometry.Polygon(vertices)
|
158
|
+
sampler = gpd.GeoSeries([poly])
|
159
|
+
candidates = sampler.sample_points(size=num_samples,
|
160
|
+
rng=random_seed)
|
161
|
+
candidates = candidates.get_coordinates().to_numpy()
|
162
|
+
return candidates
|
@@ -1,4 +1,4 @@
|
|
1
|
-
sgptools/__init__.py,sha256=
|
1
|
+
sgptools/__init__.py,sha256=ETsbfpEnORTg0xjJyWQSUNI7cSLKYCAr6QRluBeGzRs,449
|
2
2
|
sgptools/kernels/__init__.py,sha256=zRf4y-wJwjXKt1uOnmI5MbzCA6pRlyA7C-eagLfb3d0,190
|
3
3
|
sgptools/kernels/neural_kernel.py,sha256=9XEjcwwi1Gwj4D5cAZwq5QdWqMaI-Vu2DKgYO58DmPg,6709
|
4
4
|
sgptools/models/__init__.py,sha256=X2lIg9kf1-2MHUswk-VW2dHHcbSLxf6_IuV7lc_kvDc,682
|
@@ -13,13 +13,13 @@ sgptools/models/core/augmented_sgpr.py,sha256=qMP9J4AnOUx9AEZfaPhoyb3RP_2AOhOUCU
|
|
13
13
|
sgptools/models/core/osgpr.py,sha256=gqliUdXdnt3fea206LP0rqGIggmIdKh8WP2DtFWzdBw,11798
|
14
14
|
sgptools/models/core/transformations.py,sha256=X7WEKo_lFAYB5HKnFvxFsxfz6CB-jzPfVWcx1sWe2lI,18313
|
15
15
|
sgptools/utils/__init__.py,sha256=jgWqzSDgUbqOTFo8mkqZaTlyz44l3v2XYPJfcHYHjqM,376
|
16
|
-
sgptools/utils/data.py,sha256=
|
17
|
-
sgptools/utils/gpflow.py,sha256=
|
16
|
+
sgptools/utils/data.py,sha256=ojDq6KzBXbAl5CdpA6A6me0sg5Sah9ZTl2TpFCqgR4c,7464
|
17
|
+
sgptools/utils/gpflow.py,sha256=46-_Tl-suxvuX3Y9KI_uiixfyCWQ2T-7BUn-7hesdVM,10047
|
18
18
|
sgptools/utils/metrics.py,sha256=tu8H129n8GuxV5fQIKLcfzPUxd7sp8zEF9qZBOZjNKo,5834
|
19
|
-
sgptools/utils/misc.py,sha256=
|
19
|
+
sgptools/utils/misc.py,sha256=11nsDEU3imnrvH7ywGMiwtNBcBnJfHX3KaGxFS3eq6w,6223
|
20
20
|
sgptools/utils/tsp.py,sha256=b1Lx1Pj-sv7siX-f0S6d25C3RtvszCl3IP4QbvBckqY,8151
|
21
|
-
sgptools-1.1.
|
22
|
-
sgptools-1.1.
|
23
|
-
sgptools-1.1.
|
24
|
-
sgptools-1.1.
|
25
|
-
sgptools-1.1.
|
21
|
+
sgptools-1.1.6.dist-info/LICENSE.txt,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
22
|
+
sgptools-1.1.6.dist-info/METADATA,sha256=SXtfy54sXEKKFd7uPKkpDA7CCQUbAunhjl9YHE_hTRs,944
|
23
|
+
sgptools-1.1.6.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
24
|
+
sgptools-1.1.6.dist-info/top_level.txt,sha256=2NWH6uQLAOuLB9fG7o1pqf6Jvpe1_hEcuqfSqtUw3gw,9
|
25
|
+
sgptools-1.1.6.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|