pygpt-net 2.6.50__py3-none-any.whl → 2.6.52__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,513 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # ================================================== #
4
+ # This file is a part of PYGPT package #
5
+ # Website: https://pygpt.net #
6
+ # GitHub: https://github.com/szczyglis-dev/py-gpt #
7
+ # MIT License #
8
+ # Created By : Marcin Szczygliński #
9
+ # Updated Date: 2025.09.16 22:00:00 #
10
+ # ================================================== #
11
+
12
+ import asyncio
13
+ import hashlib
14
+ import re
15
+ import shlex
16
+ import time
17
+ from typing import Dict, List, Tuple, Any, Optional
18
+ from urllib.parse import urlparse
19
+
20
+ from pygpt_net.plugin.base.plugin import BasePlugin
21
+ from pygpt_net.core.events import Event
22
+ from pygpt_net.item.ctx import CtxItem
23
+
24
+ from .config import Config
25
+
26
+
27
+ class Plugin(BasePlugin):
28
+ def __init__(self, *args, **kwargs):
29
+ super(Plugin, self).__init__(*args, **kwargs)
30
+ self.id = "mcp"
31
+ self.name = "MCP"
32
+ self.description = "Use remote tools via MCP"
33
+ self.prefix = "RemoteTool"
34
+ self.order = 100
35
+ self.use_locale = True
36
+ self.worker = None
37
+ self.config = Config(self)
38
+ self.init_options()
39
+
40
+ # Runtime index for quick execution lookup
41
+ self.tools_index: Dict[str, Dict[str, Any]] = {}
42
+
43
+ # In-memory discovery cache (per server)
44
+ self._tools_cache: Dict[str, Dict[str, Any]] = {}
45
+ self._last_config_signature: Optional[str] = None
46
+
47
+ def init_options(self):
48
+ """Initialize options"""
49
+ self.config.from_defaults(self)
50
+
51
+ def handle(self, event: Event, *args, **kwargs):
52
+ """
53
+ Handle dispatched event
54
+
55
+ :param event: event object
56
+ :param args: event args
57
+ :param kwargs: event kwargs
58
+ """
59
+ name = event.name
60
+ data = event.data
61
+ ctx = event.ctx
62
+
63
+ if name == Event.CMD_SYNTAX:
64
+ self.cmd_syntax(data)
65
+
66
+ elif name == Event.CMD_EXECUTE:
67
+ self.cmd(
68
+ ctx,
69
+ data['commands'],
70
+ )
71
+
72
+ def cmd_syntax(self, data: dict):
73
+ """
74
+ Event: CMD_SYNTAX
75
+ Build "cmd" entries based on tools discovered from active MCP servers.
76
+ Applies allow/deny per server. Uses cache with TTL.
77
+
78
+ :param data: event data dict
79
+ """
80
+ servers: List[dict] = self.get_option_value("servers") or []
81
+ active_servers = [(i, s) for i, s in enumerate(servers) if s.get("active", False)]
82
+ self.tools_index.clear()
83
+
84
+ if len(active_servers) == 0:
85
+ return
86
+
87
+ # Invalidate cache if config changed
88
+ current_sig = self._config_signature(active_servers)
89
+ if current_sig != self._last_config_signature:
90
+ self._tools_cache.clear()
91
+ self._last_config_signature = current_sig
92
+
93
+ try:
94
+ discovered = self._discover_tools_sync(active_servers)
95
+ except Exception as e:
96
+ self.error(e)
97
+ self.error(f"MCP: discovery failed: {e}")
98
+ return
99
+
100
+ used_names = set() # to ensure unique tool names in this batch
101
+
102
+ for (server_idx, server_tag, transport, tool, server_cfg) in discovered:
103
+ tool_name = getattr(tool, "name", None) or tool.get("name")
104
+ description = getattr(tool, "description", None) or tool.get("description")
105
+ input_schema = getattr(tool, "inputSchema", None) or tool.get("inputSchema")
106
+
107
+ # Human-friendly display name
108
+ display_name = tool_name
109
+ try:
110
+ from mcp.shared.metadata_utils import get_display_name # type: ignore
111
+ display_name = get_display_name(tool) or tool_name
112
+ except Exception:
113
+ pass
114
+
115
+ # Server label -> used in tool name (avoid dots and other invalid chars)
116
+ server_label = (server_cfg.get("label") or server_tag or f"srv{server_idx}").strip()
117
+ server_slug = self._slugify(server_label)
118
+
119
+ # Compose final tool "cmd" name acceptable by OpenAI: ^[a-zA-Z0-9_-]+$, <=64 chars
120
+ cmd_name = self._compose_cmd_name(server_slug, tool_name, used_names)
121
+ used_names.add(cmd_name)
122
+
123
+ params = self.extract_params_from_schema(input_schema)
124
+
125
+ # Instruction for the model
126
+ if description and display_name and display_name != tool_name:
127
+ instruction = f"{display_name}: {description} (server: {server_label})"
128
+ elif description:
129
+ instruction = f"{description} (server: {server_label})"
130
+ else:
131
+ instruction = f"Call remote MCP tool '{display_name}' on server '{server_label}'."
132
+
133
+ cmd_syntax = {
134
+ "cmd": cmd_name,
135
+ "instruction": instruction,
136
+ "params": params,
137
+ "enabled": True,
138
+ }
139
+ data['cmd'].append(cmd_syntax)
140
+
141
+ # Index for execution
142
+ self.tools_index[cmd_name] = {
143
+ "server_idx": server_idx,
144
+ "server": server_cfg,
145
+ "server_tag": server_tag,
146
+ "transport": transport,
147
+ "tool_name": tool_name,
148
+ "schema": input_schema,
149
+ "description": description,
150
+ "display_name": display_name,
151
+ }
152
+
153
+ def cmd(self, ctx: CtxItem, cmds: list):
154
+ """
155
+ Event: CMD_EXECUTE
156
+
157
+ :param ctx: CtxItem
158
+ :param cmds: commands dict
159
+ """
160
+ from .worker import Worker
161
+
162
+ my_commands = [item for item in cmds if item.get("cmd") in self.tools_index]
163
+ if len(my_commands) == 0:
164
+ return
165
+
166
+ # set state: busy
167
+ self.cmd_prepare(ctx, my_commands)
168
+
169
+ try:
170
+ worker = Worker()
171
+ worker.from_defaults(self)
172
+ worker.plugin = self
173
+ worker.cmds = my_commands
174
+ worker.ctx = ctx
175
+ worker.tools_index = self.tools_index
176
+
177
+ if not self.is_async(ctx):
178
+ worker.run()
179
+ return
180
+ worker.run_async()
181
+
182
+ except Exception as e:
183
+ self.error(e)
184
+
185
+ # ---------------------------
186
+ # Discovery + caching
187
+ # ---------------------------
188
+
189
+ def _discover_tools_sync(self, active_servers: List[Tuple[int, dict]]) -> List[Tuple[int, str, str, Any, dict]]:
190
+ """Run async discovery in a dedicated loop and return collected tools."""
191
+ return asyncio.run(self._discover_tools_async(active_servers))
192
+
193
+ async def _discover_tools_async(
194
+ self,
195
+ active_servers: List[Tuple[int, dict]],
196
+ per_server_timeout: float = 8.0
197
+ ) -> List[Tuple[int, str, str, Any, dict]]:
198
+ """
199
+ Discover tools for each active server (with cache).
200
+ Returns tuples: (server_idx, server_tag, transport, tool, server_cfg)
201
+ """
202
+ results: List[Tuple[int, str, str, Any, dict]] = []
203
+
204
+ # Lazy import
205
+ try:
206
+ from mcp import ClientSession # type: ignore
207
+ from mcp.client.stdio import stdio_client # type: ignore
208
+ from mcp.client.streamable_http import streamablehttp_client # type: ignore
209
+ from mcp.client.sse import sse_client # type: ignore
210
+ from mcp import StdioServerParameters # type: ignore
211
+ except Exception as e:
212
+ self.error('MCP SDK not installed. Install with: pip install "mcp[cli]"')
213
+ self.log(f"MCP import error: {e}")
214
+ return results
215
+
216
+ cache_enabled = bool(self.get_option_value("tools_cache_enabled"))
217
+ try:
218
+ ttl = int(self.get_option_value("tools_cache_ttl") or 300)
219
+ except Exception:
220
+ ttl = 300
221
+
222
+ for server_idx, server in active_servers:
223
+ address = (server.get("server_address") or "").strip()
224
+ if not address:
225
+ continue
226
+
227
+ transport = self._detect_transport(address)
228
+ server_tag = self._make_server_tag(server, server_idx)
229
+ server_key = self._server_key(server)
230
+ headers = self._build_headers(server)
231
+
232
+ allowed = self._parse_csv(server.get("allowed_commands"))
233
+ disabled = self._parse_csv(server.get("disabled_commands"))
234
+
235
+ # Cache
236
+ cached_tools = None
237
+ if cache_enabled:
238
+ cached = self._tools_cache.get(server_key)
239
+ if cached and cached.get("transport") == transport:
240
+ if (time.time() - float(cached.get("ts", 0))) <= ttl:
241
+ cached_tools = cached.get("tools", None)
242
+
243
+ async def list_tools_for_session(session: ClientSession) -> List[Any]:
244
+ tools_resp = await session.list_tools()
245
+ return list(tools_resp.tools)
246
+
247
+ try:
248
+ if cached_tools is None:
249
+ async def _run_discovery():
250
+ if transport == "stdio":
251
+ cmd, args = self._parse_stdio_command(address)
252
+ params = StdioServerParameters(command=cmd, args=args)
253
+ async with stdio_client(params) as (read, write):
254
+ async with ClientSession(read, write) as session:
255
+ await session.initialize()
256
+ return await list_tools_for_session(session)
257
+
258
+ elif transport == "http":
259
+ # Streamable HTTP – pass Authorization if set
260
+ async with streamablehttp_client(address, headers=headers or None) as (read, write, _):
261
+ async with ClientSession(read, write) as session:
262
+ await session.initialize()
263
+ return await list_tools_for_session(session)
264
+
265
+ elif transport == "sse":
266
+ # SSE – pass Authorization if set
267
+ async with sse_client(address, headers=headers or None) as (read, write):
268
+ async with ClientSession(read, write) as session:
269
+ await session.initialize()
270
+ return await list_tools_for_session(session)
271
+
272
+ else:
273
+ raise RuntimeError(f"Unsupported MCP transport for server '{server_tag}': {transport}")
274
+
275
+ tools = await asyncio.wait_for(_run_discovery(), timeout=per_server_timeout)
276
+
277
+ if cache_enabled:
278
+ self._tools_cache[server_key] = {
279
+ "ts": time.time(),
280
+ "transport": transport,
281
+ "tools": tools,
282
+ }
283
+ else:
284
+ tools = cached_tools
285
+
286
+ for tool in tools:
287
+ tname = getattr(tool, "name", None) or tool.get("name")
288
+ if disabled and tname in disabled:
289
+ continue
290
+ if allowed and tname not in allowed:
291
+ continue
292
+ results.append((server_idx, server_tag, transport, tool, server))
293
+
294
+ except asyncio.TimeoutError:
295
+ self.error(f"MCP: timeout during discovery on server '{server_tag}'")
296
+ except Exception as e:
297
+ self.log(f"MCP discovery error on '{server_tag}': {e}")
298
+ self.error(f"MCP: discovery error on '{server_tag}': {e}")
299
+
300
+ return results
301
+
302
+ # --------------
303
+ # Schema helpers
304
+ # --------------
305
+
306
+ def extract_params(self, text: str) -> list:
307
+ """Extract params to list."""
308
+ params = []
309
+ if text is None or text == "":
310
+ return params
311
+ params_list = text.split(",")
312
+ for param in params_list:
313
+ param = param.strip()
314
+ if param == "":
315
+ continue
316
+ params.append({
317
+ "name": param,
318
+ "type": "str",
319
+ "description": param,
320
+ })
321
+ return params
322
+
323
+ def extract_params_from_schema(self, schema: Optional[dict]) -> List[dict]:
324
+ """Convert MCP tool inputSchema (JSON Schema) to {name, type, description} list."""
325
+ params: List[dict] = []
326
+ if not schema or not isinstance(schema, dict):
327
+ return params
328
+
329
+ properties = schema.get("properties", {})
330
+ required = set(schema.get("required", []) or [])
331
+
332
+ for name, prop in properties.items():
333
+ jtype = prop.get("type", "string")
334
+ desc = prop.get("description", "")
335
+ ptype = self._map_json_type_to_param_type(jtype)
336
+ if name in required and desc:
337
+ desc = f"{desc} [required]"
338
+ elif name in required:
339
+ desc = "[required]"
340
+ params.append({
341
+ "name": name,
342
+ "type": ptype,
343
+ "description": desc or name,
344
+ })
345
+
346
+ return params
347
+
348
+ # ---------------------------
349
+ # Low-level utilities
350
+ # ---------------------------
351
+
352
+ def _map_json_type_to_param_type(self, jtype: str) -> str:
353
+ """Map JSON Schema types to simple plugin param types."""
354
+ t = (jtype or "string").lower()
355
+ if t in ("string",):
356
+ return "str"
357
+ if t in ("integer", "number"):
358
+ return "float" if t == "number" else "int"
359
+ if t in ("boolean",):
360
+ return "bool"
361
+ if t in ("array", "object"):
362
+ return "str"
363
+ return "str"
364
+
365
+ def _parse_csv(self, text: Optional[str]) -> Optional[set]:
366
+ """Parse comma-separated string into a set of stripped items or None if empty."""
367
+ if not text:
368
+ return None
369
+ items = [x.strip() for x in text.split(",")]
370
+ items = [x for x in items if x]
371
+ return set(items) if items else None
372
+
373
+ def _detect_transport(self, address: str) -> str:
374
+ """
375
+ Detect transport from address:
376
+ - 'stdio: ...' -> stdio
377
+ - 'http(s)://.../mcp' or general http(s) -> http (Streamable HTTP)
378
+ - 'sse://' or 'sse+http(s)://' or path containing '/sse' -> sse
379
+ """
380
+ if address.lower().startswith("stdio:"):
381
+ return "stdio"
382
+ lower = address.lower()
383
+ if lower.startswith(("sse://", "sse+http://", "sse+https://")):
384
+ return "sse"
385
+ if lower.startswith(("http://", "https://")):
386
+ try:
387
+ parsed = urlparse(address)
388
+ path = (parsed.path or "").lower()
389
+ except Exception:
390
+ path = ""
391
+ if "/sse" in path or path.endswith("/sse"):
392
+ return "sse"
393
+ return "http"
394
+ return "stdio"
395
+
396
+ def _parse_stdio_command(self, address: str) -> Tuple[str, List[str]]:
397
+ """Parse 'stdio: <command line>' into (command, args)."""
398
+ cmdline = address[len("stdio:"):].strip()
399
+ tokens = shlex.split(cmdline)
400
+ if not tokens:
401
+ raise ValueError("Invalid stdio address: empty command")
402
+ return tokens[0], tokens[1:]
403
+
404
+ def _make_server_tag(self, server: dict, idx: int) -> str:
405
+ """
406
+ Create a short tag for the server (for display only).
407
+ Prefer explicit 'label'; fallback to host/path-derived value.
408
+ """
409
+ label = (server.get("label") or "").strip()
410
+ if label:
411
+ return label
412
+ address = (server.get("server_address") or "").strip()
413
+ if address.startswith("stdio:"):
414
+ cmdline = address[len("stdio:"):].strip()
415
+ exe = shlex.split(cmdline)[0] if cmdline else f"stdio_{idx}"
416
+ return exe
417
+ try:
418
+ parsed = urlparse(address)
419
+ host = (parsed.netloc or f"server_{idx}")
420
+ tail = (parsed.path.rstrip("/").split("/")[-1] or "mcp")
421
+ return f"{host}_{tail}"
422
+ except Exception:
423
+ return f"server_{idx}"
424
+
425
+ def _build_headers(self, server: dict) -> Optional[dict]:
426
+ """
427
+ Build optional headers for HTTP/SSE transports.
428
+ Currently supports Authorization only.
429
+ """
430
+ auth = (server.get("authorization") or "").strip()
431
+ headers = {}
432
+ if auth:
433
+ # If user passed only token, you may expect 'Bearer <token>'
434
+ headers["Authorization"] = auth
435
+ return headers or None
436
+
437
+ def _slugify(self, text: str) -> str:
438
+ """
439
+ Sanitize text to allowed chars for tool names: [a-zA-Z0-9_-]
440
+ Collapse multiple underscores and strip from ends.
441
+ """
442
+ if not text:
443
+ return "srv"
444
+ s = re.sub(r"[^a-zA-Z0-9_-]+", "_", text)
445
+ s = re.sub(r"_+", "_", s).strip("_")
446
+ return s or "srv"
447
+
448
+ def _truncate_with_hash(self, base: str, max_len: int) -> str:
449
+ """
450
+ Truncate a string to max_len with a short hash suffix to preserve uniqueness.
451
+ """
452
+ if len(base) <= max_len:
453
+ return base
454
+ h = hashlib.sha1(base.encode("utf-8")).hexdigest()[:6]
455
+ keep = max_len - 7 # 1 for '-' + 6 for hash
456
+ keep = max(1, keep)
457
+ return f"{base[:keep]}-{h}"
458
+
459
+ def _compose_cmd_name(self, server_slug: str, tool_name: str, used: set) -> str:
460
+ """
461
+ Compose final command name:
462
+ - No global prefixes
463
+ - Format: <server_slug>__<tool_slug>
464
+ - Allowed charset: [a-zA-Z0-9_-]
465
+ - Max length: 64 (OpenAI requirement)
466
+ - Ensure uniqueness within one CMD_SYNTAX build
467
+ """
468
+ tool_slug = self._slugify(tool_name)
469
+
470
+ # Initial compose and length guard
471
+ base = f"{server_slug}__{tool_slug}"
472
+ name = self._truncate_with_hash(base, 64)
473
+
474
+ # Ensure uniqueness; add numeric suffix if needed (within 64 limit)
475
+ if name not in used:
476
+ return name
477
+
478
+ i = 2
479
+ while True:
480
+ suffix = f"-{i}"
481
+ max_len = 64 - len(suffix)
482
+ candidate = self._truncate_with_hash(base, max_len) + suffix
483
+ if candidate not in used:
484
+ return candidate
485
+ i += 1
486
+
487
+ def _server_key(self, server: dict) -> str:
488
+ """Deterministic key for a server config entry."""
489
+ addr = (server.get("server_address") or "").strip()
490
+ if addr.lower().startswith("http"):
491
+ try:
492
+ parsed = urlparse(addr)
493
+ return f"http::{parsed.netloc}{parsed.path}"
494
+ except Exception:
495
+ return f"http::{addr}"
496
+ if addr.lower().startswith(("sse://", "sse+http://", "sse+https://")):
497
+ return f"sse::{addr}"
498
+ if addr.startswith("stdio:"):
499
+ return f"stdio::{addr[len('stdio:'):].strip()}"
500
+ return addr
501
+
502
+ def _config_signature(self, active_servers: List[Tuple[int, dict]]) -> str:
503
+ """Signature of current config to invalidate cache when config changes."""
504
+ norm: List[str] = []
505
+ for idx, srv in active_servers:
506
+ addr = (srv.get("server_address") or "").strip()
507
+ label = (srv.get("label") or "").strip()
508
+ auth = (srv.get("authorization") or "").strip()
509
+ a = ",".join(sorted(list(self._parse_csv(srv.get("allowed_commands")) or [])))
510
+ d = ",".join(sorted(list(self._parse_csv(srv.get("disabled_commands")) or [])))
511
+ norm.append(f"{idx}|{label}|{addr}|AUTH:{bool(auth)}|A:{a}|D:{d}")
512
+ blob = "|#|".join(norm)
513
+ return hashlib.sha256(blob.encode("utf-8")).hexdigest()