iflow-mcp_xrds76354_sumo-mcp 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mcp_tools/vehicle.py ADDED
@@ -0,0 +1,52 @@
1
+ import traci
2
+ from typing import List, Tuple
3
+ from utils.connection import connection_manager
4
+
5
+ def get_vehicles() -> List[str]:
6
+ """Get the list of all active vehicle IDs."""
7
+ if not connection_manager.is_connected():
8
+ return []
9
+ return list(traci.vehicle.getIDList())
10
+
11
+ def get_vehicle_speed(vehicle_id: str) -> float:
12
+ """Get the speed of a specific vehicle (m/s)."""
13
+ if not connection_manager.is_connected():
14
+ raise RuntimeError("Not connected to SUMO.")
15
+ return float(traci.vehicle.getSpeed(vehicle_id))
16
+
17
+ def get_vehicle_position(vehicle_id: str) -> Tuple[float, float]:
18
+ """Get the (x, y) position of a specific vehicle."""
19
+ if not connection_manager.is_connected():
20
+ raise RuntimeError("Not connected to SUMO.")
21
+ x, y = traci.vehicle.getPosition(vehicle_id)
22
+ return float(x), float(y)
23
+
24
+ def get_vehicle_acceleration(vehicle_id: str) -> float:
25
+ """Get the acceleration of a specific vehicle (m/s^2)."""
26
+ if not connection_manager.is_connected():
27
+ raise RuntimeError("Not connected to SUMO.")
28
+ return float(traci.vehicle.getAcceleration(vehicle_id))
29
+
30
+ def get_vehicle_lane(vehicle_id: str) -> str:
31
+ """Get the lane ID of a specific vehicle."""
32
+ if not connection_manager.is_connected():
33
+ raise RuntimeError("Not connected to SUMO.")
34
+ return str(traci.vehicle.getLaneID(vehicle_id))
35
+
36
+ def get_vehicle_route(vehicle_id: str) -> List[str]:
37
+ """Get the route (list of edge IDs) of a specific vehicle."""
38
+ if not connection_manager.is_connected():
39
+ raise RuntimeError("Not connected to SUMO.")
40
+ return [str(edge) for edge in traci.vehicle.getRoute(vehicle_id)]
41
+
42
+ def get_simulation_info() -> dict[str, float | int]:
43
+ """Get general simulation statistics."""
44
+ if not connection_manager.is_connected():
45
+ raise RuntimeError("Not connected to SUMO.")
46
+ return {
47
+ "time": float(traci.simulation.getTime()),
48
+ "loaded_vehicles": int(traci.simulation.getLoadedNumber()),
49
+ "departed_vehicles": int(traci.simulation.getDepartedNumber()),
50
+ "arrived_vehicles": int(traci.simulation.getArrivedNumber()),
51
+ "min_expected_vehicles": int(traci.simulation.getMinExpectedNumber()),
52
+ }
resources/__init__.py ADDED
File without changes
server.py ADDED
@@ -0,0 +1,493 @@
1
+ import logging
2
+ import subprocess
3
+ from typing import Any, Dict, Optional
4
+
5
+ from mcp.server.fastmcp import FastMCP
6
+
7
+ from utils.traci import ensure_traci_start_stdout_suppressed
8
+ from mcp_tools.simulation import run_simple_simulation
9
+ from mcp_tools.network import netconvert, netgenerate, osm_get
10
+ from mcp_tools.route import random_trips, duarouter, od2trips
11
+ from mcp_tools.signal import tls_cycle_adaptation, tls_coordinator
12
+ from mcp_tools.analysis import analyze_fcd
13
+ from mcp_tools.vehicle import (
14
+ get_vehicles, get_vehicle_speed, get_vehicle_position,
15
+ get_vehicle_acceleration, get_vehicle_lane, get_vehicle_route,
16
+ get_simulation_info
17
+ )
18
+ from mcp_tools.rl import find_sumo_rl_scenario_files, list_rl_scenarios, run_rl_training
19
+ from utils.connection import connection_manager
20
+ from utils.sumo import find_sumo_binary, find_sumo_home, find_sumo_tools_dir
21
+ from workflows.sim_gen import sim_gen_workflow
22
+ from workflows.signal_opt import signal_opt_workflow
23
+ from workflows.rl_train import rl_train_workflow
24
+
25
+ # Configure logging to stderr to not interfere with MCP stdio transport
26
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
27
+ logger = logging.getLogger(__name__)
28
+
29
+ # Ensure TraCI never writes to stdout by default (MCP stdio safety).
30
+ ensure_traci_start_stdout_suppressed()
31
+
32
+ # Initialize MCP Server (official SDK)
33
+ server = FastMCP("SUMO-MCP-Server")
34
+
35
+ # --- 1. Network Management ---
36
+ @server.tool(description="Manage SUMO network (generate, convert, or download OSM).")
37
+ def manage_network(action: str, output_file: str, params: Optional[Dict[str, Any]] = None) -> str:
38
+ """
39
+ actions:
40
+ - generate: params={'grid': bool, 'grid_number': int}
41
+ - convert: params={'osm_file': str}
42
+ - download_osm: output_file is treated as output_dir. params={'bbox': str, 'prefix': str}
43
+ """
44
+ params = params or {}
45
+ options = params.get("options")
46
+
47
+ if action == "generate":
48
+ spider = bool(params.get("spider", False))
49
+ grid = bool(params.get("grid", True))
50
+ grid_number = params.get("grid_number", 3)
51
+
52
+ if spider:
53
+ # Spider network takes precedence over grid settings.
54
+ grid = False
55
+ options_list = list(options or [])
56
+
57
+ def _strip_flag(flag: str, has_value: bool = False) -> None:
58
+ while flag in options_list:
59
+ idx = options_list.index(flag)
60
+ options_list.pop(idx)
61
+ if has_value and idx < len(options_list):
62
+ options_list.pop(idx)
63
+
64
+ def _set_option(flag: str, value: str) -> None:
65
+ if flag in options_list:
66
+ idx = options_list.index(flag)
67
+ if idx + 1 < len(options_list):
68
+ options_list[idx + 1] = value
69
+ else:
70
+ options_list.append(value)
71
+ else:
72
+ options_list.extend([flag, value])
73
+
74
+ # Enforce Spider/Grid mutual exclusion even when the user provided `options`.
75
+ _strip_flag("--grid")
76
+ _strip_flag("--grid.number", has_value=True)
77
+
78
+ if "--spider" not in options_list:
79
+ options_list.insert(0, "--spider")
80
+
81
+ arms_raw = params.get("arms", params.get("arm_number"))
82
+ if arms_raw is not None:
83
+ try:
84
+ arms = int(arms_raw)
85
+ except (TypeError, ValueError):
86
+ return f"Error: arms must be a positive integer, got {arms_raw!r}"
87
+ if arms <= 0:
88
+ return "Error: arms must be > 0"
89
+ _set_option("--spider.arm-number", str(arms))
90
+
91
+ circles_raw = params.get("circles", params.get("circle_number"))
92
+ if circles_raw is not None:
93
+ try:
94
+ circles = int(circles_raw)
95
+ except (TypeError, ValueError):
96
+ return f"Error: circles must be a positive integer, got {circles_raw!r}"
97
+ if circles <= 0:
98
+ return "Error: circles must be > 0"
99
+ _set_option("--spider.circle-number", str(circles))
100
+
101
+ space_radius_raw = params.get("ring_radius", params.get("space_radius"))
102
+ if space_radius_raw is not None:
103
+ try:
104
+ space_radius = float(space_radius_raw)
105
+ except (TypeError, ValueError):
106
+ return f"Error: ring_radius must be a number, got {space_radius_raw!r}"
107
+ if space_radius <= 0:
108
+ return "Error: ring_radius must be > 0"
109
+ _set_option("--spider.space-radius", str(space_radius))
110
+
111
+ attach_length_raw = params.get("radial_distance", params.get("attach_length"))
112
+ if attach_length_raw is not None:
113
+ try:
114
+ attach_length = float(attach_length_raw)
115
+ except (TypeError, ValueError):
116
+ return f"Error: radial_distance must be a number, got {attach_length_raw!r}"
117
+ if attach_length < 0:
118
+ return "Error: radial_distance must be >= 0"
119
+ _set_option("--spider.attach-length", str(attach_length))
120
+
121
+ omit_center_raw = params.get("omit_center")
122
+ if omit_center_raw:
123
+ if "--spider.omit-center" not in options_list:
124
+ options_list.append("--spider.omit-center")
125
+
126
+ options = options_list
127
+
128
+ return netgenerate(output_file, grid, grid_number, options)
129
+
130
+ elif action == "convert" or action == "convert_osm":
131
+ osm_file = params.get("osm_file")
132
+ if not osm_file: return "Error: osm_file required for convert action"
133
+ return netconvert(osm_file, output_file, options)
134
+
135
+ elif action == "download_osm":
136
+ # output_file here acts as output_dir
137
+ bbox = params.get("bbox")
138
+ prefix = params.get("prefix", "osm")
139
+ if not bbox: return "Error: bbox required for download_osm action"
140
+ return osm_get(bbox, output_file, prefix, options)
141
+
142
+ return f"Unknown action: {action}"
143
+
144
+ # --- 2. Demand Management ---
145
+ @server.tool(description="Manage traffic demand (random trips, OD matrix, routing).")
146
+ def manage_demand(action: str, net_file: str, output_file: str, params: Optional[Dict[str, Any]] = None) -> str:
147
+ """
148
+ actions:
149
+ - generate_random: params={'end_time': int, 'period': float}
150
+ - convert_od: params={'od_file': str} (net_file unused but kept for consistency)
151
+ - compute_routes: params={'route_files': str} (input trips)
152
+ """
153
+ params = params or {}
154
+ options = params.get("options")
155
+
156
+ if action == "generate_random" or action == "random_trips":
157
+ # Backward/compat aliases: some clients use `end` instead of `end_time`.
158
+ end_time_raw = params.get("end_time", params.get("end", 3600))
159
+ period_raw = params.get("period", 1.0)
160
+ try:
161
+ end_time = int(end_time_raw)
162
+ except (TypeError, ValueError):
163
+ return f"Error: end_time must be an integer, got {end_time_raw!r}"
164
+ try:
165
+ period = float(period_raw)
166
+ except (TypeError, ValueError):
167
+ return f"Error: period must be a number, got {period_raw!r}"
168
+ return random_trips(net_file, output_file, end_time, period, options)
169
+
170
+ elif action == "convert_od" or action == "od_matrix":
171
+ od_file = params.get("od_file")
172
+ if not od_file: return "Error: od_file required for convert_od"
173
+ return od2trips(od_file, output_file, options)
174
+
175
+ elif action == "compute_routes" or action == "routing":
176
+ route_files = params.get("route_files") # Input trips file
177
+ if not route_files: return "Error: route_files required for compute_routes"
178
+ return duarouter(net_file, route_files, output_file, options)
179
+
180
+ return f"Unknown action: {action}"
181
+
182
+ # --- 3. Simulation Control ---
183
+ @server.tool(description="Control SUMO simulation (connect, step, disconnect).")
184
+ def control_simulation(action: str, params: Optional[Dict[str, Any]] = None) -> str:
185
+ """
186
+ actions:
187
+ - connect: params={'config_file': str, 'gui': bool}
188
+ - step: params={'step': float}
189
+ - disconnect: no params
190
+ """
191
+ params = params or {}
192
+
193
+ try:
194
+ timeout_s_raw = params.get("timeout_s", params.get("timeout"))
195
+ timeout_s: Optional[float] = None
196
+ if timeout_s_raw is not None:
197
+ try:
198
+ timeout_s = float(timeout_s_raw)
199
+ except (TypeError, ValueError):
200
+ return f"Error: timeout_s must be a number, got {timeout_s_raw!r}"
201
+
202
+ if action == "connect":
203
+ config_file = params.get("config_file")
204
+ gui = params.get("gui", False)
205
+ port = params.get("port", 8813)
206
+ host = params.get("host", "localhost")
207
+ if timeout_s is None:
208
+ connection_manager.connect(config_file, gui, port, host)
209
+ else:
210
+ connection_manager.connect(config_file, gui, port, host, timeout_s=timeout_s)
211
+ return "Successfully connected to SUMO."
212
+
213
+ elif action == "step":
214
+ step = params.get("step", 0)
215
+ if timeout_s is None:
216
+ connection_manager.simulation_step(step)
217
+ else:
218
+ connection_manager.simulation_step(step, timeout_s=timeout_s)
219
+ return "Simulation advanced."
220
+
221
+ elif action == "disconnect":
222
+ if timeout_s is None:
223
+ connection_manager.disconnect()
224
+ else:
225
+ connection_manager.disconnect(timeout_s=timeout_s)
226
+ return "Successfully disconnected from SUMO."
227
+
228
+ except Exception as e:
229
+ return f"Error in control_simulation ({action}): {type(e).__name__}: {e}"
230
+
231
+ return f"Unknown action: {action}"
232
+
233
+ # --- 4. Query State ---
234
+ @server.tool(description="Query simulation state (vehicles, speed, position). Requires active connection.")
235
+ def query_simulation_state(target: str, params: Optional[Dict[str, Any]] = None) -> str:
236
+ """
237
+ targets:
238
+ - vehicle_list: no params
239
+ - vehicle_variable: params={'vehicle_id': str, 'variable': 'speed'|'position'|'lane'|'acceleration'|'route'}
240
+ """
241
+ params = params or {}
242
+
243
+ try:
244
+ if target == "vehicle_list" or target == "vehicles":
245
+ vehs = get_vehicles()
246
+ return f"Active vehicles: {vehs}"
247
+
248
+ elif target == "vehicle_variable":
249
+ v_id = params.get("vehicle_id")
250
+ var = params.get("variable")
251
+ if not v_id or not var: return "Error: vehicle_id and variable required"
252
+
253
+ if var == "speed": return f"Speed: {get_vehicle_speed(v_id)}"
254
+ if var == "position": return f"Position: {get_vehicle_position(v_id)}"
255
+ if var == "acceleration": return f"Acceleration: {get_vehicle_acceleration(v_id)}"
256
+ if var == "lane": return f"Lane: {get_vehicle_lane(v_id)}"
257
+ if var == "route": return f"Route: {get_vehicle_route(v_id)}"
258
+
259
+ return f"Unknown variable: {var}"
260
+
261
+ elif target == "simulation":
262
+ info = get_simulation_info()
263
+ return f"Simulation Info: {info}"
264
+
265
+ except Exception as e:
266
+ return f"Error querying state: {type(e).__name__}: {e}"
267
+
268
+ return f"Unknown target: {target}"
269
+
270
+ # --- 5. Optimize Signals ---
271
+ @server.tool(description="Optimize traffic signals.")
272
+ def optimize_traffic_signals(method: str, net_file: str, route_file: str, output_file: str, params: Optional[Dict[str, Any]] = None) -> str:
273
+ """
274
+ methods:
275
+ - cycle_adaptation: adapt TLS cycles
276
+ - coordination: TLS coordination
277
+ """
278
+ params = params or {}
279
+ options = params.get("options")
280
+
281
+ if method == "cycle_adaptation" or method == "Websters":
282
+ return tls_cycle_adaptation(net_file, route_file, output_file)
283
+ elif method == "coordination":
284
+ return tls_coordinator(net_file, route_file, output_file, options)
285
+
286
+ return f"Unknown method: {method}"
287
+
288
+ # --- 6. Workflows ---
289
+ @server.tool(
290
+ description="""Run high-level SUMO workflows. Available workflows:
291
+
292
+ **sim_gen_eval** - Generate grid network, simulate traffic, analyze results.
293
+ params:
294
+ - grid_number (int): Grid size NxN. Default=3. Aliases: grid_size, size
295
+ - sim_seconds (int): Simulation duration in seconds. Default=100. Aliases: steps, duration, end_time
296
+ - output_dir (str): Output directory. Default="output"
297
+ Example: run_workflow("sim_gen_eval", {"grid_number": 3, "sim_seconds": 1000})
298
+
299
+ **signal_opt** - Optimize traffic signals for existing network.
300
+ params:
301
+ - net_file (str): Path to .net.xml file. REQUIRED
302
+ - route_file (str): Path to .rou.xml file. REQUIRED
303
+ - sim_seconds (int): Simulation duration. Default=3600. Aliases: steps, duration
304
+ - use_coordinator (bool): Use tlsCoordinator instead of tlsCycleAdaptation. Default=false
305
+ - output_dir (str): Output directory. Default="output"
306
+
307
+ **rl_train** - Train RL agent for traffic signal control.
308
+ params:
309
+ - scenario_name (str): Built-in scenario name (use manage_rl_task("list_scenarios") to see options). Aliases: scenario
310
+ - episodes (int): Number of training episodes. Default=5. Aliases: num_episodes
311
+ - steps (int): Steps per episode. Default=1000. Aliases: steps_per_episode
312
+ - output_dir (str): Output directory. Default="output"
313
+ """
314
+ )
315
+ def run_workflow(workflow_name: str, params: Dict[str, Any]) -> str:
316
+ """Execute a high-level workflow."""
317
+
318
+ # Helper to get param with aliases
319
+ def get_param(keys: list, default=None):
320
+ for k in keys:
321
+ if k in params:
322
+ return params[k]
323
+ return default
324
+
325
+ if workflow_name in ("sim_gen_eval", "sim_gen_workflow", "sim_gen"):
326
+ grid_number = get_param(["grid_number", "grid_size", "size"], 3)
327
+ sim_seconds = get_param(["sim_seconds", "steps", "duration", "end_time"], 100)
328
+ output_dir = get_param(["output_dir"], "output")
329
+
330
+ return sim_gen_workflow(output_dir, int(grid_number), int(sim_seconds))
331
+
332
+ elif workflow_name in ("signal_opt", "signal_opt_workflow"):
333
+ net_file = get_param(["net_file"], "")
334
+ route_file = get_param(["route_file"], "")
335
+
336
+ if not net_file or not route_file:
337
+ return "Error: signal_opt requires net_file and route_file parameters."
338
+
339
+ sim_seconds = get_param(["sim_seconds", "steps", "duration"], 3600)
340
+ use_coordinator = get_param(["use_coordinator"], False)
341
+ output_dir = get_param(["output_dir"], "output")
342
+
343
+ return signal_opt_workflow(net_file, route_file, output_dir, int(sim_seconds), bool(use_coordinator))
344
+
345
+ elif workflow_name == "rl_train":
346
+ scenario_name = get_param(["scenario_name", "scenario"], "")
347
+ episodes = get_param(["episodes", "num_episodes"], 5)
348
+ steps = get_param(["steps", "steps_per_episode"], 1000)
349
+ output_dir = get_param(["output_dir"], "output")
350
+
351
+ return rl_train_workflow(scenario_name, output_dir, int(episodes), int(steps))
352
+
353
+ return f"Unknown workflow: {workflow_name}. Available: sim_gen_eval, signal_opt, rl_train"
354
+
355
+ # --- 7. RL Task Management ---
356
+ @server.tool(description="Manage RL tasks (list scenarios, custom training).")
357
+ def manage_rl_task(action: str, params: Optional[Dict[str, Any]] = None) -> str:
358
+ """
359
+ actions:
360
+ - list_scenarios: no params
361
+ - train_custom: params={'net_file', 'route_file', 'out_dir', 'episodes', 'steps', 'algorithm', 'reward_type'}
362
+ """
363
+ params = params or {}
364
+
365
+ if action == "list_scenarios":
366
+ return str(list_rl_scenarios())
367
+
368
+ elif action == "train_custom":
369
+ scenario_name = params.get("scenario") or params.get("scenario_name")
370
+ net_file = params.get("net_file")
371
+ route_file = params.get("route_file")
372
+
373
+ if scenario_name:
374
+ net_file, route_file, err = find_sumo_rl_scenario_files(str(scenario_name))
375
+ if err:
376
+ return err
377
+
378
+ if not net_file or not route_file:
379
+ return (
380
+ "Error: train_custom requires either:\n"
381
+ " - scenario/scenario_name (built-in sumo-rl scenario), OR\n"
382
+ " - net_file + route_file (custom files)\n"
383
+ "Hint: Use manage_rl_task(list_scenarios) to see available built-in scenarios."
384
+ )
385
+
386
+ out_dir = params.get("out_dir") or params.get("output_dir") or "output"
387
+
388
+ episodes_raw = params.get("episodes", params.get("num_episodes", 1))
389
+ steps_raw = params.get("steps", params.get("steps_per_episode", 1000))
390
+ try:
391
+ episodes = int(episodes_raw)
392
+ except (TypeError, ValueError):
393
+ return f"Error: episodes must be an integer, got {episodes_raw!r}"
394
+ try:
395
+ steps_per_episode = int(steps_raw)
396
+ except (TypeError, ValueError):
397
+ return f"Error: steps must be an integer, got {steps_raw!r}"
398
+
399
+ if episodes <= 0:
400
+ return "Error: episodes must be > 0"
401
+ if steps_per_episode <= 0:
402
+ return "Error: steps must be > 0"
403
+
404
+ algorithm = str(params.get("algorithm", "ql"))
405
+ reward_type = str(params.get("reward_type", "diff-waiting-time"))
406
+
407
+ return run_rl_training(
408
+ net_file=str(net_file),
409
+ route_file=str(route_file),
410
+ out_dir=str(out_dir),
411
+ episodes=episodes,
412
+ steps_per_episode=steps_per_episode,
413
+ algorithm=algorithm,
414
+ reward_type=reward_type,
415
+ )
416
+
417
+ return f"Unknown action: {action}"
418
+
419
+ # --- Legacy/Misc ---
420
+ @server.tool(name="get_sumo_info", description="Get the version and path of the installed SUMO.")
421
+ def get_sumo_info() -> str:
422
+ try:
423
+ sumo_binary = find_sumo_binary("sumo")
424
+ if not sumo_binary:
425
+ return (
426
+ "Error: Could not locate SUMO executable. "
427
+ "Please ensure SUMO is installed and either `sumo` is available in PATH or `SUMO_HOME` is set."
428
+ )
429
+
430
+ result = subprocess.run(
431
+ [sumo_binary, "--version"],
432
+ capture_output=True,
433
+ text=True,
434
+ check=True,
435
+ timeout=10,
436
+ )
437
+ version_output = (result.stdout.splitlines() or ["Unknown"])[0]
438
+
439
+ sumo_home = find_sumo_home()
440
+ tools_dir = find_sumo_tools_dir()
441
+ return "\n".join(
442
+ [
443
+ f"SUMO Binary: {sumo_binary}",
444
+ f"SUMO Version: {version_output}",
445
+ f"SUMO_HOME: {sumo_home or 'Not Set'}",
446
+ f"SUMO Tools Dir: {tools_dir or 'Not Found'}",
447
+ ]
448
+ )
449
+ except Exception as e:
450
+ return f"Error checking SUMO: {str(e)}"
451
+
452
+ @server.tool(name="run_simple_simulation", description="Run a SUMO simulation using a config file.")
453
+ def run_simple_simulation_tool(config_path: str, steps: int = 100) -> str:
454
+ return run_simple_simulation(config_path, steps)
455
+
456
+ @server.tool(description="Analyze FCD output.")
457
+ def run_analysis(fcd_file: str) -> str:
458
+ return analyze_fcd(fcd_file)
459
+
460
+ if __name__ == "__main__":
461
+ # NOTE:
462
+ # MCP stdio transport relies on AnyIO/asyncio to process thread callbacks.
463
+ # In some environments, a lack of scheduled timers can cause the event loop to
464
+ # block indefinitely while waiting for stdio worker-thread results. A small
465
+ # periodic sleep keeps the loop responsive without emitting any stdout output.
466
+ import anyio
467
+
468
+ async def _wakeup_task() -> None:
469
+ while True:
470
+ await anyio.sleep(0.1)
471
+
472
+ async def _run_stdio_with_wakeup() -> None:
473
+ async with anyio.create_task_group() as tg:
474
+ tg.start_soon(_wakeup_task)
475
+ await server.run_stdio_async()
476
+
477
+ def main():
478
+ """Entry point for running the MCP server via uvx."""
479
+ import anyio
480
+
481
+ async def _wakeup_task() -> None:
482
+ while True:
483
+ await anyio.sleep(0.1)
484
+
485
+ async def _run_stdio_with_wakeup() -> None:
486
+ async with anyio.create_task_group() as tg:
487
+ tg.start_soon(_wakeup_task)
488
+ await server.run_stdio_async()
489
+
490
+ anyio.run(_run_stdio_with_wakeup)
491
+
492
+ if __name__ == "__main__":
493
+ main()
utils/__init__.py ADDED
File without changes