llms-py 3.0.7__py3-none-any.whl → 3.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llms/main.py CHANGED
@@ -28,7 +28,7 @@ from datetime import datetime
28
28
  from importlib import resources # Py≥3.9 (pip install importlib_resources for 3.7/3.8)
29
29
  from io import BytesIO
30
30
  from pathlib import Path
31
- from typing import get_type_hints
31
+ from typing import Optional, get_type_hints
32
32
  from urllib.parse import parse_qs, urlencode, urljoin
33
33
 
34
34
  import aiohttp
@@ -41,7 +41,7 @@ try:
41
41
  except ImportError:
42
42
  HAS_PIL = False
43
43
 
44
- VERSION = "3.0.7"
44
+ VERSION = "3.0.8"
45
45
  _ROOT = None
46
46
  DEBUG = os.getenv("DEBUG") == "1"
47
47
  MOCK = os.getenv("MOCK") == "1"
@@ -211,8 +211,8 @@ def pluralize(word, count):
211
211
 
212
212
 
213
213
  def get_file_mime_type(filename):
214
- mime_type, _ = mimetypes.guess_type(filename)
215
- return mime_type or "application/octet-stream"
214
+ mimetype, _ = mimetypes.guess_type(filename)
215
+ return mimetype or "application/octet-stream"
216
216
 
217
217
 
218
218
  def price_to_string(price: float | int | str | None) -> str | None:
@@ -369,6 +369,75 @@ def function_to_tool_definition(func):
369
369
  }
370
370
 
371
371
 
372
+ async def download_file(url):
373
+ async with aiohttp.ClientSession() as session:
374
+ return await session_download_file(session, url)
375
+
376
+
377
+ async def session_download_file(session, url, default_mimetype="application/octet-stream"):
378
+ try:
379
+ async with session.get(url, timeout=aiohttp.ClientTimeout(total=120)) as response:
380
+ response.raise_for_status()
381
+ content = await response.read()
382
+ mimetype = response.headers.get("Content-Type")
383
+ disposition = response.headers.get("Content-Disposition")
384
+ if mimetype and ";" in mimetype:
385
+ mimetype = mimetype.split(";")[0]
386
+ ext = None
387
+ if disposition:
388
+ start = disposition.index('filename="') + len('filename="')
389
+ end = disposition.index('"', start)
390
+ filename = disposition[start:end]
391
+ if not mimetype:
392
+ mimetype = mimetypes.guess_type(filename)[0] or default_mimetype
393
+ else:
394
+ filename = url.split("/")[-1]
395
+ if "." not in filename:
396
+ if mimetype is None:
397
+ mimetype = default_mimetype
398
+ ext = mimetypes.guess_extension(mimetype) or mimetype.split("/")[1]
399
+ filename = f"{filename}.{ext}"
400
+
401
+ if not ext:
402
+ ext = Path(filename).suffix.lstrip(".")
403
+
404
+ info = {
405
+ "url": url,
406
+ "type": mimetype,
407
+ "name": filename,
408
+ "ext": ext,
409
+ }
410
+ return content, info
411
+ except Exception as e:
412
+ _err(f"Error downloading file: {url}", e)
413
+ raise e
414
+
415
+
416
+ def read_binary_file(url):
417
+ try:
418
+ path = Path(url)
419
+ with open(url, "rb") as f:
420
+ content = f.read()
421
+ info_path = path.stem + ".info.json"
422
+ if os.path.exists(info_path):
423
+ with open(info_path) as f_info:
424
+ info = json.load(f_info)
425
+ return content, info
426
+
427
+ stat = path.stat()
428
+ info = {
429
+ "date": int(stat.st_mtime),
430
+ "name": path.name,
431
+ "ext": path.suffix.lstrip("."),
432
+ "type": mimetypes.guess_type(path.name)[0],
433
+ "url": f"/~cache/{path.name[:2]}/{path.name}",
434
+ }
435
+ return content, info
436
+ except Exception as e:
437
+ _err(f"Error reading file: {url}", e)
438
+ raise e
439
+
440
+
372
441
  async def process_chat(chat, provider_id=None):
373
442
  if not chat:
374
443
  raise Exception("No chat provided")
@@ -397,31 +466,20 @@ async def process_chat(chat, provider_id=None):
397
466
  url = get_cache_path(url[8:])
398
467
  if is_url(url):
399
468
  _log(f"Downloading image: {url}")
400
- async with session.get(url, timeout=aiohttp.ClientTimeout(total=120)) as response:
401
- response.raise_for_status()
402
- content = await response.read()
403
- # get mimetype from response headers
404
- mimetype = get_file_mime_type(get_filename(url))
405
- if "Content-Type" in response.headers:
406
- mimetype = response.headers["Content-Type"]
407
- # convert/resize image if needed
408
- content, mimetype = convert_image_if_needed(content, mimetype)
409
- # convert to data uri
410
- image_url["url"] = (
411
- f"data:{mimetype};base64,{base64.b64encode(content).decode('utf-8')}"
412
- )
469
+ content, info = await session_download_file(session, url, default_mimetype="image/png")
470
+ mimetype = info["type"]
471
+ # convert/resize image if needed
472
+ content, mimetype = convert_image_if_needed(content, mimetype)
473
+ # convert to data uri
474
+ image_url["url"] = f"data:{mimetype};base64,{base64.b64encode(content).decode('utf-8')}"
413
475
  elif is_file_path(url):
414
476
  _log(f"Reading image: {url}")
415
- with open(url, "rb") as f:
416
- content = f.read()
417
- # get mimetype from file extension
418
- mimetype = get_file_mime_type(get_filename(url))
419
- # convert/resize image if needed
420
- content, mimetype = convert_image_if_needed(content, mimetype)
421
- # convert to data uri
422
- image_url["url"] = (
423
- f"data:{mimetype};base64,{base64.b64encode(content).decode('utf-8')}"
424
- )
477
+ content, info = read_binary_file(url)
478
+ mimetype = info["type"]
479
+ # convert/resize image if needed
480
+ content, mimetype = convert_image_if_needed(content, mimetype)
481
+ # convert to data uri
482
+ image_url["url"] = f"data:{mimetype};base64,{base64.b64encode(content).decode('utf-8')}"
425
483
  elif url.startswith("data:"):
426
484
  # Extract existing data URI and process it
427
485
  if ";base64," in url:
@@ -443,29 +501,24 @@ async def process_chat(chat, provider_id=None):
443
501
  url = input_audio["data"]
444
502
  if url.startswith("/~cache/"):
445
503
  url = get_cache_path(url[8:])
446
- mimetype = get_file_mime_type(get_filename(url))
447
504
  if is_url(url):
448
505
  _log(f"Downloading audio: {url}")
449
- async with session.get(url, timeout=aiohttp.ClientTimeout(total=120)) as response:
450
- response.raise_for_status()
451
- content = await response.read()
452
- # get mimetype from response headers
453
- if "Content-Type" in response.headers:
454
- mimetype = response.headers["Content-Type"]
455
- # convert to base64
456
- input_audio["data"] = base64.b64encode(content).decode("utf-8")
457
- if provider_id == "alibaba":
458
- input_audio["data"] = f"data:{mimetype};base64,{input_audio['data']}"
459
- input_audio["format"] = mimetype.rsplit("/", 1)[1]
506
+ content, info = await session_download_file(session, url, default_mimetype="audio/mp3")
507
+ mimetype = info["type"]
508
+ # convert to base64
509
+ input_audio["data"] = base64.b64encode(content).decode("utf-8")
510
+ if provider_id == "alibaba":
511
+ input_audio["data"] = f"data:{mimetype};base64,{input_audio['data']}"
512
+ input_audio["format"] = mimetype.rsplit("/", 1)[1]
460
513
  elif is_file_path(url):
461
514
  _log(f"Reading audio: {url}")
462
- with open(url, "rb") as f:
463
- content = f.read()
464
- # convert to base64
465
- input_audio["data"] = base64.b64encode(content).decode("utf-8")
466
- if provider_id == "alibaba":
467
- input_audio["data"] = f"data:{mimetype};base64,{input_audio['data']}"
468
- input_audio["format"] = mimetype.rsplit("/", 1)[1]
515
+ content, info = read_binary_file(url)
516
+ mimetype = info["type"]
517
+ # convert to base64
518
+ input_audio["data"] = base64.b64encode(content).decode("utf-8")
519
+ if provider_id == "alibaba":
520
+ input_audio["data"] = f"data:{mimetype};base64,{input_audio['data']}"
521
+ input_audio["format"] = mimetype.rsplit("/", 1)[1]
469
522
  elif is_base_64(url):
470
523
  pass # use base64 data as-is
471
524
  else:
@@ -476,24 +529,24 @@ async def process_chat(chat, provider_id=None):
476
529
  url = file["file_data"]
477
530
  if url.startswith("/~cache/"):
478
531
  url = get_cache_path(url[8:])
479
- mimetype = get_file_mime_type(get_filename(url))
480
532
  if is_url(url):
481
533
  _log(f"Downloading file: {url}")
482
- async with session.get(url, timeout=aiohttp.ClientTimeout(total=120)) as response:
483
- response.raise_for_status()
484
- content = await response.read()
485
- file["filename"] = get_filename(url)
486
- file["file_data"] = (
487
- f"data:{mimetype};base64,{base64.b64encode(content).decode('utf-8')}"
488
- )
534
+ content, info = await session_download_file(
535
+ session, url, default_mimetype="application/pdf"
536
+ )
537
+ mimetype = info["type"]
538
+ file["filename"] = info["name"]
539
+ file["file_data"] = (
540
+ f"data:{mimetype};base64,{base64.b64encode(content).decode('utf-8')}"
541
+ )
489
542
  elif is_file_path(url):
490
543
  _log(f"Reading file: {url}")
491
- with open(url, "rb") as f:
492
- content = f.read()
493
- file["filename"] = get_filename(url)
494
- file["file_data"] = (
495
- f"data:{mimetype};base64,{base64.b64encode(content).decode('utf-8')}"
496
- )
544
+ content, info = read_binary_file(url)
545
+ mimetype = info["type"]
546
+ file["filename"] = info["name"]
547
+ file["file_data"] = (
548
+ f"data:{mimetype};base64,{base64.b64encode(content).decode('utf-8')}"
549
+ )
497
550
  elif url.startswith("data:"):
498
551
  if "filename" not in file:
499
552
  file["filename"] = "file"
@@ -583,8 +636,9 @@ def cache_message_inline_data(m):
583
636
  ext = file_ext_from_mimetype(mimetype)
584
637
  filename = f"{filename}.{ext}"
585
638
 
586
- cache_url, _ = save_bytes_to_cache(base64_data, filename, {}, ignore_info=True)
639
+ cache_url, info = save_bytes_to_cache(base64_data, filename)
587
640
  file_info["file_data"] = cache_url
641
+ file_info["filename"] = info["name"]
588
642
  except Exception as e:
589
643
  _log(f"Error caching inline file: {e}")
590
644
 
@@ -598,7 +652,7 @@ class HTTPError(Exception):
598
652
  super().__init__(f"HTTP {status} {reason}")
599
653
 
600
654
 
601
- def save_bytes_to_cache(base64_data, filename, file_info, ignore_info=False):
655
+ def save_bytes_to_cache(base64_data, filename, file_info=None, ignore_info=False):
602
656
  ext = filename.split(".")[-1]
603
657
  mimetype = get_file_mime_type(filename)
604
658
  content = base64.b64decode(base64_data) if isinstance(base64_data, str) else base64_data
@@ -631,7 +685,8 @@ def save_bytes_to_cache(base64_data, filename, file_info, ignore_info=False):
631
685
  "type": mimetype,
632
686
  "name": filename,
633
687
  }
634
- info.update(file_info)
688
+ if file_info:
689
+ info.update(file_info)
635
690
 
636
691
  # Save metadata
637
692
  info_path = os.path.splitext(full_path)[0] + ".info.json"
@@ -645,6 +700,14 @@ def save_bytes_to_cache(base64_data, filename, file_info, ignore_info=False):
645
700
  return url, info
646
701
 
647
702
 
703
+ def save_audio_to_cache(base64_data, filename, audio_info, ignore_info=False):
704
+ return save_bytes_to_cache(base64_data, filename, audio_info, ignore_info)
705
+
706
+
707
+ def save_video_to_cache(base64_data, filename, file_info, ignore_info=False):
708
+ return save_bytes_to_cache(base64_data, filename, file_info, ignore_info)
709
+
710
+
648
711
  def save_image_to_cache(base64_data, filename, image_info, ignore_info=False):
649
712
  ext = filename.split(".")[-1]
650
713
  mimetype = get_file_mime_type(filename)
@@ -1348,6 +1411,162 @@ def g_chat_request(template=None, text=None, model=None, system_prompt=None):
1348
1411
  return chat
1349
1412
 
1350
1413
 
1414
+ def tool_result_part(result: dict, function_name: Optional[str] = None, function_args: Optional[dict] = None):
1415
+ args = function_args or {}
1416
+ type = result.get("type")
1417
+ prompt = args.get("prompt") or args.get("text") or args.get("message")
1418
+ if type == "text":
1419
+ return result.get("text"), None
1420
+ elif type == "image":
1421
+ format = result.get("format") or args.get("format") or "png"
1422
+ filename = result.get("filename") or args.get("filename") or f"{function_name}-{int(time.time())}.{format}"
1423
+ mime_type = get_file_mime_type(filename)
1424
+ image_info = {"type": mime_type}
1425
+ if prompt:
1426
+ image_info["prompt"] = prompt
1427
+ if "model" in args:
1428
+ image_info["model"] = args["model"]
1429
+ if "aspect_ratio" in args:
1430
+ image_info["aspect_ratio"] = args["aspect_ratio"]
1431
+ base64_data = result.get("data")
1432
+ if not base64_data:
1433
+ _dbg(f"Image data not found for {function_name}")
1434
+ return None, None
1435
+ url, _ = save_image_to_cache(base64_data, filename, image_info=image_info, ignore_info=True)
1436
+ resource = {
1437
+ "type": "image_url",
1438
+ "image_url": {
1439
+ "url": url,
1440
+ },
1441
+ }
1442
+ text = f"![{args.get('prompt') or filename}]({url})\n"
1443
+ return text, resource
1444
+ elif type == "audio":
1445
+ format = result.get("format") or args.get("format") or "mp3"
1446
+ filename = result.get("filename") or args.get("filename") or f"{function_name}-{int(time.time())}.{format}"
1447
+ mime_type = get_file_mime_type(filename)
1448
+ audio_info = {"type": mime_type}
1449
+ if prompt:
1450
+ audio_info["prompt"] = prompt
1451
+ if "model" in args:
1452
+ audio_info["model"] = args["model"]
1453
+ base64_data = result.get("data")
1454
+ if not base64_data:
1455
+ _dbg(f"Audio data not found for {function_name}")
1456
+ return None, None
1457
+ url, _ = save_audio_to_cache(base64_data, filename, audio_info=audio_info, ignore_info=True)
1458
+ resource = {
1459
+ "type": "audio_url",
1460
+ "audio_url": {
1461
+ "url": url,
1462
+ },
1463
+ }
1464
+ text = f"[{args.get('prompt') or filename}]({url})\n"
1465
+ return text, resource
1466
+ elif type == "file":
1467
+ filename = result.get("filename") or args.get("filename") or result.get("name") or args.get("name")
1468
+ format = result.get("format") or args.get("format") or (get_filename(filename) if filename else "txt")
1469
+ if not filename:
1470
+ filename = f"{function_name}-{int(time.time())}.{format}"
1471
+
1472
+ mime_type = get_file_mime_type(filename)
1473
+ file_info = {"type": mime_type}
1474
+ if prompt:
1475
+ file_info["prompt"] = prompt
1476
+ if "model" in args:
1477
+ file_info["model"] = args["model"]
1478
+ base64_data = result.get("data")
1479
+ if not base64_data:
1480
+ _dbg(f"File data not found for {function_name}")
1481
+ return None, None
1482
+ url, info = save_bytes_to_cache(base64_data, filename, file_info=file_info)
1483
+ resource = {
1484
+ "type": "file",
1485
+ "file": {
1486
+ "file_data": url,
1487
+ "filename": info["name"],
1488
+ },
1489
+ }
1490
+ text = f"[{args.get('prompt') or filename}]({url})\n"
1491
+ return text, resource
1492
+ else:
1493
+ try:
1494
+ return json.dumps(result), None
1495
+ except Exception as e:
1496
+ _dbg(f"Error converting result to JSON: {e}")
1497
+ try:
1498
+ return str(result), None
1499
+ except Exception as e:
1500
+ _dbg(f"Error converting result to string: {e}")
1501
+ return None, None
1502
+
1503
+
1504
+ def g_tool_result(result, function_name: Optional[str] = None, function_args: Optional[dict] = None):
1505
+ content = []
1506
+ resources = []
1507
+ args = function_args or {}
1508
+ if isinstance(result, dict):
1509
+ text, res = tool_result_part(result, function_name, args)
1510
+ if text:
1511
+ content.append(text)
1512
+ if res:
1513
+ resources.append(res)
1514
+ elif isinstance(result, list):
1515
+ for item in result:
1516
+ text, res = tool_result_part(item, function_name, args)
1517
+ if text:
1518
+ content.append(text)
1519
+ if res:
1520
+ resources.append(res)
1521
+ else:
1522
+ content = [str(result)]
1523
+
1524
+ text = "\n".join(content)
1525
+ return text, resources
1526
+
1527
+
1528
+ async def g_exec_tool(function_name, function_args):
1529
+ if function_name in g_app.tools:
1530
+ try:
1531
+ func = g_app.tools[function_name]
1532
+ is_async = inspect.iscoroutinefunction(func)
1533
+ _dbg(f"Executing {'async' if is_async else 'sync'} tool '{function_name}' with args: {function_args}")
1534
+ if is_async:
1535
+ return g_tool_result(await func(**function_args), function_name, function_args)
1536
+ else:
1537
+ return g_tool_result(func(**function_args), function_name, function_args)
1538
+ except Exception as e:
1539
+ return f"Error executing tool '{function_name}': {to_error_message(e)}", None
1540
+ return f"Error: Tool '{function_name}' not found", None
1541
+
1542
+
1543
+ def group_resources(resources: list):
1544
+ """
1545
+ converts list of parts into a grouped dictionary, e.g:
1546
+ [{"type: "image_url", "image_url": {"url": "/image.jpg"}}] =>
1547
+ {"images": [{"type": "image_url", "image_url": {"url": "/image.jpg"}}] }
1548
+ """
1549
+ grouped = {}
1550
+ for res in resources:
1551
+ type = res.get("type")
1552
+ if not type:
1553
+ continue
1554
+ if type == "image_url":
1555
+ group = "images"
1556
+ elif type == "audio_url":
1557
+ group = "audios"
1558
+ elif type == "file_urls" or type == "file":
1559
+ group = "files"
1560
+ elif type == "text":
1561
+ group = "texts"
1562
+ else:
1563
+ group = "others"
1564
+ if group not in grouped:
1565
+ grouped[group] = []
1566
+ grouped[group].append(res)
1567
+ return grouped
1568
+
1569
+
1351
1570
  async def g_chat_completion(chat, context=None):
1352
1571
  try:
1353
1572
  model = chat.get("model")
@@ -1445,21 +1664,15 @@ async def g_chat_completion(chat, context=None):
1445
1664
  try:
1446
1665
  function_args = json.loads(tool_call["function"]["arguments"])
1447
1666
  except Exception as e:
1448
- tool_result = f"Error parsing JSON arguments for tool {function_name}: {e}"
1667
+ tool_result = f"Error: Failed to parse JSON arguments for tool '{function_name}': {to_error_message(e)}"
1449
1668
  else:
1450
- tool_result = f"Error: Tool {function_name} not found"
1451
- if function_name in g_app.tools:
1452
- try:
1453
- func = g_app.tools[function_name]
1454
- if inspect.iscoroutinefunction(func):
1455
- tool_result = await func(**function_args)
1456
- else:
1457
- tool_result = func(**function_args)
1458
- except Exception as e:
1459
- tool_result = f"Error executing tool {function_name}: {e}"
1669
+ tool_result, resources = await g_exec_tool(function_name, function_args)
1460
1670
 
1461
1671
  # Append tool result to history
1462
1672
  tool_msg = {"role": "tool", "tool_call_id": tool_call["id"], "content": to_content(tool_result)}
1673
+
1674
+ tool_msg.update(group_resources(resources))
1675
+
1463
1676
  current_chat["messages"].append(tool_msg)
1464
1677
  tool_history.append(tool_msg)
1465
1678
 
@@ -2333,6 +2546,8 @@ class AppExtensions:
2333
2546
  self.error_auth_required = create_error_response("Authentication required", "Unauthorized")
2334
2547
  self.ui_extensions = []
2335
2548
  self.chat_request_filters = []
2549
+ self.extensions = []
2550
+ self.loaded = False
2336
2551
  self.chat_tool_filters = []
2337
2552
  self.chat_response_filters = []
2338
2553
  self.chat_error_filters = []
@@ -2345,6 +2560,7 @@ class AppExtensions:
2345
2560
  self.shutdown_handlers = []
2346
2561
  self.tools = {}
2347
2562
  self.tool_definitions = []
2563
+ self.tool_groups = {}
2348
2564
  self.index_headers = []
2349
2565
  self.index_footers = []
2350
2566
  self.request_args = {
@@ -2557,6 +2773,15 @@ class ExtensionContext:
2557
2773
  def json_from_file(self, path):
2558
2774
  return json_from_file(path)
2559
2775
 
2776
+ def download_file(self, url):
2777
+ return download_file(url)
2778
+
2779
+ def session_download_file(self, session, url):
2780
+ return session_download_file(session, url)
2781
+
2782
+ def read_binary_file(self, url):
2783
+ return read_binary_file(url)
2784
+
2560
2785
  def log(self, message):
2561
2786
  if self.verbose:
2562
2787
  print(f"[{self.name}] {message}", flush=True)
@@ -2627,25 +2852,25 @@ class ExtensionContext:
2627
2852
 
2628
2853
  self.app.server_add_get.append((os.path.join(self.ext_prefix, "{path:.*}"), serve_static, {}))
2629
2854
 
2855
+ def web_path(self, method, path):
2856
+ full_path = os.path.join(self.ext_prefix, path) if path else self.ext_prefix
2857
+ self.dbg(f"Registered {method:<6} {full_path}")
2858
+ return full_path
2859
+
2630
2860
  def add_get(self, path, handler, **kwargs):
2631
- self.dbg(f"Registered GET: {os.path.join(self.ext_prefix, path)}")
2632
- self.app.server_add_get.append((os.path.join(self.ext_prefix, path), handler, kwargs))
2861
+ self.app.server_add_get.append((self.web_path("GET", path), handler, kwargs))
2633
2862
 
2634
2863
  def add_post(self, path, handler, **kwargs):
2635
- self.dbg(f"Registered POST: {os.path.join(self.ext_prefix, path)}")
2636
- self.app.server_add_post.append((os.path.join(self.ext_prefix, path), handler, kwargs))
2864
+ self.app.server_add_post.append((self.web_path("POST", path), handler, kwargs))
2637
2865
 
2638
2866
  def add_put(self, path, handler, **kwargs):
2639
- self.dbg(f"Registered PUT: {os.path.join(self.ext_prefix, path)}")
2640
- self.app.server_add_put.append((os.path.join(self.ext_prefix, path), handler, kwargs))
2867
+ self.app.server_add_put.append((self.web_path("PUT", path), handler, kwargs))
2641
2868
 
2642
2869
  def add_delete(self, path, handler, **kwargs):
2643
- self.dbg(f"Registered DELETE: {os.path.join(self.ext_prefix, path)}")
2644
- self.app.server_add_delete.append((os.path.join(self.ext_prefix, path), handler, kwargs))
2870
+ self.app.server_add_delete.append((self.web_path("DELETE", path), handler, kwargs))
2645
2871
 
2646
2872
  def add_patch(self, path, handler, **kwargs):
2647
- self.dbg(f"Registered PATCH: {os.path.join(self.ext_prefix, path)}")
2648
- self.app.server_add_patch.append((os.path.join(self.ext_prefix, path), handler, kwargs))
2873
+ self.app.server_add_patch.append((self.web_path("PATCH", path), handler, kwargs))
2649
2874
 
2650
2875
  def add_importmaps(self, dict):
2651
2876
  self.app.import_maps.update(dict)
@@ -2677,14 +2902,88 @@ class ExtensionContext:
2677
2902
  def get_provider(self, name):
2678
2903
  return g_handlers.get(name)
2679
2904
 
2680
- def register_tool(self, func, tool_def=None):
2905
+ def sanitize_tool_def(self, tool_def):
2906
+ """
2907
+ Merge $defs parameter into tool_def property to reduce client/server complexity
2908
+ """
2909
+ # parameters = {
2910
+ # "$defs": {
2911
+ # "AspectRatio": {
2912
+ # "description": "Supported aspect ratios for image generation.",
2913
+ # "enum": [
2914
+ # "1:1",
2915
+ # "2:3",
2916
+ # "16:9"
2917
+ # ],
2918
+ # "type": "string"
2919
+ # }
2920
+ # },
2921
+ # "properties": {
2922
+ # "prompt": {
2923
+ # "type": "string"
2924
+ # },
2925
+ # "model": {
2926
+ # "default": "gemini-2.5-flash-image",
2927
+ # "type": "string"
2928
+ # },
2929
+ # "aspect_ratio": {
2930
+ # "$ref": "#/$defs/AspectRatio",
2931
+ # "default": "1:1"
2932
+ # }
2933
+ # },
2934
+ # "required": [
2935
+ # "prompt"
2936
+ # ],
2937
+ # "type": "object"
2938
+ # }
2939
+ type = tool_def.get("type")
2940
+ if type == "function":
2941
+ func_def = tool_def.get("function", {})
2942
+ parameters = func_def.get("parameters", {})
2943
+ defs = parameters.get("$defs", {})
2944
+ properties = parameters.get("properties", {})
2945
+ for prop_name, prop_def in properties.items():
2946
+ if "$ref" in prop_def:
2947
+ ref = prop_def["$ref"]
2948
+ if ref.startswith("#/$defs/"):
2949
+ def_name = ref.replace("#/$defs/", "")
2950
+ if def_name in defs:
2951
+ prop_def.update(defs[def_name])
2952
+ del prop_def["$ref"]
2953
+ if "$defs" in parameters:
2954
+ del parameters["$defs"]
2955
+ return tool_def
2956
+
2957
+ def register_tool(self, func, tool_def=None, group=None):
2681
2958
  if tool_def is None:
2682
2959
  tool_def = function_to_tool_definition(func)
2683
2960
 
2684
2961
  name = tool_def["function"]["name"]
2685
- self.log(f"Registered tool: {name}")
2962
+ if name in self.app.tools:
2963
+ self.log(f"Overriding existing tool: {name}")
2964
+ self.app.tool_definitions = [t for t in self.app.tool_definitions if t["function"]["name"] != name]
2965
+ for g_tools in self.app.tool_groups.values():
2966
+ if name in g_tools:
2967
+ g_tools.remove(name)
2968
+ else:
2969
+ self.log(f"Registered tool: {name}")
2970
+
2686
2971
  self.app.tools[name] = func
2687
- self.app.tool_definitions.append(tool_def)
2972
+ self.app.tool_definitions.append(self.sanitize_tool_def(tool_def))
2973
+ if not group:
2974
+ group = "custom"
2975
+ if group not in self.app.tool_groups:
2976
+ self.app.tool_groups[group] = []
2977
+ self.app.tool_groups[group].append(name)
2978
+
2979
+ def get_tool_definition(self, name):
2980
+ for tool_def in self.app.tool_definitions:
2981
+ if tool_def["function"]["name"] == name:
2982
+ return tool_def
2983
+ return None
2984
+
2985
+ def group_resources(self, resources: list):
2986
+ return group_resources(resources)
2688
2987
 
2689
2988
  def check_auth(self, request):
2690
2989
  return self.app.check_auth(request)
@@ -2709,6 +3008,15 @@ class ExtensionContext:
2709
3008
  def cache_message_inline_data(self, message):
2710
3009
  return cache_message_inline_data(message)
2711
3010
 
3011
+ async def exec_tool(self, name, args):
3012
+ return await g_exec_tool(name, args)
3013
+
3014
+ def tool_result(self, result, function_name: Optional[str] = None, function_args: Optional[dict] = None):
3015
+ return g_tool_result(result, function_name, function_args)
3016
+
3017
+ def tool_result_part(self, result: dict, function_name: Optional[str] = None, function_args: Optional[dict] = None):
3018
+ return tool_result_part(result, function_name, function_args)
3019
+
2712
3020
  def to_content(self, result):
2713
3021
  return to_content(result)
2714
3022
 
@@ -2816,6 +3124,8 @@ def install_extensions():
2816
3124
 
2817
3125
  _log(f"Installing {ext_count} extension{'' if ext_count == 1 else 's'}...")
2818
3126
 
3127
+ extensions = []
3128
+
2819
3129
  for item_path in extension_dirs:
2820
3130
  item = os.path.basename(item_path)
2821
3131
 
@@ -2855,11 +3165,45 @@ def install_extensions():
2855
3165
  if os.path.exists(os.path.join(ui_path, "index.mjs")):
2856
3166
  ctx.register_ui_extension("index.mjs")
2857
3167
 
3168
+ # include __load__ and __run__ hooks if they exist
3169
+ load_func = getattr(module, "__load__", None)
3170
+ if callable(load_func) and not inspect.iscoroutinefunction(load_func):
3171
+ _log(f"Warning: Extension {item} __load__ must be async")
3172
+ load_func = None
3173
+
3174
+ run_func = getattr(module, "__run__", None)
3175
+ if callable(run_func) and inspect.iscoroutinefunction(run_func):
3176
+ _log(f"Warning: Extension {item} __run__ must be sync")
3177
+ run_func = None
3178
+
3179
+ extensions.append({"name": item, "module": module, "ctx": ctx, "load": load_func, "run": run_func})
2858
3180
  except Exception as e:
2859
3181
  _err(f"Failed to install extension {item}", e)
2860
3182
  else:
2861
3183
  _dbg(f"Extension {item} not found: {item_path} is not a directory {os.path.exists(item_path)}")
2862
3184
 
3185
+ return extensions
3186
+
3187
+
3188
+ async def load_extensions():
3189
+ """
3190
+ Calls the `__load__(ctx)` async function in all installed extensions concurrently.
3191
+ """
3192
+ tasks = []
3193
+ for ext in g_app.extensions:
3194
+ if ext.get("load"):
3195
+ task = ext["load"](ext["ctx"])
3196
+ tasks.append({"name": ext["name"], "task": task})
3197
+
3198
+ if len(tasks) > 0:
3199
+ _log(f"Loading {len(tasks)} extensions...")
3200
+ results = await asyncio.gather(*[t["task"] for t in tasks], return_exceptions=True)
3201
+ for i, result in enumerate(results):
3202
+ if isinstance(result, Exception):
3203
+ # Gather returns results in order corresponding to tasks
3204
+ extension = tasks[i]
3205
+ _err(f"Failed to load extension {extension['name']}", result)
3206
+
2863
3207
 
2864
3208
  def run_extension_cli():
2865
3209
  """
@@ -3204,9 +3548,16 @@ def main():
3204
3548
  asyncio.run(update_extensions(cli_args.update))
3205
3549
  exit(0)
3206
3550
 
3207
- install_extensions()
3551
+ g_app.extensions = install_extensions()
3552
+
3553
+ # Use a persistent event loop to ensure async connections (like MCP)
3554
+ # established in load_extensions() remain active during cli_chat()
3555
+ loop = asyncio.new_event_loop()
3556
+ asyncio.set_event_loop(loop)
3208
3557
 
3209
- asyncio.run(reload_providers())
3558
+ loop.run_until_complete(reload_providers())
3559
+ loop.run_until_complete(load_extensions())
3560
+ g_app.loaded = True
3210
3561
 
3211
3562
  # print names
3212
3563
  _log(f"enabled providers: {', '.join(g_handlers.keys())}")
@@ -3259,7 +3610,9 @@ def main():
3259
3610
  # Check validity of models for a provider
3260
3611
  provider_name = cli_args.check
3261
3612
  model_names = extra_args if len(extra_args) > 0 else None
3262
- asyncio.run(check_models(provider_name, model_names))
3613
+ provider_name = cli_args.check
3614
+ model_names = extra_args if len(extra_args) > 0 else None
3615
+ loop.run_until_complete(check_models(provider_name, model_names))
3263
3616
  g_app.exit(0)
3264
3617
 
3265
3618
  if cli_args.serve is not None:
@@ -3466,11 +3819,6 @@ def main():
3466
3819
 
3467
3820
  app.router.add_get("/ext", extensions_handler)
3468
3821
 
3469
- async def tools_handler(request):
3470
- return web.json_response(g_app.tool_definitions)
3471
-
3472
- app.router.add_get("/ext/tools", tools_handler)
3473
-
3474
3822
  async def cache_handler(request):
3475
3823
  path = request.match_info["tail"]
3476
3824
  full_path = get_cache_path(path)
@@ -3505,7 +3853,7 @@ def main():
3505
3853
  if not str(requested_path).startswith(str(cache_root)):
3506
3854
  _dbg(f"Forbidden: {requested_path} is not in {cache_root}")
3507
3855
  return web.Response(text="403: Forbidden", status=403)
3508
- except Exception:
3856
+ except Exception as e:
3509
3857
  _err(f"Forbidden: {requested_path} is not in {cache_root}", e)
3510
3858
  return web.Response(text="403: Forbidden", status=403)
3511
3859
 
@@ -4027,7 +4375,7 @@ def main():
4027
4375
  if cli_args.args is not None:
4028
4376
  args = parse_args_params(cli_args.args)
4029
4377
 
4030
- asyncio.run(
4378
+ loop.run_until_complete(
4031
4379
  cli_chat(
4032
4380
  chat,
4033
4381
  tools=cli_args.tools,