iflow-mcp_xrds76354_sumo-mcp 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- iflow_mcp_xrds76354_sumo_mcp-0.1.0.dist-info/METADATA +402 -0
- iflow_mcp_xrds76354_sumo_mcp-0.1.0.dist-info/RECORD +27 -0
- iflow_mcp_xrds76354_sumo_mcp-0.1.0.dist-info/WHEEL +4 -0
- iflow_mcp_xrds76354_sumo_mcp-0.1.0.dist-info/entry_points.txt +2 -0
- iflow_mcp_xrds76354_sumo_mcp-0.1.0.dist-info/licenses/LICENSE +21 -0
- mcp_tools/__init__.py +0 -0
- mcp_tools/analysis.py +33 -0
- mcp_tools/network.py +94 -0
- mcp_tools/py.typed +0 -0
- mcp_tools/rl.py +425 -0
- mcp_tools/route.py +91 -0
- mcp_tools/signal.py +96 -0
- mcp_tools/simulation.py +79 -0
- mcp_tools/vehicle.py +52 -0
- resources/__init__.py +0 -0
- server.py +493 -0
- utils/__init__.py +0 -0
- utils/connection.py +145 -0
- utils/output.py +26 -0
- utils/sumo.py +185 -0
- utils/timeout.py +364 -0
- utils/traci.py +82 -0
- workflows/__init__.py +0 -0
- workflows/py.typed +0 -0
- workflows/rl_train.py +34 -0
- workflows/signal_opt.py +210 -0
- workflows/sim_gen.py +70 -0
mcp_tools/vehicle.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
import traci
|
|
2
|
+
from typing import List, Tuple
|
|
3
|
+
from utils.connection import connection_manager
|
|
4
|
+
|
|
5
|
+
def get_vehicles() -> List[str]:
|
|
6
|
+
"""Get the list of all active vehicle IDs."""
|
|
7
|
+
if not connection_manager.is_connected():
|
|
8
|
+
return []
|
|
9
|
+
return list(traci.vehicle.getIDList())
|
|
10
|
+
|
|
11
|
+
def get_vehicle_speed(vehicle_id: str) -> float:
|
|
12
|
+
"""Get the speed of a specific vehicle (m/s)."""
|
|
13
|
+
if not connection_manager.is_connected():
|
|
14
|
+
raise RuntimeError("Not connected to SUMO.")
|
|
15
|
+
return float(traci.vehicle.getSpeed(vehicle_id))
|
|
16
|
+
|
|
17
|
+
def get_vehicle_position(vehicle_id: str) -> Tuple[float, float]:
|
|
18
|
+
"""Get the (x, y) position of a specific vehicle."""
|
|
19
|
+
if not connection_manager.is_connected():
|
|
20
|
+
raise RuntimeError("Not connected to SUMO.")
|
|
21
|
+
x, y = traci.vehicle.getPosition(vehicle_id)
|
|
22
|
+
return float(x), float(y)
|
|
23
|
+
|
|
24
|
+
def get_vehicle_acceleration(vehicle_id: str) -> float:
|
|
25
|
+
"""Get the acceleration of a specific vehicle (m/s^2)."""
|
|
26
|
+
if not connection_manager.is_connected():
|
|
27
|
+
raise RuntimeError("Not connected to SUMO.")
|
|
28
|
+
return float(traci.vehicle.getAcceleration(vehicle_id))
|
|
29
|
+
|
|
30
|
+
def get_vehicle_lane(vehicle_id: str) -> str:
|
|
31
|
+
"""Get the lane ID of a specific vehicle."""
|
|
32
|
+
if not connection_manager.is_connected():
|
|
33
|
+
raise RuntimeError("Not connected to SUMO.")
|
|
34
|
+
return str(traci.vehicle.getLaneID(vehicle_id))
|
|
35
|
+
|
|
36
|
+
def get_vehicle_route(vehicle_id: str) -> List[str]:
|
|
37
|
+
"""Get the route (list of edge IDs) of a specific vehicle."""
|
|
38
|
+
if not connection_manager.is_connected():
|
|
39
|
+
raise RuntimeError("Not connected to SUMO.")
|
|
40
|
+
return [str(edge) for edge in traci.vehicle.getRoute(vehicle_id)]
|
|
41
|
+
|
|
42
|
+
def get_simulation_info() -> dict[str, float | int]:
|
|
43
|
+
"""Get general simulation statistics."""
|
|
44
|
+
if not connection_manager.is_connected():
|
|
45
|
+
raise RuntimeError("Not connected to SUMO.")
|
|
46
|
+
return {
|
|
47
|
+
"time": float(traci.simulation.getTime()),
|
|
48
|
+
"loaded_vehicles": int(traci.simulation.getLoadedNumber()),
|
|
49
|
+
"departed_vehicles": int(traci.simulation.getDepartedNumber()),
|
|
50
|
+
"arrived_vehicles": int(traci.simulation.getArrivedNumber()),
|
|
51
|
+
"min_expected_vehicles": int(traci.simulation.getMinExpectedNumber()),
|
|
52
|
+
}
|
resources/__init__.py
ADDED
|
File without changes
|
server.py
ADDED
|
@@ -0,0 +1,493 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import subprocess
|
|
3
|
+
from typing import Any, Dict, Optional
|
|
4
|
+
|
|
5
|
+
from mcp.server.fastmcp import FastMCP
|
|
6
|
+
|
|
7
|
+
from utils.traci import ensure_traci_start_stdout_suppressed
|
|
8
|
+
from mcp_tools.simulation import run_simple_simulation
|
|
9
|
+
from mcp_tools.network import netconvert, netgenerate, osm_get
|
|
10
|
+
from mcp_tools.route import random_trips, duarouter, od2trips
|
|
11
|
+
from mcp_tools.signal import tls_cycle_adaptation, tls_coordinator
|
|
12
|
+
from mcp_tools.analysis import analyze_fcd
|
|
13
|
+
from mcp_tools.vehicle import (
|
|
14
|
+
get_vehicles, get_vehicle_speed, get_vehicle_position,
|
|
15
|
+
get_vehicle_acceleration, get_vehicle_lane, get_vehicle_route,
|
|
16
|
+
get_simulation_info
|
|
17
|
+
)
|
|
18
|
+
from mcp_tools.rl import find_sumo_rl_scenario_files, list_rl_scenarios, run_rl_training
|
|
19
|
+
from utils.connection import connection_manager
|
|
20
|
+
from utils.sumo import find_sumo_binary, find_sumo_home, find_sumo_tools_dir
|
|
21
|
+
from workflows.sim_gen import sim_gen_workflow
|
|
22
|
+
from workflows.signal_opt import signal_opt_workflow
|
|
23
|
+
from workflows.rl_train import rl_train_workflow
|
|
24
|
+
|
|
25
|
+
# Configure logging to stderr to not interfere with MCP stdio transport
|
|
26
|
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
# Ensure TraCI never writes to stdout by default (MCP stdio safety).
|
|
30
|
+
ensure_traci_start_stdout_suppressed()
|
|
31
|
+
|
|
32
|
+
# Initialize MCP Server (official SDK)
|
|
33
|
+
server = FastMCP("SUMO-MCP-Server")
|
|
34
|
+
|
|
35
|
+
# --- 1. Network Management ---
|
|
36
|
+
@server.tool(description="Manage SUMO network (generate, convert, or download OSM).")
|
|
37
|
+
def manage_network(action: str, output_file: str, params: Optional[Dict[str, Any]] = None) -> str:
|
|
38
|
+
"""
|
|
39
|
+
actions:
|
|
40
|
+
- generate: params={'grid': bool, 'grid_number': int}
|
|
41
|
+
- convert: params={'osm_file': str}
|
|
42
|
+
- download_osm: output_file is treated as output_dir. params={'bbox': str, 'prefix': str}
|
|
43
|
+
"""
|
|
44
|
+
params = params or {}
|
|
45
|
+
options = params.get("options")
|
|
46
|
+
|
|
47
|
+
if action == "generate":
|
|
48
|
+
spider = bool(params.get("spider", False))
|
|
49
|
+
grid = bool(params.get("grid", True))
|
|
50
|
+
grid_number = params.get("grid_number", 3)
|
|
51
|
+
|
|
52
|
+
if spider:
|
|
53
|
+
# Spider network takes precedence over grid settings.
|
|
54
|
+
grid = False
|
|
55
|
+
options_list = list(options or [])
|
|
56
|
+
|
|
57
|
+
def _strip_flag(flag: str, has_value: bool = False) -> None:
|
|
58
|
+
while flag in options_list:
|
|
59
|
+
idx = options_list.index(flag)
|
|
60
|
+
options_list.pop(idx)
|
|
61
|
+
if has_value and idx < len(options_list):
|
|
62
|
+
options_list.pop(idx)
|
|
63
|
+
|
|
64
|
+
def _set_option(flag: str, value: str) -> None:
|
|
65
|
+
if flag in options_list:
|
|
66
|
+
idx = options_list.index(flag)
|
|
67
|
+
if idx + 1 < len(options_list):
|
|
68
|
+
options_list[idx + 1] = value
|
|
69
|
+
else:
|
|
70
|
+
options_list.append(value)
|
|
71
|
+
else:
|
|
72
|
+
options_list.extend([flag, value])
|
|
73
|
+
|
|
74
|
+
# Enforce Spider/Grid mutual exclusion even when the user provided `options`.
|
|
75
|
+
_strip_flag("--grid")
|
|
76
|
+
_strip_flag("--grid.number", has_value=True)
|
|
77
|
+
|
|
78
|
+
if "--spider" not in options_list:
|
|
79
|
+
options_list.insert(0, "--spider")
|
|
80
|
+
|
|
81
|
+
arms_raw = params.get("arms", params.get("arm_number"))
|
|
82
|
+
if arms_raw is not None:
|
|
83
|
+
try:
|
|
84
|
+
arms = int(arms_raw)
|
|
85
|
+
except (TypeError, ValueError):
|
|
86
|
+
return f"Error: arms must be a positive integer, got {arms_raw!r}"
|
|
87
|
+
if arms <= 0:
|
|
88
|
+
return "Error: arms must be > 0"
|
|
89
|
+
_set_option("--spider.arm-number", str(arms))
|
|
90
|
+
|
|
91
|
+
circles_raw = params.get("circles", params.get("circle_number"))
|
|
92
|
+
if circles_raw is not None:
|
|
93
|
+
try:
|
|
94
|
+
circles = int(circles_raw)
|
|
95
|
+
except (TypeError, ValueError):
|
|
96
|
+
return f"Error: circles must be a positive integer, got {circles_raw!r}"
|
|
97
|
+
if circles <= 0:
|
|
98
|
+
return "Error: circles must be > 0"
|
|
99
|
+
_set_option("--spider.circle-number", str(circles))
|
|
100
|
+
|
|
101
|
+
space_radius_raw = params.get("ring_radius", params.get("space_radius"))
|
|
102
|
+
if space_radius_raw is not None:
|
|
103
|
+
try:
|
|
104
|
+
space_radius = float(space_radius_raw)
|
|
105
|
+
except (TypeError, ValueError):
|
|
106
|
+
return f"Error: ring_radius must be a number, got {space_radius_raw!r}"
|
|
107
|
+
if space_radius <= 0:
|
|
108
|
+
return "Error: ring_radius must be > 0"
|
|
109
|
+
_set_option("--spider.space-radius", str(space_radius))
|
|
110
|
+
|
|
111
|
+
attach_length_raw = params.get("radial_distance", params.get("attach_length"))
|
|
112
|
+
if attach_length_raw is not None:
|
|
113
|
+
try:
|
|
114
|
+
attach_length = float(attach_length_raw)
|
|
115
|
+
except (TypeError, ValueError):
|
|
116
|
+
return f"Error: radial_distance must be a number, got {attach_length_raw!r}"
|
|
117
|
+
if attach_length < 0:
|
|
118
|
+
return "Error: radial_distance must be >= 0"
|
|
119
|
+
_set_option("--spider.attach-length", str(attach_length))
|
|
120
|
+
|
|
121
|
+
omit_center_raw = params.get("omit_center")
|
|
122
|
+
if omit_center_raw:
|
|
123
|
+
if "--spider.omit-center" not in options_list:
|
|
124
|
+
options_list.append("--spider.omit-center")
|
|
125
|
+
|
|
126
|
+
options = options_list
|
|
127
|
+
|
|
128
|
+
return netgenerate(output_file, grid, grid_number, options)
|
|
129
|
+
|
|
130
|
+
elif action == "convert" or action == "convert_osm":
|
|
131
|
+
osm_file = params.get("osm_file")
|
|
132
|
+
if not osm_file: return "Error: osm_file required for convert action"
|
|
133
|
+
return netconvert(osm_file, output_file, options)
|
|
134
|
+
|
|
135
|
+
elif action == "download_osm":
|
|
136
|
+
# output_file here acts as output_dir
|
|
137
|
+
bbox = params.get("bbox")
|
|
138
|
+
prefix = params.get("prefix", "osm")
|
|
139
|
+
if not bbox: return "Error: bbox required for download_osm action"
|
|
140
|
+
return osm_get(bbox, output_file, prefix, options)
|
|
141
|
+
|
|
142
|
+
return f"Unknown action: {action}"
|
|
143
|
+
|
|
144
|
+
# --- 2. Demand Management ---
|
|
145
|
+
@server.tool(description="Manage traffic demand (random trips, OD matrix, routing).")
|
|
146
|
+
def manage_demand(action: str, net_file: str, output_file: str, params: Optional[Dict[str, Any]] = None) -> str:
|
|
147
|
+
"""
|
|
148
|
+
actions:
|
|
149
|
+
- generate_random: params={'end_time': int, 'period': float}
|
|
150
|
+
- convert_od: params={'od_file': str} (net_file unused but kept for consistency)
|
|
151
|
+
- compute_routes: params={'route_files': str} (input trips)
|
|
152
|
+
"""
|
|
153
|
+
params = params or {}
|
|
154
|
+
options = params.get("options")
|
|
155
|
+
|
|
156
|
+
if action == "generate_random" or action == "random_trips":
|
|
157
|
+
# Backward/compat aliases: some clients use `end` instead of `end_time`.
|
|
158
|
+
end_time_raw = params.get("end_time", params.get("end", 3600))
|
|
159
|
+
period_raw = params.get("period", 1.0)
|
|
160
|
+
try:
|
|
161
|
+
end_time = int(end_time_raw)
|
|
162
|
+
except (TypeError, ValueError):
|
|
163
|
+
return f"Error: end_time must be an integer, got {end_time_raw!r}"
|
|
164
|
+
try:
|
|
165
|
+
period = float(period_raw)
|
|
166
|
+
except (TypeError, ValueError):
|
|
167
|
+
return f"Error: period must be a number, got {period_raw!r}"
|
|
168
|
+
return random_trips(net_file, output_file, end_time, period, options)
|
|
169
|
+
|
|
170
|
+
elif action == "convert_od" or action == "od_matrix":
|
|
171
|
+
od_file = params.get("od_file")
|
|
172
|
+
if not od_file: return "Error: od_file required for convert_od"
|
|
173
|
+
return od2trips(od_file, output_file, options)
|
|
174
|
+
|
|
175
|
+
elif action == "compute_routes" or action == "routing":
|
|
176
|
+
route_files = params.get("route_files") # Input trips file
|
|
177
|
+
if not route_files: return "Error: route_files required for compute_routes"
|
|
178
|
+
return duarouter(net_file, route_files, output_file, options)
|
|
179
|
+
|
|
180
|
+
return f"Unknown action: {action}"
|
|
181
|
+
|
|
182
|
+
# --- 3. Simulation Control ---
|
|
183
|
+
@server.tool(description="Control SUMO simulation (connect, step, disconnect).")
|
|
184
|
+
def control_simulation(action: str, params: Optional[Dict[str, Any]] = None) -> str:
|
|
185
|
+
"""
|
|
186
|
+
actions:
|
|
187
|
+
- connect: params={'config_file': str, 'gui': bool}
|
|
188
|
+
- step: params={'step': float}
|
|
189
|
+
- disconnect: no params
|
|
190
|
+
"""
|
|
191
|
+
params = params or {}
|
|
192
|
+
|
|
193
|
+
try:
|
|
194
|
+
timeout_s_raw = params.get("timeout_s", params.get("timeout"))
|
|
195
|
+
timeout_s: Optional[float] = None
|
|
196
|
+
if timeout_s_raw is not None:
|
|
197
|
+
try:
|
|
198
|
+
timeout_s = float(timeout_s_raw)
|
|
199
|
+
except (TypeError, ValueError):
|
|
200
|
+
return f"Error: timeout_s must be a number, got {timeout_s_raw!r}"
|
|
201
|
+
|
|
202
|
+
if action == "connect":
|
|
203
|
+
config_file = params.get("config_file")
|
|
204
|
+
gui = params.get("gui", False)
|
|
205
|
+
port = params.get("port", 8813)
|
|
206
|
+
host = params.get("host", "localhost")
|
|
207
|
+
if timeout_s is None:
|
|
208
|
+
connection_manager.connect(config_file, gui, port, host)
|
|
209
|
+
else:
|
|
210
|
+
connection_manager.connect(config_file, gui, port, host, timeout_s=timeout_s)
|
|
211
|
+
return "Successfully connected to SUMO."
|
|
212
|
+
|
|
213
|
+
elif action == "step":
|
|
214
|
+
step = params.get("step", 0)
|
|
215
|
+
if timeout_s is None:
|
|
216
|
+
connection_manager.simulation_step(step)
|
|
217
|
+
else:
|
|
218
|
+
connection_manager.simulation_step(step, timeout_s=timeout_s)
|
|
219
|
+
return "Simulation advanced."
|
|
220
|
+
|
|
221
|
+
elif action == "disconnect":
|
|
222
|
+
if timeout_s is None:
|
|
223
|
+
connection_manager.disconnect()
|
|
224
|
+
else:
|
|
225
|
+
connection_manager.disconnect(timeout_s=timeout_s)
|
|
226
|
+
return "Successfully disconnected from SUMO."
|
|
227
|
+
|
|
228
|
+
except Exception as e:
|
|
229
|
+
return f"Error in control_simulation ({action}): {type(e).__name__}: {e}"
|
|
230
|
+
|
|
231
|
+
return f"Unknown action: {action}"
|
|
232
|
+
|
|
233
|
+
# --- 4. Query State ---
|
|
234
|
+
@server.tool(description="Query simulation state (vehicles, speed, position). Requires active connection.")
|
|
235
|
+
def query_simulation_state(target: str, params: Optional[Dict[str, Any]] = None) -> str:
|
|
236
|
+
"""
|
|
237
|
+
targets:
|
|
238
|
+
- vehicle_list: no params
|
|
239
|
+
- vehicle_variable: params={'vehicle_id': str, 'variable': 'speed'|'position'|'lane'|'acceleration'|'route'}
|
|
240
|
+
"""
|
|
241
|
+
params = params or {}
|
|
242
|
+
|
|
243
|
+
try:
|
|
244
|
+
if target == "vehicle_list" or target == "vehicles":
|
|
245
|
+
vehs = get_vehicles()
|
|
246
|
+
return f"Active vehicles: {vehs}"
|
|
247
|
+
|
|
248
|
+
elif target == "vehicle_variable":
|
|
249
|
+
v_id = params.get("vehicle_id")
|
|
250
|
+
var = params.get("variable")
|
|
251
|
+
if not v_id or not var: return "Error: vehicle_id and variable required"
|
|
252
|
+
|
|
253
|
+
if var == "speed": return f"Speed: {get_vehicle_speed(v_id)}"
|
|
254
|
+
if var == "position": return f"Position: {get_vehicle_position(v_id)}"
|
|
255
|
+
if var == "acceleration": return f"Acceleration: {get_vehicle_acceleration(v_id)}"
|
|
256
|
+
if var == "lane": return f"Lane: {get_vehicle_lane(v_id)}"
|
|
257
|
+
if var == "route": return f"Route: {get_vehicle_route(v_id)}"
|
|
258
|
+
|
|
259
|
+
return f"Unknown variable: {var}"
|
|
260
|
+
|
|
261
|
+
elif target == "simulation":
|
|
262
|
+
info = get_simulation_info()
|
|
263
|
+
return f"Simulation Info: {info}"
|
|
264
|
+
|
|
265
|
+
except Exception as e:
|
|
266
|
+
return f"Error querying state: {type(e).__name__}: {e}"
|
|
267
|
+
|
|
268
|
+
return f"Unknown target: {target}"
|
|
269
|
+
|
|
270
|
+
# --- 5. Optimize Signals ---
|
|
271
|
+
@server.tool(description="Optimize traffic signals.")
|
|
272
|
+
def optimize_traffic_signals(method: str, net_file: str, route_file: str, output_file: str, params: Optional[Dict[str, Any]] = None) -> str:
|
|
273
|
+
"""
|
|
274
|
+
methods:
|
|
275
|
+
- cycle_adaptation: adapt TLS cycles
|
|
276
|
+
- coordination: TLS coordination
|
|
277
|
+
"""
|
|
278
|
+
params = params or {}
|
|
279
|
+
options = params.get("options")
|
|
280
|
+
|
|
281
|
+
if method == "cycle_adaptation" or method == "Websters":
|
|
282
|
+
return tls_cycle_adaptation(net_file, route_file, output_file)
|
|
283
|
+
elif method == "coordination":
|
|
284
|
+
return tls_coordinator(net_file, route_file, output_file, options)
|
|
285
|
+
|
|
286
|
+
return f"Unknown method: {method}"
|
|
287
|
+
|
|
288
|
+
# --- 6. Workflows ---
|
|
289
|
+
@server.tool(
|
|
290
|
+
description="""Run high-level SUMO workflows. Available workflows:
|
|
291
|
+
|
|
292
|
+
**sim_gen_eval** - Generate grid network, simulate traffic, analyze results.
|
|
293
|
+
params:
|
|
294
|
+
- grid_number (int): Grid size NxN. Default=3. Aliases: grid_size, size
|
|
295
|
+
- sim_seconds (int): Simulation duration in seconds. Default=100. Aliases: steps, duration, end_time
|
|
296
|
+
- output_dir (str): Output directory. Default="output"
|
|
297
|
+
Example: run_workflow("sim_gen_eval", {"grid_number": 3, "sim_seconds": 1000})
|
|
298
|
+
|
|
299
|
+
**signal_opt** - Optimize traffic signals for existing network.
|
|
300
|
+
params:
|
|
301
|
+
- net_file (str): Path to .net.xml file. REQUIRED
|
|
302
|
+
- route_file (str): Path to .rou.xml file. REQUIRED
|
|
303
|
+
- sim_seconds (int): Simulation duration. Default=3600. Aliases: steps, duration
|
|
304
|
+
- use_coordinator (bool): Use tlsCoordinator instead of tlsCycleAdaptation. Default=false
|
|
305
|
+
- output_dir (str): Output directory. Default="output"
|
|
306
|
+
|
|
307
|
+
**rl_train** - Train RL agent for traffic signal control.
|
|
308
|
+
params:
|
|
309
|
+
- scenario_name (str): Built-in scenario name (use manage_rl_task("list_scenarios") to see options). Aliases: scenario
|
|
310
|
+
- episodes (int): Number of training episodes. Default=5. Aliases: num_episodes
|
|
311
|
+
- steps (int): Steps per episode. Default=1000. Aliases: steps_per_episode
|
|
312
|
+
- output_dir (str): Output directory. Default="output"
|
|
313
|
+
"""
|
|
314
|
+
)
|
|
315
|
+
def run_workflow(workflow_name: str, params: Dict[str, Any]) -> str:
|
|
316
|
+
"""Execute a high-level workflow."""
|
|
317
|
+
|
|
318
|
+
# Helper to get param with aliases
|
|
319
|
+
def get_param(keys: list, default=None):
|
|
320
|
+
for k in keys:
|
|
321
|
+
if k in params:
|
|
322
|
+
return params[k]
|
|
323
|
+
return default
|
|
324
|
+
|
|
325
|
+
if workflow_name in ("sim_gen_eval", "sim_gen_workflow", "sim_gen"):
|
|
326
|
+
grid_number = get_param(["grid_number", "grid_size", "size"], 3)
|
|
327
|
+
sim_seconds = get_param(["sim_seconds", "steps", "duration", "end_time"], 100)
|
|
328
|
+
output_dir = get_param(["output_dir"], "output")
|
|
329
|
+
|
|
330
|
+
return sim_gen_workflow(output_dir, int(grid_number), int(sim_seconds))
|
|
331
|
+
|
|
332
|
+
elif workflow_name in ("signal_opt", "signal_opt_workflow"):
|
|
333
|
+
net_file = get_param(["net_file"], "")
|
|
334
|
+
route_file = get_param(["route_file"], "")
|
|
335
|
+
|
|
336
|
+
if not net_file or not route_file:
|
|
337
|
+
return "Error: signal_opt requires net_file and route_file parameters."
|
|
338
|
+
|
|
339
|
+
sim_seconds = get_param(["sim_seconds", "steps", "duration"], 3600)
|
|
340
|
+
use_coordinator = get_param(["use_coordinator"], False)
|
|
341
|
+
output_dir = get_param(["output_dir"], "output")
|
|
342
|
+
|
|
343
|
+
return signal_opt_workflow(net_file, route_file, output_dir, int(sim_seconds), bool(use_coordinator))
|
|
344
|
+
|
|
345
|
+
elif workflow_name == "rl_train":
|
|
346
|
+
scenario_name = get_param(["scenario_name", "scenario"], "")
|
|
347
|
+
episodes = get_param(["episodes", "num_episodes"], 5)
|
|
348
|
+
steps = get_param(["steps", "steps_per_episode"], 1000)
|
|
349
|
+
output_dir = get_param(["output_dir"], "output")
|
|
350
|
+
|
|
351
|
+
return rl_train_workflow(scenario_name, output_dir, int(episodes), int(steps))
|
|
352
|
+
|
|
353
|
+
return f"Unknown workflow: {workflow_name}. Available: sim_gen_eval, signal_opt, rl_train"
|
|
354
|
+
|
|
355
|
+
# --- 7. RL Task Management ---
|
|
356
|
+
@server.tool(description="Manage RL tasks (list scenarios, custom training).")
|
|
357
|
+
def manage_rl_task(action: str, params: Optional[Dict[str, Any]] = None) -> str:
|
|
358
|
+
"""
|
|
359
|
+
actions:
|
|
360
|
+
- list_scenarios: no params
|
|
361
|
+
- train_custom: params={'net_file', 'route_file', 'out_dir', 'episodes', 'steps', 'algorithm', 'reward_type'}
|
|
362
|
+
"""
|
|
363
|
+
params = params or {}
|
|
364
|
+
|
|
365
|
+
if action == "list_scenarios":
|
|
366
|
+
return str(list_rl_scenarios())
|
|
367
|
+
|
|
368
|
+
elif action == "train_custom":
|
|
369
|
+
scenario_name = params.get("scenario") or params.get("scenario_name")
|
|
370
|
+
net_file = params.get("net_file")
|
|
371
|
+
route_file = params.get("route_file")
|
|
372
|
+
|
|
373
|
+
if scenario_name:
|
|
374
|
+
net_file, route_file, err = find_sumo_rl_scenario_files(str(scenario_name))
|
|
375
|
+
if err:
|
|
376
|
+
return err
|
|
377
|
+
|
|
378
|
+
if not net_file or not route_file:
|
|
379
|
+
return (
|
|
380
|
+
"Error: train_custom requires either:\n"
|
|
381
|
+
" - scenario/scenario_name (built-in sumo-rl scenario), OR\n"
|
|
382
|
+
" - net_file + route_file (custom files)\n"
|
|
383
|
+
"Hint: Use manage_rl_task(list_scenarios) to see available built-in scenarios."
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
out_dir = params.get("out_dir") or params.get("output_dir") or "output"
|
|
387
|
+
|
|
388
|
+
episodes_raw = params.get("episodes", params.get("num_episodes", 1))
|
|
389
|
+
steps_raw = params.get("steps", params.get("steps_per_episode", 1000))
|
|
390
|
+
try:
|
|
391
|
+
episodes = int(episodes_raw)
|
|
392
|
+
except (TypeError, ValueError):
|
|
393
|
+
return f"Error: episodes must be an integer, got {episodes_raw!r}"
|
|
394
|
+
try:
|
|
395
|
+
steps_per_episode = int(steps_raw)
|
|
396
|
+
except (TypeError, ValueError):
|
|
397
|
+
return f"Error: steps must be an integer, got {steps_raw!r}"
|
|
398
|
+
|
|
399
|
+
if episodes <= 0:
|
|
400
|
+
return "Error: episodes must be > 0"
|
|
401
|
+
if steps_per_episode <= 0:
|
|
402
|
+
return "Error: steps must be > 0"
|
|
403
|
+
|
|
404
|
+
algorithm = str(params.get("algorithm", "ql"))
|
|
405
|
+
reward_type = str(params.get("reward_type", "diff-waiting-time"))
|
|
406
|
+
|
|
407
|
+
return run_rl_training(
|
|
408
|
+
net_file=str(net_file),
|
|
409
|
+
route_file=str(route_file),
|
|
410
|
+
out_dir=str(out_dir),
|
|
411
|
+
episodes=episodes,
|
|
412
|
+
steps_per_episode=steps_per_episode,
|
|
413
|
+
algorithm=algorithm,
|
|
414
|
+
reward_type=reward_type,
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
return f"Unknown action: {action}"
|
|
418
|
+
|
|
419
|
+
# --- Legacy/Misc ---
|
|
420
|
+
@server.tool(name="get_sumo_info", description="Get the version and path of the installed SUMO.")
|
|
421
|
+
def get_sumo_info() -> str:
|
|
422
|
+
try:
|
|
423
|
+
sumo_binary = find_sumo_binary("sumo")
|
|
424
|
+
if not sumo_binary:
|
|
425
|
+
return (
|
|
426
|
+
"Error: Could not locate SUMO executable. "
|
|
427
|
+
"Please ensure SUMO is installed and either `sumo` is available in PATH or `SUMO_HOME` is set."
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
result = subprocess.run(
|
|
431
|
+
[sumo_binary, "--version"],
|
|
432
|
+
capture_output=True,
|
|
433
|
+
text=True,
|
|
434
|
+
check=True,
|
|
435
|
+
timeout=10,
|
|
436
|
+
)
|
|
437
|
+
version_output = (result.stdout.splitlines() or ["Unknown"])[0]
|
|
438
|
+
|
|
439
|
+
sumo_home = find_sumo_home()
|
|
440
|
+
tools_dir = find_sumo_tools_dir()
|
|
441
|
+
return "\n".join(
|
|
442
|
+
[
|
|
443
|
+
f"SUMO Binary: {sumo_binary}",
|
|
444
|
+
f"SUMO Version: {version_output}",
|
|
445
|
+
f"SUMO_HOME: {sumo_home or 'Not Set'}",
|
|
446
|
+
f"SUMO Tools Dir: {tools_dir or 'Not Found'}",
|
|
447
|
+
]
|
|
448
|
+
)
|
|
449
|
+
except Exception as e:
|
|
450
|
+
return f"Error checking SUMO: {str(e)}"
|
|
451
|
+
|
|
452
|
+
@server.tool(name="run_simple_simulation", description="Run a SUMO simulation using a config file.")
|
|
453
|
+
def run_simple_simulation_tool(config_path: str, steps: int = 100) -> str:
|
|
454
|
+
return run_simple_simulation(config_path, steps)
|
|
455
|
+
|
|
456
|
+
@server.tool(description="Analyze FCD output.")
|
|
457
|
+
def run_analysis(fcd_file: str) -> str:
|
|
458
|
+
return analyze_fcd(fcd_file)
|
|
459
|
+
|
|
460
|
+
if __name__ == "__main__":
|
|
461
|
+
# NOTE:
|
|
462
|
+
# MCP stdio transport relies on AnyIO/asyncio to process thread callbacks.
|
|
463
|
+
# In some environments, a lack of scheduled timers can cause the event loop to
|
|
464
|
+
# block indefinitely while waiting for stdio worker-thread results. A small
|
|
465
|
+
# periodic sleep keeps the loop responsive without emitting any stdout output.
|
|
466
|
+
import anyio
|
|
467
|
+
|
|
468
|
+
async def _wakeup_task() -> None:
|
|
469
|
+
while True:
|
|
470
|
+
await anyio.sleep(0.1)
|
|
471
|
+
|
|
472
|
+
async def _run_stdio_with_wakeup() -> None:
|
|
473
|
+
async with anyio.create_task_group() as tg:
|
|
474
|
+
tg.start_soon(_wakeup_task)
|
|
475
|
+
await server.run_stdio_async()
|
|
476
|
+
|
|
477
|
+
def main():
|
|
478
|
+
"""Entry point for running the MCP server via uvx."""
|
|
479
|
+
import anyio
|
|
480
|
+
|
|
481
|
+
async def _wakeup_task() -> None:
|
|
482
|
+
while True:
|
|
483
|
+
await anyio.sleep(0.1)
|
|
484
|
+
|
|
485
|
+
async def _run_stdio_with_wakeup() -> None:
|
|
486
|
+
async with anyio.create_task_group() as tg:
|
|
487
|
+
tg.start_soon(_wakeup_task)
|
|
488
|
+
await server.run_stdio_async()
|
|
489
|
+
|
|
490
|
+
anyio.run(_run_stdio_with_wakeup)
|
|
491
|
+
|
|
492
|
+
if __name__ == "__main__":
|
|
493
|
+
main()
|
utils/__init__.py
ADDED
|
File without changes
|