sgptools 1.1.2__py3-none-any.whl → 1.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sgptools/__init__.py CHANGED
@@ -12,7 +12,7 @@ The library includes python code for the following:
12
12
 
13
13
  """
14
14
 
15
- __version__ = "1.1.2"
15
+ __version__ = "1.1.4"
16
16
  __author__ = 'Kalvik'
17
17
 
18
18
  from .models.core import *
sgptools/utils/gpflow.py CHANGED
@@ -21,6 +21,8 @@ import tensorflow_probability as tfp
21
21
  import numpy as np
22
22
  import matplotlib.pyplot as plt
23
23
 
24
+ from .misc import get_inducing_pts
25
+
24
26
 
25
27
  def plot_loss(losses, save_file=None):
26
28
  """Helper function to plot the training loss
@@ -50,8 +52,10 @@ def get_model_params(X_train, y_train,
50
52
  variance=1.0,
51
53
  noise_variance=0.1,
52
54
  kernel=None,
55
+ return_gp=False,
53
56
  **kwargs):
54
- """Train a GP on the given training set
57
+ """Train a GP on the given training set.
58
+ Trains a sparse GP if the training set is larger than 1000 samples.
55
59
 
56
60
  Args:
57
61
  X_train (ndarray): (n, d); Training set inputs
@@ -64,29 +68,47 @@ def get_model_params(X_train, y_train,
64
68
  variance (float): Kernel variance
65
69
  noise_variance (float): Data noise variance
66
70
  kernel (gpflow.kernels.Kernel): gpflow kernel function
71
+ return_gp (bool): If True, returns the trained GP model
67
72
 
68
73
  Returns:
69
74
  loss (list): Loss values obtained during training
70
75
  variance (float): Optimized data noise variance
71
76
  kernel (gpflow.kernels.Kernel): Optimized gpflow kernel function
77
+ gp (gpflow.models.GPR): Optimized gpflow GP model.
78
+ Returned only if ```return_gp=True```.
79
+
72
80
  """
73
81
  if kernel is None:
74
82
  kernel = gpflow.kernels.SquaredExponential(lengthscales=lengthscales,
75
83
  variance=variance)
76
84
 
77
- gpr_gt = gpflow.models.GPR(data=(X_train, y_train),
78
- kernel=kernel,
79
- noise_variance=noise_variance)
85
+ if len(X_train) <= 1500:
86
+ gpr = gpflow.models.GPR(data=(X_train, y_train),
87
+ kernel=kernel,
88
+ noise_variance=noise_variance)
89
+ trainable_variables=gpr.trainable_variables
90
+ else:
91
+ inducing_pts = get_inducing_pts(X_train, 500)
92
+ gpr = gpflow.models.SGPR(data=(X_train, y_train),
93
+ kernel=kernel,
94
+ inducing_variable=inducing_pts,
95
+ noise_variance=noise_variance)
96
+ trainable_variables=gpr.trainable_variables[1:]
80
97
 
81
98
  if max_steps > 0:
82
- loss = optimize_model(gpr_gt, max_steps=max_steps, lr=lr, **kwargs)
99
+ loss = optimize_model(gpr, max_steps=max_steps, lr=lr,
100
+ trainable_variables=trainable_variables,
101
+ **kwargs)
83
102
  else:
84
103
  loss = 0
85
104
 
86
105
  if print_params:
87
- print_summary(gpr_gt)
106
+ print_summary(gpr)
88
107
 
89
- return loss, gpr_gt.likelihood.variance, kernel
108
+ if return_gp:
109
+ return loss, gpr.likelihood.variance, kernel, gpr
110
+ else:
111
+ return loss, gpr.likelihood.variance, kernel
90
112
 
91
113
 
92
114
  class TraceInducingPts(gpflow.monitor.MonitorTask):
sgptools/utils/tsp.py CHANGED
@@ -24,21 +24,23 @@ def run_tsp(nodes,
24
24
  max_dist=25,
25
25
  depth=1,
26
26
  resample=None,
27
- start_idx=None,
28
- end_idx=None,
27
+ start_nodes=None,
28
+ end_nodes=None,
29
29
  time_limit=10):
30
30
  """Method to run TSP/VRP with arbitrary start and end nodes,
31
31
  and without any distance constraint
32
32
 
33
33
  Args:
34
- nodes (ndarray): (# nodes, n_dim); Nodes to visit
34
+ nodes (ndarray): (# nodes, ndim); Nodes to visit
35
35
  num_vehicles (int): Number of robots/vehicles
36
36
  max_dist (float): Maximum distance allowed for each path when handling mutli-robot case
37
37
  depth (int): Internal parameter used to track re-try recursion depth
38
38
  resample (int): Each solution path will be resampled to have
39
39
  `resample` number of points
40
- start_idx (list): Optionl list of start node indices from which to start the solution path
41
- end_idx (list): Optionl list of end node indices from which to start the solution path
40
+ start_nodes (ndarray): (# num_vehicles, ndim); Optionl array of start nodes from which
41
+ to start each vehicle's solution path
42
+ end_nodes (ndarray): (# num_vehicles, ndim); Optionl array of end nodes at which
43
+ to end each vehicle's solution path
42
44
  time_limit (int): TSP runtime time limit in seconds
43
45
 
44
46
  Returns:
@@ -48,28 +50,42 @@ def run_tsp(nodes,
48
50
  if depth > 5:
49
51
  print('Warning: Max depth reached')
50
52
  return None, None
51
-
53
+
54
+ # Add the start and end nodes to the node list
55
+ if end_nodes is not None:
56
+ assert end_nodes.shape == (num_vehicles, nodes.shape[-1]), \
57
+ "Incorrect end_nodes shape, should be (num_vehicles, ndim)!"
58
+ nodes = np.concatenate([end_nodes, nodes])
59
+ if start_nodes is not None:
60
+ assert start_nodes.shape == (num_vehicles, nodes.shape[-1]), \
61
+ "Incorrect start_nodes shape, should be (num_vehicles, ndim)!"
62
+ nodes = np.concatenate([start_nodes, nodes])
63
+
52
64
  # Add dummy 0 location to get arbitrary start and end node sols
53
- if start_idx is None or end_idx is None:
65
+ if start_nodes is None or end_nodes is None:
54
66
  distance_mat = np.zeros((len(nodes)+1, len(nodes)+1))
55
67
  distance_mat[1:, 1:] = pairwise_distances(nodes, nodes)*1e4
56
- trim_paths = True
68
+ trim_paths = True #shift to account for dummy node
57
69
  else:
58
70
  distance_mat = pairwise_distances(nodes, nodes)*1e4
59
71
  trim_paths = False
60
72
  distance_mat = distance_mat.astype(int)
61
73
  max_dist = int(max_dist*1e4)
62
74
 
63
- if start_idx is None:
64
- start_idx = [0]*num_vehicles
65
- elif trim_paths:
66
- start_idx = [i+1 for i in start_idx]
75
+ # Get start and end node indices for ortools
76
+ if start_nodes is None:
77
+ start_idx = np.zeros(num_vehicles, dtype=int)
78
+ num_start_nodes = 0
79
+ else:
80
+ start_idx = np.arange(num_vehicles)+int(trim_paths)
81
+ num_start_nodes = len(start_nodes)
67
82
 
68
- if end_idx is None:
69
- end_idx = [0]*num_vehicles
70
- elif trim_paths:
71
- end_idx = [i+1 for i in end_idx]
83
+ if end_nodes is None:
84
+ end_idx = np.zeros(num_vehicles, dtype=int)
85
+ else:
86
+ end_idx = np.arange(num_vehicles)+num_start_nodes+int(trim_paths)
72
87
 
88
+ # used by ortools
73
89
  def distance_callback(from_index, to_index):
74
90
  from_node = manager.IndexToNode(from_index)
75
91
  to_node = manager.IndexToNode(to_index)
@@ -78,8 +94,8 @@ def run_tsp(nodes,
78
94
  # num_locations, num vehicles, start, end
79
95
  manager = pywrapcp.RoutingIndexManager(len(distance_mat),
80
96
  num_vehicles,
81
- start_idx,
82
- end_idx)
97
+ start_idx.tolist(),
98
+ end_idx.tolist())
83
99
  routing = pywrapcp.RoutingModel(manager)
84
100
  transit_callback_index = routing.RegisterTransitCallback(distance_callback)
85
101
  routing.SetArcCostEvaluatorOfAllVehicles(transit_callback_index)
@@ -165,14 +181,20 @@ def resample_path(waypoints, num_inducing=10):
165
181
  inducing points path with fixed number of waypoints
166
182
 
167
183
  Args:
168
- waypoints (ndarray): (num_waypoints, n_dim); waypoints of path from vrp solver
184
+ waypoints (ndarray): (num_waypoints, ndim); waypoints of path from vrp solver
169
185
  num_inducing (int): Number of inducing points (waypoints) in the returned path
170
186
 
171
187
  Returns:
172
- points (ndarray): (num_inducing, n_dim); Resampled path
188
+ points (ndarray): (num_inducing, ndim); Resampled path
173
189
  """
190
+ ndim = np.shape(waypoints)[-1]
191
+ if not (ndim==2 or ndim==3):
192
+ raise Exception(f"ndim={ndim} is not supported for path resampling!")
174
193
  line = LineString(waypoints)
175
194
  distances = np.linspace(0, line.length, num_inducing)
176
195
  points = [line.interpolate(distance) for distance in distances]
177
- points = np.array([[p.x, p.y] for p in points])
196
+ if ndim==2:
197
+ points = np.array([[p.x, p.y] for p in points])
198
+ elif ndim==3:
199
+ points = np.array([[p.x, p.y, p.z] for p in points])
178
200
  return points
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: sgptools
3
- Version: 1.1.2
3
+ Version: 1.1.4
4
4
  Summary: Software Suite for Sensor Placement and Informative Path Planning
5
5
  Home-page: https://www.itskalvik.com/sgp-tools
6
6
  Author: Kalvik
@@ -1,4 +1,4 @@
1
- sgptools/__init__.py,sha256=v20FWREF4GkLOyADat_DukZaWOptzI1Yt0RxT4JSY3o,449
1
+ sgptools/__init__.py,sha256=99q8oLBb0hxIdFFxf9AWGmgcNMrJePEkcCRtn9K-iDI,449
2
2
  sgptools/kernels/__init__.py,sha256=zRf4y-wJwjXKt1uOnmI5MbzCA6pRlyA7C-eagLfb3d0,190
3
3
  sgptools/kernels/neural_kernel.py,sha256=9XEjcwwi1Gwj4D5cAZwq5QdWqMaI-Vu2DKgYO58DmPg,6709
4
4
  sgptools/models/__init__.py,sha256=X2lIg9kf1-2MHUswk-VW2dHHcbSLxf6_IuV7lc_kvDc,682
@@ -14,12 +14,12 @@ sgptools/models/core/osgpr.py,sha256=gqliUdXdnt3fea206LP0rqGIggmIdKh8WP2DtFWzdBw
14
14
  sgptools/models/core/transformations.py,sha256=X7WEKo_lFAYB5HKnFvxFsxfz6CB-jzPfVWcx1sWe2lI,18313
15
15
  sgptools/utils/__init__.py,sha256=jgWqzSDgUbqOTFo8mkqZaTlyz44l3v2XYPJfcHYHjqM,376
16
16
  sgptools/utils/data.py,sha256=oTXq4oRuzJdXpZC6frUfja8jhwy_ZdDDi7L1BYZcdQs,7309
17
- sgptools/utils/gpflow.py,sha256=LnFYufnMW4ch7qsKnru53QUxEtIzJqE822qj6w8ssRg,8576
17
+ sgptools/utils/gpflow.py,sha256=3Zw7iFg1pdrzAVih3dmt4y4NtJF4EMSioNVDR2DZ_sg,9575
18
18
  sgptools/utils/metrics.py,sha256=tu8H129n8GuxV5fQIKLcfzPUxd7sp8zEF9qZBOZjNKo,5834
19
19
  sgptools/utils/misc.py,sha256=LdAFJS7-xubWpRnrgdLOorCa9vB_8vRrvL5cahxHYNA,5442
20
- sgptools/utils/tsp.py,sha256=RJAQ4_uE7CUtR1ei3nSnGy-1kNhw82E9P_HyaCkc4iI,7007
21
- sgptools-1.1.2.dist-info/LICENSE.txt,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
22
- sgptools-1.1.2.dist-info/METADATA,sha256=Ig-Y6qyVRqARqO4FODhkSOHUsjrrNI6Eaw96uVTNDmg,944
23
- sgptools-1.1.2.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
24
- sgptools-1.1.2.dist-info/top_level.txt,sha256=2NWH6uQLAOuLB9fG7o1pqf6Jvpe1_hEcuqfSqtUw3gw,9
25
- sgptools-1.1.2.dist-info/RECORD,,
20
+ sgptools/utils/tsp.py,sha256=b1Lx1Pj-sv7siX-f0S6d25C3RtvszCl3IP4QbvBckqY,8151
21
+ sgptools-1.1.4.dist-info/LICENSE.txt,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
22
+ sgptools-1.1.4.dist-info/METADATA,sha256=o43iAz3oaqmk1E7uqGVyw4ya-YRgRwJJN-O3Y2wL_lQ,944
23
+ sgptools-1.1.4.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
24
+ sgptools-1.1.4.dist-info/top_level.txt,sha256=2NWH6uQLAOuLB9fG7o1pqf6Jvpe1_hEcuqfSqtUw3gw,9
25
+ sgptools-1.1.4.dist-info/RECORD,,