sgptools 1.1.2__py3-none-any.whl → 1.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sgptools/__init__.py +1 -1
- sgptools/utils/gpflow.py +29 -7
- sgptools/utils/tsp.py +43 -21
- {sgptools-1.1.2.dist-info → sgptools-1.1.4.dist-info}/METADATA +1 -1
- {sgptools-1.1.2.dist-info → sgptools-1.1.4.dist-info}/RECORD +8 -8
- {sgptools-1.1.2.dist-info → sgptools-1.1.4.dist-info}/LICENSE.txt +0 -0
- {sgptools-1.1.2.dist-info → sgptools-1.1.4.dist-info}/WHEEL +0 -0
- {sgptools-1.1.2.dist-info → sgptools-1.1.4.dist-info}/top_level.txt +0 -0
sgptools/__init__.py
CHANGED
sgptools/utils/gpflow.py
CHANGED
@@ -21,6 +21,8 @@ import tensorflow_probability as tfp
|
|
21
21
|
import numpy as np
|
22
22
|
import matplotlib.pyplot as plt
|
23
23
|
|
24
|
+
from .misc import get_inducing_pts
|
25
|
+
|
24
26
|
|
25
27
|
def plot_loss(losses, save_file=None):
|
26
28
|
"""Helper function to plot the training loss
|
@@ -50,8 +52,10 @@ def get_model_params(X_train, y_train,
|
|
50
52
|
variance=1.0,
|
51
53
|
noise_variance=0.1,
|
52
54
|
kernel=None,
|
55
|
+
return_gp=False,
|
53
56
|
**kwargs):
|
54
|
-
"""Train a GP on the given training set
|
57
|
+
"""Train a GP on the given training set.
|
58
|
+
Trains a sparse GP if the training set is larger than 1000 samples.
|
55
59
|
|
56
60
|
Args:
|
57
61
|
X_train (ndarray): (n, d); Training set inputs
|
@@ -64,29 +68,47 @@ def get_model_params(X_train, y_train,
|
|
64
68
|
variance (float): Kernel variance
|
65
69
|
noise_variance (float): Data noise variance
|
66
70
|
kernel (gpflow.kernels.Kernel): gpflow kernel function
|
71
|
+
return_gp (bool): If True, returns the trained GP model
|
67
72
|
|
68
73
|
Returns:
|
69
74
|
loss (list): Loss values obtained during training
|
70
75
|
variance (float): Optimized data noise variance
|
71
76
|
kernel (gpflow.kernels.Kernel): Optimized gpflow kernel function
|
77
|
+
gp (gpflow.models.GPR): Optimized gpflow GP model.
|
78
|
+
Returned only if ```return_gp=True```.
|
79
|
+
|
72
80
|
"""
|
73
81
|
if kernel is None:
|
74
82
|
kernel = gpflow.kernels.SquaredExponential(lengthscales=lengthscales,
|
75
83
|
variance=variance)
|
76
84
|
|
77
|
-
|
78
|
-
|
79
|
-
|
85
|
+
if len(X_train) <= 1500:
|
86
|
+
gpr = gpflow.models.GPR(data=(X_train, y_train),
|
87
|
+
kernel=kernel,
|
88
|
+
noise_variance=noise_variance)
|
89
|
+
trainable_variables=gpr.trainable_variables
|
90
|
+
else:
|
91
|
+
inducing_pts = get_inducing_pts(X_train, 500)
|
92
|
+
gpr = gpflow.models.SGPR(data=(X_train, y_train),
|
93
|
+
kernel=kernel,
|
94
|
+
inducing_variable=inducing_pts,
|
95
|
+
noise_variance=noise_variance)
|
96
|
+
trainable_variables=gpr.trainable_variables[1:]
|
80
97
|
|
81
98
|
if max_steps > 0:
|
82
|
-
loss = optimize_model(
|
99
|
+
loss = optimize_model(gpr, max_steps=max_steps, lr=lr,
|
100
|
+
trainable_variables=trainable_variables,
|
101
|
+
**kwargs)
|
83
102
|
else:
|
84
103
|
loss = 0
|
85
104
|
|
86
105
|
if print_params:
|
87
|
-
print_summary(
|
106
|
+
print_summary(gpr)
|
88
107
|
|
89
|
-
|
108
|
+
if return_gp:
|
109
|
+
return loss, gpr.likelihood.variance, kernel, gpr
|
110
|
+
else:
|
111
|
+
return loss, gpr.likelihood.variance, kernel
|
90
112
|
|
91
113
|
|
92
114
|
class TraceInducingPts(gpflow.monitor.MonitorTask):
|
sgptools/utils/tsp.py
CHANGED
@@ -24,21 +24,23 @@ def run_tsp(nodes,
|
|
24
24
|
max_dist=25,
|
25
25
|
depth=1,
|
26
26
|
resample=None,
|
27
|
-
|
28
|
-
|
27
|
+
start_nodes=None,
|
28
|
+
end_nodes=None,
|
29
29
|
time_limit=10):
|
30
30
|
"""Method to run TSP/VRP with arbitrary start and end nodes,
|
31
31
|
and without any distance constraint
|
32
32
|
|
33
33
|
Args:
|
34
|
-
nodes (ndarray): (# nodes,
|
34
|
+
nodes (ndarray): (# nodes, ndim); Nodes to visit
|
35
35
|
num_vehicles (int): Number of robots/vehicles
|
36
36
|
max_dist (float): Maximum distance allowed for each path when handling mutli-robot case
|
37
37
|
depth (int): Internal parameter used to track re-try recursion depth
|
38
38
|
resample (int): Each solution path will be resampled to have
|
39
39
|
`resample` number of points
|
40
|
-
|
41
|
-
|
40
|
+
start_nodes (ndarray): (# num_vehicles, ndim); Optionl array of start nodes from which
|
41
|
+
to start each vehicle's solution path
|
42
|
+
end_nodes (ndarray): (# num_vehicles, ndim); Optionl array of end nodes at which
|
43
|
+
to end each vehicle's solution path
|
42
44
|
time_limit (int): TSP runtime time limit in seconds
|
43
45
|
|
44
46
|
Returns:
|
@@ -48,28 +50,42 @@ def run_tsp(nodes,
|
|
48
50
|
if depth > 5:
|
49
51
|
print('Warning: Max depth reached')
|
50
52
|
return None, None
|
51
|
-
|
53
|
+
|
54
|
+
# Add the start and end nodes to the node list
|
55
|
+
if end_nodes is not None:
|
56
|
+
assert end_nodes.shape == (num_vehicles, nodes.shape[-1]), \
|
57
|
+
"Incorrect end_nodes shape, should be (num_vehicles, ndim)!"
|
58
|
+
nodes = np.concatenate([end_nodes, nodes])
|
59
|
+
if start_nodes is not None:
|
60
|
+
assert start_nodes.shape == (num_vehicles, nodes.shape[-1]), \
|
61
|
+
"Incorrect start_nodes shape, should be (num_vehicles, ndim)!"
|
62
|
+
nodes = np.concatenate([start_nodes, nodes])
|
63
|
+
|
52
64
|
# Add dummy 0 location to get arbitrary start and end node sols
|
53
|
-
if
|
65
|
+
if start_nodes is None or end_nodes is None:
|
54
66
|
distance_mat = np.zeros((len(nodes)+1, len(nodes)+1))
|
55
67
|
distance_mat[1:, 1:] = pairwise_distances(nodes, nodes)*1e4
|
56
|
-
trim_paths = True
|
68
|
+
trim_paths = True #shift to account for dummy node
|
57
69
|
else:
|
58
70
|
distance_mat = pairwise_distances(nodes, nodes)*1e4
|
59
71
|
trim_paths = False
|
60
72
|
distance_mat = distance_mat.astype(int)
|
61
73
|
max_dist = int(max_dist*1e4)
|
62
74
|
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
75
|
+
# Get start and end node indices for ortools
|
76
|
+
if start_nodes is None:
|
77
|
+
start_idx = np.zeros(num_vehicles, dtype=int)
|
78
|
+
num_start_nodes = 0
|
79
|
+
else:
|
80
|
+
start_idx = np.arange(num_vehicles)+int(trim_paths)
|
81
|
+
num_start_nodes = len(start_nodes)
|
67
82
|
|
68
|
-
if
|
69
|
-
end_idx =
|
70
|
-
|
71
|
-
end_idx =
|
83
|
+
if end_nodes is None:
|
84
|
+
end_idx = np.zeros(num_vehicles, dtype=int)
|
85
|
+
else:
|
86
|
+
end_idx = np.arange(num_vehicles)+num_start_nodes+int(trim_paths)
|
72
87
|
|
88
|
+
# used by ortools
|
73
89
|
def distance_callback(from_index, to_index):
|
74
90
|
from_node = manager.IndexToNode(from_index)
|
75
91
|
to_node = manager.IndexToNode(to_index)
|
@@ -78,8 +94,8 @@ def run_tsp(nodes,
|
|
78
94
|
# num_locations, num vehicles, start, end
|
79
95
|
manager = pywrapcp.RoutingIndexManager(len(distance_mat),
|
80
96
|
num_vehicles,
|
81
|
-
start_idx,
|
82
|
-
end_idx)
|
97
|
+
start_idx.tolist(),
|
98
|
+
end_idx.tolist())
|
83
99
|
routing = pywrapcp.RoutingModel(manager)
|
84
100
|
transit_callback_index = routing.RegisterTransitCallback(distance_callback)
|
85
101
|
routing.SetArcCostEvaluatorOfAllVehicles(transit_callback_index)
|
@@ -165,14 +181,20 @@ def resample_path(waypoints, num_inducing=10):
|
|
165
181
|
inducing points path with fixed number of waypoints
|
166
182
|
|
167
183
|
Args:
|
168
|
-
waypoints (ndarray): (num_waypoints,
|
184
|
+
waypoints (ndarray): (num_waypoints, ndim); waypoints of path from vrp solver
|
169
185
|
num_inducing (int): Number of inducing points (waypoints) in the returned path
|
170
186
|
|
171
187
|
Returns:
|
172
|
-
points (ndarray): (num_inducing,
|
188
|
+
points (ndarray): (num_inducing, ndim); Resampled path
|
173
189
|
"""
|
190
|
+
ndim = np.shape(waypoints)[-1]
|
191
|
+
if not (ndim==2 or ndim==3):
|
192
|
+
raise Exception(f"ndim={ndim} is not supported for path resampling!")
|
174
193
|
line = LineString(waypoints)
|
175
194
|
distances = np.linspace(0, line.length, num_inducing)
|
176
195
|
points = [line.interpolate(distance) for distance in distances]
|
177
|
-
|
196
|
+
if ndim==2:
|
197
|
+
points = np.array([[p.x, p.y] for p in points])
|
198
|
+
elif ndim==3:
|
199
|
+
points = np.array([[p.x, p.y, p.z] for p in points])
|
178
200
|
return points
|
@@ -1,4 +1,4 @@
|
|
1
|
-
sgptools/__init__.py,sha256=
|
1
|
+
sgptools/__init__.py,sha256=99q8oLBb0hxIdFFxf9AWGmgcNMrJePEkcCRtn9K-iDI,449
|
2
2
|
sgptools/kernels/__init__.py,sha256=zRf4y-wJwjXKt1uOnmI5MbzCA6pRlyA7C-eagLfb3d0,190
|
3
3
|
sgptools/kernels/neural_kernel.py,sha256=9XEjcwwi1Gwj4D5cAZwq5QdWqMaI-Vu2DKgYO58DmPg,6709
|
4
4
|
sgptools/models/__init__.py,sha256=X2lIg9kf1-2MHUswk-VW2dHHcbSLxf6_IuV7lc_kvDc,682
|
@@ -14,12 +14,12 @@ sgptools/models/core/osgpr.py,sha256=gqliUdXdnt3fea206LP0rqGIggmIdKh8WP2DtFWzdBw
|
|
14
14
|
sgptools/models/core/transformations.py,sha256=X7WEKo_lFAYB5HKnFvxFsxfz6CB-jzPfVWcx1sWe2lI,18313
|
15
15
|
sgptools/utils/__init__.py,sha256=jgWqzSDgUbqOTFo8mkqZaTlyz44l3v2XYPJfcHYHjqM,376
|
16
16
|
sgptools/utils/data.py,sha256=oTXq4oRuzJdXpZC6frUfja8jhwy_ZdDDi7L1BYZcdQs,7309
|
17
|
-
sgptools/utils/gpflow.py,sha256=
|
17
|
+
sgptools/utils/gpflow.py,sha256=3Zw7iFg1pdrzAVih3dmt4y4NtJF4EMSioNVDR2DZ_sg,9575
|
18
18
|
sgptools/utils/metrics.py,sha256=tu8H129n8GuxV5fQIKLcfzPUxd7sp8zEF9qZBOZjNKo,5834
|
19
19
|
sgptools/utils/misc.py,sha256=LdAFJS7-xubWpRnrgdLOorCa9vB_8vRrvL5cahxHYNA,5442
|
20
|
-
sgptools/utils/tsp.py,sha256=
|
21
|
-
sgptools-1.1.
|
22
|
-
sgptools-1.1.
|
23
|
-
sgptools-1.1.
|
24
|
-
sgptools-1.1.
|
25
|
-
sgptools-1.1.
|
20
|
+
sgptools/utils/tsp.py,sha256=b1Lx1Pj-sv7siX-f0S6d25C3RtvszCl3IP4QbvBckqY,8151
|
21
|
+
sgptools-1.1.4.dist-info/LICENSE.txt,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
22
|
+
sgptools-1.1.4.dist-info/METADATA,sha256=o43iAz3oaqmk1E7uqGVyw4ya-YRgRwJJN-O3Y2wL_lQ,944
|
23
|
+
sgptools-1.1.4.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
24
|
+
sgptools-1.1.4.dist-info/top_level.txt,sha256=2NWH6uQLAOuLB9fG7o1pqf6Jvpe1_hEcuqfSqtUw3gw,9
|
25
|
+
sgptools-1.1.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|