llms-py 2.0.35__py3-none-any.whl → 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. llms/__pycache__/__init__.cpython-312.pyc +0 -0
  2. llms/__pycache__/__init__.cpython-313.pyc +0 -0
  3. llms/__pycache__/__init__.cpython-314.pyc +0 -0
  4. llms/__pycache__/__main__.cpython-312.pyc +0 -0
  5. llms/__pycache__/__main__.cpython-314.pyc +0 -0
  6. llms/__pycache__/llms.cpython-312.pyc +0 -0
  7. llms/__pycache__/main.cpython-312.pyc +0 -0
  8. llms/__pycache__/main.cpython-313.pyc +0 -0
  9. llms/__pycache__/main.cpython-314.pyc +0 -0
  10. llms/__pycache__/plugins.cpython-314.pyc +0 -0
  11. llms/{ui/Analytics.mjs → extensions/analytics/ui/index.mjs} +154 -238
  12. llms/extensions/app/README.md +20 -0
  13. llms/extensions/app/__init__.py +530 -0
  14. llms/extensions/app/__pycache__/__init__.cpython-314.pyc +0 -0
  15. llms/extensions/app/__pycache__/db.cpython-314.pyc +0 -0
  16. llms/extensions/app/__pycache__/db_manager.cpython-314.pyc +0 -0
  17. llms/extensions/app/db.py +644 -0
  18. llms/extensions/app/db_manager.py +195 -0
  19. llms/extensions/app/requests.json +9073 -0
  20. llms/extensions/app/threads.json +15290 -0
  21. llms/{ui → extensions/app/ui}/Recents.mjs +91 -65
  22. llms/{ui/Sidebar.mjs → extensions/app/ui/index.mjs} +124 -58
  23. llms/extensions/app/ui/threadStore.mjs +411 -0
  24. llms/extensions/core_tools/CALCULATOR.md +32 -0
  25. llms/extensions/core_tools/__init__.py +598 -0
  26. llms/extensions/core_tools/__pycache__/__init__.cpython-314.pyc +0 -0
  27. llms/extensions/core_tools/ui/codemirror/addon/edit/closebrackets.js +201 -0
  28. llms/extensions/core_tools/ui/codemirror/addon/edit/closetag.js +185 -0
  29. llms/extensions/core_tools/ui/codemirror/addon/edit/continuelist.js +101 -0
  30. llms/extensions/core_tools/ui/codemirror/addon/edit/matchbrackets.js +160 -0
  31. llms/extensions/core_tools/ui/codemirror/addon/edit/matchtags.js +66 -0
  32. llms/extensions/core_tools/ui/codemirror/addon/edit/trailingspace.js +27 -0
  33. llms/extensions/core_tools/ui/codemirror/addon/selection/active-line.js +72 -0
  34. llms/extensions/core_tools/ui/codemirror/addon/selection/mark-selection.js +119 -0
  35. llms/extensions/core_tools/ui/codemirror/addon/selection/selection-pointer.js +98 -0
  36. llms/extensions/core_tools/ui/codemirror/doc/docs.css +225 -0
  37. llms/extensions/core_tools/ui/codemirror/doc/source_sans.woff +0 -0
  38. llms/extensions/core_tools/ui/codemirror/lib/codemirror.css +344 -0
  39. llms/extensions/core_tools/ui/codemirror/lib/codemirror.js +9884 -0
  40. llms/extensions/core_tools/ui/codemirror/mode/clike/clike.js +942 -0
  41. llms/extensions/core_tools/ui/codemirror/mode/javascript/index.html +118 -0
  42. llms/extensions/core_tools/ui/codemirror/mode/javascript/javascript.js +962 -0
  43. llms/extensions/core_tools/ui/codemirror/mode/javascript/typescript.html +62 -0
  44. llms/extensions/core_tools/ui/codemirror/mode/python/python.js +402 -0
  45. llms/extensions/core_tools/ui/codemirror/theme/dracula.css +40 -0
  46. llms/extensions/core_tools/ui/codemirror/theme/mocha.css +135 -0
  47. llms/extensions/core_tools/ui/index.mjs +650 -0
  48. llms/extensions/gallery/README.md +61 -0
  49. llms/extensions/gallery/__init__.py +61 -0
  50. llms/extensions/gallery/__pycache__/__init__.cpython-314.pyc +0 -0
  51. llms/extensions/gallery/__pycache__/db.cpython-314.pyc +0 -0
  52. llms/extensions/gallery/db.py +298 -0
  53. llms/extensions/gallery/ui/index.mjs +482 -0
  54. llms/extensions/katex/README.md +39 -0
  55. llms/extensions/katex/__init__.py +6 -0
  56. llms/extensions/katex/__pycache__/__init__.cpython-314.pyc +0 -0
  57. llms/extensions/katex/ui/README.md +125 -0
  58. llms/extensions/katex/ui/contrib/auto-render.js +338 -0
  59. llms/extensions/katex/ui/contrib/auto-render.min.js +1 -0
  60. llms/extensions/katex/ui/contrib/auto-render.mjs +244 -0
  61. llms/extensions/katex/ui/contrib/copy-tex.js +127 -0
  62. llms/extensions/katex/ui/contrib/copy-tex.min.js +1 -0
  63. llms/extensions/katex/ui/contrib/copy-tex.mjs +105 -0
  64. llms/extensions/katex/ui/contrib/mathtex-script-type.js +109 -0
  65. llms/extensions/katex/ui/contrib/mathtex-script-type.min.js +1 -0
  66. llms/extensions/katex/ui/contrib/mathtex-script-type.mjs +24 -0
  67. llms/extensions/katex/ui/contrib/mhchem.js +3213 -0
  68. llms/extensions/katex/ui/contrib/mhchem.min.js +1 -0
  69. llms/extensions/katex/ui/contrib/mhchem.mjs +3109 -0
  70. llms/extensions/katex/ui/contrib/render-a11y-string.js +887 -0
  71. llms/extensions/katex/ui/contrib/render-a11y-string.min.js +1 -0
  72. llms/extensions/katex/ui/contrib/render-a11y-string.mjs +800 -0
  73. llms/extensions/katex/ui/fonts/KaTeX_AMS-Regular.ttf +0 -0
  74. llms/extensions/katex/ui/fonts/KaTeX_AMS-Regular.woff +0 -0
  75. llms/extensions/katex/ui/fonts/KaTeX_AMS-Regular.woff2 +0 -0
  76. llms/extensions/katex/ui/fonts/KaTeX_Caligraphic-Bold.ttf +0 -0
  77. llms/extensions/katex/ui/fonts/KaTeX_Caligraphic-Bold.woff +0 -0
  78. llms/extensions/katex/ui/fonts/KaTeX_Caligraphic-Bold.woff2 +0 -0
  79. llms/extensions/katex/ui/fonts/KaTeX_Caligraphic-Regular.ttf +0 -0
  80. llms/extensions/katex/ui/fonts/KaTeX_Caligraphic-Regular.woff +0 -0
  81. llms/extensions/katex/ui/fonts/KaTeX_Caligraphic-Regular.woff2 +0 -0
  82. llms/extensions/katex/ui/fonts/KaTeX_Fraktur-Bold.ttf +0 -0
  83. llms/extensions/katex/ui/fonts/KaTeX_Fraktur-Bold.woff +0 -0
  84. llms/extensions/katex/ui/fonts/KaTeX_Fraktur-Bold.woff2 +0 -0
  85. llms/extensions/katex/ui/fonts/KaTeX_Fraktur-Regular.ttf +0 -0
  86. llms/extensions/katex/ui/fonts/KaTeX_Fraktur-Regular.woff +0 -0
  87. llms/extensions/katex/ui/fonts/KaTeX_Fraktur-Regular.woff2 +0 -0
  88. llms/extensions/katex/ui/fonts/KaTeX_Main-Bold.ttf +0 -0
  89. llms/extensions/katex/ui/fonts/KaTeX_Main-Bold.woff +0 -0
  90. llms/extensions/katex/ui/fonts/KaTeX_Main-Bold.woff2 +0 -0
  91. llms/extensions/katex/ui/fonts/KaTeX_Main-BoldItalic.ttf +0 -0
  92. llms/extensions/katex/ui/fonts/KaTeX_Main-BoldItalic.woff +0 -0
  93. llms/extensions/katex/ui/fonts/KaTeX_Main-BoldItalic.woff2 +0 -0
  94. llms/extensions/katex/ui/fonts/KaTeX_Main-Italic.ttf +0 -0
  95. llms/extensions/katex/ui/fonts/KaTeX_Main-Italic.woff +0 -0
  96. llms/extensions/katex/ui/fonts/KaTeX_Main-Italic.woff2 +0 -0
  97. llms/extensions/katex/ui/fonts/KaTeX_Main-Regular.ttf +0 -0
  98. llms/extensions/katex/ui/fonts/KaTeX_Main-Regular.woff +0 -0
  99. llms/extensions/katex/ui/fonts/KaTeX_Main-Regular.woff2 +0 -0
  100. llms/extensions/katex/ui/fonts/KaTeX_Math-BoldItalic.ttf +0 -0
  101. llms/extensions/katex/ui/fonts/KaTeX_Math-BoldItalic.woff +0 -0
  102. llms/extensions/katex/ui/fonts/KaTeX_Math-BoldItalic.woff2 +0 -0
  103. llms/extensions/katex/ui/fonts/KaTeX_Math-Italic.ttf +0 -0
  104. llms/extensions/katex/ui/fonts/KaTeX_Math-Italic.woff +0 -0
  105. llms/extensions/katex/ui/fonts/KaTeX_Math-Italic.woff2 +0 -0
  106. llms/extensions/katex/ui/fonts/KaTeX_SansSerif-Bold.ttf +0 -0
  107. llms/extensions/katex/ui/fonts/KaTeX_SansSerif-Bold.woff +0 -0
  108. llms/extensions/katex/ui/fonts/KaTeX_SansSerif-Bold.woff2 +0 -0
  109. llms/extensions/katex/ui/fonts/KaTeX_SansSerif-Italic.ttf +0 -0
  110. llms/extensions/katex/ui/fonts/KaTeX_SansSerif-Italic.woff +0 -0
  111. llms/extensions/katex/ui/fonts/KaTeX_SansSerif-Italic.woff2 +0 -0
  112. llms/extensions/katex/ui/fonts/KaTeX_SansSerif-Regular.ttf +0 -0
  113. llms/extensions/katex/ui/fonts/KaTeX_SansSerif-Regular.woff +0 -0
  114. llms/extensions/katex/ui/fonts/KaTeX_SansSerif-Regular.woff2 +0 -0
  115. llms/extensions/katex/ui/fonts/KaTeX_Script-Regular.ttf +0 -0
  116. llms/extensions/katex/ui/fonts/KaTeX_Script-Regular.woff +0 -0
  117. llms/extensions/katex/ui/fonts/KaTeX_Script-Regular.woff2 +0 -0
  118. llms/extensions/katex/ui/fonts/KaTeX_Size1-Regular.ttf +0 -0
  119. llms/extensions/katex/ui/fonts/KaTeX_Size1-Regular.woff +0 -0
  120. llms/extensions/katex/ui/fonts/KaTeX_Size1-Regular.woff2 +0 -0
  121. llms/extensions/katex/ui/fonts/KaTeX_Size2-Regular.ttf +0 -0
  122. llms/extensions/katex/ui/fonts/KaTeX_Size2-Regular.woff +0 -0
  123. llms/extensions/katex/ui/fonts/KaTeX_Size2-Regular.woff2 +0 -0
  124. llms/extensions/katex/ui/fonts/KaTeX_Size3-Regular.ttf +0 -0
  125. llms/extensions/katex/ui/fonts/KaTeX_Size3-Regular.woff +0 -0
  126. llms/extensions/katex/ui/fonts/KaTeX_Size3-Regular.woff2 +0 -0
  127. llms/extensions/katex/ui/fonts/KaTeX_Size4-Regular.ttf +0 -0
  128. llms/extensions/katex/ui/fonts/KaTeX_Size4-Regular.woff +0 -0
  129. llms/extensions/katex/ui/fonts/KaTeX_Size4-Regular.woff2 +0 -0
  130. llms/extensions/katex/ui/fonts/KaTeX_Typewriter-Regular.ttf +0 -0
  131. llms/extensions/katex/ui/fonts/KaTeX_Typewriter-Regular.woff +0 -0
  132. llms/extensions/katex/ui/fonts/KaTeX_Typewriter-Regular.woff2 +0 -0
  133. llms/extensions/katex/ui/index.mjs +92 -0
  134. llms/extensions/katex/ui/katex-swap.css +1230 -0
  135. llms/extensions/katex/ui/katex-swap.min.css +1 -0
  136. llms/extensions/katex/ui/katex.css +1230 -0
  137. llms/extensions/katex/ui/katex.js +19080 -0
  138. llms/extensions/katex/ui/katex.min.css +1 -0
  139. llms/extensions/katex/ui/katex.min.js +1 -0
  140. llms/extensions/katex/ui/katex.min.mjs +1 -0
  141. llms/extensions/katex/ui/katex.mjs +18547 -0
  142. llms/extensions/providers/__init__.py +18 -0
  143. llms/extensions/providers/__pycache__/__init__.cpython-314.pyc +0 -0
  144. llms/extensions/providers/__pycache__/anthropic.cpython-314.pyc +0 -0
  145. llms/extensions/providers/__pycache__/chutes.cpython-314.pyc +0 -0
  146. llms/extensions/providers/__pycache__/google.cpython-314.pyc +0 -0
  147. llms/extensions/providers/__pycache__/nvidia.cpython-314.pyc +0 -0
  148. llms/extensions/providers/__pycache__/openai.cpython-314.pyc +0 -0
  149. llms/extensions/providers/__pycache__/openrouter.cpython-314.pyc +0 -0
  150. llms/extensions/providers/anthropic.py +229 -0
  151. llms/extensions/providers/chutes.py +155 -0
  152. llms/extensions/providers/google.py +378 -0
  153. llms/extensions/providers/nvidia.py +105 -0
  154. llms/extensions/providers/openai.py +156 -0
  155. llms/extensions/providers/openrouter.py +72 -0
  156. llms/extensions/system_prompts/README.md +22 -0
  157. llms/extensions/system_prompts/__init__.py +45 -0
  158. llms/extensions/system_prompts/__pycache__/__init__.cpython-314.pyc +0 -0
  159. llms/extensions/system_prompts/ui/index.mjs +280 -0
  160. llms/extensions/system_prompts/ui/prompts.json +1067 -0
  161. llms/extensions/tools/__init__.py +5 -0
  162. llms/extensions/tools/__pycache__/__init__.cpython-314.pyc +0 -0
  163. llms/extensions/tools/ui/index.mjs +204 -0
  164. llms/index.html +35 -77
  165. llms/llms.json +357 -1186
  166. llms/main.py +2349 -591
  167. llms/providers-extra.json +356 -0
  168. llms/providers.json +1 -0
  169. llms/ui/App.mjs +151 -60
  170. llms/ui/ai.mjs +132 -60
  171. llms/ui/app.css +2173 -161
  172. llms/ui/ctx.mjs +365 -0
  173. llms/ui/index.mjs +129 -0
  174. llms/ui/lib/charts.mjs +9 -13
  175. llms/ui/lib/servicestack-vue.mjs +3 -3
  176. llms/ui/lib/vue.min.mjs +10 -9
  177. llms/ui/lib/vue.mjs +1796 -1635
  178. llms/ui/markdown.mjs +18 -7
  179. llms/ui/modules/chat/ChatBody.mjs +691 -0
  180. llms/ui/{SettingsDialog.mjs → modules/chat/SettingsDialog.mjs} +9 -9
  181. llms/ui/modules/chat/index.mjs +828 -0
  182. llms/ui/modules/layout.mjs +243 -0
  183. llms/ui/modules/model-selector.mjs +851 -0
  184. llms/ui/tailwind.input.css +496 -80
  185. llms/ui/utils.mjs +161 -93
  186. {llms_py-2.0.35.dist-info → llms_py-3.0.0.dist-info}/METADATA +1 -1
  187. llms_py-3.0.0.dist-info/RECORD +202 -0
  188. llms/ui/Avatar.mjs +0 -85
  189. llms/ui/Brand.mjs +0 -52
  190. llms/ui/ChatPrompt.mjs +0 -590
  191. llms/ui/Main.mjs +0 -823
  192. llms/ui/ModelSelector.mjs +0 -78
  193. llms/ui/OAuthSignIn.mjs +0 -92
  194. llms/ui/ProviderIcon.mjs +0 -30
  195. llms/ui/ProviderStatus.mjs +0 -105
  196. llms/ui/SignIn.mjs +0 -64
  197. llms/ui/SystemPromptEditor.mjs +0 -31
  198. llms/ui/SystemPromptSelector.mjs +0 -56
  199. llms/ui/Welcome.mjs +0 -8
  200. llms/ui/threadStore.mjs +0 -563
  201. llms/ui.json +0 -1069
  202. llms_py-2.0.35.dist-info/RECORD +0 -48
  203. {llms_py-2.0.35.dist-info → llms_py-3.0.0.dist-info}/WHEEL +0 -0
  204. {llms_py-2.0.35.dist-info → llms_py-3.0.0.dist-info}/entry_points.txt +0 -0
  205. {llms_py-2.0.35.dist-info → llms_py-3.0.0.dist-info}/licenses/LICENSE +0 -0
  206. {llms_py-2.0.35.dist-info → llms_py-3.0.0.dist-info}/top_level.txt +0 -0
llms/main.py CHANGED
@@ -9,20 +9,27 @@
9
9
  import argparse
10
10
  import asyncio
11
11
  import base64
12
+ import contextlib
13
+ import hashlib
14
+ import importlib.util
15
+ import inspect
12
16
  import json
13
17
  import mimetypes
14
18
  import os
15
19
  import re
16
20
  import secrets
21
+ import shutil
17
22
  import site
18
23
  import subprocess
19
24
  import sys
20
25
  import time
21
26
  import traceback
27
+ from datetime import datetime
22
28
  from importlib import resources # Py≥3.9 (pip install importlib_resources for 3.7/3.8)
23
29
  from io import BytesIO
24
30
  from pathlib import Path
25
- from urllib.parse import parse_qs, urlencode
31
+ from typing import get_type_hints
32
+ from urllib.parse import parse_qs, urlencode, urljoin
26
33
 
27
34
  import aiohttp
28
35
  from aiohttp import web
@@ -34,25 +41,40 @@ try:
34
41
  except ImportError:
35
42
  HAS_PIL = False
36
43
 
37
- VERSION = "2.0.35"
44
+ VERSION = "3.0.0"
38
45
  _ROOT = None
46
+ DEBUG = os.getenv("DEBUG") == "1"
47
+ MOCK = os.getenv("MOCK") == "1"
48
+ MOCK_DIR = os.getenv("MOCK_DIR")
49
+ DISABLE_EXTENSIONS = (os.getenv("LLMS_DISABLE") or "").split(",")
39
50
  g_config_path = None
40
- g_ui_path = None
41
51
  g_config = None
52
+ g_providers = None
42
53
  g_handlers = {}
43
54
  g_verbose = False
44
55
  g_logprefix = ""
45
56
  g_default_model = ""
46
57
  g_sessions = {} # OAuth session storage: {session_token: {userId, userName, displayName, profileUrl, email, created}}
47
58
  g_oauth_states = {} # CSRF protection: {state: {created, redirect_uri}}
59
+ g_app = None # ExtensionsContext Singleton
48
60
 
49
61
 
50
62
  def _log(message):
51
- """Helper method for logging from the global polling task."""
52
63
  if g_verbose:
53
64
  print(f"{g_logprefix}{message}", flush=True)
54
65
 
55
66
 
67
+ def _dbg(message):
68
+ if DEBUG:
69
+ print(f"DEBUG: {message}", flush=True)
70
+
71
+
72
+ def _err(message, e):
73
+ print(f"ERROR: {message}: {e}", flush=True)
74
+ if g_verbose:
75
+ print(traceback.format_exc(), flush=True)
76
+
77
+
56
78
  def printdump(obj):
57
79
  args = obj.__dict__ if hasattr(obj, "__dict__") else obj
58
80
  print(json.dumps(args, indent=2))
@@ -85,17 +107,6 @@ def chat_summary(chat):
85
107
  return json.dumps(clone, indent=2)
86
108
 
87
109
 
88
- def gemini_chat_summary(gemini_chat):
89
- """Summarize Gemini chat completion request for logging. Replace inline_data with size of content only"""
90
- clone = json.loads(json.dumps(gemini_chat))
91
- for content in clone["contents"]:
92
- for part in content["parts"]:
93
- if "inline_data" in part:
94
- data = part["inline_data"]["data"]
95
- part["inline_data"]["data"] = f"({len(data)})"
96
- return json.dumps(clone, indent=2)
97
-
98
-
99
110
  image_exts = ["png", "webp", "jpg", "jpeg", "gif", "bmp", "svg", "tiff", "ico"]
100
111
  audio_exts = ["mp3", "wav", "ogg", "flac", "m4a", "opus", "webm"]
101
112
 
@@ -189,6 +200,16 @@ def is_base_64(data):
189
200
  return False
190
201
 
191
202
 
203
+ def id_to_name(id):
204
+ return id.replace("-", " ").title()
205
+
206
+
207
+ def pluralize(word, count):
208
+ if count == 1:
209
+ return word
210
+ return word + "s"
211
+
212
+
192
213
  def get_file_mime_type(filename):
193
214
  mime_type, _ = mimetypes.guess_type(filename)
194
215
  return mime_type or "application/octet-stream"
@@ -310,11 +331,52 @@ def convert_image_if_needed(image_bytes, mimetype="image/png"):
310
331
  return image_bytes, mimetype
311
332
 
312
333
 
313
- async def process_chat(chat):
334
+ def to_content(result):
335
+ if isinstance(result, (str, int, float, bool)):
336
+ return str(result)
337
+ elif isinstance(result, (list, set, tuple, dict)):
338
+ return json.dumps(result)
339
+ else:
340
+ return str(result)
341
+
342
+
343
+ def function_to_tool_definition(func):
344
+ type_hints = get_type_hints(func)
345
+ signature = inspect.signature(func)
346
+ parameters = {"type": "object", "properties": {}, "required": []}
347
+
348
+ for name, param in signature.parameters.items():
349
+ param_type = type_hints.get(name, str)
350
+ param_type_name = "string"
351
+ if param_type is int:
352
+ param_type_name = "integer"
353
+ elif param_type is float:
354
+ param_type_name = "number"
355
+ elif param_type is bool:
356
+ param_type_name = "boolean"
357
+
358
+ parameters["properties"][name] = {"type": param_type_name}
359
+ if param.default == inspect.Parameter.empty:
360
+ parameters["required"].append(name)
361
+
362
+ return {
363
+ "type": "function",
364
+ "function": {
365
+ "name": func.__name__,
366
+ "description": func.__doc__ or "",
367
+ "parameters": parameters,
368
+ },
369
+ }
370
+
371
+
372
+ async def process_chat(chat, provider_id=None):
314
373
  if not chat:
315
374
  raise Exception("No chat provided")
316
375
  if "stream" not in chat:
317
376
  chat["stream"] = False
377
+ # Some providers don't support empty tools
378
+ if "tools" in chat and len(chat["tools"]) == 0:
379
+ del chat["tools"]
318
380
  if "messages" not in chat:
319
381
  return chat
320
382
 
@@ -331,6 +393,8 @@ async def process_chat(chat):
331
393
  image_url = item["image_url"]
332
394
  if "url" in image_url:
333
395
  url = image_url["url"]
396
+ if url.startswith("/~cache/"):
397
+ url = get_cache_path(url[8:])
334
398
  if is_url(url):
335
399
  _log(f"Downloading image: {url}")
336
400
  async with session.get(url, timeout=aiohttp.ClientTimeout(total=120)) as response:
@@ -377,6 +441,8 @@ async def process_chat(chat):
377
441
  input_audio = item["input_audio"]
378
442
  if "data" in input_audio:
379
443
  url = input_audio["data"]
444
+ if url.startswith("/~cache/"):
445
+ url = get_cache_path(url[8:])
380
446
  mimetype = get_file_mime_type(get_filename(url))
381
447
  if is_url(url):
382
448
  _log(f"Downloading audio: {url}")
@@ -388,6 +454,8 @@ async def process_chat(chat):
388
454
  mimetype = response.headers["Content-Type"]
389
455
  # convert to base64
390
456
  input_audio["data"] = base64.b64encode(content).decode("utf-8")
457
+ if provider_id == "alibaba":
458
+ input_audio["data"] = f"data:{mimetype};base64,{input_audio['data']}"
391
459
  input_audio["format"] = mimetype.rsplit("/", 1)[1]
392
460
  elif is_file_path(url):
393
461
  _log(f"Reading audio: {url}")
@@ -395,6 +463,8 @@ async def process_chat(chat):
395
463
  content = f.read()
396
464
  # convert to base64
397
465
  input_audio["data"] = base64.b64encode(content).decode("utf-8")
466
+ if provider_id == "alibaba":
467
+ input_audio["data"] = f"data:{mimetype};base64,{input_audio['data']}"
398
468
  input_audio["format"] = mimetype.rsplit("/", 1)[1]
399
469
  elif is_base_64(url):
400
470
  pass # use base64 data as-is
@@ -404,6 +474,8 @@ async def process_chat(chat):
404
474
  file = item["file"]
405
475
  if "file_data" in file:
406
476
  url = file["file_data"]
477
+ if url.startswith("/~cache/"):
478
+ url = get_cache_path(url[8:])
407
479
  mimetype = get_file_mime_type(get_filename(url))
408
480
  if is_url(url):
409
481
  _log(f"Downloading file: {url}")
@@ -431,6 +503,92 @@ async def process_chat(chat):
431
503
  return chat
432
504
 
433
505
 
506
+ def image_ext_from_mimetype(mimetype, default="png"):
507
+ if "/" in mimetype:
508
+ _ext = mimetypes.guess_extension(mimetype)
509
+ if _ext:
510
+ return _ext.lstrip(".")
511
+ return default
512
+
513
+
514
+ def audio_ext_from_format(format, default="mp3"):
515
+ if format == "mpeg":
516
+ return "mp3"
517
+ return format or default
518
+
519
+
520
+ def file_ext_from_mimetype(mimetype, default="pdf"):
521
+ if "/" in mimetype:
522
+ _ext = mimetypes.guess_extension(mimetype)
523
+ if _ext:
524
+ return _ext.lstrip(".")
525
+ return default
526
+
527
+
528
+ def cache_message_inline_data(m):
529
+ """
530
+ Replaces and caches any inline data URIs in the message content.
531
+ """
532
+ if "content" not in m:
533
+ return
534
+
535
+ content = m["content"]
536
+ if isinstance(content, list):
537
+ for item in content:
538
+ if item.get("type") == "image_url":
539
+ image_url = item.get("image_url", {})
540
+ url = image_url.get("url")
541
+ if url and url.startswith("data:"):
542
+ # Extract base64 and mimetype
543
+ try:
544
+ header, base64_data = url.split(";base64,")
545
+ # header is like "data:image/png"
546
+ ext = image_ext_from_mimetype(header.split(":")[1])
547
+ filename = f"image.{ext}" # Hash will handle uniqueness
548
+
549
+ cache_url, _ = save_image_to_cache(base64_data, filename, {}, ignore_info=True)
550
+ image_url["url"] = cache_url
551
+ except Exception as e:
552
+ _log(f"Error caching inline image: {e}")
553
+
554
+ elif item.get("type") == "input_audio":
555
+ input_audio = item.get("input_audio", {})
556
+ data = input_audio.get("data")
557
+ if data:
558
+ # Handle data URI or raw base64
559
+ base64_data = data
560
+ if data.startswith("data:"):
561
+ with contextlib.suppress(ValueError):
562
+ header, base64_data = data.split(";base64,")
563
+
564
+ fmt = audio_ext_from_format(input_audio.get("format"))
565
+ filename = f"audio.{fmt}"
566
+
567
+ try:
568
+ cache_url, _ = save_bytes_to_cache(base64_data, filename, {}, ignore_info=True)
569
+ input_audio["data"] = cache_url
570
+ except Exception as e:
571
+ _log(f"Error caching inline audio: {e}")
572
+
573
+ elif item.get("type") == "file":
574
+ file_info = item.get("file", {})
575
+ file_data = file_info.get("file_data")
576
+ if file_data and file_data.startswith("data:"):
577
+ try:
578
+ header, base64_data = file_data.split(";base64,")
579
+ mimetype = header.split(":")[1]
580
+ # Try to get extension from filename if available, else mimetype
581
+ filename = file_info.get("filename", "file")
582
+ if "." not in filename:
583
+ ext = file_ext_from_mimetype(mimetype)
584
+ filename = f"{filename}.{ext}"
585
+
586
+ cache_url, _ = save_bytes_to_cache(base64_data, filename, {}, ignore_info=True)
587
+ file_info["file_data"] = cache_url
588
+ except Exception as e:
589
+ _log(f"Error caching inline file: {e}")
590
+
591
+
434
592
  class HTTPError(Exception):
435
593
  def __init__(self, status, reason, body, headers=None):
436
594
  self.status = status
@@ -440,33 +598,323 @@ class HTTPError(Exception):
440
598
  super().__init__(f"HTTP {status} {reason}")
441
599
 
442
600
 
601
+ def save_bytes_to_cache(base64_data, filename, file_info, ignore_info=False):
602
+ ext = filename.split(".")[-1]
603
+ mimetype = get_file_mime_type(filename)
604
+ content = base64.b64decode(base64_data) if isinstance(base64_data, str) else base64_data
605
+ sha256_hash = hashlib.sha256(content).hexdigest()
606
+
607
+ save_filename = f"{sha256_hash}.{ext}" if ext else sha256_hash
608
+
609
+ # Use first 2 chars for subdir to avoid too many files in one dir
610
+ subdir = sha256_hash[:2]
611
+ relative_path = f"{subdir}/{save_filename}"
612
+ full_path = get_cache_path(relative_path)
613
+ url = f"/~cache/{relative_path}"
614
+
615
+ # if file and its .info.json already exists, return it
616
+ info_path = os.path.splitext(full_path)[0] + ".info.json"
617
+ if os.path.exists(full_path) and os.path.exists(info_path):
618
+ _dbg(f"Cached bytes exists: {relative_path}")
619
+ if ignore_info:
620
+ return url, None
621
+ return url, json.load(open(info_path))
622
+
623
+ os.makedirs(os.path.dirname(full_path), exist_ok=True)
624
+
625
+ with open(full_path, "wb") as f:
626
+ f.write(content)
627
+ info = {
628
+ "date": int(time.time()),
629
+ "url": url,
630
+ "size": len(content),
631
+ "type": mimetype,
632
+ "name": filename,
633
+ }
634
+ info.update(file_info)
635
+
636
+ # Save metadata
637
+ info_path = os.path.splitext(full_path)[0] + ".info.json"
638
+ with open(info_path, "w") as f:
639
+ json.dump(info, f)
640
+
641
+ _dbg(f"Saved cached bytes and info: {relative_path}")
642
+
643
+ g_app.on_cache_saved_filters({"url": url, "info": info})
644
+
645
+ return url, info
646
+
647
+
648
+ def save_image_to_cache(base64_data, filename, image_info, ignore_info=False):
649
+ ext = filename.split(".")[-1]
650
+ mimetype = get_file_mime_type(filename)
651
+ content = base64.b64decode(base64_data) if isinstance(base64_data, str) else base64_data
652
+ sha256_hash = hashlib.sha256(content).hexdigest()
653
+
654
+ save_filename = f"{sha256_hash}.{ext}" if ext else sha256_hash
655
+
656
+ # Use first 2 chars for subdir to avoid too many files in one dir
657
+ subdir = sha256_hash[:2]
658
+ relative_path = f"{subdir}/{save_filename}"
659
+ full_path = get_cache_path(relative_path)
660
+ url = f"/~cache/{relative_path}"
661
+
662
+ # if file and its .info.json already exists, return it
663
+ info_path = os.path.splitext(full_path)[0] + ".info.json"
664
+ if os.path.exists(full_path) and os.path.exists(info_path):
665
+ _dbg(f"Saved image exists: {relative_path}")
666
+ if ignore_info:
667
+ return url, None
668
+ return url, json.load(open(info_path))
669
+
670
+ os.makedirs(os.path.dirname(full_path), exist_ok=True)
671
+
672
+ with open(full_path, "wb") as f:
673
+ f.write(content)
674
+ info = {
675
+ "date": int(time.time()),
676
+ "url": url,
677
+ "size": len(content),
678
+ "type": mimetype,
679
+ "name": filename,
680
+ }
681
+ info.update(image_info)
682
+
683
+ # If image, get dimensions
684
+ if HAS_PIL and mimetype.startswith("image/"):
685
+ try:
686
+ with Image.open(BytesIO(content)) as img:
687
+ info["width"] = img.width
688
+ info["height"] = img.height
689
+ except Exception:
690
+ pass
691
+
692
+ if "width" in info and "height" in info:
693
+ _log(f"Saved image to cache: {full_path} ({len(content)} bytes) {info['width']}x{info['height']}")
694
+ else:
695
+ _log(f"Saved image to cache: {full_path} ({len(content)} bytes)")
696
+
697
+ # Save metadata
698
+ info_path = os.path.splitext(full_path)[0] + ".info.json"
699
+ with open(info_path, "w") as f:
700
+ json.dump(info, f)
701
+
702
+ _dbg(f"Saved image and info: {relative_path}")
703
+
704
+ g_app.on_cache_saved_filters({"url": url, "info": info})
705
+
706
+ return url, info
707
+
708
+
443
709
  async def response_json(response):
444
710
  text = await response.text()
445
711
  if response.status >= 400:
712
+ _dbg(f"HTTP {response.status} {response.reason}: {text}")
446
713
  raise HTTPError(response.status, reason=response.reason, body=text, headers=dict(response.headers))
447
714
  response.raise_for_status()
448
715
  body = json.loads(text)
449
716
  return body
450
717
 
451
718
 
452
- class OpenAiProvider:
453
- def __init__(self, base_url, api_key=None, models=None, **kwargs):
454
- if models is None:
455
- models = {}
456
- self.base_url = base_url.strip("/")
457
- self.api_key = api_key
458
- self.models = models
719
+ def chat_to_prompt(chat):
720
+ prompt = ""
721
+ if "messages" in chat:
722
+ for message in chat["messages"]:
723
+ if message["role"] == "user":
724
+ # if content is string
725
+ if isinstance(message["content"], str):
726
+ if prompt:
727
+ prompt += "\n"
728
+ prompt += message["content"]
729
+ elif isinstance(message["content"], list):
730
+ # if content is array of objects
731
+ for part in message["content"]:
732
+ if part["type"] == "text":
733
+ if prompt:
734
+ prompt += "\n"
735
+ prompt += part["text"]
736
+ return prompt
737
+
738
+
739
+ def chat_to_system_prompt(chat):
740
+ if "messages" in chat:
741
+ for message in chat["messages"]:
742
+ if message["role"] == "system":
743
+ # if content is string
744
+ if isinstance(message["content"], str):
745
+ return message["content"]
746
+ elif isinstance(message["content"], list):
747
+ # if content is array of objects
748
+ for part in message["content"]:
749
+ if part["type"] == "text":
750
+ return part["text"]
751
+ return None
459
752
 
460
- # check if base_url ends with /v{\d} to handle providers with different versions (e.g. z.ai uses /v4)
461
- last_segment = base_url.rsplit("/", 1)[1]
462
- if last_segment.startswith("v") and last_segment[1:].isdigit():
463
- self.chat_url = f"{base_url}/chat/completions"
464
- else:
465
- self.chat_url = f"{base_url}/v1/chat/completions"
753
+
754
+ def chat_to_username(chat):
755
+ if "metadata" in chat and "user" in chat["metadata"]:
756
+ return chat["metadata"]["user"]
757
+ return None
758
+
759
+
760
+ def last_user_prompt(chat):
761
+ prompt = ""
762
+ if "messages" in chat:
763
+ for message in chat["messages"]:
764
+ if message["role"] == "user":
765
+ # if content is string
766
+ if isinstance(message["content"], str):
767
+ prompt = message["content"]
768
+ elif isinstance(message["content"], list):
769
+ # if content is array of objects
770
+ for part in message["content"]:
771
+ if part["type"] == "text":
772
+ prompt = part["text"]
773
+ return prompt
774
+
775
+
776
+ def chat_response_to_message(openai_response):
777
+ """
778
+ Returns an assistant message from the OpenAI Response.
779
+ Handles normalizing text, image, and audio responses into the message content.
780
+ """
781
+ timestamp = int(time.time() * 1000) # openai_response.get("created")
782
+ choices = openai_response
783
+ if isinstance(openai_response, dict) and "choices" in openai_response:
784
+ choices = openai_response["choices"]
785
+
786
+ choice = choices[0] if isinstance(choices, list) and choices else choices
787
+
788
+ if isinstance(choice, str):
789
+ return {"role": "assistant", "content": choice, "timestamp": timestamp}
790
+
791
+ if isinstance(choice, dict):
792
+ message = choice.get("message", choice)
793
+ else:
794
+ return {"role": "assistant", "content": str(choice), "timestamp": timestamp}
795
+
796
+ # Ensure message is a dict
797
+ if not isinstance(message, dict):
798
+ return {"role": "assistant", "content": message, "timestamp": timestamp}
799
+
800
+ message.update({"timestamp": timestamp})
801
+ return message
802
+
803
+
804
+ def to_file_info(chat, info=None, response=None):
805
+ prompt = last_user_prompt(chat)
806
+ ret = info or {}
807
+ if chat["model"] and "model" not in ret:
808
+ ret["model"] = chat["model"]
809
+ if prompt and "prompt" not in ret:
810
+ ret["prompt"] = prompt
811
+ if "image_config" in chat:
812
+ ret.update(chat["image_config"])
813
+ user = chat_to_username(chat)
814
+ if user:
815
+ ret["user"] = user
816
+ return ret
817
+
818
+
819
+ # Image Generator Providers
820
+ class GeneratorBase:
821
+ def __init__(self, **kwargs):
822
+ self.id = kwargs.get("id")
823
+ self.api = kwargs.get("api")
824
+ self.api_key = kwargs.get("api_key")
825
+ self.headers = {
826
+ "Accept": "application/json",
827
+ "Content-Type": "application/json",
828
+ }
829
+ self.chat_url = f"{self.api}/chat/completions"
830
+ self.default_content = "I've generated the image for you."
831
+
832
+ def validate(self, **kwargs):
833
+ if not self.api_key:
834
+ api_keys = ", ".join(self.env)
835
+ return f"Provider '{self.name}' requires API Key {api_keys}"
836
+ return None
837
+
838
+ def test(self, **kwargs):
839
+ error_msg = self.validate(**kwargs)
840
+ if error_msg:
841
+ _log(error_msg)
842
+ return False
843
+ return True
844
+
845
+ async def load(self):
846
+ pass
847
+
848
+ def gen_summary(self, gen):
849
+ """Summarize gen response for logging."""
850
+ clone = json.loads(json.dumps(gen))
851
+ return json.dumps(clone, indent=2)
852
+
853
+ def chat_summary(self, chat):
854
+ return chat_summary(chat)
855
+
856
+ def process_chat(self, chat, provider_id=None):
857
+ return process_chat(chat, provider_id)
858
+
859
+ async def response_json(self, response):
860
+ return await response_json(response)
861
+
862
+ def get_headers(self, provider, chat):
863
+ headers = self.headers.copy()
864
+ if provider is not None:
865
+ headers["Authorization"] = f"Bearer {provider.api_key}"
866
+ elif self.api_key:
867
+ headers["Authorization"] = f"Bearer {self.api_key}"
868
+ return headers
869
+
870
+ def to_response(self, response, chat, started_at):
871
+ raise NotImplementedError
872
+
873
+ async def chat(self, chat, provider=None):
874
+ return {
875
+ "choices": [
876
+ {
877
+ "message": {
878
+ "role": "assistant",
879
+ "content": "Not Implemented",
880
+ "images": [
881
+ {
882
+ "type": "image_url",
883
+ "image_url": {
884
+ "url": "",
885
+ },
886
+ }
887
+ ],
888
+ }
889
+ }
890
+ ]
891
+ }
892
+
893
+
894
+ # OpenAI Providers
895
+
896
+
897
+ class OpenAiCompatible:
898
+ sdk = "@ai-sdk/openai-compatible"
899
+
900
+ def __init__(self, **kwargs):
901
+ required_args = ["id", "api"]
902
+ for arg in required_args:
903
+ if arg not in kwargs:
904
+ raise ValueError(f"Missing required argument: {arg}")
905
+
906
+ self.id = kwargs.get("id")
907
+ self.api = kwargs.get("api").strip("/")
908
+ self.env = kwargs.get("env", [])
909
+ self.api_key = kwargs.get("api_key")
910
+ self.name = kwargs.get("name", id_to_name(self.id))
911
+ self.set_models(**kwargs)
912
+
913
+ self.chat_url = f"{self.api}/chat/completions"
466
914
 
467
915
  self.headers = kwargs.get("headers", {"Content-Type": "application/json"})
468
- if api_key is not None:
469
- self.headers["Authorization"] = f"Bearer {api_key}"
916
+ if self.api_key is not None:
917
+ self.headers["Authorization"] = f"Bearer {self.api_key}"
470
918
 
471
919
  self.frequency_penalty = float(kwargs["frequency_penalty"]) if "frequency_penalty" in kwargs else None
472
920
  self.max_completion_tokens = int(kwargs["max_completion_tokens"]) if "max_completion_tokens" in kwargs else None
@@ -486,44 +934,132 @@ class OpenAiProvider:
486
934
  self.verbosity = kwargs.get("verbosity")
487
935
  self.stream = bool(kwargs["stream"]) if "stream" in kwargs else None
488
936
  self.enable_thinking = bool(kwargs["enable_thinking"]) if "enable_thinking" in kwargs else None
489
- self.pricing = kwargs.get("pricing")
490
- self.default_pricing = kwargs.get("default_pricing")
491
937
  self.check = kwargs.get("check")
938
+ self.modalities = kwargs.get("modalities", {})
939
+
940
+ def set_models(self, **kwargs):
941
+ models = kwargs.get("models", {})
942
+ self.map_models = kwargs.get("map_models", {})
943
+ # if 'map_models' is provided, only include models in `map_models[model_id] = provider_model_id`
944
+ if self.map_models:
945
+ self.models = {}
946
+ for provider_model_id in self.map_models.values():
947
+ if provider_model_id in models:
948
+ self.models[provider_model_id] = models[provider_model_id]
949
+ else:
950
+ self.models = models
951
+
952
+ include_models = kwargs.get("include_models") # string regex pattern
953
+ # only include models that match the regex pattern
954
+ if include_models:
955
+ _log(f"Filtering {len(self.models)} models, only including models that match regex: {include_models}")
956
+ self.models = {k: v for k, v in self.models.items() if re.search(include_models, k)}
957
+
958
+ exclude_models = kwargs.get("exclude_models") # string regex pattern
959
+ # exclude models that match the regex pattern
960
+ if exclude_models:
961
+ _log(f"Filtering {len(self.models)} models, excluding models that match regex: {exclude_models}")
962
+ self.models = {k: v for k, v in self.models.items() if not re.search(exclude_models, k)}
963
+
964
+ def validate(self, **kwargs):
965
+ if not self.api_key:
966
+ api_keys = ", ".join(self.env)
967
+ return f"Provider '{self.name}' requires API Key {api_keys}"
968
+ return None
492
969
 
493
- @classmethod
494
- def test(cls, base_url=None, api_key=None, models=None, **kwargs):
495
- if models is None:
496
- models = {}
497
- return base_url and api_key and len(models) > 0
970
+ def test(self, **kwargs):
971
+ error_msg = self.validate(**kwargs)
972
+ if error_msg:
973
+ _log(error_msg)
974
+ return False
975
+ return True
498
976
 
499
977
  async def load(self):
500
- pass
978
+ if not self.models:
979
+ await self.load_models()
501
980
 
502
- def model_pricing(self, model):
981
+ def model_info(self, model):
503
982
  provider_model = self.provider_model(model) or model
504
- if self.pricing and provider_model in self.pricing:
505
- return self.pricing[provider_model]
506
- return self.default_pricing or None
983
+ for model_id, model_info in self.models.items():
984
+ if model_id.lower() == provider_model.lower():
985
+ return model_info
986
+ return None
987
+
988
+ def model_cost(self, model):
989
+ model_info = self.model_info(model)
990
+ return model_info.get("cost") if model_info else None
507
991
 
508
992
  def provider_model(self, model):
509
- if model in self.models:
510
- return self.models[model]
993
+ # convert model to lowercase for case-insensitive comparison
994
+ model_lower = model.lower()
995
+
996
+ # if model is a map model id, return the provider model id
997
+ for model_id, provider_model in self.map_models.items():
998
+ if model_id.lower() == model_lower:
999
+ return provider_model
1000
+
1001
+ # if model is a provider model id, try again with just the model name
1002
+ for provider_model in self.map_models.values():
1003
+ if provider_model.lower() == model_lower:
1004
+ return provider_model
1005
+
1006
+ # if model is a model id, try again with just the model id or name
1007
+ for model_id, provider_model_info in self.models.items():
1008
+ id = provider_model_info.get("id") or model_id
1009
+ if model_id.lower() == model_lower or id.lower() == model_lower:
1010
+ return id
1011
+ name = provider_model_info.get("name")
1012
+ if name and name.lower() == model_lower:
1013
+ return id
1014
+
1015
+ # fallback to trying again with just the model short name
1016
+ for model_id, provider_model_info in self.models.items():
1017
+ id = provider_model_info.get("id") or model_id
1018
+ if "/" in id:
1019
+ model_name = id.split("/")[-1]
1020
+ if model_name.lower() == model_lower:
1021
+ return id
1022
+
1023
+ # if model is a full provider model id, try again with just the model name
1024
+ if "/" in model:
1025
+ last_part = model.split("/")[-1]
1026
+ return self.provider_model(last_part)
1027
+
511
1028
  return None
512
1029
 
1030
+ def response_json(self, response):
1031
+ return response_json(response)
1032
+
513
1033
  def to_response(self, response, chat, started_at):
514
1034
  if "metadata" not in response:
515
1035
  response["metadata"] = {}
516
1036
  response["metadata"]["duration"] = int((time.time() - started_at) * 1000)
517
1037
  if chat is not None and "model" in chat:
518
- pricing = self.model_pricing(chat["model"])
1038
+ pricing = self.model_cost(chat["model"])
519
1039
  if pricing and "input" in pricing and "output" in pricing:
520
1040
  response["metadata"]["pricing"] = f"{pricing['input']}/{pricing['output']}"
521
- _log(json.dumps(response, indent=2))
522
1041
  return response
523
1042
 
1043
+ def chat_summary(self, chat):
1044
+ return chat_summary(chat)
1045
+
1046
+ def process_chat(self, chat, provider_id=None):
1047
+ return process_chat(chat, provider_id)
1048
+
524
1049
  async def chat(self, chat):
525
1050
  chat["model"] = self.provider_model(chat["model"]) or chat["model"]
526
1051
 
1052
+ if "modalities" in chat:
1053
+ for modality in chat.get("modalities", []):
1054
+ # use default implementation for text modalities
1055
+ if modality == "text":
1056
+ continue
1057
+ modality_provider = self.modalities.get(modality)
1058
+ if modality_provider:
1059
+ return await modality_provider.chat(chat, self)
1060
+ else:
1061
+ raise Exception(f"Provider {self.name} does not support '{modality}' modality")
1062
+
527
1063
  # with open(os.path.join(os.path.dirname(__file__), 'chat.wip.json'), "w") as f:
528
1064
  # f.write(json.dumps(chat, indent=2))
529
1065
 
@@ -562,285 +1098,152 @@ class OpenAiProvider:
562
1098
  if self.enable_thinking is not None:
563
1099
  chat["enable_thinking"] = self.enable_thinking
564
1100
 
565
- chat = await process_chat(chat)
1101
+ chat = await process_chat(chat, provider_id=self.id)
566
1102
  _log(f"POST {self.chat_url}")
567
1103
  _log(chat_summary(chat))
568
1104
  # remove metadata if any (conflicts with some providers, e.g. Z.ai)
569
- chat.pop("metadata", None)
1105
+ metadata = chat.pop("metadata", None)
570
1106
 
571
1107
  async with aiohttp.ClientSession() as session:
572
1108
  started_at = time.time()
573
1109
  async with session.post(
574
1110
  self.chat_url, headers=self.headers, data=json.dumps(chat), timeout=aiohttp.ClientTimeout(total=120)
575
1111
  ) as response:
1112
+ chat["metadata"] = metadata
576
1113
  return self.to_response(await response_json(response), chat, started_at)
577
1114
 
578
1115
 
579
- class OllamaProvider(OpenAiProvider):
580
- def __init__(self, base_url, models, all_models=False, **kwargs):
581
- super().__init__(base_url=base_url, models=models, **kwargs)
582
- self.all_models = all_models
1116
+ class MistralProvider(OpenAiCompatible):
1117
+ sdk = "@ai-sdk/mistral"
1118
+
1119
+ def __init__(self, **kwargs):
1120
+ if "api" not in kwargs:
1121
+ kwargs["api"] = "https://api.mistral.ai/v1"
1122
+ super().__init__(**kwargs)
1123
+
1124
+
1125
+ class GroqProvider(OpenAiCompatible):
1126
+ sdk = "@ai-sdk/groq"
1127
+
1128
+ def __init__(self, **kwargs):
1129
+ if "api" not in kwargs:
1130
+ kwargs["api"] = "https://api.groq.com/openai/v1"
1131
+ super().__init__(**kwargs)
1132
+
1133
+
1134
+ class XaiProvider(OpenAiCompatible):
1135
+ sdk = "@ai-sdk/xai"
1136
+
1137
+ def __init__(self, **kwargs):
1138
+ if "api" not in kwargs:
1139
+ kwargs["api"] = "https://api.x.ai/v1"
1140
+ super().__init__(**kwargs)
1141
+
1142
+
1143
+ class CodestralProvider(OpenAiCompatible):
1144
+ sdk = "codestral"
1145
+
1146
+ def __init__(self, **kwargs):
1147
+ super().__init__(**kwargs)
1148
+
1149
+
1150
+ class OllamaProvider(OpenAiCompatible):
1151
+ sdk = "ollama"
1152
+
1153
+ def __init__(self, **kwargs):
1154
+ super().__init__(**kwargs)
1155
+ # Ollama's OpenAI-compatible endpoint is at /v1/chat/completions
1156
+ self.chat_url = f"{self.api}/v1/chat/completions"
583
1157
 
584
1158
  async def load(self):
585
- if self.all_models:
586
- await self.load_models(default_models=self.models)
1159
+ if not self.models:
1160
+ await self.load_models()
587
1161
 
588
1162
  async def get_models(self):
589
1163
  ret = {}
590
1164
  try:
591
1165
  async with aiohttp.ClientSession() as session:
592
- _log(f"GET {self.base_url}/api/tags")
1166
+ _log(f"GET {self.api}/api/tags")
593
1167
  async with session.get(
594
- f"{self.base_url}/api/tags", headers=self.headers, timeout=aiohttp.ClientTimeout(total=120)
1168
+ f"{self.api}/api/tags", headers=self.headers, timeout=aiohttp.ClientTimeout(total=120)
595
1169
  ) as response:
596
1170
  data = await response_json(response)
597
1171
  for model in data.get("models", []):
598
- name = model["model"]
599
- if name.endswith(":latest"):
600
- name = name[:-7]
601
- ret[name] = name
1172
+ model_id = model["model"]
1173
+ if model_id.endswith(":latest"):
1174
+ model_id = model_id[:-7]
1175
+ ret[model_id] = model_id
602
1176
  _log(f"Loaded Ollama models: {ret}")
603
1177
  except Exception as e:
604
1178
  _log(f"Error getting Ollama models: {e}")
605
1179
  # return empty dict if ollama is not available
606
1180
  return ret
607
1181
 
608
- async def load_models(self, default_models):
1182
+ async def load_models(self):
609
1183
  """Load models if all_models was requested"""
610
- if self.all_models:
611
- self.models = await self.get_models()
612
- if default_models:
613
- self.models = {**default_models, **self.models}
614
-
615
- @classmethod
616
- def test(cls, base_url=None, models=None, all_models=False, **kwargs):
617
- if models is None:
618
- models = {}
619
- return base_url and (len(models) > 0 or all_models)
620
-
621
-
622
- class GoogleOpenAiProvider(OpenAiProvider):
623
- def __init__(self, api_key, models, **kwargs):
624
- super().__init__(base_url="https://generativelanguage.googleapis.com", api_key=api_key, models=models, **kwargs)
625
- self.chat_url = "https://generativelanguage.googleapis.com/v1beta/chat/completions"
626
-
627
- @classmethod
628
- def test(cls, api_key=None, models=None, **kwargs):
629
- if models is None:
630
- models = {}
631
- return api_key and len(models) > 0
632
-
633
-
634
- class GoogleProvider(OpenAiProvider):
635
- def __init__(self, models, api_key, safety_settings=None, thinking_config=None, curl=False, **kwargs):
636
- super().__init__(base_url="https://generativelanguage.googleapis.com", api_key=api_key, models=models, **kwargs)
637
- self.safety_settings = safety_settings
638
- self.thinking_config = thinking_config
639
- self.curl = curl
640
- self.headers = kwargs.get("headers", {"Content-Type": "application/json"})
641
- # Google fails when using Authorization header, use query string param instead
642
- if "Authorization" in self.headers:
643
- del self.headers["Authorization"]
644
-
645
- @classmethod
646
- def test(cls, api_key=None, models=None, **kwargs):
647
- if models is None:
648
- models = {}
649
- return api_key is not None and len(models) > 0
650
1184
 
651
- async def chat(self, chat):
652
- chat["model"] = self.provider_model(chat["model"]) or chat["model"]
1185
+ # Map models to provider models {model_id:model_id}
1186
+ model_map = await self.get_models()
1187
+ if self.map_models:
1188
+ map_model_values = set(self.map_models.values())
1189
+ to = {}
1190
+ for k, v in model_map.items():
1191
+ if k in self.map_models:
1192
+ to[k] = v
1193
+ if v in map_model_values:
1194
+ to[k] = v
1195
+ model_map = to
1196
+ else:
1197
+ self.map_models = model_map
1198
+ models = {}
1199
+ for k, v in model_map.items():
1200
+ models[k] = {
1201
+ "id": k,
1202
+ "name": v.replace(":", " "),
1203
+ "modalities": {"input": ["text"], "output": ["text"]},
1204
+ "cost": {
1205
+ "input": 0,
1206
+ "output": 0,
1207
+ },
1208
+ }
1209
+ self.models = models
653
1210
 
654
- chat = await process_chat(chat)
655
- generation_config = {}
1211
+ def validate(self, **kwargs):
1212
+ return None
656
1213
 
657
- # Filter out system messages and convert to proper Gemini format
658
- contents = []
659
- system_prompt = None
660
1214
 
661
- async with aiohttp.ClientSession() as session:
662
- for message in chat["messages"]:
663
- if message["role"] == "system":
664
- content = message["content"]
665
- if isinstance(content, list):
666
- for item in content:
667
- if "text" in item:
668
- system_prompt = item["text"]
669
- break
670
- elif isinstance(content, str):
671
- system_prompt = content
672
- elif "content" in message:
673
- if isinstance(message["content"], list):
674
- parts = []
675
- for item in message["content"]:
676
- if "type" in item:
677
- if item["type"] == "image_url" and "image_url" in item:
678
- image_url = item["image_url"]
679
- if "url" not in image_url:
680
- continue
681
- url = image_url["url"]
682
- if not url.startswith("data:"):
683
- raise (Exception("Image was not downloaded: " + url))
684
- # Extract mime type from data uri
685
- mimetype = url.split(";", 1)[0].split(":", 1)[1] if ";" in url else "image/png"
686
- base64_data = url.split(",", 1)[1]
687
- parts.append({"inline_data": {"mime_type": mimetype, "data": base64_data}})
688
- elif item["type"] == "input_audio" and "input_audio" in item:
689
- input_audio = item["input_audio"]
690
- if "data" not in input_audio:
691
- continue
692
- data = input_audio["data"]
693
- format = input_audio["format"]
694
- mimetype = f"audio/{format}"
695
- parts.append({"inline_data": {"mime_type": mimetype, "data": data}})
696
- elif item["type"] == "file" and "file" in item:
697
- file = item["file"]
698
- if "file_data" not in file:
699
- continue
700
- data = file["file_data"]
701
- if not data.startswith("data:"):
702
- raise (Exception("File was not downloaded: " + data))
703
- # Extract mime type from data uri
704
- mimetype = (
705
- data.split(";", 1)[0].split(":", 1)[1]
706
- if ";" in data
707
- else "application/octet-stream"
708
- )
709
- base64_data = data.split(",", 1)[1]
710
- parts.append({"inline_data": {"mime_type": mimetype, "data": base64_data}})
711
- if "text" in item:
712
- text = item["text"]
713
- parts.append({"text": text})
714
- if len(parts) > 0:
715
- contents.append(
716
- {
717
- "role": message["role"]
718
- if "role" in message and message["role"] == "user"
719
- else "model",
720
- "parts": parts,
721
- }
722
- )
723
- else:
724
- content = message["content"]
725
- contents.append(
726
- {
727
- "role": message["role"] if "role" in message and message["role"] == "user" else "model",
728
- "parts": [{"text": content}],
729
- }
730
- )
1215
+ class LMStudioProvider(OllamaProvider):
1216
+ sdk = "lmstudio"
731
1217
 
732
- gemini_chat = {
733
- "contents": contents,
734
- }
1218
+ def __init__(self, **kwargs):
1219
+ super().__init__(**kwargs)
1220
+ self.chat_url = f"{self.api}/chat/completions"
735
1221
 
736
- if self.safety_settings:
737
- gemini_chat["safetySettings"] = self.safety_settings
738
-
739
- # Add system instruction if present
740
- if system_prompt is not None:
741
- gemini_chat["systemInstruction"] = {"parts": [{"text": system_prompt}]}
742
-
743
- if "max_completion_tokens" in chat:
744
- generation_config["maxOutputTokens"] = chat["max_completion_tokens"]
745
- if "stop" in chat:
746
- generation_config["stopSequences"] = [chat["stop"]]
747
- if "temperature" in chat:
748
- generation_config["temperature"] = chat["temperature"]
749
- if "top_p" in chat:
750
- generation_config["topP"] = chat["top_p"]
751
- if "top_logprobs" in chat:
752
- generation_config["topK"] = chat["top_logprobs"]
753
-
754
- if "thinkingConfig" in chat:
755
- generation_config["thinkingConfig"] = chat["thinkingConfig"]
756
- elif self.thinking_config:
757
- generation_config["thinkingConfig"] = self.thinking_config
758
-
759
- if len(generation_config) > 0:
760
- gemini_chat["generationConfig"] = generation_config
761
-
762
- started_at = int(time.time() * 1000)
763
- gemini_chat_url = f"https://generativelanguage.googleapis.com/v1beta/models/{chat['model']}:generateContent?key={self.api_key}"
764
-
765
- _log(f"POST {gemini_chat_url}")
766
- _log(gemini_chat_summary(gemini_chat))
767
- started_at = time.time()
1222
+ async def get_models(self):
1223
+ ret = {}
1224
+ try:
1225
+ async with aiohttp.ClientSession() as session:
1226
+ _log(f"GET {self.api}/models")
1227
+ async with session.get(
1228
+ f"{self.api}/models", headers=self.headers, timeout=aiohttp.ClientTimeout(total=120)
1229
+ ) as response:
1230
+ data = await response_json(response)
1231
+ for model in data.get("data", []):
1232
+ id = model["id"]
1233
+ ret[id] = id
1234
+ _log(f"Loaded LMStudio models: {ret}")
1235
+ except Exception as e:
1236
+ _log(f"Error getting LMStudio models: {e}")
1237
+ # return empty dict if ollama is not available
1238
+ return ret
768
1239
 
769
- if self.curl:
770
- curl_args = [
771
- "curl",
772
- "-X",
773
- "POST",
774
- "-H",
775
- "Content-Type: application/json",
776
- "-d",
777
- json.dumps(gemini_chat),
778
- gemini_chat_url,
779
- ]
780
- try:
781
- o = subprocess.run(curl_args, check=True, capture_output=True, text=True, timeout=120)
782
- obj = json.loads(o.stdout)
783
- except Exception as e:
784
- raise Exception(f"Error executing curl: {e}") from e
785
- else:
786
- async with session.post(
787
- gemini_chat_url,
788
- headers=self.headers,
789
- data=json.dumps(gemini_chat),
790
- timeout=aiohttp.ClientTimeout(total=120),
791
- ) as res:
792
- obj = await response_json(res)
793
- _log(f"google response:\n{json.dumps(obj, indent=2)}")
794
-
795
- response = {
796
- "id": f"chatcmpl-{started_at}",
797
- "created": started_at,
798
- "model": obj.get("modelVersion", chat["model"]),
799
- }
800
- choices = []
801
- if "error" in obj:
802
- _log(f"Error: {obj['error']}")
803
- raise Exception(obj["error"]["message"])
804
- for i, candidate in enumerate(obj["candidates"]):
805
- role = "assistant"
806
- if "content" in candidate and "role" in candidate["content"]:
807
- role = "assistant" if candidate["content"]["role"] == "model" else candidate["content"]["role"]
808
-
809
- # Safely extract content from all text parts
810
- content = ""
811
- reasoning = ""
812
- if "content" in candidate and "parts" in candidate["content"]:
813
- text_parts = []
814
- reasoning_parts = []
815
- for part in candidate["content"]["parts"]:
816
- if "text" in part:
817
- if "thought" in part and part["thought"]:
818
- reasoning_parts.append(part["text"])
819
- else:
820
- text_parts.append(part["text"])
821
- content = " ".join(text_parts)
822
- reasoning = " ".join(reasoning_parts)
823
1240
 
824
- choice = {
825
- "index": i,
826
- "finish_reason": candidate.get("finishReason", "stop"),
827
- "message": {
828
- "role": role,
829
- "content": content,
830
- },
831
- }
832
- if reasoning:
833
- choice["message"]["reasoning"] = reasoning
834
- choices.append(choice)
835
- response["choices"] = choices
836
- if "usageMetadata" in obj:
837
- usage = obj["usageMetadata"]
838
- response["usage"] = {
839
- "completion_tokens": usage["candidatesTokenCount"],
840
- "total_tokens": usage["totalTokenCount"],
841
- "prompt_tokens": usage["promptTokenCount"],
842
- }
843
- return self.to_response(response, chat, started_at)
1241
+ def get_provider_model(model_name):
1242
+ for provider in g_handlers.values():
1243
+ provider_model = provider.provider_model(model_name)
1244
+ if provider_model:
1245
+ return provider_model
1246
+ return None
844
1247
 
845
1248
 
846
1249
  def get_models():
@@ -856,42 +1259,259 @@ def get_models():
856
1259
  def get_active_models():
857
1260
  ret = []
858
1261
  existing_models = set()
859
- for id, provider in g_handlers.items():
860
- for model in provider.models:
861
- if model not in existing_models:
862
- existing_models.add(model)
863
- provider_model = provider.models[model]
864
- pricing = provider.model_pricing(model)
865
- ret.append({"id": model, "provider": id, "provider_model": provider_model, "pricing": pricing})
1262
+ for provider_id, provider in g_handlers.items():
1263
+ for model in provider.models.values():
1264
+ name = model.get("name")
1265
+ if not name:
1266
+ _log(f"Provider {provider_id} model {model} has no name")
1267
+ continue
1268
+ if name not in existing_models:
1269
+ existing_models.add(name)
1270
+ item = model.copy()
1271
+ item.update({"provider": provider_id})
1272
+ ret.append(item)
866
1273
  ret.sort(key=lambda x: x["id"])
867
1274
  return ret
868
1275
 
869
1276
 
870
- async def chat_completion(chat):
871
- model = chat["model"]
872
- # get first provider that has the model
873
- candidate_providers = [name for name, provider in g_handlers.items() if model in provider.models]
874
- if len(candidate_providers) == 0:
875
- raise (Exception(f"Model {model} not found"))
1277
+ def api_providers():
1278
+ ret = []
1279
+ for id, provider in g_handlers.items():
1280
+ ret.append({"id": id, "name": provider.name, "models": provider.models})
1281
+ return ret
1282
+
1283
+
1284
+ def to_error_message(e):
1285
+ return str(e)
1286
+
1287
+
1288
+ def to_error_response(e, stacktrace=False):
1289
+ status = {"errorCode": "Error", "message": to_error_message(e)}
1290
+ if stacktrace:
1291
+ status["stackTrace"] = traceback.format_exc()
1292
+ return {"responseStatus": status}
1293
+
1294
+
1295
+ def create_error_response(message, error_code="Error", stack_trace=None):
1296
+ ret = {"responseStatus": {"errorCode": error_code, "message": message}}
1297
+ if stack_trace:
1298
+ ret["responseStatus"]["stackTrace"] = stack_trace
1299
+ return ret
1300
+
1301
+
1302
+ def should_cancel_thread(context):
1303
+ ret = context.get("cancelled", False)
1304
+ if ret:
1305
+ thread_id = context.get("threadId")
1306
+ _dbg(f"Thread cancelled {thread_id}")
1307
+ return ret
1308
+
1309
+
1310
+ def g_chat_request(template=None, text=None, model=None, system_prompt=None):
1311
+ chat_template = g_config["defaults"].get(template or "text")
1312
+ if not chat_template:
1313
+ raise Exception(f"Chat template '{template}' not found")
1314
+
1315
+ chat = chat_template.copy()
1316
+ if model:
1317
+ chat["model"] = model
1318
+ if system_prompt is not None:
1319
+ chat["messages"].insert(0, {"role": "system", "content": system_prompt})
1320
+ if text is not None:
1321
+ if not chat["messages"] or len(chat["messages"]) == 0:
1322
+ chat["messages"] = [{"role": "user", "content": [{"type": "text", "text": ""}]}]
1323
+
1324
+ # replace content of last message if exists, else add
1325
+ last_msg = chat["messages"][-1] if "messages" in chat else None
1326
+ if last_msg and last_msg["role"] == "user":
1327
+ if isinstance(last_msg["content"], list):
1328
+ last_msg["content"][-1]["text"] = text
1329
+ else:
1330
+ last_msg["content"] = text
1331
+ else:
1332
+ chat["messages"].append({"role": "user", "content": text})
1333
+
1334
+ return chat
1335
+
1336
+
1337
+ async def g_chat_completion(chat, context=None):
1338
+ try:
1339
+ model = chat.get("model")
1340
+ if not model:
1341
+ raise Exception("Model not specified")
876
1342
 
1343
+ if context is None:
1344
+ context = {"chat": chat, "tools": "all"}
1345
+
1346
+ # get first provider that has the model
1347
+ candidate_providers = [name for name, provider in g_handlers.items() if provider.provider_model(model)]
1348
+ if len(candidate_providers) == 0:
1349
+ raise (Exception(f"Model {model} not found"))
1350
+ except Exception as e:
1351
+ await g_app.on_chat_error(e, context or {"chat": chat})
1352
+ raise e
1353
+
1354
+ started_at = time.time()
877
1355
  first_exception = None
1356
+ provider_name = "Unknown"
878
1357
  for name in candidate_providers:
879
- provider = g_handlers[name]
880
- _log(f"provider: {name} {type(provider).__name__}")
881
1358
  try:
882
- response = await provider.chat(chat.copy())
883
- return response
1359
+ provider_name = name
1360
+ provider = g_handlers[name]
1361
+ _log(f"provider: {name} {type(provider).__name__}")
1362
+ started_at = time.time()
1363
+ context["startedAt"] = datetime.now()
1364
+ context["provider"] = name
1365
+ model_info = provider.model_info(model)
1366
+ context["modelCost"] = model_info.get("cost", provider.model_cost(model)) or {"input": 0, "output": 0}
1367
+ context["modelInfo"] = model_info
1368
+
1369
+ # Accumulate usage across tool calls
1370
+ total_usage = {
1371
+ "prompt_tokens": 0,
1372
+ "completion_tokens": 0,
1373
+ "total_tokens": 0,
1374
+ }
1375
+ accumulated_cost = 0.0
1376
+
1377
+ # Inject global tools if present
1378
+ current_chat = chat.copy()
1379
+ if g_app.tool_definitions:
1380
+ only_tools_str = context.get("tools", "all")
1381
+ include_all_tools = only_tools_str == "all"
1382
+ only_tools = only_tools_str.split(",")
1383
+
1384
+ if include_all_tools or len(only_tools) > 0:
1385
+ if "tools" not in current_chat:
1386
+ current_chat["tools"] = []
1387
+
1388
+ existing_tools = {t["function"]["name"] for t in current_chat["tools"]}
1389
+ for tool_def in g_app.tool_definitions:
1390
+ name = tool_def["function"]["name"]
1391
+ if name not in existing_tools and (include_all_tools or name in only_tools):
1392
+ current_chat["tools"].append(tool_def)
1393
+
1394
+ # Apply pre-chat filters ONCE
1395
+ context["chat"] = current_chat
1396
+ for filter_func in g_app.chat_request_filters:
1397
+ await filter_func(current_chat, context)
1398
+
1399
+ # Tool execution loop
1400
+ max_iterations = 10
1401
+ tool_history = []
1402
+ final_response = None
1403
+
1404
+ for _ in range(max_iterations):
1405
+ if should_cancel_thread(context):
1406
+ return
1407
+
1408
+ response = await provider.chat(current_chat)
1409
+
1410
+ if should_cancel_thread(context):
1411
+ return
1412
+
1413
+ # Aggregate usage
1414
+ if "usage" in response:
1415
+ usage = response["usage"]
1416
+ total_usage["prompt_tokens"] += usage.get("prompt_tokens", 0)
1417
+ total_usage["completion_tokens"] += usage.get("completion_tokens", 0)
1418
+ total_usage["total_tokens"] += usage.get("total_tokens", 0)
1419
+
1420
+ # Calculate cost for this step if available
1421
+ if "cost" in response and isinstance(response["cost"], (int, float)):
1422
+ accumulated_cost += response["cost"]
1423
+ elif "cost" in usage and isinstance(usage["cost"], (int, float)):
1424
+ accumulated_cost += usage["cost"]
1425
+
1426
+ # Check for tool_calls in the response
1427
+ choice = response.get("choices", [])[0] if response.get("choices") else {}
1428
+ message = choice.get("message", {})
1429
+ tool_calls = message.get("tool_calls")
1430
+
1431
+ if tool_calls:
1432
+ # Append the assistant's message with tool calls to history
1433
+ if "messages" not in current_chat:
1434
+ current_chat["messages"] = []
1435
+ if "timestamp" not in message:
1436
+ message["timestamp"] = int(time.time() * 1000)
1437
+ current_chat["messages"].append(message)
1438
+ tool_history.append(message)
1439
+
1440
+ await g_app.on_chat_tool(current_chat, context)
1441
+
1442
+ for tool_call in tool_calls:
1443
+ function_name = tool_call["function"]["name"]
1444
+ try:
1445
+ function_args = json.loads(tool_call["function"]["arguments"])
1446
+ except Exception as e:
1447
+ tool_result = f"Error parsing JSON arguments for tool {function_name}: {e}"
1448
+ else:
1449
+ tool_result = f"Error: Tool {function_name} not found"
1450
+ if function_name in g_app.tools:
1451
+ try:
1452
+ func = g_app.tools[function_name]
1453
+ if inspect.iscoroutinefunction(func):
1454
+ tool_result = await func(**function_args)
1455
+ else:
1456
+ tool_result = func(**function_args)
1457
+ except Exception as e:
1458
+ tool_result = f"Error executing tool {function_name}: {e}"
1459
+
1460
+ # Append tool result to history
1461
+ tool_msg = {"role": "tool", "tool_call_id": tool_call["id"], "content": to_content(tool_result)}
1462
+ current_chat["messages"].append(tool_msg)
1463
+ tool_history.append(tool_msg)
1464
+
1465
+ await g_app.on_chat_tool(current_chat, context)
1466
+
1467
+ if should_cancel_thread(context):
1468
+ return
1469
+
1470
+ # Continue loop to send tool results back to LLM
1471
+ continue
1472
+
1473
+ # If no tool calls, this is the final response
1474
+ if tool_history:
1475
+ response["tool_history"] = tool_history
1476
+
1477
+ # Update final response with aggregated usage
1478
+ if "usage" not in response:
1479
+ response["usage"] = {}
1480
+ # convert to int seconds
1481
+ context["duration"] = duration = int(time.time() - started_at)
1482
+ total_usage.update({"duration": duration})
1483
+ response["usage"].update(total_usage)
1484
+ # If we accumulated cost, set it on the response
1485
+ if accumulated_cost > 0:
1486
+ response["cost"] = accumulated_cost
1487
+
1488
+ final_response = response
1489
+ break # Exit tool loop
1490
+
1491
+ if final_response:
1492
+ # Apply post-chat filters ONCE on final response
1493
+ for filter_func in g_app.chat_response_filters:
1494
+ await filter_func(final_response, context)
1495
+
1496
+ if DEBUG:
1497
+ _dbg(json.dumps(final_response, indent=2))
1498
+
1499
+ return final_response
1500
+
884
1501
  except Exception as e:
885
1502
  if first_exception is None:
886
1503
  first_exception = e
887
- _log(f"Provider {name} failed: {e}")
1504
+ context["stackTrace"] = traceback.format_exc()
1505
+ _err(f"Provider {provider_name} failed", first_exception)
1506
+ await g_app.on_chat_error(e, context)
1507
+
888
1508
  continue
889
1509
 
890
1510
  # If we get here, all providers failed
891
1511
  raise first_exception
892
1512
 
893
1513
 
894
- async def cli_chat(chat, image=None, audio=None, file=None, args=None, raw=False):
1514
+ async def cli_chat(chat, tools=None, image=None, audio=None, file=None, args=None, raw=False):
895
1515
  if g_default_model:
896
1516
  chat["model"] = g_default_model
897
1517
 
@@ -966,73 +1586,161 @@ async def cli_chat(chat, image=None, audio=None, file=None, args=None, raw=False
966
1586
  printdump(chat)
967
1587
 
968
1588
  try:
969
- response = await chat_completion(chat)
1589
+ context = {
1590
+ "tools": tools or "all",
1591
+ }
1592
+ response = await g_app.chat_completion(chat, context=context)
1593
+
970
1594
  if raw:
971
1595
  print(json.dumps(response, indent=2))
972
1596
  exit(0)
973
1597
  else:
974
- answer = response["choices"][0]["message"]["content"]
975
- print(answer)
1598
+ msg = response["choices"][0]["message"]
1599
+ if "content" in msg or "answer" in msg:
1600
+ print(msg["content"])
1601
+
1602
+ generated_files = []
1603
+ for choice in response["choices"]:
1604
+ if "message" in choice:
1605
+ msg = choice["message"]
1606
+ if "images" in msg:
1607
+ for image in msg["images"]:
1608
+ image_url = image["image_url"]["url"]
1609
+ generated_files.append(image_url)
1610
+ if "audios" in msg:
1611
+ for audio in msg["audios"]:
1612
+ audio_url = audio["audio_url"]["url"]
1613
+ generated_files.append(audio_url)
1614
+
1615
+ if len(generated_files) > 0:
1616
+ print("\nSaved files:")
1617
+ for file in generated_files:
1618
+ if file.startswith("/~cache"):
1619
+ print(get_cache_path(file[8:]))
1620
+ print(urljoin("http://localhost:8000", file))
1621
+ else:
1622
+ print(file)
1623
+
976
1624
  except HTTPError as e:
977
1625
  # HTTP error (4xx, 5xx)
978
1626
  print(f"{e}:\n{e.body}")
979
- exit(1)
1627
+ g_app.exit(1)
980
1628
  except aiohttp.ClientConnectionError as e:
981
1629
  # Connection issues
982
1630
  print(f"Connection error: {e}")
983
- exit(1)
1631
+ g_app.exit(1)
984
1632
  except asyncio.TimeoutError as e:
985
1633
  # Timeout
986
1634
  print(f"Timeout error: {e}")
987
- exit(1)
1635
+ g_app.exit(1)
988
1636
 
989
1637
 
990
1638
  def config_str(key):
991
1639
  return key in g_config and g_config[key] or None
992
1640
 
993
1641
 
994
- def init_llms(config):
1642
+ def load_config(config, providers, verbose=None):
1643
+ global g_config, g_providers, g_verbose
1644
+ g_config = config
1645
+ g_providers = providers
1646
+ if verbose:
1647
+ g_verbose = verbose
1648
+
1649
+
1650
+ def init_llms(config, providers):
995
1651
  global g_config, g_handlers
996
1652
 
997
- g_config = config
1653
+ load_config(config, providers)
998
1654
  g_handlers = {}
999
1655
  # iterate over config and replace $ENV with env value
1000
1656
  for key, value in g_config.items():
1001
1657
  if isinstance(value, str) and value.startswith("$"):
1002
- g_config[key] = os.environ.get(value[1:], "")
1658
+ g_config[key] = os.getenv(value[1:], "")
1003
1659
 
1004
1660
  # if g_verbose:
1005
1661
  # printdump(g_config)
1006
1662
  providers = g_config["providers"]
1007
1663
 
1008
- for name, orig in providers.items():
1009
- definition = orig.copy()
1010
- provider_type = definition["type"]
1011
- if "enabled" in definition and not definition["enabled"]:
1664
+ for id, orig in providers.items():
1665
+ if "enabled" in orig and not orig["enabled"]:
1012
1666
  continue
1013
1667
 
1014
- # Replace API keys with environment variables if they start with $
1015
- if "api_key" in definition:
1016
- value = definition["api_key"]
1017
- if isinstance(value, str) and value.startswith("$"):
1018
- definition["api_key"] = os.environ.get(value[1:], "")
1019
-
1020
- # Create a copy of definition without the 'type' key for constructor kwargs
1021
- constructor_kwargs = {k: v for k, v in definition.items() if k != "type" and k != "enabled"}
1022
- constructor_kwargs["headers"] = g_config["defaults"]["headers"].copy()
1023
-
1024
- if provider_type == "OpenAiProvider" and OpenAiProvider.test(**constructor_kwargs):
1025
- g_handlers[name] = OpenAiProvider(**constructor_kwargs)
1026
- elif provider_type == "OllamaProvider" and OllamaProvider.test(**constructor_kwargs):
1027
- g_handlers[name] = OllamaProvider(**constructor_kwargs)
1028
- elif provider_type == "GoogleProvider" and GoogleProvider.test(**constructor_kwargs):
1029
- g_handlers[name] = GoogleProvider(**constructor_kwargs)
1030
- elif provider_type == "GoogleOpenAiProvider" and GoogleOpenAiProvider.test(**constructor_kwargs):
1031
- g_handlers[name] = GoogleOpenAiProvider(**constructor_kwargs)
1032
-
1668
+ provider, constructor_kwargs = create_provider_from_definition(id, orig)
1669
+ if provider and provider.test(**constructor_kwargs):
1670
+ g_handlers[id] = provider
1033
1671
  return g_handlers
1034
1672
 
1035
1673
 
1674
+ def create_provider_from_definition(id, orig):
1675
+ definition = orig.copy()
1676
+ provider_id = definition.get("id", id)
1677
+ if "id" not in definition:
1678
+ definition["id"] = provider_id
1679
+ provider = g_providers.get(provider_id)
1680
+ constructor_kwargs = create_provider_kwargs(definition, provider)
1681
+ provider = create_provider(constructor_kwargs)
1682
+ return provider, constructor_kwargs
1683
+
1684
+
1685
+ def create_provider_kwargs(definition, provider=None):
1686
+ if provider:
1687
+ provider = provider.copy()
1688
+ provider.update(definition)
1689
+ else:
1690
+ provider = definition.copy()
1691
+
1692
+ # Replace API keys with environment variables if they start with $
1693
+ if "api_key" in provider:
1694
+ value = provider["api_key"]
1695
+ if isinstance(value, str) and value.startswith("$"):
1696
+ provider["api_key"] = os.getenv(value[1:], "")
1697
+
1698
+ if "api_key" not in provider and "env" in provider:
1699
+ for env_var in provider["env"]:
1700
+ val = os.getenv(env_var)
1701
+ if val:
1702
+ provider["api_key"] = val
1703
+ break
1704
+
1705
+ # Create a copy of provider
1706
+ constructor_kwargs = dict(provider.items())
1707
+ # Create a copy of all list and dict values
1708
+ for key, value in constructor_kwargs.items():
1709
+ if isinstance(value, (list, dict)):
1710
+ constructor_kwargs[key] = value.copy()
1711
+ constructor_kwargs["headers"] = g_config["defaults"]["headers"].copy()
1712
+
1713
+ if "modalities" in definition:
1714
+ constructor_kwargs["modalities"] = {}
1715
+ for modality, modality_definition in definition["modalities"].items():
1716
+ modality_provider = create_provider(modality_definition)
1717
+ if not modality_provider:
1718
+ return None
1719
+ constructor_kwargs["modalities"][modality] = modality_provider
1720
+
1721
+ return constructor_kwargs
1722
+
1723
+
1724
+ def create_provider(provider):
1725
+ if not isinstance(provider, dict):
1726
+ return None
1727
+ provider_label = provider.get("id", provider.get("name", "unknown"))
1728
+ npm_sdk = provider.get("npm")
1729
+ if not npm_sdk:
1730
+ _log(f"Provider {provider_label} is missing 'npm' sdk")
1731
+ return None
1732
+
1733
+ for provider_type in g_app.all_providers:
1734
+ if provider_type.sdk == npm_sdk:
1735
+ kwargs = create_provider_kwargs(provider)
1736
+ if kwargs is None:
1737
+ kwargs = provider
1738
+ return provider_type(**kwargs)
1739
+
1740
+ _log(f"Could not find provider {provider_label} with npm sdk {npm_sdk}")
1741
+ return None
1742
+
1743
+
1036
1744
  async def load_llms():
1037
1745
  global g_handlers
1038
1746
  _log("Loading providers...")
@@ -1076,6 +1784,35 @@ async def save_default_config(config_path):
1076
1784
  g_config = json.loads(config_json)
1077
1785
 
1078
1786
 
1787
+ async def update_providers(home_providers_path):
1788
+ global g_providers
1789
+ text = await get_text("https://models.dev/api.json")
1790
+ all_providers = json.loads(text)
1791
+ extra_providers = {}
1792
+ extra_providers_path = home_providers_path.replace("providers.json", "providers-extra.json")
1793
+ if os.path.exists(extra_providers_path):
1794
+ with open(extra_providers_path) as f:
1795
+ extra_providers = json.load(f)
1796
+
1797
+ filtered_providers = {}
1798
+ for id, provider in all_providers.items():
1799
+ if id in g_config["providers"]:
1800
+ filtered_providers[id] = provider
1801
+ if id in extra_providers and "models" in extra_providers[id]:
1802
+ for model_id, model in extra_providers[id]["models"].items():
1803
+ if "id" not in model:
1804
+ model["id"] = model_id
1805
+ if "name" not in model:
1806
+ model["name"] = id_to_name(model["id"])
1807
+ filtered_providers[id]["models"][model_id] = model
1808
+
1809
+ os.makedirs(os.path.dirname(home_providers_path), exist_ok=True)
1810
+ with open(home_providers_path, "w", encoding="utf-8") as f:
1811
+ json.dump(filtered_providers, f)
1812
+
1813
+ g_providers = filtered_providers
1814
+
1815
+
1079
1816
  def provider_status():
1080
1817
  enabled = list(g_handlers.keys())
1081
1818
  disabled = [provider for provider in g_config["providers"] if provider not in enabled]
@@ -1097,7 +1834,11 @@ def print_status():
1097
1834
 
1098
1835
 
1099
1836
  def home_llms_path(filename):
1100
- return f"{os.environ.get('HOME')}/.llms/{filename}"
1837
+ return f"{os.getenv('HOME')}/.llms/{filename}"
1838
+
1839
+
1840
+ def get_cache_path(path=""):
1841
+ return home_llms_path(f"cache/{path}") if path else home_llms_path("cache")
1101
1842
 
1102
1843
 
1103
1844
  def get_config_path():
@@ -1106,8 +1847,8 @@ def get_config_path():
1106
1847
  "./llms.json",
1107
1848
  home_config_path,
1108
1849
  ]
1109
- if os.environ.get("LLMS_CONFIG_PATH"):
1110
- check_paths.insert(0, os.environ.get("LLMS_CONFIG_PATH"))
1850
+ if os.getenv("LLMS_CONFIG_PATH"):
1851
+ check_paths.insert(0, os.getenv("LLMS_CONFIG_PATH"))
1111
1852
 
1112
1853
  for check_path in check_paths:
1113
1854
  g_config_path = os.path.normpath(os.path.join(os.path.dirname(__file__), check_path))
@@ -1116,28 +1857,20 @@ def get_config_path():
1116
1857
  return None
1117
1858
 
1118
1859
 
1119
- def get_ui_path():
1120
- ui_paths = [home_llms_path("ui.json"), "ui.json"]
1121
- for ui_path in ui_paths:
1122
- if os.path.exists(ui_path):
1123
- return ui_path
1124
- return None
1125
-
1126
-
1127
1860
  def enable_provider(provider):
1128
1861
  msg = None
1129
1862
  provider_config = g_config["providers"][provider]
1863
+ if not provider_config:
1864
+ return None, f"Provider {provider} not found"
1865
+
1866
+ provider, constructor_kwargs = create_provider_from_definition(provider, provider_config)
1867
+ msg = provider.validate(**constructor_kwargs)
1868
+ if msg:
1869
+ return None, msg
1870
+
1130
1871
  provider_config["enabled"] = True
1131
- if "api_key" in provider_config:
1132
- api_key = provider_config["api_key"]
1133
- if isinstance(api_key, str):
1134
- if api_key.startswith("$"):
1135
- if not os.environ.get(api_key[1:], ""):
1136
- msg = f"WARNING: {provider} requires missing API Key in Environment Variable {api_key}"
1137
- else:
1138
- msg = f"WARNING: {provider} is not configured with an API Key"
1139
1872
  save_config(g_config)
1140
- init_llms(g_config)
1873
+ init_llms(g_config, g_providers)
1141
1874
  return provider_config, msg
1142
1875
 
1143
1876
 
@@ -1145,7 +1878,7 @@ def disable_provider(provider):
1145
1878
  provider_config = g_config["providers"][provider]
1146
1879
  provider_config["enabled"] = False
1147
1880
  save_config(g_config)
1148
- init_llms(g_config)
1881
+ init_llms(g_config, g_providers)
1149
1882
 
1150
1883
 
1151
1884
  def resolve_root():
@@ -1340,7 +2073,8 @@ async def check_models(provider_name, model_names=None):
1340
2073
  else:
1341
2074
  # Check only specified models
1342
2075
  for model_name in model_names:
1343
- if model_name in provider.models:
2076
+ provider_model = provider.provider_model(model_name)
2077
+ if provider_model:
1344
2078
  models_to_check.append(model_name)
1345
2079
  else:
1346
2080
  print(f"Model '{model_name}' not found in provider '{provider_name}'")
@@ -1355,69 +2089,76 @@ async def check_models(provider_name, model_names=None):
1355
2089
 
1356
2090
  # Test each model
1357
2091
  for model in models_to_check:
1358
- # Create a simple ping chat request
1359
- chat = (provider.check or g_config["defaults"]["check"]).copy()
1360
- chat["model"] = model
2092
+ await check_provider_model(provider, model)
1361
2093
 
1362
- started_at = time.time()
1363
- try:
1364
- # Try to get a response from the model
1365
- response = await provider.chat(chat)
1366
- duration_ms = int((time.time() - started_at) * 1000)
2094
+ print()
1367
2095
 
1368
- # Check if we got a valid response
1369
- if response and "choices" in response and len(response["choices"]) > 0:
1370
- print(f" ✓ {model:<40} ({duration_ms}ms)")
1371
- else:
1372
- print(f" ✗ {model:<40} Invalid response format")
1373
- except HTTPError as e:
1374
- duration_ms = int((time.time() - started_at) * 1000)
1375
- error_msg = f"HTTP {e.status}"
1376
- try:
1377
- # Try to parse error body for more details
1378
- error_body = json.loads(e.body) if e.body else {}
1379
- if "error" in error_body:
1380
- error = error_body["error"]
1381
- if isinstance(error, dict):
1382
- if "message" in error and isinstance(error["message"], str):
1383
- # OpenRouter
1384
- error_msg = error["message"]
1385
- if "code" in error:
1386
- error_msg = f"{error['code']} {error_msg}"
1387
- if "metadata" in error and "raw" in error["metadata"]:
1388
- error_msg += f" - {error['metadata']['raw']}"
1389
- if "provider" in error:
1390
- error_msg += f" ({error['provider']})"
1391
- elif isinstance(error, str):
1392
- error_msg = error
1393
- elif "message" in error_body:
1394
- if isinstance(error_body["message"], str):
1395
- error_msg = error_body["message"]
1396
- elif (
1397
- isinstance(error_body["message"], dict)
1398
- and "detail" in error_body["message"]
1399
- and isinstance(error_body["message"]["detail"], list)
1400
- ):
1401
- # codestral error format
1402
- error_msg = error_body["message"]["detail"][0]["msg"]
1403
- if (
1404
- "loc" in error_body["message"]["detail"][0]
1405
- and len(error_body["message"]["detail"][0]["loc"]) > 0
1406
- ):
1407
- error_msg += f" (in {' '.join(error_body['message']['detail'][0]['loc'])})"
1408
- except Exception as parse_error:
1409
- _log(f"Error parsing error body: {parse_error}")
1410
- error_msg = e.body[:100] if e.body else f"HTTP {e.status}"
1411
- print(f" ✗ {model:<40} {error_msg}")
1412
- except asyncio.TimeoutError:
1413
- duration_ms = int((time.time() - started_at) * 1000)
1414
- print(f" ✗ {model:<40} Timeout after {duration_ms}ms")
1415
- except Exception as e:
1416
- duration_ms = int((time.time() - started_at) * 1000)
1417
- error_msg = str(e)[:100]
1418
- print(f" ✗ {model:<40} {error_msg}")
1419
2096
 
1420
- print()
2097
+ async def check_provider_model(provider, model):
2098
+ # Create a simple ping chat request
2099
+ chat = (provider.check or g_config["defaults"]["check"]).copy()
2100
+ chat["model"] = model
2101
+
2102
+ success = False
2103
+ started_at = time.time()
2104
+ try:
2105
+ # Try to get a response from the model
2106
+ response = await provider.chat(chat)
2107
+ duration_ms = int((time.time() - started_at) * 1000)
2108
+
2109
+ # Check if we got a valid response
2110
+ if response and "choices" in response and len(response["choices"]) > 0:
2111
+ success = True
2112
+ print(f" ✓ {model:<40} ({duration_ms}ms)")
2113
+ else:
2114
+ print(f" ✗ {model:<40} Invalid response format")
2115
+ except HTTPError as e:
2116
+ duration_ms = int((time.time() - started_at) * 1000)
2117
+ error_msg = f"HTTP {e.status}"
2118
+ try:
2119
+ # Try to parse error body for more details
2120
+ error_body = json.loads(e.body) if e.body else {}
2121
+ if "error" in error_body:
2122
+ error = error_body["error"]
2123
+ if isinstance(error, dict):
2124
+ if "message" in error and isinstance(error["message"], str):
2125
+ # OpenRouter
2126
+ error_msg = error["message"]
2127
+ if "code" in error:
2128
+ error_msg = f"{error['code']} {error_msg}"
2129
+ if "metadata" in error and "raw" in error["metadata"]:
2130
+ error_msg += f" - {error['metadata']['raw']}"
2131
+ if "provider" in error:
2132
+ error_msg += f" ({error['provider']})"
2133
+ elif isinstance(error, str):
2134
+ error_msg = error
2135
+ elif "message" in error_body:
2136
+ if isinstance(error_body["message"], str):
2137
+ error_msg = error_body["message"]
2138
+ elif (
2139
+ isinstance(error_body["message"], dict)
2140
+ and "detail" in error_body["message"]
2141
+ and isinstance(error_body["message"]["detail"], list)
2142
+ ):
2143
+ # codestral error format
2144
+ error_msg = error_body["message"]["detail"][0]["msg"]
2145
+ if (
2146
+ "loc" in error_body["message"]["detail"][0]
2147
+ and len(error_body["message"]["detail"][0]["loc"]) > 0
2148
+ ):
2149
+ error_msg += f" (in {' '.join(error_body['message']['detail'][0]['loc'])})"
2150
+ except Exception as parse_error:
2151
+ _log(f"Error parsing error body: {parse_error}")
2152
+ error_msg = e.body[:100] if e.body else f"HTTP {e.status}"
2153
+ print(f" ✗ {model:<40} {error_msg}")
2154
+ except asyncio.TimeoutError:
2155
+ duration_ms = int((time.time() - started_at) * 1000)
2156
+ print(f" ✗ {model:<40} Timeout after {duration_ms}ms")
2157
+ except Exception as e:
2158
+ duration_ms = int((time.time() - started_at) * 1000)
2159
+ error_msg = str(e)[:100]
2160
+ print(f" ✗ {model:<40} {error_msg}")
2161
+ return success
1421
2162
 
1422
2163
 
1423
2164
  def text_from_resource(filename):
@@ -1452,8 +2193,14 @@ async def text_from_resource_or_url(filename):
1452
2193
 
1453
2194
  async def save_home_configs():
1454
2195
  home_config_path = home_llms_path("llms.json")
1455
- home_ui_path = home_llms_path("ui.json")
1456
- if os.path.exists(home_config_path) and os.path.exists(home_ui_path):
2196
+ home_providers_path = home_llms_path("providers.json")
2197
+ home_providers_extra_path = home_llms_path("providers-extra.json")
2198
+
2199
+ if (
2200
+ os.path.exists(home_config_path)
2201
+ and os.path.exists(home_providers_path)
2202
+ and os.path.exists(home_providers_extra_path)
2203
+ ):
1457
2204
  return
1458
2205
 
1459
2206
  llms_home = os.path.dirname(home_config_path)
@@ -1465,92 +2212,650 @@ async def save_home_configs():
1465
2212
  f.write(config_json)
1466
2213
  _log(f"Created default config at {home_config_path}")
1467
2214
 
1468
- if not os.path.exists(home_ui_path):
1469
- ui_json = await text_from_resource_or_url("ui.json")
1470
- with open(home_ui_path, "w", encoding="utf-8") as f:
1471
- f.write(ui_json)
1472
- _log(f"Created default ui config at {home_ui_path}")
2215
+ if not os.path.exists(home_providers_path):
2216
+ providers_json = await text_from_resource_or_url("providers.json")
2217
+ with open(home_providers_path, "w", encoding="utf-8") as f:
2218
+ f.write(providers_json)
2219
+ _log(f"Created default providers config at {home_providers_path}")
2220
+
2221
+ if not os.path.exists(home_providers_extra_path):
2222
+ extra_json = await text_from_resource_or_url("providers-extra.json")
2223
+ with open(home_providers_extra_path, "w", encoding="utf-8") as f:
2224
+ f.write(extra_json)
2225
+ _log(f"Created default extra providers config at {home_providers_extra_path}")
1473
2226
  except Exception:
1474
2227
  print("Could not create llms.json. Create one with --init or use --config <path>")
1475
2228
  exit(1)
1476
2229
 
1477
2230
 
1478
- async def reload_providers():
1479
- global g_config, g_handlers
1480
- g_handlers = init_llms(g_config)
1481
- await load_llms()
2231
+ def load_config_json(config_json):
2232
+ if config_json is None:
2233
+ return None
2234
+ config = json.loads(config_json)
2235
+ if not config or "version" not in config or config["version"] < 3:
2236
+ preserve_keys = ["auth", "defaults", "limits", "convert"]
2237
+ new_config = json.loads(text_from_resource("llms.json"))
2238
+ if config:
2239
+ for key in preserve_keys:
2240
+ if key in config:
2241
+ new_config[key] = config[key]
2242
+ config = new_config
2243
+ # move old config to YYYY-MM-DD.bak
2244
+ new_path = f"{g_config_path}.{datetime.now().strftime('%Y-%m-%d')}.bak"
2245
+ if os.path.exists(new_path):
2246
+ os.remove(new_path)
2247
+ os.rename(g_config_path, new_path)
2248
+ print(f"llms.json migrated. old config moved to {new_path}")
2249
+ # save new config
2250
+ save_config(g_config)
2251
+ return config
2252
+
2253
+
2254
+ async def reload_providers():
2255
+ global g_config, g_handlers
2256
+ g_handlers = init_llms(g_config, g_providers)
2257
+ await load_llms()
1482
2258
  _log(f"{len(g_handlers)} providers loaded")
1483
2259
  return g_handlers
1484
2260
 
1485
2261
 
1486
- async def watch_config_files(config_path, ui_path, interval=1):
2262
+ async def watch_config_files(config_path, providers_path, interval=1):
1487
2263
  """Watch config files and reload providers when they change"""
1488
2264
  global g_config
1489
2265
 
1490
2266
  config_path = Path(config_path)
1491
- ui_path = Path(ui_path) if ui_path else None
2267
+ providers_path = Path(providers_path)
2268
+
2269
+ _log(f"Watching config file: {config_path}")
2270
+ _log(f"Watching providers file: {providers_path}")
1492
2271
 
1493
- file_mtimes = {}
2272
+ def get_latest_mtime():
2273
+ ret = 0
2274
+ name = "llms.json"
2275
+ if config_path.is_file():
2276
+ ret = config_path.stat().st_mtime
2277
+ name = config_path.name
2278
+ if providers_path.is_file() and providers_path.stat().st_mtime > ret:
2279
+ ret = providers_path.stat().st_mtime
2280
+ name = providers_path.name
2281
+ return ret, name
1494
2282
 
1495
- _log(f"Watching config files: {config_path}" + (f", {ui_path}" if ui_path else ""))
2283
+ latest_mtime, name = get_latest_mtime()
1496
2284
 
1497
2285
  while True:
1498
2286
  await asyncio.sleep(interval)
1499
2287
 
1500
2288
  # Check llms.json
1501
2289
  try:
1502
- if config_path.is_file():
1503
- mtime = config_path.stat().st_mtime
2290
+ new_mtime, name = get_latest_mtime()
2291
+ if new_mtime > latest_mtime:
2292
+ _log(f"Config file changed: {name}")
2293
+ latest_mtime = new_mtime
1504
2294
 
1505
- if str(config_path) not in file_mtimes:
1506
- file_mtimes[str(config_path)] = mtime
1507
- elif file_mtimes[str(config_path)] != mtime:
1508
- _log(f"Config file changed: {config_path.name}")
1509
- file_mtimes[str(config_path)] = mtime
1510
-
1511
- try:
1512
- # Reload llms.json
1513
- with open(config_path) as f:
1514
- g_config = json.load(f)
2295
+ try:
2296
+ # Reload llms.json
2297
+ with open(config_path) as f:
2298
+ g_config = json.load(f)
1515
2299
 
1516
- # Reload providers
1517
- await reload_providers()
1518
- _log("Providers reloaded successfully")
1519
- except Exception as e:
1520
- _log(f"Error reloading config: {e}")
2300
+ # Reload providers
2301
+ await reload_providers()
2302
+ _log("Providers reloaded successfully")
2303
+ except Exception as e:
2304
+ _log(f"Error reloading config: {e}")
1521
2305
  except FileNotFoundError:
1522
2306
  pass
1523
2307
 
1524
- # Check ui.json
1525
- if ui_path:
2308
+
2309
+ def get_session_token(request):
2310
+ return request.query.get("session") or request.headers.get("X-Session-Token") or request.cookies.get("llms-token")
2311
+
2312
+
2313
+ class AppExtensions:
2314
+ """
2315
+ APIs extensions can use to extend the app
2316
+ """
2317
+
2318
+ def __init__(self, cli_args, extra_args):
2319
+ self.cli_args = cli_args
2320
+ self.extra_args = extra_args
2321
+ self.config = None
2322
+ self.error_auth_required = create_error_response("Authentication required", "Unauthorized")
2323
+ self.ui_extensions = []
2324
+ self.chat_request_filters = []
2325
+ self.chat_tool_filters = []
2326
+ self.chat_response_filters = []
2327
+ self.chat_error_filters = []
2328
+ self.server_add_get = []
2329
+ self.server_add_post = []
2330
+ self.server_add_put = []
2331
+ self.server_add_delete = []
2332
+ self.server_add_patch = []
2333
+ self.cache_saved_filters = []
2334
+ self.shutdown_handlers = []
2335
+ self.tools = {}
2336
+ self.tool_definitions = []
2337
+ self.index_headers = []
2338
+ self.index_footers = []
2339
+ self.request_args = {
2340
+ "image_config": dict, # e.g. { "aspect_ratio": "1:1" }
2341
+ "temperature": float, # e.g: 0.7
2342
+ "max_completion_tokens": int, # e.g: 2048
2343
+ "seed": int, # e.g: 42
2344
+ "top_p": float, # e.g: 0.9
2345
+ "frequency_penalty": float, # e.g: 0.5
2346
+ "presence_penalty": float, # e.g: 0.5
2347
+ "stop": list, # e.g: ["Stop"]
2348
+ "reasoning_effort": str, # e.g: minimal, low, medium, high
2349
+ "verbosity": str, # e.g: low, medium, high
2350
+ "service_tier": str, # e.g: auto, default
2351
+ "top_logprobs": int,
2352
+ "safety_identifier": str,
2353
+ "store": bool,
2354
+ "enable_thinking": bool,
2355
+ }
2356
+ self.all_providers = [
2357
+ OpenAiCompatible,
2358
+ MistralProvider,
2359
+ GroqProvider,
2360
+ XaiProvider,
2361
+ CodestralProvider,
2362
+ OllamaProvider,
2363
+ LMStudioProvider,
2364
+ ]
2365
+ self.aspect_ratios = {
2366
+ "1:1": "1024×1024",
2367
+ "2:3": "832×1248",
2368
+ "3:2": "1248×832",
2369
+ "3:4": "864×1184",
2370
+ "4:3": "1184×864",
2371
+ "4:5": "896×1152",
2372
+ "5:4": "1152×896",
2373
+ "9:16": "768×1344",
2374
+ "16:9": "1344×768",
2375
+ "21:9": "1536×672",
2376
+ }
2377
+ self.import_maps = {
2378
+ "vue-prod": "/ui/lib/vue.min.mjs",
2379
+ "vue": "/ui/lib/vue.mjs",
2380
+ "vue-router": "/ui/lib/vue-router.min.mjs",
2381
+ "@servicestack/client": "/ui/lib/servicestack-client.mjs",
2382
+ "@servicestack/vue": "/ui/lib/servicestack-vue.mjs",
2383
+ "idb": "/ui/lib/idb.min.mjs",
2384
+ "marked": "/ui/lib/marked.min.mjs",
2385
+ "highlight.js": "/ui/lib/highlight.min.mjs",
2386
+ "chart.js": "/ui/lib/chart.js",
2387
+ "color.js": "/ui/lib/color.js",
2388
+ "ctx.mjs": "/ui/ctx.mjs",
2389
+ }
2390
+
2391
+ def set_config(self, config):
2392
+ self.config = config
2393
+ self.auth_enabled = self.config.get("auth", {}).get("enabled", False)
2394
+
2395
+ # Authentication middleware helper
2396
+ def check_auth(self, request):
2397
+ """Check if request is authenticated. Returns (is_authenticated, user_data)"""
2398
+ if not self.auth_enabled:
2399
+ return True, None
2400
+
2401
+ # Check for OAuth session token
2402
+ session_token = get_session_token(request)
2403
+ if session_token and session_token in g_sessions:
2404
+ return True, g_sessions[session_token]
2405
+
2406
+ # Check for API key
2407
+ auth_header = request.headers.get("Authorization", "")
2408
+ if auth_header.startswith("Bearer "):
2409
+ api_key = auth_header[7:]
2410
+ if api_key:
2411
+ return True, {"authProvider": "apikey"}
2412
+
2413
+ return False, None
2414
+
2415
+ def get_session(self, request):
2416
+ session_token = get_session_token(request)
2417
+
2418
+ if not session_token or session_token not in g_sessions:
2419
+ return None
2420
+
2421
+ session_data = g_sessions[session_token]
2422
+ return session_data
2423
+
2424
+ def get_username(self, request):
2425
+ session = self.get_session(request)
2426
+ if session:
2427
+ return session.get("userName")
2428
+ return None
2429
+
2430
+ def get_user_path(self, username=None):
2431
+ if username:
2432
+ return home_llms_path(os.path.join("user", username))
2433
+ return home_llms_path(os.path.join("user", "default"))
2434
+
2435
+ def chat_request(self, template=None, text=None, model=None, system_prompt=None):
2436
+ return g_chat_request(template=template, text=text, model=model, system_prompt=system_prompt)
2437
+
2438
+ async def chat_completion(self, chat, context=None):
2439
+ response = await g_chat_completion(chat, context)
2440
+ return response
2441
+
2442
+ def on_cache_saved_filters(self, context):
2443
+ # _log(f"on_cache_saved_filters {len(self.cache_saved_filters)}: {context['url']}")
2444
+ for filter_func in self.cache_saved_filters:
2445
+ filter_func(context)
2446
+
2447
+ async def on_chat_error(self, e, context):
2448
+ # Apply chat error filters
2449
+ if "stackTrace" not in context:
2450
+ context["stackTrace"] = traceback.format_exc()
2451
+ for filter_func in self.chat_error_filters:
1526
2452
  try:
1527
- if ui_path.is_file():
1528
- mtime = ui_path.stat().st_mtime
2453
+ await filter_func(e, context)
2454
+ except Exception as e:
2455
+ _err("chat error filter failed", e)
2456
+
2457
+ async def on_chat_tool(self, chat, context):
2458
+ m_len = len(chat.get("messages", []))
2459
+ t_len = len(self.chat_tool_filters)
2460
+ _dbg(
2461
+ f"on_tool_call for thread {context.get('threadId', None)} with {m_len} {pluralize('message', m_len)}, invoking {t_len} {pluralize('filter', t_len)}:"
2462
+ )
2463
+ for filter_func in self.chat_tool_filters:
2464
+ await filter_func(chat, context)
2465
+
2466
+ def exit(self, exit_code=0):
2467
+ if len(self.shutdown_handlers) > 0:
2468
+ _dbg(f"running {len(self.shutdown_handlers)} shutdown handlers...")
2469
+ for handler in self.shutdown_handlers:
2470
+ handler()
2471
+
2472
+ _dbg(f"exit({exit_code})")
2473
+ sys.exit(exit_code)
2474
+
2475
+
2476
+ def handler_name(handler):
2477
+ if hasattr(handler, "__name__"):
2478
+ return handler.__name__
2479
+ return "unknown"
2480
+
2481
+
2482
+ class ExtensionContext:
2483
+ def __init__(self, app, path):
2484
+ self.app = app
2485
+ self.cli_args = app.cli_args
2486
+ self.extra_args = app.extra_args
2487
+ self.error_auth_required = app.error_auth_required
2488
+ self.path = path
2489
+ self.name = os.path.basename(path)
2490
+ if self.name.endswith(".py"):
2491
+ self.name = self.name[:-3]
2492
+ self.ext_prefix = f"/ext/{self.name}"
2493
+ self.MOCK = MOCK
2494
+ self.MOCK_DIR = MOCK_DIR
2495
+ self.debug = DEBUG
2496
+ self.verbose = g_verbose
2497
+ self.aspect_ratios = app.aspect_ratios
2498
+ self.request_args = app.request_args
2499
+
2500
+ def chat_to_prompt(self, chat):
2501
+ return chat_to_prompt(chat)
2502
+
2503
+ def chat_to_system_prompt(self, chat):
2504
+ return chat_to_system_prompt(chat)
2505
+
2506
+ def chat_response_to_message(self, response):
2507
+ return chat_response_to_message(response)
2508
+
2509
+ def last_user_prompt(self, chat):
2510
+ return last_user_prompt(chat)
2511
+
2512
+ def to_file_info(self, chat, info=None, response=None):
2513
+ return to_file_info(chat, info=info, response=response)
2514
+
2515
+ def save_image_to_cache(self, base64_data, filename, image_info):
2516
+ return save_image_to_cache(base64_data, filename, image_info)
2517
+
2518
+ def save_bytes_to_cache(self, bytes_data, filename, file_info):
2519
+ return save_bytes_to_cache(bytes_data, filename, file_info)
2520
+
2521
+ def text_from_file(self, path):
2522
+ return text_from_file(path)
2523
+
2524
+ def log(self, message):
2525
+ if self.verbose:
2526
+ print(f"[{self.name}] {message}", flush=True)
2527
+ return message
2528
+
2529
+ def log_json(self, obj):
2530
+ if self.verbose:
2531
+ print(f"[{self.name}] {json.dumps(obj, indent=2)}", flush=True)
2532
+ return obj
2533
+
2534
+ def dbg(self, message):
2535
+ if self.debug:
2536
+ print(f"DEBUG [{self.name}]: {message}", flush=True)
1529
2537
 
1530
- if str(ui_path) not in file_mtimes:
1531
- file_mtimes[str(ui_path)] = mtime
1532
- elif file_mtimes[str(ui_path)] != mtime:
1533
- _log(f"Config file changed: {ui_path.name}")
1534
- file_mtimes[str(ui_path)] = mtime
1535
- _log("ui.json reloaded - reload page to update")
1536
- except FileNotFoundError:
1537
- pass
2538
+ def err(self, message, e):
2539
+ print(f"ERROR [{self.name}]: {message}", e)
2540
+ if self.verbose:
2541
+ print(traceback.format_exc(), flush=True)
2542
+
2543
+ def error_message(self, e):
2544
+ return to_error_message(e)
2545
+
2546
+ def error_response(self, e, stacktrace=False):
2547
+ return to_error_response(e, stacktrace=stacktrace)
2548
+
2549
+ def add_provider(self, provider):
2550
+ self.log(f"Registered provider: {provider.__name__}")
2551
+ self.app.all_providers.append(provider)
2552
+
2553
+ def register_ui_extension(self, index):
2554
+ path = os.path.join(self.ext_prefix, index)
2555
+ self.log(f"Registered UI extension: {path}")
2556
+ self.app.ui_extensions.append({"id": self.name, "path": path})
2557
+
2558
+ def register_chat_request_filter(self, handler):
2559
+ self.log(f"Registered chat request filter: {handler_name(handler)}")
2560
+ self.app.chat_request_filters.append(handler)
2561
+
2562
+ def register_chat_tool_filter(self, handler):
2563
+ self.log(f"Registered chat tool filter: {handler_name(handler)}")
2564
+ self.app.chat_tool_filters.append(handler)
2565
+
2566
+ def register_chat_response_filter(self, handler):
2567
+ self.log(f"Registered chat response filter: {handler_name(handler)}")
2568
+ self.app.chat_response_filters.append(handler)
2569
+
2570
+ def register_chat_error_filter(self, handler):
2571
+ self.log(f"Registered chat error filter: {handler_name(handler)}")
2572
+ self.app.chat_error_filters.append(handler)
2573
+
2574
+ def register_cache_saved_filter(self, handler):
2575
+ self.log(f"Registered cache saved filter: {handler_name(handler)}")
2576
+ self.app.cache_saved_filters.append(handler)
2577
+
2578
+ def register_shutdown_handler(self, handler):
2579
+ self.log(f"Registered shutdown handler: {handler_name(handler)}")
2580
+ self.app.shutdown_handlers.append(handler)
2581
+
2582
+ def add_static_files(self, ext_dir):
2583
+ self.log(f"Registered static files: {ext_dir}")
2584
+
2585
+ async def serve_static(request):
2586
+ path = request.match_info["path"]
2587
+ file_path = os.path.join(ext_dir, path)
2588
+ if os.path.exists(file_path):
2589
+ return web.FileResponse(file_path)
2590
+ return web.Response(status=404)
2591
+
2592
+ self.app.server_add_get.append((os.path.join(self.ext_prefix, "{path:.*}"), serve_static, {}))
2593
+
2594
+ def add_get(self, path, handler, **kwargs):
2595
+ self.dbg(f"Registered GET: {os.path.join(self.ext_prefix, path)}")
2596
+ self.app.server_add_get.append((os.path.join(self.ext_prefix, path), handler, kwargs))
2597
+
2598
+ def add_post(self, path, handler, **kwargs):
2599
+ self.dbg(f"Registered POST: {os.path.join(self.ext_prefix, path)}")
2600
+ self.app.server_add_post.append((os.path.join(self.ext_prefix, path), handler, kwargs))
2601
+
2602
+ def add_put(self, path, handler, **kwargs):
2603
+ self.dbg(f"Registered PUT: {os.path.join(self.ext_prefix, path)}")
2604
+ self.app.server_add_put.append((os.path.join(self.ext_prefix, path), handler, kwargs))
2605
+
2606
+ def add_delete(self, path, handler, **kwargs):
2607
+ self.dbg(f"Registered DELETE: {os.path.join(self.ext_prefix, path)}")
2608
+ self.app.server_add_delete.append((os.path.join(self.ext_prefix, path), handler, kwargs))
2609
+
2610
+ def add_patch(self, path, handler, **kwargs):
2611
+ self.dbg(f"Registered PATCH: {os.path.join(self.ext_prefix, path)}")
2612
+ self.app.server_add_patch.append((os.path.join(self.ext_prefix, path), handler, kwargs))
2613
+
2614
+ def add_importmaps(self, dict):
2615
+ self.app.import_maps.update(dict)
2616
+
2617
+ def add_index_header(self, html):
2618
+ self.app.index_headers.append(html)
2619
+
2620
+ def add_index_footer(self, html):
2621
+ self.app.index_footers.append(html)
2622
+
2623
+ def get_config(self):
2624
+ return g_config
2625
+
2626
+ def get_cache_path(self, path=""):
2627
+ return get_cache_path(path)
2628
+
2629
+ def chat_request(self, template=None, text=None, model=None, system_prompt=None):
2630
+ return self.app.chat_request(template=template, text=text, model=model, system_prompt=system_prompt)
2631
+
2632
+ def chat_completion(self, chat, context=None):
2633
+ return self.app.chat_completion(chat, context=context)
2634
+
2635
+ def get_providers(self):
2636
+ return g_handlers
2637
+
2638
+ def get_provider(self, name):
2639
+ return g_handlers.get(name)
2640
+
2641
+ def register_tool(self, func, tool_def=None):
2642
+ if tool_def is None:
2643
+ tool_def = function_to_tool_definition(func)
2644
+
2645
+ name = tool_def["function"]["name"]
2646
+ self.log(f"Registered tool: {name}")
2647
+ self.app.tools[name] = func
2648
+ self.app.tool_definitions.append(tool_def)
2649
+
2650
+ def check_auth(self, request):
2651
+ return self.app.check_auth(request)
2652
+
2653
+ def get_session(self, request):
2654
+ return self.app.get_session(request)
2655
+
2656
+ def get_username(self, request):
2657
+ return self.app.get_username(request)
2658
+
2659
+ def get_user_path(self, username=None):
2660
+ return self.app.get_user_path(username)
2661
+
2662
+ def should_cancel_thread(self, context):
2663
+ return should_cancel_thread(context)
2664
+
2665
+ def cache_message_inline_data(self, message):
2666
+ return cache_message_inline_data(message)
2667
+
2668
+ def to_content(self, result):
2669
+ return to_content(result)
2670
+
2671
+
2672
+ def get_extensions_path():
2673
+ return os.getenv("LLMS_EXTENSIONS_DIR", os.path.join(Path.home(), ".llms", "extensions"))
2674
+
2675
+
2676
+ def get_disabled_extensions():
2677
+ ret = DISABLE_EXTENSIONS.copy()
2678
+ if g_config:
2679
+ for ext in g_config.get("disable_extensions", []):
2680
+ if ext not in ret:
2681
+ ret.append(ext)
2682
+ return ret
2683
+
2684
+
2685
+ def get_extensions_dirs():
2686
+ """
2687
+ Returns a list of extension directories.
2688
+ """
2689
+ extensions_path = get_extensions_path()
2690
+ os.makedirs(extensions_path, exist_ok=True)
2691
+
2692
+ # allow overriding builtin extensions
2693
+ override_extensions = []
2694
+ if os.path.exists(extensions_path):
2695
+ override_extensions = os.listdir(extensions_path)
2696
+
2697
+ ret = []
2698
+ disabled_extensions = get_disabled_extensions()
2699
+
2700
+ builtin_extensions_dir = _ROOT / "extensions"
2701
+ if os.path.exists(builtin_extensions_dir):
2702
+ for item in os.listdir(builtin_extensions_dir):
2703
+ if os.path.isdir(os.path.join(builtin_extensions_dir, item)):
2704
+ if item in override_extensions:
2705
+ continue
2706
+ if item in disabled_extensions:
2707
+ continue
2708
+ ret.append(os.path.join(builtin_extensions_dir, item))
2709
+
2710
+ if os.path.exists(extensions_path):
2711
+ for item in os.listdir(extensions_path):
2712
+ if os.path.isdir(os.path.join(extensions_path, item)):
2713
+ if item in disabled_extensions:
2714
+ continue
2715
+ ret.append(os.path.join(extensions_path, item))
2716
+
2717
+ return ret
2718
+
2719
+
2720
+ def init_extensions(parser):
2721
+ """
2722
+ Initializes extensions by loading their __init__.py files and calling the __parser__ function if it exists.
2723
+ """
2724
+ for item_path in get_extensions_dirs():
2725
+ item = os.path.basename(item_path)
2726
+
2727
+ if os.path.isdir(item_path):
2728
+ try:
2729
+ # check for __parser__ function if exists in __init.__.py and call it with parser
2730
+ init_file = os.path.join(item_path, "__init__.py")
2731
+ if os.path.exists(init_file):
2732
+ spec = importlib.util.spec_from_file_location(item, init_file)
2733
+ if spec and spec.loader:
2734
+ module = importlib.util.module_from_spec(spec)
2735
+ sys.modules[item] = module
2736
+ spec.loader.exec_module(module)
2737
+
2738
+ parser_func = getattr(module, "__parser__", None)
2739
+ if callable(parser_func):
2740
+ parser_func(parser)
2741
+ _log(f"Extension {item} parser loaded")
2742
+ except Exception as e:
2743
+ _err(f"Failed to load extension {item} parser", e)
2744
+
2745
+
2746
+ def install_extensions():
2747
+ """
2748
+ Scans ensure ~/.llms/extensions/ for directories with __init__.py and loads them as extensions.
2749
+ Calls the `__install__(ctx)` function in the extension module.
2750
+ """
2751
+
2752
+ extension_dirs = get_extensions_dirs()
2753
+ ext_count = len(list(extension_dirs))
2754
+ if ext_count == 0:
2755
+ _log("No extensions found")
2756
+ return
2757
+
2758
+ disabled_extensions = get_disabled_extensions()
2759
+ if len(disabled_extensions) > 0:
2760
+ _log(f"Disabled extensions: {', '.join(disabled_extensions)}")
2761
+
2762
+ _log(f"Installing {ext_count} extension{'' if ext_count == 1 else 's'}...")
2763
+
2764
+ for item_path in extension_dirs:
2765
+ item = os.path.basename(item_path)
2766
+
2767
+ if os.path.isdir(item_path):
2768
+ sys.path.append(item_path)
2769
+ try:
2770
+ ctx = ExtensionContext(g_app, item_path)
2771
+ init_file = os.path.join(item_path, "__init__.py")
2772
+ if os.path.exists(init_file):
2773
+ spec = importlib.util.spec_from_file_location(item, init_file)
2774
+ if spec and spec.loader:
2775
+ module = importlib.util.module_from_spec(spec)
2776
+ sys.modules[item] = module
2777
+ spec.loader.exec_module(module)
2778
+
2779
+ install_func = getattr(module, "__install__", None)
2780
+ if callable(install_func):
2781
+ install_func(ctx)
2782
+ _log(f"Extension {item} installed")
2783
+ else:
2784
+ _dbg(f"Extension {item} has no __install__ function")
2785
+ else:
2786
+ _dbg(f"Extension {item} has no __init__.py")
2787
+ else:
2788
+ _dbg(f"Extension {init_file} not found")
2789
+
2790
+ # if ui folder exists, serve as static files at /ext/{item}/
2791
+ ui_path = os.path.join(item_path, "ui")
2792
+ if os.path.exists(ui_path):
2793
+ ctx.add_static_files(ui_path)
2794
+
2795
+ # Register UI extension if index.mjs exists (/ext/{item}/index.mjs)
2796
+ if os.path.exists(os.path.join(ui_path, "index.mjs")):
2797
+ ctx.register_ui_extension("index.mjs")
2798
+
2799
+ except Exception as e:
2800
+ _err(f"Failed to install extension {item}", e)
2801
+ else:
2802
+ _dbg(f"Extension {item} not found: {item_path} is not a directory {os.path.exists(item_path)}")
2803
+
2804
+
2805
+ def run_extension_cli():
2806
+ """
2807
+ Run the CLI for an extension.
2808
+ """
2809
+ for item_path in get_extensions_dirs():
2810
+ item = os.path.basename(item_path)
2811
+
2812
+ if os.path.isdir(item_path):
2813
+ init_file = os.path.join(item_path, "__init__.py")
2814
+ if os.path.exists(init_file):
2815
+ ctx = ExtensionContext(g_app, item_path)
2816
+ try:
2817
+ spec = importlib.util.spec_from_file_location(item, init_file)
2818
+ if spec and spec.loader:
2819
+ module = importlib.util.module_from_spec(spec)
2820
+ sys.modules[item] = module
2821
+ spec.loader.exec_module(module)
2822
+
2823
+ # Check for __run__ function if exists in __init__.py and call it with ctx
2824
+ run_func = getattr(module, "__run__", None)
2825
+ if callable(run_func):
2826
+ _log(f"Running extension {item}...")
2827
+ handled = run_func(ctx)
2828
+ return handled
2829
+
2830
+ except Exception as e:
2831
+ _err(f"Failed to run extension {item}", e)
2832
+ return False
1538
2833
 
1539
2834
 
1540
2835
  def main():
1541
- global _ROOT, g_verbose, g_default_model, g_logprefix, g_config, g_config_path, g_ui_path
2836
+ global _ROOT, g_verbose, g_default_model, g_logprefix, g_providers, g_config, g_config_path, g_app
2837
+
2838
+ _ROOT = os.getenv("LLMS_ROOT", resolve_root())
2839
+ if not _ROOT:
2840
+ print("Resource root not found")
2841
+ exit(1)
1542
2842
 
1543
2843
  parser = argparse.ArgumentParser(description=f"llms v{VERSION}")
1544
2844
  parser.add_argument("--config", default=None, help="Path to config file", metavar="FILE")
2845
+ parser.add_argument("--providers", default=None, help="Path to models.dev providers file", metavar="FILE")
1545
2846
  parser.add_argument("-m", "--model", default=None, help="Model to use")
1546
2847
 
1547
2848
  parser.add_argument("--chat", default=None, help="OpenAI Chat Completion Request to send", metavar="REQUEST")
1548
2849
  parser.add_argument(
1549
2850
  "-s", "--system", default=None, help="System prompt to use for chat completion", metavar="PROMPT"
1550
2851
  )
2852
+ parser.add_argument(
2853
+ "--tools", default=None, help="Tools to use for chat completion (all|none|<tool>,<tool>...)", metavar="TOOLS"
2854
+ )
1551
2855
  parser.add_argument("--image", default=None, help="Image input to use in chat completion")
1552
2856
  parser.add_argument("--audio", default=None, help="Audio input to use in chat completion")
1553
2857
  parser.add_argument("--file", default=None, help="File input to use in chat completion")
2858
+ parser.add_argument("--out", default=None, help="Image or Video Generation Request", metavar="MODALITY")
1554
2859
  parser.add_argument(
1555
2860
  "--args",
1556
2861
  default=None,
@@ -1573,15 +2878,46 @@ def main():
1573
2878
  parser.add_argument("--default", default=None, help="Configure the default model to use", metavar="MODEL")
1574
2879
 
1575
2880
  parser.add_argument("--init", action="store_true", help="Create a default llms.json")
2881
+ parser.add_argument("--update-providers", action="store_true", help="Update local models.dev providers.json")
1576
2882
 
1577
- parser.add_argument("--root", default=None, help="Change root directory for UI files", metavar="PATH")
1578
2883
  parser.add_argument("--logprefix", default="", help="Prefix used in log messages", metavar="PREFIX")
1579
2884
  parser.add_argument("--verbose", action="store_true", help="Verbose output")
1580
2885
 
2886
+ parser.add_argument(
2887
+ "--add",
2888
+ nargs="?",
2889
+ const="ls",
2890
+ default=None,
2891
+ help="Install an extension (lists available extensions if no name provided)",
2892
+ metavar="EXTENSION",
2893
+ )
2894
+ parser.add_argument(
2895
+ "--remove",
2896
+ nargs="?",
2897
+ const="ls",
2898
+ default=None,
2899
+ help="Remove an extension (lists installed extensions if no name provided)",
2900
+ metavar="EXTENSION",
2901
+ )
2902
+
2903
+ parser.add_argument(
2904
+ "--update",
2905
+ nargs="?",
2906
+ const="ls",
2907
+ default=None,
2908
+ help="Update an extension (use 'all' to update all extensions)",
2909
+ metavar="EXTENSION",
2910
+ )
2911
+
2912
+ # Load parser extensions, go through all extensions and load their parser arguments
2913
+ init_extensions(parser)
2914
+
1581
2915
  cli_args, extra_args = parser.parse_known_args()
1582
2916
 
2917
+ g_app = AppExtensions(cli_args, extra_args)
2918
+
1583
2919
  # Check for verbose mode from CLI argument or environment variables
1584
- verbose_env = os.environ.get("VERBOSE", "").lower()
2920
+ verbose_env = os.getenv("VERBOSE", "").lower()
1585
2921
  if cli_args.verbose or verbose_env in ("1", "true"):
1586
2922
  g_verbose = True
1587
2923
  # printdump(cli_args)
@@ -1590,13 +2926,9 @@ def main():
1590
2926
  if cli_args.logprefix:
1591
2927
  g_logprefix = cli_args.logprefix
1592
2928
 
1593
- _ROOT = Path(cli_args.root) if cli_args.root else resolve_root()
1594
- if not _ROOT:
1595
- print("Resource root not found")
1596
- exit(1)
1597
-
1598
2929
  home_config_path = home_llms_path("llms.json")
1599
- home_ui_path = home_llms_path("ui.json")
2930
+ home_providers_path = home_llms_path("providers.json")
2931
+ home_providers_extra_path = home_llms_path("providers-extra.json")
1600
2932
 
1601
2933
  if cli_args.init:
1602
2934
  if os.path.exists(home_config_path):
@@ -1605,38 +2937,215 @@ def main():
1605
2937
  asyncio.run(save_default_config(home_config_path))
1606
2938
  print(f"Created default config at {home_config_path}")
1607
2939
 
1608
- if os.path.exists(home_ui_path):
1609
- print(f"ui.json already exists at {home_ui_path}")
2940
+ if os.path.exists(home_providers_path):
2941
+ print(f"providers.json already exists at {home_providers_path}")
2942
+ else:
2943
+ asyncio.run(save_text_url(github_url("providers.json"), home_providers_path))
2944
+ print(f"Created default providers config at {home_providers_path}")
2945
+
2946
+ if os.path.exists(home_providers_extra_path):
2947
+ print(f"providers-extra.json already exists at {home_providers_extra_path}")
1610
2948
  else:
1611
- asyncio.run(save_text_url(github_url("ui.json"), home_ui_path))
1612
- print(f"Created default ui config at {home_ui_path}")
2949
+ asyncio.run(save_text_url(github_url("providers-extra.json"), home_providers_extra_path))
2950
+ print(f"Created default extra providers config at {home_providers_extra_path}")
1613
2951
  exit(0)
1614
2952
 
2953
+ if cli_args.providers:
2954
+ if not os.path.exists(cli_args.providers):
2955
+ print(f"providers.json not found at {cli_args.providers}")
2956
+ exit(1)
2957
+ g_providers = json.loads(text_from_file(cli_args.providers))
2958
+
1615
2959
  if cli_args.config:
1616
2960
  # read contents
1617
2961
  g_config_path = cli_args.config
1618
2962
  with open(g_config_path, encoding="utf-8") as f:
1619
2963
  config_json = f.read()
1620
- g_config = json.loads(config_json)
2964
+ g_config = load_config_json(config_json)
1621
2965
 
1622
2966
  config_dir = os.path.dirname(g_config_path)
1623
- # look for ui.json in same directory as config
1624
- ui_path = os.path.join(config_dir, "ui.json")
1625
- if os.path.exists(ui_path):
1626
- g_ui_path = ui_path
1627
- else:
1628
- if not os.path.exists(home_ui_path):
1629
- ui_json = text_from_resource("ui.json")
1630
- with open(home_ui_path, "w", encoding="utf-8") as f:
1631
- f.write(ui_json)
1632
- _log(f"Created default ui config at {home_ui_path}")
1633
- g_ui_path = home_ui_path
2967
+
2968
+ if not g_providers and os.path.exists(os.path.join(config_dir, "providers.json")):
2969
+ g_providers = json.loads(text_from_file(os.path.join(config_dir, "providers.json")))
2970
+
1634
2971
  else:
1635
- # ensure llms.json and ui.json exist in home directory
2972
+ # ensure llms.json and providers.json exist in home directory
1636
2973
  asyncio.run(save_home_configs())
1637
2974
  g_config_path = home_config_path
1638
- g_ui_path = home_ui_path
1639
- g_config = json.loads(text_from_file(g_config_path))
2975
+ g_config = load_config_json(text_from_file(g_config_path))
2976
+
2977
+ g_app.set_config(g_config)
2978
+
2979
+ if not g_providers:
2980
+ g_providers = json.loads(text_from_file(home_providers_path))
2981
+
2982
+ if cli_args.update_providers:
2983
+ asyncio.run(update_providers(home_providers_path))
2984
+ print(f"Updated {home_providers_path}")
2985
+ exit(0)
2986
+
2987
+ # if home_providers_path is older than 1 day, update providers list
2988
+ if (
2989
+ os.path.exists(home_providers_path)
2990
+ and (time.time() - os.path.getmtime(home_providers_path)) > 86400
2991
+ and os.getenv("LLMS_DISABLE_UPDATE", "") != "1"
2992
+ ):
2993
+ try:
2994
+ asyncio.run(update_providers(home_providers_path))
2995
+ _log(f"Updated {home_providers_path}")
2996
+ except Exception as e:
2997
+ _err("Failed to update providers", e)
2998
+
2999
+ if cli_args.add is not None:
3000
+ if cli_args.add == "ls":
3001
+
3002
+ async def list_extensions():
3003
+ print("\nAvailable extensions:")
3004
+ text = await get_text("https://api.github.com/orgs/llmspy/repos?per_page=100&sort=updated")
3005
+ repos = json.loads(text)
3006
+ max_name_length = 0
3007
+ for repo in repos:
3008
+ max_name_length = max(max_name_length, len(repo["name"]))
3009
+
3010
+ for repo in repos:
3011
+ print(f" {repo['name']:<{max_name_length + 2}} {repo['description']}")
3012
+
3013
+ print("\nUsage:")
3014
+ print(" llms --add <extension>")
3015
+ print(" llms --add <github-user>/<repo>")
3016
+
3017
+ asyncio.run(list_extensions())
3018
+ exit(0)
3019
+
3020
+ async def install_extension(name):
3021
+ # Determine git URL and target directory name
3022
+ if "/" in name:
3023
+ git_url = f"https://github.com/{name}"
3024
+ target_name = name.split("/")[-1]
3025
+ else:
3026
+ git_url = f"https://github.com/llmspy/{name}"
3027
+ target_name = name
3028
+
3029
+ # check extension is not already installed
3030
+ extensions_path = get_extensions_path()
3031
+ target_path = os.path.join(extensions_path, target_name)
3032
+
3033
+ if os.path.exists(target_path):
3034
+ print(f"Extension {target_name} is already installed at {target_path}")
3035
+ return
3036
+
3037
+ print(f"Installing extension: {name}")
3038
+ print(f"Cloning from {git_url} to {target_path}...")
3039
+
3040
+ try:
3041
+ subprocess.run(["git", "clone", git_url, target_path], check=True)
3042
+
3043
+ # Check for requirements.txt
3044
+ requirements_path = os.path.join(target_path, "requirements.txt")
3045
+ if os.path.exists(requirements_path):
3046
+ print(f"Installing dependencies from {requirements_path}...")
3047
+
3048
+ # Check if uv is installed
3049
+ has_uv = False
3050
+ try:
3051
+ subprocess.run(
3052
+ ["uv", "--version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True
3053
+ )
3054
+ has_uv = True
3055
+ except (subprocess.CalledProcessError, FileNotFoundError):
3056
+ pass
3057
+
3058
+ if has_uv:
3059
+ subprocess.run(
3060
+ ["uv", "pip", "install", "-p", sys.executable, "-r", "requirements.txt"],
3061
+ cwd=target_path,
3062
+ check=True,
3063
+ )
3064
+ else:
3065
+ subprocess.run(
3066
+ [sys.executable, "-m", "pip", "install", "-r", "requirements.txt"],
3067
+ cwd=target_path,
3068
+ check=True,
3069
+ )
3070
+ print("Dependencies installed successfully.")
3071
+
3072
+ print(f"Extension {target_name} installed successfully.")
3073
+
3074
+ except subprocess.CalledProcessError as e:
3075
+ print(f"Failed to install extension: {e}")
3076
+ # cleanup if clone failed but directory was created (unlikely with simple git clone but good practice)
3077
+ if os.path.exists(target_path) and not os.listdir(target_path):
3078
+ os.rmdir(target_path)
3079
+
3080
+ asyncio.run(install_extension(cli_args.add))
3081
+ exit(0)
3082
+
3083
+ if cli_args.remove is not None:
3084
+ if cli_args.remove == "ls":
3085
+ # List installed extensions
3086
+ extensions_path = get_extensions_path()
3087
+ extensions = os.listdir(extensions_path)
3088
+ if len(extensions) == 0:
3089
+ print("No extensions installed.")
3090
+ exit(0)
3091
+ print("Installed extensions:")
3092
+ for extension in extensions:
3093
+ print(f" {extension}")
3094
+ exit(0)
3095
+ # Remove an extension
3096
+ extension_name = cli_args.remove
3097
+ extensions_path = get_extensions_path()
3098
+ target_path = os.path.join(extensions_path, extension_name)
3099
+
3100
+ if not os.path.exists(target_path):
3101
+ print(f"Extension {extension_name} not found at {target_path}")
3102
+ exit(1)
3103
+
3104
+ print(f"Removing extension: {extension_name}...")
3105
+ try:
3106
+ shutil.rmtree(target_path)
3107
+ print(f"Extension {extension_name} removed successfully.")
3108
+ except Exception as e:
3109
+ print(f"Failed to remove extension: {e}")
3110
+ exit(1)
3111
+
3112
+ exit(0)
3113
+
3114
+ if cli_args.update:
3115
+ if cli_args.update == "ls":
3116
+ # List installed extensions
3117
+ extensions_path = get_extensions_path()
3118
+ extensions = os.listdir(extensions_path)
3119
+ if len(extensions) == 0:
3120
+ print("No extensions installed.")
3121
+ exit(0)
3122
+ print("Installed extensions:")
3123
+ for extension in extensions:
3124
+ print(f" {extension}")
3125
+
3126
+ print("\nUsage:")
3127
+ print(" llms --update <extension>")
3128
+ print(" llms --update all")
3129
+ exit(0)
3130
+
3131
+ async def update_extensions(extension_name):
3132
+ extensions_path = get_extensions_path()
3133
+ for extension in os.listdir(extensions_path):
3134
+ extension_path = os.path.join(extensions_path, extension)
3135
+ if os.path.isdir(extension_path):
3136
+ if extension_name != "all" and extension != extension_name:
3137
+ continue
3138
+ result = subprocess.run(["git", "pull"], cwd=extension_path, capture_output=True)
3139
+ if result.returncode != 0:
3140
+ print(f"Failed to update extension {extension}: {result.stderr.decode('utf-8')}")
3141
+ continue
3142
+ print(f"Updated extension {extension}")
3143
+ _log(result.stdout.decode("utf-8"))
3144
+
3145
+ asyncio.run(update_extensions(cli_args.update))
3146
+ exit(0)
3147
+
3148
+ install_extensions()
1640
3149
 
1641
3150
  asyncio.run(reload_providers())
1642
3151
 
@@ -1654,23 +3163,45 @@ def main():
1654
3163
  if cli_args.list:
1655
3164
  # Show list of enabled providers and their models
1656
3165
  enabled = []
3166
+ provider_count = 0
3167
+ model_count = 0
3168
+
3169
+ max_model_length = 0
3170
+ for name, provider in g_handlers.items():
3171
+ if len(filter_list) > 0 and name not in filter_list:
3172
+ continue
3173
+ for model in provider.models:
3174
+ max_model_length = max(max_model_length, len(model))
3175
+
1657
3176
  for name, provider in g_handlers.items():
1658
3177
  if len(filter_list) > 0 and name not in filter_list:
1659
3178
  continue
3179
+ provider_count += 1
1660
3180
  print(f"{name}:")
1661
3181
  enabled.append(name)
1662
3182
  for model in provider.models:
1663
- print(f" {model}")
3183
+ model_count += 1
3184
+ model_cost_info = None
3185
+ if "cost" in provider.models[model]:
3186
+ model_cost = provider.models[model]["cost"]
3187
+ if "input" in model_cost and "output" in model_cost:
3188
+ if model_cost["input"] == 0 and model_cost["output"] == 0:
3189
+ model_cost_info = " 0"
3190
+ else:
3191
+ model_cost_info = f"{model_cost['input']:5} / {model_cost['output']}"
3192
+ print(f" {model:{max_model_length}} {model_cost_info or ''}")
3193
+
3194
+ print(f"\n{model_count} models available from {provider_count} providers")
1664
3195
 
1665
3196
  print_status()
1666
- exit(0)
3197
+ g_app.exit(0)
1667
3198
 
1668
3199
  if cli_args.check is not None:
1669
3200
  # Check validity of models for a provider
1670
3201
  provider_name = cli_args.check
1671
3202
  model_names = extra_args if len(extra_args) > 0 else None
1672
3203
  asyncio.run(check_models(provider_name, model_names))
1673
- exit(0)
3204
+ g_app.exit(0)
1674
3205
 
1675
3206
  if cli_args.serve is not None:
1676
3207
  # Disable inactive providers and save to config before starting server
@@ -1690,10 +3221,6 @@ def main():
1690
3221
  # Start server
1691
3222
  port = int(cli_args.serve)
1692
3223
 
1693
- if not os.path.exists(g_ui_path):
1694
- print(f"UI not found at {g_ui_path}")
1695
- exit(1)
1696
-
1697
3224
  # Validate auth configuration if enabled
1698
3225
  auth_enabled = g_config.get("auth", {}).get("enabled", False)
1699
3226
  if auth_enabled:
@@ -1703,11 +3230,19 @@ def main():
1703
3230
 
1704
3231
  # Expand environment variables
1705
3232
  if client_id.startswith("$"):
1706
- client_id = os.environ.get(client_id[1:], "")
3233
+ client_id = client_id[1:]
1707
3234
  if client_secret.startswith("$"):
1708
- client_secret = os.environ.get(client_secret[1:], "")
3235
+ client_secret = client_secret[1:]
1709
3236
 
1710
- if not client_id or not client_secret:
3237
+ client_id = os.getenv(client_id, client_id)
3238
+ client_secret = os.getenv(client_secret, client_secret)
3239
+
3240
+ if (
3241
+ not client_id
3242
+ or not client_secret
3243
+ or client_id == "GITHUB_CLIENT_ID"
3244
+ or client_secret == "GITHUB_CLIENT_SECRET"
3245
+ ):
1711
3246
  print("ERROR: Authentication is enabled but GitHub OAuth is not properly configured.")
1712
3247
  print("Please set GITHUB_CLIENT_ID and GITHUB_CLIENT_SECRET environment variables,")
1713
3248
  print("or disable authentication by setting 'auth.enabled' to false in llms.json")
@@ -1721,60 +3256,35 @@ def main():
1721
3256
  _log(f"client_max_size set to {client_max_size} bytes ({client_max_size / 1024 / 1024:.1f}MB)")
1722
3257
  app = web.Application(client_max_size=client_max_size)
1723
3258
 
1724
- # Authentication middleware helper
1725
- def check_auth(request):
1726
- """Check if request is authenticated. Returns (is_authenticated, user_data)"""
1727
- if not auth_enabled:
1728
- return True, None
1729
-
1730
- # Check for OAuth session token
1731
- session_token = request.query.get("session") or request.headers.get("X-Session-Token")
1732
- if session_token and session_token in g_sessions:
1733
- return True, g_sessions[session_token]
1734
-
1735
- # Check for API key
1736
- auth_header = request.headers.get("Authorization", "")
1737
- if auth_header.startswith("Bearer "):
1738
- api_key = auth_header[7:]
1739
- if api_key:
1740
- return True, {"authProvider": "apikey"}
1741
-
1742
- return False, None
1743
-
1744
3259
  async def chat_handler(request):
1745
3260
  # Check authentication if enabled
1746
- is_authenticated, user_data = check_auth(request)
3261
+ is_authenticated, user_data = g_app.check_auth(request)
1747
3262
  if not is_authenticated:
1748
- return web.json_response(
1749
- {
1750
- "error": {
1751
- "message": "Authentication required",
1752
- "type": "authentication_error",
1753
- "code": "unauthorized",
1754
- }
1755
- },
1756
- status=401,
1757
- )
3263
+ return web.json_response(g_app.error_auth_required, status=401)
1758
3264
 
1759
3265
  try:
1760
3266
  chat = await request.json()
1761
- response = await chat_completion(chat)
3267
+ context = {"chat": chat, "request": request, "user": g_app.get_username(request)}
3268
+ metadata = chat.get("metadata", {})
3269
+ context["threadId"] = metadata.get("threadId", None)
3270
+ context["tools"] = metadata.get("tools", "all")
3271
+ response = await g_app.chat_completion(chat, context)
1762
3272
  return web.json_response(response)
1763
3273
  except Exception as e:
1764
- return web.json_response({"error": str(e)}, status=500)
3274
+ return web.json_response(to_error_response(e), status=500)
1765
3275
 
1766
3276
  app.router.add_post("/v1/chat/completions", chat_handler)
1767
3277
 
1768
- async def models_handler(request):
1769
- return web.json_response(get_models())
1770
-
1771
- app.router.add_get("/models/list", models_handler)
1772
-
1773
3278
  async def active_models_handler(request):
1774
3279
  return web.json_response(get_active_models())
1775
3280
 
1776
3281
  app.router.add_get("/models", active_models_handler)
1777
3282
 
3283
+ async def active_providers_handler(request):
3284
+ return web.json_response(api_providers())
3285
+
3286
+ app.router.add_get("/providers", active_providers_handler)
3287
+
1778
3288
  async def status_handler(request):
1779
3289
  enabled, disabled = provider_status()
1780
3290
  return web.json_response(
@@ -1794,8 +3304,9 @@ def main():
1794
3304
  if provider:
1795
3305
  if data.get("enable", False):
1796
3306
  provider_config, msg = enable_provider(provider)
1797
- _log(f"Enabled provider {provider}")
1798
- await load_llms()
3307
+ _log(f"Enabled provider {provider} {msg}")
3308
+ if not msg:
3309
+ await load_llms()
1799
3310
  elif data.get("disable", False):
1800
3311
  disable_provider(provider)
1801
3312
  _log(f"Disabled provider {provider}")
@@ -1810,11 +3321,144 @@ def main():
1810
3321
 
1811
3322
  app.router.add_post("/providers/{provider}", provider_handler)
1812
3323
 
3324
+ async def upload_handler(request):
3325
+ # Check authentication if enabled
3326
+ is_authenticated, user_data = g_app.check_auth(request)
3327
+ if not is_authenticated:
3328
+ return web.json_response(g_app.error_auth_required, status=401)
3329
+
3330
+ reader = await request.multipart()
3331
+
3332
+ # Read first file field
3333
+ field = await reader.next()
3334
+ while field and field.name != "file":
3335
+ field = await reader.next()
3336
+
3337
+ if not field:
3338
+ return web.json_response(create_error_response("No file provided"), status=400)
3339
+
3340
+ filename = field.filename or "file"
3341
+ content = await field.read()
3342
+ mimetype = get_file_mime_type(filename)
3343
+
3344
+ # If image, resize if needed
3345
+ if mimetype.startswith("image/"):
3346
+ content, mimetype = convert_image_if_needed(content, mimetype)
3347
+
3348
+ # Calculate SHA256
3349
+ sha256_hash = hashlib.sha256(content).hexdigest()
3350
+ ext = filename.rsplit(".", 1)[1] if "." in filename else ""
3351
+ if not ext:
3352
+ ext = mimetypes.guess_extension(mimetype) or ""
3353
+ if ext.startswith("."):
3354
+ ext = ext[1:]
3355
+
3356
+ if not ext:
3357
+ ext = "bin"
3358
+
3359
+ save_filename = f"{sha256_hash}.{ext}" if ext else sha256_hash
3360
+
3361
+ # Use first 2 chars for subdir to avoid too many files in one dir
3362
+ subdir = sha256_hash[:2]
3363
+ relative_path = f"{subdir}/{save_filename}"
3364
+ full_path = get_cache_path(relative_path)
3365
+
3366
+ # if file and its .info.json already exists, return it
3367
+ info_path = os.path.splitext(full_path)[0] + ".info.json"
3368
+ if os.path.exists(full_path) and os.path.exists(info_path):
3369
+ return web.json_response(json.load(open(info_path)))
3370
+
3371
+ os.makedirs(os.path.dirname(full_path), exist_ok=True)
3372
+
3373
+ with open(full_path, "wb") as f:
3374
+ f.write(content)
3375
+
3376
+ url = f"/~cache/{relative_path}"
3377
+ response_data = {
3378
+ "date": int(time.time()),
3379
+ "url": url,
3380
+ "size": len(content),
3381
+ "type": mimetype,
3382
+ "name": filename,
3383
+ }
3384
+
3385
+ # If image, get dimensions
3386
+ if HAS_PIL and mimetype.startswith("image/"):
3387
+ try:
3388
+ with Image.open(BytesIO(content)) as img:
3389
+ response_data["width"] = img.width
3390
+ response_data["height"] = img.height
3391
+ except Exception:
3392
+ pass
3393
+
3394
+ # Save metadata
3395
+ info_path = os.path.splitext(full_path)[0] + ".info.json"
3396
+ with open(info_path, "w") as f:
3397
+ json.dump(response_data, f)
3398
+
3399
+ g_app.on_cache_saved_filters({"url": url, "info": response_data})
3400
+
3401
+ return web.json_response(response_data)
3402
+
3403
+ app.router.add_post("/upload", upload_handler)
3404
+
3405
+ async def extensions_handler(request):
3406
+ return web.json_response(g_app.ui_extensions)
3407
+
3408
+ app.router.add_get("/ext", extensions_handler)
3409
+
3410
+ async def tools_handler(request):
3411
+ return web.json_response(g_app.tool_definitions)
3412
+
3413
+ app.router.add_get("/ext/tools", tools_handler)
3414
+
3415
+ async def cache_handler(request):
3416
+ path = request.match_info["tail"]
3417
+ full_path = get_cache_path(path)
3418
+
3419
+ if "info" in request.query:
3420
+ info_path = os.path.splitext(full_path)[0] + ".info.json"
3421
+ if not os.path.exists(info_path):
3422
+ return web.Response(text="404: Not Found", status=404)
3423
+
3424
+ # Check for directory traversal for info path
3425
+ try:
3426
+ cache_root = Path(get_cache_path())
3427
+ requested_path = Path(info_path).resolve()
3428
+ if not str(requested_path).startswith(str(cache_root)):
3429
+ return web.Response(text="403: Forbidden", status=403)
3430
+ except Exception:
3431
+ return web.Response(text="403: Forbidden", status=403)
3432
+
3433
+ with open(info_path) as f:
3434
+ content = f.read()
3435
+ return web.Response(text=content, content_type="application/json")
3436
+
3437
+ if not os.path.exists(full_path):
3438
+ return web.Response(text="404: Not Found", status=404)
3439
+
3440
+ # Check for directory traversal
3441
+ try:
3442
+ cache_root = Path(get_cache_path())
3443
+ requested_path = Path(full_path).resolve()
3444
+ if not str(requested_path).startswith(str(cache_root)):
3445
+ return web.Response(text="403: Forbidden", status=403)
3446
+ except Exception:
3447
+ return web.Response(text="403: Forbidden", status=403)
3448
+
3449
+ with open(full_path, "rb") as f:
3450
+ content = f.read()
3451
+
3452
+ mimetype = get_file_mime_type(full_path)
3453
+ return web.Response(body=content, content_type=mimetype)
3454
+
3455
+ app.router.add_get("/~cache/{tail:.*}", cache_handler)
3456
+
1813
3457
  # OAuth handlers
1814
3458
  async def github_auth_handler(request):
1815
3459
  """Initiate GitHub OAuth flow"""
1816
3460
  if "auth" not in g_config or "github" not in g_config["auth"]:
1817
- return web.json_response({"error": "GitHub OAuth not configured"}, status=500)
3461
+ return web.json_response(create_error_response("GitHub OAuth not configured"), status=500)
1818
3462
 
1819
3463
  auth_config = g_config["auth"]["github"]
1820
3464
  client_id = auth_config.get("client_id", "")
@@ -1822,12 +3466,15 @@ def main():
1822
3466
 
1823
3467
  # Expand environment variables
1824
3468
  if client_id.startswith("$"):
1825
- client_id = os.environ.get(client_id[1:], "")
3469
+ client_id = client_id[1:]
1826
3470
  if redirect_uri.startswith("$"):
1827
- redirect_uri = os.environ.get(redirect_uri[1:], "")
3471
+ redirect_uri = redirect_uri[1:]
3472
+
3473
+ client_id = os.getenv(client_id, client_id)
3474
+ redirect_uri = os.getenv(redirect_uri, redirect_uri)
1828
3475
 
1829
3476
  if not client_id:
1830
- return web.json_response({"error": "GitHub client_id not configured"}, status=500)
3477
+ return web.json_response(create_error_response("GitHub client_id not configured"), status=500)
1831
3478
 
1832
3479
  # Generate CSRF state token
1833
3480
  state = secrets.token_urlsafe(32)
@@ -1857,7 +3504,9 @@ def main():
1857
3504
 
1858
3505
  # Expand environment variables
1859
3506
  if restrict_to.startswith("$"):
1860
- restrict_to = os.environ.get(restrict_to[1:], "")
3507
+ restrict_to = restrict_to[1:]
3508
+
3509
+ restrict_to = os.getenv(restrict_to, None if restrict_to == "GITHUB_USERS" else restrict_to)
1861
3510
 
1862
3511
  # If restrict_to is configured, validate the user
1863
3512
  if restrict_to:
@@ -1878,6 +3527,14 @@ def main():
1878
3527
  code = request.query.get("code")
1879
3528
  state = request.query.get("state")
1880
3529
 
3530
+ # Handle malformed URLs where query params are appended with & instead of ?
3531
+ if not code and "tail" in request.match_info:
3532
+ tail = request.match_info["tail"]
3533
+ if tail.startswith("&"):
3534
+ params = parse_qs(tail[1:])
3535
+ code = params.get("code", [None])[0]
3536
+ state = params.get("state", [None])[0]
3537
+
1881
3538
  if not code or not state:
1882
3539
  return web.Response(text="Missing code or state parameter", status=400)
1883
3540
 
@@ -1888,7 +3545,7 @@ def main():
1888
3545
  g_oauth_states.pop(state)
1889
3546
 
1890
3547
  if "auth" not in g_config or "github" not in g_config["auth"]:
1891
- return web.json_response({"error": "GitHub OAuth not configured"}, status=500)
3548
+ return web.json_response(create_error_response("GitHub OAuth not configured"), status=500)
1892
3549
 
1893
3550
  auth_config = g_config["auth"]["github"]
1894
3551
  client_id = auth_config.get("client_id", "")
@@ -1897,14 +3554,18 @@ def main():
1897
3554
 
1898
3555
  # Expand environment variables
1899
3556
  if client_id.startswith("$"):
1900
- client_id = os.environ.get(client_id[1:], "")
3557
+ client_id = client_id[1:]
1901
3558
  if client_secret.startswith("$"):
1902
- client_secret = os.environ.get(client_secret[1:], "")
3559
+ client_secret = client_secret[1:]
1903
3560
  if redirect_uri.startswith("$"):
1904
- redirect_uri = os.environ.get(redirect_uri[1:], "")
3561
+ redirect_uri = redirect_uri[1:]
3562
+
3563
+ client_id = os.getenv(client_id, client_id)
3564
+ client_secret = os.getenv(client_secret, client_secret)
3565
+ redirect_uri = os.getenv(redirect_uri, redirect_uri)
1905
3566
 
1906
3567
  if not client_id or not client_secret:
1907
- return web.json_response({"error": "GitHub OAuth credentials not configured"}, status=500)
3568
+ return web.json_response(create_error_response("GitHub OAuth credentials not configured"), status=500)
1908
3569
 
1909
3570
  # Exchange code for access token
1910
3571
  async with aiohttp.ClientSession() as session:
@@ -1923,7 +3584,7 @@ def main():
1923
3584
 
1924
3585
  if not access_token:
1925
3586
  error = token_response.get("error_description", "Failed to get access token")
1926
- return web.Response(text=f"OAuth error: {error}", status=400)
3587
+ return web.json_response(create_error_response(f"OAuth error: {error}"), status=400)
1927
3588
 
1928
3589
  # Fetch user info
1929
3590
  user_url = "https://api.github.com/user"
@@ -1949,14 +3610,16 @@ def main():
1949
3610
  }
1950
3611
 
1951
3612
  # Redirect to UI with session token
1952
- return web.HTTPFound(f"/?session={session_token}")
3613
+ response = web.HTTPFound(f"/?session={session_token}")
3614
+ response.set_cookie("llms-token", session_token, httponly=True, path="/", max_age=86400)
3615
+ return response
1953
3616
 
1954
3617
  async def session_handler(request):
1955
3618
  """Validate and return session info"""
1956
- session_token = request.query.get("session") or request.headers.get("X-Session-Token")
3619
+ session_token = get_session_token(request)
1957
3620
 
1958
3621
  if not session_token or session_token not in g_sessions:
1959
- return web.json_response({"error": "Invalid or expired session"}, status=401)
3622
+ return web.json_response(create_error_response("Invalid or expired session"), status=401)
1960
3623
 
1961
3624
  session_data = g_sessions[session_token]
1962
3625
 
@@ -1970,17 +3633,19 @@ def main():
1970
3633
 
1971
3634
  async def logout_handler(request):
1972
3635
  """End OAuth session"""
1973
- session_token = request.query.get("session") or request.headers.get("X-Session-Token")
3636
+ session_token = get_session_token(request)
1974
3637
 
1975
3638
  if session_token and session_token in g_sessions:
1976
3639
  del g_sessions[session_token]
1977
3640
 
1978
- return web.json_response({"success": True})
3641
+ response = web.json_response({"success": True})
3642
+ response.del_cookie("llms-token")
3643
+ return response
1979
3644
 
1980
3645
  async def auth_handler(request):
1981
3646
  """Check authentication status and return user info"""
1982
3647
  # Check for OAuth session token
1983
- session_token = request.query.get("session") or request.headers.get("X-Session-Token")
3648
+ session_token = get_session_token(request)
1984
3649
 
1985
3650
  if session_token and session_token in g_sessions:
1986
3651
  session_data = g_sessions[session_token]
@@ -2010,13 +3675,12 @@ def main():
2010
3675
  # })
2011
3676
 
2012
3677
  # Not authenticated - return error in expected format
2013
- return web.json_response(
2014
- {"responseStatus": {"errorCode": "Unauthorized", "message": "Not authenticated"}}, status=401
2015
- )
3678
+ return web.json_response(g_app.error_auth_required, status=401)
2016
3679
 
2017
3680
  app.router.add_get("/auth", auth_handler)
2018
3681
  app.router.add_get("/auth/github", github_auth_handler)
2019
3682
  app.router.add_get("/auth/github/callback", github_callback_handler)
3683
+ app.router.add_get("/auth/github/callback{tail:.*}", github_callback_handler)
2020
3684
  app.router.add_get("/auth/session", session_handler)
2021
3685
  app.router.add_post("/auth/logout", logout_handler)
2022
3686
 
@@ -2051,30 +3715,101 @@ def main():
2051
3715
 
2052
3716
  app.router.add_get("/ui/{path:.*}", ui_static, name="ui_static")
2053
3717
 
2054
- async def ui_config_handler(request):
2055
- with open(g_ui_path, encoding="utf-8") as f:
2056
- ui = json.load(f)
2057
- if "defaults" not in ui:
2058
- ui["defaults"] = g_config["defaults"]
2059
- enabled, disabled = provider_status()
2060
- ui["status"] = {"all": list(g_config["providers"].keys()), "enabled": enabled, "disabled": disabled}
2061
- # Add auth configuration
2062
- ui["requiresAuth"] = auth_enabled
2063
- ui["authType"] = "oauth" if auth_enabled else "apikey"
2064
- return web.json_response(ui)
3718
+ async def config_handler(request):
3719
+ ret = {}
3720
+ if "defaults" not in ret:
3721
+ ret["defaults"] = g_config["defaults"]
3722
+ enabled, disabled = provider_status()
3723
+ ret["status"] = {"all": list(g_config["providers"].keys()), "enabled": enabled, "disabled": disabled}
3724
+ # Add auth configuration
3725
+ ret["requiresAuth"] = auth_enabled
3726
+ ret["authType"] = "oauth" if auth_enabled else "apikey"
3727
+ return web.json_response(ret)
2065
3728
 
2066
- app.router.add_get("/config", ui_config_handler)
3729
+ app.router.add_get("/config", config_handler)
2067
3730
 
2068
3731
  async def not_found_handler(request):
2069
3732
  return web.Response(text="404: Not Found", status=404)
2070
3733
 
2071
3734
  app.router.add_get("/favicon.ico", not_found_handler)
2072
3735
 
3736
+ # go through and register all g_app extensions
3737
+ for handler in g_app.server_add_get:
3738
+ handler_fn = handler[1]
3739
+
3740
+ async def managed_handler(request, handler_fn=handler_fn):
3741
+ try:
3742
+ return await handler_fn(request)
3743
+ except Exception as e:
3744
+ return web.json_response(to_error_response(e, stacktrace=g_verbose), status=500)
3745
+
3746
+ app.router.add_get(handler[0], managed_handler, **handler[2])
3747
+ for handler in g_app.server_add_post:
3748
+ handler_fn = handler[1]
3749
+
3750
+ async def managed_handler(request, handler_fn=handler_fn):
3751
+ try:
3752
+ return await handler_fn(request)
3753
+ except Exception as e:
3754
+ return web.json_response(to_error_response(e, stacktrace=g_verbose), status=500)
3755
+
3756
+ app.router.add_post(handler[0], managed_handler, **handler[2])
3757
+ for handler in g_app.server_add_put:
3758
+ handler_fn = handler[1]
3759
+
3760
+ async def managed_handler(request, handler_fn=handler_fn):
3761
+ try:
3762
+ return await handler_fn(request)
3763
+ except Exception as e:
3764
+ return web.json_response(to_error_response(e, stacktrace=g_verbose), status=500)
3765
+
3766
+ app.router.add_put(handler[0], managed_handler, **handler[2])
3767
+ for handler in g_app.server_add_delete:
3768
+ handler_fn = handler[1]
3769
+
3770
+ async def managed_handler(request, handler_fn=handler_fn):
3771
+ try:
3772
+ return await handler_fn(request)
3773
+ except Exception as e:
3774
+ return web.json_response(to_error_response(e, stacktrace=g_verbose), status=500)
3775
+
3776
+ app.router.add_delete(handler[0], managed_handler, **handler[2])
3777
+ for handler in g_app.server_add_patch:
3778
+ handler_fn = handler[1]
3779
+
3780
+ async def managed_handler(request, handler_fn=handler_fn):
3781
+ try:
3782
+ return await handler_fn(request)
3783
+ except Exception as e:
3784
+ return web.json_response(to_error_response(e, stacktrace=g_verbose), status=500)
3785
+
3786
+ app.router.add_patch(handler[0], managed_handler, **handler[2])
3787
+
2073
3788
  # Serve index.html from root
2074
3789
  async def index_handler(request):
2075
3790
  index_content = read_resource_file_bytes("index.html")
2076
- if index_content is None:
2077
- raise web.HTTPNotFound
3791
+
3792
+ importmaps = {"imports": g_app.import_maps}
3793
+ importmaps_script = '<script type="importmap">\n' + json.dumps(importmaps, indent=4) + "\n</script>"
3794
+ index_content = index_content.replace(
3795
+ b'<script type="importmap"></script>',
3796
+ importmaps_script.encode("utf-8"),
3797
+ )
3798
+
3799
+ if len(g_app.index_headers) > 0:
3800
+ html_header = ""
3801
+ for header in g_app.index_headers:
3802
+ html_header += header
3803
+ # replace </head> with html_header
3804
+ index_content = index_content.replace(b"</head>", html_header.encode("utf-8") + b"\n</head>")
3805
+
3806
+ if len(g_app.index_footers) > 0:
3807
+ html_footer = ""
3808
+ for footer in g_app.index_footers:
3809
+ html_footer += footer
3810
+ # replace </body> with html_footer
3811
+ index_content = index_content.replace(b"</body>", html_footer.encode("utf-8") + b"\n</body>")
3812
+
2078
3813
  return web.Response(body=index_content, content_type="text/html")
2079
3814
 
2080
3815
  app.router.add_get("/", index_handler)
@@ -2086,13 +3821,15 @@ def main():
2086
3821
  async def start_background_tasks(app):
2087
3822
  """Start background tasks when the app starts"""
2088
3823
  # Start watching config files in the background
2089
- asyncio.create_task(watch_config_files(g_config_path, g_ui_path))
3824
+ asyncio.create_task(watch_config_files(g_config_path, home_providers_path))
2090
3825
 
2091
3826
  app.on_startup.append(start_background_tasks)
2092
3827
 
3828
+ # go through and register all g_app extensions
3829
+
2093
3830
  print(f"Starting server on port {port}...")
2094
3831
  web.run_app(app, host="0.0.0.0", port=port, print=_log)
2095
- exit(0)
3832
+ g_app.exit(0)
2096
3833
 
2097
3834
  if cli_args.enable is not None:
2098
3835
  if cli_args.enable.endswith(","):
@@ -2109,7 +3846,7 @@ def main():
2109
3846
 
2110
3847
  for provider in enable_providers:
2111
3848
  if provider not in g_config["providers"]:
2112
- print(f"Provider {provider} not found")
3849
+ print(f"Provider '{provider}' not found")
2113
3850
  print(f"Available providers: {', '.join(g_config['providers'].keys())}")
2114
3851
  exit(1)
2115
3852
  if provider in g_config["providers"]:
@@ -2122,7 +3859,7 @@ def main():
2122
3859
  print_status()
2123
3860
  if len(msgs) > 0:
2124
3861
  print("\n" + "\n".join(msgs))
2125
- exit(0)
3862
+ g_app.exit(0)
2126
3863
 
2127
3864
  if cli_args.disable is not None:
2128
3865
  if cli_args.disable.endswith(","):
@@ -2145,26 +3882,26 @@ def main():
2145
3882
  print(f"\nDisabled provider {provider}")
2146
3883
 
2147
3884
  print_status()
2148
- exit(0)
3885
+ g_app.exit(0)
2149
3886
 
2150
3887
  if cli_args.default is not None:
2151
3888
  default_model = cli_args.default
2152
- all_models = get_models()
2153
- if default_model not in all_models:
3889
+ provider_model = get_provider_model(default_model)
3890
+ if provider_model is None:
2154
3891
  print(f"Model {default_model} not found")
2155
- print(f"Available models: {', '.join(all_models)}")
2156
3892
  exit(1)
2157
3893
  default_text = g_config["defaults"]["text"]
2158
3894
  default_text["model"] = default_model
2159
3895
  save_config(g_config)
2160
3896
  print(f"\nDefault model set to: {default_model}")
2161
- exit(0)
3897
+ g_app.exit(0)
2162
3898
 
2163
3899
  if (
2164
3900
  cli_args.chat is not None
2165
3901
  or cli_args.image is not None
2166
3902
  or cli_args.audio is not None
2167
3903
  or cli_args.file is not None
3904
+ or cli_args.out is not None
2168
3905
  or len(extra_args) > 0
2169
3906
  ):
2170
3907
  try:
@@ -2175,6 +3912,12 @@ def main():
2175
3912
  chat = g_config["defaults"]["audio"]
2176
3913
  elif cli_args.file is not None:
2177
3914
  chat = g_config["defaults"]["file"]
3915
+ elif cli_args.out is not None:
3916
+ template = f"out:{cli_args.out}"
3917
+ if template not in g_config["defaults"]:
3918
+ print(f"Template for output modality '{cli_args.out}' not found")
3919
+ exit(1)
3920
+ chat = g_config["defaults"][template]
2178
3921
  if cli_args.chat is not None:
2179
3922
  chat_path = os.path.join(os.path.dirname(__file__), cli_args.chat)
2180
3923
  if not os.path.exists(chat_path):
@@ -2191,6 +3934,9 @@ def main():
2191
3934
 
2192
3935
  if len(extra_args) > 0:
2193
3936
  prompt = " ".join(extra_args)
3937
+ if not chat["messages"] or len(chat["messages"]) == 0:
3938
+ chat["messages"] = [{"role": "user", "content": [{"type": "text", "text": ""}]}]
3939
+
2194
3940
  # replace content of last message if exists, else add
2195
3941
  last_msg = chat["messages"][-1] if "messages" in chat else None
2196
3942
  if last_msg and last_msg["role"] == "user":
@@ -2208,19 +3954,31 @@ def main():
2208
3954
 
2209
3955
  asyncio.run(
2210
3956
  cli_chat(
2211
- chat, image=cli_args.image, audio=cli_args.audio, file=cli_args.file, args=args, raw=cli_args.raw
3957
+ chat,
3958
+ tools=cli_args.tools,
3959
+ image=cli_args.image,
3960
+ audio=cli_args.audio,
3961
+ file=cli_args.file,
3962
+ args=args,
3963
+ raw=cli_args.raw,
2212
3964
  )
2213
3965
  )
2214
- exit(0)
3966
+ g_app.exit(0)
2215
3967
  except Exception as e:
2216
3968
  print(f"{cli_args.logprefix}Error: {e}")
2217
3969
  if cli_args.verbose:
2218
3970
  traceback.print_exc()
2219
- exit(1)
3971
+ g_app.exit(1)
3972
+
3973
+ handled = run_extension_cli()
2220
3974
 
2221
- # show usage from ArgumentParser
2222
- parser.print_help()
3975
+ if not handled:
3976
+ # show usage from ArgumentParser
3977
+ parser.print_help()
3978
+ g_app.exit(0)
2223
3979
 
2224
3980
 
2225
3981
  if __name__ == "__main__":
3982
+ if MOCK or DEBUG:
3983
+ print(f"MOCK={MOCK} or DEBUG={DEBUG}")
2226
3984
  main()