multimodalrouter 0.1.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- multimodalrouter/__init__.py +7 -0
- multimodalrouter/graph/__init__.py +2 -0
- multimodalrouter/graph/dataclasses.py +220 -0
- multimodalrouter/graph/graph.py +798 -0
- multimodalrouter/graphics/__init__.py +1 -0
- multimodalrouter/graphics/graphicsWrapper.py +323 -0
- multimodalrouter/router/__init__.py +0 -0
- multimodalrouter/router/build.py +97 -0
- multimodalrouter/router/route.py +71 -0
- multimodalrouter/utils/__init__.py +1 -0
- multimodalrouter/utils/preprocessor.py +177 -0
- multimodalrouter-0.1.14.dist-info/METADATA +131 -0
- multimodalrouter-0.1.14.dist-info/RECORD +18 -0
- multimodalrouter-0.1.14.dist-info/WHEEL +5 -0
- multimodalrouter-0.1.14.dist-info/entry_points.txt +3 -0
- multimodalrouter-0.1.14.dist-info/licenses/LICENSE.md +10 -0
- multimodalrouter-0.1.14.dist-info/licenses/NOTICE.md +44 -0
- multimodalrouter-0.1.14.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .graphicsWrapper import GraphDisplay # noqa: F401
|
|
@@ -0,0 +1,323 @@
|
|
|
1
|
+
# dataclasses.py
|
|
2
|
+
# Copyright (c) 2025 Tobias Karusseit
|
|
3
|
+
# Licensed under the MIT License. See LICENSE file in the project root for full license information.
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
from ..graph import RouteGraph
|
|
7
|
+
import plotly.graph_objects as go
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class GraphDisplay():
|
|
11
|
+
|
|
12
|
+
def __init__(self, graph: RouteGraph, name: str = "Graph", iconSize: int = 10) -> None:
|
|
13
|
+
self.graph: RouteGraph = graph
|
|
14
|
+
self.name: str = name
|
|
15
|
+
self.iconSize: int = iconSize
|
|
16
|
+
|
|
17
|
+
def _toPlotlyFormat(
|
|
18
|
+
self,
|
|
19
|
+
nodeTransform=None,
|
|
20
|
+
edgeTransform=None
|
|
21
|
+
):
|
|
22
|
+
"""
|
|
23
|
+
transform the graph data into plotly format.to use the display function
|
|
24
|
+
|
|
25
|
+
args:
|
|
26
|
+
- nodeTransform: function to transform the node coordinates (default = None)
|
|
27
|
+
- edgeTransform: function to transform the edge coordinates (default = None)
|
|
28
|
+
returns:
|
|
29
|
+
- None (modifies self.nodes and self.edges)
|
|
30
|
+
"""
|
|
31
|
+
self.nodes = {
|
|
32
|
+
f"{hub.hubType}-{hub.id}": {
|
|
33
|
+
"coords": hub.coords,
|
|
34
|
+
"hubType": hub.hubType,
|
|
35
|
+
"id": hub.id
|
|
36
|
+
}
|
|
37
|
+
for hub in self.graph._allHubs()
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
self.edges = [
|
|
41
|
+
{
|
|
42
|
+
"from": f"{hub.hubType}-{hub.id}",
|
|
43
|
+
"to": f"{self.graph.getHubById(dest).hubType}-{dest}",
|
|
44
|
+
**edge.allMetrics
|
|
45
|
+
}
|
|
46
|
+
for hub in self.graph._allHubs()
|
|
47
|
+
for _, edge in hub.outgoing.items()
|
|
48
|
+
for dest, edge in edge.items()
|
|
49
|
+
]
|
|
50
|
+
self.dim = max(len(node.get("coords")) for node in self.nodes.values())
|
|
51
|
+
|
|
52
|
+
if nodeTransform is not None:
|
|
53
|
+
expandedCoords = [node.get("coords") + [0] * (self.dim - len(node.get("coords"))) for node in self.nodes.values()]
|
|
54
|
+
transformedCoords = nodeTransform(expandedCoords)
|
|
55
|
+
for node, coords in zip(self.nodes.values(), transformedCoords):
|
|
56
|
+
node["coords"] = coords
|
|
57
|
+
|
|
58
|
+
self.dim = max(len(node.get("coords")) for node in self.nodes.values())
|
|
59
|
+
|
|
60
|
+
if edgeTransform is not None:
|
|
61
|
+
starts = [edge["from"] for edge in self.edges]
|
|
62
|
+
startCoords = [self.nodes[start]["coords"] for start in starts]
|
|
63
|
+
ends = [edge["to"] for edge in self.edges]
|
|
64
|
+
endCoords = [self.nodes[end]["coords"] for end in ends]
|
|
65
|
+
|
|
66
|
+
transformedEdges = edgeTransform(startCoords, endCoords)
|
|
67
|
+
for edge, transformedEdge in zip(self.edges, transformedEdges):
|
|
68
|
+
edge["curve"] = transformedEdge
|
|
69
|
+
|
|
70
|
+
def display(
|
|
71
|
+
self,
|
|
72
|
+
nodeTransform=None,
|
|
73
|
+
edgeTransform=None,
|
|
74
|
+
displayEarth=False
|
|
75
|
+
):
|
|
76
|
+
"""
|
|
77
|
+
function to display any 2D or 3D RouteGraph
|
|
78
|
+
|
|
79
|
+
args:
|
|
80
|
+
- nodeTransform: function to transform the node coordinates (default = None)
|
|
81
|
+
- edgeTransform: function to transform the edge coordinates (default = None)
|
|
82
|
+
- displayEarth: whether to display the earth as a background (default = False, only in 3D)
|
|
83
|
+
|
|
84
|
+
returns:
|
|
85
|
+
- None (modifies self.nodes and self.edges opens the plot in a browser)
|
|
86
|
+
|
|
87
|
+
"""
|
|
88
|
+
# transform the graph
|
|
89
|
+
self._toPlotlyFormat(nodeTransform, edgeTransform)
|
|
90
|
+
# init plotly placeholders
|
|
91
|
+
node_x, node_y, node_z, text, colors = [], [], [], [], []
|
|
92
|
+
edge_x, edge_y, edge_z, edge_text = [], [], [], []
|
|
93
|
+
|
|
94
|
+
# add all the nodes
|
|
95
|
+
for node_key, node_data in self.nodes.items():
|
|
96
|
+
x, y, *rest = node_data["coords"]
|
|
97
|
+
node_x.append(x)
|
|
98
|
+
node_y.append(y)
|
|
99
|
+
if self.dim == 3:
|
|
100
|
+
node_z.append(node_data["coords"][2])
|
|
101
|
+
text.append(f"{node_data['id']}<br>Type: {node_data['hubType']}")
|
|
102
|
+
colors.append(hash(node_data['hubType']) % 10)
|
|
103
|
+
|
|
104
|
+
# add all the edges
|
|
105
|
+
for edge in self.edges:
|
|
106
|
+
# check if edge has been transformed
|
|
107
|
+
if "curve" in edge:
|
|
108
|
+
curve = edge["curve"]
|
|
109
|
+
# add all the points of the edge
|
|
110
|
+
for point in curve:
|
|
111
|
+
edge_x.append(point[0])
|
|
112
|
+
edge_y.append(point[1])
|
|
113
|
+
if self.dim == 3:
|
|
114
|
+
edge_z.append(point[2])
|
|
115
|
+
edge_x.append(None)
|
|
116
|
+
edge_y.append(None)
|
|
117
|
+
# if 3d add the extra none to close the edge
|
|
118
|
+
if self.dim == 3:
|
|
119
|
+
edge_z.append(None)
|
|
120
|
+
else:
|
|
121
|
+
source = self.nodes[edge["from"]]["coords"]
|
|
122
|
+
target = self.nodes[edge["to"]]["coords"]
|
|
123
|
+
|
|
124
|
+
edge_x += [source[0], target[0], None]
|
|
125
|
+
edge_y += [source[1], target[1], None]
|
|
126
|
+
|
|
127
|
+
if self.dim == 3:
|
|
128
|
+
edge_z += [source[2], target[2], None]
|
|
129
|
+
|
|
130
|
+
# add text and hover display
|
|
131
|
+
hover = f"{edge['from']} → {edge['to']}"
|
|
132
|
+
metrics = {k: v for k, v in edge.items() if k not in ("from", "to", "curve")}
|
|
133
|
+
if metrics:
|
|
134
|
+
hover += "<br>" + "<br>".join(f"{k}: {v}" for k, v in metrics.items())
|
|
135
|
+
edge_text.append(hover)
|
|
136
|
+
|
|
137
|
+
if self.dim == 2:
|
|
138
|
+
# ceate the plot in 2d
|
|
139
|
+
node_trace = go.Scatter(
|
|
140
|
+
x=node_x,
|
|
141
|
+
y=node_y,
|
|
142
|
+
mode="markers",
|
|
143
|
+
hoverinfo="text",
|
|
144
|
+
text=text,
|
|
145
|
+
marker=dict(
|
|
146
|
+
size=self.iconSize,
|
|
147
|
+
color=colors,
|
|
148
|
+
colorscale="Viridis",
|
|
149
|
+
showscale=True
|
|
150
|
+
)
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
edge_trace = go.Scatter(
|
|
154
|
+
x=edge_x,
|
|
155
|
+
y=edge_y,
|
|
156
|
+
line=dict(width=2, color="#888"),
|
|
157
|
+
hoverinfo="text",
|
|
158
|
+
text=edge_text,
|
|
159
|
+
mode="lines"
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
elif self.dim == 3:
|
|
163
|
+
# create the plot in 3d
|
|
164
|
+
node_trace = go.Scatter3d(
|
|
165
|
+
x=node_x,
|
|
166
|
+
y=node_y,
|
|
167
|
+
z=node_z,
|
|
168
|
+
mode="markers",
|
|
169
|
+
hoverinfo="text",
|
|
170
|
+
text=text,
|
|
171
|
+
marker=dict(
|
|
172
|
+
size=self.iconSize,
|
|
173
|
+
color=colors,
|
|
174
|
+
colorscale="Viridis",
|
|
175
|
+
showscale=True
|
|
176
|
+
)
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
edge_trace = go.Scatter3d(
|
|
180
|
+
x=edge_x,
|
|
181
|
+
y=edge_y,
|
|
182
|
+
z=edge_z,
|
|
183
|
+
line=dict(width=1, color="#888"),
|
|
184
|
+
hoverinfo="text",
|
|
185
|
+
text=edge_text,
|
|
186
|
+
mode="lines",
|
|
187
|
+
opacity=0.6
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# create the plotly figure
|
|
191
|
+
fig = go.Figure(data=[edge_trace, node_trace])
|
|
192
|
+
# render earth / sphere in 3d
|
|
193
|
+
if self.dim == 3 and displayEarth:
|
|
194
|
+
try:
|
|
195
|
+
import numpy as np
|
|
196
|
+
R = 6369.9 # sphere radius
|
|
197
|
+
u = np.linspace(0, 2 * np.pi, 50) # azimuthal angle
|
|
198
|
+
v = np.linspace(0, np.pi, 50) # polar angle
|
|
199
|
+
u, v = np.meshgrid(u, v)
|
|
200
|
+
|
|
201
|
+
# Cartesian coordinates
|
|
202
|
+
x = R * np.cos(u) * np.sin(v)
|
|
203
|
+
y = R * np.sin(u) * np.sin(v)
|
|
204
|
+
z = R * np.cos(v)
|
|
205
|
+
except ImportError:
|
|
206
|
+
raise ImportError("numpy is required to display the earth")
|
|
207
|
+
|
|
208
|
+
sphere_surface = go.Surface(
|
|
209
|
+
x=x, y=y, z=z,
|
|
210
|
+
colorscale='Blues',
|
|
211
|
+
opacity=1,
|
|
212
|
+
showscale=False,
|
|
213
|
+
hoverinfo='skip'
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
fig.add_trace(sphere_surface)
|
|
217
|
+
|
|
218
|
+
fig.update_layout(title="Interactive Graph", showlegend=False, hovermode="closest")
|
|
219
|
+
fig.show()
|
|
220
|
+
|
|
221
|
+
@staticmethod
|
|
222
|
+
def degreesToCartesian3D(coords):
|
|
223
|
+
try:
|
|
224
|
+
import torch
|
|
225
|
+
C = torch.tensor(coords)
|
|
226
|
+
if C.dim() == 1:
|
|
227
|
+
C = C.unsqueeze(0)
|
|
228
|
+
R = 6371.0
|
|
229
|
+
lat = torch.deg2rad(C[:, 0])
|
|
230
|
+
lng = torch.deg2rad(C[:, 1])
|
|
231
|
+
x = R * torch.cos(lat) * torch.cos(lng)
|
|
232
|
+
y = R * torch.cos(lat) * torch.sin(lng)
|
|
233
|
+
z = R * torch.sin(lat)
|
|
234
|
+
return list(torch.stack((x, y, z), dim=1).numpy())
|
|
235
|
+
except ImportError:
|
|
236
|
+
import math
|
|
237
|
+
R = 6371.0
|
|
238
|
+
output = []
|
|
239
|
+
for lat, lng in coords:
|
|
240
|
+
lat = math.radians(lat)
|
|
241
|
+
lng = math.radians(lng)
|
|
242
|
+
x = R * math.cos(lat) * math.cos(lng)
|
|
243
|
+
y = R * math.cos(lat) * math.sin(lng)
|
|
244
|
+
z = R * math.sin(lat)
|
|
245
|
+
output.append([x, y, z])
|
|
246
|
+
return output
|
|
247
|
+
|
|
248
|
+
@staticmethod
|
|
249
|
+
def curvedEdges(start, end, R=6371.0, H=0.05, n=20):
|
|
250
|
+
try:
|
|
251
|
+
# if torch and np are available calc vectorized graeter circle curves
|
|
252
|
+
import numpy as np
|
|
253
|
+
import torch
|
|
254
|
+
|
|
255
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
256
|
+
|
|
257
|
+
start_np = np.array(start, dtype=np.float32)
|
|
258
|
+
end_np = np.array(end, dtype=np.float32)
|
|
259
|
+
|
|
260
|
+
start = torch.tensor(start_np, device=device)
|
|
261
|
+
end = torch.tensor(end_np, device=device)
|
|
262
|
+
start = start.float()
|
|
263
|
+
end = end.float()
|
|
264
|
+
|
|
265
|
+
# normalize to sphere
|
|
266
|
+
start_norm = R * start / start.norm(dim=1, keepdim=True)
|
|
267
|
+
end_norm = R * end / end.norm(dim=1, keepdim=True)
|
|
268
|
+
|
|
269
|
+
# compute angle between vectors
|
|
270
|
+
dot = (start_norm * end_norm).sum(dim=1, keepdim=True) / (R**2)
|
|
271
|
+
dot = torch.clamp(dot, -1.0, 1.0)
|
|
272
|
+
theta = torch.acos(dot).unsqueeze(2) # shape: (num_edges,1,1)
|
|
273
|
+
|
|
274
|
+
# linear interpolation along great circle
|
|
275
|
+
t = torch.linspace(0, 1, n, device=device).view(1, n, 1)
|
|
276
|
+
one_minus_t = 1 - t
|
|
277
|
+
sin_theta = torch.sin(theta)
|
|
278
|
+
sin_theta[sin_theta == 0] = 1e-6
|
|
279
|
+
|
|
280
|
+
factor_start = torch.sin(one_minus_t * theta) / sin_theta
|
|
281
|
+
factor_end = torch.sin(t * theta) / sin_theta
|
|
282
|
+
|
|
283
|
+
curve = factor_start * start_norm.unsqueeze(1) + factor_end * end_norm.unsqueeze(1)
|
|
284
|
+
|
|
285
|
+
# normalize to radius
|
|
286
|
+
curve = R * curve / curve.norm(dim=2, keepdim=True)
|
|
287
|
+
|
|
288
|
+
# apply radial lift at curve center using sin weight
|
|
289
|
+
weight = torch.sin(torch.pi * t) # 0 at endpoints, 1 at center
|
|
290
|
+
curve = curve * (1 + H * weight)
|
|
291
|
+
|
|
292
|
+
return curve
|
|
293
|
+
except ImportError:
|
|
294
|
+
# fallback to calculating quadratic bezier curves with math
|
|
295
|
+
import math
|
|
296
|
+
curves_all = []
|
|
297
|
+
|
|
298
|
+
def multiply_vec(vec, factor):
|
|
299
|
+
return [factor * x for x in vec]
|
|
300
|
+
|
|
301
|
+
def add_vec(*vecs):
|
|
302
|
+
return [sum(items) for items in zip(*vecs)]
|
|
303
|
+
|
|
304
|
+
for startP, endP in zip(start, end):
|
|
305
|
+
mid = [(s + e) / 2 for s, e in zip(startP, endP)]
|
|
306
|
+
norm = math.sqrt(sum(c ** 2 for c in mid))
|
|
307
|
+
mid_proj = [R * c / norm for c in mid]
|
|
308
|
+
mid_arch = [c * (1 + H) for c in mid_proj]
|
|
309
|
+
|
|
310
|
+
curve = []
|
|
311
|
+
for i in range(n):
|
|
312
|
+
t_i = i / (n - 1)
|
|
313
|
+
one_minus_t = 1 - t_i
|
|
314
|
+
point = add_vec(
|
|
315
|
+
multiply_vec(startP, one_minus_t ** 2),
|
|
316
|
+
multiply_vec(mid_arch, 2 * one_minus_t * t_i),
|
|
317
|
+
multiply_vec(endP, t_i ** 2)
|
|
318
|
+
)
|
|
319
|
+
curve.append(point)
|
|
320
|
+
|
|
321
|
+
curves_all.append(curve)
|
|
322
|
+
|
|
323
|
+
return curves_all
|
|
File without changes
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
# build.py
|
|
2
|
+
# Copyright (c) 2025 Tobias Karusseit
|
|
3
|
+
# Licensed under the MIT License. See LICENSE file in the project root for full license information.
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
from ..graph import RouteGraph
|
|
7
|
+
import argparse
|
|
8
|
+
import os
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def main():
|
|
12
|
+
print("Building graph...")
|
|
13
|
+
parser = argparse.ArgumentParser(
|
|
14
|
+
description="Collect key-value1-value2 triplets into two dicts"
|
|
15
|
+
)
|
|
16
|
+
parser.add_argument(
|
|
17
|
+
"data",
|
|
18
|
+
nargs="+",
|
|
19
|
+
help="Arguments in groups of 3: hubType transportMode dataPath"
|
|
20
|
+
)
|
|
21
|
+
parser.add_argument(
|
|
22
|
+
"--maxDist",
|
|
23
|
+
type=int,
|
|
24
|
+
default=50,
|
|
25
|
+
help="Maximum distance to connect hubs with driving edges"
|
|
26
|
+
)
|
|
27
|
+
parser.add_argument(
|
|
28
|
+
"--compressed",
|
|
29
|
+
action="store_true",
|
|
30
|
+
help="Whether to compress the saved graph (default: False)"
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
parser.add_argument(
|
|
34
|
+
"--extraMetrics",
|
|
35
|
+
nargs="+",
|
|
36
|
+
default=[],
|
|
37
|
+
help="Extra metrics to add to the edge metadata"
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
parser.add_argument(
|
|
41
|
+
"--drivingEnabled",
|
|
42
|
+
action="store_true",
|
|
43
|
+
default=True,
|
|
44
|
+
help="Whether to connect hubs with driving edges (default: True)"
|
|
45
|
+
)
|
|
46
|
+
path = os.path.dirname(os.path.abspath(__file__))
|
|
47
|
+
parser.add_argument(
|
|
48
|
+
"--Dir",
|
|
49
|
+
type=str,
|
|
50
|
+
default=os.path.join(path, "..", "..", "..", "data"),
|
|
51
|
+
help="Directory to save the graph in (default: .)"
|
|
52
|
+
)
|
|
53
|
+
parser.add_argument(
|
|
54
|
+
"--sourceKeys",
|
|
55
|
+
nargs="+",
|
|
56
|
+
default=["source_lat", "source_lng"],
|
|
57
|
+
help="Source keys to search the source coordinates for (default: ['source_lat', 'source_lng'])"
|
|
58
|
+
)
|
|
59
|
+
parser.add_argument(
|
|
60
|
+
"--destKeys",
|
|
61
|
+
nargs="+",
|
|
62
|
+
default=["destination_lat", "destination_lng"],
|
|
63
|
+
help="Destination keys to search the destination coordinates for (default: ['destination_lat', 'destination_lng'])"
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
args = parser.parse_args()
|
|
67
|
+
|
|
68
|
+
if len(args.data) % 3 != 0:
|
|
69
|
+
parser.error("Arguments must be in groups of 3: hubType transportMode dataPath")
|
|
70
|
+
|
|
71
|
+
transportModes = {}
|
|
72
|
+
dataPaths = {}
|
|
73
|
+
|
|
74
|
+
for i in range(0, len(args.data), 3):
|
|
75
|
+
key, val1, val2 = args.data[i], args.data[i + 1], args.data[i + 2]
|
|
76
|
+
transportModes[key] = val1
|
|
77
|
+
dataPaths[key] = val2
|
|
78
|
+
|
|
79
|
+
graph = RouteGraph(
|
|
80
|
+
maxDistance=args.maxDist,
|
|
81
|
+
transportModes=transportModes,
|
|
82
|
+
dataPaths=dataPaths,
|
|
83
|
+
compressed=args.compressed,
|
|
84
|
+
extraMetricsKeys=args.extraMetrics,
|
|
85
|
+
drivingEnabled=args.drivingEnabled,
|
|
86
|
+
sourceCoordKeys=args.sourceKeys,
|
|
87
|
+
destCoordKeys=args.destKeys
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
graph.build()
|
|
91
|
+
graph.save(filepath=args.Dir, compressed=args.compressed)
|
|
92
|
+
|
|
93
|
+
print("Graph built and saved.")
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
if __name__ == "__main__":
|
|
97
|
+
main()
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
# router.py
|
|
2
|
+
# Copyright (c) 2025 Tobias Karusseit
|
|
3
|
+
# Licensed under the MIT License. See LICENSE file in the project root for full license information.
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
from ..graph import RouteGraph
|
|
7
|
+
import argparse
|
|
8
|
+
import os
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def main():
|
|
12
|
+
graph = RouteGraph.load(
|
|
13
|
+
os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "data", "graph.dill"),
|
|
14
|
+
compressed=False
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
parser = argparse.ArgumentParser(
|
|
18
|
+
description="parse the arguments"
|
|
19
|
+
)
|
|
20
|
+
parser.add_argument(
|
|
21
|
+
"--start",
|
|
22
|
+
nargs="+",
|
|
23
|
+
type=float,
|
|
24
|
+
required=True,
|
|
25
|
+
help="Start coordinates"
|
|
26
|
+
)
|
|
27
|
+
parser.add_argument(
|
|
28
|
+
"--end",
|
|
29
|
+
nargs="+",
|
|
30
|
+
type=float,
|
|
31
|
+
required=True,
|
|
32
|
+
help="End coordinates"
|
|
33
|
+
)
|
|
34
|
+
parser.add_argument(
|
|
35
|
+
"--allowedModes",
|
|
36
|
+
nargs="+",
|
|
37
|
+
type=str,
|
|
38
|
+
default=["car"],
|
|
39
|
+
help="Allowed transport modes"
|
|
40
|
+
)
|
|
41
|
+
parser.add_argument(
|
|
42
|
+
"--maxSegments",
|
|
43
|
+
type=int,
|
|
44
|
+
default=10,
|
|
45
|
+
help="Maximum number of segments in the route"
|
|
46
|
+
)
|
|
47
|
+
parser.add_argument(
|
|
48
|
+
"--verbose",
|
|
49
|
+
action="store_true",
|
|
50
|
+
help="Verbose output for the paths"
|
|
51
|
+
)
|
|
52
|
+
args = parser.parse_args()
|
|
53
|
+
|
|
54
|
+
start_hub = graph.findClosestHub(["airport"], args.start)
|
|
55
|
+
end_hub = graph.findClosestHub(["airport"], args.end)
|
|
56
|
+
|
|
57
|
+
if start_hub is None or end_hub is None:
|
|
58
|
+
print("One of the airports does not exist in the graph")
|
|
59
|
+
return
|
|
60
|
+
|
|
61
|
+
route = graph.find_shortest_path(start_id=start_hub.id,
|
|
62
|
+
end_id=end_hub.id,
|
|
63
|
+
allowed_modes=args.allowedModes,
|
|
64
|
+
max_segments=args.maxSegments,
|
|
65
|
+
verbose=args.verbose)
|
|
66
|
+
|
|
67
|
+
print(route.flatPath if route else "No route found")
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
if __name__ == "__main__":
|
|
71
|
+
main()
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .preprocessor import preprocessor # noqa: F401
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
# preprocessor.py
|
|
2
|
+
# Copyright (c) 2025 Tobias Karusseit
|
|
3
|
+
# Licensed under the MIT License. See LICENSE file in the project root for full license information.
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import os
|
|
8
|
+
|
|
9
|
+
# all datasets need:
|
|
10
|
+
# 1. source
|
|
11
|
+
# 2. destination
|
|
12
|
+
# 3. distance
|
|
13
|
+
# 4. source lat
|
|
14
|
+
# 5. source lng
|
|
15
|
+
# 6. destination lat
|
|
16
|
+
# 7. destination lng
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class preprocessor:
|
|
20
|
+
|
|
21
|
+
@staticmethod
|
|
22
|
+
def _save(
|
|
23
|
+
df: pd.DataFrame,
|
|
24
|
+
targetType: str = "parquet"
|
|
25
|
+
) -> None:
|
|
26
|
+
"""
|
|
27
|
+
Save the DataFrame to a file in the data directory.
|
|
28
|
+
|
|
29
|
+
Parameters:
|
|
30
|
+
df (pd.DataFrame): The DataFrame to be saved.
|
|
31
|
+
targetType (str): The type of file to be saved. Defaults to "parquet".
|
|
32
|
+
"""
|
|
33
|
+
from pathlib import Path
|
|
34
|
+
|
|
35
|
+
# create data directory if it doesn't exist (should be on te same level as this parent folder)
|
|
36
|
+
data_dir = Path(__file__).parent.parent / "data"
|
|
37
|
+
data_dir.mkdir(parents=True, exist_ok=True)
|
|
38
|
+
file_path = data_dir / f"fullDataset.{targetType}"
|
|
39
|
+
|
|
40
|
+
if targetType == "csv":
|
|
41
|
+
# Save the DataFrame to a csv file
|
|
42
|
+
df.to_csv(file_path, index=False)
|
|
43
|
+
else:
|
|
44
|
+
# Save the DataFrame to a parquet file
|
|
45
|
+
df.to_parquet(file_path, engine="pyarrow")
|
|
46
|
+
|
|
47
|
+
@staticmethod
|
|
48
|
+
def preprocess(
|
|
49
|
+
path: str,
|
|
50
|
+
sourceKey: str = "source",
|
|
51
|
+
sourceNameKey: str = "source_name",
|
|
52
|
+
destinationKey: str = "destination",
|
|
53
|
+
destinationNameKey: str = "destination_name",
|
|
54
|
+
distanceKey: str = "distance",
|
|
55
|
+
sourceLatKey: str = "source_lat",
|
|
56
|
+
sourceLngKey: str = "source_lng",
|
|
57
|
+
destinationLatKey: str = "destination_lat",
|
|
58
|
+
destinationLngKey: str = "destination_lng",
|
|
59
|
+
targetType: str = "parquet"
|
|
60
|
+
) -> pd.DataFrame:
|
|
61
|
+
"""
|
|
62
|
+
Preprocess a dataset by renaming columns to the desired format,
|
|
63
|
+
calculating distances and adding the result to the dataframe.
|
|
64
|
+
|
|
65
|
+
Parameters:
|
|
66
|
+
path (str): path to the dataset
|
|
67
|
+
sourceKey (str): key for the source column (default: "source")
|
|
68
|
+
destinationKey (str): key for the destination column (default: "destination")
|
|
69
|
+
distanceKey (str): key for the distance column (default: "distance")
|
|
70
|
+
sourceLatKey (str): key for the source latitude column (default: "source_lat")
|
|
71
|
+
sourceLngKey (str): key for the source longitude column (default: "source_lng")
|
|
72
|
+
destinationLatKey (str): key for the destination latitude column (default: "destination_lat")
|
|
73
|
+
destinationLngKey (str): key for the destination longitude column (default: "destination_lng")
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
pd.DataFrame: the preprocessed dataframe
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
# check if file exists and read it into a df
|
|
80
|
+
_, fType = os.path.splitext(path)
|
|
81
|
+
if fType == ".csv":
|
|
82
|
+
df = pd.read_csv(path)
|
|
83
|
+
elif fType == ".parquet":
|
|
84
|
+
df = pd.read_parquet(path)
|
|
85
|
+
|
|
86
|
+
# get all column names
|
|
87
|
+
cols = list(df.columns)
|
|
88
|
+
|
|
89
|
+
# check if all required columns are present
|
|
90
|
+
if any([
|
|
91
|
+
sourceKey not in cols,
|
|
92
|
+
sourceNameKey not in cols,
|
|
93
|
+
destinationKey not in cols,
|
|
94
|
+
destinationNameKey not in cols,
|
|
95
|
+
sourceLatKey not in cols,
|
|
96
|
+
sourceLngKey not in cols,
|
|
97
|
+
destinationLatKey not in cols,
|
|
98
|
+
destinationLngKey not in cols
|
|
99
|
+
]):
|
|
100
|
+
raise Exception("Invalid dataset")
|
|
101
|
+
|
|
102
|
+
# rename columns to the desired format
|
|
103
|
+
df.rename(columns={
|
|
104
|
+
sourceKey: "source",
|
|
105
|
+
sourceNameKey: "source_name",
|
|
106
|
+
destinationKey: "destination",
|
|
107
|
+
destinationNameKey: "destination_name",
|
|
108
|
+
sourceLatKey: "source_lat",
|
|
109
|
+
sourceLngKey: "source_lng",
|
|
110
|
+
destinationLatKey: "destination_lat",
|
|
111
|
+
destinationLngKey: "destination_lng",
|
|
112
|
+
**({distanceKey: "distance"} if distanceKey in cols else {})
|
|
113
|
+
}, inplace=True)
|
|
114
|
+
|
|
115
|
+
# distance is already present return here
|
|
116
|
+
if distanceKey in cols:
|
|
117
|
+
preprocessor._save(df, targetType=targetType)
|
|
118
|
+
return df[[
|
|
119
|
+
"source",
|
|
120
|
+
"source_name",
|
|
121
|
+
"destination",
|
|
122
|
+
"destination_name",
|
|
123
|
+
"distance",
|
|
124
|
+
"source_lat",
|
|
125
|
+
"source_lng",
|
|
126
|
+
"destination_lat",
|
|
127
|
+
"destination_lng"
|
|
128
|
+
]]
|
|
129
|
+
|
|
130
|
+
# calculate distance
|
|
131
|
+
df["distance"] = preprocessor.haversine(df)
|
|
132
|
+
|
|
133
|
+
# save df
|
|
134
|
+
preprocessor._save(df, targetType=targetType)
|
|
135
|
+
# return processed df
|
|
136
|
+
return df[[
|
|
137
|
+
"source",
|
|
138
|
+
"source_name"
|
|
139
|
+
"destination",
|
|
140
|
+
"destination_name",
|
|
141
|
+
"distance",
|
|
142
|
+
"source_lat",
|
|
143
|
+
"source_lng",
|
|
144
|
+
"destination_lat",
|
|
145
|
+
"destination_lng"
|
|
146
|
+
]]
|
|
147
|
+
|
|
148
|
+
@staticmethod
|
|
149
|
+
def haversine(df: pd.DataFrame) -> float:
|
|
150
|
+
# use torch for vector calculation
|
|
151
|
+
import torch
|
|
152
|
+
# set device
|
|
153
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
154
|
+
# compute vectorized haversine
|
|
155
|
+
with torch.no_grad():
|
|
156
|
+
# convert to radians
|
|
157
|
+
lat1 = torch.deg2rad(torch.tensor(df["source_lat"].values, device=device))
|
|
158
|
+
lng1 = torch.deg2rad(torch.tensor(df["source_lng"].values, device=device))
|
|
159
|
+
lat2 = torch.deg2rad(torch.tensor(df["destination_lat"].values, device=device))
|
|
160
|
+
lng2 = torch.deg2rad(torch.tensor(df["destination_lng"].values, device=device))
|
|
161
|
+
|
|
162
|
+
# compute delta lat and delta lng
|
|
163
|
+
dlat = lat2 - lat1
|
|
164
|
+
dlng = lng2 - lng1
|
|
165
|
+
# compute haversine
|
|
166
|
+
a = torch.sin(dlat / 2)**2 + torch.cos(lat1) * torch.cos(lat2) * torch.sin(dlng / 2)**2
|
|
167
|
+
c = 2 * torch.atan2(torch.sqrt(a), torch.sqrt(1 - a))
|
|
168
|
+
|
|
169
|
+
distances = 6371 * c
|
|
170
|
+
|
|
171
|
+
return distances.cpu().numpy()
|
|
172
|
+
|
|
173
|
+
@staticmethod
|
|
174
|
+
def combine(df1: pd.DataFrame, df2: pd.DataFrame) -> pd.DataFrame:
|
|
175
|
+
# Combine the two DataFrames
|
|
176
|
+
combined_df = pd.concat([df1, df2], axis=0)
|
|
177
|
+
return combined_df
|