llms-py 2.0.35__py3-none-any.whl → 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llms/__pycache__/__init__.cpython-312.pyc +0 -0
- llms/__pycache__/__init__.cpython-313.pyc +0 -0
- llms/__pycache__/__init__.cpython-314.pyc +0 -0
- llms/__pycache__/__main__.cpython-312.pyc +0 -0
- llms/__pycache__/__main__.cpython-314.pyc +0 -0
- llms/__pycache__/llms.cpython-312.pyc +0 -0
- llms/__pycache__/main.cpython-312.pyc +0 -0
- llms/__pycache__/main.cpython-313.pyc +0 -0
- llms/__pycache__/main.cpython-314.pyc +0 -0
- llms/__pycache__/plugins.cpython-314.pyc +0 -0
- llms/{ui/Analytics.mjs → extensions/analytics/ui/index.mjs} +154 -238
- llms/extensions/app/README.md +20 -0
- llms/extensions/app/__init__.py +530 -0
- llms/extensions/app/__pycache__/__init__.cpython-314.pyc +0 -0
- llms/extensions/app/__pycache__/db.cpython-314.pyc +0 -0
- llms/extensions/app/__pycache__/db_manager.cpython-314.pyc +0 -0
- llms/extensions/app/db.py +644 -0
- llms/extensions/app/db_manager.py +195 -0
- llms/extensions/app/requests.json +9073 -0
- llms/extensions/app/threads.json +15290 -0
- llms/{ui → extensions/app/ui}/Recents.mjs +91 -65
- llms/{ui/Sidebar.mjs → extensions/app/ui/index.mjs} +124 -58
- llms/extensions/app/ui/threadStore.mjs +411 -0
- llms/extensions/core_tools/CALCULATOR.md +32 -0
- llms/extensions/core_tools/__init__.py +598 -0
- llms/extensions/core_tools/__pycache__/__init__.cpython-314.pyc +0 -0
- llms/extensions/core_tools/ui/codemirror/addon/edit/closebrackets.js +201 -0
- llms/extensions/core_tools/ui/codemirror/addon/edit/closetag.js +185 -0
- llms/extensions/core_tools/ui/codemirror/addon/edit/continuelist.js +101 -0
- llms/extensions/core_tools/ui/codemirror/addon/edit/matchbrackets.js +160 -0
- llms/extensions/core_tools/ui/codemirror/addon/edit/matchtags.js +66 -0
- llms/extensions/core_tools/ui/codemirror/addon/edit/trailingspace.js +27 -0
- llms/extensions/core_tools/ui/codemirror/addon/selection/active-line.js +72 -0
- llms/extensions/core_tools/ui/codemirror/addon/selection/mark-selection.js +119 -0
- llms/extensions/core_tools/ui/codemirror/addon/selection/selection-pointer.js +98 -0
- llms/extensions/core_tools/ui/codemirror/doc/docs.css +225 -0
- llms/extensions/core_tools/ui/codemirror/doc/source_sans.woff +0 -0
- llms/extensions/core_tools/ui/codemirror/lib/codemirror.css +344 -0
- llms/extensions/core_tools/ui/codemirror/lib/codemirror.js +9884 -0
- llms/extensions/core_tools/ui/codemirror/mode/clike/clike.js +942 -0
- llms/extensions/core_tools/ui/codemirror/mode/javascript/index.html +118 -0
- llms/extensions/core_tools/ui/codemirror/mode/javascript/javascript.js +962 -0
- llms/extensions/core_tools/ui/codemirror/mode/javascript/typescript.html +62 -0
- llms/extensions/core_tools/ui/codemirror/mode/python/python.js +402 -0
- llms/extensions/core_tools/ui/codemirror/theme/dracula.css +40 -0
- llms/extensions/core_tools/ui/codemirror/theme/mocha.css +135 -0
- llms/extensions/core_tools/ui/index.mjs +650 -0
- llms/extensions/gallery/README.md +61 -0
- llms/extensions/gallery/__init__.py +61 -0
- llms/extensions/gallery/__pycache__/__init__.cpython-314.pyc +0 -0
- llms/extensions/gallery/__pycache__/db.cpython-314.pyc +0 -0
- llms/extensions/gallery/db.py +298 -0
- llms/extensions/gallery/ui/index.mjs +482 -0
- llms/extensions/katex/README.md +39 -0
- llms/extensions/katex/__init__.py +6 -0
- llms/extensions/katex/__pycache__/__init__.cpython-314.pyc +0 -0
- llms/extensions/katex/ui/README.md +125 -0
- llms/extensions/katex/ui/contrib/auto-render.js +338 -0
- llms/extensions/katex/ui/contrib/auto-render.min.js +1 -0
- llms/extensions/katex/ui/contrib/auto-render.mjs +244 -0
- llms/extensions/katex/ui/contrib/copy-tex.js +127 -0
- llms/extensions/katex/ui/contrib/copy-tex.min.js +1 -0
- llms/extensions/katex/ui/contrib/copy-tex.mjs +105 -0
- llms/extensions/katex/ui/contrib/mathtex-script-type.js +109 -0
- llms/extensions/katex/ui/contrib/mathtex-script-type.min.js +1 -0
- llms/extensions/katex/ui/contrib/mathtex-script-type.mjs +24 -0
- llms/extensions/katex/ui/contrib/mhchem.js +3213 -0
- llms/extensions/katex/ui/contrib/mhchem.min.js +1 -0
- llms/extensions/katex/ui/contrib/mhchem.mjs +3109 -0
- llms/extensions/katex/ui/contrib/render-a11y-string.js +887 -0
- llms/extensions/katex/ui/contrib/render-a11y-string.min.js +1 -0
- llms/extensions/katex/ui/contrib/render-a11y-string.mjs +800 -0
- llms/extensions/katex/ui/fonts/KaTeX_AMS-Regular.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_AMS-Regular.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_AMS-Regular.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Caligraphic-Bold.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Caligraphic-Bold.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Caligraphic-Bold.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Caligraphic-Regular.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Caligraphic-Regular.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Caligraphic-Regular.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Fraktur-Bold.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Fraktur-Bold.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Fraktur-Bold.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Fraktur-Regular.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Fraktur-Regular.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Fraktur-Regular.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Main-Bold.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Main-Bold.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Main-Bold.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Main-BoldItalic.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Main-BoldItalic.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Main-BoldItalic.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Main-Italic.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Main-Italic.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Main-Italic.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Main-Regular.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Main-Regular.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Main-Regular.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Math-BoldItalic.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Math-BoldItalic.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Math-BoldItalic.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Math-Italic.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Math-Italic.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Math-Italic.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_SansSerif-Bold.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_SansSerif-Bold.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_SansSerif-Bold.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_SansSerif-Italic.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_SansSerif-Italic.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_SansSerif-Italic.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_SansSerif-Regular.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_SansSerif-Regular.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_SansSerif-Regular.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Script-Regular.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Script-Regular.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Script-Regular.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Size1-Regular.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Size1-Regular.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Size1-Regular.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Size2-Regular.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Size2-Regular.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Size2-Regular.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Size3-Regular.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Size3-Regular.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Size3-Regular.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Size4-Regular.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Size4-Regular.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Size4-Regular.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Typewriter-Regular.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Typewriter-Regular.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Typewriter-Regular.woff2 +0 -0
- llms/extensions/katex/ui/index.mjs +92 -0
- llms/extensions/katex/ui/katex-swap.css +1230 -0
- llms/extensions/katex/ui/katex-swap.min.css +1 -0
- llms/extensions/katex/ui/katex.css +1230 -0
- llms/extensions/katex/ui/katex.js +19080 -0
- llms/extensions/katex/ui/katex.min.css +1 -0
- llms/extensions/katex/ui/katex.min.js +1 -0
- llms/extensions/katex/ui/katex.min.mjs +1 -0
- llms/extensions/katex/ui/katex.mjs +18547 -0
- llms/extensions/providers/__init__.py +18 -0
- llms/extensions/providers/__pycache__/__init__.cpython-314.pyc +0 -0
- llms/extensions/providers/__pycache__/anthropic.cpython-314.pyc +0 -0
- llms/extensions/providers/__pycache__/chutes.cpython-314.pyc +0 -0
- llms/extensions/providers/__pycache__/google.cpython-314.pyc +0 -0
- llms/extensions/providers/__pycache__/nvidia.cpython-314.pyc +0 -0
- llms/extensions/providers/__pycache__/openai.cpython-314.pyc +0 -0
- llms/extensions/providers/__pycache__/openrouter.cpython-314.pyc +0 -0
- llms/extensions/providers/anthropic.py +229 -0
- llms/extensions/providers/chutes.py +155 -0
- llms/extensions/providers/google.py +378 -0
- llms/extensions/providers/nvidia.py +105 -0
- llms/extensions/providers/openai.py +156 -0
- llms/extensions/providers/openrouter.py +72 -0
- llms/extensions/system_prompts/README.md +22 -0
- llms/extensions/system_prompts/__init__.py +45 -0
- llms/extensions/system_prompts/__pycache__/__init__.cpython-314.pyc +0 -0
- llms/extensions/system_prompts/ui/index.mjs +280 -0
- llms/extensions/system_prompts/ui/prompts.json +1067 -0
- llms/extensions/tools/__init__.py +5 -0
- llms/extensions/tools/__pycache__/__init__.cpython-314.pyc +0 -0
- llms/extensions/tools/ui/index.mjs +204 -0
- llms/index.html +35 -77
- llms/llms.json +357 -1186
- llms/main.py +2349 -591
- llms/providers-extra.json +356 -0
- llms/providers.json +1 -0
- llms/ui/App.mjs +151 -60
- llms/ui/ai.mjs +132 -60
- llms/ui/app.css +2173 -161
- llms/ui/ctx.mjs +365 -0
- llms/ui/index.mjs +129 -0
- llms/ui/lib/charts.mjs +9 -13
- llms/ui/lib/servicestack-vue.mjs +3 -3
- llms/ui/lib/vue.min.mjs +10 -9
- llms/ui/lib/vue.mjs +1796 -1635
- llms/ui/markdown.mjs +18 -7
- llms/ui/modules/chat/ChatBody.mjs +691 -0
- llms/ui/{SettingsDialog.mjs → modules/chat/SettingsDialog.mjs} +9 -9
- llms/ui/modules/chat/index.mjs +828 -0
- llms/ui/modules/layout.mjs +243 -0
- llms/ui/modules/model-selector.mjs +851 -0
- llms/ui/tailwind.input.css +496 -80
- llms/ui/utils.mjs +161 -93
- {llms_py-2.0.35.dist-info → llms_py-3.0.0.dist-info}/METADATA +1 -1
- llms_py-3.0.0.dist-info/RECORD +202 -0
- llms/ui/Avatar.mjs +0 -85
- llms/ui/Brand.mjs +0 -52
- llms/ui/ChatPrompt.mjs +0 -590
- llms/ui/Main.mjs +0 -823
- llms/ui/ModelSelector.mjs +0 -78
- llms/ui/OAuthSignIn.mjs +0 -92
- llms/ui/ProviderIcon.mjs +0 -30
- llms/ui/ProviderStatus.mjs +0 -105
- llms/ui/SignIn.mjs +0 -64
- llms/ui/SystemPromptEditor.mjs +0 -31
- llms/ui/SystemPromptSelector.mjs +0 -56
- llms/ui/Welcome.mjs +0 -8
- llms/ui/threadStore.mjs +0 -563
- llms/ui.json +0 -1069
- llms_py-2.0.35.dist-info/RECORD +0 -48
- {llms_py-2.0.35.dist-info → llms_py-3.0.0.dist-info}/WHEEL +0 -0
- {llms_py-2.0.35.dist-info → llms_py-3.0.0.dist-info}/entry_points.txt +0 -0
- {llms_py-2.0.35.dist-info → llms_py-3.0.0.dist-info}/licenses/LICENSE +0 -0
- {llms_py-2.0.35.dist-info → llms_py-3.0.0.dist-info}/top_level.txt +0 -0
llms/main.py
CHANGED
|
@@ -9,20 +9,27 @@
|
|
|
9
9
|
import argparse
|
|
10
10
|
import asyncio
|
|
11
11
|
import base64
|
|
12
|
+
import contextlib
|
|
13
|
+
import hashlib
|
|
14
|
+
import importlib.util
|
|
15
|
+
import inspect
|
|
12
16
|
import json
|
|
13
17
|
import mimetypes
|
|
14
18
|
import os
|
|
15
19
|
import re
|
|
16
20
|
import secrets
|
|
21
|
+
import shutil
|
|
17
22
|
import site
|
|
18
23
|
import subprocess
|
|
19
24
|
import sys
|
|
20
25
|
import time
|
|
21
26
|
import traceback
|
|
27
|
+
from datetime import datetime
|
|
22
28
|
from importlib import resources # Py≥3.9 (pip install importlib_resources for 3.7/3.8)
|
|
23
29
|
from io import BytesIO
|
|
24
30
|
from pathlib import Path
|
|
25
|
-
from
|
|
31
|
+
from typing import get_type_hints
|
|
32
|
+
from urllib.parse import parse_qs, urlencode, urljoin
|
|
26
33
|
|
|
27
34
|
import aiohttp
|
|
28
35
|
from aiohttp import web
|
|
@@ -34,25 +41,40 @@ try:
|
|
|
34
41
|
except ImportError:
|
|
35
42
|
HAS_PIL = False
|
|
36
43
|
|
|
37
|
-
VERSION = "
|
|
44
|
+
VERSION = "3.0.0"
|
|
38
45
|
_ROOT = None
|
|
46
|
+
DEBUG = os.getenv("DEBUG") == "1"
|
|
47
|
+
MOCK = os.getenv("MOCK") == "1"
|
|
48
|
+
MOCK_DIR = os.getenv("MOCK_DIR")
|
|
49
|
+
DISABLE_EXTENSIONS = (os.getenv("LLMS_DISABLE") or "").split(",")
|
|
39
50
|
g_config_path = None
|
|
40
|
-
g_ui_path = None
|
|
41
51
|
g_config = None
|
|
52
|
+
g_providers = None
|
|
42
53
|
g_handlers = {}
|
|
43
54
|
g_verbose = False
|
|
44
55
|
g_logprefix = ""
|
|
45
56
|
g_default_model = ""
|
|
46
57
|
g_sessions = {} # OAuth session storage: {session_token: {userId, userName, displayName, profileUrl, email, created}}
|
|
47
58
|
g_oauth_states = {} # CSRF protection: {state: {created, redirect_uri}}
|
|
59
|
+
g_app = None # ExtensionsContext Singleton
|
|
48
60
|
|
|
49
61
|
|
|
50
62
|
def _log(message):
|
|
51
|
-
"""Helper method for logging from the global polling task."""
|
|
52
63
|
if g_verbose:
|
|
53
64
|
print(f"{g_logprefix}{message}", flush=True)
|
|
54
65
|
|
|
55
66
|
|
|
67
|
+
def _dbg(message):
|
|
68
|
+
if DEBUG:
|
|
69
|
+
print(f"DEBUG: {message}", flush=True)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _err(message, e):
|
|
73
|
+
print(f"ERROR: {message}: {e}", flush=True)
|
|
74
|
+
if g_verbose:
|
|
75
|
+
print(traceback.format_exc(), flush=True)
|
|
76
|
+
|
|
77
|
+
|
|
56
78
|
def printdump(obj):
|
|
57
79
|
args = obj.__dict__ if hasattr(obj, "__dict__") else obj
|
|
58
80
|
print(json.dumps(args, indent=2))
|
|
@@ -85,17 +107,6 @@ def chat_summary(chat):
|
|
|
85
107
|
return json.dumps(clone, indent=2)
|
|
86
108
|
|
|
87
109
|
|
|
88
|
-
def gemini_chat_summary(gemini_chat):
|
|
89
|
-
"""Summarize Gemini chat completion request for logging. Replace inline_data with size of content only"""
|
|
90
|
-
clone = json.loads(json.dumps(gemini_chat))
|
|
91
|
-
for content in clone["contents"]:
|
|
92
|
-
for part in content["parts"]:
|
|
93
|
-
if "inline_data" in part:
|
|
94
|
-
data = part["inline_data"]["data"]
|
|
95
|
-
part["inline_data"]["data"] = f"({len(data)})"
|
|
96
|
-
return json.dumps(clone, indent=2)
|
|
97
|
-
|
|
98
|
-
|
|
99
110
|
image_exts = ["png", "webp", "jpg", "jpeg", "gif", "bmp", "svg", "tiff", "ico"]
|
|
100
111
|
audio_exts = ["mp3", "wav", "ogg", "flac", "m4a", "opus", "webm"]
|
|
101
112
|
|
|
@@ -189,6 +200,16 @@ def is_base_64(data):
|
|
|
189
200
|
return False
|
|
190
201
|
|
|
191
202
|
|
|
203
|
+
def id_to_name(id):
|
|
204
|
+
return id.replace("-", " ").title()
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def pluralize(word, count):
|
|
208
|
+
if count == 1:
|
|
209
|
+
return word
|
|
210
|
+
return word + "s"
|
|
211
|
+
|
|
212
|
+
|
|
192
213
|
def get_file_mime_type(filename):
|
|
193
214
|
mime_type, _ = mimetypes.guess_type(filename)
|
|
194
215
|
return mime_type or "application/octet-stream"
|
|
@@ -310,11 +331,52 @@ def convert_image_if_needed(image_bytes, mimetype="image/png"):
|
|
|
310
331
|
return image_bytes, mimetype
|
|
311
332
|
|
|
312
333
|
|
|
313
|
-
|
|
334
|
+
def to_content(result):
|
|
335
|
+
if isinstance(result, (str, int, float, bool)):
|
|
336
|
+
return str(result)
|
|
337
|
+
elif isinstance(result, (list, set, tuple, dict)):
|
|
338
|
+
return json.dumps(result)
|
|
339
|
+
else:
|
|
340
|
+
return str(result)
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def function_to_tool_definition(func):
|
|
344
|
+
type_hints = get_type_hints(func)
|
|
345
|
+
signature = inspect.signature(func)
|
|
346
|
+
parameters = {"type": "object", "properties": {}, "required": []}
|
|
347
|
+
|
|
348
|
+
for name, param in signature.parameters.items():
|
|
349
|
+
param_type = type_hints.get(name, str)
|
|
350
|
+
param_type_name = "string"
|
|
351
|
+
if param_type is int:
|
|
352
|
+
param_type_name = "integer"
|
|
353
|
+
elif param_type is float:
|
|
354
|
+
param_type_name = "number"
|
|
355
|
+
elif param_type is bool:
|
|
356
|
+
param_type_name = "boolean"
|
|
357
|
+
|
|
358
|
+
parameters["properties"][name] = {"type": param_type_name}
|
|
359
|
+
if param.default == inspect.Parameter.empty:
|
|
360
|
+
parameters["required"].append(name)
|
|
361
|
+
|
|
362
|
+
return {
|
|
363
|
+
"type": "function",
|
|
364
|
+
"function": {
|
|
365
|
+
"name": func.__name__,
|
|
366
|
+
"description": func.__doc__ or "",
|
|
367
|
+
"parameters": parameters,
|
|
368
|
+
},
|
|
369
|
+
}
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
async def process_chat(chat, provider_id=None):
|
|
314
373
|
if not chat:
|
|
315
374
|
raise Exception("No chat provided")
|
|
316
375
|
if "stream" not in chat:
|
|
317
376
|
chat["stream"] = False
|
|
377
|
+
# Some providers don't support empty tools
|
|
378
|
+
if "tools" in chat and len(chat["tools"]) == 0:
|
|
379
|
+
del chat["tools"]
|
|
318
380
|
if "messages" not in chat:
|
|
319
381
|
return chat
|
|
320
382
|
|
|
@@ -331,6 +393,8 @@ async def process_chat(chat):
|
|
|
331
393
|
image_url = item["image_url"]
|
|
332
394
|
if "url" in image_url:
|
|
333
395
|
url = image_url["url"]
|
|
396
|
+
if url.startswith("/~cache/"):
|
|
397
|
+
url = get_cache_path(url[8:])
|
|
334
398
|
if is_url(url):
|
|
335
399
|
_log(f"Downloading image: {url}")
|
|
336
400
|
async with session.get(url, timeout=aiohttp.ClientTimeout(total=120)) as response:
|
|
@@ -377,6 +441,8 @@ async def process_chat(chat):
|
|
|
377
441
|
input_audio = item["input_audio"]
|
|
378
442
|
if "data" in input_audio:
|
|
379
443
|
url = input_audio["data"]
|
|
444
|
+
if url.startswith("/~cache/"):
|
|
445
|
+
url = get_cache_path(url[8:])
|
|
380
446
|
mimetype = get_file_mime_type(get_filename(url))
|
|
381
447
|
if is_url(url):
|
|
382
448
|
_log(f"Downloading audio: {url}")
|
|
@@ -388,6 +454,8 @@ async def process_chat(chat):
|
|
|
388
454
|
mimetype = response.headers["Content-Type"]
|
|
389
455
|
# convert to base64
|
|
390
456
|
input_audio["data"] = base64.b64encode(content).decode("utf-8")
|
|
457
|
+
if provider_id == "alibaba":
|
|
458
|
+
input_audio["data"] = f"data:{mimetype};base64,{input_audio['data']}"
|
|
391
459
|
input_audio["format"] = mimetype.rsplit("/", 1)[1]
|
|
392
460
|
elif is_file_path(url):
|
|
393
461
|
_log(f"Reading audio: {url}")
|
|
@@ -395,6 +463,8 @@ async def process_chat(chat):
|
|
|
395
463
|
content = f.read()
|
|
396
464
|
# convert to base64
|
|
397
465
|
input_audio["data"] = base64.b64encode(content).decode("utf-8")
|
|
466
|
+
if provider_id == "alibaba":
|
|
467
|
+
input_audio["data"] = f"data:{mimetype};base64,{input_audio['data']}"
|
|
398
468
|
input_audio["format"] = mimetype.rsplit("/", 1)[1]
|
|
399
469
|
elif is_base_64(url):
|
|
400
470
|
pass # use base64 data as-is
|
|
@@ -404,6 +474,8 @@ async def process_chat(chat):
|
|
|
404
474
|
file = item["file"]
|
|
405
475
|
if "file_data" in file:
|
|
406
476
|
url = file["file_data"]
|
|
477
|
+
if url.startswith("/~cache/"):
|
|
478
|
+
url = get_cache_path(url[8:])
|
|
407
479
|
mimetype = get_file_mime_type(get_filename(url))
|
|
408
480
|
if is_url(url):
|
|
409
481
|
_log(f"Downloading file: {url}")
|
|
@@ -431,6 +503,92 @@ async def process_chat(chat):
|
|
|
431
503
|
return chat
|
|
432
504
|
|
|
433
505
|
|
|
506
|
+
def image_ext_from_mimetype(mimetype, default="png"):
|
|
507
|
+
if "/" in mimetype:
|
|
508
|
+
_ext = mimetypes.guess_extension(mimetype)
|
|
509
|
+
if _ext:
|
|
510
|
+
return _ext.lstrip(".")
|
|
511
|
+
return default
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
def audio_ext_from_format(format, default="mp3"):
|
|
515
|
+
if format == "mpeg":
|
|
516
|
+
return "mp3"
|
|
517
|
+
return format or default
|
|
518
|
+
|
|
519
|
+
|
|
520
|
+
def file_ext_from_mimetype(mimetype, default="pdf"):
|
|
521
|
+
if "/" in mimetype:
|
|
522
|
+
_ext = mimetypes.guess_extension(mimetype)
|
|
523
|
+
if _ext:
|
|
524
|
+
return _ext.lstrip(".")
|
|
525
|
+
return default
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
def cache_message_inline_data(m):
|
|
529
|
+
"""
|
|
530
|
+
Replaces and caches any inline data URIs in the message content.
|
|
531
|
+
"""
|
|
532
|
+
if "content" not in m:
|
|
533
|
+
return
|
|
534
|
+
|
|
535
|
+
content = m["content"]
|
|
536
|
+
if isinstance(content, list):
|
|
537
|
+
for item in content:
|
|
538
|
+
if item.get("type") == "image_url":
|
|
539
|
+
image_url = item.get("image_url", {})
|
|
540
|
+
url = image_url.get("url")
|
|
541
|
+
if url and url.startswith("data:"):
|
|
542
|
+
# Extract base64 and mimetype
|
|
543
|
+
try:
|
|
544
|
+
header, base64_data = url.split(";base64,")
|
|
545
|
+
# header is like "data:image/png"
|
|
546
|
+
ext = image_ext_from_mimetype(header.split(":")[1])
|
|
547
|
+
filename = f"image.{ext}" # Hash will handle uniqueness
|
|
548
|
+
|
|
549
|
+
cache_url, _ = save_image_to_cache(base64_data, filename, {}, ignore_info=True)
|
|
550
|
+
image_url["url"] = cache_url
|
|
551
|
+
except Exception as e:
|
|
552
|
+
_log(f"Error caching inline image: {e}")
|
|
553
|
+
|
|
554
|
+
elif item.get("type") == "input_audio":
|
|
555
|
+
input_audio = item.get("input_audio", {})
|
|
556
|
+
data = input_audio.get("data")
|
|
557
|
+
if data:
|
|
558
|
+
# Handle data URI or raw base64
|
|
559
|
+
base64_data = data
|
|
560
|
+
if data.startswith("data:"):
|
|
561
|
+
with contextlib.suppress(ValueError):
|
|
562
|
+
header, base64_data = data.split(";base64,")
|
|
563
|
+
|
|
564
|
+
fmt = audio_ext_from_format(input_audio.get("format"))
|
|
565
|
+
filename = f"audio.{fmt}"
|
|
566
|
+
|
|
567
|
+
try:
|
|
568
|
+
cache_url, _ = save_bytes_to_cache(base64_data, filename, {}, ignore_info=True)
|
|
569
|
+
input_audio["data"] = cache_url
|
|
570
|
+
except Exception as e:
|
|
571
|
+
_log(f"Error caching inline audio: {e}")
|
|
572
|
+
|
|
573
|
+
elif item.get("type") == "file":
|
|
574
|
+
file_info = item.get("file", {})
|
|
575
|
+
file_data = file_info.get("file_data")
|
|
576
|
+
if file_data and file_data.startswith("data:"):
|
|
577
|
+
try:
|
|
578
|
+
header, base64_data = file_data.split(";base64,")
|
|
579
|
+
mimetype = header.split(":")[1]
|
|
580
|
+
# Try to get extension from filename if available, else mimetype
|
|
581
|
+
filename = file_info.get("filename", "file")
|
|
582
|
+
if "." not in filename:
|
|
583
|
+
ext = file_ext_from_mimetype(mimetype)
|
|
584
|
+
filename = f"{filename}.{ext}"
|
|
585
|
+
|
|
586
|
+
cache_url, _ = save_bytes_to_cache(base64_data, filename, {}, ignore_info=True)
|
|
587
|
+
file_info["file_data"] = cache_url
|
|
588
|
+
except Exception as e:
|
|
589
|
+
_log(f"Error caching inline file: {e}")
|
|
590
|
+
|
|
591
|
+
|
|
434
592
|
class HTTPError(Exception):
|
|
435
593
|
def __init__(self, status, reason, body, headers=None):
|
|
436
594
|
self.status = status
|
|
@@ -440,33 +598,323 @@ class HTTPError(Exception):
|
|
|
440
598
|
super().__init__(f"HTTP {status} {reason}")
|
|
441
599
|
|
|
442
600
|
|
|
601
|
+
def save_bytes_to_cache(base64_data, filename, file_info, ignore_info=False):
|
|
602
|
+
ext = filename.split(".")[-1]
|
|
603
|
+
mimetype = get_file_mime_type(filename)
|
|
604
|
+
content = base64.b64decode(base64_data) if isinstance(base64_data, str) else base64_data
|
|
605
|
+
sha256_hash = hashlib.sha256(content).hexdigest()
|
|
606
|
+
|
|
607
|
+
save_filename = f"{sha256_hash}.{ext}" if ext else sha256_hash
|
|
608
|
+
|
|
609
|
+
# Use first 2 chars for subdir to avoid too many files in one dir
|
|
610
|
+
subdir = sha256_hash[:2]
|
|
611
|
+
relative_path = f"{subdir}/{save_filename}"
|
|
612
|
+
full_path = get_cache_path(relative_path)
|
|
613
|
+
url = f"/~cache/{relative_path}"
|
|
614
|
+
|
|
615
|
+
# if file and its .info.json already exists, return it
|
|
616
|
+
info_path = os.path.splitext(full_path)[0] + ".info.json"
|
|
617
|
+
if os.path.exists(full_path) and os.path.exists(info_path):
|
|
618
|
+
_dbg(f"Cached bytes exists: {relative_path}")
|
|
619
|
+
if ignore_info:
|
|
620
|
+
return url, None
|
|
621
|
+
return url, json.load(open(info_path))
|
|
622
|
+
|
|
623
|
+
os.makedirs(os.path.dirname(full_path), exist_ok=True)
|
|
624
|
+
|
|
625
|
+
with open(full_path, "wb") as f:
|
|
626
|
+
f.write(content)
|
|
627
|
+
info = {
|
|
628
|
+
"date": int(time.time()),
|
|
629
|
+
"url": url,
|
|
630
|
+
"size": len(content),
|
|
631
|
+
"type": mimetype,
|
|
632
|
+
"name": filename,
|
|
633
|
+
}
|
|
634
|
+
info.update(file_info)
|
|
635
|
+
|
|
636
|
+
# Save metadata
|
|
637
|
+
info_path = os.path.splitext(full_path)[0] + ".info.json"
|
|
638
|
+
with open(info_path, "w") as f:
|
|
639
|
+
json.dump(info, f)
|
|
640
|
+
|
|
641
|
+
_dbg(f"Saved cached bytes and info: {relative_path}")
|
|
642
|
+
|
|
643
|
+
g_app.on_cache_saved_filters({"url": url, "info": info})
|
|
644
|
+
|
|
645
|
+
return url, info
|
|
646
|
+
|
|
647
|
+
|
|
648
|
+
def save_image_to_cache(base64_data, filename, image_info, ignore_info=False):
|
|
649
|
+
ext = filename.split(".")[-1]
|
|
650
|
+
mimetype = get_file_mime_type(filename)
|
|
651
|
+
content = base64.b64decode(base64_data) if isinstance(base64_data, str) else base64_data
|
|
652
|
+
sha256_hash = hashlib.sha256(content).hexdigest()
|
|
653
|
+
|
|
654
|
+
save_filename = f"{sha256_hash}.{ext}" if ext else sha256_hash
|
|
655
|
+
|
|
656
|
+
# Use first 2 chars for subdir to avoid too many files in one dir
|
|
657
|
+
subdir = sha256_hash[:2]
|
|
658
|
+
relative_path = f"{subdir}/{save_filename}"
|
|
659
|
+
full_path = get_cache_path(relative_path)
|
|
660
|
+
url = f"/~cache/{relative_path}"
|
|
661
|
+
|
|
662
|
+
# if file and its .info.json already exists, return it
|
|
663
|
+
info_path = os.path.splitext(full_path)[0] + ".info.json"
|
|
664
|
+
if os.path.exists(full_path) and os.path.exists(info_path):
|
|
665
|
+
_dbg(f"Saved image exists: {relative_path}")
|
|
666
|
+
if ignore_info:
|
|
667
|
+
return url, None
|
|
668
|
+
return url, json.load(open(info_path))
|
|
669
|
+
|
|
670
|
+
os.makedirs(os.path.dirname(full_path), exist_ok=True)
|
|
671
|
+
|
|
672
|
+
with open(full_path, "wb") as f:
|
|
673
|
+
f.write(content)
|
|
674
|
+
info = {
|
|
675
|
+
"date": int(time.time()),
|
|
676
|
+
"url": url,
|
|
677
|
+
"size": len(content),
|
|
678
|
+
"type": mimetype,
|
|
679
|
+
"name": filename,
|
|
680
|
+
}
|
|
681
|
+
info.update(image_info)
|
|
682
|
+
|
|
683
|
+
# If image, get dimensions
|
|
684
|
+
if HAS_PIL and mimetype.startswith("image/"):
|
|
685
|
+
try:
|
|
686
|
+
with Image.open(BytesIO(content)) as img:
|
|
687
|
+
info["width"] = img.width
|
|
688
|
+
info["height"] = img.height
|
|
689
|
+
except Exception:
|
|
690
|
+
pass
|
|
691
|
+
|
|
692
|
+
if "width" in info and "height" in info:
|
|
693
|
+
_log(f"Saved image to cache: {full_path} ({len(content)} bytes) {info['width']}x{info['height']}")
|
|
694
|
+
else:
|
|
695
|
+
_log(f"Saved image to cache: {full_path} ({len(content)} bytes)")
|
|
696
|
+
|
|
697
|
+
# Save metadata
|
|
698
|
+
info_path = os.path.splitext(full_path)[0] + ".info.json"
|
|
699
|
+
with open(info_path, "w") as f:
|
|
700
|
+
json.dump(info, f)
|
|
701
|
+
|
|
702
|
+
_dbg(f"Saved image and info: {relative_path}")
|
|
703
|
+
|
|
704
|
+
g_app.on_cache_saved_filters({"url": url, "info": info})
|
|
705
|
+
|
|
706
|
+
return url, info
|
|
707
|
+
|
|
708
|
+
|
|
443
709
|
async def response_json(response):
|
|
444
710
|
text = await response.text()
|
|
445
711
|
if response.status >= 400:
|
|
712
|
+
_dbg(f"HTTP {response.status} {response.reason}: {text}")
|
|
446
713
|
raise HTTPError(response.status, reason=response.reason, body=text, headers=dict(response.headers))
|
|
447
714
|
response.raise_for_status()
|
|
448
715
|
body = json.loads(text)
|
|
449
716
|
return body
|
|
450
717
|
|
|
451
718
|
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
719
|
+
def chat_to_prompt(chat):
|
|
720
|
+
prompt = ""
|
|
721
|
+
if "messages" in chat:
|
|
722
|
+
for message in chat["messages"]:
|
|
723
|
+
if message["role"] == "user":
|
|
724
|
+
# if content is string
|
|
725
|
+
if isinstance(message["content"], str):
|
|
726
|
+
if prompt:
|
|
727
|
+
prompt += "\n"
|
|
728
|
+
prompt += message["content"]
|
|
729
|
+
elif isinstance(message["content"], list):
|
|
730
|
+
# if content is array of objects
|
|
731
|
+
for part in message["content"]:
|
|
732
|
+
if part["type"] == "text":
|
|
733
|
+
if prompt:
|
|
734
|
+
prompt += "\n"
|
|
735
|
+
prompt += part["text"]
|
|
736
|
+
return prompt
|
|
737
|
+
|
|
738
|
+
|
|
739
|
+
def chat_to_system_prompt(chat):
|
|
740
|
+
if "messages" in chat:
|
|
741
|
+
for message in chat["messages"]:
|
|
742
|
+
if message["role"] == "system":
|
|
743
|
+
# if content is string
|
|
744
|
+
if isinstance(message["content"], str):
|
|
745
|
+
return message["content"]
|
|
746
|
+
elif isinstance(message["content"], list):
|
|
747
|
+
# if content is array of objects
|
|
748
|
+
for part in message["content"]:
|
|
749
|
+
if part["type"] == "text":
|
|
750
|
+
return part["text"]
|
|
751
|
+
return None
|
|
459
752
|
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
753
|
+
|
|
754
|
+
def chat_to_username(chat):
|
|
755
|
+
if "metadata" in chat and "user" in chat["metadata"]:
|
|
756
|
+
return chat["metadata"]["user"]
|
|
757
|
+
return None
|
|
758
|
+
|
|
759
|
+
|
|
760
|
+
def last_user_prompt(chat):
|
|
761
|
+
prompt = ""
|
|
762
|
+
if "messages" in chat:
|
|
763
|
+
for message in chat["messages"]:
|
|
764
|
+
if message["role"] == "user":
|
|
765
|
+
# if content is string
|
|
766
|
+
if isinstance(message["content"], str):
|
|
767
|
+
prompt = message["content"]
|
|
768
|
+
elif isinstance(message["content"], list):
|
|
769
|
+
# if content is array of objects
|
|
770
|
+
for part in message["content"]:
|
|
771
|
+
if part["type"] == "text":
|
|
772
|
+
prompt = part["text"]
|
|
773
|
+
return prompt
|
|
774
|
+
|
|
775
|
+
|
|
776
|
+
def chat_response_to_message(openai_response):
|
|
777
|
+
"""
|
|
778
|
+
Returns an assistant message from the OpenAI Response.
|
|
779
|
+
Handles normalizing text, image, and audio responses into the message content.
|
|
780
|
+
"""
|
|
781
|
+
timestamp = int(time.time() * 1000) # openai_response.get("created")
|
|
782
|
+
choices = openai_response
|
|
783
|
+
if isinstance(openai_response, dict) and "choices" in openai_response:
|
|
784
|
+
choices = openai_response["choices"]
|
|
785
|
+
|
|
786
|
+
choice = choices[0] if isinstance(choices, list) and choices else choices
|
|
787
|
+
|
|
788
|
+
if isinstance(choice, str):
|
|
789
|
+
return {"role": "assistant", "content": choice, "timestamp": timestamp}
|
|
790
|
+
|
|
791
|
+
if isinstance(choice, dict):
|
|
792
|
+
message = choice.get("message", choice)
|
|
793
|
+
else:
|
|
794
|
+
return {"role": "assistant", "content": str(choice), "timestamp": timestamp}
|
|
795
|
+
|
|
796
|
+
# Ensure message is a dict
|
|
797
|
+
if not isinstance(message, dict):
|
|
798
|
+
return {"role": "assistant", "content": message, "timestamp": timestamp}
|
|
799
|
+
|
|
800
|
+
message.update({"timestamp": timestamp})
|
|
801
|
+
return message
|
|
802
|
+
|
|
803
|
+
|
|
804
|
+
def to_file_info(chat, info=None, response=None):
|
|
805
|
+
prompt = last_user_prompt(chat)
|
|
806
|
+
ret = info or {}
|
|
807
|
+
if chat["model"] and "model" not in ret:
|
|
808
|
+
ret["model"] = chat["model"]
|
|
809
|
+
if prompt and "prompt" not in ret:
|
|
810
|
+
ret["prompt"] = prompt
|
|
811
|
+
if "image_config" in chat:
|
|
812
|
+
ret.update(chat["image_config"])
|
|
813
|
+
user = chat_to_username(chat)
|
|
814
|
+
if user:
|
|
815
|
+
ret["user"] = user
|
|
816
|
+
return ret
|
|
817
|
+
|
|
818
|
+
|
|
819
|
+
# Image Generator Providers
|
|
820
|
+
class GeneratorBase:
|
|
821
|
+
def __init__(self, **kwargs):
|
|
822
|
+
self.id = kwargs.get("id")
|
|
823
|
+
self.api = kwargs.get("api")
|
|
824
|
+
self.api_key = kwargs.get("api_key")
|
|
825
|
+
self.headers = {
|
|
826
|
+
"Accept": "application/json",
|
|
827
|
+
"Content-Type": "application/json",
|
|
828
|
+
}
|
|
829
|
+
self.chat_url = f"{self.api}/chat/completions"
|
|
830
|
+
self.default_content = "I've generated the image for you."
|
|
831
|
+
|
|
832
|
+
def validate(self, **kwargs):
|
|
833
|
+
if not self.api_key:
|
|
834
|
+
api_keys = ", ".join(self.env)
|
|
835
|
+
return f"Provider '{self.name}' requires API Key {api_keys}"
|
|
836
|
+
return None
|
|
837
|
+
|
|
838
|
+
def test(self, **kwargs):
|
|
839
|
+
error_msg = self.validate(**kwargs)
|
|
840
|
+
if error_msg:
|
|
841
|
+
_log(error_msg)
|
|
842
|
+
return False
|
|
843
|
+
return True
|
|
844
|
+
|
|
845
|
+
async def load(self):
|
|
846
|
+
pass
|
|
847
|
+
|
|
848
|
+
def gen_summary(self, gen):
|
|
849
|
+
"""Summarize gen response for logging."""
|
|
850
|
+
clone = json.loads(json.dumps(gen))
|
|
851
|
+
return json.dumps(clone, indent=2)
|
|
852
|
+
|
|
853
|
+
def chat_summary(self, chat):
|
|
854
|
+
return chat_summary(chat)
|
|
855
|
+
|
|
856
|
+
def process_chat(self, chat, provider_id=None):
|
|
857
|
+
return process_chat(chat, provider_id)
|
|
858
|
+
|
|
859
|
+
async def response_json(self, response):
|
|
860
|
+
return await response_json(response)
|
|
861
|
+
|
|
862
|
+
def get_headers(self, provider, chat):
|
|
863
|
+
headers = self.headers.copy()
|
|
864
|
+
if provider is not None:
|
|
865
|
+
headers["Authorization"] = f"Bearer {provider.api_key}"
|
|
866
|
+
elif self.api_key:
|
|
867
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
868
|
+
return headers
|
|
869
|
+
|
|
870
|
+
def to_response(self, response, chat, started_at):
|
|
871
|
+
raise NotImplementedError
|
|
872
|
+
|
|
873
|
+
async def chat(self, chat, provider=None):
|
|
874
|
+
return {
|
|
875
|
+
"choices": [
|
|
876
|
+
{
|
|
877
|
+
"message": {
|
|
878
|
+
"role": "assistant",
|
|
879
|
+
"content": "Not Implemented",
|
|
880
|
+
"images": [
|
|
881
|
+
{
|
|
882
|
+
"type": "image_url",
|
|
883
|
+
"image_url": {
|
|
884
|
+
"url": "",
|
|
885
|
+
},
|
|
886
|
+
}
|
|
887
|
+
],
|
|
888
|
+
}
|
|
889
|
+
}
|
|
890
|
+
]
|
|
891
|
+
}
|
|
892
|
+
|
|
893
|
+
|
|
894
|
+
# OpenAI Providers
|
|
895
|
+
|
|
896
|
+
|
|
897
|
+
class OpenAiCompatible:
|
|
898
|
+
sdk = "@ai-sdk/openai-compatible"
|
|
899
|
+
|
|
900
|
+
def __init__(self, **kwargs):
|
|
901
|
+
required_args = ["id", "api"]
|
|
902
|
+
for arg in required_args:
|
|
903
|
+
if arg not in kwargs:
|
|
904
|
+
raise ValueError(f"Missing required argument: {arg}")
|
|
905
|
+
|
|
906
|
+
self.id = kwargs.get("id")
|
|
907
|
+
self.api = kwargs.get("api").strip("/")
|
|
908
|
+
self.env = kwargs.get("env", [])
|
|
909
|
+
self.api_key = kwargs.get("api_key")
|
|
910
|
+
self.name = kwargs.get("name", id_to_name(self.id))
|
|
911
|
+
self.set_models(**kwargs)
|
|
912
|
+
|
|
913
|
+
self.chat_url = f"{self.api}/chat/completions"
|
|
466
914
|
|
|
467
915
|
self.headers = kwargs.get("headers", {"Content-Type": "application/json"})
|
|
468
|
-
if api_key is not None:
|
|
469
|
-
self.headers["Authorization"] = f"Bearer {api_key}"
|
|
916
|
+
if self.api_key is not None:
|
|
917
|
+
self.headers["Authorization"] = f"Bearer {self.api_key}"
|
|
470
918
|
|
|
471
919
|
self.frequency_penalty = float(kwargs["frequency_penalty"]) if "frequency_penalty" in kwargs else None
|
|
472
920
|
self.max_completion_tokens = int(kwargs["max_completion_tokens"]) if "max_completion_tokens" in kwargs else None
|
|
@@ -486,44 +934,132 @@ class OpenAiProvider:
|
|
|
486
934
|
self.verbosity = kwargs.get("verbosity")
|
|
487
935
|
self.stream = bool(kwargs["stream"]) if "stream" in kwargs else None
|
|
488
936
|
self.enable_thinking = bool(kwargs["enable_thinking"]) if "enable_thinking" in kwargs else None
|
|
489
|
-
self.pricing = kwargs.get("pricing")
|
|
490
|
-
self.default_pricing = kwargs.get("default_pricing")
|
|
491
937
|
self.check = kwargs.get("check")
|
|
938
|
+
self.modalities = kwargs.get("modalities", {})
|
|
939
|
+
|
|
940
|
+
def set_models(self, **kwargs):
|
|
941
|
+
models = kwargs.get("models", {})
|
|
942
|
+
self.map_models = kwargs.get("map_models", {})
|
|
943
|
+
# if 'map_models' is provided, only include models in `map_models[model_id] = provider_model_id`
|
|
944
|
+
if self.map_models:
|
|
945
|
+
self.models = {}
|
|
946
|
+
for provider_model_id in self.map_models.values():
|
|
947
|
+
if provider_model_id in models:
|
|
948
|
+
self.models[provider_model_id] = models[provider_model_id]
|
|
949
|
+
else:
|
|
950
|
+
self.models = models
|
|
951
|
+
|
|
952
|
+
include_models = kwargs.get("include_models") # string regex pattern
|
|
953
|
+
# only include models that match the regex pattern
|
|
954
|
+
if include_models:
|
|
955
|
+
_log(f"Filtering {len(self.models)} models, only including models that match regex: {include_models}")
|
|
956
|
+
self.models = {k: v for k, v in self.models.items() if re.search(include_models, k)}
|
|
957
|
+
|
|
958
|
+
exclude_models = kwargs.get("exclude_models") # string regex pattern
|
|
959
|
+
# exclude models that match the regex pattern
|
|
960
|
+
if exclude_models:
|
|
961
|
+
_log(f"Filtering {len(self.models)} models, excluding models that match regex: {exclude_models}")
|
|
962
|
+
self.models = {k: v for k, v in self.models.items() if not re.search(exclude_models, k)}
|
|
963
|
+
|
|
964
|
+
def validate(self, **kwargs):
|
|
965
|
+
if not self.api_key:
|
|
966
|
+
api_keys = ", ".join(self.env)
|
|
967
|
+
return f"Provider '{self.name}' requires API Key {api_keys}"
|
|
968
|
+
return None
|
|
492
969
|
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
if
|
|
496
|
-
|
|
497
|
-
|
|
970
|
+
def test(self, **kwargs):
|
|
971
|
+
error_msg = self.validate(**kwargs)
|
|
972
|
+
if error_msg:
|
|
973
|
+
_log(error_msg)
|
|
974
|
+
return False
|
|
975
|
+
return True
|
|
498
976
|
|
|
499
977
|
async def load(self):
|
|
500
|
-
|
|
978
|
+
if not self.models:
|
|
979
|
+
await self.load_models()
|
|
501
980
|
|
|
502
|
-
def
|
|
981
|
+
def model_info(self, model):
|
|
503
982
|
provider_model = self.provider_model(model) or model
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
983
|
+
for model_id, model_info in self.models.items():
|
|
984
|
+
if model_id.lower() == provider_model.lower():
|
|
985
|
+
return model_info
|
|
986
|
+
return None
|
|
987
|
+
|
|
988
|
+
def model_cost(self, model):
|
|
989
|
+
model_info = self.model_info(model)
|
|
990
|
+
return model_info.get("cost") if model_info else None
|
|
507
991
|
|
|
508
992
|
def provider_model(self, model):
|
|
509
|
-
|
|
510
|
-
|
|
993
|
+
# convert model to lowercase for case-insensitive comparison
|
|
994
|
+
model_lower = model.lower()
|
|
995
|
+
|
|
996
|
+
# if model is a map model id, return the provider model id
|
|
997
|
+
for model_id, provider_model in self.map_models.items():
|
|
998
|
+
if model_id.lower() == model_lower:
|
|
999
|
+
return provider_model
|
|
1000
|
+
|
|
1001
|
+
# if model is a provider model id, try again with just the model name
|
|
1002
|
+
for provider_model in self.map_models.values():
|
|
1003
|
+
if provider_model.lower() == model_lower:
|
|
1004
|
+
return provider_model
|
|
1005
|
+
|
|
1006
|
+
# if model is a model id, try again with just the model id or name
|
|
1007
|
+
for model_id, provider_model_info in self.models.items():
|
|
1008
|
+
id = provider_model_info.get("id") or model_id
|
|
1009
|
+
if model_id.lower() == model_lower or id.lower() == model_lower:
|
|
1010
|
+
return id
|
|
1011
|
+
name = provider_model_info.get("name")
|
|
1012
|
+
if name and name.lower() == model_lower:
|
|
1013
|
+
return id
|
|
1014
|
+
|
|
1015
|
+
# fallback to trying again with just the model short name
|
|
1016
|
+
for model_id, provider_model_info in self.models.items():
|
|
1017
|
+
id = provider_model_info.get("id") or model_id
|
|
1018
|
+
if "/" in id:
|
|
1019
|
+
model_name = id.split("/")[-1]
|
|
1020
|
+
if model_name.lower() == model_lower:
|
|
1021
|
+
return id
|
|
1022
|
+
|
|
1023
|
+
# if model is a full provider model id, try again with just the model name
|
|
1024
|
+
if "/" in model:
|
|
1025
|
+
last_part = model.split("/")[-1]
|
|
1026
|
+
return self.provider_model(last_part)
|
|
1027
|
+
|
|
511
1028
|
return None
|
|
512
1029
|
|
|
1030
|
+
def response_json(self, response):
|
|
1031
|
+
return response_json(response)
|
|
1032
|
+
|
|
513
1033
|
def to_response(self, response, chat, started_at):
|
|
514
1034
|
if "metadata" not in response:
|
|
515
1035
|
response["metadata"] = {}
|
|
516
1036
|
response["metadata"]["duration"] = int((time.time() - started_at) * 1000)
|
|
517
1037
|
if chat is not None and "model" in chat:
|
|
518
|
-
pricing = self.
|
|
1038
|
+
pricing = self.model_cost(chat["model"])
|
|
519
1039
|
if pricing and "input" in pricing and "output" in pricing:
|
|
520
1040
|
response["metadata"]["pricing"] = f"{pricing['input']}/{pricing['output']}"
|
|
521
|
-
_log(json.dumps(response, indent=2))
|
|
522
1041
|
return response
|
|
523
1042
|
|
|
1043
|
+
def chat_summary(self, chat):
|
|
1044
|
+
return chat_summary(chat)
|
|
1045
|
+
|
|
1046
|
+
def process_chat(self, chat, provider_id=None):
|
|
1047
|
+
return process_chat(chat, provider_id)
|
|
1048
|
+
|
|
524
1049
|
async def chat(self, chat):
|
|
525
1050
|
chat["model"] = self.provider_model(chat["model"]) or chat["model"]
|
|
526
1051
|
|
|
1052
|
+
if "modalities" in chat:
|
|
1053
|
+
for modality in chat.get("modalities", []):
|
|
1054
|
+
# use default implementation for text modalities
|
|
1055
|
+
if modality == "text":
|
|
1056
|
+
continue
|
|
1057
|
+
modality_provider = self.modalities.get(modality)
|
|
1058
|
+
if modality_provider:
|
|
1059
|
+
return await modality_provider.chat(chat, self)
|
|
1060
|
+
else:
|
|
1061
|
+
raise Exception(f"Provider {self.name} does not support '{modality}' modality")
|
|
1062
|
+
|
|
527
1063
|
# with open(os.path.join(os.path.dirname(__file__), 'chat.wip.json'), "w") as f:
|
|
528
1064
|
# f.write(json.dumps(chat, indent=2))
|
|
529
1065
|
|
|
@@ -562,285 +1098,152 @@ class OpenAiProvider:
|
|
|
562
1098
|
if self.enable_thinking is not None:
|
|
563
1099
|
chat["enable_thinking"] = self.enable_thinking
|
|
564
1100
|
|
|
565
|
-
chat = await process_chat(chat)
|
|
1101
|
+
chat = await process_chat(chat, provider_id=self.id)
|
|
566
1102
|
_log(f"POST {self.chat_url}")
|
|
567
1103
|
_log(chat_summary(chat))
|
|
568
1104
|
# remove metadata if any (conflicts with some providers, e.g. Z.ai)
|
|
569
|
-
chat.pop("metadata", None)
|
|
1105
|
+
metadata = chat.pop("metadata", None)
|
|
570
1106
|
|
|
571
1107
|
async with aiohttp.ClientSession() as session:
|
|
572
1108
|
started_at = time.time()
|
|
573
1109
|
async with session.post(
|
|
574
1110
|
self.chat_url, headers=self.headers, data=json.dumps(chat), timeout=aiohttp.ClientTimeout(total=120)
|
|
575
1111
|
) as response:
|
|
1112
|
+
chat["metadata"] = metadata
|
|
576
1113
|
return self.to_response(await response_json(response), chat, started_at)
|
|
577
1114
|
|
|
578
1115
|
|
|
579
|
-
class
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
1116
|
+
class MistralProvider(OpenAiCompatible):
|
|
1117
|
+
sdk = "@ai-sdk/mistral"
|
|
1118
|
+
|
|
1119
|
+
def __init__(self, **kwargs):
|
|
1120
|
+
if "api" not in kwargs:
|
|
1121
|
+
kwargs["api"] = "https://api.mistral.ai/v1"
|
|
1122
|
+
super().__init__(**kwargs)
|
|
1123
|
+
|
|
1124
|
+
|
|
1125
|
+
class GroqProvider(OpenAiCompatible):
|
|
1126
|
+
sdk = "@ai-sdk/groq"
|
|
1127
|
+
|
|
1128
|
+
def __init__(self, **kwargs):
|
|
1129
|
+
if "api" not in kwargs:
|
|
1130
|
+
kwargs["api"] = "https://api.groq.com/openai/v1"
|
|
1131
|
+
super().__init__(**kwargs)
|
|
1132
|
+
|
|
1133
|
+
|
|
1134
|
+
class XaiProvider(OpenAiCompatible):
|
|
1135
|
+
sdk = "@ai-sdk/xai"
|
|
1136
|
+
|
|
1137
|
+
def __init__(self, **kwargs):
|
|
1138
|
+
if "api" not in kwargs:
|
|
1139
|
+
kwargs["api"] = "https://api.x.ai/v1"
|
|
1140
|
+
super().__init__(**kwargs)
|
|
1141
|
+
|
|
1142
|
+
|
|
1143
|
+
class CodestralProvider(OpenAiCompatible):
|
|
1144
|
+
sdk = "codestral"
|
|
1145
|
+
|
|
1146
|
+
def __init__(self, **kwargs):
|
|
1147
|
+
super().__init__(**kwargs)
|
|
1148
|
+
|
|
1149
|
+
|
|
1150
|
+
class OllamaProvider(OpenAiCompatible):
|
|
1151
|
+
sdk = "ollama"
|
|
1152
|
+
|
|
1153
|
+
def __init__(self, **kwargs):
|
|
1154
|
+
super().__init__(**kwargs)
|
|
1155
|
+
# Ollama's OpenAI-compatible endpoint is at /v1/chat/completions
|
|
1156
|
+
self.chat_url = f"{self.api}/v1/chat/completions"
|
|
583
1157
|
|
|
584
1158
|
async def load(self):
|
|
585
|
-
if self.
|
|
586
|
-
await self.load_models(
|
|
1159
|
+
if not self.models:
|
|
1160
|
+
await self.load_models()
|
|
587
1161
|
|
|
588
1162
|
async def get_models(self):
|
|
589
1163
|
ret = {}
|
|
590
1164
|
try:
|
|
591
1165
|
async with aiohttp.ClientSession() as session:
|
|
592
|
-
_log(f"GET {self.
|
|
1166
|
+
_log(f"GET {self.api}/api/tags")
|
|
593
1167
|
async with session.get(
|
|
594
|
-
f"{self.
|
|
1168
|
+
f"{self.api}/api/tags", headers=self.headers, timeout=aiohttp.ClientTimeout(total=120)
|
|
595
1169
|
) as response:
|
|
596
1170
|
data = await response_json(response)
|
|
597
1171
|
for model in data.get("models", []):
|
|
598
|
-
|
|
599
|
-
if
|
|
600
|
-
|
|
601
|
-
ret[
|
|
1172
|
+
model_id = model["model"]
|
|
1173
|
+
if model_id.endswith(":latest"):
|
|
1174
|
+
model_id = model_id[:-7]
|
|
1175
|
+
ret[model_id] = model_id
|
|
602
1176
|
_log(f"Loaded Ollama models: {ret}")
|
|
603
1177
|
except Exception as e:
|
|
604
1178
|
_log(f"Error getting Ollama models: {e}")
|
|
605
1179
|
# return empty dict if ollama is not available
|
|
606
1180
|
return ret
|
|
607
1181
|
|
|
608
|
-
async def load_models(self
|
|
1182
|
+
async def load_models(self):
|
|
609
1183
|
"""Load models if all_models was requested"""
|
|
610
|
-
if self.all_models:
|
|
611
|
-
self.models = await self.get_models()
|
|
612
|
-
if default_models:
|
|
613
|
-
self.models = {**default_models, **self.models}
|
|
614
|
-
|
|
615
|
-
@classmethod
|
|
616
|
-
def test(cls, base_url=None, models=None, all_models=False, **kwargs):
|
|
617
|
-
if models is None:
|
|
618
|
-
models = {}
|
|
619
|
-
return base_url and (len(models) > 0 or all_models)
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
class GoogleOpenAiProvider(OpenAiProvider):
|
|
623
|
-
def __init__(self, api_key, models, **kwargs):
|
|
624
|
-
super().__init__(base_url="https://generativelanguage.googleapis.com", api_key=api_key, models=models, **kwargs)
|
|
625
|
-
self.chat_url = "https://generativelanguage.googleapis.com/v1beta/chat/completions"
|
|
626
|
-
|
|
627
|
-
@classmethod
|
|
628
|
-
def test(cls, api_key=None, models=None, **kwargs):
|
|
629
|
-
if models is None:
|
|
630
|
-
models = {}
|
|
631
|
-
return api_key and len(models) > 0
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
class GoogleProvider(OpenAiProvider):
|
|
635
|
-
def __init__(self, models, api_key, safety_settings=None, thinking_config=None, curl=False, **kwargs):
|
|
636
|
-
super().__init__(base_url="https://generativelanguage.googleapis.com", api_key=api_key, models=models, **kwargs)
|
|
637
|
-
self.safety_settings = safety_settings
|
|
638
|
-
self.thinking_config = thinking_config
|
|
639
|
-
self.curl = curl
|
|
640
|
-
self.headers = kwargs.get("headers", {"Content-Type": "application/json"})
|
|
641
|
-
# Google fails when using Authorization header, use query string param instead
|
|
642
|
-
if "Authorization" in self.headers:
|
|
643
|
-
del self.headers["Authorization"]
|
|
644
|
-
|
|
645
|
-
@classmethod
|
|
646
|
-
def test(cls, api_key=None, models=None, **kwargs):
|
|
647
|
-
if models is None:
|
|
648
|
-
models = {}
|
|
649
|
-
return api_key is not None and len(models) > 0
|
|
650
1184
|
|
|
651
|
-
|
|
652
|
-
|
|
1185
|
+
# Map models to provider models {model_id:model_id}
|
|
1186
|
+
model_map = await self.get_models()
|
|
1187
|
+
if self.map_models:
|
|
1188
|
+
map_model_values = set(self.map_models.values())
|
|
1189
|
+
to = {}
|
|
1190
|
+
for k, v in model_map.items():
|
|
1191
|
+
if k in self.map_models:
|
|
1192
|
+
to[k] = v
|
|
1193
|
+
if v in map_model_values:
|
|
1194
|
+
to[k] = v
|
|
1195
|
+
model_map = to
|
|
1196
|
+
else:
|
|
1197
|
+
self.map_models = model_map
|
|
1198
|
+
models = {}
|
|
1199
|
+
for k, v in model_map.items():
|
|
1200
|
+
models[k] = {
|
|
1201
|
+
"id": k,
|
|
1202
|
+
"name": v.replace(":", " "),
|
|
1203
|
+
"modalities": {"input": ["text"], "output": ["text"]},
|
|
1204
|
+
"cost": {
|
|
1205
|
+
"input": 0,
|
|
1206
|
+
"output": 0,
|
|
1207
|
+
},
|
|
1208
|
+
}
|
|
1209
|
+
self.models = models
|
|
653
1210
|
|
|
654
|
-
|
|
655
|
-
|
|
1211
|
+
def validate(self, **kwargs):
|
|
1212
|
+
return None
|
|
656
1213
|
|
|
657
|
-
# Filter out system messages and convert to proper Gemini format
|
|
658
|
-
contents = []
|
|
659
|
-
system_prompt = None
|
|
660
1214
|
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
if message["role"] == "system":
|
|
664
|
-
content = message["content"]
|
|
665
|
-
if isinstance(content, list):
|
|
666
|
-
for item in content:
|
|
667
|
-
if "text" in item:
|
|
668
|
-
system_prompt = item["text"]
|
|
669
|
-
break
|
|
670
|
-
elif isinstance(content, str):
|
|
671
|
-
system_prompt = content
|
|
672
|
-
elif "content" in message:
|
|
673
|
-
if isinstance(message["content"], list):
|
|
674
|
-
parts = []
|
|
675
|
-
for item in message["content"]:
|
|
676
|
-
if "type" in item:
|
|
677
|
-
if item["type"] == "image_url" and "image_url" in item:
|
|
678
|
-
image_url = item["image_url"]
|
|
679
|
-
if "url" not in image_url:
|
|
680
|
-
continue
|
|
681
|
-
url = image_url["url"]
|
|
682
|
-
if not url.startswith("data:"):
|
|
683
|
-
raise (Exception("Image was not downloaded: " + url))
|
|
684
|
-
# Extract mime type from data uri
|
|
685
|
-
mimetype = url.split(";", 1)[0].split(":", 1)[1] if ";" in url else "image/png"
|
|
686
|
-
base64_data = url.split(",", 1)[1]
|
|
687
|
-
parts.append({"inline_data": {"mime_type": mimetype, "data": base64_data}})
|
|
688
|
-
elif item["type"] == "input_audio" and "input_audio" in item:
|
|
689
|
-
input_audio = item["input_audio"]
|
|
690
|
-
if "data" not in input_audio:
|
|
691
|
-
continue
|
|
692
|
-
data = input_audio["data"]
|
|
693
|
-
format = input_audio["format"]
|
|
694
|
-
mimetype = f"audio/{format}"
|
|
695
|
-
parts.append({"inline_data": {"mime_type": mimetype, "data": data}})
|
|
696
|
-
elif item["type"] == "file" and "file" in item:
|
|
697
|
-
file = item["file"]
|
|
698
|
-
if "file_data" not in file:
|
|
699
|
-
continue
|
|
700
|
-
data = file["file_data"]
|
|
701
|
-
if not data.startswith("data:"):
|
|
702
|
-
raise (Exception("File was not downloaded: " + data))
|
|
703
|
-
# Extract mime type from data uri
|
|
704
|
-
mimetype = (
|
|
705
|
-
data.split(";", 1)[0].split(":", 1)[1]
|
|
706
|
-
if ";" in data
|
|
707
|
-
else "application/octet-stream"
|
|
708
|
-
)
|
|
709
|
-
base64_data = data.split(",", 1)[1]
|
|
710
|
-
parts.append({"inline_data": {"mime_type": mimetype, "data": base64_data}})
|
|
711
|
-
if "text" in item:
|
|
712
|
-
text = item["text"]
|
|
713
|
-
parts.append({"text": text})
|
|
714
|
-
if len(parts) > 0:
|
|
715
|
-
contents.append(
|
|
716
|
-
{
|
|
717
|
-
"role": message["role"]
|
|
718
|
-
if "role" in message and message["role"] == "user"
|
|
719
|
-
else "model",
|
|
720
|
-
"parts": parts,
|
|
721
|
-
}
|
|
722
|
-
)
|
|
723
|
-
else:
|
|
724
|
-
content = message["content"]
|
|
725
|
-
contents.append(
|
|
726
|
-
{
|
|
727
|
-
"role": message["role"] if "role" in message and message["role"] == "user" else "model",
|
|
728
|
-
"parts": [{"text": content}],
|
|
729
|
-
}
|
|
730
|
-
)
|
|
1215
|
+
class LMStudioProvider(OllamaProvider):
|
|
1216
|
+
sdk = "lmstudio"
|
|
731
1217
|
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
1218
|
+
def __init__(self, **kwargs):
|
|
1219
|
+
super().__init__(**kwargs)
|
|
1220
|
+
self.chat_url = f"{self.api}/chat/completions"
|
|
735
1221
|
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
if
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
if "thinkingConfig" in chat:
|
|
755
|
-
generation_config["thinkingConfig"] = chat["thinkingConfig"]
|
|
756
|
-
elif self.thinking_config:
|
|
757
|
-
generation_config["thinkingConfig"] = self.thinking_config
|
|
758
|
-
|
|
759
|
-
if len(generation_config) > 0:
|
|
760
|
-
gemini_chat["generationConfig"] = generation_config
|
|
761
|
-
|
|
762
|
-
started_at = int(time.time() * 1000)
|
|
763
|
-
gemini_chat_url = f"https://generativelanguage.googleapis.com/v1beta/models/{chat['model']}:generateContent?key={self.api_key}"
|
|
764
|
-
|
|
765
|
-
_log(f"POST {gemini_chat_url}")
|
|
766
|
-
_log(gemini_chat_summary(gemini_chat))
|
|
767
|
-
started_at = time.time()
|
|
1222
|
+
async def get_models(self):
|
|
1223
|
+
ret = {}
|
|
1224
|
+
try:
|
|
1225
|
+
async with aiohttp.ClientSession() as session:
|
|
1226
|
+
_log(f"GET {self.api}/models")
|
|
1227
|
+
async with session.get(
|
|
1228
|
+
f"{self.api}/models", headers=self.headers, timeout=aiohttp.ClientTimeout(total=120)
|
|
1229
|
+
) as response:
|
|
1230
|
+
data = await response_json(response)
|
|
1231
|
+
for model in data.get("data", []):
|
|
1232
|
+
id = model["id"]
|
|
1233
|
+
ret[id] = id
|
|
1234
|
+
_log(f"Loaded LMStudio models: {ret}")
|
|
1235
|
+
except Exception as e:
|
|
1236
|
+
_log(f"Error getting LMStudio models: {e}")
|
|
1237
|
+
# return empty dict if ollama is not available
|
|
1238
|
+
return ret
|
|
768
1239
|
|
|
769
|
-
if self.curl:
|
|
770
|
-
curl_args = [
|
|
771
|
-
"curl",
|
|
772
|
-
"-X",
|
|
773
|
-
"POST",
|
|
774
|
-
"-H",
|
|
775
|
-
"Content-Type: application/json",
|
|
776
|
-
"-d",
|
|
777
|
-
json.dumps(gemini_chat),
|
|
778
|
-
gemini_chat_url,
|
|
779
|
-
]
|
|
780
|
-
try:
|
|
781
|
-
o = subprocess.run(curl_args, check=True, capture_output=True, text=True, timeout=120)
|
|
782
|
-
obj = json.loads(o.stdout)
|
|
783
|
-
except Exception as e:
|
|
784
|
-
raise Exception(f"Error executing curl: {e}") from e
|
|
785
|
-
else:
|
|
786
|
-
async with session.post(
|
|
787
|
-
gemini_chat_url,
|
|
788
|
-
headers=self.headers,
|
|
789
|
-
data=json.dumps(gemini_chat),
|
|
790
|
-
timeout=aiohttp.ClientTimeout(total=120),
|
|
791
|
-
) as res:
|
|
792
|
-
obj = await response_json(res)
|
|
793
|
-
_log(f"google response:\n{json.dumps(obj, indent=2)}")
|
|
794
|
-
|
|
795
|
-
response = {
|
|
796
|
-
"id": f"chatcmpl-{started_at}",
|
|
797
|
-
"created": started_at,
|
|
798
|
-
"model": obj.get("modelVersion", chat["model"]),
|
|
799
|
-
}
|
|
800
|
-
choices = []
|
|
801
|
-
if "error" in obj:
|
|
802
|
-
_log(f"Error: {obj['error']}")
|
|
803
|
-
raise Exception(obj["error"]["message"])
|
|
804
|
-
for i, candidate in enumerate(obj["candidates"]):
|
|
805
|
-
role = "assistant"
|
|
806
|
-
if "content" in candidate and "role" in candidate["content"]:
|
|
807
|
-
role = "assistant" if candidate["content"]["role"] == "model" else candidate["content"]["role"]
|
|
808
|
-
|
|
809
|
-
# Safely extract content from all text parts
|
|
810
|
-
content = ""
|
|
811
|
-
reasoning = ""
|
|
812
|
-
if "content" in candidate and "parts" in candidate["content"]:
|
|
813
|
-
text_parts = []
|
|
814
|
-
reasoning_parts = []
|
|
815
|
-
for part in candidate["content"]["parts"]:
|
|
816
|
-
if "text" in part:
|
|
817
|
-
if "thought" in part and part["thought"]:
|
|
818
|
-
reasoning_parts.append(part["text"])
|
|
819
|
-
else:
|
|
820
|
-
text_parts.append(part["text"])
|
|
821
|
-
content = " ".join(text_parts)
|
|
822
|
-
reasoning = " ".join(reasoning_parts)
|
|
823
1240
|
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
},
|
|
831
|
-
}
|
|
832
|
-
if reasoning:
|
|
833
|
-
choice["message"]["reasoning"] = reasoning
|
|
834
|
-
choices.append(choice)
|
|
835
|
-
response["choices"] = choices
|
|
836
|
-
if "usageMetadata" in obj:
|
|
837
|
-
usage = obj["usageMetadata"]
|
|
838
|
-
response["usage"] = {
|
|
839
|
-
"completion_tokens": usage["candidatesTokenCount"],
|
|
840
|
-
"total_tokens": usage["totalTokenCount"],
|
|
841
|
-
"prompt_tokens": usage["promptTokenCount"],
|
|
842
|
-
}
|
|
843
|
-
return self.to_response(response, chat, started_at)
|
|
1241
|
+
def get_provider_model(model_name):
|
|
1242
|
+
for provider in g_handlers.values():
|
|
1243
|
+
provider_model = provider.provider_model(model_name)
|
|
1244
|
+
if provider_model:
|
|
1245
|
+
return provider_model
|
|
1246
|
+
return None
|
|
844
1247
|
|
|
845
1248
|
|
|
846
1249
|
def get_models():
|
|
@@ -856,42 +1259,259 @@ def get_models():
|
|
|
856
1259
|
def get_active_models():
|
|
857
1260
|
ret = []
|
|
858
1261
|
existing_models = set()
|
|
859
|
-
for
|
|
860
|
-
for model in provider.models:
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
1262
|
+
for provider_id, provider in g_handlers.items():
|
|
1263
|
+
for model in provider.models.values():
|
|
1264
|
+
name = model.get("name")
|
|
1265
|
+
if not name:
|
|
1266
|
+
_log(f"Provider {provider_id} model {model} has no name")
|
|
1267
|
+
continue
|
|
1268
|
+
if name not in existing_models:
|
|
1269
|
+
existing_models.add(name)
|
|
1270
|
+
item = model.copy()
|
|
1271
|
+
item.update({"provider": provider_id})
|
|
1272
|
+
ret.append(item)
|
|
866
1273
|
ret.sort(key=lambda x: x["id"])
|
|
867
1274
|
return ret
|
|
868
1275
|
|
|
869
1276
|
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
1277
|
+
def api_providers():
|
|
1278
|
+
ret = []
|
|
1279
|
+
for id, provider in g_handlers.items():
|
|
1280
|
+
ret.append({"id": id, "name": provider.name, "models": provider.models})
|
|
1281
|
+
return ret
|
|
1282
|
+
|
|
1283
|
+
|
|
1284
|
+
def to_error_message(e):
|
|
1285
|
+
return str(e)
|
|
1286
|
+
|
|
1287
|
+
|
|
1288
|
+
def to_error_response(e, stacktrace=False):
|
|
1289
|
+
status = {"errorCode": "Error", "message": to_error_message(e)}
|
|
1290
|
+
if stacktrace:
|
|
1291
|
+
status["stackTrace"] = traceback.format_exc()
|
|
1292
|
+
return {"responseStatus": status}
|
|
1293
|
+
|
|
1294
|
+
|
|
1295
|
+
def create_error_response(message, error_code="Error", stack_trace=None):
|
|
1296
|
+
ret = {"responseStatus": {"errorCode": error_code, "message": message}}
|
|
1297
|
+
if stack_trace:
|
|
1298
|
+
ret["responseStatus"]["stackTrace"] = stack_trace
|
|
1299
|
+
return ret
|
|
1300
|
+
|
|
1301
|
+
|
|
1302
|
+
def should_cancel_thread(context):
|
|
1303
|
+
ret = context.get("cancelled", False)
|
|
1304
|
+
if ret:
|
|
1305
|
+
thread_id = context.get("threadId")
|
|
1306
|
+
_dbg(f"Thread cancelled {thread_id}")
|
|
1307
|
+
return ret
|
|
1308
|
+
|
|
1309
|
+
|
|
1310
|
+
def g_chat_request(template=None, text=None, model=None, system_prompt=None):
|
|
1311
|
+
chat_template = g_config["defaults"].get(template or "text")
|
|
1312
|
+
if not chat_template:
|
|
1313
|
+
raise Exception(f"Chat template '{template}' not found")
|
|
1314
|
+
|
|
1315
|
+
chat = chat_template.copy()
|
|
1316
|
+
if model:
|
|
1317
|
+
chat["model"] = model
|
|
1318
|
+
if system_prompt is not None:
|
|
1319
|
+
chat["messages"].insert(0, {"role": "system", "content": system_prompt})
|
|
1320
|
+
if text is not None:
|
|
1321
|
+
if not chat["messages"] or len(chat["messages"]) == 0:
|
|
1322
|
+
chat["messages"] = [{"role": "user", "content": [{"type": "text", "text": ""}]}]
|
|
1323
|
+
|
|
1324
|
+
# replace content of last message if exists, else add
|
|
1325
|
+
last_msg = chat["messages"][-1] if "messages" in chat else None
|
|
1326
|
+
if last_msg and last_msg["role"] == "user":
|
|
1327
|
+
if isinstance(last_msg["content"], list):
|
|
1328
|
+
last_msg["content"][-1]["text"] = text
|
|
1329
|
+
else:
|
|
1330
|
+
last_msg["content"] = text
|
|
1331
|
+
else:
|
|
1332
|
+
chat["messages"].append({"role": "user", "content": text})
|
|
1333
|
+
|
|
1334
|
+
return chat
|
|
1335
|
+
|
|
1336
|
+
|
|
1337
|
+
async def g_chat_completion(chat, context=None):
|
|
1338
|
+
try:
|
|
1339
|
+
model = chat.get("model")
|
|
1340
|
+
if not model:
|
|
1341
|
+
raise Exception("Model not specified")
|
|
876
1342
|
|
|
1343
|
+
if context is None:
|
|
1344
|
+
context = {"chat": chat, "tools": "all"}
|
|
1345
|
+
|
|
1346
|
+
# get first provider that has the model
|
|
1347
|
+
candidate_providers = [name for name, provider in g_handlers.items() if provider.provider_model(model)]
|
|
1348
|
+
if len(candidate_providers) == 0:
|
|
1349
|
+
raise (Exception(f"Model {model} not found"))
|
|
1350
|
+
except Exception as e:
|
|
1351
|
+
await g_app.on_chat_error(e, context or {"chat": chat})
|
|
1352
|
+
raise e
|
|
1353
|
+
|
|
1354
|
+
started_at = time.time()
|
|
877
1355
|
first_exception = None
|
|
1356
|
+
provider_name = "Unknown"
|
|
878
1357
|
for name in candidate_providers:
|
|
879
|
-
provider = g_handlers[name]
|
|
880
|
-
_log(f"provider: {name} {type(provider).__name__}")
|
|
881
1358
|
try:
|
|
882
|
-
|
|
883
|
-
|
|
1359
|
+
provider_name = name
|
|
1360
|
+
provider = g_handlers[name]
|
|
1361
|
+
_log(f"provider: {name} {type(provider).__name__}")
|
|
1362
|
+
started_at = time.time()
|
|
1363
|
+
context["startedAt"] = datetime.now()
|
|
1364
|
+
context["provider"] = name
|
|
1365
|
+
model_info = provider.model_info(model)
|
|
1366
|
+
context["modelCost"] = model_info.get("cost", provider.model_cost(model)) or {"input": 0, "output": 0}
|
|
1367
|
+
context["modelInfo"] = model_info
|
|
1368
|
+
|
|
1369
|
+
# Accumulate usage across tool calls
|
|
1370
|
+
total_usage = {
|
|
1371
|
+
"prompt_tokens": 0,
|
|
1372
|
+
"completion_tokens": 0,
|
|
1373
|
+
"total_tokens": 0,
|
|
1374
|
+
}
|
|
1375
|
+
accumulated_cost = 0.0
|
|
1376
|
+
|
|
1377
|
+
# Inject global tools if present
|
|
1378
|
+
current_chat = chat.copy()
|
|
1379
|
+
if g_app.tool_definitions:
|
|
1380
|
+
only_tools_str = context.get("tools", "all")
|
|
1381
|
+
include_all_tools = only_tools_str == "all"
|
|
1382
|
+
only_tools = only_tools_str.split(",")
|
|
1383
|
+
|
|
1384
|
+
if include_all_tools or len(only_tools) > 0:
|
|
1385
|
+
if "tools" not in current_chat:
|
|
1386
|
+
current_chat["tools"] = []
|
|
1387
|
+
|
|
1388
|
+
existing_tools = {t["function"]["name"] for t in current_chat["tools"]}
|
|
1389
|
+
for tool_def in g_app.tool_definitions:
|
|
1390
|
+
name = tool_def["function"]["name"]
|
|
1391
|
+
if name not in existing_tools and (include_all_tools or name in only_tools):
|
|
1392
|
+
current_chat["tools"].append(tool_def)
|
|
1393
|
+
|
|
1394
|
+
# Apply pre-chat filters ONCE
|
|
1395
|
+
context["chat"] = current_chat
|
|
1396
|
+
for filter_func in g_app.chat_request_filters:
|
|
1397
|
+
await filter_func(current_chat, context)
|
|
1398
|
+
|
|
1399
|
+
# Tool execution loop
|
|
1400
|
+
max_iterations = 10
|
|
1401
|
+
tool_history = []
|
|
1402
|
+
final_response = None
|
|
1403
|
+
|
|
1404
|
+
for _ in range(max_iterations):
|
|
1405
|
+
if should_cancel_thread(context):
|
|
1406
|
+
return
|
|
1407
|
+
|
|
1408
|
+
response = await provider.chat(current_chat)
|
|
1409
|
+
|
|
1410
|
+
if should_cancel_thread(context):
|
|
1411
|
+
return
|
|
1412
|
+
|
|
1413
|
+
# Aggregate usage
|
|
1414
|
+
if "usage" in response:
|
|
1415
|
+
usage = response["usage"]
|
|
1416
|
+
total_usage["prompt_tokens"] += usage.get("prompt_tokens", 0)
|
|
1417
|
+
total_usage["completion_tokens"] += usage.get("completion_tokens", 0)
|
|
1418
|
+
total_usage["total_tokens"] += usage.get("total_tokens", 0)
|
|
1419
|
+
|
|
1420
|
+
# Calculate cost for this step if available
|
|
1421
|
+
if "cost" in response and isinstance(response["cost"], (int, float)):
|
|
1422
|
+
accumulated_cost += response["cost"]
|
|
1423
|
+
elif "cost" in usage and isinstance(usage["cost"], (int, float)):
|
|
1424
|
+
accumulated_cost += usage["cost"]
|
|
1425
|
+
|
|
1426
|
+
# Check for tool_calls in the response
|
|
1427
|
+
choice = response.get("choices", [])[0] if response.get("choices") else {}
|
|
1428
|
+
message = choice.get("message", {})
|
|
1429
|
+
tool_calls = message.get("tool_calls")
|
|
1430
|
+
|
|
1431
|
+
if tool_calls:
|
|
1432
|
+
# Append the assistant's message with tool calls to history
|
|
1433
|
+
if "messages" not in current_chat:
|
|
1434
|
+
current_chat["messages"] = []
|
|
1435
|
+
if "timestamp" not in message:
|
|
1436
|
+
message["timestamp"] = int(time.time() * 1000)
|
|
1437
|
+
current_chat["messages"].append(message)
|
|
1438
|
+
tool_history.append(message)
|
|
1439
|
+
|
|
1440
|
+
await g_app.on_chat_tool(current_chat, context)
|
|
1441
|
+
|
|
1442
|
+
for tool_call in tool_calls:
|
|
1443
|
+
function_name = tool_call["function"]["name"]
|
|
1444
|
+
try:
|
|
1445
|
+
function_args = json.loads(tool_call["function"]["arguments"])
|
|
1446
|
+
except Exception as e:
|
|
1447
|
+
tool_result = f"Error parsing JSON arguments for tool {function_name}: {e}"
|
|
1448
|
+
else:
|
|
1449
|
+
tool_result = f"Error: Tool {function_name} not found"
|
|
1450
|
+
if function_name in g_app.tools:
|
|
1451
|
+
try:
|
|
1452
|
+
func = g_app.tools[function_name]
|
|
1453
|
+
if inspect.iscoroutinefunction(func):
|
|
1454
|
+
tool_result = await func(**function_args)
|
|
1455
|
+
else:
|
|
1456
|
+
tool_result = func(**function_args)
|
|
1457
|
+
except Exception as e:
|
|
1458
|
+
tool_result = f"Error executing tool {function_name}: {e}"
|
|
1459
|
+
|
|
1460
|
+
# Append tool result to history
|
|
1461
|
+
tool_msg = {"role": "tool", "tool_call_id": tool_call["id"], "content": to_content(tool_result)}
|
|
1462
|
+
current_chat["messages"].append(tool_msg)
|
|
1463
|
+
tool_history.append(tool_msg)
|
|
1464
|
+
|
|
1465
|
+
await g_app.on_chat_tool(current_chat, context)
|
|
1466
|
+
|
|
1467
|
+
if should_cancel_thread(context):
|
|
1468
|
+
return
|
|
1469
|
+
|
|
1470
|
+
# Continue loop to send tool results back to LLM
|
|
1471
|
+
continue
|
|
1472
|
+
|
|
1473
|
+
# If no tool calls, this is the final response
|
|
1474
|
+
if tool_history:
|
|
1475
|
+
response["tool_history"] = tool_history
|
|
1476
|
+
|
|
1477
|
+
# Update final response with aggregated usage
|
|
1478
|
+
if "usage" not in response:
|
|
1479
|
+
response["usage"] = {}
|
|
1480
|
+
# convert to int seconds
|
|
1481
|
+
context["duration"] = duration = int(time.time() - started_at)
|
|
1482
|
+
total_usage.update({"duration": duration})
|
|
1483
|
+
response["usage"].update(total_usage)
|
|
1484
|
+
# If we accumulated cost, set it on the response
|
|
1485
|
+
if accumulated_cost > 0:
|
|
1486
|
+
response["cost"] = accumulated_cost
|
|
1487
|
+
|
|
1488
|
+
final_response = response
|
|
1489
|
+
break # Exit tool loop
|
|
1490
|
+
|
|
1491
|
+
if final_response:
|
|
1492
|
+
# Apply post-chat filters ONCE on final response
|
|
1493
|
+
for filter_func in g_app.chat_response_filters:
|
|
1494
|
+
await filter_func(final_response, context)
|
|
1495
|
+
|
|
1496
|
+
if DEBUG:
|
|
1497
|
+
_dbg(json.dumps(final_response, indent=2))
|
|
1498
|
+
|
|
1499
|
+
return final_response
|
|
1500
|
+
|
|
884
1501
|
except Exception as e:
|
|
885
1502
|
if first_exception is None:
|
|
886
1503
|
first_exception = e
|
|
887
|
-
|
|
1504
|
+
context["stackTrace"] = traceback.format_exc()
|
|
1505
|
+
_err(f"Provider {provider_name} failed", first_exception)
|
|
1506
|
+
await g_app.on_chat_error(e, context)
|
|
1507
|
+
|
|
888
1508
|
continue
|
|
889
1509
|
|
|
890
1510
|
# If we get here, all providers failed
|
|
891
1511
|
raise first_exception
|
|
892
1512
|
|
|
893
1513
|
|
|
894
|
-
async def cli_chat(chat, image=None, audio=None, file=None, args=None, raw=False):
|
|
1514
|
+
async def cli_chat(chat, tools=None, image=None, audio=None, file=None, args=None, raw=False):
|
|
895
1515
|
if g_default_model:
|
|
896
1516
|
chat["model"] = g_default_model
|
|
897
1517
|
|
|
@@ -966,73 +1586,161 @@ async def cli_chat(chat, image=None, audio=None, file=None, args=None, raw=False
|
|
|
966
1586
|
printdump(chat)
|
|
967
1587
|
|
|
968
1588
|
try:
|
|
969
|
-
|
|
1589
|
+
context = {
|
|
1590
|
+
"tools": tools or "all",
|
|
1591
|
+
}
|
|
1592
|
+
response = await g_app.chat_completion(chat, context=context)
|
|
1593
|
+
|
|
970
1594
|
if raw:
|
|
971
1595
|
print(json.dumps(response, indent=2))
|
|
972
1596
|
exit(0)
|
|
973
1597
|
else:
|
|
974
|
-
|
|
975
|
-
|
|
1598
|
+
msg = response["choices"][0]["message"]
|
|
1599
|
+
if "content" in msg or "answer" in msg:
|
|
1600
|
+
print(msg["content"])
|
|
1601
|
+
|
|
1602
|
+
generated_files = []
|
|
1603
|
+
for choice in response["choices"]:
|
|
1604
|
+
if "message" in choice:
|
|
1605
|
+
msg = choice["message"]
|
|
1606
|
+
if "images" in msg:
|
|
1607
|
+
for image in msg["images"]:
|
|
1608
|
+
image_url = image["image_url"]["url"]
|
|
1609
|
+
generated_files.append(image_url)
|
|
1610
|
+
if "audios" in msg:
|
|
1611
|
+
for audio in msg["audios"]:
|
|
1612
|
+
audio_url = audio["audio_url"]["url"]
|
|
1613
|
+
generated_files.append(audio_url)
|
|
1614
|
+
|
|
1615
|
+
if len(generated_files) > 0:
|
|
1616
|
+
print("\nSaved files:")
|
|
1617
|
+
for file in generated_files:
|
|
1618
|
+
if file.startswith("/~cache"):
|
|
1619
|
+
print(get_cache_path(file[8:]))
|
|
1620
|
+
print(urljoin("http://localhost:8000", file))
|
|
1621
|
+
else:
|
|
1622
|
+
print(file)
|
|
1623
|
+
|
|
976
1624
|
except HTTPError as e:
|
|
977
1625
|
# HTTP error (4xx, 5xx)
|
|
978
1626
|
print(f"{e}:\n{e.body}")
|
|
979
|
-
exit(1)
|
|
1627
|
+
g_app.exit(1)
|
|
980
1628
|
except aiohttp.ClientConnectionError as e:
|
|
981
1629
|
# Connection issues
|
|
982
1630
|
print(f"Connection error: {e}")
|
|
983
|
-
exit(1)
|
|
1631
|
+
g_app.exit(1)
|
|
984
1632
|
except asyncio.TimeoutError as e:
|
|
985
1633
|
# Timeout
|
|
986
1634
|
print(f"Timeout error: {e}")
|
|
987
|
-
exit(1)
|
|
1635
|
+
g_app.exit(1)
|
|
988
1636
|
|
|
989
1637
|
|
|
990
1638
|
def config_str(key):
|
|
991
1639
|
return key in g_config and g_config[key] or None
|
|
992
1640
|
|
|
993
1641
|
|
|
994
|
-
def
|
|
1642
|
+
def load_config(config, providers, verbose=None):
|
|
1643
|
+
global g_config, g_providers, g_verbose
|
|
1644
|
+
g_config = config
|
|
1645
|
+
g_providers = providers
|
|
1646
|
+
if verbose:
|
|
1647
|
+
g_verbose = verbose
|
|
1648
|
+
|
|
1649
|
+
|
|
1650
|
+
def init_llms(config, providers):
|
|
995
1651
|
global g_config, g_handlers
|
|
996
1652
|
|
|
997
|
-
|
|
1653
|
+
load_config(config, providers)
|
|
998
1654
|
g_handlers = {}
|
|
999
1655
|
# iterate over config and replace $ENV with env value
|
|
1000
1656
|
for key, value in g_config.items():
|
|
1001
1657
|
if isinstance(value, str) and value.startswith("$"):
|
|
1002
|
-
g_config[key] = os.
|
|
1658
|
+
g_config[key] = os.getenv(value[1:], "")
|
|
1003
1659
|
|
|
1004
1660
|
# if g_verbose:
|
|
1005
1661
|
# printdump(g_config)
|
|
1006
1662
|
providers = g_config["providers"]
|
|
1007
1663
|
|
|
1008
|
-
for
|
|
1009
|
-
|
|
1010
|
-
provider_type = definition["type"]
|
|
1011
|
-
if "enabled" in definition and not definition["enabled"]:
|
|
1664
|
+
for id, orig in providers.items():
|
|
1665
|
+
if "enabled" in orig and not orig["enabled"]:
|
|
1012
1666
|
continue
|
|
1013
1667
|
|
|
1014
|
-
|
|
1015
|
-
if
|
|
1016
|
-
|
|
1017
|
-
if isinstance(value, str) and value.startswith("$"):
|
|
1018
|
-
definition["api_key"] = os.environ.get(value[1:], "")
|
|
1019
|
-
|
|
1020
|
-
# Create a copy of definition without the 'type' key for constructor kwargs
|
|
1021
|
-
constructor_kwargs = {k: v for k, v in definition.items() if k != "type" and k != "enabled"}
|
|
1022
|
-
constructor_kwargs["headers"] = g_config["defaults"]["headers"].copy()
|
|
1023
|
-
|
|
1024
|
-
if provider_type == "OpenAiProvider" and OpenAiProvider.test(**constructor_kwargs):
|
|
1025
|
-
g_handlers[name] = OpenAiProvider(**constructor_kwargs)
|
|
1026
|
-
elif provider_type == "OllamaProvider" and OllamaProvider.test(**constructor_kwargs):
|
|
1027
|
-
g_handlers[name] = OllamaProvider(**constructor_kwargs)
|
|
1028
|
-
elif provider_type == "GoogleProvider" and GoogleProvider.test(**constructor_kwargs):
|
|
1029
|
-
g_handlers[name] = GoogleProvider(**constructor_kwargs)
|
|
1030
|
-
elif provider_type == "GoogleOpenAiProvider" and GoogleOpenAiProvider.test(**constructor_kwargs):
|
|
1031
|
-
g_handlers[name] = GoogleOpenAiProvider(**constructor_kwargs)
|
|
1032
|
-
|
|
1668
|
+
provider, constructor_kwargs = create_provider_from_definition(id, orig)
|
|
1669
|
+
if provider and provider.test(**constructor_kwargs):
|
|
1670
|
+
g_handlers[id] = provider
|
|
1033
1671
|
return g_handlers
|
|
1034
1672
|
|
|
1035
1673
|
|
|
1674
|
+
def create_provider_from_definition(id, orig):
|
|
1675
|
+
definition = orig.copy()
|
|
1676
|
+
provider_id = definition.get("id", id)
|
|
1677
|
+
if "id" not in definition:
|
|
1678
|
+
definition["id"] = provider_id
|
|
1679
|
+
provider = g_providers.get(provider_id)
|
|
1680
|
+
constructor_kwargs = create_provider_kwargs(definition, provider)
|
|
1681
|
+
provider = create_provider(constructor_kwargs)
|
|
1682
|
+
return provider, constructor_kwargs
|
|
1683
|
+
|
|
1684
|
+
|
|
1685
|
+
def create_provider_kwargs(definition, provider=None):
|
|
1686
|
+
if provider:
|
|
1687
|
+
provider = provider.copy()
|
|
1688
|
+
provider.update(definition)
|
|
1689
|
+
else:
|
|
1690
|
+
provider = definition.copy()
|
|
1691
|
+
|
|
1692
|
+
# Replace API keys with environment variables if they start with $
|
|
1693
|
+
if "api_key" in provider:
|
|
1694
|
+
value = provider["api_key"]
|
|
1695
|
+
if isinstance(value, str) and value.startswith("$"):
|
|
1696
|
+
provider["api_key"] = os.getenv(value[1:], "")
|
|
1697
|
+
|
|
1698
|
+
if "api_key" not in provider and "env" in provider:
|
|
1699
|
+
for env_var in provider["env"]:
|
|
1700
|
+
val = os.getenv(env_var)
|
|
1701
|
+
if val:
|
|
1702
|
+
provider["api_key"] = val
|
|
1703
|
+
break
|
|
1704
|
+
|
|
1705
|
+
# Create a copy of provider
|
|
1706
|
+
constructor_kwargs = dict(provider.items())
|
|
1707
|
+
# Create a copy of all list and dict values
|
|
1708
|
+
for key, value in constructor_kwargs.items():
|
|
1709
|
+
if isinstance(value, (list, dict)):
|
|
1710
|
+
constructor_kwargs[key] = value.copy()
|
|
1711
|
+
constructor_kwargs["headers"] = g_config["defaults"]["headers"].copy()
|
|
1712
|
+
|
|
1713
|
+
if "modalities" in definition:
|
|
1714
|
+
constructor_kwargs["modalities"] = {}
|
|
1715
|
+
for modality, modality_definition in definition["modalities"].items():
|
|
1716
|
+
modality_provider = create_provider(modality_definition)
|
|
1717
|
+
if not modality_provider:
|
|
1718
|
+
return None
|
|
1719
|
+
constructor_kwargs["modalities"][modality] = modality_provider
|
|
1720
|
+
|
|
1721
|
+
return constructor_kwargs
|
|
1722
|
+
|
|
1723
|
+
|
|
1724
|
+
def create_provider(provider):
|
|
1725
|
+
if not isinstance(provider, dict):
|
|
1726
|
+
return None
|
|
1727
|
+
provider_label = provider.get("id", provider.get("name", "unknown"))
|
|
1728
|
+
npm_sdk = provider.get("npm")
|
|
1729
|
+
if not npm_sdk:
|
|
1730
|
+
_log(f"Provider {provider_label} is missing 'npm' sdk")
|
|
1731
|
+
return None
|
|
1732
|
+
|
|
1733
|
+
for provider_type in g_app.all_providers:
|
|
1734
|
+
if provider_type.sdk == npm_sdk:
|
|
1735
|
+
kwargs = create_provider_kwargs(provider)
|
|
1736
|
+
if kwargs is None:
|
|
1737
|
+
kwargs = provider
|
|
1738
|
+
return provider_type(**kwargs)
|
|
1739
|
+
|
|
1740
|
+
_log(f"Could not find provider {provider_label} with npm sdk {npm_sdk}")
|
|
1741
|
+
return None
|
|
1742
|
+
|
|
1743
|
+
|
|
1036
1744
|
async def load_llms():
|
|
1037
1745
|
global g_handlers
|
|
1038
1746
|
_log("Loading providers...")
|
|
@@ -1076,6 +1784,35 @@ async def save_default_config(config_path):
|
|
|
1076
1784
|
g_config = json.loads(config_json)
|
|
1077
1785
|
|
|
1078
1786
|
|
|
1787
|
+
async def update_providers(home_providers_path):
|
|
1788
|
+
global g_providers
|
|
1789
|
+
text = await get_text("https://models.dev/api.json")
|
|
1790
|
+
all_providers = json.loads(text)
|
|
1791
|
+
extra_providers = {}
|
|
1792
|
+
extra_providers_path = home_providers_path.replace("providers.json", "providers-extra.json")
|
|
1793
|
+
if os.path.exists(extra_providers_path):
|
|
1794
|
+
with open(extra_providers_path) as f:
|
|
1795
|
+
extra_providers = json.load(f)
|
|
1796
|
+
|
|
1797
|
+
filtered_providers = {}
|
|
1798
|
+
for id, provider in all_providers.items():
|
|
1799
|
+
if id in g_config["providers"]:
|
|
1800
|
+
filtered_providers[id] = provider
|
|
1801
|
+
if id in extra_providers and "models" in extra_providers[id]:
|
|
1802
|
+
for model_id, model in extra_providers[id]["models"].items():
|
|
1803
|
+
if "id" not in model:
|
|
1804
|
+
model["id"] = model_id
|
|
1805
|
+
if "name" not in model:
|
|
1806
|
+
model["name"] = id_to_name(model["id"])
|
|
1807
|
+
filtered_providers[id]["models"][model_id] = model
|
|
1808
|
+
|
|
1809
|
+
os.makedirs(os.path.dirname(home_providers_path), exist_ok=True)
|
|
1810
|
+
with open(home_providers_path, "w", encoding="utf-8") as f:
|
|
1811
|
+
json.dump(filtered_providers, f)
|
|
1812
|
+
|
|
1813
|
+
g_providers = filtered_providers
|
|
1814
|
+
|
|
1815
|
+
|
|
1079
1816
|
def provider_status():
|
|
1080
1817
|
enabled = list(g_handlers.keys())
|
|
1081
1818
|
disabled = [provider for provider in g_config["providers"] if provider not in enabled]
|
|
@@ -1097,7 +1834,11 @@ def print_status():
|
|
|
1097
1834
|
|
|
1098
1835
|
|
|
1099
1836
|
def home_llms_path(filename):
|
|
1100
|
-
return f"{os.
|
|
1837
|
+
return f"{os.getenv('HOME')}/.llms/{filename}"
|
|
1838
|
+
|
|
1839
|
+
|
|
1840
|
+
def get_cache_path(path=""):
|
|
1841
|
+
return home_llms_path(f"cache/{path}") if path else home_llms_path("cache")
|
|
1101
1842
|
|
|
1102
1843
|
|
|
1103
1844
|
def get_config_path():
|
|
@@ -1106,8 +1847,8 @@ def get_config_path():
|
|
|
1106
1847
|
"./llms.json",
|
|
1107
1848
|
home_config_path,
|
|
1108
1849
|
]
|
|
1109
|
-
if os.
|
|
1110
|
-
check_paths.insert(0, os.
|
|
1850
|
+
if os.getenv("LLMS_CONFIG_PATH"):
|
|
1851
|
+
check_paths.insert(0, os.getenv("LLMS_CONFIG_PATH"))
|
|
1111
1852
|
|
|
1112
1853
|
for check_path in check_paths:
|
|
1113
1854
|
g_config_path = os.path.normpath(os.path.join(os.path.dirname(__file__), check_path))
|
|
@@ -1116,28 +1857,20 @@ def get_config_path():
|
|
|
1116
1857
|
return None
|
|
1117
1858
|
|
|
1118
1859
|
|
|
1119
|
-
def get_ui_path():
|
|
1120
|
-
ui_paths = [home_llms_path("ui.json"), "ui.json"]
|
|
1121
|
-
for ui_path in ui_paths:
|
|
1122
|
-
if os.path.exists(ui_path):
|
|
1123
|
-
return ui_path
|
|
1124
|
-
return None
|
|
1125
|
-
|
|
1126
|
-
|
|
1127
1860
|
def enable_provider(provider):
|
|
1128
1861
|
msg = None
|
|
1129
1862
|
provider_config = g_config["providers"][provider]
|
|
1863
|
+
if not provider_config:
|
|
1864
|
+
return None, f"Provider {provider} not found"
|
|
1865
|
+
|
|
1866
|
+
provider, constructor_kwargs = create_provider_from_definition(provider, provider_config)
|
|
1867
|
+
msg = provider.validate(**constructor_kwargs)
|
|
1868
|
+
if msg:
|
|
1869
|
+
return None, msg
|
|
1870
|
+
|
|
1130
1871
|
provider_config["enabled"] = True
|
|
1131
|
-
if "api_key" in provider_config:
|
|
1132
|
-
api_key = provider_config["api_key"]
|
|
1133
|
-
if isinstance(api_key, str):
|
|
1134
|
-
if api_key.startswith("$"):
|
|
1135
|
-
if not os.environ.get(api_key[1:], ""):
|
|
1136
|
-
msg = f"WARNING: {provider} requires missing API Key in Environment Variable {api_key}"
|
|
1137
|
-
else:
|
|
1138
|
-
msg = f"WARNING: {provider} is not configured with an API Key"
|
|
1139
1872
|
save_config(g_config)
|
|
1140
|
-
init_llms(g_config)
|
|
1873
|
+
init_llms(g_config, g_providers)
|
|
1141
1874
|
return provider_config, msg
|
|
1142
1875
|
|
|
1143
1876
|
|
|
@@ -1145,7 +1878,7 @@ def disable_provider(provider):
|
|
|
1145
1878
|
provider_config = g_config["providers"][provider]
|
|
1146
1879
|
provider_config["enabled"] = False
|
|
1147
1880
|
save_config(g_config)
|
|
1148
|
-
init_llms(g_config)
|
|
1881
|
+
init_llms(g_config, g_providers)
|
|
1149
1882
|
|
|
1150
1883
|
|
|
1151
1884
|
def resolve_root():
|
|
@@ -1340,7 +2073,8 @@ async def check_models(provider_name, model_names=None):
|
|
|
1340
2073
|
else:
|
|
1341
2074
|
# Check only specified models
|
|
1342
2075
|
for model_name in model_names:
|
|
1343
|
-
|
|
2076
|
+
provider_model = provider.provider_model(model_name)
|
|
2077
|
+
if provider_model:
|
|
1344
2078
|
models_to_check.append(model_name)
|
|
1345
2079
|
else:
|
|
1346
2080
|
print(f"Model '{model_name}' not found in provider '{provider_name}'")
|
|
@@ -1355,69 +2089,76 @@ async def check_models(provider_name, model_names=None):
|
|
|
1355
2089
|
|
|
1356
2090
|
# Test each model
|
|
1357
2091
|
for model in models_to_check:
|
|
1358
|
-
|
|
1359
|
-
chat = (provider.check or g_config["defaults"]["check"]).copy()
|
|
1360
|
-
chat["model"] = model
|
|
2092
|
+
await check_provider_model(provider, model)
|
|
1361
2093
|
|
|
1362
|
-
|
|
1363
|
-
try:
|
|
1364
|
-
# Try to get a response from the model
|
|
1365
|
-
response = await provider.chat(chat)
|
|
1366
|
-
duration_ms = int((time.time() - started_at) * 1000)
|
|
2094
|
+
print()
|
|
1367
2095
|
|
|
1368
|
-
# Check if we got a valid response
|
|
1369
|
-
if response and "choices" in response and len(response["choices"]) > 0:
|
|
1370
|
-
print(f" ✓ {model:<40} ({duration_ms}ms)")
|
|
1371
|
-
else:
|
|
1372
|
-
print(f" ✗ {model:<40} Invalid response format")
|
|
1373
|
-
except HTTPError as e:
|
|
1374
|
-
duration_ms = int((time.time() - started_at) * 1000)
|
|
1375
|
-
error_msg = f"HTTP {e.status}"
|
|
1376
|
-
try:
|
|
1377
|
-
# Try to parse error body for more details
|
|
1378
|
-
error_body = json.loads(e.body) if e.body else {}
|
|
1379
|
-
if "error" in error_body:
|
|
1380
|
-
error = error_body["error"]
|
|
1381
|
-
if isinstance(error, dict):
|
|
1382
|
-
if "message" in error and isinstance(error["message"], str):
|
|
1383
|
-
# OpenRouter
|
|
1384
|
-
error_msg = error["message"]
|
|
1385
|
-
if "code" in error:
|
|
1386
|
-
error_msg = f"{error['code']} {error_msg}"
|
|
1387
|
-
if "metadata" in error and "raw" in error["metadata"]:
|
|
1388
|
-
error_msg += f" - {error['metadata']['raw']}"
|
|
1389
|
-
if "provider" in error:
|
|
1390
|
-
error_msg += f" ({error['provider']})"
|
|
1391
|
-
elif isinstance(error, str):
|
|
1392
|
-
error_msg = error
|
|
1393
|
-
elif "message" in error_body:
|
|
1394
|
-
if isinstance(error_body["message"], str):
|
|
1395
|
-
error_msg = error_body["message"]
|
|
1396
|
-
elif (
|
|
1397
|
-
isinstance(error_body["message"], dict)
|
|
1398
|
-
and "detail" in error_body["message"]
|
|
1399
|
-
and isinstance(error_body["message"]["detail"], list)
|
|
1400
|
-
):
|
|
1401
|
-
# codestral error format
|
|
1402
|
-
error_msg = error_body["message"]["detail"][0]["msg"]
|
|
1403
|
-
if (
|
|
1404
|
-
"loc" in error_body["message"]["detail"][0]
|
|
1405
|
-
and len(error_body["message"]["detail"][0]["loc"]) > 0
|
|
1406
|
-
):
|
|
1407
|
-
error_msg += f" (in {' '.join(error_body['message']['detail'][0]['loc'])})"
|
|
1408
|
-
except Exception as parse_error:
|
|
1409
|
-
_log(f"Error parsing error body: {parse_error}")
|
|
1410
|
-
error_msg = e.body[:100] if e.body else f"HTTP {e.status}"
|
|
1411
|
-
print(f" ✗ {model:<40} {error_msg}")
|
|
1412
|
-
except asyncio.TimeoutError:
|
|
1413
|
-
duration_ms = int((time.time() - started_at) * 1000)
|
|
1414
|
-
print(f" ✗ {model:<40} Timeout after {duration_ms}ms")
|
|
1415
|
-
except Exception as e:
|
|
1416
|
-
duration_ms = int((time.time() - started_at) * 1000)
|
|
1417
|
-
error_msg = str(e)[:100]
|
|
1418
|
-
print(f" ✗ {model:<40} {error_msg}")
|
|
1419
2096
|
|
|
1420
|
-
|
|
2097
|
+
async def check_provider_model(provider, model):
|
|
2098
|
+
# Create a simple ping chat request
|
|
2099
|
+
chat = (provider.check or g_config["defaults"]["check"]).copy()
|
|
2100
|
+
chat["model"] = model
|
|
2101
|
+
|
|
2102
|
+
success = False
|
|
2103
|
+
started_at = time.time()
|
|
2104
|
+
try:
|
|
2105
|
+
# Try to get a response from the model
|
|
2106
|
+
response = await provider.chat(chat)
|
|
2107
|
+
duration_ms = int((time.time() - started_at) * 1000)
|
|
2108
|
+
|
|
2109
|
+
# Check if we got a valid response
|
|
2110
|
+
if response and "choices" in response and len(response["choices"]) > 0:
|
|
2111
|
+
success = True
|
|
2112
|
+
print(f" ✓ {model:<40} ({duration_ms}ms)")
|
|
2113
|
+
else:
|
|
2114
|
+
print(f" ✗ {model:<40} Invalid response format")
|
|
2115
|
+
except HTTPError as e:
|
|
2116
|
+
duration_ms = int((time.time() - started_at) * 1000)
|
|
2117
|
+
error_msg = f"HTTP {e.status}"
|
|
2118
|
+
try:
|
|
2119
|
+
# Try to parse error body for more details
|
|
2120
|
+
error_body = json.loads(e.body) if e.body else {}
|
|
2121
|
+
if "error" in error_body:
|
|
2122
|
+
error = error_body["error"]
|
|
2123
|
+
if isinstance(error, dict):
|
|
2124
|
+
if "message" in error and isinstance(error["message"], str):
|
|
2125
|
+
# OpenRouter
|
|
2126
|
+
error_msg = error["message"]
|
|
2127
|
+
if "code" in error:
|
|
2128
|
+
error_msg = f"{error['code']} {error_msg}"
|
|
2129
|
+
if "metadata" in error and "raw" in error["metadata"]:
|
|
2130
|
+
error_msg += f" - {error['metadata']['raw']}"
|
|
2131
|
+
if "provider" in error:
|
|
2132
|
+
error_msg += f" ({error['provider']})"
|
|
2133
|
+
elif isinstance(error, str):
|
|
2134
|
+
error_msg = error
|
|
2135
|
+
elif "message" in error_body:
|
|
2136
|
+
if isinstance(error_body["message"], str):
|
|
2137
|
+
error_msg = error_body["message"]
|
|
2138
|
+
elif (
|
|
2139
|
+
isinstance(error_body["message"], dict)
|
|
2140
|
+
and "detail" in error_body["message"]
|
|
2141
|
+
and isinstance(error_body["message"]["detail"], list)
|
|
2142
|
+
):
|
|
2143
|
+
# codestral error format
|
|
2144
|
+
error_msg = error_body["message"]["detail"][0]["msg"]
|
|
2145
|
+
if (
|
|
2146
|
+
"loc" in error_body["message"]["detail"][0]
|
|
2147
|
+
and len(error_body["message"]["detail"][0]["loc"]) > 0
|
|
2148
|
+
):
|
|
2149
|
+
error_msg += f" (in {' '.join(error_body['message']['detail'][0]['loc'])})"
|
|
2150
|
+
except Exception as parse_error:
|
|
2151
|
+
_log(f"Error parsing error body: {parse_error}")
|
|
2152
|
+
error_msg = e.body[:100] if e.body else f"HTTP {e.status}"
|
|
2153
|
+
print(f" ✗ {model:<40} {error_msg}")
|
|
2154
|
+
except asyncio.TimeoutError:
|
|
2155
|
+
duration_ms = int((time.time() - started_at) * 1000)
|
|
2156
|
+
print(f" ✗ {model:<40} Timeout after {duration_ms}ms")
|
|
2157
|
+
except Exception as e:
|
|
2158
|
+
duration_ms = int((time.time() - started_at) * 1000)
|
|
2159
|
+
error_msg = str(e)[:100]
|
|
2160
|
+
print(f" ✗ {model:<40} {error_msg}")
|
|
2161
|
+
return success
|
|
1421
2162
|
|
|
1422
2163
|
|
|
1423
2164
|
def text_from_resource(filename):
|
|
@@ -1452,8 +2193,14 @@ async def text_from_resource_or_url(filename):
|
|
|
1452
2193
|
|
|
1453
2194
|
async def save_home_configs():
|
|
1454
2195
|
home_config_path = home_llms_path("llms.json")
|
|
1455
|
-
|
|
1456
|
-
|
|
2196
|
+
home_providers_path = home_llms_path("providers.json")
|
|
2197
|
+
home_providers_extra_path = home_llms_path("providers-extra.json")
|
|
2198
|
+
|
|
2199
|
+
if (
|
|
2200
|
+
os.path.exists(home_config_path)
|
|
2201
|
+
and os.path.exists(home_providers_path)
|
|
2202
|
+
and os.path.exists(home_providers_extra_path)
|
|
2203
|
+
):
|
|
1457
2204
|
return
|
|
1458
2205
|
|
|
1459
2206
|
llms_home = os.path.dirname(home_config_path)
|
|
@@ -1465,92 +2212,650 @@ async def save_home_configs():
|
|
|
1465
2212
|
f.write(config_json)
|
|
1466
2213
|
_log(f"Created default config at {home_config_path}")
|
|
1467
2214
|
|
|
1468
|
-
if not os.path.exists(
|
|
1469
|
-
|
|
1470
|
-
with open(
|
|
1471
|
-
f.write(
|
|
1472
|
-
_log(f"Created default
|
|
2215
|
+
if not os.path.exists(home_providers_path):
|
|
2216
|
+
providers_json = await text_from_resource_or_url("providers.json")
|
|
2217
|
+
with open(home_providers_path, "w", encoding="utf-8") as f:
|
|
2218
|
+
f.write(providers_json)
|
|
2219
|
+
_log(f"Created default providers config at {home_providers_path}")
|
|
2220
|
+
|
|
2221
|
+
if not os.path.exists(home_providers_extra_path):
|
|
2222
|
+
extra_json = await text_from_resource_or_url("providers-extra.json")
|
|
2223
|
+
with open(home_providers_extra_path, "w", encoding="utf-8") as f:
|
|
2224
|
+
f.write(extra_json)
|
|
2225
|
+
_log(f"Created default extra providers config at {home_providers_extra_path}")
|
|
1473
2226
|
except Exception:
|
|
1474
2227
|
print("Could not create llms.json. Create one with --init or use --config <path>")
|
|
1475
2228
|
exit(1)
|
|
1476
2229
|
|
|
1477
2230
|
|
|
1478
|
-
|
|
1479
|
-
|
|
1480
|
-
|
|
1481
|
-
|
|
2231
|
+
def load_config_json(config_json):
|
|
2232
|
+
if config_json is None:
|
|
2233
|
+
return None
|
|
2234
|
+
config = json.loads(config_json)
|
|
2235
|
+
if not config or "version" not in config or config["version"] < 3:
|
|
2236
|
+
preserve_keys = ["auth", "defaults", "limits", "convert"]
|
|
2237
|
+
new_config = json.loads(text_from_resource("llms.json"))
|
|
2238
|
+
if config:
|
|
2239
|
+
for key in preserve_keys:
|
|
2240
|
+
if key in config:
|
|
2241
|
+
new_config[key] = config[key]
|
|
2242
|
+
config = new_config
|
|
2243
|
+
# move old config to YYYY-MM-DD.bak
|
|
2244
|
+
new_path = f"{g_config_path}.{datetime.now().strftime('%Y-%m-%d')}.bak"
|
|
2245
|
+
if os.path.exists(new_path):
|
|
2246
|
+
os.remove(new_path)
|
|
2247
|
+
os.rename(g_config_path, new_path)
|
|
2248
|
+
print(f"llms.json migrated. old config moved to {new_path}")
|
|
2249
|
+
# save new config
|
|
2250
|
+
save_config(g_config)
|
|
2251
|
+
return config
|
|
2252
|
+
|
|
2253
|
+
|
|
2254
|
+
async def reload_providers():
|
|
2255
|
+
global g_config, g_handlers
|
|
2256
|
+
g_handlers = init_llms(g_config, g_providers)
|
|
2257
|
+
await load_llms()
|
|
1482
2258
|
_log(f"{len(g_handlers)} providers loaded")
|
|
1483
2259
|
return g_handlers
|
|
1484
2260
|
|
|
1485
2261
|
|
|
1486
|
-
async def watch_config_files(config_path,
|
|
2262
|
+
async def watch_config_files(config_path, providers_path, interval=1):
|
|
1487
2263
|
"""Watch config files and reload providers when they change"""
|
|
1488
2264
|
global g_config
|
|
1489
2265
|
|
|
1490
2266
|
config_path = Path(config_path)
|
|
1491
|
-
|
|
2267
|
+
providers_path = Path(providers_path)
|
|
2268
|
+
|
|
2269
|
+
_log(f"Watching config file: {config_path}")
|
|
2270
|
+
_log(f"Watching providers file: {providers_path}")
|
|
1492
2271
|
|
|
1493
|
-
|
|
2272
|
+
def get_latest_mtime():
|
|
2273
|
+
ret = 0
|
|
2274
|
+
name = "llms.json"
|
|
2275
|
+
if config_path.is_file():
|
|
2276
|
+
ret = config_path.stat().st_mtime
|
|
2277
|
+
name = config_path.name
|
|
2278
|
+
if providers_path.is_file() and providers_path.stat().st_mtime > ret:
|
|
2279
|
+
ret = providers_path.stat().st_mtime
|
|
2280
|
+
name = providers_path.name
|
|
2281
|
+
return ret, name
|
|
1494
2282
|
|
|
1495
|
-
|
|
2283
|
+
latest_mtime, name = get_latest_mtime()
|
|
1496
2284
|
|
|
1497
2285
|
while True:
|
|
1498
2286
|
await asyncio.sleep(interval)
|
|
1499
2287
|
|
|
1500
2288
|
# Check llms.json
|
|
1501
2289
|
try:
|
|
1502
|
-
|
|
1503
|
-
|
|
2290
|
+
new_mtime, name = get_latest_mtime()
|
|
2291
|
+
if new_mtime > latest_mtime:
|
|
2292
|
+
_log(f"Config file changed: {name}")
|
|
2293
|
+
latest_mtime = new_mtime
|
|
1504
2294
|
|
|
1505
|
-
|
|
1506
|
-
|
|
1507
|
-
|
|
1508
|
-
|
|
1509
|
-
file_mtimes[str(config_path)] = mtime
|
|
1510
|
-
|
|
1511
|
-
try:
|
|
1512
|
-
# Reload llms.json
|
|
1513
|
-
with open(config_path) as f:
|
|
1514
|
-
g_config = json.load(f)
|
|
2295
|
+
try:
|
|
2296
|
+
# Reload llms.json
|
|
2297
|
+
with open(config_path) as f:
|
|
2298
|
+
g_config = json.load(f)
|
|
1515
2299
|
|
|
1516
|
-
|
|
1517
|
-
|
|
1518
|
-
|
|
1519
|
-
|
|
1520
|
-
|
|
2300
|
+
# Reload providers
|
|
2301
|
+
await reload_providers()
|
|
2302
|
+
_log("Providers reloaded successfully")
|
|
2303
|
+
except Exception as e:
|
|
2304
|
+
_log(f"Error reloading config: {e}")
|
|
1521
2305
|
except FileNotFoundError:
|
|
1522
2306
|
pass
|
|
1523
2307
|
|
|
1524
|
-
|
|
1525
|
-
|
|
2308
|
+
|
|
2309
|
+
def get_session_token(request):
|
|
2310
|
+
return request.query.get("session") or request.headers.get("X-Session-Token") or request.cookies.get("llms-token")
|
|
2311
|
+
|
|
2312
|
+
|
|
2313
|
+
class AppExtensions:
|
|
2314
|
+
"""
|
|
2315
|
+
APIs extensions can use to extend the app
|
|
2316
|
+
"""
|
|
2317
|
+
|
|
2318
|
+
def __init__(self, cli_args, extra_args):
|
|
2319
|
+
self.cli_args = cli_args
|
|
2320
|
+
self.extra_args = extra_args
|
|
2321
|
+
self.config = None
|
|
2322
|
+
self.error_auth_required = create_error_response("Authentication required", "Unauthorized")
|
|
2323
|
+
self.ui_extensions = []
|
|
2324
|
+
self.chat_request_filters = []
|
|
2325
|
+
self.chat_tool_filters = []
|
|
2326
|
+
self.chat_response_filters = []
|
|
2327
|
+
self.chat_error_filters = []
|
|
2328
|
+
self.server_add_get = []
|
|
2329
|
+
self.server_add_post = []
|
|
2330
|
+
self.server_add_put = []
|
|
2331
|
+
self.server_add_delete = []
|
|
2332
|
+
self.server_add_patch = []
|
|
2333
|
+
self.cache_saved_filters = []
|
|
2334
|
+
self.shutdown_handlers = []
|
|
2335
|
+
self.tools = {}
|
|
2336
|
+
self.tool_definitions = []
|
|
2337
|
+
self.index_headers = []
|
|
2338
|
+
self.index_footers = []
|
|
2339
|
+
self.request_args = {
|
|
2340
|
+
"image_config": dict, # e.g. { "aspect_ratio": "1:1" }
|
|
2341
|
+
"temperature": float, # e.g: 0.7
|
|
2342
|
+
"max_completion_tokens": int, # e.g: 2048
|
|
2343
|
+
"seed": int, # e.g: 42
|
|
2344
|
+
"top_p": float, # e.g: 0.9
|
|
2345
|
+
"frequency_penalty": float, # e.g: 0.5
|
|
2346
|
+
"presence_penalty": float, # e.g: 0.5
|
|
2347
|
+
"stop": list, # e.g: ["Stop"]
|
|
2348
|
+
"reasoning_effort": str, # e.g: minimal, low, medium, high
|
|
2349
|
+
"verbosity": str, # e.g: low, medium, high
|
|
2350
|
+
"service_tier": str, # e.g: auto, default
|
|
2351
|
+
"top_logprobs": int,
|
|
2352
|
+
"safety_identifier": str,
|
|
2353
|
+
"store": bool,
|
|
2354
|
+
"enable_thinking": bool,
|
|
2355
|
+
}
|
|
2356
|
+
self.all_providers = [
|
|
2357
|
+
OpenAiCompatible,
|
|
2358
|
+
MistralProvider,
|
|
2359
|
+
GroqProvider,
|
|
2360
|
+
XaiProvider,
|
|
2361
|
+
CodestralProvider,
|
|
2362
|
+
OllamaProvider,
|
|
2363
|
+
LMStudioProvider,
|
|
2364
|
+
]
|
|
2365
|
+
self.aspect_ratios = {
|
|
2366
|
+
"1:1": "1024×1024",
|
|
2367
|
+
"2:3": "832×1248",
|
|
2368
|
+
"3:2": "1248×832",
|
|
2369
|
+
"3:4": "864×1184",
|
|
2370
|
+
"4:3": "1184×864",
|
|
2371
|
+
"4:5": "896×1152",
|
|
2372
|
+
"5:4": "1152×896",
|
|
2373
|
+
"9:16": "768×1344",
|
|
2374
|
+
"16:9": "1344×768",
|
|
2375
|
+
"21:9": "1536×672",
|
|
2376
|
+
}
|
|
2377
|
+
self.import_maps = {
|
|
2378
|
+
"vue-prod": "/ui/lib/vue.min.mjs",
|
|
2379
|
+
"vue": "/ui/lib/vue.mjs",
|
|
2380
|
+
"vue-router": "/ui/lib/vue-router.min.mjs",
|
|
2381
|
+
"@servicestack/client": "/ui/lib/servicestack-client.mjs",
|
|
2382
|
+
"@servicestack/vue": "/ui/lib/servicestack-vue.mjs",
|
|
2383
|
+
"idb": "/ui/lib/idb.min.mjs",
|
|
2384
|
+
"marked": "/ui/lib/marked.min.mjs",
|
|
2385
|
+
"highlight.js": "/ui/lib/highlight.min.mjs",
|
|
2386
|
+
"chart.js": "/ui/lib/chart.js",
|
|
2387
|
+
"color.js": "/ui/lib/color.js",
|
|
2388
|
+
"ctx.mjs": "/ui/ctx.mjs",
|
|
2389
|
+
}
|
|
2390
|
+
|
|
2391
|
+
def set_config(self, config):
|
|
2392
|
+
self.config = config
|
|
2393
|
+
self.auth_enabled = self.config.get("auth", {}).get("enabled", False)
|
|
2394
|
+
|
|
2395
|
+
# Authentication middleware helper
|
|
2396
|
+
def check_auth(self, request):
|
|
2397
|
+
"""Check if request is authenticated. Returns (is_authenticated, user_data)"""
|
|
2398
|
+
if not self.auth_enabled:
|
|
2399
|
+
return True, None
|
|
2400
|
+
|
|
2401
|
+
# Check for OAuth session token
|
|
2402
|
+
session_token = get_session_token(request)
|
|
2403
|
+
if session_token and session_token in g_sessions:
|
|
2404
|
+
return True, g_sessions[session_token]
|
|
2405
|
+
|
|
2406
|
+
# Check for API key
|
|
2407
|
+
auth_header = request.headers.get("Authorization", "")
|
|
2408
|
+
if auth_header.startswith("Bearer "):
|
|
2409
|
+
api_key = auth_header[7:]
|
|
2410
|
+
if api_key:
|
|
2411
|
+
return True, {"authProvider": "apikey"}
|
|
2412
|
+
|
|
2413
|
+
return False, None
|
|
2414
|
+
|
|
2415
|
+
def get_session(self, request):
|
|
2416
|
+
session_token = get_session_token(request)
|
|
2417
|
+
|
|
2418
|
+
if not session_token or session_token not in g_sessions:
|
|
2419
|
+
return None
|
|
2420
|
+
|
|
2421
|
+
session_data = g_sessions[session_token]
|
|
2422
|
+
return session_data
|
|
2423
|
+
|
|
2424
|
+
def get_username(self, request):
|
|
2425
|
+
session = self.get_session(request)
|
|
2426
|
+
if session:
|
|
2427
|
+
return session.get("userName")
|
|
2428
|
+
return None
|
|
2429
|
+
|
|
2430
|
+
def get_user_path(self, username=None):
|
|
2431
|
+
if username:
|
|
2432
|
+
return home_llms_path(os.path.join("user", username))
|
|
2433
|
+
return home_llms_path(os.path.join("user", "default"))
|
|
2434
|
+
|
|
2435
|
+
def chat_request(self, template=None, text=None, model=None, system_prompt=None):
|
|
2436
|
+
return g_chat_request(template=template, text=text, model=model, system_prompt=system_prompt)
|
|
2437
|
+
|
|
2438
|
+
async def chat_completion(self, chat, context=None):
|
|
2439
|
+
response = await g_chat_completion(chat, context)
|
|
2440
|
+
return response
|
|
2441
|
+
|
|
2442
|
+
def on_cache_saved_filters(self, context):
|
|
2443
|
+
# _log(f"on_cache_saved_filters {len(self.cache_saved_filters)}: {context['url']}")
|
|
2444
|
+
for filter_func in self.cache_saved_filters:
|
|
2445
|
+
filter_func(context)
|
|
2446
|
+
|
|
2447
|
+
async def on_chat_error(self, e, context):
|
|
2448
|
+
# Apply chat error filters
|
|
2449
|
+
if "stackTrace" not in context:
|
|
2450
|
+
context["stackTrace"] = traceback.format_exc()
|
|
2451
|
+
for filter_func in self.chat_error_filters:
|
|
1526
2452
|
try:
|
|
1527
|
-
|
|
1528
|
-
|
|
2453
|
+
await filter_func(e, context)
|
|
2454
|
+
except Exception as e:
|
|
2455
|
+
_err("chat error filter failed", e)
|
|
2456
|
+
|
|
2457
|
+
async def on_chat_tool(self, chat, context):
|
|
2458
|
+
m_len = len(chat.get("messages", []))
|
|
2459
|
+
t_len = len(self.chat_tool_filters)
|
|
2460
|
+
_dbg(
|
|
2461
|
+
f"on_tool_call for thread {context.get('threadId', None)} with {m_len} {pluralize('message', m_len)}, invoking {t_len} {pluralize('filter', t_len)}:"
|
|
2462
|
+
)
|
|
2463
|
+
for filter_func in self.chat_tool_filters:
|
|
2464
|
+
await filter_func(chat, context)
|
|
2465
|
+
|
|
2466
|
+
def exit(self, exit_code=0):
|
|
2467
|
+
if len(self.shutdown_handlers) > 0:
|
|
2468
|
+
_dbg(f"running {len(self.shutdown_handlers)} shutdown handlers...")
|
|
2469
|
+
for handler in self.shutdown_handlers:
|
|
2470
|
+
handler()
|
|
2471
|
+
|
|
2472
|
+
_dbg(f"exit({exit_code})")
|
|
2473
|
+
sys.exit(exit_code)
|
|
2474
|
+
|
|
2475
|
+
|
|
2476
|
+
def handler_name(handler):
|
|
2477
|
+
if hasattr(handler, "__name__"):
|
|
2478
|
+
return handler.__name__
|
|
2479
|
+
return "unknown"
|
|
2480
|
+
|
|
2481
|
+
|
|
2482
|
+
class ExtensionContext:
|
|
2483
|
+
def __init__(self, app, path):
|
|
2484
|
+
self.app = app
|
|
2485
|
+
self.cli_args = app.cli_args
|
|
2486
|
+
self.extra_args = app.extra_args
|
|
2487
|
+
self.error_auth_required = app.error_auth_required
|
|
2488
|
+
self.path = path
|
|
2489
|
+
self.name = os.path.basename(path)
|
|
2490
|
+
if self.name.endswith(".py"):
|
|
2491
|
+
self.name = self.name[:-3]
|
|
2492
|
+
self.ext_prefix = f"/ext/{self.name}"
|
|
2493
|
+
self.MOCK = MOCK
|
|
2494
|
+
self.MOCK_DIR = MOCK_DIR
|
|
2495
|
+
self.debug = DEBUG
|
|
2496
|
+
self.verbose = g_verbose
|
|
2497
|
+
self.aspect_ratios = app.aspect_ratios
|
|
2498
|
+
self.request_args = app.request_args
|
|
2499
|
+
|
|
2500
|
+
def chat_to_prompt(self, chat):
|
|
2501
|
+
return chat_to_prompt(chat)
|
|
2502
|
+
|
|
2503
|
+
def chat_to_system_prompt(self, chat):
|
|
2504
|
+
return chat_to_system_prompt(chat)
|
|
2505
|
+
|
|
2506
|
+
def chat_response_to_message(self, response):
|
|
2507
|
+
return chat_response_to_message(response)
|
|
2508
|
+
|
|
2509
|
+
def last_user_prompt(self, chat):
|
|
2510
|
+
return last_user_prompt(chat)
|
|
2511
|
+
|
|
2512
|
+
def to_file_info(self, chat, info=None, response=None):
|
|
2513
|
+
return to_file_info(chat, info=info, response=response)
|
|
2514
|
+
|
|
2515
|
+
def save_image_to_cache(self, base64_data, filename, image_info):
|
|
2516
|
+
return save_image_to_cache(base64_data, filename, image_info)
|
|
2517
|
+
|
|
2518
|
+
def save_bytes_to_cache(self, bytes_data, filename, file_info):
|
|
2519
|
+
return save_bytes_to_cache(bytes_data, filename, file_info)
|
|
2520
|
+
|
|
2521
|
+
def text_from_file(self, path):
|
|
2522
|
+
return text_from_file(path)
|
|
2523
|
+
|
|
2524
|
+
def log(self, message):
|
|
2525
|
+
if self.verbose:
|
|
2526
|
+
print(f"[{self.name}] {message}", flush=True)
|
|
2527
|
+
return message
|
|
2528
|
+
|
|
2529
|
+
def log_json(self, obj):
|
|
2530
|
+
if self.verbose:
|
|
2531
|
+
print(f"[{self.name}] {json.dumps(obj, indent=2)}", flush=True)
|
|
2532
|
+
return obj
|
|
2533
|
+
|
|
2534
|
+
def dbg(self, message):
|
|
2535
|
+
if self.debug:
|
|
2536
|
+
print(f"DEBUG [{self.name}]: {message}", flush=True)
|
|
1529
2537
|
|
|
1530
|
-
|
|
1531
|
-
|
|
1532
|
-
|
|
1533
|
-
|
|
1534
|
-
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
|
|
2538
|
+
def err(self, message, e):
|
|
2539
|
+
print(f"ERROR [{self.name}]: {message}", e)
|
|
2540
|
+
if self.verbose:
|
|
2541
|
+
print(traceback.format_exc(), flush=True)
|
|
2542
|
+
|
|
2543
|
+
def error_message(self, e):
|
|
2544
|
+
return to_error_message(e)
|
|
2545
|
+
|
|
2546
|
+
def error_response(self, e, stacktrace=False):
|
|
2547
|
+
return to_error_response(e, stacktrace=stacktrace)
|
|
2548
|
+
|
|
2549
|
+
def add_provider(self, provider):
|
|
2550
|
+
self.log(f"Registered provider: {provider.__name__}")
|
|
2551
|
+
self.app.all_providers.append(provider)
|
|
2552
|
+
|
|
2553
|
+
def register_ui_extension(self, index):
|
|
2554
|
+
path = os.path.join(self.ext_prefix, index)
|
|
2555
|
+
self.log(f"Registered UI extension: {path}")
|
|
2556
|
+
self.app.ui_extensions.append({"id": self.name, "path": path})
|
|
2557
|
+
|
|
2558
|
+
def register_chat_request_filter(self, handler):
|
|
2559
|
+
self.log(f"Registered chat request filter: {handler_name(handler)}")
|
|
2560
|
+
self.app.chat_request_filters.append(handler)
|
|
2561
|
+
|
|
2562
|
+
def register_chat_tool_filter(self, handler):
|
|
2563
|
+
self.log(f"Registered chat tool filter: {handler_name(handler)}")
|
|
2564
|
+
self.app.chat_tool_filters.append(handler)
|
|
2565
|
+
|
|
2566
|
+
def register_chat_response_filter(self, handler):
|
|
2567
|
+
self.log(f"Registered chat response filter: {handler_name(handler)}")
|
|
2568
|
+
self.app.chat_response_filters.append(handler)
|
|
2569
|
+
|
|
2570
|
+
def register_chat_error_filter(self, handler):
|
|
2571
|
+
self.log(f"Registered chat error filter: {handler_name(handler)}")
|
|
2572
|
+
self.app.chat_error_filters.append(handler)
|
|
2573
|
+
|
|
2574
|
+
def register_cache_saved_filter(self, handler):
|
|
2575
|
+
self.log(f"Registered cache saved filter: {handler_name(handler)}")
|
|
2576
|
+
self.app.cache_saved_filters.append(handler)
|
|
2577
|
+
|
|
2578
|
+
def register_shutdown_handler(self, handler):
|
|
2579
|
+
self.log(f"Registered shutdown handler: {handler_name(handler)}")
|
|
2580
|
+
self.app.shutdown_handlers.append(handler)
|
|
2581
|
+
|
|
2582
|
+
def add_static_files(self, ext_dir):
|
|
2583
|
+
self.log(f"Registered static files: {ext_dir}")
|
|
2584
|
+
|
|
2585
|
+
async def serve_static(request):
|
|
2586
|
+
path = request.match_info["path"]
|
|
2587
|
+
file_path = os.path.join(ext_dir, path)
|
|
2588
|
+
if os.path.exists(file_path):
|
|
2589
|
+
return web.FileResponse(file_path)
|
|
2590
|
+
return web.Response(status=404)
|
|
2591
|
+
|
|
2592
|
+
self.app.server_add_get.append((os.path.join(self.ext_prefix, "{path:.*}"), serve_static, {}))
|
|
2593
|
+
|
|
2594
|
+
def add_get(self, path, handler, **kwargs):
|
|
2595
|
+
self.dbg(f"Registered GET: {os.path.join(self.ext_prefix, path)}")
|
|
2596
|
+
self.app.server_add_get.append((os.path.join(self.ext_prefix, path), handler, kwargs))
|
|
2597
|
+
|
|
2598
|
+
def add_post(self, path, handler, **kwargs):
|
|
2599
|
+
self.dbg(f"Registered POST: {os.path.join(self.ext_prefix, path)}")
|
|
2600
|
+
self.app.server_add_post.append((os.path.join(self.ext_prefix, path), handler, kwargs))
|
|
2601
|
+
|
|
2602
|
+
def add_put(self, path, handler, **kwargs):
|
|
2603
|
+
self.dbg(f"Registered PUT: {os.path.join(self.ext_prefix, path)}")
|
|
2604
|
+
self.app.server_add_put.append((os.path.join(self.ext_prefix, path), handler, kwargs))
|
|
2605
|
+
|
|
2606
|
+
def add_delete(self, path, handler, **kwargs):
|
|
2607
|
+
self.dbg(f"Registered DELETE: {os.path.join(self.ext_prefix, path)}")
|
|
2608
|
+
self.app.server_add_delete.append((os.path.join(self.ext_prefix, path), handler, kwargs))
|
|
2609
|
+
|
|
2610
|
+
def add_patch(self, path, handler, **kwargs):
|
|
2611
|
+
self.dbg(f"Registered PATCH: {os.path.join(self.ext_prefix, path)}")
|
|
2612
|
+
self.app.server_add_patch.append((os.path.join(self.ext_prefix, path), handler, kwargs))
|
|
2613
|
+
|
|
2614
|
+
def add_importmaps(self, dict):
|
|
2615
|
+
self.app.import_maps.update(dict)
|
|
2616
|
+
|
|
2617
|
+
def add_index_header(self, html):
|
|
2618
|
+
self.app.index_headers.append(html)
|
|
2619
|
+
|
|
2620
|
+
def add_index_footer(self, html):
|
|
2621
|
+
self.app.index_footers.append(html)
|
|
2622
|
+
|
|
2623
|
+
def get_config(self):
|
|
2624
|
+
return g_config
|
|
2625
|
+
|
|
2626
|
+
def get_cache_path(self, path=""):
|
|
2627
|
+
return get_cache_path(path)
|
|
2628
|
+
|
|
2629
|
+
def chat_request(self, template=None, text=None, model=None, system_prompt=None):
|
|
2630
|
+
return self.app.chat_request(template=template, text=text, model=model, system_prompt=system_prompt)
|
|
2631
|
+
|
|
2632
|
+
def chat_completion(self, chat, context=None):
|
|
2633
|
+
return self.app.chat_completion(chat, context=context)
|
|
2634
|
+
|
|
2635
|
+
def get_providers(self):
|
|
2636
|
+
return g_handlers
|
|
2637
|
+
|
|
2638
|
+
def get_provider(self, name):
|
|
2639
|
+
return g_handlers.get(name)
|
|
2640
|
+
|
|
2641
|
+
def register_tool(self, func, tool_def=None):
|
|
2642
|
+
if tool_def is None:
|
|
2643
|
+
tool_def = function_to_tool_definition(func)
|
|
2644
|
+
|
|
2645
|
+
name = tool_def["function"]["name"]
|
|
2646
|
+
self.log(f"Registered tool: {name}")
|
|
2647
|
+
self.app.tools[name] = func
|
|
2648
|
+
self.app.tool_definitions.append(tool_def)
|
|
2649
|
+
|
|
2650
|
+
def check_auth(self, request):
|
|
2651
|
+
return self.app.check_auth(request)
|
|
2652
|
+
|
|
2653
|
+
def get_session(self, request):
|
|
2654
|
+
return self.app.get_session(request)
|
|
2655
|
+
|
|
2656
|
+
def get_username(self, request):
|
|
2657
|
+
return self.app.get_username(request)
|
|
2658
|
+
|
|
2659
|
+
def get_user_path(self, username=None):
|
|
2660
|
+
return self.app.get_user_path(username)
|
|
2661
|
+
|
|
2662
|
+
def should_cancel_thread(self, context):
|
|
2663
|
+
return should_cancel_thread(context)
|
|
2664
|
+
|
|
2665
|
+
def cache_message_inline_data(self, message):
|
|
2666
|
+
return cache_message_inline_data(message)
|
|
2667
|
+
|
|
2668
|
+
def to_content(self, result):
|
|
2669
|
+
return to_content(result)
|
|
2670
|
+
|
|
2671
|
+
|
|
2672
|
+
def get_extensions_path():
|
|
2673
|
+
return os.getenv("LLMS_EXTENSIONS_DIR", os.path.join(Path.home(), ".llms", "extensions"))
|
|
2674
|
+
|
|
2675
|
+
|
|
2676
|
+
def get_disabled_extensions():
|
|
2677
|
+
ret = DISABLE_EXTENSIONS.copy()
|
|
2678
|
+
if g_config:
|
|
2679
|
+
for ext in g_config.get("disable_extensions", []):
|
|
2680
|
+
if ext not in ret:
|
|
2681
|
+
ret.append(ext)
|
|
2682
|
+
return ret
|
|
2683
|
+
|
|
2684
|
+
|
|
2685
|
+
def get_extensions_dirs():
|
|
2686
|
+
"""
|
|
2687
|
+
Returns a list of extension directories.
|
|
2688
|
+
"""
|
|
2689
|
+
extensions_path = get_extensions_path()
|
|
2690
|
+
os.makedirs(extensions_path, exist_ok=True)
|
|
2691
|
+
|
|
2692
|
+
# allow overriding builtin extensions
|
|
2693
|
+
override_extensions = []
|
|
2694
|
+
if os.path.exists(extensions_path):
|
|
2695
|
+
override_extensions = os.listdir(extensions_path)
|
|
2696
|
+
|
|
2697
|
+
ret = []
|
|
2698
|
+
disabled_extensions = get_disabled_extensions()
|
|
2699
|
+
|
|
2700
|
+
builtin_extensions_dir = _ROOT / "extensions"
|
|
2701
|
+
if os.path.exists(builtin_extensions_dir):
|
|
2702
|
+
for item in os.listdir(builtin_extensions_dir):
|
|
2703
|
+
if os.path.isdir(os.path.join(builtin_extensions_dir, item)):
|
|
2704
|
+
if item in override_extensions:
|
|
2705
|
+
continue
|
|
2706
|
+
if item in disabled_extensions:
|
|
2707
|
+
continue
|
|
2708
|
+
ret.append(os.path.join(builtin_extensions_dir, item))
|
|
2709
|
+
|
|
2710
|
+
if os.path.exists(extensions_path):
|
|
2711
|
+
for item in os.listdir(extensions_path):
|
|
2712
|
+
if os.path.isdir(os.path.join(extensions_path, item)):
|
|
2713
|
+
if item in disabled_extensions:
|
|
2714
|
+
continue
|
|
2715
|
+
ret.append(os.path.join(extensions_path, item))
|
|
2716
|
+
|
|
2717
|
+
return ret
|
|
2718
|
+
|
|
2719
|
+
|
|
2720
|
+
def init_extensions(parser):
|
|
2721
|
+
"""
|
|
2722
|
+
Initializes extensions by loading their __init__.py files and calling the __parser__ function if it exists.
|
|
2723
|
+
"""
|
|
2724
|
+
for item_path in get_extensions_dirs():
|
|
2725
|
+
item = os.path.basename(item_path)
|
|
2726
|
+
|
|
2727
|
+
if os.path.isdir(item_path):
|
|
2728
|
+
try:
|
|
2729
|
+
# check for __parser__ function if exists in __init.__.py and call it with parser
|
|
2730
|
+
init_file = os.path.join(item_path, "__init__.py")
|
|
2731
|
+
if os.path.exists(init_file):
|
|
2732
|
+
spec = importlib.util.spec_from_file_location(item, init_file)
|
|
2733
|
+
if spec and spec.loader:
|
|
2734
|
+
module = importlib.util.module_from_spec(spec)
|
|
2735
|
+
sys.modules[item] = module
|
|
2736
|
+
spec.loader.exec_module(module)
|
|
2737
|
+
|
|
2738
|
+
parser_func = getattr(module, "__parser__", None)
|
|
2739
|
+
if callable(parser_func):
|
|
2740
|
+
parser_func(parser)
|
|
2741
|
+
_log(f"Extension {item} parser loaded")
|
|
2742
|
+
except Exception as e:
|
|
2743
|
+
_err(f"Failed to load extension {item} parser", e)
|
|
2744
|
+
|
|
2745
|
+
|
|
2746
|
+
def install_extensions():
|
|
2747
|
+
"""
|
|
2748
|
+
Scans ensure ~/.llms/extensions/ for directories with __init__.py and loads them as extensions.
|
|
2749
|
+
Calls the `__install__(ctx)` function in the extension module.
|
|
2750
|
+
"""
|
|
2751
|
+
|
|
2752
|
+
extension_dirs = get_extensions_dirs()
|
|
2753
|
+
ext_count = len(list(extension_dirs))
|
|
2754
|
+
if ext_count == 0:
|
|
2755
|
+
_log("No extensions found")
|
|
2756
|
+
return
|
|
2757
|
+
|
|
2758
|
+
disabled_extensions = get_disabled_extensions()
|
|
2759
|
+
if len(disabled_extensions) > 0:
|
|
2760
|
+
_log(f"Disabled extensions: {', '.join(disabled_extensions)}")
|
|
2761
|
+
|
|
2762
|
+
_log(f"Installing {ext_count} extension{'' if ext_count == 1 else 's'}...")
|
|
2763
|
+
|
|
2764
|
+
for item_path in extension_dirs:
|
|
2765
|
+
item = os.path.basename(item_path)
|
|
2766
|
+
|
|
2767
|
+
if os.path.isdir(item_path):
|
|
2768
|
+
sys.path.append(item_path)
|
|
2769
|
+
try:
|
|
2770
|
+
ctx = ExtensionContext(g_app, item_path)
|
|
2771
|
+
init_file = os.path.join(item_path, "__init__.py")
|
|
2772
|
+
if os.path.exists(init_file):
|
|
2773
|
+
spec = importlib.util.spec_from_file_location(item, init_file)
|
|
2774
|
+
if spec and spec.loader:
|
|
2775
|
+
module = importlib.util.module_from_spec(spec)
|
|
2776
|
+
sys.modules[item] = module
|
|
2777
|
+
spec.loader.exec_module(module)
|
|
2778
|
+
|
|
2779
|
+
install_func = getattr(module, "__install__", None)
|
|
2780
|
+
if callable(install_func):
|
|
2781
|
+
install_func(ctx)
|
|
2782
|
+
_log(f"Extension {item} installed")
|
|
2783
|
+
else:
|
|
2784
|
+
_dbg(f"Extension {item} has no __install__ function")
|
|
2785
|
+
else:
|
|
2786
|
+
_dbg(f"Extension {item} has no __init__.py")
|
|
2787
|
+
else:
|
|
2788
|
+
_dbg(f"Extension {init_file} not found")
|
|
2789
|
+
|
|
2790
|
+
# if ui folder exists, serve as static files at /ext/{item}/
|
|
2791
|
+
ui_path = os.path.join(item_path, "ui")
|
|
2792
|
+
if os.path.exists(ui_path):
|
|
2793
|
+
ctx.add_static_files(ui_path)
|
|
2794
|
+
|
|
2795
|
+
# Register UI extension if index.mjs exists (/ext/{item}/index.mjs)
|
|
2796
|
+
if os.path.exists(os.path.join(ui_path, "index.mjs")):
|
|
2797
|
+
ctx.register_ui_extension("index.mjs")
|
|
2798
|
+
|
|
2799
|
+
except Exception as e:
|
|
2800
|
+
_err(f"Failed to install extension {item}", e)
|
|
2801
|
+
else:
|
|
2802
|
+
_dbg(f"Extension {item} not found: {item_path} is not a directory {os.path.exists(item_path)}")
|
|
2803
|
+
|
|
2804
|
+
|
|
2805
|
+
def run_extension_cli():
|
|
2806
|
+
"""
|
|
2807
|
+
Run the CLI for an extension.
|
|
2808
|
+
"""
|
|
2809
|
+
for item_path in get_extensions_dirs():
|
|
2810
|
+
item = os.path.basename(item_path)
|
|
2811
|
+
|
|
2812
|
+
if os.path.isdir(item_path):
|
|
2813
|
+
init_file = os.path.join(item_path, "__init__.py")
|
|
2814
|
+
if os.path.exists(init_file):
|
|
2815
|
+
ctx = ExtensionContext(g_app, item_path)
|
|
2816
|
+
try:
|
|
2817
|
+
spec = importlib.util.spec_from_file_location(item, init_file)
|
|
2818
|
+
if spec and spec.loader:
|
|
2819
|
+
module = importlib.util.module_from_spec(spec)
|
|
2820
|
+
sys.modules[item] = module
|
|
2821
|
+
spec.loader.exec_module(module)
|
|
2822
|
+
|
|
2823
|
+
# Check for __run__ function if exists in __init__.py and call it with ctx
|
|
2824
|
+
run_func = getattr(module, "__run__", None)
|
|
2825
|
+
if callable(run_func):
|
|
2826
|
+
_log(f"Running extension {item}...")
|
|
2827
|
+
handled = run_func(ctx)
|
|
2828
|
+
return handled
|
|
2829
|
+
|
|
2830
|
+
except Exception as e:
|
|
2831
|
+
_err(f"Failed to run extension {item}", e)
|
|
2832
|
+
return False
|
|
1538
2833
|
|
|
1539
2834
|
|
|
1540
2835
|
def main():
|
|
1541
|
-
global _ROOT, g_verbose, g_default_model, g_logprefix, g_config, g_config_path,
|
|
2836
|
+
global _ROOT, g_verbose, g_default_model, g_logprefix, g_providers, g_config, g_config_path, g_app
|
|
2837
|
+
|
|
2838
|
+
_ROOT = os.getenv("LLMS_ROOT", resolve_root())
|
|
2839
|
+
if not _ROOT:
|
|
2840
|
+
print("Resource root not found")
|
|
2841
|
+
exit(1)
|
|
1542
2842
|
|
|
1543
2843
|
parser = argparse.ArgumentParser(description=f"llms v{VERSION}")
|
|
1544
2844
|
parser.add_argument("--config", default=None, help="Path to config file", metavar="FILE")
|
|
2845
|
+
parser.add_argument("--providers", default=None, help="Path to models.dev providers file", metavar="FILE")
|
|
1545
2846
|
parser.add_argument("-m", "--model", default=None, help="Model to use")
|
|
1546
2847
|
|
|
1547
2848
|
parser.add_argument("--chat", default=None, help="OpenAI Chat Completion Request to send", metavar="REQUEST")
|
|
1548
2849
|
parser.add_argument(
|
|
1549
2850
|
"-s", "--system", default=None, help="System prompt to use for chat completion", metavar="PROMPT"
|
|
1550
2851
|
)
|
|
2852
|
+
parser.add_argument(
|
|
2853
|
+
"--tools", default=None, help="Tools to use for chat completion (all|none|<tool>,<tool>...)", metavar="TOOLS"
|
|
2854
|
+
)
|
|
1551
2855
|
parser.add_argument("--image", default=None, help="Image input to use in chat completion")
|
|
1552
2856
|
parser.add_argument("--audio", default=None, help="Audio input to use in chat completion")
|
|
1553
2857
|
parser.add_argument("--file", default=None, help="File input to use in chat completion")
|
|
2858
|
+
parser.add_argument("--out", default=None, help="Image or Video Generation Request", metavar="MODALITY")
|
|
1554
2859
|
parser.add_argument(
|
|
1555
2860
|
"--args",
|
|
1556
2861
|
default=None,
|
|
@@ -1573,15 +2878,46 @@ def main():
|
|
|
1573
2878
|
parser.add_argument("--default", default=None, help="Configure the default model to use", metavar="MODEL")
|
|
1574
2879
|
|
|
1575
2880
|
parser.add_argument("--init", action="store_true", help="Create a default llms.json")
|
|
2881
|
+
parser.add_argument("--update-providers", action="store_true", help="Update local models.dev providers.json")
|
|
1576
2882
|
|
|
1577
|
-
parser.add_argument("--root", default=None, help="Change root directory for UI files", metavar="PATH")
|
|
1578
2883
|
parser.add_argument("--logprefix", default="", help="Prefix used in log messages", metavar="PREFIX")
|
|
1579
2884
|
parser.add_argument("--verbose", action="store_true", help="Verbose output")
|
|
1580
2885
|
|
|
2886
|
+
parser.add_argument(
|
|
2887
|
+
"--add",
|
|
2888
|
+
nargs="?",
|
|
2889
|
+
const="ls",
|
|
2890
|
+
default=None,
|
|
2891
|
+
help="Install an extension (lists available extensions if no name provided)",
|
|
2892
|
+
metavar="EXTENSION",
|
|
2893
|
+
)
|
|
2894
|
+
parser.add_argument(
|
|
2895
|
+
"--remove",
|
|
2896
|
+
nargs="?",
|
|
2897
|
+
const="ls",
|
|
2898
|
+
default=None,
|
|
2899
|
+
help="Remove an extension (lists installed extensions if no name provided)",
|
|
2900
|
+
metavar="EXTENSION",
|
|
2901
|
+
)
|
|
2902
|
+
|
|
2903
|
+
parser.add_argument(
|
|
2904
|
+
"--update",
|
|
2905
|
+
nargs="?",
|
|
2906
|
+
const="ls",
|
|
2907
|
+
default=None,
|
|
2908
|
+
help="Update an extension (use 'all' to update all extensions)",
|
|
2909
|
+
metavar="EXTENSION",
|
|
2910
|
+
)
|
|
2911
|
+
|
|
2912
|
+
# Load parser extensions, go through all extensions and load their parser arguments
|
|
2913
|
+
init_extensions(parser)
|
|
2914
|
+
|
|
1581
2915
|
cli_args, extra_args = parser.parse_known_args()
|
|
1582
2916
|
|
|
2917
|
+
g_app = AppExtensions(cli_args, extra_args)
|
|
2918
|
+
|
|
1583
2919
|
# Check for verbose mode from CLI argument or environment variables
|
|
1584
|
-
verbose_env = os.
|
|
2920
|
+
verbose_env = os.getenv("VERBOSE", "").lower()
|
|
1585
2921
|
if cli_args.verbose or verbose_env in ("1", "true"):
|
|
1586
2922
|
g_verbose = True
|
|
1587
2923
|
# printdump(cli_args)
|
|
@@ -1590,13 +2926,9 @@ def main():
|
|
|
1590
2926
|
if cli_args.logprefix:
|
|
1591
2927
|
g_logprefix = cli_args.logprefix
|
|
1592
2928
|
|
|
1593
|
-
_ROOT = Path(cli_args.root) if cli_args.root else resolve_root()
|
|
1594
|
-
if not _ROOT:
|
|
1595
|
-
print("Resource root not found")
|
|
1596
|
-
exit(1)
|
|
1597
|
-
|
|
1598
2929
|
home_config_path = home_llms_path("llms.json")
|
|
1599
|
-
|
|
2930
|
+
home_providers_path = home_llms_path("providers.json")
|
|
2931
|
+
home_providers_extra_path = home_llms_path("providers-extra.json")
|
|
1600
2932
|
|
|
1601
2933
|
if cli_args.init:
|
|
1602
2934
|
if os.path.exists(home_config_path):
|
|
@@ -1605,38 +2937,215 @@ def main():
|
|
|
1605
2937
|
asyncio.run(save_default_config(home_config_path))
|
|
1606
2938
|
print(f"Created default config at {home_config_path}")
|
|
1607
2939
|
|
|
1608
|
-
if os.path.exists(
|
|
1609
|
-
print(f"
|
|
2940
|
+
if os.path.exists(home_providers_path):
|
|
2941
|
+
print(f"providers.json already exists at {home_providers_path}")
|
|
2942
|
+
else:
|
|
2943
|
+
asyncio.run(save_text_url(github_url("providers.json"), home_providers_path))
|
|
2944
|
+
print(f"Created default providers config at {home_providers_path}")
|
|
2945
|
+
|
|
2946
|
+
if os.path.exists(home_providers_extra_path):
|
|
2947
|
+
print(f"providers-extra.json already exists at {home_providers_extra_path}")
|
|
1610
2948
|
else:
|
|
1611
|
-
asyncio.run(save_text_url(github_url("
|
|
1612
|
-
print(f"Created default
|
|
2949
|
+
asyncio.run(save_text_url(github_url("providers-extra.json"), home_providers_extra_path))
|
|
2950
|
+
print(f"Created default extra providers config at {home_providers_extra_path}")
|
|
1613
2951
|
exit(0)
|
|
1614
2952
|
|
|
2953
|
+
if cli_args.providers:
|
|
2954
|
+
if not os.path.exists(cli_args.providers):
|
|
2955
|
+
print(f"providers.json not found at {cli_args.providers}")
|
|
2956
|
+
exit(1)
|
|
2957
|
+
g_providers = json.loads(text_from_file(cli_args.providers))
|
|
2958
|
+
|
|
1615
2959
|
if cli_args.config:
|
|
1616
2960
|
# read contents
|
|
1617
2961
|
g_config_path = cli_args.config
|
|
1618
2962
|
with open(g_config_path, encoding="utf-8") as f:
|
|
1619
2963
|
config_json = f.read()
|
|
1620
|
-
g_config =
|
|
2964
|
+
g_config = load_config_json(config_json)
|
|
1621
2965
|
|
|
1622
2966
|
config_dir = os.path.dirname(g_config_path)
|
|
1623
|
-
|
|
1624
|
-
|
|
1625
|
-
|
|
1626
|
-
|
|
1627
|
-
else:
|
|
1628
|
-
if not os.path.exists(home_ui_path):
|
|
1629
|
-
ui_json = text_from_resource("ui.json")
|
|
1630
|
-
with open(home_ui_path, "w", encoding="utf-8") as f:
|
|
1631
|
-
f.write(ui_json)
|
|
1632
|
-
_log(f"Created default ui config at {home_ui_path}")
|
|
1633
|
-
g_ui_path = home_ui_path
|
|
2967
|
+
|
|
2968
|
+
if not g_providers and os.path.exists(os.path.join(config_dir, "providers.json")):
|
|
2969
|
+
g_providers = json.loads(text_from_file(os.path.join(config_dir, "providers.json")))
|
|
2970
|
+
|
|
1634
2971
|
else:
|
|
1635
|
-
# ensure llms.json and
|
|
2972
|
+
# ensure llms.json and providers.json exist in home directory
|
|
1636
2973
|
asyncio.run(save_home_configs())
|
|
1637
2974
|
g_config_path = home_config_path
|
|
1638
|
-
|
|
1639
|
-
|
|
2975
|
+
g_config = load_config_json(text_from_file(g_config_path))
|
|
2976
|
+
|
|
2977
|
+
g_app.set_config(g_config)
|
|
2978
|
+
|
|
2979
|
+
if not g_providers:
|
|
2980
|
+
g_providers = json.loads(text_from_file(home_providers_path))
|
|
2981
|
+
|
|
2982
|
+
if cli_args.update_providers:
|
|
2983
|
+
asyncio.run(update_providers(home_providers_path))
|
|
2984
|
+
print(f"Updated {home_providers_path}")
|
|
2985
|
+
exit(0)
|
|
2986
|
+
|
|
2987
|
+
# if home_providers_path is older than 1 day, update providers list
|
|
2988
|
+
if (
|
|
2989
|
+
os.path.exists(home_providers_path)
|
|
2990
|
+
and (time.time() - os.path.getmtime(home_providers_path)) > 86400
|
|
2991
|
+
and os.getenv("LLMS_DISABLE_UPDATE", "") != "1"
|
|
2992
|
+
):
|
|
2993
|
+
try:
|
|
2994
|
+
asyncio.run(update_providers(home_providers_path))
|
|
2995
|
+
_log(f"Updated {home_providers_path}")
|
|
2996
|
+
except Exception as e:
|
|
2997
|
+
_err("Failed to update providers", e)
|
|
2998
|
+
|
|
2999
|
+
if cli_args.add is not None:
|
|
3000
|
+
if cli_args.add == "ls":
|
|
3001
|
+
|
|
3002
|
+
async def list_extensions():
|
|
3003
|
+
print("\nAvailable extensions:")
|
|
3004
|
+
text = await get_text("https://api.github.com/orgs/llmspy/repos?per_page=100&sort=updated")
|
|
3005
|
+
repos = json.loads(text)
|
|
3006
|
+
max_name_length = 0
|
|
3007
|
+
for repo in repos:
|
|
3008
|
+
max_name_length = max(max_name_length, len(repo["name"]))
|
|
3009
|
+
|
|
3010
|
+
for repo in repos:
|
|
3011
|
+
print(f" {repo['name']:<{max_name_length + 2}} {repo['description']}")
|
|
3012
|
+
|
|
3013
|
+
print("\nUsage:")
|
|
3014
|
+
print(" llms --add <extension>")
|
|
3015
|
+
print(" llms --add <github-user>/<repo>")
|
|
3016
|
+
|
|
3017
|
+
asyncio.run(list_extensions())
|
|
3018
|
+
exit(0)
|
|
3019
|
+
|
|
3020
|
+
async def install_extension(name):
|
|
3021
|
+
# Determine git URL and target directory name
|
|
3022
|
+
if "/" in name:
|
|
3023
|
+
git_url = f"https://github.com/{name}"
|
|
3024
|
+
target_name = name.split("/")[-1]
|
|
3025
|
+
else:
|
|
3026
|
+
git_url = f"https://github.com/llmspy/{name}"
|
|
3027
|
+
target_name = name
|
|
3028
|
+
|
|
3029
|
+
# check extension is not already installed
|
|
3030
|
+
extensions_path = get_extensions_path()
|
|
3031
|
+
target_path = os.path.join(extensions_path, target_name)
|
|
3032
|
+
|
|
3033
|
+
if os.path.exists(target_path):
|
|
3034
|
+
print(f"Extension {target_name} is already installed at {target_path}")
|
|
3035
|
+
return
|
|
3036
|
+
|
|
3037
|
+
print(f"Installing extension: {name}")
|
|
3038
|
+
print(f"Cloning from {git_url} to {target_path}...")
|
|
3039
|
+
|
|
3040
|
+
try:
|
|
3041
|
+
subprocess.run(["git", "clone", git_url, target_path], check=True)
|
|
3042
|
+
|
|
3043
|
+
# Check for requirements.txt
|
|
3044
|
+
requirements_path = os.path.join(target_path, "requirements.txt")
|
|
3045
|
+
if os.path.exists(requirements_path):
|
|
3046
|
+
print(f"Installing dependencies from {requirements_path}...")
|
|
3047
|
+
|
|
3048
|
+
# Check if uv is installed
|
|
3049
|
+
has_uv = False
|
|
3050
|
+
try:
|
|
3051
|
+
subprocess.run(
|
|
3052
|
+
["uv", "--version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True
|
|
3053
|
+
)
|
|
3054
|
+
has_uv = True
|
|
3055
|
+
except (subprocess.CalledProcessError, FileNotFoundError):
|
|
3056
|
+
pass
|
|
3057
|
+
|
|
3058
|
+
if has_uv:
|
|
3059
|
+
subprocess.run(
|
|
3060
|
+
["uv", "pip", "install", "-p", sys.executable, "-r", "requirements.txt"],
|
|
3061
|
+
cwd=target_path,
|
|
3062
|
+
check=True,
|
|
3063
|
+
)
|
|
3064
|
+
else:
|
|
3065
|
+
subprocess.run(
|
|
3066
|
+
[sys.executable, "-m", "pip", "install", "-r", "requirements.txt"],
|
|
3067
|
+
cwd=target_path,
|
|
3068
|
+
check=True,
|
|
3069
|
+
)
|
|
3070
|
+
print("Dependencies installed successfully.")
|
|
3071
|
+
|
|
3072
|
+
print(f"Extension {target_name} installed successfully.")
|
|
3073
|
+
|
|
3074
|
+
except subprocess.CalledProcessError as e:
|
|
3075
|
+
print(f"Failed to install extension: {e}")
|
|
3076
|
+
# cleanup if clone failed but directory was created (unlikely with simple git clone but good practice)
|
|
3077
|
+
if os.path.exists(target_path) and not os.listdir(target_path):
|
|
3078
|
+
os.rmdir(target_path)
|
|
3079
|
+
|
|
3080
|
+
asyncio.run(install_extension(cli_args.add))
|
|
3081
|
+
exit(0)
|
|
3082
|
+
|
|
3083
|
+
if cli_args.remove is not None:
|
|
3084
|
+
if cli_args.remove == "ls":
|
|
3085
|
+
# List installed extensions
|
|
3086
|
+
extensions_path = get_extensions_path()
|
|
3087
|
+
extensions = os.listdir(extensions_path)
|
|
3088
|
+
if len(extensions) == 0:
|
|
3089
|
+
print("No extensions installed.")
|
|
3090
|
+
exit(0)
|
|
3091
|
+
print("Installed extensions:")
|
|
3092
|
+
for extension in extensions:
|
|
3093
|
+
print(f" {extension}")
|
|
3094
|
+
exit(0)
|
|
3095
|
+
# Remove an extension
|
|
3096
|
+
extension_name = cli_args.remove
|
|
3097
|
+
extensions_path = get_extensions_path()
|
|
3098
|
+
target_path = os.path.join(extensions_path, extension_name)
|
|
3099
|
+
|
|
3100
|
+
if not os.path.exists(target_path):
|
|
3101
|
+
print(f"Extension {extension_name} not found at {target_path}")
|
|
3102
|
+
exit(1)
|
|
3103
|
+
|
|
3104
|
+
print(f"Removing extension: {extension_name}...")
|
|
3105
|
+
try:
|
|
3106
|
+
shutil.rmtree(target_path)
|
|
3107
|
+
print(f"Extension {extension_name} removed successfully.")
|
|
3108
|
+
except Exception as e:
|
|
3109
|
+
print(f"Failed to remove extension: {e}")
|
|
3110
|
+
exit(1)
|
|
3111
|
+
|
|
3112
|
+
exit(0)
|
|
3113
|
+
|
|
3114
|
+
if cli_args.update:
|
|
3115
|
+
if cli_args.update == "ls":
|
|
3116
|
+
# List installed extensions
|
|
3117
|
+
extensions_path = get_extensions_path()
|
|
3118
|
+
extensions = os.listdir(extensions_path)
|
|
3119
|
+
if len(extensions) == 0:
|
|
3120
|
+
print("No extensions installed.")
|
|
3121
|
+
exit(0)
|
|
3122
|
+
print("Installed extensions:")
|
|
3123
|
+
for extension in extensions:
|
|
3124
|
+
print(f" {extension}")
|
|
3125
|
+
|
|
3126
|
+
print("\nUsage:")
|
|
3127
|
+
print(" llms --update <extension>")
|
|
3128
|
+
print(" llms --update all")
|
|
3129
|
+
exit(0)
|
|
3130
|
+
|
|
3131
|
+
async def update_extensions(extension_name):
|
|
3132
|
+
extensions_path = get_extensions_path()
|
|
3133
|
+
for extension in os.listdir(extensions_path):
|
|
3134
|
+
extension_path = os.path.join(extensions_path, extension)
|
|
3135
|
+
if os.path.isdir(extension_path):
|
|
3136
|
+
if extension_name != "all" and extension != extension_name:
|
|
3137
|
+
continue
|
|
3138
|
+
result = subprocess.run(["git", "pull"], cwd=extension_path, capture_output=True)
|
|
3139
|
+
if result.returncode != 0:
|
|
3140
|
+
print(f"Failed to update extension {extension}: {result.stderr.decode('utf-8')}")
|
|
3141
|
+
continue
|
|
3142
|
+
print(f"Updated extension {extension}")
|
|
3143
|
+
_log(result.stdout.decode("utf-8"))
|
|
3144
|
+
|
|
3145
|
+
asyncio.run(update_extensions(cli_args.update))
|
|
3146
|
+
exit(0)
|
|
3147
|
+
|
|
3148
|
+
install_extensions()
|
|
1640
3149
|
|
|
1641
3150
|
asyncio.run(reload_providers())
|
|
1642
3151
|
|
|
@@ -1654,23 +3163,45 @@ def main():
|
|
|
1654
3163
|
if cli_args.list:
|
|
1655
3164
|
# Show list of enabled providers and their models
|
|
1656
3165
|
enabled = []
|
|
3166
|
+
provider_count = 0
|
|
3167
|
+
model_count = 0
|
|
3168
|
+
|
|
3169
|
+
max_model_length = 0
|
|
3170
|
+
for name, provider in g_handlers.items():
|
|
3171
|
+
if len(filter_list) > 0 and name not in filter_list:
|
|
3172
|
+
continue
|
|
3173
|
+
for model in provider.models:
|
|
3174
|
+
max_model_length = max(max_model_length, len(model))
|
|
3175
|
+
|
|
1657
3176
|
for name, provider in g_handlers.items():
|
|
1658
3177
|
if len(filter_list) > 0 and name not in filter_list:
|
|
1659
3178
|
continue
|
|
3179
|
+
provider_count += 1
|
|
1660
3180
|
print(f"{name}:")
|
|
1661
3181
|
enabled.append(name)
|
|
1662
3182
|
for model in provider.models:
|
|
1663
|
-
|
|
3183
|
+
model_count += 1
|
|
3184
|
+
model_cost_info = None
|
|
3185
|
+
if "cost" in provider.models[model]:
|
|
3186
|
+
model_cost = provider.models[model]["cost"]
|
|
3187
|
+
if "input" in model_cost and "output" in model_cost:
|
|
3188
|
+
if model_cost["input"] == 0 and model_cost["output"] == 0:
|
|
3189
|
+
model_cost_info = " 0"
|
|
3190
|
+
else:
|
|
3191
|
+
model_cost_info = f"{model_cost['input']:5} / {model_cost['output']}"
|
|
3192
|
+
print(f" {model:{max_model_length}} {model_cost_info or ''}")
|
|
3193
|
+
|
|
3194
|
+
print(f"\n{model_count} models available from {provider_count} providers")
|
|
1664
3195
|
|
|
1665
3196
|
print_status()
|
|
1666
|
-
exit(0)
|
|
3197
|
+
g_app.exit(0)
|
|
1667
3198
|
|
|
1668
3199
|
if cli_args.check is not None:
|
|
1669
3200
|
# Check validity of models for a provider
|
|
1670
3201
|
provider_name = cli_args.check
|
|
1671
3202
|
model_names = extra_args if len(extra_args) > 0 else None
|
|
1672
3203
|
asyncio.run(check_models(provider_name, model_names))
|
|
1673
|
-
exit(0)
|
|
3204
|
+
g_app.exit(0)
|
|
1674
3205
|
|
|
1675
3206
|
if cli_args.serve is not None:
|
|
1676
3207
|
# Disable inactive providers and save to config before starting server
|
|
@@ -1690,10 +3221,6 @@ def main():
|
|
|
1690
3221
|
# Start server
|
|
1691
3222
|
port = int(cli_args.serve)
|
|
1692
3223
|
|
|
1693
|
-
if not os.path.exists(g_ui_path):
|
|
1694
|
-
print(f"UI not found at {g_ui_path}")
|
|
1695
|
-
exit(1)
|
|
1696
|
-
|
|
1697
3224
|
# Validate auth configuration if enabled
|
|
1698
3225
|
auth_enabled = g_config.get("auth", {}).get("enabled", False)
|
|
1699
3226
|
if auth_enabled:
|
|
@@ -1703,11 +3230,19 @@ def main():
|
|
|
1703
3230
|
|
|
1704
3231
|
# Expand environment variables
|
|
1705
3232
|
if client_id.startswith("$"):
|
|
1706
|
-
client_id =
|
|
3233
|
+
client_id = client_id[1:]
|
|
1707
3234
|
if client_secret.startswith("$"):
|
|
1708
|
-
client_secret =
|
|
3235
|
+
client_secret = client_secret[1:]
|
|
1709
3236
|
|
|
1710
|
-
|
|
3237
|
+
client_id = os.getenv(client_id, client_id)
|
|
3238
|
+
client_secret = os.getenv(client_secret, client_secret)
|
|
3239
|
+
|
|
3240
|
+
if (
|
|
3241
|
+
not client_id
|
|
3242
|
+
or not client_secret
|
|
3243
|
+
or client_id == "GITHUB_CLIENT_ID"
|
|
3244
|
+
or client_secret == "GITHUB_CLIENT_SECRET"
|
|
3245
|
+
):
|
|
1711
3246
|
print("ERROR: Authentication is enabled but GitHub OAuth is not properly configured.")
|
|
1712
3247
|
print("Please set GITHUB_CLIENT_ID and GITHUB_CLIENT_SECRET environment variables,")
|
|
1713
3248
|
print("or disable authentication by setting 'auth.enabled' to false in llms.json")
|
|
@@ -1721,60 +3256,35 @@ def main():
|
|
|
1721
3256
|
_log(f"client_max_size set to {client_max_size} bytes ({client_max_size / 1024 / 1024:.1f}MB)")
|
|
1722
3257
|
app = web.Application(client_max_size=client_max_size)
|
|
1723
3258
|
|
|
1724
|
-
# Authentication middleware helper
|
|
1725
|
-
def check_auth(request):
|
|
1726
|
-
"""Check if request is authenticated. Returns (is_authenticated, user_data)"""
|
|
1727
|
-
if not auth_enabled:
|
|
1728
|
-
return True, None
|
|
1729
|
-
|
|
1730
|
-
# Check for OAuth session token
|
|
1731
|
-
session_token = request.query.get("session") or request.headers.get("X-Session-Token")
|
|
1732
|
-
if session_token and session_token in g_sessions:
|
|
1733
|
-
return True, g_sessions[session_token]
|
|
1734
|
-
|
|
1735
|
-
# Check for API key
|
|
1736
|
-
auth_header = request.headers.get("Authorization", "")
|
|
1737
|
-
if auth_header.startswith("Bearer "):
|
|
1738
|
-
api_key = auth_header[7:]
|
|
1739
|
-
if api_key:
|
|
1740
|
-
return True, {"authProvider": "apikey"}
|
|
1741
|
-
|
|
1742
|
-
return False, None
|
|
1743
|
-
|
|
1744
3259
|
async def chat_handler(request):
|
|
1745
3260
|
# Check authentication if enabled
|
|
1746
|
-
is_authenticated, user_data = check_auth(request)
|
|
3261
|
+
is_authenticated, user_data = g_app.check_auth(request)
|
|
1747
3262
|
if not is_authenticated:
|
|
1748
|
-
return web.json_response(
|
|
1749
|
-
{
|
|
1750
|
-
"error": {
|
|
1751
|
-
"message": "Authentication required",
|
|
1752
|
-
"type": "authentication_error",
|
|
1753
|
-
"code": "unauthorized",
|
|
1754
|
-
}
|
|
1755
|
-
},
|
|
1756
|
-
status=401,
|
|
1757
|
-
)
|
|
3263
|
+
return web.json_response(g_app.error_auth_required, status=401)
|
|
1758
3264
|
|
|
1759
3265
|
try:
|
|
1760
3266
|
chat = await request.json()
|
|
1761
|
-
|
|
3267
|
+
context = {"chat": chat, "request": request, "user": g_app.get_username(request)}
|
|
3268
|
+
metadata = chat.get("metadata", {})
|
|
3269
|
+
context["threadId"] = metadata.get("threadId", None)
|
|
3270
|
+
context["tools"] = metadata.get("tools", "all")
|
|
3271
|
+
response = await g_app.chat_completion(chat, context)
|
|
1762
3272
|
return web.json_response(response)
|
|
1763
3273
|
except Exception as e:
|
|
1764
|
-
return web.json_response(
|
|
3274
|
+
return web.json_response(to_error_response(e), status=500)
|
|
1765
3275
|
|
|
1766
3276
|
app.router.add_post("/v1/chat/completions", chat_handler)
|
|
1767
3277
|
|
|
1768
|
-
async def models_handler(request):
|
|
1769
|
-
return web.json_response(get_models())
|
|
1770
|
-
|
|
1771
|
-
app.router.add_get("/models/list", models_handler)
|
|
1772
|
-
|
|
1773
3278
|
async def active_models_handler(request):
|
|
1774
3279
|
return web.json_response(get_active_models())
|
|
1775
3280
|
|
|
1776
3281
|
app.router.add_get("/models", active_models_handler)
|
|
1777
3282
|
|
|
3283
|
+
async def active_providers_handler(request):
|
|
3284
|
+
return web.json_response(api_providers())
|
|
3285
|
+
|
|
3286
|
+
app.router.add_get("/providers", active_providers_handler)
|
|
3287
|
+
|
|
1778
3288
|
async def status_handler(request):
|
|
1779
3289
|
enabled, disabled = provider_status()
|
|
1780
3290
|
return web.json_response(
|
|
@@ -1794,8 +3304,9 @@ def main():
|
|
|
1794
3304
|
if provider:
|
|
1795
3305
|
if data.get("enable", False):
|
|
1796
3306
|
provider_config, msg = enable_provider(provider)
|
|
1797
|
-
_log(f"Enabled provider {provider}")
|
|
1798
|
-
|
|
3307
|
+
_log(f"Enabled provider {provider} {msg}")
|
|
3308
|
+
if not msg:
|
|
3309
|
+
await load_llms()
|
|
1799
3310
|
elif data.get("disable", False):
|
|
1800
3311
|
disable_provider(provider)
|
|
1801
3312
|
_log(f"Disabled provider {provider}")
|
|
@@ -1810,11 +3321,144 @@ def main():
|
|
|
1810
3321
|
|
|
1811
3322
|
app.router.add_post("/providers/{provider}", provider_handler)
|
|
1812
3323
|
|
|
3324
|
+
async def upload_handler(request):
|
|
3325
|
+
# Check authentication if enabled
|
|
3326
|
+
is_authenticated, user_data = g_app.check_auth(request)
|
|
3327
|
+
if not is_authenticated:
|
|
3328
|
+
return web.json_response(g_app.error_auth_required, status=401)
|
|
3329
|
+
|
|
3330
|
+
reader = await request.multipart()
|
|
3331
|
+
|
|
3332
|
+
# Read first file field
|
|
3333
|
+
field = await reader.next()
|
|
3334
|
+
while field and field.name != "file":
|
|
3335
|
+
field = await reader.next()
|
|
3336
|
+
|
|
3337
|
+
if not field:
|
|
3338
|
+
return web.json_response(create_error_response("No file provided"), status=400)
|
|
3339
|
+
|
|
3340
|
+
filename = field.filename or "file"
|
|
3341
|
+
content = await field.read()
|
|
3342
|
+
mimetype = get_file_mime_type(filename)
|
|
3343
|
+
|
|
3344
|
+
# If image, resize if needed
|
|
3345
|
+
if mimetype.startswith("image/"):
|
|
3346
|
+
content, mimetype = convert_image_if_needed(content, mimetype)
|
|
3347
|
+
|
|
3348
|
+
# Calculate SHA256
|
|
3349
|
+
sha256_hash = hashlib.sha256(content).hexdigest()
|
|
3350
|
+
ext = filename.rsplit(".", 1)[1] if "." in filename else ""
|
|
3351
|
+
if not ext:
|
|
3352
|
+
ext = mimetypes.guess_extension(mimetype) or ""
|
|
3353
|
+
if ext.startswith("."):
|
|
3354
|
+
ext = ext[1:]
|
|
3355
|
+
|
|
3356
|
+
if not ext:
|
|
3357
|
+
ext = "bin"
|
|
3358
|
+
|
|
3359
|
+
save_filename = f"{sha256_hash}.{ext}" if ext else sha256_hash
|
|
3360
|
+
|
|
3361
|
+
# Use first 2 chars for subdir to avoid too many files in one dir
|
|
3362
|
+
subdir = sha256_hash[:2]
|
|
3363
|
+
relative_path = f"{subdir}/{save_filename}"
|
|
3364
|
+
full_path = get_cache_path(relative_path)
|
|
3365
|
+
|
|
3366
|
+
# if file and its .info.json already exists, return it
|
|
3367
|
+
info_path = os.path.splitext(full_path)[0] + ".info.json"
|
|
3368
|
+
if os.path.exists(full_path) and os.path.exists(info_path):
|
|
3369
|
+
return web.json_response(json.load(open(info_path)))
|
|
3370
|
+
|
|
3371
|
+
os.makedirs(os.path.dirname(full_path), exist_ok=True)
|
|
3372
|
+
|
|
3373
|
+
with open(full_path, "wb") as f:
|
|
3374
|
+
f.write(content)
|
|
3375
|
+
|
|
3376
|
+
url = f"/~cache/{relative_path}"
|
|
3377
|
+
response_data = {
|
|
3378
|
+
"date": int(time.time()),
|
|
3379
|
+
"url": url,
|
|
3380
|
+
"size": len(content),
|
|
3381
|
+
"type": mimetype,
|
|
3382
|
+
"name": filename,
|
|
3383
|
+
}
|
|
3384
|
+
|
|
3385
|
+
# If image, get dimensions
|
|
3386
|
+
if HAS_PIL and mimetype.startswith("image/"):
|
|
3387
|
+
try:
|
|
3388
|
+
with Image.open(BytesIO(content)) as img:
|
|
3389
|
+
response_data["width"] = img.width
|
|
3390
|
+
response_data["height"] = img.height
|
|
3391
|
+
except Exception:
|
|
3392
|
+
pass
|
|
3393
|
+
|
|
3394
|
+
# Save metadata
|
|
3395
|
+
info_path = os.path.splitext(full_path)[0] + ".info.json"
|
|
3396
|
+
with open(info_path, "w") as f:
|
|
3397
|
+
json.dump(response_data, f)
|
|
3398
|
+
|
|
3399
|
+
g_app.on_cache_saved_filters({"url": url, "info": response_data})
|
|
3400
|
+
|
|
3401
|
+
return web.json_response(response_data)
|
|
3402
|
+
|
|
3403
|
+
app.router.add_post("/upload", upload_handler)
|
|
3404
|
+
|
|
3405
|
+
async def extensions_handler(request):
|
|
3406
|
+
return web.json_response(g_app.ui_extensions)
|
|
3407
|
+
|
|
3408
|
+
app.router.add_get("/ext", extensions_handler)
|
|
3409
|
+
|
|
3410
|
+
async def tools_handler(request):
|
|
3411
|
+
return web.json_response(g_app.tool_definitions)
|
|
3412
|
+
|
|
3413
|
+
app.router.add_get("/ext/tools", tools_handler)
|
|
3414
|
+
|
|
3415
|
+
async def cache_handler(request):
|
|
3416
|
+
path = request.match_info["tail"]
|
|
3417
|
+
full_path = get_cache_path(path)
|
|
3418
|
+
|
|
3419
|
+
if "info" in request.query:
|
|
3420
|
+
info_path = os.path.splitext(full_path)[0] + ".info.json"
|
|
3421
|
+
if not os.path.exists(info_path):
|
|
3422
|
+
return web.Response(text="404: Not Found", status=404)
|
|
3423
|
+
|
|
3424
|
+
# Check for directory traversal for info path
|
|
3425
|
+
try:
|
|
3426
|
+
cache_root = Path(get_cache_path())
|
|
3427
|
+
requested_path = Path(info_path).resolve()
|
|
3428
|
+
if not str(requested_path).startswith(str(cache_root)):
|
|
3429
|
+
return web.Response(text="403: Forbidden", status=403)
|
|
3430
|
+
except Exception:
|
|
3431
|
+
return web.Response(text="403: Forbidden", status=403)
|
|
3432
|
+
|
|
3433
|
+
with open(info_path) as f:
|
|
3434
|
+
content = f.read()
|
|
3435
|
+
return web.Response(text=content, content_type="application/json")
|
|
3436
|
+
|
|
3437
|
+
if not os.path.exists(full_path):
|
|
3438
|
+
return web.Response(text="404: Not Found", status=404)
|
|
3439
|
+
|
|
3440
|
+
# Check for directory traversal
|
|
3441
|
+
try:
|
|
3442
|
+
cache_root = Path(get_cache_path())
|
|
3443
|
+
requested_path = Path(full_path).resolve()
|
|
3444
|
+
if not str(requested_path).startswith(str(cache_root)):
|
|
3445
|
+
return web.Response(text="403: Forbidden", status=403)
|
|
3446
|
+
except Exception:
|
|
3447
|
+
return web.Response(text="403: Forbidden", status=403)
|
|
3448
|
+
|
|
3449
|
+
with open(full_path, "rb") as f:
|
|
3450
|
+
content = f.read()
|
|
3451
|
+
|
|
3452
|
+
mimetype = get_file_mime_type(full_path)
|
|
3453
|
+
return web.Response(body=content, content_type=mimetype)
|
|
3454
|
+
|
|
3455
|
+
app.router.add_get("/~cache/{tail:.*}", cache_handler)
|
|
3456
|
+
|
|
1813
3457
|
# OAuth handlers
|
|
1814
3458
|
async def github_auth_handler(request):
|
|
1815
3459
|
"""Initiate GitHub OAuth flow"""
|
|
1816
3460
|
if "auth" not in g_config or "github" not in g_config["auth"]:
|
|
1817
|
-
return web.json_response(
|
|
3461
|
+
return web.json_response(create_error_response("GitHub OAuth not configured"), status=500)
|
|
1818
3462
|
|
|
1819
3463
|
auth_config = g_config["auth"]["github"]
|
|
1820
3464
|
client_id = auth_config.get("client_id", "")
|
|
@@ -1822,12 +3466,15 @@ def main():
|
|
|
1822
3466
|
|
|
1823
3467
|
# Expand environment variables
|
|
1824
3468
|
if client_id.startswith("$"):
|
|
1825
|
-
client_id =
|
|
3469
|
+
client_id = client_id[1:]
|
|
1826
3470
|
if redirect_uri.startswith("$"):
|
|
1827
|
-
redirect_uri =
|
|
3471
|
+
redirect_uri = redirect_uri[1:]
|
|
3472
|
+
|
|
3473
|
+
client_id = os.getenv(client_id, client_id)
|
|
3474
|
+
redirect_uri = os.getenv(redirect_uri, redirect_uri)
|
|
1828
3475
|
|
|
1829
3476
|
if not client_id:
|
|
1830
|
-
return web.json_response(
|
|
3477
|
+
return web.json_response(create_error_response("GitHub client_id not configured"), status=500)
|
|
1831
3478
|
|
|
1832
3479
|
# Generate CSRF state token
|
|
1833
3480
|
state = secrets.token_urlsafe(32)
|
|
@@ -1857,7 +3504,9 @@ def main():
|
|
|
1857
3504
|
|
|
1858
3505
|
# Expand environment variables
|
|
1859
3506
|
if restrict_to.startswith("$"):
|
|
1860
|
-
restrict_to =
|
|
3507
|
+
restrict_to = restrict_to[1:]
|
|
3508
|
+
|
|
3509
|
+
restrict_to = os.getenv(restrict_to, None if restrict_to == "GITHUB_USERS" else restrict_to)
|
|
1861
3510
|
|
|
1862
3511
|
# If restrict_to is configured, validate the user
|
|
1863
3512
|
if restrict_to:
|
|
@@ -1878,6 +3527,14 @@ def main():
|
|
|
1878
3527
|
code = request.query.get("code")
|
|
1879
3528
|
state = request.query.get("state")
|
|
1880
3529
|
|
|
3530
|
+
# Handle malformed URLs where query params are appended with & instead of ?
|
|
3531
|
+
if not code and "tail" in request.match_info:
|
|
3532
|
+
tail = request.match_info["tail"]
|
|
3533
|
+
if tail.startswith("&"):
|
|
3534
|
+
params = parse_qs(tail[1:])
|
|
3535
|
+
code = params.get("code", [None])[0]
|
|
3536
|
+
state = params.get("state", [None])[0]
|
|
3537
|
+
|
|
1881
3538
|
if not code or not state:
|
|
1882
3539
|
return web.Response(text="Missing code or state parameter", status=400)
|
|
1883
3540
|
|
|
@@ -1888,7 +3545,7 @@ def main():
|
|
|
1888
3545
|
g_oauth_states.pop(state)
|
|
1889
3546
|
|
|
1890
3547
|
if "auth" not in g_config or "github" not in g_config["auth"]:
|
|
1891
|
-
return web.json_response(
|
|
3548
|
+
return web.json_response(create_error_response("GitHub OAuth not configured"), status=500)
|
|
1892
3549
|
|
|
1893
3550
|
auth_config = g_config["auth"]["github"]
|
|
1894
3551
|
client_id = auth_config.get("client_id", "")
|
|
@@ -1897,14 +3554,18 @@ def main():
|
|
|
1897
3554
|
|
|
1898
3555
|
# Expand environment variables
|
|
1899
3556
|
if client_id.startswith("$"):
|
|
1900
|
-
client_id =
|
|
3557
|
+
client_id = client_id[1:]
|
|
1901
3558
|
if client_secret.startswith("$"):
|
|
1902
|
-
client_secret =
|
|
3559
|
+
client_secret = client_secret[1:]
|
|
1903
3560
|
if redirect_uri.startswith("$"):
|
|
1904
|
-
redirect_uri =
|
|
3561
|
+
redirect_uri = redirect_uri[1:]
|
|
3562
|
+
|
|
3563
|
+
client_id = os.getenv(client_id, client_id)
|
|
3564
|
+
client_secret = os.getenv(client_secret, client_secret)
|
|
3565
|
+
redirect_uri = os.getenv(redirect_uri, redirect_uri)
|
|
1905
3566
|
|
|
1906
3567
|
if not client_id or not client_secret:
|
|
1907
|
-
return web.json_response(
|
|
3568
|
+
return web.json_response(create_error_response("GitHub OAuth credentials not configured"), status=500)
|
|
1908
3569
|
|
|
1909
3570
|
# Exchange code for access token
|
|
1910
3571
|
async with aiohttp.ClientSession() as session:
|
|
@@ -1923,7 +3584,7 @@ def main():
|
|
|
1923
3584
|
|
|
1924
3585
|
if not access_token:
|
|
1925
3586
|
error = token_response.get("error_description", "Failed to get access token")
|
|
1926
|
-
return web.
|
|
3587
|
+
return web.json_response(create_error_response(f"OAuth error: {error}"), status=400)
|
|
1927
3588
|
|
|
1928
3589
|
# Fetch user info
|
|
1929
3590
|
user_url = "https://api.github.com/user"
|
|
@@ -1949,14 +3610,16 @@ def main():
|
|
|
1949
3610
|
}
|
|
1950
3611
|
|
|
1951
3612
|
# Redirect to UI with session token
|
|
1952
|
-
|
|
3613
|
+
response = web.HTTPFound(f"/?session={session_token}")
|
|
3614
|
+
response.set_cookie("llms-token", session_token, httponly=True, path="/", max_age=86400)
|
|
3615
|
+
return response
|
|
1953
3616
|
|
|
1954
3617
|
async def session_handler(request):
|
|
1955
3618
|
"""Validate and return session info"""
|
|
1956
|
-
session_token =
|
|
3619
|
+
session_token = get_session_token(request)
|
|
1957
3620
|
|
|
1958
3621
|
if not session_token or session_token not in g_sessions:
|
|
1959
|
-
return web.json_response(
|
|
3622
|
+
return web.json_response(create_error_response("Invalid or expired session"), status=401)
|
|
1960
3623
|
|
|
1961
3624
|
session_data = g_sessions[session_token]
|
|
1962
3625
|
|
|
@@ -1970,17 +3633,19 @@ def main():
|
|
|
1970
3633
|
|
|
1971
3634
|
async def logout_handler(request):
|
|
1972
3635
|
"""End OAuth session"""
|
|
1973
|
-
session_token =
|
|
3636
|
+
session_token = get_session_token(request)
|
|
1974
3637
|
|
|
1975
3638
|
if session_token and session_token in g_sessions:
|
|
1976
3639
|
del g_sessions[session_token]
|
|
1977
3640
|
|
|
1978
|
-
|
|
3641
|
+
response = web.json_response({"success": True})
|
|
3642
|
+
response.del_cookie("llms-token")
|
|
3643
|
+
return response
|
|
1979
3644
|
|
|
1980
3645
|
async def auth_handler(request):
|
|
1981
3646
|
"""Check authentication status and return user info"""
|
|
1982
3647
|
# Check for OAuth session token
|
|
1983
|
-
session_token =
|
|
3648
|
+
session_token = get_session_token(request)
|
|
1984
3649
|
|
|
1985
3650
|
if session_token and session_token in g_sessions:
|
|
1986
3651
|
session_data = g_sessions[session_token]
|
|
@@ -2010,13 +3675,12 @@ def main():
|
|
|
2010
3675
|
# })
|
|
2011
3676
|
|
|
2012
3677
|
# Not authenticated - return error in expected format
|
|
2013
|
-
return web.json_response(
|
|
2014
|
-
{"responseStatus": {"errorCode": "Unauthorized", "message": "Not authenticated"}}, status=401
|
|
2015
|
-
)
|
|
3678
|
+
return web.json_response(g_app.error_auth_required, status=401)
|
|
2016
3679
|
|
|
2017
3680
|
app.router.add_get("/auth", auth_handler)
|
|
2018
3681
|
app.router.add_get("/auth/github", github_auth_handler)
|
|
2019
3682
|
app.router.add_get("/auth/github/callback", github_callback_handler)
|
|
3683
|
+
app.router.add_get("/auth/github/callback{tail:.*}", github_callback_handler)
|
|
2020
3684
|
app.router.add_get("/auth/session", session_handler)
|
|
2021
3685
|
app.router.add_post("/auth/logout", logout_handler)
|
|
2022
3686
|
|
|
@@ -2051,30 +3715,101 @@ def main():
|
|
|
2051
3715
|
|
|
2052
3716
|
app.router.add_get("/ui/{path:.*}", ui_static, name="ui_static")
|
|
2053
3717
|
|
|
2054
|
-
async def
|
|
2055
|
-
|
|
2056
|
-
|
|
2057
|
-
|
|
2058
|
-
|
|
2059
|
-
|
|
2060
|
-
|
|
2061
|
-
|
|
2062
|
-
|
|
2063
|
-
|
|
2064
|
-
return web.json_response(ui)
|
|
3718
|
+
async def config_handler(request):
|
|
3719
|
+
ret = {}
|
|
3720
|
+
if "defaults" not in ret:
|
|
3721
|
+
ret["defaults"] = g_config["defaults"]
|
|
3722
|
+
enabled, disabled = provider_status()
|
|
3723
|
+
ret["status"] = {"all": list(g_config["providers"].keys()), "enabled": enabled, "disabled": disabled}
|
|
3724
|
+
# Add auth configuration
|
|
3725
|
+
ret["requiresAuth"] = auth_enabled
|
|
3726
|
+
ret["authType"] = "oauth" if auth_enabled else "apikey"
|
|
3727
|
+
return web.json_response(ret)
|
|
2065
3728
|
|
|
2066
|
-
app.router.add_get("/config",
|
|
3729
|
+
app.router.add_get("/config", config_handler)
|
|
2067
3730
|
|
|
2068
3731
|
async def not_found_handler(request):
|
|
2069
3732
|
return web.Response(text="404: Not Found", status=404)
|
|
2070
3733
|
|
|
2071
3734
|
app.router.add_get("/favicon.ico", not_found_handler)
|
|
2072
3735
|
|
|
3736
|
+
# go through and register all g_app extensions
|
|
3737
|
+
for handler in g_app.server_add_get:
|
|
3738
|
+
handler_fn = handler[1]
|
|
3739
|
+
|
|
3740
|
+
async def managed_handler(request, handler_fn=handler_fn):
|
|
3741
|
+
try:
|
|
3742
|
+
return await handler_fn(request)
|
|
3743
|
+
except Exception as e:
|
|
3744
|
+
return web.json_response(to_error_response(e, stacktrace=g_verbose), status=500)
|
|
3745
|
+
|
|
3746
|
+
app.router.add_get(handler[0], managed_handler, **handler[2])
|
|
3747
|
+
for handler in g_app.server_add_post:
|
|
3748
|
+
handler_fn = handler[1]
|
|
3749
|
+
|
|
3750
|
+
async def managed_handler(request, handler_fn=handler_fn):
|
|
3751
|
+
try:
|
|
3752
|
+
return await handler_fn(request)
|
|
3753
|
+
except Exception as e:
|
|
3754
|
+
return web.json_response(to_error_response(e, stacktrace=g_verbose), status=500)
|
|
3755
|
+
|
|
3756
|
+
app.router.add_post(handler[0], managed_handler, **handler[2])
|
|
3757
|
+
for handler in g_app.server_add_put:
|
|
3758
|
+
handler_fn = handler[1]
|
|
3759
|
+
|
|
3760
|
+
async def managed_handler(request, handler_fn=handler_fn):
|
|
3761
|
+
try:
|
|
3762
|
+
return await handler_fn(request)
|
|
3763
|
+
except Exception as e:
|
|
3764
|
+
return web.json_response(to_error_response(e, stacktrace=g_verbose), status=500)
|
|
3765
|
+
|
|
3766
|
+
app.router.add_put(handler[0], managed_handler, **handler[2])
|
|
3767
|
+
for handler in g_app.server_add_delete:
|
|
3768
|
+
handler_fn = handler[1]
|
|
3769
|
+
|
|
3770
|
+
async def managed_handler(request, handler_fn=handler_fn):
|
|
3771
|
+
try:
|
|
3772
|
+
return await handler_fn(request)
|
|
3773
|
+
except Exception as e:
|
|
3774
|
+
return web.json_response(to_error_response(e, stacktrace=g_verbose), status=500)
|
|
3775
|
+
|
|
3776
|
+
app.router.add_delete(handler[0], managed_handler, **handler[2])
|
|
3777
|
+
for handler in g_app.server_add_patch:
|
|
3778
|
+
handler_fn = handler[1]
|
|
3779
|
+
|
|
3780
|
+
async def managed_handler(request, handler_fn=handler_fn):
|
|
3781
|
+
try:
|
|
3782
|
+
return await handler_fn(request)
|
|
3783
|
+
except Exception as e:
|
|
3784
|
+
return web.json_response(to_error_response(e, stacktrace=g_verbose), status=500)
|
|
3785
|
+
|
|
3786
|
+
app.router.add_patch(handler[0], managed_handler, **handler[2])
|
|
3787
|
+
|
|
2073
3788
|
# Serve index.html from root
|
|
2074
3789
|
async def index_handler(request):
|
|
2075
3790
|
index_content = read_resource_file_bytes("index.html")
|
|
2076
|
-
|
|
2077
|
-
|
|
3791
|
+
|
|
3792
|
+
importmaps = {"imports": g_app.import_maps}
|
|
3793
|
+
importmaps_script = '<script type="importmap">\n' + json.dumps(importmaps, indent=4) + "\n</script>"
|
|
3794
|
+
index_content = index_content.replace(
|
|
3795
|
+
b'<script type="importmap"></script>',
|
|
3796
|
+
importmaps_script.encode("utf-8"),
|
|
3797
|
+
)
|
|
3798
|
+
|
|
3799
|
+
if len(g_app.index_headers) > 0:
|
|
3800
|
+
html_header = ""
|
|
3801
|
+
for header in g_app.index_headers:
|
|
3802
|
+
html_header += header
|
|
3803
|
+
# replace </head> with html_header
|
|
3804
|
+
index_content = index_content.replace(b"</head>", html_header.encode("utf-8") + b"\n</head>")
|
|
3805
|
+
|
|
3806
|
+
if len(g_app.index_footers) > 0:
|
|
3807
|
+
html_footer = ""
|
|
3808
|
+
for footer in g_app.index_footers:
|
|
3809
|
+
html_footer += footer
|
|
3810
|
+
# replace </body> with html_footer
|
|
3811
|
+
index_content = index_content.replace(b"</body>", html_footer.encode("utf-8") + b"\n</body>")
|
|
3812
|
+
|
|
2078
3813
|
return web.Response(body=index_content, content_type="text/html")
|
|
2079
3814
|
|
|
2080
3815
|
app.router.add_get("/", index_handler)
|
|
@@ -2086,13 +3821,15 @@ def main():
|
|
|
2086
3821
|
async def start_background_tasks(app):
|
|
2087
3822
|
"""Start background tasks when the app starts"""
|
|
2088
3823
|
# Start watching config files in the background
|
|
2089
|
-
asyncio.create_task(watch_config_files(g_config_path,
|
|
3824
|
+
asyncio.create_task(watch_config_files(g_config_path, home_providers_path))
|
|
2090
3825
|
|
|
2091
3826
|
app.on_startup.append(start_background_tasks)
|
|
2092
3827
|
|
|
3828
|
+
# go through and register all g_app extensions
|
|
3829
|
+
|
|
2093
3830
|
print(f"Starting server on port {port}...")
|
|
2094
3831
|
web.run_app(app, host="0.0.0.0", port=port, print=_log)
|
|
2095
|
-
exit(0)
|
|
3832
|
+
g_app.exit(0)
|
|
2096
3833
|
|
|
2097
3834
|
if cli_args.enable is not None:
|
|
2098
3835
|
if cli_args.enable.endswith(","):
|
|
@@ -2109,7 +3846,7 @@ def main():
|
|
|
2109
3846
|
|
|
2110
3847
|
for provider in enable_providers:
|
|
2111
3848
|
if provider not in g_config["providers"]:
|
|
2112
|
-
print(f"Provider {provider} not found")
|
|
3849
|
+
print(f"Provider '{provider}' not found")
|
|
2113
3850
|
print(f"Available providers: {', '.join(g_config['providers'].keys())}")
|
|
2114
3851
|
exit(1)
|
|
2115
3852
|
if provider in g_config["providers"]:
|
|
@@ -2122,7 +3859,7 @@ def main():
|
|
|
2122
3859
|
print_status()
|
|
2123
3860
|
if len(msgs) > 0:
|
|
2124
3861
|
print("\n" + "\n".join(msgs))
|
|
2125
|
-
exit(0)
|
|
3862
|
+
g_app.exit(0)
|
|
2126
3863
|
|
|
2127
3864
|
if cli_args.disable is not None:
|
|
2128
3865
|
if cli_args.disable.endswith(","):
|
|
@@ -2145,26 +3882,26 @@ def main():
|
|
|
2145
3882
|
print(f"\nDisabled provider {provider}")
|
|
2146
3883
|
|
|
2147
3884
|
print_status()
|
|
2148
|
-
exit(0)
|
|
3885
|
+
g_app.exit(0)
|
|
2149
3886
|
|
|
2150
3887
|
if cli_args.default is not None:
|
|
2151
3888
|
default_model = cli_args.default
|
|
2152
|
-
|
|
2153
|
-
if
|
|
3889
|
+
provider_model = get_provider_model(default_model)
|
|
3890
|
+
if provider_model is None:
|
|
2154
3891
|
print(f"Model {default_model} not found")
|
|
2155
|
-
print(f"Available models: {', '.join(all_models)}")
|
|
2156
3892
|
exit(1)
|
|
2157
3893
|
default_text = g_config["defaults"]["text"]
|
|
2158
3894
|
default_text["model"] = default_model
|
|
2159
3895
|
save_config(g_config)
|
|
2160
3896
|
print(f"\nDefault model set to: {default_model}")
|
|
2161
|
-
exit(0)
|
|
3897
|
+
g_app.exit(0)
|
|
2162
3898
|
|
|
2163
3899
|
if (
|
|
2164
3900
|
cli_args.chat is not None
|
|
2165
3901
|
or cli_args.image is not None
|
|
2166
3902
|
or cli_args.audio is not None
|
|
2167
3903
|
or cli_args.file is not None
|
|
3904
|
+
or cli_args.out is not None
|
|
2168
3905
|
or len(extra_args) > 0
|
|
2169
3906
|
):
|
|
2170
3907
|
try:
|
|
@@ -2175,6 +3912,12 @@ def main():
|
|
|
2175
3912
|
chat = g_config["defaults"]["audio"]
|
|
2176
3913
|
elif cli_args.file is not None:
|
|
2177
3914
|
chat = g_config["defaults"]["file"]
|
|
3915
|
+
elif cli_args.out is not None:
|
|
3916
|
+
template = f"out:{cli_args.out}"
|
|
3917
|
+
if template not in g_config["defaults"]:
|
|
3918
|
+
print(f"Template for output modality '{cli_args.out}' not found")
|
|
3919
|
+
exit(1)
|
|
3920
|
+
chat = g_config["defaults"][template]
|
|
2178
3921
|
if cli_args.chat is not None:
|
|
2179
3922
|
chat_path = os.path.join(os.path.dirname(__file__), cli_args.chat)
|
|
2180
3923
|
if not os.path.exists(chat_path):
|
|
@@ -2191,6 +3934,9 @@ def main():
|
|
|
2191
3934
|
|
|
2192
3935
|
if len(extra_args) > 0:
|
|
2193
3936
|
prompt = " ".join(extra_args)
|
|
3937
|
+
if not chat["messages"] or len(chat["messages"]) == 0:
|
|
3938
|
+
chat["messages"] = [{"role": "user", "content": [{"type": "text", "text": ""}]}]
|
|
3939
|
+
|
|
2194
3940
|
# replace content of last message if exists, else add
|
|
2195
3941
|
last_msg = chat["messages"][-1] if "messages" in chat else None
|
|
2196
3942
|
if last_msg and last_msg["role"] == "user":
|
|
@@ -2208,19 +3954,31 @@ def main():
|
|
|
2208
3954
|
|
|
2209
3955
|
asyncio.run(
|
|
2210
3956
|
cli_chat(
|
|
2211
|
-
chat,
|
|
3957
|
+
chat,
|
|
3958
|
+
tools=cli_args.tools,
|
|
3959
|
+
image=cli_args.image,
|
|
3960
|
+
audio=cli_args.audio,
|
|
3961
|
+
file=cli_args.file,
|
|
3962
|
+
args=args,
|
|
3963
|
+
raw=cli_args.raw,
|
|
2212
3964
|
)
|
|
2213
3965
|
)
|
|
2214
|
-
exit(0)
|
|
3966
|
+
g_app.exit(0)
|
|
2215
3967
|
except Exception as e:
|
|
2216
3968
|
print(f"{cli_args.logprefix}Error: {e}")
|
|
2217
3969
|
if cli_args.verbose:
|
|
2218
3970
|
traceback.print_exc()
|
|
2219
|
-
exit(1)
|
|
3971
|
+
g_app.exit(1)
|
|
3972
|
+
|
|
3973
|
+
handled = run_extension_cli()
|
|
2220
3974
|
|
|
2221
|
-
|
|
2222
|
-
|
|
3975
|
+
if not handled:
|
|
3976
|
+
# show usage from ArgumentParser
|
|
3977
|
+
parser.print_help()
|
|
3978
|
+
g_app.exit(0)
|
|
2223
3979
|
|
|
2224
3980
|
|
|
2225
3981
|
if __name__ == "__main__":
|
|
3982
|
+
if MOCK or DEBUG:
|
|
3983
|
+
print(f"MOCK={MOCK} or DEBUG={DEBUG}")
|
|
2226
3984
|
main()
|