llms-py 2.0.34__py3-none-any.whl → 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llms/__init__.py +3 -1
- llms/__pycache__/__init__.cpython-312.pyc +0 -0
- llms/__pycache__/__init__.cpython-313.pyc +0 -0
- llms/__pycache__/__init__.cpython-314.pyc +0 -0
- llms/__pycache__/__main__.cpython-312.pyc +0 -0
- llms/__pycache__/__main__.cpython-314.pyc +0 -0
- llms/__pycache__/llms.cpython-312.pyc +0 -0
- llms/__pycache__/main.cpython-312.pyc +0 -0
- llms/__pycache__/main.cpython-313.pyc +0 -0
- llms/__pycache__/main.cpython-314.pyc +0 -0
- llms/__pycache__/plugins.cpython-314.pyc +0 -0
- llms/{ui/Analytics.mjs → extensions/analytics/ui/index.mjs} +154 -238
- llms/extensions/app/README.md +20 -0
- llms/extensions/app/__init__.py +530 -0
- llms/extensions/app/__pycache__/__init__.cpython-314.pyc +0 -0
- llms/extensions/app/__pycache__/db.cpython-314.pyc +0 -0
- llms/extensions/app/__pycache__/db_manager.cpython-314.pyc +0 -0
- llms/extensions/app/db.py +644 -0
- llms/extensions/app/db_manager.py +195 -0
- llms/extensions/app/requests.json +9073 -0
- llms/extensions/app/threads.json +15290 -0
- llms/{ui → extensions/app/ui}/Recents.mjs +91 -65
- llms/{ui/Sidebar.mjs → extensions/app/ui/index.mjs} +124 -58
- llms/extensions/app/ui/threadStore.mjs +411 -0
- llms/extensions/core_tools/CALCULATOR.md +32 -0
- llms/extensions/core_tools/__init__.py +598 -0
- llms/extensions/core_tools/__pycache__/__init__.cpython-314.pyc +0 -0
- llms/extensions/core_tools/ui/codemirror/addon/edit/closebrackets.js +201 -0
- llms/extensions/core_tools/ui/codemirror/addon/edit/closetag.js +185 -0
- llms/extensions/core_tools/ui/codemirror/addon/edit/continuelist.js +101 -0
- llms/extensions/core_tools/ui/codemirror/addon/edit/matchbrackets.js +160 -0
- llms/extensions/core_tools/ui/codemirror/addon/edit/matchtags.js +66 -0
- llms/extensions/core_tools/ui/codemirror/addon/edit/trailingspace.js +27 -0
- llms/extensions/core_tools/ui/codemirror/addon/selection/active-line.js +72 -0
- llms/extensions/core_tools/ui/codemirror/addon/selection/mark-selection.js +119 -0
- llms/extensions/core_tools/ui/codemirror/addon/selection/selection-pointer.js +98 -0
- llms/extensions/core_tools/ui/codemirror/doc/docs.css +225 -0
- llms/extensions/core_tools/ui/codemirror/doc/source_sans.woff +0 -0
- llms/extensions/core_tools/ui/codemirror/lib/codemirror.css +344 -0
- llms/extensions/core_tools/ui/codemirror/lib/codemirror.js +9884 -0
- llms/extensions/core_tools/ui/codemirror/mode/clike/clike.js +942 -0
- llms/extensions/core_tools/ui/codemirror/mode/javascript/index.html +118 -0
- llms/extensions/core_tools/ui/codemirror/mode/javascript/javascript.js +962 -0
- llms/extensions/core_tools/ui/codemirror/mode/javascript/typescript.html +62 -0
- llms/extensions/core_tools/ui/codemirror/mode/python/python.js +402 -0
- llms/extensions/core_tools/ui/codemirror/theme/dracula.css +40 -0
- llms/extensions/core_tools/ui/codemirror/theme/mocha.css +135 -0
- llms/extensions/core_tools/ui/index.mjs +650 -0
- llms/extensions/gallery/README.md +61 -0
- llms/extensions/gallery/__init__.py +61 -0
- llms/extensions/gallery/__pycache__/__init__.cpython-314.pyc +0 -0
- llms/extensions/gallery/__pycache__/db.cpython-314.pyc +0 -0
- llms/extensions/gallery/db.py +298 -0
- llms/extensions/gallery/ui/index.mjs +482 -0
- llms/extensions/katex/README.md +39 -0
- llms/extensions/katex/__init__.py +6 -0
- llms/extensions/katex/__pycache__/__init__.cpython-314.pyc +0 -0
- llms/extensions/katex/ui/README.md +125 -0
- llms/extensions/katex/ui/contrib/auto-render.js +338 -0
- llms/extensions/katex/ui/contrib/auto-render.min.js +1 -0
- llms/extensions/katex/ui/contrib/auto-render.mjs +244 -0
- llms/extensions/katex/ui/contrib/copy-tex.js +127 -0
- llms/extensions/katex/ui/contrib/copy-tex.min.js +1 -0
- llms/extensions/katex/ui/contrib/copy-tex.mjs +105 -0
- llms/extensions/katex/ui/contrib/mathtex-script-type.js +109 -0
- llms/extensions/katex/ui/contrib/mathtex-script-type.min.js +1 -0
- llms/extensions/katex/ui/contrib/mathtex-script-type.mjs +24 -0
- llms/extensions/katex/ui/contrib/mhchem.js +3213 -0
- llms/extensions/katex/ui/contrib/mhchem.min.js +1 -0
- llms/extensions/katex/ui/contrib/mhchem.mjs +3109 -0
- llms/extensions/katex/ui/contrib/render-a11y-string.js +887 -0
- llms/extensions/katex/ui/contrib/render-a11y-string.min.js +1 -0
- llms/extensions/katex/ui/contrib/render-a11y-string.mjs +800 -0
- llms/extensions/katex/ui/fonts/KaTeX_AMS-Regular.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_AMS-Regular.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_AMS-Regular.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Caligraphic-Bold.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Caligraphic-Bold.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Caligraphic-Bold.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Caligraphic-Regular.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Caligraphic-Regular.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Caligraphic-Regular.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Fraktur-Bold.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Fraktur-Bold.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Fraktur-Bold.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Fraktur-Regular.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Fraktur-Regular.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Fraktur-Regular.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Main-Bold.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Main-Bold.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Main-Bold.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Main-BoldItalic.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Main-BoldItalic.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Main-BoldItalic.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Main-Italic.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Main-Italic.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Main-Italic.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Main-Regular.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Main-Regular.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Main-Regular.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Math-BoldItalic.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Math-BoldItalic.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Math-BoldItalic.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Math-Italic.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Math-Italic.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Math-Italic.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_SansSerif-Bold.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_SansSerif-Bold.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_SansSerif-Bold.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_SansSerif-Italic.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_SansSerif-Italic.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_SansSerif-Italic.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_SansSerif-Regular.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_SansSerif-Regular.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_SansSerif-Regular.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Script-Regular.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Script-Regular.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Script-Regular.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Size1-Regular.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Size1-Regular.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Size1-Regular.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Size2-Regular.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Size2-Regular.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Size2-Regular.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Size3-Regular.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Size3-Regular.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Size3-Regular.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Size4-Regular.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Size4-Regular.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Size4-Regular.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Typewriter-Regular.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Typewriter-Regular.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Typewriter-Regular.woff2 +0 -0
- llms/extensions/katex/ui/index.mjs +92 -0
- llms/extensions/katex/ui/katex-swap.css +1230 -0
- llms/extensions/katex/ui/katex-swap.min.css +1 -0
- llms/extensions/katex/ui/katex.css +1230 -0
- llms/extensions/katex/ui/katex.js +19080 -0
- llms/extensions/katex/ui/katex.min.css +1 -0
- llms/extensions/katex/ui/katex.min.js +1 -0
- llms/extensions/katex/ui/katex.min.mjs +1 -0
- llms/extensions/katex/ui/katex.mjs +18547 -0
- llms/extensions/providers/__init__.py +18 -0
- llms/extensions/providers/__pycache__/__init__.cpython-314.pyc +0 -0
- llms/extensions/providers/__pycache__/anthropic.cpython-314.pyc +0 -0
- llms/extensions/providers/__pycache__/chutes.cpython-314.pyc +0 -0
- llms/extensions/providers/__pycache__/google.cpython-314.pyc +0 -0
- llms/extensions/providers/__pycache__/nvidia.cpython-314.pyc +0 -0
- llms/extensions/providers/__pycache__/openai.cpython-314.pyc +0 -0
- llms/extensions/providers/__pycache__/openrouter.cpython-314.pyc +0 -0
- llms/extensions/providers/anthropic.py +229 -0
- llms/extensions/providers/chutes.py +155 -0
- llms/extensions/providers/google.py +378 -0
- llms/extensions/providers/nvidia.py +105 -0
- llms/extensions/providers/openai.py +156 -0
- llms/extensions/providers/openrouter.py +72 -0
- llms/extensions/system_prompts/README.md +22 -0
- llms/extensions/system_prompts/__init__.py +45 -0
- llms/extensions/system_prompts/__pycache__/__init__.cpython-314.pyc +0 -0
- llms/extensions/system_prompts/ui/index.mjs +280 -0
- llms/extensions/system_prompts/ui/prompts.json +1067 -0
- llms/extensions/tools/__init__.py +5 -0
- llms/extensions/tools/__pycache__/__init__.cpython-314.pyc +0 -0
- llms/extensions/tools/ui/index.mjs +204 -0
- llms/index.html +35 -77
- llms/llms.json +357 -1186
- llms/main.py +2847 -999
- llms/providers-extra.json +356 -0
- llms/providers.json +1 -0
- llms/ui/App.mjs +151 -60
- llms/ui/ai.mjs +132 -60
- llms/ui/app.css +2173 -161
- llms/ui/ctx.mjs +365 -0
- llms/ui/index.mjs +129 -0
- llms/ui/lib/charts.mjs +9 -13
- llms/ui/lib/servicestack-vue.mjs +3 -3
- llms/ui/lib/vue.min.mjs +10 -9
- llms/ui/lib/vue.mjs +1796 -1635
- llms/ui/markdown.mjs +18 -7
- llms/ui/modules/chat/ChatBody.mjs +691 -0
- llms/ui/{SettingsDialog.mjs → modules/chat/SettingsDialog.mjs} +9 -9
- llms/ui/modules/chat/index.mjs +828 -0
- llms/ui/modules/layout.mjs +243 -0
- llms/ui/modules/model-selector.mjs +851 -0
- llms/ui/tailwind.input.css +496 -80
- llms/ui/utils.mjs +161 -93
- {llms_py-2.0.34.dist-info → llms_py-3.0.0.dist-info}/METADATA +1 -1
- llms_py-3.0.0.dist-info/RECORD +202 -0
- llms/ui/Avatar.mjs +0 -85
- llms/ui/Brand.mjs +0 -52
- llms/ui/ChatPrompt.mjs +0 -590
- llms/ui/Main.mjs +0 -823
- llms/ui/ModelSelector.mjs +0 -78
- llms/ui/OAuthSignIn.mjs +0 -92
- llms/ui/ProviderIcon.mjs +0 -30
- llms/ui/ProviderStatus.mjs +0 -105
- llms/ui/SignIn.mjs +0 -64
- llms/ui/SystemPromptEditor.mjs +0 -31
- llms/ui/SystemPromptSelector.mjs +0 -56
- llms/ui/Welcome.mjs +0 -8
- llms/ui/threadStore.mjs +0 -563
- llms/ui.json +0 -1069
- llms_py-2.0.34.dist-info/RECORD +0 -48
- {llms_py-2.0.34.dist-info → llms_py-3.0.0.dist-info}/WHEEL +0 -0
- {llms_py-2.0.34.dist-info → llms_py-3.0.0.dist-info}/entry_points.txt +0 -0
- {llms_py-2.0.34.dist-info → llms_py-3.0.0.dist-info}/licenses/LICENSE +0 -0
- {llms_py-2.0.34.dist-info → llms_py-3.0.0.dist-info}/top_level.txt +0 -0
llms/main.py
CHANGED
|
@@ -6,104 +6,123 @@
|
|
|
6
6
|
# A lightweight CLI tool and OpenAI-compatible server for querying multiple Large Language Model (LLM) providers.
|
|
7
7
|
# Docs: https://github.com/ServiceStack/llms
|
|
8
8
|
|
|
9
|
-
import os
|
|
10
|
-
import time
|
|
11
|
-
import json
|
|
12
9
|
import argparse
|
|
13
10
|
import asyncio
|
|
14
|
-
import subprocess
|
|
15
11
|
import base64
|
|
12
|
+
import contextlib
|
|
13
|
+
import hashlib
|
|
14
|
+
import importlib.util
|
|
15
|
+
import inspect
|
|
16
|
+
import json
|
|
16
17
|
import mimetypes
|
|
17
|
-
import
|
|
18
|
-
import sys
|
|
19
|
-
import site
|
|
20
|
-
import secrets
|
|
18
|
+
import os
|
|
21
19
|
import re
|
|
20
|
+
import secrets
|
|
21
|
+
import shutil
|
|
22
|
+
import site
|
|
23
|
+
import subprocess
|
|
24
|
+
import sys
|
|
25
|
+
import time
|
|
26
|
+
import traceback
|
|
27
|
+
from datetime import datetime
|
|
28
|
+
from importlib import resources # Py≥3.9 (pip install importlib_resources for 3.7/3.8)
|
|
22
29
|
from io import BytesIO
|
|
23
|
-
from
|
|
30
|
+
from pathlib import Path
|
|
31
|
+
from typing import get_type_hints
|
|
32
|
+
from urllib.parse import parse_qs, urlencode, urljoin
|
|
24
33
|
|
|
25
34
|
import aiohttp
|
|
26
35
|
from aiohttp import web
|
|
27
36
|
|
|
28
|
-
from pathlib import Path
|
|
29
|
-
from importlib import resources # Py≥3.9 (pip install importlib_resources for 3.7/3.8)
|
|
30
|
-
|
|
31
37
|
try:
|
|
32
38
|
from PIL import Image
|
|
39
|
+
|
|
33
40
|
HAS_PIL = True
|
|
34
41
|
except ImportError:
|
|
35
42
|
HAS_PIL = False
|
|
36
43
|
|
|
37
|
-
VERSION = "
|
|
44
|
+
VERSION = "3.0.0"
|
|
38
45
|
_ROOT = None
|
|
46
|
+
DEBUG = os.getenv("DEBUG") == "1"
|
|
47
|
+
MOCK = os.getenv("MOCK") == "1"
|
|
48
|
+
MOCK_DIR = os.getenv("MOCK_DIR")
|
|
49
|
+
DISABLE_EXTENSIONS = (os.getenv("LLMS_DISABLE") or "").split(",")
|
|
39
50
|
g_config_path = None
|
|
40
|
-
g_ui_path = None
|
|
41
51
|
g_config = None
|
|
52
|
+
g_providers = None
|
|
42
53
|
g_handlers = {}
|
|
43
54
|
g_verbose = False
|
|
44
|
-
g_logprefix=""
|
|
45
|
-
g_default_model=""
|
|
55
|
+
g_logprefix = ""
|
|
56
|
+
g_default_model = ""
|
|
46
57
|
g_sessions = {} # OAuth session storage: {session_token: {userId, userName, displayName, profileUrl, email, created}}
|
|
47
58
|
g_oauth_states = {} # CSRF protection: {state: {created, redirect_uri}}
|
|
59
|
+
g_app = None # ExtensionsContext Singleton
|
|
60
|
+
|
|
48
61
|
|
|
49
62
|
def _log(message):
|
|
50
|
-
"""Helper method for logging from the global polling task."""
|
|
51
63
|
if g_verbose:
|
|
52
64
|
print(f"{g_logprefix}{message}", flush=True)
|
|
53
65
|
|
|
66
|
+
|
|
67
|
+
def _dbg(message):
|
|
68
|
+
if DEBUG:
|
|
69
|
+
print(f"DEBUG: {message}", flush=True)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _err(message, e):
|
|
73
|
+
print(f"ERROR: {message}: {e}", flush=True)
|
|
74
|
+
if g_verbose:
|
|
75
|
+
print(traceback.format_exc(), flush=True)
|
|
76
|
+
|
|
77
|
+
|
|
54
78
|
def printdump(obj):
|
|
55
|
-
args = obj.__dict__ if hasattr(obj,
|
|
79
|
+
args = obj.__dict__ if hasattr(obj, "__dict__") else obj
|
|
56
80
|
print(json.dumps(args, indent=2))
|
|
57
81
|
|
|
82
|
+
|
|
58
83
|
def print_chat(chat):
|
|
59
84
|
_log(f"Chat: {chat_summary(chat)}")
|
|
60
85
|
|
|
86
|
+
|
|
61
87
|
def chat_summary(chat):
|
|
62
88
|
"""Summarize chat completion request for logging."""
|
|
63
89
|
# replace image_url.url with <image>
|
|
64
90
|
clone = json.loads(json.dumps(chat))
|
|
65
|
-
for message in clone[
|
|
66
|
-
if
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
if
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
prefix = data.split(',', 1)[0]
|
|
82
|
-
item['file']['file_data'] = prefix + f",({len(data) - len(prefix)})"
|
|
91
|
+
for message in clone["messages"]:
|
|
92
|
+
if "content" in message and isinstance(message["content"], list):
|
|
93
|
+
for item in message["content"]:
|
|
94
|
+
if "image_url" in item:
|
|
95
|
+
if "url" in item["image_url"]:
|
|
96
|
+
url = item["image_url"]["url"]
|
|
97
|
+
prefix = url.split(",", 1)[0]
|
|
98
|
+
item["image_url"]["url"] = prefix + f",({len(url) - len(prefix)})"
|
|
99
|
+
elif "input_audio" in item:
|
|
100
|
+
if "data" in item["input_audio"]:
|
|
101
|
+
data = item["input_audio"]["data"]
|
|
102
|
+
item["input_audio"]["data"] = f"({len(data)})"
|
|
103
|
+
elif "file" in item and "file_data" in item["file"]:
|
|
104
|
+
data = item["file"]["file_data"]
|
|
105
|
+
prefix = data.split(",", 1)[0]
|
|
106
|
+
item["file"]["file_data"] = prefix + f",({len(data) - len(prefix)})"
|
|
83
107
|
return json.dumps(clone, indent=2)
|
|
84
108
|
|
|
85
|
-
def gemini_chat_summary(gemini_chat):
|
|
86
|
-
"""Summarize Gemini chat completion request for logging. Replace inline_data with size of content only"""
|
|
87
|
-
clone = json.loads(json.dumps(gemini_chat))
|
|
88
|
-
for content in clone['contents']:
|
|
89
|
-
for part in content['parts']:
|
|
90
|
-
if 'inline_data' in part:
|
|
91
|
-
data = part['inline_data']['data']
|
|
92
|
-
part['inline_data']['data'] = f"({len(data)})"
|
|
93
|
-
return json.dumps(clone, indent=2)
|
|
94
109
|
|
|
95
|
-
image_exts =
|
|
96
|
-
audio_exts =
|
|
110
|
+
image_exts = ["png", "webp", "jpg", "jpeg", "gif", "bmp", "svg", "tiff", "ico"]
|
|
111
|
+
audio_exts = ["mp3", "wav", "ogg", "flac", "m4a", "opus", "webm"]
|
|
112
|
+
|
|
97
113
|
|
|
98
114
|
def is_file_path(path):
|
|
99
115
|
# macOs max path is 1023
|
|
100
116
|
return path and len(path) < 1024 and os.path.exists(path)
|
|
101
117
|
|
|
118
|
+
|
|
102
119
|
def is_url(url):
|
|
103
|
-
return url and (url.startswith(
|
|
120
|
+
return url and (url.startswith("http://") or url.startswith("https://"))
|
|
121
|
+
|
|
104
122
|
|
|
105
123
|
def get_filename(file):
|
|
106
|
-
return file.rsplit(
|
|
124
|
+
return file.rsplit("/", 1)[1] if "/" in file else "file"
|
|
125
|
+
|
|
107
126
|
|
|
108
127
|
def parse_args_params(args_str):
|
|
109
128
|
"""Parse URL-encoded parameters and return a dictionary."""
|
|
@@ -119,9 +138,9 @@ def parse_args_params(args_str):
|
|
|
119
138
|
if len(values) == 1:
|
|
120
139
|
value = values[0]
|
|
121
140
|
# Try to convert to appropriate types
|
|
122
|
-
if value.lower() ==
|
|
141
|
+
if value.lower() == "true":
|
|
123
142
|
result[key] = True
|
|
124
|
-
elif value.lower() ==
|
|
143
|
+
elif value.lower() == "false":
|
|
125
144
|
result[key] = False
|
|
126
145
|
elif value.isdigit():
|
|
127
146
|
result[key] = int(value)
|
|
@@ -138,6 +157,7 @@ def parse_args_params(args_str):
|
|
|
138
157
|
|
|
139
158
|
return result
|
|
140
159
|
|
|
160
|
+
|
|
141
161
|
def apply_args_to_chat(chat, args_params):
|
|
142
162
|
"""Apply parsed arguments to the chat request."""
|
|
143
163
|
if not args_params:
|
|
@@ -146,19 +166,32 @@ def apply_args_to_chat(chat, args_params):
|
|
|
146
166
|
# Apply each parameter to the chat request
|
|
147
167
|
for key, value in args_params.items():
|
|
148
168
|
if isinstance(value, str):
|
|
149
|
-
if key ==
|
|
150
|
-
if
|
|
151
|
-
value = value.split(
|
|
152
|
-
elif
|
|
169
|
+
if key == "stop":
|
|
170
|
+
if "," in value:
|
|
171
|
+
value = value.split(",")
|
|
172
|
+
elif (
|
|
173
|
+
key == "max_completion_tokens"
|
|
174
|
+
or key == "max_tokens"
|
|
175
|
+
or key == "n"
|
|
176
|
+
or key == "seed"
|
|
177
|
+
or key == "top_logprobs"
|
|
178
|
+
):
|
|
153
179
|
value = int(value)
|
|
154
|
-
elif key ==
|
|
180
|
+
elif key == "temperature" or key == "top_p" or key == "frequency_penalty" or key == "presence_penalty":
|
|
155
181
|
value = float(value)
|
|
156
|
-
elif
|
|
182
|
+
elif (
|
|
183
|
+
key == "store"
|
|
184
|
+
or key == "logprobs"
|
|
185
|
+
or key == "enable_thinking"
|
|
186
|
+
or key == "parallel_tool_calls"
|
|
187
|
+
or key == "stream"
|
|
188
|
+
):
|
|
157
189
|
value = bool(value)
|
|
158
190
|
chat[key] = value
|
|
159
191
|
|
|
160
192
|
return chat
|
|
161
193
|
|
|
194
|
+
|
|
162
195
|
def is_base_64(data):
|
|
163
196
|
try:
|
|
164
197
|
base64.b64decode(data)
|
|
@@ -166,6 +199,17 @@ def is_base_64(data):
|
|
|
166
199
|
except Exception:
|
|
167
200
|
return False
|
|
168
201
|
|
|
202
|
+
|
|
203
|
+
def id_to_name(id):
|
|
204
|
+
return id.replace("-", " ").title()
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def pluralize(word, count):
|
|
208
|
+
if count == 1:
|
|
209
|
+
return word
|
|
210
|
+
return word + "s"
|
|
211
|
+
|
|
212
|
+
|
|
169
213
|
def get_file_mime_type(filename):
|
|
170
214
|
mime_type, _ = mimetypes.guess_type(filename)
|
|
171
215
|
return mime_type or "application/octet-stream"
|
|
@@ -182,36 +226,38 @@ def price_to_string(price: float | int | str | None) -> str | None:
|
|
|
182
226
|
try:
|
|
183
227
|
price_float = float(price)
|
|
184
228
|
# Format with enough decimal places to avoid scientific notation
|
|
185
|
-
formatted = format(price_float,
|
|
229
|
+
formatted = format(price_float, ".20f")
|
|
186
230
|
|
|
187
231
|
# Detect recurring 9s pattern (e.g., "...9999999")
|
|
188
232
|
# If we have 4 or more consecutive 9s, round up
|
|
189
|
-
if
|
|
233
|
+
if "9999" in formatted:
|
|
190
234
|
# Round up by adding a small amount and reformatting
|
|
191
235
|
# Find the position of the 9s to determine precision
|
|
192
236
|
import decimal
|
|
237
|
+
|
|
193
238
|
decimal.getcontext().prec = 28
|
|
194
239
|
d = decimal.Decimal(str(price_float))
|
|
195
240
|
# Round to one less decimal place than where the 9s start
|
|
196
|
-
nines_pos = formatted.find(
|
|
241
|
+
nines_pos = formatted.find("9999")
|
|
197
242
|
if nines_pos > 0:
|
|
198
243
|
# Round up at the position before the 9s
|
|
199
|
-
decimal_places = nines_pos - formatted.find(
|
|
244
|
+
decimal_places = nines_pos - formatted.find(".") - 1
|
|
200
245
|
if decimal_places > 0:
|
|
201
|
-
quantize_str =
|
|
246
|
+
quantize_str = "0." + "0" * (decimal_places - 1) + "1"
|
|
202
247
|
d = d.quantize(decimal.Decimal(quantize_str), rounding=decimal.ROUND_UP)
|
|
203
248
|
result = str(d)
|
|
204
249
|
# Remove trailing zeros
|
|
205
|
-
if
|
|
206
|
-
result = result.rstrip(
|
|
250
|
+
if "." in result:
|
|
251
|
+
result = result.rstrip("0").rstrip(".")
|
|
207
252
|
return result
|
|
208
253
|
|
|
209
254
|
# Normal case: strip trailing zeros
|
|
210
|
-
return formatted.rstrip(
|
|
255
|
+
return formatted.rstrip("0").rstrip(".")
|
|
211
256
|
except (ValueError, TypeError):
|
|
212
257
|
return None
|
|
213
258
|
|
|
214
|
-
|
|
259
|
+
|
|
260
|
+
def convert_image_if_needed(image_bytes, mimetype="image/png"):
|
|
215
261
|
"""
|
|
216
262
|
Convert and resize image to WebP if it exceeds configured limits.
|
|
217
263
|
|
|
@@ -226,16 +272,16 @@ def convert_image_if_needed(image_bytes, mimetype='image/png'):
|
|
|
226
272
|
return image_bytes, mimetype
|
|
227
273
|
|
|
228
274
|
# Get conversion config
|
|
229
|
-
convert_config = g_config.get(
|
|
275
|
+
convert_config = g_config.get("convert", {}).get("image", {}) if g_config else {}
|
|
230
276
|
if not convert_config:
|
|
231
277
|
return image_bytes, mimetype
|
|
232
278
|
|
|
233
|
-
max_size_str = convert_config.get(
|
|
234
|
-
max_length = convert_config.get(
|
|
279
|
+
max_size_str = convert_config.get("max_size", "1536x1024")
|
|
280
|
+
max_length = convert_config.get("max_length", 1.5 * 1024 * 1024) # 1.5MB
|
|
235
281
|
|
|
236
282
|
try:
|
|
237
283
|
# Parse max_size (e.g., "1536x1024")
|
|
238
|
-
max_width, max_height = map(int, max_size_str.split(
|
|
284
|
+
max_width, max_height = map(int, max_size_str.split("x"))
|
|
239
285
|
|
|
240
286
|
# Open image
|
|
241
287
|
with Image.open(BytesIO(image_bytes)) as img:
|
|
@@ -253,15 +299,15 @@ def convert_image_if_needed(image_bytes, mimetype='image/png'):
|
|
|
253
299
|
return image_bytes, mimetype
|
|
254
300
|
|
|
255
301
|
# Convert RGBA to RGB if necessary (WebP doesn't support transparency in RGB mode)
|
|
256
|
-
if img.mode in (
|
|
302
|
+
if img.mode in ("RGBA", "LA", "P"):
|
|
257
303
|
# Create a white background
|
|
258
|
-
background = Image.new(
|
|
259
|
-
if img.mode ==
|
|
260
|
-
img = img.convert(
|
|
261
|
-
background.paste(img, mask=img.split()[-1] if img.mode in (
|
|
304
|
+
background = Image.new("RGB", img.size, (255, 255, 255))
|
|
305
|
+
if img.mode == "P":
|
|
306
|
+
img = img.convert("RGBA")
|
|
307
|
+
background.paste(img, mask=img.split()[-1] if img.mode in ("RGBA", "LA") else None)
|
|
262
308
|
img = background
|
|
263
|
-
elif img.mode !=
|
|
264
|
-
img = img.convert(
|
|
309
|
+
elif img.mode != "RGB":
|
|
310
|
+
img = img.convert("RGB")
|
|
265
311
|
|
|
266
312
|
# Resize if needed (preserve aspect ratio)
|
|
267
313
|
if needs_resize:
|
|
@@ -270,39 +316,85 @@ def convert_image_if_needed(image_bytes, mimetype='image/png'):
|
|
|
270
316
|
|
|
271
317
|
# Convert to WebP
|
|
272
318
|
output = BytesIO()
|
|
273
|
-
img.save(output, format=
|
|
319
|
+
img.save(output, format="WEBP", quality=85, method=6)
|
|
274
320
|
converted_bytes = output.getvalue()
|
|
275
321
|
|
|
276
|
-
_log(
|
|
322
|
+
_log(
|
|
323
|
+
f"Converted image to WebP: {len(image_bytes)} bytes -> {len(converted_bytes)} bytes ({len(converted_bytes) * 100 // len(image_bytes)}%)"
|
|
324
|
+
)
|
|
277
325
|
|
|
278
|
-
return converted_bytes,
|
|
326
|
+
return converted_bytes, "image/webp"
|
|
279
327
|
|
|
280
328
|
except Exception as e:
|
|
281
329
|
_log(f"Error converting image: {e}")
|
|
282
330
|
# Return original if conversion fails
|
|
283
331
|
return image_bytes, mimetype
|
|
284
332
|
|
|
285
|
-
|
|
333
|
+
|
|
334
|
+
def to_content(result):
|
|
335
|
+
if isinstance(result, (str, int, float, bool)):
|
|
336
|
+
return str(result)
|
|
337
|
+
elif isinstance(result, (list, set, tuple, dict)):
|
|
338
|
+
return json.dumps(result)
|
|
339
|
+
else:
|
|
340
|
+
return str(result)
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def function_to_tool_definition(func):
|
|
344
|
+
type_hints = get_type_hints(func)
|
|
345
|
+
signature = inspect.signature(func)
|
|
346
|
+
parameters = {"type": "object", "properties": {}, "required": []}
|
|
347
|
+
|
|
348
|
+
for name, param in signature.parameters.items():
|
|
349
|
+
param_type = type_hints.get(name, str)
|
|
350
|
+
param_type_name = "string"
|
|
351
|
+
if param_type is int:
|
|
352
|
+
param_type_name = "integer"
|
|
353
|
+
elif param_type is float:
|
|
354
|
+
param_type_name = "number"
|
|
355
|
+
elif param_type is bool:
|
|
356
|
+
param_type_name = "boolean"
|
|
357
|
+
|
|
358
|
+
parameters["properties"][name] = {"type": param_type_name}
|
|
359
|
+
if param.default == inspect.Parameter.empty:
|
|
360
|
+
parameters["required"].append(name)
|
|
361
|
+
|
|
362
|
+
return {
|
|
363
|
+
"type": "function",
|
|
364
|
+
"function": {
|
|
365
|
+
"name": func.__name__,
|
|
366
|
+
"description": func.__doc__ or "",
|
|
367
|
+
"parameters": parameters,
|
|
368
|
+
},
|
|
369
|
+
}
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
async def process_chat(chat, provider_id=None):
|
|
286
373
|
if not chat:
|
|
287
374
|
raise Exception("No chat provided")
|
|
288
|
-
if
|
|
289
|
-
chat[
|
|
290
|
-
|
|
375
|
+
if "stream" not in chat:
|
|
376
|
+
chat["stream"] = False
|
|
377
|
+
# Some providers don't support empty tools
|
|
378
|
+
if "tools" in chat and len(chat["tools"]) == 0:
|
|
379
|
+
del chat["tools"]
|
|
380
|
+
if "messages" not in chat:
|
|
291
381
|
return chat
|
|
292
382
|
|
|
293
383
|
async with aiohttp.ClientSession() as session:
|
|
294
|
-
for message in chat[
|
|
295
|
-
if
|
|
384
|
+
for message in chat["messages"]:
|
|
385
|
+
if "content" not in message:
|
|
296
386
|
continue
|
|
297
387
|
|
|
298
|
-
if isinstance(message[
|
|
299
|
-
for item in message[
|
|
300
|
-
if
|
|
388
|
+
if isinstance(message["content"], list):
|
|
389
|
+
for item in message["content"]:
|
|
390
|
+
if "type" not in item:
|
|
301
391
|
continue
|
|
302
|
-
if item[
|
|
303
|
-
image_url = item[
|
|
304
|
-
if
|
|
305
|
-
url = image_url[
|
|
392
|
+
if item["type"] == "image_url" and "image_url" in item:
|
|
393
|
+
image_url = item["image_url"]
|
|
394
|
+
if "url" in image_url:
|
|
395
|
+
url = image_url["url"]
|
|
396
|
+
if url.startswith("/~cache/"):
|
|
397
|
+
url = get_cache_path(url[8:])
|
|
306
398
|
if is_url(url):
|
|
307
399
|
_log(f"Downloading image: {url}")
|
|
308
400
|
async with session.get(url, timeout=aiohttp.ClientTimeout(total=120)) as response:
|
|
@@ -310,12 +402,14 @@ async def process_chat(chat):
|
|
|
310
402
|
content = await response.read()
|
|
311
403
|
# get mimetype from response headers
|
|
312
404
|
mimetype = get_file_mime_type(get_filename(url))
|
|
313
|
-
if
|
|
314
|
-
mimetype = response.headers[
|
|
405
|
+
if "Content-Type" in response.headers:
|
|
406
|
+
mimetype = response.headers["Content-Type"]
|
|
315
407
|
# convert/resize image if needed
|
|
316
408
|
content, mimetype = convert_image_if_needed(content, mimetype)
|
|
317
409
|
# convert to data uri
|
|
318
|
-
image_url[
|
|
410
|
+
image_url["url"] = (
|
|
411
|
+
f"data:{mimetype};base64,{base64.b64encode(content).decode('utf-8')}"
|
|
412
|
+
)
|
|
319
413
|
elif is_file_path(url):
|
|
320
414
|
_log(f"Reading image: {url}")
|
|
321
415
|
with open(url, "rb") as f:
|
|
@@ -325,24 +419,30 @@ async def process_chat(chat):
|
|
|
325
419
|
# convert/resize image if needed
|
|
326
420
|
content, mimetype = convert_image_if_needed(content, mimetype)
|
|
327
421
|
# convert to data uri
|
|
328
|
-
image_url[
|
|
329
|
-
|
|
422
|
+
image_url["url"] = (
|
|
423
|
+
f"data:{mimetype};base64,{base64.b64encode(content).decode('utf-8')}"
|
|
424
|
+
)
|
|
425
|
+
elif url.startswith("data:"):
|
|
330
426
|
# Extract existing data URI and process it
|
|
331
|
-
if
|
|
332
|
-
prefix = url.split(
|
|
333
|
-
mimetype = prefix.split(
|
|
334
|
-
base64_data = url.split(
|
|
427
|
+
if ";base64," in url:
|
|
428
|
+
prefix = url.split(";base64,")[0]
|
|
429
|
+
mimetype = prefix.split(":")[1] if ":" in prefix else "image/png"
|
|
430
|
+
base64_data = url.split(";base64,")[1]
|
|
335
431
|
content = base64.b64decode(base64_data)
|
|
336
432
|
# convert/resize image if needed
|
|
337
433
|
content, mimetype = convert_image_if_needed(content, mimetype)
|
|
338
434
|
# update data uri with potentially converted image
|
|
339
|
-
image_url[
|
|
435
|
+
image_url["url"] = (
|
|
436
|
+
f"data:{mimetype};base64,{base64.b64encode(content).decode('utf-8')}"
|
|
437
|
+
)
|
|
340
438
|
else:
|
|
341
439
|
raise Exception(f"Invalid image: {url}")
|
|
342
|
-
elif item[
|
|
343
|
-
input_audio = item[
|
|
344
|
-
if
|
|
345
|
-
url = input_audio[
|
|
440
|
+
elif item["type"] == "input_audio" and "input_audio" in item:
|
|
441
|
+
input_audio = item["input_audio"]
|
|
442
|
+
if "data" in input_audio:
|
|
443
|
+
url = input_audio["data"]
|
|
444
|
+
if url.startswith("/~cache/"):
|
|
445
|
+
url = get_cache_path(url[8:])
|
|
346
446
|
mimetype = get_file_mime_type(get_filename(url))
|
|
347
447
|
if is_url(url):
|
|
348
448
|
_log(f"Downloading audio: {url}")
|
|
@@ -350,48 +450,145 @@ async def process_chat(chat):
|
|
|
350
450
|
response.raise_for_status()
|
|
351
451
|
content = await response.read()
|
|
352
452
|
# get mimetype from response headers
|
|
353
|
-
if
|
|
354
|
-
mimetype = response.headers[
|
|
453
|
+
if "Content-Type" in response.headers:
|
|
454
|
+
mimetype = response.headers["Content-Type"]
|
|
355
455
|
# convert to base64
|
|
356
|
-
input_audio[
|
|
357
|
-
|
|
456
|
+
input_audio["data"] = base64.b64encode(content).decode("utf-8")
|
|
457
|
+
if provider_id == "alibaba":
|
|
458
|
+
input_audio["data"] = f"data:{mimetype};base64,{input_audio['data']}"
|
|
459
|
+
input_audio["format"] = mimetype.rsplit("/", 1)[1]
|
|
358
460
|
elif is_file_path(url):
|
|
359
461
|
_log(f"Reading audio: {url}")
|
|
360
462
|
with open(url, "rb") as f:
|
|
361
463
|
content = f.read()
|
|
362
464
|
# convert to base64
|
|
363
|
-
input_audio[
|
|
364
|
-
|
|
465
|
+
input_audio["data"] = base64.b64encode(content).decode("utf-8")
|
|
466
|
+
if provider_id == "alibaba":
|
|
467
|
+
input_audio["data"] = f"data:{mimetype};base64,{input_audio['data']}"
|
|
468
|
+
input_audio["format"] = mimetype.rsplit("/", 1)[1]
|
|
365
469
|
elif is_base_64(url):
|
|
366
|
-
pass
|
|
470
|
+
pass # use base64 data as-is
|
|
367
471
|
else:
|
|
368
472
|
raise Exception(f"Invalid audio: {url}")
|
|
369
|
-
elif item[
|
|
370
|
-
file = item[
|
|
371
|
-
if
|
|
372
|
-
url = file[
|
|
473
|
+
elif item["type"] == "file" and "file" in item:
|
|
474
|
+
file = item["file"]
|
|
475
|
+
if "file_data" in file:
|
|
476
|
+
url = file["file_data"]
|
|
477
|
+
if url.startswith("/~cache/"):
|
|
478
|
+
url = get_cache_path(url[8:])
|
|
373
479
|
mimetype = get_file_mime_type(get_filename(url))
|
|
374
480
|
if is_url(url):
|
|
375
481
|
_log(f"Downloading file: {url}")
|
|
376
482
|
async with session.get(url, timeout=aiohttp.ClientTimeout(total=120)) as response:
|
|
377
483
|
response.raise_for_status()
|
|
378
484
|
content = await response.read()
|
|
379
|
-
file[
|
|
380
|
-
file[
|
|
485
|
+
file["filename"] = get_filename(url)
|
|
486
|
+
file["file_data"] = (
|
|
487
|
+
f"data:{mimetype};base64,{base64.b64encode(content).decode('utf-8')}"
|
|
488
|
+
)
|
|
381
489
|
elif is_file_path(url):
|
|
382
490
|
_log(f"Reading file: {url}")
|
|
383
491
|
with open(url, "rb") as f:
|
|
384
492
|
content = f.read()
|
|
385
|
-
file[
|
|
386
|
-
file[
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
493
|
+
file["filename"] = get_filename(url)
|
|
494
|
+
file["file_data"] = (
|
|
495
|
+
f"data:{mimetype};base64,{base64.b64encode(content).decode('utf-8')}"
|
|
496
|
+
)
|
|
497
|
+
elif url.startswith("data:"):
|
|
498
|
+
if "filename" not in file:
|
|
499
|
+
file["filename"] = "file"
|
|
500
|
+
pass # use base64 data as-is
|
|
391
501
|
else:
|
|
392
502
|
raise Exception(f"Invalid file: {url}")
|
|
393
503
|
return chat
|
|
394
504
|
|
|
505
|
+
|
|
506
|
+
def image_ext_from_mimetype(mimetype, default="png"):
|
|
507
|
+
if "/" in mimetype:
|
|
508
|
+
_ext = mimetypes.guess_extension(mimetype)
|
|
509
|
+
if _ext:
|
|
510
|
+
return _ext.lstrip(".")
|
|
511
|
+
return default
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
def audio_ext_from_format(format, default="mp3"):
|
|
515
|
+
if format == "mpeg":
|
|
516
|
+
return "mp3"
|
|
517
|
+
return format or default
|
|
518
|
+
|
|
519
|
+
|
|
520
|
+
def file_ext_from_mimetype(mimetype, default="pdf"):
|
|
521
|
+
if "/" in mimetype:
|
|
522
|
+
_ext = mimetypes.guess_extension(mimetype)
|
|
523
|
+
if _ext:
|
|
524
|
+
return _ext.lstrip(".")
|
|
525
|
+
return default
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
def cache_message_inline_data(m):
|
|
529
|
+
"""
|
|
530
|
+
Replaces and caches any inline data URIs in the message content.
|
|
531
|
+
"""
|
|
532
|
+
if "content" not in m:
|
|
533
|
+
return
|
|
534
|
+
|
|
535
|
+
content = m["content"]
|
|
536
|
+
if isinstance(content, list):
|
|
537
|
+
for item in content:
|
|
538
|
+
if item.get("type") == "image_url":
|
|
539
|
+
image_url = item.get("image_url", {})
|
|
540
|
+
url = image_url.get("url")
|
|
541
|
+
if url and url.startswith("data:"):
|
|
542
|
+
# Extract base64 and mimetype
|
|
543
|
+
try:
|
|
544
|
+
header, base64_data = url.split(";base64,")
|
|
545
|
+
# header is like "data:image/png"
|
|
546
|
+
ext = image_ext_from_mimetype(header.split(":")[1])
|
|
547
|
+
filename = f"image.{ext}" # Hash will handle uniqueness
|
|
548
|
+
|
|
549
|
+
cache_url, _ = save_image_to_cache(base64_data, filename, {}, ignore_info=True)
|
|
550
|
+
image_url["url"] = cache_url
|
|
551
|
+
except Exception as e:
|
|
552
|
+
_log(f"Error caching inline image: {e}")
|
|
553
|
+
|
|
554
|
+
elif item.get("type") == "input_audio":
|
|
555
|
+
input_audio = item.get("input_audio", {})
|
|
556
|
+
data = input_audio.get("data")
|
|
557
|
+
if data:
|
|
558
|
+
# Handle data URI or raw base64
|
|
559
|
+
base64_data = data
|
|
560
|
+
if data.startswith("data:"):
|
|
561
|
+
with contextlib.suppress(ValueError):
|
|
562
|
+
header, base64_data = data.split(";base64,")
|
|
563
|
+
|
|
564
|
+
fmt = audio_ext_from_format(input_audio.get("format"))
|
|
565
|
+
filename = f"audio.{fmt}"
|
|
566
|
+
|
|
567
|
+
try:
|
|
568
|
+
cache_url, _ = save_bytes_to_cache(base64_data, filename, {}, ignore_info=True)
|
|
569
|
+
input_audio["data"] = cache_url
|
|
570
|
+
except Exception as e:
|
|
571
|
+
_log(f"Error caching inline audio: {e}")
|
|
572
|
+
|
|
573
|
+
elif item.get("type") == "file":
|
|
574
|
+
file_info = item.get("file", {})
|
|
575
|
+
file_data = file_info.get("file_data")
|
|
576
|
+
if file_data and file_data.startswith("data:"):
|
|
577
|
+
try:
|
|
578
|
+
header, base64_data = file_data.split(";base64,")
|
|
579
|
+
mimetype = header.split(":")[1]
|
|
580
|
+
# Try to get extension from filename if available, else mimetype
|
|
581
|
+
filename = file_info.get("filename", "file")
|
|
582
|
+
if "." not in filename:
|
|
583
|
+
ext = file_ext_from_mimetype(mimetype)
|
|
584
|
+
filename = f"{filename}.{ext}"
|
|
585
|
+
|
|
586
|
+
cache_url, _ = save_bytes_to_cache(base64_data, filename, {}, ignore_info=True)
|
|
587
|
+
file_info["file_data"] = cache_url
|
|
588
|
+
except Exception as e:
|
|
589
|
+
_log(f"Error caching inline file: {e}")
|
|
590
|
+
|
|
591
|
+
|
|
395
592
|
class HTTPError(Exception):
|
|
396
593
|
def __init__(self, status, reason, body, headers=None):
|
|
397
594
|
self.status = status
|
|
@@ -400,448 +597,923 @@ class HTTPError(Exception):
|
|
|
400
597
|
self.headers = headers
|
|
401
598
|
super().__init__(f"HTTP {status} {reason}")
|
|
402
599
|
|
|
600
|
+
|
|
601
|
+
def save_bytes_to_cache(base64_data, filename, file_info, ignore_info=False):
|
|
602
|
+
ext = filename.split(".")[-1]
|
|
603
|
+
mimetype = get_file_mime_type(filename)
|
|
604
|
+
content = base64.b64decode(base64_data) if isinstance(base64_data, str) else base64_data
|
|
605
|
+
sha256_hash = hashlib.sha256(content).hexdigest()
|
|
606
|
+
|
|
607
|
+
save_filename = f"{sha256_hash}.{ext}" if ext else sha256_hash
|
|
608
|
+
|
|
609
|
+
# Use first 2 chars for subdir to avoid too many files in one dir
|
|
610
|
+
subdir = sha256_hash[:2]
|
|
611
|
+
relative_path = f"{subdir}/{save_filename}"
|
|
612
|
+
full_path = get_cache_path(relative_path)
|
|
613
|
+
url = f"/~cache/{relative_path}"
|
|
614
|
+
|
|
615
|
+
# if file and its .info.json already exists, return it
|
|
616
|
+
info_path = os.path.splitext(full_path)[0] + ".info.json"
|
|
617
|
+
if os.path.exists(full_path) and os.path.exists(info_path):
|
|
618
|
+
_dbg(f"Cached bytes exists: {relative_path}")
|
|
619
|
+
if ignore_info:
|
|
620
|
+
return url, None
|
|
621
|
+
return url, json.load(open(info_path))
|
|
622
|
+
|
|
623
|
+
os.makedirs(os.path.dirname(full_path), exist_ok=True)
|
|
624
|
+
|
|
625
|
+
with open(full_path, "wb") as f:
|
|
626
|
+
f.write(content)
|
|
627
|
+
info = {
|
|
628
|
+
"date": int(time.time()),
|
|
629
|
+
"url": url,
|
|
630
|
+
"size": len(content),
|
|
631
|
+
"type": mimetype,
|
|
632
|
+
"name": filename,
|
|
633
|
+
}
|
|
634
|
+
info.update(file_info)
|
|
635
|
+
|
|
636
|
+
# Save metadata
|
|
637
|
+
info_path = os.path.splitext(full_path)[0] + ".info.json"
|
|
638
|
+
with open(info_path, "w") as f:
|
|
639
|
+
json.dump(info, f)
|
|
640
|
+
|
|
641
|
+
_dbg(f"Saved cached bytes and info: {relative_path}")
|
|
642
|
+
|
|
643
|
+
g_app.on_cache_saved_filters({"url": url, "info": info})
|
|
644
|
+
|
|
645
|
+
return url, info
|
|
646
|
+
|
|
647
|
+
|
|
648
|
+
def save_image_to_cache(base64_data, filename, image_info, ignore_info=False):
|
|
649
|
+
ext = filename.split(".")[-1]
|
|
650
|
+
mimetype = get_file_mime_type(filename)
|
|
651
|
+
content = base64.b64decode(base64_data) if isinstance(base64_data, str) else base64_data
|
|
652
|
+
sha256_hash = hashlib.sha256(content).hexdigest()
|
|
653
|
+
|
|
654
|
+
save_filename = f"{sha256_hash}.{ext}" if ext else sha256_hash
|
|
655
|
+
|
|
656
|
+
# Use first 2 chars for subdir to avoid too many files in one dir
|
|
657
|
+
subdir = sha256_hash[:2]
|
|
658
|
+
relative_path = f"{subdir}/{save_filename}"
|
|
659
|
+
full_path = get_cache_path(relative_path)
|
|
660
|
+
url = f"/~cache/{relative_path}"
|
|
661
|
+
|
|
662
|
+
# if file and its .info.json already exists, return it
|
|
663
|
+
info_path = os.path.splitext(full_path)[0] + ".info.json"
|
|
664
|
+
if os.path.exists(full_path) and os.path.exists(info_path):
|
|
665
|
+
_dbg(f"Saved image exists: {relative_path}")
|
|
666
|
+
if ignore_info:
|
|
667
|
+
return url, None
|
|
668
|
+
return url, json.load(open(info_path))
|
|
669
|
+
|
|
670
|
+
os.makedirs(os.path.dirname(full_path), exist_ok=True)
|
|
671
|
+
|
|
672
|
+
with open(full_path, "wb") as f:
|
|
673
|
+
f.write(content)
|
|
674
|
+
info = {
|
|
675
|
+
"date": int(time.time()),
|
|
676
|
+
"url": url,
|
|
677
|
+
"size": len(content),
|
|
678
|
+
"type": mimetype,
|
|
679
|
+
"name": filename,
|
|
680
|
+
}
|
|
681
|
+
info.update(image_info)
|
|
682
|
+
|
|
683
|
+
# If image, get dimensions
|
|
684
|
+
if HAS_PIL and mimetype.startswith("image/"):
|
|
685
|
+
try:
|
|
686
|
+
with Image.open(BytesIO(content)) as img:
|
|
687
|
+
info["width"] = img.width
|
|
688
|
+
info["height"] = img.height
|
|
689
|
+
except Exception:
|
|
690
|
+
pass
|
|
691
|
+
|
|
692
|
+
if "width" in info and "height" in info:
|
|
693
|
+
_log(f"Saved image to cache: {full_path} ({len(content)} bytes) {info['width']}x{info['height']}")
|
|
694
|
+
else:
|
|
695
|
+
_log(f"Saved image to cache: {full_path} ({len(content)} bytes)")
|
|
696
|
+
|
|
697
|
+
# Save metadata
|
|
698
|
+
info_path = os.path.splitext(full_path)[0] + ".info.json"
|
|
699
|
+
with open(info_path, "w") as f:
|
|
700
|
+
json.dump(info, f)
|
|
701
|
+
|
|
702
|
+
_dbg(f"Saved image and info: {relative_path}")
|
|
703
|
+
|
|
704
|
+
g_app.on_cache_saved_filters({"url": url, "info": info})
|
|
705
|
+
|
|
706
|
+
return url, info
|
|
707
|
+
|
|
708
|
+
|
|
403
709
|
async def response_json(response):
|
|
404
710
|
text = await response.text()
|
|
405
711
|
if response.status >= 400:
|
|
712
|
+
_dbg(f"HTTP {response.status} {response.reason}: {text}")
|
|
406
713
|
raise HTTPError(response.status, reason=response.reason, body=text, headers=dict(response.headers))
|
|
407
714
|
response.raise_for_status()
|
|
408
715
|
body = json.loads(text)
|
|
409
716
|
return body
|
|
410
717
|
|
|
411
|
-
class OpenAiProvider:
|
|
412
|
-
def __init__(self, base_url, api_key=None, models={}, **kwargs):
|
|
413
|
-
self.base_url = base_url.strip("/")
|
|
414
|
-
self.api_key = api_key
|
|
415
|
-
self.models = models
|
|
416
718
|
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
719
|
+
def chat_to_prompt(chat):
|
|
720
|
+
prompt = ""
|
|
721
|
+
if "messages" in chat:
|
|
722
|
+
for message in chat["messages"]:
|
|
723
|
+
if message["role"] == "user":
|
|
724
|
+
# if content is string
|
|
725
|
+
if isinstance(message["content"], str):
|
|
726
|
+
if prompt:
|
|
727
|
+
prompt += "\n"
|
|
728
|
+
prompt += message["content"]
|
|
729
|
+
elif isinstance(message["content"], list):
|
|
730
|
+
# if content is array of objects
|
|
731
|
+
for part in message["content"]:
|
|
732
|
+
if part["type"] == "text":
|
|
733
|
+
if prompt:
|
|
734
|
+
prompt += "\n"
|
|
735
|
+
prompt += part["text"]
|
|
736
|
+
return prompt
|
|
737
|
+
|
|
738
|
+
|
|
739
|
+
def chat_to_system_prompt(chat):
|
|
740
|
+
if "messages" in chat:
|
|
741
|
+
for message in chat["messages"]:
|
|
742
|
+
if message["role"] == "system":
|
|
743
|
+
# if content is string
|
|
744
|
+
if isinstance(message["content"], str):
|
|
745
|
+
return message["content"]
|
|
746
|
+
elif isinstance(message["content"], list):
|
|
747
|
+
# if content is array of objects
|
|
748
|
+
for part in message["content"]:
|
|
749
|
+
if part["type"] == "text":
|
|
750
|
+
return part["text"]
|
|
751
|
+
return None
|
|
752
|
+
|
|
753
|
+
|
|
754
|
+
def chat_to_username(chat):
|
|
755
|
+
if "metadata" in chat and "user" in chat["metadata"]:
|
|
756
|
+
return chat["metadata"]["user"]
|
|
757
|
+
return None
|
|
758
|
+
|
|
759
|
+
|
|
760
|
+
def last_user_prompt(chat):
|
|
761
|
+
prompt = ""
|
|
762
|
+
if "messages" in chat:
|
|
763
|
+
for message in chat["messages"]:
|
|
764
|
+
if message["role"] == "user":
|
|
765
|
+
# if content is string
|
|
766
|
+
if isinstance(message["content"], str):
|
|
767
|
+
prompt = message["content"]
|
|
768
|
+
elif isinstance(message["content"], list):
|
|
769
|
+
# if content is array of objects
|
|
770
|
+
for part in message["content"]:
|
|
771
|
+
if part["type"] == "text":
|
|
772
|
+
prompt = part["text"]
|
|
773
|
+
return prompt
|
|
774
|
+
|
|
775
|
+
|
|
776
|
+
def chat_response_to_message(openai_response):
|
|
777
|
+
"""
|
|
778
|
+
Returns an assistant message from the OpenAI Response.
|
|
779
|
+
Handles normalizing text, image, and audio responses into the message content.
|
|
780
|
+
"""
|
|
781
|
+
timestamp = int(time.time() * 1000) # openai_response.get("created")
|
|
782
|
+
choices = openai_response
|
|
783
|
+
if isinstance(openai_response, dict) and "choices" in openai_response:
|
|
784
|
+
choices = openai_response["choices"]
|
|
785
|
+
|
|
786
|
+
choice = choices[0] if isinstance(choices, list) and choices else choices
|
|
787
|
+
|
|
788
|
+
if isinstance(choice, str):
|
|
789
|
+
return {"role": "assistant", "content": choice, "timestamp": timestamp}
|
|
790
|
+
|
|
791
|
+
if isinstance(choice, dict):
|
|
792
|
+
message = choice.get("message", choice)
|
|
793
|
+
else:
|
|
794
|
+
return {"role": "assistant", "content": str(choice), "timestamp": timestamp}
|
|
795
|
+
|
|
796
|
+
# Ensure message is a dict
|
|
797
|
+
if not isinstance(message, dict):
|
|
798
|
+
return {"role": "assistant", "content": message, "timestamp": timestamp}
|
|
799
|
+
|
|
800
|
+
message.update({"timestamp": timestamp})
|
|
801
|
+
return message
|
|
802
|
+
|
|
803
|
+
|
|
804
|
+
def to_file_info(chat, info=None, response=None):
|
|
805
|
+
prompt = last_user_prompt(chat)
|
|
806
|
+
ret = info or {}
|
|
807
|
+
if chat["model"] and "model" not in ret:
|
|
808
|
+
ret["model"] = chat["model"]
|
|
809
|
+
if prompt and "prompt" not in ret:
|
|
810
|
+
ret["prompt"] = prompt
|
|
811
|
+
if "image_config" in chat:
|
|
812
|
+
ret.update(chat["image_config"])
|
|
813
|
+
user = chat_to_username(chat)
|
|
814
|
+
if user:
|
|
815
|
+
ret["user"] = user
|
|
816
|
+
return ret
|
|
817
|
+
|
|
423
818
|
|
|
424
|
-
|
|
819
|
+
# Image Generator Providers
|
|
820
|
+
class GeneratorBase:
|
|
821
|
+
def __init__(self, **kwargs):
|
|
822
|
+
self.id = kwargs.get("id")
|
|
823
|
+
self.api = kwargs.get("api")
|
|
824
|
+
self.api_key = kwargs.get("api_key")
|
|
825
|
+
self.headers = {
|
|
826
|
+
"Accept": "application/json",
|
|
425
827
|
"Content-Type": "application/json",
|
|
426
828
|
}
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
self.temperature = float(kwargs['temperature']) if 'temperature' in kwargs else None
|
|
443
|
-
self.top_logprobs = int(kwargs['top_logprobs']) if 'top_logprobs' in kwargs else None
|
|
444
|
-
self.top_p = float(kwargs['top_p']) if 'top_p' in kwargs else None
|
|
445
|
-
self.verbosity = kwargs['verbosity'] if 'verbosity' in kwargs else None
|
|
446
|
-
self.stream = bool(kwargs['stream']) if 'stream' in kwargs else None
|
|
447
|
-
self.enable_thinking = bool(kwargs['enable_thinking']) if 'enable_thinking' in kwargs else None
|
|
448
|
-
self.pricing = kwargs['pricing'] if 'pricing' in kwargs else None
|
|
449
|
-
self.default_pricing = kwargs['default_pricing'] if 'default_pricing' in kwargs else None
|
|
450
|
-
self.check = kwargs['check'] if 'check' in kwargs else None
|
|
451
|
-
|
|
452
|
-
@classmethod
|
|
453
|
-
def test(cls, base_url=None, api_key=None, models={}, **kwargs):
|
|
454
|
-
return base_url and api_key and len(models) > 0
|
|
829
|
+
self.chat_url = f"{self.api}/chat/completions"
|
|
830
|
+
self.default_content = "I've generated the image for you."
|
|
831
|
+
|
|
832
|
+
def validate(self, **kwargs):
|
|
833
|
+
if not self.api_key:
|
|
834
|
+
api_keys = ", ".join(self.env)
|
|
835
|
+
return f"Provider '{self.name}' requires API Key {api_keys}"
|
|
836
|
+
return None
|
|
837
|
+
|
|
838
|
+
def test(self, **kwargs):
|
|
839
|
+
error_msg = self.validate(**kwargs)
|
|
840
|
+
if error_msg:
|
|
841
|
+
_log(error_msg)
|
|
842
|
+
return False
|
|
843
|
+
return True
|
|
455
844
|
|
|
456
845
|
async def load(self):
|
|
457
846
|
pass
|
|
458
847
|
|
|
459
|
-
def
|
|
848
|
+
def gen_summary(self, gen):
|
|
849
|
+
"""Summarize gen response for logging."""
|
|
850
|
+
clone = json.loads(json.dumps(gen))
|
|
851
|
+
return json.dumps(clone, indent=2)
|
|
852
|
+
|
|
853
|
+
def chat_summary(self, chat):
|
|
854
|
+
return chat_summary(chat)
|
|
855
|
+
|
|
856
|
+
def process_chat(self, chat, provider_id=None):
|
|
857
|
+
return process_chat(chat, provider_id)
|
|
858
|
+
|
|
859
|
+
async def response_json(self, response):
|
|
860
|
+
return await response_json(response)
|
|
861
|
+
|
|
862
|
+
def get_headers(self, provider, chat):
|
|
863
|
+
headers = self.headers.copy()
|
|
864
|
+
if provider is not None:
|
|
865
|
+
headers["Authorization"] = f"Bearer {provider.api_key}"
|
|
866
|
+
elif self.api_key:
|
|
867
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
868
|
+
return headers
|
|
869
|
+
|
|
870
|
+
def to_response(self, response, chat, started_at):
|
|
871
|
+
raise NotImplementedError
|
|
872
|
+
|
|
873
|
+
async def chat(self, chat, provider=None):
|
|
874
|
+
return {
|
|
875
|
+
"choices": [
|
|
876
|
+
{
|
|
877
|
+
"message": {
|
|
878
|
+
"role": "assistant",
|
|
879
|
+
"content": "Not Implemented",
|
|
880
|
+
"images": [
|
|
881
|
+
{
|
|
882
|
+
"type": "image_url",
|
|
883
|
+
"image_url": {
|
|
884
|
+
"url": "data:image/svg+xml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHdpZHRoPSIyNCIgaGVpZ2h0PSIyNCIgdmlld0JveD0iMCAwIDI0IDI0Ij48cGF0aCBmaWxsPSJjdXJyZW50Q29sb3IiIGQ9Ik0xMiAyMGE4IDggMCAxIDAgMC0xNmE4IDggMCAwIDAgMCAxNm0wIDJDNi40NzcgMjIgMiAxNy41MjMgMiAxMlM2LjQ3NyAyIDEyIDJzMTAgNC40NzcgMTAgMTBzLTQuNDc3IDEwLTEwIDEwbS0xLTZoMnYyaC0yem0wLTEwaDJ2OGgtMnoiLz48L3N2Zz4=",
|
|
885
|
+
},
|
|
886
|
+
}
|
|
887
|
+
],
|
|
888
|
+
}
|
|
889
|
+
}
|
|
890
|
+
]
|
|
891
|
+
}
|
|
892
|
+
|
|
893
|
+
|
|
894
|
+
# OpenAI Providers
|
|
895
|
+
|
|
896
|
+
|
|
897
|
+
class OpenAiCompatible:
|
|
898
|
+
sdk = "@ai-sdk/openai-compatible"
|
|
899
|
+
|
|
900
|
+
def __init__(self, **kwargs):
|
|
901
|
+
required_args = ["id", "api"]
|
|
902
|
+
for arg in required_args:
|
|
903
|
+
if arg not in kwargs:
|
|
904
|
+
raise ValueError(f"Missing required argument: {arg}")
|
|
905
|
+
|
|
906
|
+
self.id = kwargs.get("id")
|
|
907
|
+
self.api = kwargs.get("api").strip("/")
|
|
908
|
+
self.env = kwargs.get("env", [])
|
|
909
|
+
self.api_key = kwargs.get("api_key")
|
|
910
|
+
self.name = kwargs.get("name", id_to_name(self.id))
|
|
911
|
+
self.set_models(**kwargs)
|
|
912
|
+
|
|
913
|
+
self.chat_url = f"{self.api}/chat/completions"
|
|
914
|
+
|
|
915
|
+
self.headers = kwargs.get("headers", {"Content-Type": "application/json"})
|
|
916
|
+
if self.api_key is not None:
|
|
917
|
+
self.headers["Authorization"] = f"Bearer {self.api_key}"
|
|
918
|
+
|
|
919
|
+
self.frequency_penalty = float(kwargs["frequency_penalty"]) if "frequency_penalty" in kwargs else None
|
|
920
|
+
self.max_completion_tokens = int(kwargs["max_completion_tokens"]) if "max_completion_tokens" in kwargs else None
|
|
921
|
+
self.n = int(kwargs["n"]) if "n" in kwargs else None
|
|
922
|
+
self.parallel_tool_calls = bool(kwargs["parallel_tool_calls"]) if "parallel_tool_calls" in kwargs else None
|
|
923
|
+
self.presence_penalty = float(kwargs["presence_penalty"]) if "presence_penalty" in kwargs else None
|
|
924
|
+
self.prompt_cache_key = kwargs.get("prompt_cache_key")
|
|
925
|
+
self.reasoning_effort = kwargs.get("reasoning_effort")
|
|
926
|
+
self.safety_identifier = kwargs.get("safety_identifier")
|
|
927
|
+
self.seed = int(kwargs["seed"]) if "seed" in kwargs else None
|
|
928
|
+
self.service_tier = kwargs.get("service_tier")
|
|
929
|
+
self.stop = kwargs.get("stop")
|
|
930
|
+
self.store = bool(kwargs["store"]) if "store" in kwargs else None
|
|
931
|
+
self.temperature = float(kwargs["temperature"]) if "temperature" in kwargs else None
|
|
932
|
+
self.top_logprobs = int(kwargs["top_logprobs"]) if "top_logprobs" in kwargs else None
|
|
933
|
+
self.top_p = float(kwargs["top_p"]) if "top_p" in kwargs else None
|
|
934
|
+
self.verbosity = kwargs.get("verbosity")
|
|
935
|
+
self.stream = bool(kwargs["stream"]) if "stream" in kwargs else None
|
|
936
|
+
self.enable_thinking = bool(kwargs["enable_thinking"]) if "enable_thinking" in kwargs else None
|
|
937
|
+
self.check = kwargs.get("check")
|
|
938
|
+
self.modalities = kwargs.get("modalities", {})
|
|
939
|
+
|
|
940
|
+
def set_models(self, **kwargs):
|
|
941
|
+
models = kwargs.get("models", {})
|
|
942
|
+
self.map_models = kwargs.get("map_models", {})
|
|
943
|
+
# if 'map_models' is provided, only include models in `map_models[model_id] = provider_model_id`
|
|
944
|
+
if self.map_models:
|
|
945
|
+
self.models = {}
|
|
946
|
+
for provider_model_id in self.map_models.values():
|
|
947
|
+
if provider_model_id in models:
|
|
948
|
+
self.models[provider_model_id] = models[provider_model_id]
|
|
949
|
+
else:
|
|
950
|
+
self.models = models
|
|
951
|
+
|
|
952
|
+
include_models = kwargs.get("include_models") # string regex pattern
|
|
953
|
+
# only include models that match the regex pattern
|
|
954
|
+
if include_models:
|
|
955
|
+
_log(f"Filtering {len(self.models)} models, only including models that match regex: {include_models}")
|
|
956
|
+
self.models = {k: v for k, v in self.models.items() if re.search(include_models, k)}
|
|
957
|
+
|
|
958
|
+
exclude_models = kwargs.get("exclude_models") # string regex pattern
|
|
959
|
+
# exclude models that match the regex pattern
|
|
960
|
+
if exclude_models:
|
|
961
|
+
_log(f"Filtering {len(self.models)} models, excluding models that match regex: {exclude_models}")
|
|
962
|
+
self.models = {k: v for k, v in self.models.items() if not re.search(exclude_models, k)}
|
|
963
|
+
|
|
964
|
+
def validate(self, **kwargs):
|
|
965
|
+
if not self.api_key:
|
|
966
|
+
api_keys = ", ".join(self.env)
|
|
967
|
+
return f"Provider '{self.name}' requires API Key {api_keys}"
|
|
968
|
+
return None
|
|
969
|
+
|
|
970
|
+
def test(self, **kwargs):
|
|
971
|
+
error_msg = self.validate(**kwargs)
|
|
972
|
+
if error_msg:
|
|
973
|
+
_log(error_msg)
|
|
974
|
+
return False
|
|
975
|
+
return True
|
|
976
|
+
|
|
977
|
+
async def load(self):
|
|
978
|
+
if not self.models:
|
|
979
|
+
await self.load_models()
|
|
980
|
+
|
|
981
|
+
def model_info(self, model):
|
|
460
982
|
provider_model = self.provider_model(model) or model
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
983
|
+
for model_id, model_info in self.models.items():
|
|
984
|
+
if model_id.lower() == provider_model.lower():
|
|
985
|
+
return model_info
|
|
986
|
+
return None
|
|
987
|
+
|
|
988
|
+
def model_cost(self, model):
|
|
989
|
+
model_info = self.model_info(model)
|
|
990
|
+
return model_info.get("cost") if model_info else None
|
|
464
991
|
|
|
465
992
|
def provider_model(self, model):
|
|
466
|
-
|
|
467
|
-
|
|
993
|
+
# convert model to lowercase for case-insensitive comparison
|
|
994
|
+
model_lower = model.lower()
|
|
995
|
+
|
|
996
|
+
# if model is a map model id, return the provider model id
|
|
997
|
+
for model_id, provider_model in self.map_models.items():
|
|
998
|
+
if model_id.lower() == model_lower:
|
|
999
|
+
return provider_model
|
|
1000
|
+
|
|
1001
|
+
# if model is a provider model id, try again with just the model name
|
|
1002
|
+
for provider_model in self.map_models.values():
|
|
1003
|
+
if provider_model.lower() == model_lower:
|
|
1004
|
+
return provider_model
|
|
1005
|
+
|
|
1006
|
+
# if model is a model id, try again with just the model id or name
|
|
1007
|
+
for model_id, provider_model_info in self.models.items():
|
|
1008
|
+
id = provider_model_info.get("id") or model_id
|
|
1009
|
+
if model_id.lower() == model_lower or id.lower() == model_lower:
|
|
1010
|
+
return id
|
|
1011
|
+
name = provider_model_info.get("name")
|
|
1012
|
+
if name and name.lower() == model_lower:
|
|
1013
|
+
return id
|
|
1014
|
+
|
|
1015
|
+
# fallback to trying again with just the model short name
|
|
1016
|
+
for model_id, provider_model_info in self.models.items():
|
|
1017
|
+
id = provider_model_info.get("id") or model_id
|
|
1018
|
+
if "/" in id:
|
|
1019
|
+
model_name = id.split("/")[-1]
|
|
1020
|
+
if model_name.lower() == model_lower:
|
|
1021
|
+
return id
|
|
1022
|
+
|
|
1023
|
+
# if model is a full provider model id, try again with just the model name
|
|
1024
|
+
if "/" in model:
|
|
1025
|
+
last_part = model.split("/")[-1]
|
|
1026
|
+
return self.provider_model(last_part)
|
|
1027
|
+
|
|
468
1028
|
return None
|
|
469
1029
|
|
|
1030
|
+
def response_json(self, response):
|
|
1031
|
+
return response_json(response)
|
|
1032
|
+
|
|
470
1033
|
def to_response(self, response, chat, started_at):
|
|
471
|
-
if
|
|
472
|
-
response[
|
|
473
|
-
response[
|
|
474
|
-
if chat is not None and
|
|
475
|
-
pricing = self.
|
|
476
|
-
if pricing and
|
|
477
|
-
response[
|
|
478
|
-
_log(json.dumps(response, indent=2))
|
|
1034
|
+
if "metadata" not in response:
|
|
1035
|
+
response["metadata"] = {}
|
|
1036
|
+
response["metadata"]["duration"] = int((time.time() - started_at) * 1000)
|
|
1037
|
+
if chat is not None and "model" in chat:
|
|
1038
|
+
pricing = self.model_cost(chat["model"])
|
|
1039
|
+
if pricing and "input" in pricing and "output" in pricing:
|
|
1040
|
+
response["metadata"]["pricing"] = f"{pricing['input']}/{pricing['output']}"
|
|
479
1041
|
return response
|
|
480
1042
|
|
|
1043
|
+
def chat_summary(self, chat):
|
|
1044
|
+
return chat_summary(chat)
|
|
1045
|
+
|
|
1046
|
+
def process_chat(self, chat, provider_id=None):
|
|
1047
|
+
return process_chat(chat, provider_id)
|
|
1048
|
+
|
|
481
1049
|
async def chat(self, chat):
|
|
482
|
-
chat[
|
|
1050
|
+
chat["model"] = self.provider_model(chat["model"]) or chat["model"]
|
|
1051
|
+
|
|
1052
|
+
if "modalities" in chat:
|
|
1053
|
+
for modality in chat.get("modalities", []):
|
|
1054
|
+
# use default implementation for text modalities
|
|
1055
|
+
if modality == "text":
|
|
1056
|
+
continue
|
|
1057
|
+
modality_provider = self.modalities.get(modality)
|
|
1058
|
+
if modality_provider:
|
|
1059
|
+
return await modality_provider.chat(chat, self)
|
|
1060
|
+
else:
|
|
1061
|
+
raise Exception(f"Provider {self.name} does not support '{modality}' modality")
|
|
483
1062
|
|
|
484
1063
|
# with open(os.path.join(os.path.dirname(__file__), 'chat.wip.json'), "w") as f:
|
|
485
1064
|
# f.write(json.dumps(chat, indent=2))
|
|
486
1065
|
|
|
487
1066
|
if self.frequency_penalty is not None:
|
|
488
|
-
chat[
|
|
1067
|
+
chat["frequency_penalty"] = self.frequency_penalty
|
|
489
1068
|
if self.max_completion_tokens is not None:
|
|
490
|
-
chat[
|
|
1069
|
+
chat["max_completion_tokens"] = self.max_completion_tokens
|
|
491
1070
|
if self.n is not None:
|
|
492
|
-
chat[
|
|
1071
|
+
chat["n"] = self.n
|
|
493
1072
|
if self.parallel_tool_calls is not None:
|
|
494
|
-
chat[
|
|
1073
|
+
chat["parallel_tool_calls"] = self.parallel_tool_calls
|
|
495
1074
|
if self.presence_penalty is not None:
|
|
496
|
-
chat[
|
|
1075
|
+
chat["presence_penalty"] = self.presence_penalty
|
|
497
1076
|
if self.prompt_cache_key is not None:
|
|
498
|
-
chat[
|
|
1077
|
+
chat["prompt_cache_key"] = self.prompt_cache_key
|
|
499
1078
|
if self.reasoning_effort is not None:
|
|
500
|
-
chat[
|
|
1079
|
+
chat["reasoning_effort"] = self.reasoning_effort
|
|
501
1080
|
if self.safety_identifier is not None:
|
|
502
|
-
chat[
|
|
1081
|
+
chat["safety_identifier"] = self.safety_identifier
|
|
503
1082
|
if self.seed is not None:
|
|
504
|
-
chat[
|
|
1083
|
+
chat["seed"] = self.seed
|
|
505
1084
|
if self.service_tier is not None:
|
|
506
|
-
chat[
|
|
1085
|
+
chat["service_tier"] = self.service_tier
|
|
507
1086
|
if self.stop is not None:
|
|
508
|
-
chat[
|
|
1087
|
+
chat["stop"] = self.stop
|
|
509
1088
|
if self.store is not None:
|
|
510
|
-
chat[
|
|
1089
|
+
chat["store"] = self.store
|
|
511
1090
|
if self.temperature is not None:
|
|
512
|
-
chat[
|
|
1091
|
+
chat["temperature"] = self.temperature
|
|
513
1092
|
if self.top_logprobs is not None:
|
|
514
|
-
chat[
|
|
1093
|
+
chat["top_logprobs"] = self.top_logprobs
|
|
515
1094
|
if self.top_p is not None:
|
|
516
|
-
chat[
|
|
1095
|
+
chat["top_p"] = self.top_p
|
|
517
1096
|
if self.verbosity is not None:
|
|
518
|
-
chat[
|
|
1097
|
+
chat["verbosity"] = self.verbosity
|
|
519
1098
|
if self.enable_thinking is not None:
|
|
520
|
-
chat[
|
|
1099
|
+
chat["enable_thinking"] = self.enable_thinking
|
|
521
1100
|
|
|
522
|
-
chat = await process_chat(chat)
|
|
1101
|
+
chat = await process_chat(chat, provider_id=self.id)
|
|
523
1102
|
_log(f"POST {self.chat_url}")
|
|
524
1103
|
_log(chat_summary(chat))
|
|
525
1104
|
# remove metadata if any (conflicts with some providers, e.g. Z.ai)
|
|
526
|
-
chat.pop(
|
|
1105
|
+
metadata = chat.pop("metadata", None)
|
|
527
1106
|
|
|
528
1107
|
async with aiohttp.ClientSession() as session:
|
|
529
1108
|
started_at = time.time()
|
|
530
|
-
async with session.post(
|
|
1109
|
+
async with session.post(
|
|
1110
|
+
self.chat_url, headers=self.headers, data=json.dumps(chat), timeout=aiohttp.ClientTimeout(total=120)
|
|
1111
|
+
) as response:
|
|
1112
|
+
chat["metadata"] = metadata
|
|
531
1113
|
return self.to_response(await response_json(response), chat, started_at)
|
|
532
1114
|
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
1115
|
+
|
|
1116
|
+
class MistralProvider(OpenAiCompatible):
|
|
1117
|
+
sdk = "@ai-sdk/mistral"
|
|
1118
|
+
|
|
1119
|
+
def __init__(self, **kwargs):
|
|
1120
|
+
if "api" not in kwargs:
|
|
1121
|
+
kwargs["api"] = "https://api.mistral.ai/v1"
|
|
1122
|
+
super().__init__(**kwargs)
|
|
1123
|
+
|
|
1124
|
+
|
|
1125
|
+
class GroqProvider(OpenAiCompatible):
|
|
1126
|
+
sdk = "@ai-sdk/groq"
|
|
1127
|
+
|
|
1128
|
+
def __init__(self, **kwargs):
|
|
1129
|
+
if "api" not in kwargs:
|
|
1130
|
+
kwargs["api"] = "https://api.groq.com/openai/v1"
|
|
1131
|
+
super().__init__(**kwargs)
|
|
1132
|
+
|
|
1133
|
+
|
|
1134
|
+
class XaiProvider(OpenAiCompatible):
|
|
1135
|
+
sdk = "@ai-sdk/xai"
|
|
1136
|
+
|
|
1137
|
+
def __init__(self, **kwargs):
|
|
1138
|
+
if "api" not in kwargs:
|
|
1139
|
+
kwargs["api"] = "https://api.x.ai/v1"
|
|
1140
|
+
super().__init__(**kwargs)
|
|
1141
|
+
|
|
1142
|
+
|
|
1143
|
+
class CodestralProvider(OpenAiCompatible):
|
|
1144
|
+
sdk = "codestral"
|
|
1145
|
+
|
|
1146
|
+
def __init__(self, **kwargs):
|
|
1147
|
+
super().__init__(**kwargs)
|
|
1148
|
+
|
|
1149
|
+
|
|
1150
|
+
class OllamaProvider(OpenAiCompatible):
|
|
1151
|
+
sdk = "ollama"
|
|
1152
|
+
|
|
1153
|
+
def __init__(self, **kwargs):
|
|
1154
|
+
super().__init__(**kwargs)
|
|
1155
|
+
# Ollama's OpenAI-compatible endpoint is at /v1/chat/completions
|
|
1156
|
+
self.chat_url = f"{self.api}/v1/chat/completions"
|
|
537
1157
|
|
|
538
1158
|
async def load(self):
|
|
539
|
-
if self.
|
|
540
|
-
await self.load_models(
|
|
1159
|
+
if not self.models:
|
|
1160
|
+
await self.load_models()
|
|
541
1161
|
|
|
542
1162
|
async def get_models(self):
|
|
543
1163
|
ret = {}
|
|
544
1164
|
try:
|
|
545
1165
|
async with aiohttp.ClientSession() as session:
|
|
546
|
-
_log(f"GET {self.
|
|
547
|
-
async with session.get(
|
|
1166
|
+
_log(f"GET {self.api}/api/tags")
|
|
1167
|
+
async with session.get(
|
|
1168
|
+
f"{self.api}/api/tags", headers=self.headers, timeout=aiohttp.ClientTimeout(total=120)
|
|
1169
|
+
) as response:
|
|
548
1170
|
data = await response_json(response)
|
|
549
|
-
for model in data.get(
|
|
550
|
-
|
|
551
|
-
if
|
|
552
|
-
|
|
553
|
-
ret[
|
|
1171
|
+
for model in data.get("models", []):
|
|
1172
|
+
model_id = model["model"]
|
|
1173
|
+
if model_id.endswith(":latest"):
|
|
1174
|
+
model_id = model_id[:-7]
|
|
1175
|
+
ret[model_id] = model_id
|
|
554
1176
|
_log(f"Loaded Ollama models: {ret}")
|
|
555
1177
|
except Exception as e:
|
|
556
1178
|
_log(f"Error getting Ollama models: {e}")
|
|
557
1179
|
# return empty dict if ollama is not available
|
|
558
1180
|
return ret
|
|
559
1181
|
|
|
560
|
-
async def load_models(self
|
|
1182
|
+
async def load_models(self):
|
|
561
1183
|
"""Load models if all_models was requested"""
|
|
562
|
-
if self.all_models:
|
|
563
|
-
self.models = await self.get_models()
|
|
564
|
-
if default_models:
|
|
565
|
-
self.models = {**default_models, **self.models}
|
|
566
|
-
|
|
567
|
-
@classmethod
|
|
568
|
-
def test(cls, base_url=None, models={}, all_models=False, **kwargs):
|
|
569
|
-
return base_url and (len(models) > 0 or all_models)
|
|
570
|
-
|
|
571
|
-
class GoogleOpenAiProvider(OpenAiProvider):
|
|
572
|
-
def __init__(self, api_key, models, **kwargs):
|
|
573
|
-
super().__init__(base_url="https://generativelanguage.googleapis.com", api_key=api_key, models=models, **kwargs)
|
|
574
|
-
self.chat_url = "https://generativelanguage.googleapis.com/v1beta/chat/completions"
|
|
575
|
-
|
|
576
|
-
@classmethod
|
|
577
|
-
def test(cls, api_key=None, models={}, **kwargs):
|
|
578
|
-
return api_key and len(models) > 0
|
|
579
|
-
|
|
580
|
-
class GoogleProvider(OpenAiProvider):
|
|
581
|
-
def __init__(self, models, api_key, safety_settings=None, thinking_config=None, curl=False, **kwargs):
|
|
582
|
-
super().__init__(base_url="https://generativelanguage.googleapis.com", api_key=api_key, models=models, **kwargs)
|
|
583
|
-
self.safety_settings = safety_settings
|
|
584
|
-
self.thinking_config = thinking_config
|
|
585
|
-
self.curl = curl
|
|
586
|
-
self.headers = kwargs['headers'] if 'headers' in kwargs else {
|
|
587
|
-
"Content-Type": "application/json",
|
|
588
|
-
}
|
|
589
|
-
# Google fails when using Authorization header, use query string param instead
|
|
590
|
-
if 'Authorization' in self.headers:
|
|
591
|
-
del self.headers['Authorization']
|
|
592
|
-
|
|
593
|
-
@classmethod
|
|
594
|
-
def test(cls, api_key=None, models={}, **kwargs):
|
|
595
|
-
return api_key is not None and len(models) > 0
|
|
596
1184
|
|
|
597
|
-
|
|
598
|
-
|
|
1185
|
+
# Map models to provider models {model_id:model_id}
|
|
1186
|
+
model_map = await self.get_models()
|
|
1187
|
+
if self.map_models:
|
|
1188
|
+
map_model_values = set(self.map_models.values())
|
|
1189
|
+
to = {}
|
|
1190
|
+
for k, v in model_map.items():
|
|
1191
|
+
if k in self.map_models:
|
|
1192
|
+
to[k] = v
|
|
1193
|
+
if v in map_model_values:
|
|
1194
|
+
to[k] = v
|
|
1195
|
+
model_map = to
|
|
1196
|
+
else:
|
|
1197
|
+
self.map_models = model_map
|
|
1198
|
+
models = {}
|
|
1199
|
+
for k, v in model_map.items():
|
|
1200
|
+
models[k] = {
|
|
1201
|
+
"id": k,
|
|
1202
|
+
"name": v.replace(":", " "),
|
|
1203
|
+
"modalities": {"input": ["text"], "output": ["text"]},
|
|
1204
|
+
"cost": {
|
|
1205
|
+
"input": 0,
|
|
1206
|
+
"output": 0,
|
|
1207
|
+
},
|
|
1208
|
+
}
|
|
1209
|
+
self.models = models
|
|
599
1210
|
|
|
600
|
-
|
|
601
|
-
|
|
1211
|
+
def validate(self, **kwargs):
|
|
1212
|
+
return None
|
|
602
1213
|
|
|
603
|
-
# Filter out system messages and convert to proper Gemini format
|
|
604
|
-
contents = []
|
|
605
|
-
system_prompt = None
|
|
606
1214
|
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
if message['role'] == 'system':
|
|
610
|
-
content = message['content']
|
|
611
|
-
if isinstance(content, list):
|
|
612
|
-
for item in content:
|
|
613
|
-
if 'text' in item:
|
|
614
|
-
system_prompt = item['text']
|
|
615
|
-
break
|
|
616
|
-
elif isinstance(content, str):
|
|
617
|
-
system_prompt = content
|
|
618
|
-
elif 'content' in message:
|
|
619
|
-
if isinstance(message['content'], list):
|
|
620
|
-
parts = []
|
|
621
|
-
for item in message['content']:
|
|
622
|
-
if 'type' in item:
|
|
623
|
-
if item['type'] == 'image_url' and 'image_url' in item:
|
|
624
|
-
image_url = item['image_url']
|
|
625
|
-
if 'url' not in image_url:
|
|
626
|
-
continue
|
|
627
|
-
url = image_url['url']
|
|
628
|
-
if not url.startswith('data:'):
|
|
629
|
-
raise(Exception("Image was not downloaded: " + url))
|
|
630
|
-
# Extract mime type from data uri
|
|
631
|
-
mimetype = url.split(';',1)[0].split(':',1)[1] if ';' in url else "image/png"
|
|
632
|
-
base64Data = url.split(',',1)[1]
|
|
633
|
-
parts.append({
|
|
634
|
-
"inline_data": {
|
|
635
|
-
"mime_type": mimetype,
|
|
636
|
-
"data": base64Data
|
|
637
|
-
}
|
|
638
|
-
})
|
|
639
|
-
elif item['type'] == 'input_audio' and 'input_audio' in item:
|
|
640
|
-
input_audio = item['input_audio']
|
|
641
|
-
if 'data' not in input_audio:
|
|
642
|
-
continue
|
|
643
|
-
data = input_audio['data']
|
|
644
|
-
format = input_audio['format']
|
|
645
|
-
mimetype = f"audio/{format}"
|
|
646
|
-
parts.append({
|
|
647
|
-
"inline_data": {
|
|
648
|
-
"mime_type": mimetype,
|
|
649
|
-
"data": data
|
|
650
|
-
}
|
|
651
|
-
})
|
|
652
|
-
elif item['type'] == 'file' and 'file' in item:
|
|
653
|
-
file = item['file']
|
|
654
|
-
if 'file_data' not in file:
|
|
655
|
-
continue
|
|
656
|
-
data = file['file_data']
|
|
657
|
-
if not data.startswith('data:'):
|
|
658
|
-
raise(Exception("File was not downloaded: " + data))
|
|
659
|
-
# Extract mime type from data uri
|
|
660
|
-
mimetype = data.split(';',1)[0].split(':',1)[1] if ';' in data else "application/octet-stream"
|
|
661
|
-
base64Data = data.split(',',1)[1]
|
|
662
|
-
parts.append({
|
|
663
|
-
"inline_data": {
|
|
664
|
-
"mime_type": mimetype,
|
|
665
|
-
"data": base64Data
|
|
666
|
-
}
|
|
667
|
-
})
|
|
668
|
-
if 'text' in item:
|
|
669
|
-
text = item['text']
|
|
670
|
-
parts.append({"text": text})
|
|
671
|
-
if len(parts) > 0:
|
|
672
|
-
contents.append({
|
|
673
|
-
"role": message['role'] if 'role' in message and message['role'] == 'user' else 'model',
|
|
674
|
-
"parts": parts
|
|
675
|
-
})
|
|
676
|
-
else:
|
|
677
|
-
content = message['content']
|
|
678
|
-
contents.append({
|
|
679
|
-
"role": message['role'] if 'role' in message and message['role'] == 'user' else 'model',
|
|
680
|
-
"parts": [{"text": content}]
|
|
681
|
-
})
|
|
682
|
-
|
|
683
|
-
gemini_chat = {
|
|
684
|
-
"contents": contents,
|
|
685
|
-
}
|
|
1215
|
+
class LMStudioProvider(OllamaProvider):
|
|
1216
|
+
sdk = "lmstudio"
|
|
686
1217
|
|
|
687
|
-
|
|
688
|
-
|
|
1218
|
+
def __init__(self, **kwargs):
|
|
1219
|
+
super().__init__(**kwargs)
|
|
1220
|
+
self.chat_url = f"{self.api}/chat/completions"
|
|
689
1221
|
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
}
|
|
1222
|
+
async def get_models(self):
|
|
1223
|
+
ret = {}
|
|
1224
|
+
try:
|
|
1225
|
+
async with aiohttp.ClientSession() as session:
|
|
1226
|
+
_log(f"GET {self.api}/models")
|
|
1227
|
+
async with session.get(
|
|
1228
|
+
f"{self.api}/models", headers=self.headers, timeout=aiohttp.ClientTimeout(total=120)
|
|
1229
|
+
) as response:
|
|
1230
|
+
data = await response_json(response)
|
|
1231
|
+
for model in data.get("data", []):
|
|
1232
|
+
id = model["id"]
|
|
1233
|
+
ret[id] = id
|
|
1234
|
+
_log(f"Loaded LMStudio models: {ret}")
|
|
1235
|
+
except Exception as e:
|
|
1236
|
+
_log(f"Error getting LMStudio models: {e}")
|
|
1237
|
+
# return empty dict if ollama is not available
|
|
1238
|
+
return ret
|
|
695
1239
|
|
|
696
|
-
if 'max_completion_tokens' in chat:
|
|
697
|
-
generationConfig['maxOutputTokens'] = chat['max_completion_tokens']
|
|
698
|
-
if 'stop' in chat:
|
|
699
|
-
generationConfig['stopSequences'] = [chat['stop']]
|
|
700
|
-
if 'temperature' in chat:
|
|
701
|
-
generationConfig['temperature'] = chat['temperature']
|
|
702
|
-
if 'top_p' in chat:
|
|
703
|
-
generationConfig['topP'] = chat['top_p']
|
|
704
|
-
if 'top_logprobs' in chat:
|
|
705
|
-
generationConfig['topK'] = chat['top_logprobs']
|
|
706
|
-
|
|
707
|
-
if 'thinkingConfig' in chat:
|
|
708
|
-
generationConfig['thinkingConfig'] = chat['thinkingConfig']
|
|
709
|
-
elif self.thinking_config:
|
|
710
|
-
generationConfig['thinkingConfig'] = self.thinking_config
|
|
711
|
-
|
|
712
|
-
if len(generationConfig) > 0:
|
|
713
|
-
gemini_chat['generationConfig'] = generationConfig
|
|
714
|
-
|
|
715
|
-
started_at = int(time.time() * 1000)
|
|
716
|
-
gemini_chat_url = f"https://generativelanguage.googleapis.com/v1beta/models/{chat['model']}:generateContent?key={self.api_key}"
|
|
717
|
-
|
|
718
|
-
_log(f"POST {gemini_chat_url}")
|
|
719
|
-
_log(gemini_chat_summary(gemini_chat))
|
|
720
|
-
started_at = time.time()
|
|
721
1240
|
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
gemini_chat_url
|
|
729
|
-
]
|
|
730
|
-
try:
|
|
731
|
-
o = subprocess.run(curl_args, check=True, capture_output=True, text=True, timeout=120)
|
|
732
|
-
obj = json.loads(o.stdout)
|
|
733
|
-
except Exception as e:
|
|
734
|
-
raise Exception(f"Error executing curl: {e}")
|
|
735
|
-
else:
|
|
736
|
-
async with session.post(gemini_chat_url, headers=self.headers, data=json.dumps(gemini_chat), timeout=aiohttp.ClientTimeout(total=120)) as res:
|
|
737
|
-
obj = await response_json(res)
|
|
738
|
-
_log(f"google response:\n{json.dumps(obj, indent=2)}")
|
|
739
|
-
|
|
740
|
-
response = {
|
|
741
|
-
"id": f"chatcmpl-{started_at}",
|
|
742
|
-
"created": started_at,
|
|
743
|
-
"model": obj.get('modelVersion', chat['model']),
|
|
744
|
-
}
|
|
745
|
-
choices = []
|
|
746
|
-
i = 0
|
|
747
|
-
if 'error' in obj:
|
|
748
|
-
_log(f"Error: {obj['error']}")
|
|
749
|
-
raise Exception(obj['error']['message'])
|
|
750
|
-
for candidate in obj['candidates']:
|
|
751
|
-
role = "assistant"
|
|
752
|
-
if 'content' in candidate and 'role' in candidate['content']:
|
|
753
|
-
role = "assistant" if candidate['content']['role'] == 'model' else candidate['content']['role']
|
|
754
|
-
|
|
755
|
-
# Safely extract content from all text parts
|
|
756
|
-
content = ""
|
|
757
|
-
reasoning = ""
|
|
758
|
-
if 'content' in candidate and 'parts' in candidate['content']:
|
|
759
|
-
text_parts = []
|
|
760
|
-
reasoning_parts = []
|
|
761
|
-
for part in candidate['content']['parts']:
|
|
762
|
-
if 'text' in part:
|
|
763
|
-
if 'thought' in part and part['thought']:
|
|
764
|
-
reasoning_parts.append(part['text'])
|
|
765
|
-
else:
|
|
766
|
-
text_parts.append(part['text'])
|
|
767
|
-
content = ' '.join(text_parts)
|
|
768
|
-
reasoning = ' '.join(reasoning_parts)
|
|
1241
|
+
def get_provider_model(model_name):
|
|
1242
|
+
for provider in g_handlers.values():
|
|
1243
|
+
provider_model = provider.provider_model(model_name)
|
|
1244
|
+
if provider_model:
|
|
1245
|
+
return provider_model
|
|
1246
|
+
return None
|
|
769
1247
|
|
|
770
|
-
choice = {
|
|
771
|
-
"index": i,
|
|
772
|
-
"finish_reason": candidate.get('finishReason', 'stop'),
|
|
773
|
-
"message": {
|
|
774
|
-
"role": role,
|
|
775
|
-
"content": content,
|
|
776
|
-
},
|
|
777
|
-
}
|
|
778
|
-
if reasoning:
|
|
779
|
-
choice['message']['reasoning'] = reasoning
|
|
780
|
-
choices.append(choice)
|
|
781
|
-
i += 1
|
|
782
|
-
response['choices'] = choices
|
|
783
|
-
if 'usageMetadata' in obj:
|
|
784
|
-
usage = obj['usageMetadata']
|
|
785
|
-
response['usage'] = {
|
|
786
|
-
"completion_tokens": usage['candidatesTokenCount'],
|
|
787
|
-
"total_tokens": usage['totalTokenCount'],
|
|
788
|
-
"prompt_tokens": usage['promptTokenCount'],
|
|
789
|
-
}
|
|
790
|
-
return self.to_response(response, chat, started_at)
|
|
791
1248
|
|
|
792
1249
|
def get_models():
|
|
793
1250
|
ret = []
|
|
794
1251
|
for provider in g_handlers.values():
|
|
795
|
-
for model in provider.models
|
|
1252
|
+
for model in provider.models:
|
|
796
1253
|
if model not in ret:
|
|
797
1254
|
ret.append(model)
|
|
798
1255
|
ret.sort()
|
|
799
1256
|
return ret
|
|
800
1257
|
|
|
1258
|
+
|
|
801
1259
|
def get_active_models():
|
|
802
1260
|
ret = []
|
|
803
1261
|
existing_models = set()
|
|
804
|
-
for
|
|
805
|
-
for model in provider.models.
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
})
|
|
1262
|
+
for provider_id, provider in g_handlers.items():
|
|
1263
|
+
for model in provider.models.values():
|
|
1264
|
+
name = model.get("name")
|
|
1265
|
+
if not name:
|
|
1266
|
+
_log(f"Provider {provider_id} model {model} has no name")
|
|
1267
|
+
continue
|
|
1268
|
+
if name not in existing_models:
|
|
1269
|
+
existing_models.add(name)
|
|
1270
|
+
item = model.copy()
|
|
1271
|
+
item.update({"provider": provider_id})
|
|
1272
|
+
ret.append(item)
|
|
816
1273
|
ret.sort(key=lambda x: x["id"])
|
|
817
1274
|
return ret
|
|
818
1275
|
|
|
819
|
-
async def chat_completion(chat):
|
|
820
|
-
model = chat['model']
|
|
821
|
-
# get first provider that has the model
|
|
822
|
-
candidate_providers = [name for name, provider in g_handlers.items() if model in provider.models]
|
|
823
|
-
if len(candidate_providers) == 0:
|
|
824
|
-
raise(Exception(f"Model {model} not found"))
|
|
825
1276
|
|
|
1277
|
+
def api_providers():
|
|
1278
|
+
ret = []
|
|
1279
|
+
for id, provider in g_handlers.items():
|
|
1280
|
+
ret.append({"id": id, "name": provider.name, "models": provider.models})
|
|
1281
|
+
return ret
|
|
1282
|
+
|
|
1283
|
+
|
|
1284
|
+
def to_error_message(e):
|
|
1285
|
+
return str(e)
|
|
1286
|
+
|
|
1287
|
+
|
|
1288
|
+
def to_error_response(e, stacktrace=False):
|
|
1289
|
+
status = {"errorCode": "Error", "message": to_error_message(e)}
|
|
1290
|
+
if stacktrace:
|
|
1291
|
+
status["stackTrace"] = traceback.format_exc()
|
|
1292
|
+
return {"responseStatus": status}
|
|
1293
|
+
|
|
1294
|
+
|
|
1295
|
+
def create_error_response(message, error_code="Error", stack_trace=None):
|
|
1296
|
+
ret = {"responseStatus": {"errorCode": error_code, "message": message}}
|
|
1297
|
+
if stack_trace:
|
|
1298
|
+
ret["responseStatus"]["stackTrace"] = stack_trace
|
|
1299
|
+
return ret
|
|
1300
|
+
|
|
1301
|
+
|
|
1302
|
+
def should_cancel_thread(context):
|
|
1303
|
+
ret = context.get("cancelled", False)
|
|
1304
|
+
if ret:
|
|
1305
|
+
thread_id = context.get("threadId")
|
|
1306
|
+
_dbg(f"Thread cancelled {thread_id}")
|
|
1307
|
+
return ret
|
|
1308
|
+
|
|
1309
|
+
|
|
1310
|
+
def g_chat_request(template=None, text=None, model=None, system_prompt=None):
|
|
1311
|
+
chat_template = g_config["defaults"].get(template or "text")
|
|
1312
|
+
if not chat_template:
|
|
1313
|
+
raise Exception(f"Chat template '{template}' not found")
|
|
1314
|
+
|
|
1315
|
+
chat = chat_template.copy()
|
|
1316
|
+
if model:
|
|
1317
|
+
chat["model"] = model
|
|
1318
|
+
if system_prompt is not None:
|
|
1319
|
+
chat["messages"].insert(0, {"role": "system", "content": system_prompt})
|
|
1320
|
+
if text is not None:
|
|
1321
|
+
if not chat["messages"] or len(chat["messages"]) == 0:
|
|
1322
|
+
chat["messages"] = [{"role": "user", "content": [{"type": "text", "text": ""}]}]
|
|
1323
|
+
|
|
1324
|
+
# replace content of last message if exists, else add
|
|
1325
|
+
last_msg = chat["messages"][-1] if "messages" in chat else None
|
|
1326
|
+
if last_msg and last_msg["role"] == "user":
|
|
1327
|
+
if isinstance(last_msg["content"], list):
|
|
1328
|
+
last_msg["content"][-1]["text"] = text
|
|
1329
|
+
else:
|
|
1330
|
+
last_msg["content"] = text
|
|
1331
|
+
else:
|
|
1332
|
+
chat["messages"].append({"role": "user", "content": text})
|
|
1333
|
+
|
|
1334
|
+
return chat
|
|
1335
|
+
|
|
1336
|
+
|
|
1337
|
+
async def g_chat_completion(chat, context=None):
|
|
1338
|
+
try:
|
|
1339
|
+
model = chat.get("model")
|
|
1340
|
+
if not model:
|
|
1341
|
+
raise Exception("Model not specified")
|
|
1342
|
+
|
|
1343
|
+
if context is None:
|
|
1344
|
+
context = {"chat": chat, "tools": "all"}
|
|
1345
|
+
|
|
1346
|
+
# get first provider that has the model
|
|
1347
|
+
candidate_providers = [name for name, provider in g_handlers.items() if provider.provider_model(model)]
|
|
1348
|
+
if len(candidate_providers) == 0:
|
|
1349
|
+
raise (Exception(f"Model {model} not found"))
|
|
1350
|
+
except Exception as e:
|
|
1351
|
+
await g_app.on_chat_error(e, context or {"chat": chat})
|
|
1352
|
+
raise e
|
|
1353
|
+
|
|
1354
|
+
started_at = time.time()
|
|
826
1355
|
first_exception = None
|
|
1356
|
+
provider_name = "Unknown"
|
|
827
1357
|
for name in candidate_providers:
|
|
828
|
-
provider = g_handlers[name]
|
|
829
|
-
_log(f"provider: {name} {type(provider).__name__}")
|
|
830
1358
|
try:
|
|
831
|
-
|
|
832
|
-
|
|
1359
|
+
provider_name = name
|
|
1360
|
+
provider = g_handlers[name]
|
|
1361
|
+
_log(f"provider: {name} {type(provider).__name__}")
|
|
1362
|
+
started_at = time.time()
|
|
1363
|
+
context["startedAt"] = datetime.now()
|
|
1364
|
+
context["provider"] = name
|
|
1365
|
+
model_info = provider.model_info(model)
|
|
1366
|
+
context["modelCost"] = model_info.get("cost", provider.model_cost(model)) or {"input": 0, "output": 0}
|
|
1367
|
+
context["modelInfo"] = model_info
|
|
1368
|
+
|
|
1369
|
+
# Accumulate usage across tool calls
|
|
1370
|
+
total_usage = {
|
|
1371
|
+
"prompt_tokens": 0,
|
|
1372
|
+
"completion_tokens": 0,
|
|
1373
|
+
"total_tokens": 0,
|
|
1374
|
+
}
|
|
1375
|
+
accumulated_cost = 0.0
|
|
1376
|
+
|
|
1377
|
+
# Inject global tools if present
|
|
1378
|
+
current_chat = chat.copy()
|
|
1379
|
+
if g_app.tool_definitions:
|
|
1380
|
+
only_tools_str = context.get("tools", "all")
|
|
1381
|
+
include_all_tools = only_tools_str == "all"
|
|
1382
|
+
only_tools = only_tools_str.split(",")
|
|
1383
|
+
|
|
1384
|
+
if include_all_tools or len(only_tools) > 0:
|
|
1385
|
+
if "tools" not in current_chat:
|
|
1386
|
+
current_chat["tools"] = []
|
|
1387
|
+
|
|
1388
|
+
existing_tools = {t["function"]["name"] for t in current_chat["tools"]}
|
|
1389
|
+
for tool_def in g_app.tool_definitions:
|
|
1390
|
+
name = tool_def["function"]["name"]
|
|
1391
|
+
if name not in existing_tools and (include_all_tools or name in only_tools):
|
|
1392
|
+
current_chat["tools"].append(tool_def)
|
|
1393
|
+
|
|
1394
|
+
# Apply pre-chat filters ONCE
|
|
1395
|
+
context["chat"] = current_chat
|
|
1396
|
+
for filter_func in g_app.chat_request_filters:
|
|
1397
|
+
await filter_func(current_chat, context)
|
|
1398
|
+
|
|
1399
|
+
# Tool execution loop
|
|
1400
|
+
max_iterations = 10
|
|
1401
|
+
tool_history = []
|
|
1402
|
+
final_response = None
|
|
1403
|
+
|
|
1404
|
+
for _ in range(max_iterations):
|
|
1405
|
+
if should_cancel_thread(context):
|
|
1406
|
+
return
|
|
1407
|
+
|
|
1408
|
+
response = await provider.chat(current_chat)
|
|
1409
|
+
|
|
1410
|
+
if should_cancel_thread(context):
|
|
1411
|
+
return
|
|
1412
|
+
|
|
1413
|
+
# Aggregate usage
|
|
1414
|
+
if "usage" in response:
|
|
1415
|
+
usage = response["usage"]
|
|
1416
|
+
total_usage["prompt_tokens"] += usage.get("prompt_tokens", 0)
|
|
1417
|
+
total_usage["completion_tokens"] += usage.get("completion_tokens", 0)
|
|
1418
|
+
total_usage["total_tokens"] += usage.get("total_tokens", 0)
|
|
1419
|
+
|
|
1420
|
+
# Calculate cost for this step if available
|
|
1421
|
+
if "cost" in response and isinstance(response["cost"], (int, float)):
|
|
1422
|
+
accumulated_cost += response["cost"]
|
|
1423
|
+
elif "cost" in usage and isinstance(usage["cost"], (int, float)):
|
|
1424
|
+
accumulated_cost += usage["cost"]
|
|
1425
|
+
|
|
1426
|
+
# Check for tool_calls in the response
|
|
1427
|
+
choice = response.get("choices", [])[0] if response.get("choices") else {}
|
|
1428
|
+
message = choice.get("message", {})
|
|
1429
|
+
tool_calls = message.get("tool_calls")
|
|
1430
|
+
|
|
1431
|
+
if tool_calls:
|
|
1432
|
+
# Append the assistant's message with tool calls to history
|
|
1433
|
+
if "messages" not in current_chat:
|
|
1434
|
+
current_chat["messages"] = []
|
|
1435
|
+
if "timestamp" not in message:
|
|
1436
|
+
message["timestamp"] = int(time.time() * 1000)
|
|
1437
|
+
current_chat["messages"].append(message)
|
|
1438
|
+
tool_history.append(message)
|
|
1439
|
+
|
|
1440
|
+
await g_app.on_chat_tool(current_chat, context)
|
|
1441
|
+
|
|
1442
|
+
for tool_call in tool_calls:
|
|
1443
|
+
function_name = tool_call["function"]["name"]
|
|
1444
|
+
try:
|
|
1445
|
+
function_args = json.loads(tool_call["function"]["arguments"])
|
|
1446
|
+
except Exception as e:
|
|
1447
|
+
tool_result = f"Error parsing JSON arguments for tool {function_name}: {e}"
|
|
1448
|
+
else:
|
|
1449
|
+
tool_result = f"Error: Tool {function_name} not found"
|
|
1450
|
+
if function_name in g_app.tools:
|
|
1451
|
+
try:
|
|
1452
|
+
func = g_app.tools[function_name]
|
|
1453
|
+
if inspect.iscoroutinefunction(func):
|
|
1454
|
+
tool_result = await func(**function_args)
|
|
1455
|
+
else:
|
|
1456
|
+
tool_result = func(**function_args)
|
|
1457
|
+
except Exception as e:
|
|
1458
|
+
tool_result = f"Error executing tool {function_name}: {e}"
|
|
1459
|
+
|
|
1460
|
+
# Append tool result to history
|
|
1461
|
+
tool_msg = {"role": "tool", "tool_call_id": tool_call["id"], "content": to_content(tool_result)}
|
|
1462
|
+
current_chat["messages"].append(tool_msg)
|
|
1463
|
+
tool_history.append(tool_msg)
|
|
1464
|
+
|
|
1465
|
+
await g_app.on_chat_tool(current_chat, context)
|
|
1466
|
+
|
|
1467
|
+
if should_cancel_thread(context):
|
|
1468
|
+
return
|
|
1469
|
+
|
|
1470
|
+
# Continue loop to send tool results back to LLM
|
|
1471
|
+
continue
|
|
1472
|
+
|
|
1473
|
+
# If no tool calls, this is the final response
|
|
1474
|
+
if tool_history:
|
|
1475
|
+
response["tool_history"] = tool_history
|
|
1476
|
+
|
|
1477
|
+
# Update final response with aggregated usage
|
|
1478
|
+
if "usage" not in response:
|
|
1479
|
+
response["usage"] = {}
|
|
1480
|
+
# convert to int seconds
|
|
1481
|
+
context["duration"] = duration = int(time.time() - started_at)
|
|
1482
|
+
total_usage.update({"duration": duration})
|
|
1483
|
+
response["usage"].update(total_usage)
|
|
1484
|
+
# If we accumulated cost, set it on the response
|
|
1485
|
+
if accumulated_cost > 0:
|
|
1486
|
+
response["cost"] = accumulated_cost
|
|
1487
|
+
|
|
1488
|
+
final_response = response
|
|
1489
|
+
break # Exit tool loop
|
|
1490
|
+
|
|
1491
|
+
if final_response:
|
|
1492
|
+
# Apply post-chat filters ONCE on final response
|
|
1493
|
+
for filter_func in g_app.chat_response_filters:
|
|
1494
|
+
await filter_func(final_response, context)
|
|
1495
|
+
|
|
1496
|
+
if DEBUG:
|
|
1497
|
+
_dbg(json.dumps(final_response, indent=2))
|
|
1498
|
+
|
|
1499
|
+
return final_response
|
|
1500
|
+
|
|
833
1501
|
except Exception as e:
|
|
834
1502
|
if first_exception is None:
|
|
835
1503
|
first_exception = e
|
|
836
|
-
|
|
1504
|
+
context["stackTrace"] = traceback.format_exc()
|
|
1505
|
+
_err(f"Provider {provider_name} failed", first_exception)
|
|
1506
|
+
await g_app.on_chat_error(e, context)
|
|
1507
|
+
|
|
837
1508
|
continue
|
|
838
1509
|
|
|
839
1510
|
# If we get here, all providers failed
|
|
840
1511
|
raise first_exception
|
|
841
1512
|
|
|
842
|
-
|
|
1513
|
+
|
|
1514
|
+
async def cli_chat(chat, tools=None, image=None, audio=None, file=None, args=None, raw=False):
|
|
843
1515
|
if g_default_model:
|
|
844
|
-
chat[
|
|
1516
|
+
chat["model"] = g_default_model
|
|
845
1517
|
|
|
846
1518
|
# Apply args parameters to chat request
|
|
847
1519
|
if args:
|
|
@@ -850,176 +1522,244 @@ async def cli_chat(chat, image=None, audio=None, file=None, args=None, raw=False
|
|
|
850
1522
|
# process_chat downloads the image, just adding the reference here
|
|
851
1523
|
if image is not None:
|
|
852
1524
|
first_message = None
|
|
853
|
-
for message in chat[
|
|
854
|
-
if message[
|
|
1525
|
+
for message in chat["messages"]:
|
|
1526
|
+
if message["role"] == "user":
|
|
855
1527
|
first_message = message
|
|
856
1528
|
break
|
|
857
|
-
image_content = {
|
|
858
|
-
|
|
859
|
-
"
|
|
860
|
-
"url": image
|
|
861
|
-
}
|
|
862
|
-
}
|
|
863
|
-
if 'content' in first_message:
|
|
864
|
-
if isinstance(first_message['content'], list):
|
|
1529
|
+
image_content = {"type": "image_url", "image_url": {"url": image}}
|
|
1530
|
+
if "content" in first_message:
|
|
1531
|
+
if isinstance(first_message["content"], list):
|
|
865
1532
|
image_url = None
|
|
866
|
-
for item in first_message[
|
|
867
|
-
if
|
|
868
|
-
image_url = item[
|
|
1533
|
+
for item in first_message["content"]:
|
|
1534
|
+
if "image_url" in item:
|
|
1535
|
+
image_url = item["image_url"]
|
|
869
1536
|
# If no image_url, add one
|
|
870
1537
|
if image_url is None:
|
|
871
|
-
first_message[
|
|
1538
|
+
first_message["content"].insert(0, image_content)
|
|
872
1539
|
else:
|
|
873
|
-
image_url[
|
|
1540
|
+
image_url["url"] = image
|
|
874
1541
|
else:
|
|
875
|
-
first_message[
|
|
876
|
-
image_content,
|
|
877
|
-
{ "type": "text", "text": first_message['content'] }
|
|
878
|
-
]
|
|
1542
|
+
first_message["content"] = [image_content, {"type": "text", "text": first_message["content"]}]
|
|
879
1543
|
if audio is not None:
|
|
880
1544
|
first_message = None
|
|
881
|
-
for message in chat[
|
|
882
|
-
if message[
|
|
1545
|
+
for message in chat["messages"]:
|
|
1546
|
+
if message["role"] == "user":
|
|
883
1547
|
first_message = message
|
|
884
1548
|
break
|
|
885
|
-
audio_content = {
|
|
886
|
-
|
|
887
|
-
"
|
|
888
|
-
"data": audio,
|
|
889
|
-
"format": "mp3"
|
|
890
|
-
}
|
|
891
|
-
}
|
|
892
|
-
if 'content' in first_message:
|
|
893
|
-
if isinstance(first_message['content'], list):
|
|
1549
|
+
audio_content = {"type": "input_audio", "input_audio": {"data": audio, "format": "mp3"}}
|
|
1550
|
+
if "content" in first_message:
|
|
1551
|
+
if isinstance(first_message["content"], list):
|
|
894
1552
|
input_audio = None
|
|
895
|
-
for item in first_message[
|
|
896
|
-
if
|
|
897
|
-
input_audio = item[
|
|
1553
|
+
for item in first_message["content"]:
|
|
1554
|
+
if "input_audio" in item:
|
|
1555
|
+
input_audio = item["input_audio"]
|
|
898
1556
|
# If no input_audio, add one
|
|
899
1557
|
if input_audio is None:
|
|
900
|
-
first_message[
|
|
1558
|
+
first_message["content"].insert(0, audio_content)
|
|
901
1559
|
else:
|
|
902
|
-
input_audio[
|
|
1560
|
+
input_audio["data"] = audio
|
|
903
1561
|
else:
|
|
904
|
-
first_message[
|
|
905
|
-
audio_content,
|
|
906
|
-
{ "type": "text", "text": first_message['content'] }
|
|
907
|
-
]
|
|
1562
|
+
first_message["content"] = [audio_content, {"type": "text", "text": first_message["content"]}]
|
|
908
1563
|
if file is not None:
|
|
909
1564
|
first_message = None
|
|
910
|
-
for message in chat[
|
|
911
|
-
if message[
|
|
1565
|
+
for message in chat["messages"]:
|
|
1566
|
+
if message["role"] == "user":
|
|
912
1567
|
first_message = message
|
|
913
1568
|
break
|
|
914
|
-
file_content = {
|
|
915
|
-
|
|
916
|
-
"
|
|
917
|
-
"filename": get_filename(file),
|
|
918
|
-
"file_data": file
|
|
919
|
-
}
|
|
920
|
-
}
|
|
921
|
-
if 'content' in first_message:
|
|
922
|
-
if isinstance(first_message['content'], list):
|
|
1569
|
+
file_content = {"type": "file", "file": {"filename": get_filename(file), "file_data": file}}
|
|
1570
|
+
if "content" in first_message:
|
|
1571
|
+
if isinstance(first_message["content"], list):
|
|
923
1572
|
file_data = None
|
|
924
|
-
for item in first_message[
|
|
925
|
-
if
|
|
926
|
-
file_data = item[
|
|
1573
|
+
for item in first_message["content"]:
|
|
1574
|
+
if "file" in item:
|
|
1575
|
+
file_data = item["file"]
|
|
927
1576
|
# If no file_data, add one
|
|
928
1577
|
if file_data is None:
|
|
929
|
-
first_message[
|
|
1578
|
+
first_message["content"].insert(0, file_content)
|
|
930
1579
|
else:
|
|
931
|
-
file_data[
|
|
932
|
-
file_data[
|
|
1580
|
+
file_data["filename"] = get_filename(file)
|
|
1581
|
+
file_data["file_data"] = file
|
|
933
1582
|
else:
|
|
934
|
-
first_message[
|
|
935
|
-
file_content,
|
|
936
|
-
{ "type": "text", "text": first_message['content'] }
|
|
937
|
-
]
|
|
1583
|
+
first_message["content"] = [file_content, {"type": "text", "text": first_message["content"]}]
|
|
938
1584
|
|
|
939
1585
|
if g_verbose:
|
|
940
1586
|
printdump(chat)
|
|
941
1587
|
|
|
942
1588
|
try:
|
|
943
|
-
|
|
1589
|
+
context = {
|
|
1590
|
+
"tools": tools or "all",
|
|
1591
|
+
}
|
|
1592
|
+
response = await g_app.chat_completion(chat, context=context)
|
|
1593
|
+
|
|
944
1594
|
if raw:
|
|
945
1595
|
print(json.dumps(response, indent=2))
|
|
946
1596
|
exit(0)
|
|
947
1597
|
else:
|
|
948
|
-
|
|
949
|
-
|
|
1598
|
+
msg = response["choices"][0]["message"]
|
|
1599
|
+
if "content" in msg or "answer" in msg:
|
|
1600
|
+
print(msg["content"])
|
|
1601
|
+
|
|
1602
|
+
generated_files = []
|
|
1603
|
+
for choice in response["choices"]:
|
|
1604
|
+
if "message" in choice:
|
|
1605
|
+
msg = choice["message"]
|
|
1606
|
+
if "images" in msg:
|
|
1607
|
+
for image in msg["images"]:
|
|
1608
|
+
image_url = image["image_url"]["url"]
|
|
1609
|
+
generated_files.append(image_url)
|
|
1610
|
+
if "audios" in msg:
|
|
1611
|
+
for audio in msg["audios"]:
|
|
1612
|
+
audio_url = audio["audio_url"]["url"]
|
|
1613
|
+
generated_files.append(audio_url)
|
|
1614
|
+
|
|
1615
|
+
if len(generated_files) > 0:
|
|
1616
|
+
print("\nSaved files:")
|
|
1617
|
+
for file in generated_files:
|
|
1618
|
+
if file.startswith("/~cache"):
|
|
1619
|
+
print(get_cache_path(file[8:]))
|
|
1620
|
+
print(urljoin("http://localhost:8000", file))
|
|
1621
|
+
else:
|
|
1622
|
+
print(file)
|
|
1623
|
+
|
|
950
1624
|
except HTTPError as e:
|
|
951
1625
|
# HTTP error (4xx, 5xx)
|
|
952
1626
|
print(f"{e}:\n{e.body}")
|
|
953
|
-
exit(1)
|
|
1627
|
+
g_app.exit(1)
|
|
954
1628
|
except aiohttp.ClientConnectionError as e:
|
|
955
1629
|
# Connection issues
|
|
956
1630
|
print(f"Connection error: {e}")
|
|
957
|
-
exit(1)
|
|
1631
|
+
g_app.exit(1)
|
|
958
1632
|
except asyncio.TimeoutError as e:
|
|
959
1633
|
# Timeout
|
|
960
1634
|
print(f"Timeout error: {e}")
|
|
961
|
-
exit(1)
|
|
1635
|
+
g_app.exit(1)
|
|
1636
|
+
|
|
962
1637
|
|
|
963
1638
|
def config_str(key):
|
|
964
1639
|
return key in g_config and g_config[key] or None
|
|
965
1640
|
|
|
966
|
-
def init_llms(config):
|
|
967
|
-
global g_config, g_handlers
|
|
968
1641
|
|
|
1642
|
+
def load_config(config, providers, verbose=None):
|
|
1643
|
+
global g_config, g_providers, g_verbose
|
|
969
1644
|
g_config = config
|
|
1645
|
+
g_providers = providers
|
|
1646
|
+
if verbose:
|
|
1647
|
+
g_verbose = verbose
|
|
1648
|
+
|
|
1649
|
+
|
|
1650
|
+
def init_llms(config, providers):
|
|
1651
|
+
global g_config, g_handlers
|
|
1652
|
+
|
|
1653
|
+
load_config(config, providers)
|
|
970
1654
|
g_handlers = {}
|
|
971
1655
|
# iterate over config and replace $ENV with env value
|
|
972
1656
|
for key, value in g_config.items():
|
|
973
1657
|
if isinstance(value, str) and value.startswith("$"):
|
|
974
|
-
g_config[key] = os.
|
|
1658
|
+
g_config[key] = os.getenv(value[1:], "")
|
|
975
1659
|
|
|
976
1660
|
# if g_verbose:
|
|
977
1661
|
# printdump(g_config)
|
|
978
|
-
providers = g_config[
|
|
1662
|
+
providers = g_config["providers"]
|
|
979
1663
|
|
|
980
|
-
for
|
|
981
|
-
|
|
982
|
-
provider_type = definition['type']
|
|
983
|
-
if 'enabled' in definition and not definition['enabled']:
|
|
1664
|
+
for id, orig in providers.items():
|
|
1665
|
+
if "enabled" in orig and not orig["enabled"]:
|
|
984
1666
|
continue
|
|
985
1667
|
|
|
986
|
-
|
|
987
|
-
if
|
|
988
|
-
|
|
989
|
-
if isinstance(value, str) and value.startswith("$"):
|
|
990
|
-
definition['api_key'] = os.environ.get(value[1:], "")
|
|
991
|
-
|
|
992
|
-
# Create a copy of definition without the 'type' key for constructor kwargs
|
|
993
|
-
constructor_kwargs = {k: v for k, v in definition.items() if k != 'type' and k != 'enabled'}
|
|
994
|
-
constructor_kwargs['headers'] = g_config['defaults']['headers'].copy()
|
|
995
|
-
|
|
996
|
-
if provider_type == 'OpenAiProvider' and OpenAiProvider.test(**constructor_kwargs):
|
|
997
|
-
g_handlers[name] = OpenAiProvider(**constructor_kwargs)
|
|
998
|
-
elif provider_type == 'OllamaProvider' and OllamaProvider.test(**constructor_kwargs):
|
|
999
|
-
g_handlers[name] = OllamaProvider(**constructor_kwargs)
|
|
1000
|
-
elif provider_type == 'GoogleProvider' and GoogleProvider.test(**constructor_kwargs):
|
|
1001
|
-
g_handlers[name] = GoogleProvider(**constructor_kwargs)
|
|
1002
|
-
elif provider_type == 'GoogleOpenAiProvider' and GoogleOpenAiProvider.test(**constructor_kwargs):
|
|
1003
|
-
g_handlers[name] = GoogleOpenAiProvider(**constructor_kwargs)
|
|
1004
|
-
|
|
1668
|
+
provider, constructor_kwargs = create_provider_from_definition(id, orig)
|
|
1669
|
+
if provider and provider.test(**constructor_kwargs):
|
|
1670
|
+
g_handlers[id] = provider
|
|
1005
1671
|
return g_handlers
|
|
1006
1672
|
|
|
1673
|
+
|
|
1674
|
+
def create_provider_from_definition(id, orig):
|
|
1675
|
+
definition = orig.copy()
|
|
1676
|
+
provider_id = definition.get("id", id)
|
|
1677
|
+
if "id" not in definition:
|
|
1678
|
+
definition["id"] = provider_id
|
|
1679
|
+
provider = g_providers.get(provider_id)
|
|
1680
|
+
constructor_kwargs = create_provider_kwargs(definition, provider)
|
|
1681
|
+
provider = create_provider(constructor_kwargs)
|
|
1682
|
+
return provider, constructor_kwargs
|
|
1683
|
+
|
|
1684
|
+
|
|
1685
|
+
def create_provider_kwargs(definition, provider=None):
|
|
1686
|
+
if provider:
|
|
1687
|
+
provider = provider.copy()
|
|
1688
|
+
provider.update(definition)
|
|
1689
|
+
else:
|
|
1690
|
+
provider = definition.copy()
|
|
1691
|
+
|
|
1692
|
+
# Replace API keys with environment variables if they start with $
|
|
1693
|
+
if "api_key" in provider:
|
|
1694
|
+
value = provider["api_key"]
|
|
1695
|
+
if isinstance(value, str) and value.startswith("$"):
|
|
1696
|
+
provider["api_key"] = os.getenv(value[1:], "")
|
|
1697
|
+
|
|
1698
|
+
if "api_key" not in provider and "env" in provider:
|
|
1699
|
+
for env_var in provider["env"]:
|
|
1700
|
+
val = os.getenv(env_var)
|
|
1701
|
+
if val:
|
|
1702
|
+
provider["api_key"] = val
|
|
1703
|
+
break
|
|
1704
|
+
|
|
1705
|
+
# Create a copy of provider
|
|
1706
|
+
constructor_kwargs = dict(provider.items())
|
|
1707
|
+
# Create a copy of all list and dict values
|
|
1708
|
+
for key, value in constructor_kwargs.items():
|
|
1709
|
+
if isinstance(value, (list, dict)):
|
|
1710
|
+
constructor_kwargs[key] = value.copy()
|
|
1711
|
+
constructor_kwargs["headers"] = g_config["defaults"]["headers"].copy()
|
|
1712
|
+
|
|
1713
|
+
if "modalities" in definition:
|
|
1714
|
+
constructor_kwargs["modalities"] = {}
|
|
1715
|
+
for modality, modality_definition in definition["modalities"].items():
|
|
1716
|
+
modality_provider = create_provider(modality_definition)
|
|
1717
|
+
if not modality_provider:
|
|
1718
|
+
return None
|
|
1719
|
+
constructor_kwargs["modalities"][modality] = modality_provider
|
|
1720
|
+
|
|
1721
|
+
return constructor_kwargs
|
|
1722
|
+
|
|
1723
|
+
|
|
1724
|
+
def create_provider(provider):
|
|
1725
|
+
if not isinstance(provider, dict):
|
|
1726
|
+
return None
|
|
1727
|
+
provider_label = provider.get("id", provider.get("name", "unknown"))
|
|
1728
|
+
npm_sdk = provider.get("npm")
|
|
1729
|
+
if not npm_sdk:
|
|
1730
|
+
_log(f"Provider {provider_label} is missing 'npm' sdk")
|
|
1731
|
+
return None
|
|
1732
|
+
|
|
1733
|
+
for provider_type in g_app.all_providers:
|
|
1734
|
+
if provider_type.sdk == npm_sdk:
|
|
1735
|
+
kwargs = create_provider_kwargs(provider)
|
|
1736
|
+
if kwargs is None:
|
|
1737
|
+
kwargs = provider
|
|
1738
|
+
return provider_type(**kwargs)
|
|
1739
|
+
|
|
1740
|
+
_log(f"Could not find provider {provider_label} with npm sdk {npm_sdk}")
|
|
1741
|
+
return None
|
|
1742
|
+
|
|
1743
|
+
|
|
1007
1744
|
async def load_llms():
|
|
1008
1745
|
global g_handlers
|
|
1009
1746
|
_log("Loading providers...")
|
|
1010
|
-
for
|
|
1747
|
+
for _name, provider in g_handlers.items():
|
|
1011
1748
|
await provider.load()
|
|
1012
1749
|
|
|
1750
|
+
|
|
1013
1751
|
def save_config(config):
|
|
1014
1752
|
global g_config, g_config_path
|
|
1015
1753
|
g_config = config
|
|
1016
|
-
with open(g_config_path, "w") as f:
|
|
1754
|
+
with open(g_config_path, "w", encoding="utf-8") as f:
|
|
1017
1755
|
json.dump(g_config, f, indent=4)
|
|
1018
1756
|
_log(f"Saved config to {g_config_path}")
|
|
1019
1757
|
|
|
1758
|
+
|
|
1020
1759
|
def github_url(filename):
|
|
1021
1760
|
return f"https://raw.githubusercontent.com/ServiceStack/llms/refs/heads/main/llms/{filename}"
|
|
1022
1761
|
|
|
1762
|
+
|
|
1023
1763
|
async def get_text(url):
|
|
1024
1764
|
async with aiohttp.ClientSession() as session:
|
|
1025
1765
|
_log(f"GET {url}")
|
|
@@ -1029,25 +1769,58 @@ async def get_text(url):
|
|
|
1029
1769
|
raise HTTPError(resp.status, reason=resp.reason, body=text, headers=dict(resp.headers))
|
|
1030
1770
|
return text
|
|
1031
1771
|
|
|
1772
|
+
|
|
1032
1773
|
async def save_text_url(url, save_path):
|
|
1033
1774
|
text = await get_text(url)
|
|
1034
1775
|
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
1035
|
-
with open(save_path, "w") as f:
|
|
1776
|
+
with open(save_path, "w", encoding="utf-8") as f:
|
|
1036
1777
|
f.write(text)
|
|
1037
1778
|
return text
|
|
1038
1779
|
|
|
1780
|
+
|
|
1039
1781
|
async def save_default_config(config_path):
|
|
1040
1782
|
global g_config
|
|
1041
1783
|
config_json = await save_text_url(github_url("llms.json"), config_path)
|
|
1042
1784
|
g_config = json.loads(config_json)
|
|
1043
1785
|
|
|
1786
|
+
|
|
1787
|
+
async def update_providers(home_providers_path):
|
|
1788
|
+
global g_providers
|
|
1789
|
+
text = await get_text("https://models.dev/api.json")
|
|
1790
|
+
all_providers = json.loads(text)
|
|
1791
|
+
extra_providers = {}
|
|
1792
|
+
extra_providers_path = home_providers_path.replace("providers.json", "providers-extra.json")
|
|
1793
|
+
if os.path.exists(extra_providers_path):
|
|
1794
|
+
with open(extra_providers_path) as f:
|
|
1795
|
+
extra_providers = json.load(f)
|
|
1796
|
+
|
|
1797
|
+
filtered_providers = {}
|
|
1798
|
+
for id, provider in all_providers.items():
|
|
1799
|
+
if id in g_config["providers"]:
|
|
1800
|
+
filtered_providers[id] = provider
|
|
1801
|
+
if id in extra_providers and "models" in extra_providers[id]:
|
|
1802
|
+
for model_id, model in extra_providers[id]["models"].items():
|
|
1803
|
+
if "id" not in model:
|
|
1804
|
+
model["id"] = model_id
|
|
1805
|
+
if "name" not in model:
|
|
1806
|
+
model["name"] = id_to_name(model["id"])
|
|
1807
|
+
filtered_providers[id]["models"][model_id] = model
|
|
1808
|
+
|
|
1809
|
+
os.makedirs(os.path.dirname(home_providers_path), exist_ok=True)
|
|
1810
|
+
with open(home_providers_path, "w", encoding="utf-8") as f:
|
|
1811
|
+
json.dump(filtered_providers, f)
|
|
1812
|
+
|
|
1813
|
+
g_providers = filtered_providers
|
|
1814
|
+
|
|
1815
|
+
|
|
1044
1816
|
def provider_status():
|
|
1045
1817
|
enabled = list(g_handlers.keys())
|
|
1046
|
-
disabled = [provider for provider in g_config[
|
|
1818
|
+
disabled = [provider for provider in g_config["providers"] if provider not in enabled]
|
|
1047
1819
|
enabled.sort()
|
|
1048
1820
|
disabled.sort()
|
|
1049
1821
|
return enabled, disabled
|
|
1050
1822
|
|
|
1823
|
+
|
|
1051
1824
|
def print_status():
|
|
1052
1825
|
enabled, disabled = provider_status()
|
|
1053
1826
|
if len(enabled) > 0:
|
|
@@ -1059,8 +1832,14 @@ def print_status():
|
|
|
1059
1832
|
else:
|
|
1060
1833
|
print("Disabled: None")
|
|
1061
1834
|
|
|
1835
|
+
|
|
1062
1836
|
def home_llms_path(filename):
|
|
1063
|
-
return f"{os.
|
|
1837
|
+
return f"{os.getenv('HOME')}/.llms/{filename}"
|
|
1838
|
+
|
|
1839
|
+
|
|
1840
|
+
def get_cache_path(path=""):
|
|
1841
|
+
return home_llms_path(f"cache/{path}") if path else home_llms_path("cache")
|
|
1842
|
+
|
|
1064
1843
|
|
|
1065
1844
|
def get_config_path():
|
|
1066
1845
|
home_config_path = home_llms_path("llms.json")
|
|
@@ -1068,8 +1847,8 @@ def get_config_path():
|
|
|
1068
1847
|
"./llms.json",
|
|
1069
1848
|
home_config_path,
|
|
1070
1849
|
]
|
|
1071
|
-
if os.
|
|
1072
|
-
check_paths.insert(0, os.
|
|
1850
|
+
if os.getenv("LLMS_CONFIG_PATH"):
|
|
1851
|
+
check_paths.insert(0, os.getenv("LLMS_CONFIG_PATH"))
|
|
1073
1852
|
|
|
1074
1853
|
for check_path in check_paths:
|
|
1075
1854
|
g_config_path = os.path.normpath(os.path.join(os.path.dirname(__file__), check_path))
|
|
@@ -1077,37 +1856,30 @@ def get_config_path():
|
|
|
1077
1856
|
return g_config_path
|
|
1078
1857
|
return None
|
|
1079
1858
|
|
|
1080
|
-
def get_ui_path():
|
|
1081
|
-
ui_paths = [
|
|
1082
|
-
home_llms_path("ui.json"),
|
|
1083
|
-
"ui.json"
|
|
1084
|
-
]
|
|
1085
|
-
for ui_path in ui_paths:
|
|
1086
|
-
if os.path.exists(ui_path):
|
|
1087
|
-
return ui_path
|
|
1088
|
-
return None
|
|
1089
1859
|
|
|
1090
1860
|
def enable_provider(provider):
|
|
1091
1861
|
msg = None
|
|
1092
|
-
provider_config = g_config[
|
|
1093
|
-
|
|
1094
|
-
|
|
1095
|
-
|
|
1096
|
-
|
|
1097
|
-
|
|
1098
|
-
|
|
1099
|
-
|
|
1100
|
-
|
|
1101
|
-
|
|
1862
|
+
provider_config = g_config["providers"][provider]
|
|
1863
|
+
if not provider_config:
|
|
1864
|
+
return None, f"Provider {provider} not found"
|
|
1865
|
+
|
|
1866
|
+
provider, constructor_kwargs = create_provider_from_definition(provider, provider_config)
|
|
1867
|
+
msg = provider.validate(**constructor_kwargs)
|
|
1868
|
+
if msg:
|
|
1869
|
+
return None, msg
|
|
1870
|
+
|
|
1871
|
+
provider_config["enabled"] = True
|
|
1102
1872
|
save_config(g_config)
|
|
1103
|
-
init_llms(g_config)
|
|
1873
|
+
init_llms(g_config, g_providers)
|
|
1104
1874
|
return provider_config, msg
|
|
1105
1875
|
|
|
1876
|
+
|
|
1106
1877
|
def disable_provider(provider):
|
|
1107
|
-
provider_config = g_config[
|
|
1108
|
-
provider_config[
|
|
1878
|
+
provider_config = g_config["providers"][provider]
|
|
1879
|
+
provider_config["enabled"] = False
|
|
1109
1880
|
save_config(g_config)
|
|
1110
|
-
init_llms(g_config)
|
|
1881
|
+
init_llms(g_config, g_providers)
|
|
1882
|
+
|
|
1111
1883
|
|
|
1112
1884
|
def resolve_root():
|
|
1113
1885
|
# Try to find the resource root directory
|
|
@@ -1119,7 +1891,7 @@ def resolve_root():
|
|
|
1119
1891
|
# Try to access the package resources
|
|
1120
1892
|
pkg_files = resources.files("llms")
|
|
1121
1893
|
# Check if ui directory exists in package resources
|
|
1122
|
-
if hasattr(pkg_files,
|
|
1894
|
+
if hasattr(pkg_files, "is_dir") and (pkg_files / "ui").is_dir():
|
|
1123
1895
|
_log(f"RESOURCE ROOT (package): {pkg_files}")
|
|
1124
1896
|
return pkg_files
|
|
1125
1897
|
except (FileNotFoundError, AttributeError, TypeError):
|
|
@@ -1132,8 +1904,9 @@ def resolve_root():
|
|
|
1132
1904
|
# Method 1b: Look for the installed package and check for UI files
|
|
1133
1905
|
try:
|
|
1134
1906
|
import llms
|
|
1907
|
+
|
|
1135
1908
|
# If llms is a package, check its directory
|
|
1136
|
-
if hasattr(llms,
|
|
1909
|
+
if hasattr(llms, "__path__"):
|
|
1137
1910
|
# It's a package
|
|
1138
1911
|
package_path = Path(llms.__path__[0])
|
|
1139
1912
|
|
|
@@ -1170,21 +1943,25 @@ def resolve_root():
|
|
|
1170
1943
|
|
|
1171
1944
|
# Add site-packages directories
|
|
1172
1945
|
for site_dir in site.getsitepackages():
|
|
1173
|
-
possible_roots.extend(
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
|
|
1177
|
-
|
|
1946
|
+
possible_roots.extend(
|
|
1947
|
+
[
|
|
1948
|
+
Path(site_dir),
|
|
1949
|
+
Path(site_dir).parent,
|
|
1950
|
+
Path(site_dir).parent / "share",
|
|
1951
|
+
]
|
|
1952
|
+
)
|
|
1178
1953
|
|
|
1179
1954
|
# Add user site directory
|
|
1180
1955
|
try:
|
|
1181
1956
|
user_site = site.getusersitepackages()
|
|
1182
1957
|
if user_site:
|
|
1183
|
-
possible_roots.extend(
|
|
1184
|
-
|
|
1185
|
-
|
|
1186
|
-
|
|
1187
|
-
|
|
1958
|
+
possible_roots.extend(
|
|
1959
|
+
[
|
|
1960
|
+
Path(user_site),
|
|
1961
|
+
Path(user_site).parent,
|
|
1962
|
+
Path(user_site).parent / "share",
|
|
1963
|
+
]
|
|
1964
|
+
)
|
|
1188
1965
|
except AttributeError:
|
|
1189
1966
|
pass
|
|
1190
1967
|
|
|
@@ -1195,12 +1972,17 @@ def resolve_root():
|
|
|
1195
1972
|
homebrew_prefixes = ["/opt/homebrew", "/usr/local"] # Apple Silicon and Intel
|
|
1196
1973
|
for prefix in homebrew_prefixes:
|
|
1197
1974
|
if Path(prefix).exists():
|
|
1198
|
-
homebrew_roots.extend(
|
|
1199
|
-
|
|
1200
|
-
|
|
1201
|
-
|
|
1202
|
-
|
|
1203
|
-
|
|
1975
|
+
homebrew_roots.extend(
|
|
1976
|
+
[
|
|
1977
|
+
Path(prefix),
|
|
1978
|
+
Path(prefix) / "share",
|
|
1979
|
+
Path(prefix) / "lib" / "python3.11" / "site-packages",
|
|
1980
|
+
Path(prefix)
|
|
1981
|
+
/ "lib"
|
|
1982
|
+
/ f"python{sys.version_info.major}.{sys.version_info.minor}"
|
|
1983
|
+
/ "site-packages",
|
|
1984
|
+
]
|
|
1985
|
+
)
|
|
1204
1986
|
|
|
1205
1987
|
possible_roots.extend(homebrew_roots)
|
|
1206
1988
|
|
|
@@ -1232,26 +2014,29 @@ def resolve_root():
|
|
|
1232
2014
|
_log(f"RESOURCE ROOT (fallback): {from_file}")
|
|
1233
2015
|
return from_file
|
|
1234
2016
|
|
|
2017
|
+
|
|
1235
2018
|
def resource_exists(resource_path):
|
|
1236
2019
|
# Check if resource files exist (handle both Path and Traversable objects)
|
|
1237
2020
|
try:
|
|
1238
|
-
if hasattr(resource_path,
|
|
2021
|
+
if hasattr(resource_path, "is_file"):
|
|
1239
2022
|
return resource_path.is_file()
|
|
1240
2023
|
else:
|
|
1241
2024
|
return os.path.exists(resource_path)
|
|
1242
2025
|
except (OSError, AttributeError):
|
|
1243
2026
|
pass
|
|
1244
2027
|
|
|
2028
|
+
|
|
1245
2029
|
def read_resource_text(resource_path):
|
|
1246
|
-
if hasattr(resource_path,
|
|
2030
|
+
if hasattr(resource_path, "read_text"):
|
|
1247
2031
|
return resource_path.read_text()
|
|
1248
2032
|
else:
|
|
1249
|
-
with open(resource_path, "
|
|
2033
|
+
with open(resource_path, encoding="utf-8") as f:
|
|
1250
2034
|
return f.read()
|
|
1251
2035
|
|
|
2036
|
+
|
|
1252
2037
|
def read_resource_file_bytes(resource_file):
|
|
1253
2038
|
try:
|
|
1254
|
-
if hasattr(_ROOT,
|
|
2039
|
+
if hasattr(_ROOT, "joinpath"):
|
|
1255
2040
|
# importlib.resources Traversable
|
|
1256
2041
|
index_resource = _ROOT.joinpath(resource_file)
|
|
1257
2042
|
if index_resource.is_file():
|
|
@@ -1264,6 +2049,7 @@ def read_resource_file_bytes(resource_file):
|
|
|
1264
2049
|
except (OSError, PermissionError, AttributeError) as e:
|
|
1265
2050
|
_log(f"Error reading resource bytes: {e}")
|
|
1266
2051
|
|
|
2052
|
+
|
|
1267
2053
|
async def check_models(provider_name, model_names=None):
|
|
1268
2054
|
"""
|
|
1269
2055
|
Check validity of models for a specific provider by sending a ping message.
|
|
@@ -1281,13 +2067,14 @@ async def check_models(provider_name, model_names=None):
|
|
|
1281
2067
|
models_to_check = []
|
|
1282
2068
|
|
|
1283
2069
|
# Determine which models to check
|
|
1284
|
-
if model_names is None or (len(model_names) == 1 and model_names[0] ==
|
|
2070
|
+
if model_names is None or (len(model_names) == 1 and model_names[0] == "all"):
|
|
1285
2071
|
# Check all models for this provider
|
|
1286
2072
|
models_to_check = list(provider.models.keys())
|
|
1287
2073
|
else:
|
|
1288
2074
|
# Check only specified models
|
|
1289
2075
|
for model_name in model_names:
|
|
1290
|
-
|
|
2076
|
+
provider_model = provider.provider_model(model_name)
|
|
2077
|
+
if provider_model:
|
|
1291
2078
|
models_to_check.append(model_name)
|
|
1292
2079
|
else:
|
|
1293
2080
|
print(f"Model '{model_name}' not found in provider '{provider_name}'")
|
|
@@ -1296,68 +2083,83 @@ async def check_models(provider_name, model_names=None):
|
|
|
1296
2083
|
print(f"No models to check for provider '{provider_name}'")
|
|
1297
2084
|
return
|
|
1298
2085
|
|
|
1299
|
-
print(
|
|
2086
|
+
print(
|
|
2087
|
+
f"\nChecking {len(models_to_check)} model{'' if len(models_to_check) == 1 else 's'} for provider '{provider_name}':\n"
|
|
2088
|
+
)
|
|
1300
2089
|
|
|
1301
2090
|
# Test each model
|
|
1302
2091
|
for model in models_to_check:
|
|
1303
|
-
|
|
1304
|
-
chat = (provider.check or g_config['defaults']['check']).copy()
|
|
1305
|
-
chat["model"] = model
|
|
2092
|
+
await check_provider_model(provider, model)
|
|
1306
2093
|
|
|
1307
|
-
|
|
1308
|
-
try:
|
|
1309
|
-
# Try to get a response from the model
|
|
1310
|
-
response = await provider.chat(chat)
|
|
1311
|
-
duration_ms = int((time.time() - started_at) * 1000)
|
|
2094
|
+
print()
|
|
1312
2095
|
|
|
1313
|
-
# Check if we got a valid response
|
|
1314
|
-
if response and 'choices' in response and len(response['choices']) > 0:
|
|
1315
|
-
print(f" ✓ {model:<40} ({duration_ms}ms)")
|
|
1316
|
-
else:
|
|
1317
|
-
print(f" ✗ {model:<40} Invalid response format")
|
|
1318
|
-
except HTTPError as e:
|
|
1319
|
-
duration_ms = int((time.time() - started_at) * 1000)
|
|
1320
|
-
error_msg = f"HTTP {e.status}"
|
|
1321
|
-
try:
|
|
1322
|
-
# Try to parse error body for more details
|
|
1323
|
-
error_body = json.loads(e.body) if e.body else {}
|
|
1324
|
-
if 'error' in error_body:
|
|
1325
|
-
error = error_body['error']
|
|
1326
|
-
if isinstance(error, dict):
|
|
1327
|
-
if 'message' in error:
|
|
1328
|
-
# OpenRouter
|
|
1329
|
-
if isinstance(error['message'], str):
|
|
1330
|
-
error_msg = error['message']
|
|
1331
|
-
if 'code' in error:
|
|
1332
|
-
error_msg = f"{error['code']} {error_msg}"
|
|
1333
|
-
if 'metadata' in error and 'raw' in error['metadata']:
|
|
1334
|
-
error_msg += f" - {error['metadata']['raw']}"
|
|
1335
|
-
if 'provider' in error:
|
|
1336
|
-
error_msg += f" ({error['provider']})"
|
|
1337
|
-
elif isinstance(error, str):
|
|
1338
|
-
error_msg = error
|
|
1339
|
-
elif 'message' in error_body:
|
|
1340
|
-
if isinstance(error_body['message'], str):
|
|
1341
|
-
error_msg = error_body['message']
|
|
1342
|
-
elif isinstance(error_body['message'], dict):
|
|
1343
|
-
# codestral error format
|
|
1344
|
-
if 'detail' in error_body['message'] and isinstance(error_body['message']['detail'], list):
|
|
1345
|
-
error_msg = error_body['message']['detail'][0]['msg']
|
|
1346
|
-
if 'loc' in error_body['message']['detail'][0] and len(error_body['message']['detail'][0]['loc']) > 0:
|
|
1347
|
-
error_msg += f" (in {' '.join(error_body['message']['detail'][0]['loc'])})"
|
|
1348
|
-
except Exception as parse_error:
|
|
1349
|
-
_log(f"Error parsing error body: {parse_error}")
|
|
1350
|
-
error_msg = e.body[:100] if e.body else f"HTTP {e.status}"
|
|
1351
|
-
print(f" ✗ {model:<40} {error_msg}")
|
|
1352
|
-
except asyncio.TimeoutError:
|
|
1353
|
-
duration_ms = int((time.time() - started_at) * 1000)
|
|
1354
|
-
print(f" ✗ {model:<40} Timeout after {duration_ms}ms")
|
|
1355
|
-
except Exception as e:
|
|
1356
|
-
duration_ms = int((time.time() - started_at) * 1000)
|
|
1357
|
-
error_msg = str(e)[:100]
|
|
1358
|
-
print(f" ✗ {model:<40} {error_msg}")
|
|
1359
2096
|
|
|
1360
|
-
|
|
2097
|
+
async def check_provider_model(provider, model):
|
|
2098
|
+
# Create a simple ping chat request
|
|
2099
|
+
chat = (provider.check or g_config["defaults"]["check"]).copy()
|
|
2100
|
+
chat["model"] = model
|
|
2101
|
+
|
|
2102
|
+
success = False
|
|
2103
|
+
started_at = time.time()
|
|
2104
|
+
try:
|
|
2105
|
+
# Try to get a response from the model
|
|
2106
|
+
response = await provider.chat(chat)
|
|
2107
|
+
duration_ms = int((time.time() - started_at) * 1000)
|
|
2108
|
+
|
|
2109
|
+
# Check if we got a valid response
|
|
2110
|
+
if response and "choices" in response and len(response["choices"]) > 0:
|
|
2111
|
+
success = True
|
|
2112
|
+
print(f" ✓ {model:<40} ({duration_ms}ms)")
|
|
2113
|
+
else:
|
|
2114
|
+
print(f" ✗ {model:<40} Invalid response format")
|
|
2115
|
+
except HTTPError as e:
|
|
2116
|
+
duration_ms = int((time.time() - started_at) * 1000)
|
|
2117
|
+
error_msg = f"HTTP {e.status}"
|
|
2118
|
+
try:
|
|
2119
|
+
# Try to parse error body for more details
|
|
2120
|
+
error_body = json.loads(e.body) if e.body else {}
|
|
2121
|
+
if "error" in error_body:
|
|
2122
|
+
error = error_body["error"]
|
|
2123
|
+
if isinstance(error, dict):
|
|
2124
|
+
if "message" in error and isinstance(error["message"], str):
|
|
2125
|
+
# OpenRouter
|
|
2126
|
+
error_msg = error["message"]
|
|
2127
|
+
if "code" in error:
|
|
2128
|
+
error_msg = f"{error['code']} {error_msg}"
|
|
2129
|
+
if "metadata" in error and "raw" in error["metadata"]:
|
|
2130
|
+
error_msg += f" - {error['metadata']['raw']}"
|
|
2131
|
+
if "provider" in error:
|
|
2132
|
+
error_msg += f" ({error['provider']})"
|
|
2133
|
+
elif isinstance(error, str):
|
|
2134
|
+
error_msg = error
|
|
2135
|
+
elif "message" in error_body:
|
|
2136
|
+
if isinstance(error_body["message"], str):
|
|
2137
|
+
error_msg = error_body["message"]
|
|
2138
|
+
elif (
|
|
2139
|
+
isinstance(error_body["message"], dict)
|
|
2140
|
+
and "detail" in error_body["message"]
|
|
2141
|
+
and isinstance(error_body["message"]["detail"], list)
|
|
2142
|
+
):
|
|
2143
|
+
# codestral error format
|
|
2144
|
+
error_msg = error_body["message"]["detail"][0]["msg"]
|
|
2145
|
+
if (
|
|
2146
|
+
"loc" in error_body["message"]["detail"][0]
|
|
2147
|
+
and len(error_body["message"]["detail"][0]["loc"]) > 0
|
|
2148
|
+
):
|
|
2149
|
+
error_msg += f" (in {' '.join(error_body['message']['detail'][0]['loc'])})"
|
|
2150
|
+
except Exception as parse_error:
|
|
2151
|
+
_log(f"Error parsing error body: {parse_error}")
|
|
2152
|
+
error_msg = e.body[:100] if e.body else f"HTTP {e.status}"
|
|
2153
|
+
print(f" ✗ {model:<40} {error_msg}")
|
|
2154
|
+
except asyncio.TimeoutError:
|
|
2155
|
+
duration_ms = int((time.time() - started_at) * 1000)
|
|
2156
|
+
print(f" ✗ {model:<40} Timeout after {duration_ms}ms")
|
|
2157
|
+
except Exception as e:
|
|
2158
|
+
duration_ms = int((time.time() - started_at) * 1000)
|
|
2159
|
+
error_msg = str(e)[:100]
|
|
2160
|
+
print(f" ✗ {model:<40} {error_msg}")
|
|
2161
|
+
return success
|
|
2162
|
+
|
|
1361
2163
|
|
|
1362
2164
|
def text_from_resource(filename):
|
|
1363
2165
|
global _ROOT
|
|
@@ -1369,12 +2171,14 @@ def text_from_resource(filename):
|
|
|
1369
2171
|
_log(f"Error reading resource config {filename}: {e}")
|
|
1370
2172
|
return None
|
|
1371
2173
|
|
|
2174
|
+
|
|
1372
2175
|
def text_from_file(filename):
|
|
1373
2176
|
if os.path.exists(filename):
|
|
1374
|
-
with open(filename, "
|
|
2177
|
+
with open(filename, encoding="utf-8") as f:
|
|
1375
2178
|
return f.read()
|
|
1376
2179
|
return None
|
|
1377
2180
|
|
|
2181
|
+
|
|
1378
2182
|
async def text_from_resource_or_url(filename):
|
|
1379
2183
|
text = text_from_resource(filename)
|
|
1380
2184
|
if not text:
|
|
@@ -1386,10 +2190,17 @@ async def text_from_resource_or_url(filename):
|
|
|
1386
2190
|
raise e
|
|
1387
2191
|
return text
|
|
1388
2192
|
|
|
2193
|
+
|
|
1389
2194
|
async def save_home_configs():
|
|
1390
2195
|
home_config_path = home_llms_path("llms.json")
|
|
1391
|
-
|
|
1392
|
-
|
|
2196
|
+
home_providers_path = home_llms_path("providers.json")
|
|
2197
|
+
home_providers_extra_path = home_llms_path("providers-extra.json")
|
|
2198
|
+
|
|
2199
|
+
if (
|
|
2200
|
+
os.path.exists(home_config_path)
|
|
2201
|
+
and os.path.exists(home_providers_path)
|
|
2202
|
+
and os.path.exists(home_providers_extra_path)
|
|
2203
|
+
):
|
|
1393
2204
|
return
|
|
1394
2205
|
|
|
1395
2206
|
llms_home = os.path.dirname(home_config_path)
|
|
@@ -1397,114 +2208,717 @@ async def save_home_configs():
|
|
|
1397
2208
|
try:
|
|
1398
2209
|
if not os.path.exists(home_config_path):
|
|
1399
2210
|
config_json = await text_from_resource_or_url("llms.json")
|
|
1400
|
-
with open(home_config_path, "w") as f:
|
|
2211
|
+
with open(home_config_path, "w", encoding="utf-8") as f:
|
|
1401
2212
|
f.write(config_json)
|
|
1402
2213
|
_log(f"Created default config at {home_config_path}")
|
|
1403
2214
|
|
|
1404
|
-
if not os.path.exists(
|
|
1405
|
-
|
|
1406
|
-
with open(
|
|
1407
|
-
f.write(
|
|
1408
|
-
_log(f"Created default
|
|
1409
|
-
|
|
2215
|
+
if not os.path.exists(home_providers_path):
|
|
2216
|
+
providers_json = await text_from_resource_or_url("providers.json")
|
|
2217
|
+
with open(home_providers_path, "w", encoding="utf-8") as f:
|
|
2218
|
+
f.write(providers_json)
|
|
2219
|
+
_log(f"Created default providers config at {home_providers_path}")
|
|
2220
|
+
|
|
2221
|
+
if not os.path.exists(home_providers_extra_path):
|
|
2222
|
+
extra_json = await text_from_resource_or_url("providers-extra.json")
|
|
2223
|
+
with open(home_providers_extra_path, "w", encoding="utf-8") as f:
|
|
2224
|
+
f.write(extra_json)
|
|
2225
|
+
_log(f"Created default extra providers config at {home_providers_extra_path}")
|
|
2226
|
+
except Exception:
|
|
1410
2227
|
print("Could not create llms.json. Create one with --init or use --config <path>")
|
|
1411
2228
|
exit(1)
|
|
1412
2229
|
|
|
2230
|
+
|
|
2231
|
+
def load_config_json(config_json):
|
|
2232
|
+
if config_json is None:
|
|
2233
|
+
return None
|
|
2234
|
+
config = json.loads(config_json)
|
|
2235
|
+
if not config or "version" not in config or config["version"] < 3:
|
|
2236
|
+
preserve_keys = ["auth", "defaults", "limits", "convert"]
|
|
2237
|
+
new_config = json.loads(text_from_resource("llms.json"))
|
|
2238
|
+
if config:
|
|
2239
|
+
for key in preserve_keys:
|
|
2240
|
+
if key in config:
|
|
2241
|
+
new_config[key] = config[key]
|
|
2242
|
+
config = new_config
|
|
2243
|
+
# move old config to YYYY-MM-DD.bak
|
|
2244
|
+
new_path = f"{g_config_path}.{datetime.now().strftime('%Y-%m-%d')}.bak"
|
|
2245
|
+
if os.path.exists(new_path):
|
|
2246
|
+
os.remove(new_path)
|
|
2247
|
+
os.rename(g_config_path, new_path)
|
|
2248
|
+
print(f"llms.json migrated. old config moved to {new_path}")
|
|
2249
|
+
# save new config
|
|
2250
|
+
save_config(g_config)
|
|
2251
|
+
return config
|
|
2252
|
+
|
|
2253
|
+
|
|
1413
2254
|
async def reload_providers():
|
|
1414
2255
|
global g_config, g_handlers
|
|
1415
|
-
g_handlers = init_llms(g_config)
|
|
2256
|
+
g_handlers = init_llms(g_config, g_providers)
|
|
1416
2257
|
await load_llms()
|
|
1417
2258
|
_log(f"{len(g_handlers)} providers loaded")
|
|
1418
2259
|
return g_handlers
|
|
1419
2260
|
|
|
1420
|
-
|
|
2261
|
+
|
|
2262
|
+
async def watch_config_files(config_path, providers_path, interval=1):
|
|
1421
2263
|
"""Watch config files and reload providers when they change"""
|
|
1422
2264
|
global g_config
|
|
1423
2265
|
|
|
1424
2266
|
config_path = Path(config_path)
|
|
1425
|
-
|
|
2267
|
+
providers_path = Path(providers_path)
|
|
1426
2268
|
|
|
1427
|
-
|
|
2269
|
+
_log(f"Watching config file: {config_path}")
|
|
2270
|
+
_log(f"Watching providers file: {providers_path}")
|
|
1428
2271
|
|
|
1429
|
-
|
|
2272
|
+
def get_latest_mtime():
|
|
2273
|
+
ret = 0
|
|
2274
|
+
name = "llms.json"
|
|
2275
|
+
if config_path.is_file():
|
|
2276
|
+
ret = config_path.stat().st_mtime
|
|
2277
|
+
name = config_path.name
|
|
2278
|
+
if providers_path.is_file() and providers_path.stat().st_mtime > ret:
|
|
2279
|
+
ret = providers_path.stat().st_mtime
|
|
2280
|
+
name = providers_path.name
|
|
2281
|
+
return ret, name
|
|
2282
|
+
|
|
2283
|
+
latest_mtime, name = get_latest_mtime()
|
|
1430
2284
|
|
|
1431
2285
|
while True:
|
|
1432
2286
|
await asyncio.sleep(interval)
|
|
1433
2287
|
|
|
1434
2288
|
# Check llms.json
|
|
1435
2289
|
try:
|
|
1436
|
-
|
|
1437
|
-
|
|
1438
|
-
|
|
1439
|
-
|
|
1440
|
-
file_mtimes[str(config_path)] = mtime
|
|
1441
|
-
elif file_mtimes[str(config_path)] != mtime:
|
|
1442
|
-
_log(f"Config file changed: {config_path.name}")
|
|
1443
|
-
file_mtimes[str(config_path)] = mtime
|
|
2290
|
+
new_mtime, name = get_latest_mtime()
|
|
2291
|
+
if new_mtime > latest_mtime:
|
|
2292
|
+
_log(f"Config file changed: {name}")
|
|
2293
|
+
latest_mtime = new_mtime
|
|
1444
2294
|
|
|
1445
|
-
|
|
1446
|
-
|
|
1447
|
-
|
|
1448
|
-
|
|
2295
|
+
try:
|
|
2296
|
+
# Reload llms.json
|
|
2297
|
+
with open(config_path) as f:
|
|
2298
|
+
g_config = json.load(f)
|
|
1449
2299
|
|
|
1450
|
-
|
|
1451
|
-
|
|
1452
|
-
|
|
1453
|
-
|
|
1454
|
-
|
|
2300
|
+
# Reload providers
|
|
2301
|
+
await reload_providers()
|
|
2302
|
+
_log("Providers reloaded successfully")
|
|
2303
|
+
except Exception as e:
|
|
2304
|
+
_log(f"Error reloading config: {e}")
|
|
1455
2305
|
except FileNotFoundError:
|
|
1456
2306
|
pass
|
|
1457
2307
|
|
|
1458
|
-
|
|
1459
|
-
|
|
2308
|
+
|
|
2309
|
+
def get_session_token(request):
|
|
2310
|
+
return request.query.get("session") or request.headers.get("X-Session-Token") or request.cookies.get("llms-token")
|
|
2311
|
+
|
|
2312
|
+
|
|
2313
|
+
class AppExtensions:
|
|
2314
|
+
"""
|
|
2315
|
+
APIs extensions can use to extend the app
|
|
2316
|
+
"""
|
|
2317
|
+
|
|
2318
|
+
def __init__(self, cli_args, extra_args):
|
|
2319
|
+
self.cli_args = cli_args
|
|
2320
|
+
self.extra_args = extra_args
|
|
2321
|
+
self.config = None
|
|
2322
|
+
self.error_auth_required = create_error_response("Authentication required", "Unauthorized")
|
|
2323
|
+
self.ui_extensions = []
|
|
2324
|
+
self.chat_request_filters = []
|
|
2325
|
+
self.chat_tool_filters = []
|
|
2326
|
+
self.chat_response_filters = []
|
|
2327
|
+
self.chat_error_filters = []
|
|
2328
|
+
self.server_add_get = []
|
|
2329
|
+
self.server_add_post = []
|
|
2330
|
+
self.server_add_put = []
|
|
2331
|
+
self.server_add_delete = []
|
|
2332
|
+
self.server_add_patch = []
|
|
2333
|
+
self.cache_saved_filters = []
|
|
2334
|
+
self.shutdown_handlers = []
|
|
2335
|
+
self.tools = {}
|
|
2336
|
+
self.tool_definitions = []
|
|
2337
|
+
self.index_headers = []
|
|
2338
|
+
self.index_footers = []
|
|
2339
|
+
self.request_args = {
|
|
2340
|
+
"image_config": dict, # e.g. { "aspect_ratio": "1:1" }
|
|
2341
|
+
"temperature": float, # e.g: 0.7
|
|
2342
|
+
"max_completion_tokens": int, # e.g: 2048
|
|
2343
|
+
"seed": int, # e.g: 42
|
|
2344
|
+
"top_p": float, # e.g: 0.9
|
|
2345
|
+
"frequency_penalty": float, # e.g: 0.5
|
|
2346
|
+
"presence_penalty": float, # e.g: 0.5
|
|
2347
|
+
"stop": list, # e.g: ["Stop"]
|
|
2348
|
+
"reasoning_effort": str, # e.g: minimal, low, medium, high
|
|
2349
|
+
"verbosity": str, # e.g: low, medium, high
|
|
2350
|
+
"service_tier": str, # e.g: auto, default
|
|
2351
|
+
"top_logprobs": int,
|
|
2352
|
+
"safety_identifier": str,
|
|
2353
|
+
"store": bool,
|
|
2354
|
+
"enable_thinking": bool,
|
|
2355
|
+
}
|
|
2356
|
+
self.all_providers = [
|
|
2357
|
+
OpenAiCompatible,
|
|
2358
|
+
MistralProvider,
|
|
2359
|
+
GroqProvider,
|
|
2360
|
+
XaiProvider,
|
|
2361
|
+
CodestralProvider,
|
|
2362
|
+
OllamaProvider,
|
|
2363
|
+
LMStudioProvider,
|
|
2364
|
+
]
|
|
2365
|
+
self.aspect_ratios = {
|
|
2366
|
+
"1:1": "1024×1024",
|
|
2367
|
+
"2:3": "832×1248",
|
|
2368
|
+
"3:2": "1248×832",
|
|
2369
|
+
"3:4": "864×1184",
|
|
2370
|
+
"4:3": "1184×864",
|
|
2371
|
+
"4:5": "896×1152",
|
|
2372
|
+
"5:4": "1152×896",
|
|
2373
|
+
"9:16": "768×1344",
|
|
2374
|
+
"16:9": "1344×768",
|
|
2375
|
+
"21:9": "1536×672",
|
|
2376
|
+
}
|
|
2377
|
+
self.import_maps = {
|
|
2378
|
+
"vue-prod": "/ui/lib/vue.min.mjs",
|
|
2379
|
+
"vue": "/ui/lib/vue.mjs",
|
|
2380
|
+
"vue-router": "/ui/lib/vue-router.min.mjs",
|
|
2381
|
+
"@servicestack/client": "/ui/lib/servicestack-client.mjs",
|
|
2382
|
+
"@servicestack/vue": "/ui/lib/servicestack-vue.mjs",
|
|
2383
|
+
"idb": "/ui/lib/idb.min.mjs",
|
|
2384
|
+
"marked": "/ui/lib/marked.min.mjs",
|
|
2385
|
+
"highlight.js": "/ui/lib/highlight.min.mjs",
|
|
2386
|
+
"chart.js": "/ui/lib/chart.js",
|
|
2387
|
+
"color.js": "/ui/lib/color.js",
|
|
2388
|
+
"ctx.mjs": "/ui/ctx.mjs",
|
|
2389
|
+
}
|
|
2390
|
+
|
|
2391
|
+
def set_config(self, config):
|
|
2392
|
+
self.config = config
|
|
2393
|
+
self.auth_enabled = self.config.get("auth", {}).get("enabled", False)
|
|
2394
|
+
|
|
2395
|
+
# Authentication middleware helper
|
|
2396
|
+
def check_auth(self, request):
|
|
2397
|
+
"""Check if request is authenticated. Returns (is_authenticated, user_data)"""
|
|
2398
|
+
if not self.auth_enabled:
|
|
2399
|
+
return True, None
|
|
2400
|
+
|
|
2401
|
+
# Check for OAuth session token
|
|
2402
|
+
session_token = get_session_token(request)
|
|
2403
|
+
if session_token and session_token in g_sessions:
|
|
2404
|
+
return True, g_sessions[session_token]
|
|
2405
|
+
|
|
2406
|
+
# Check for API key
|
|
2407
|
+
auth_header = request.headers.get("Authorization", "")
|
|
2408
|
+
if auth_header.startswith("Bearer "):
|
|
2409
|
+
api_key = auth_header[7:]
|
|
2410
|
+
if api_key:
|
|
2411
|
+
return True, {"authProvider": "apikey"}
|
|
2412
|
+
|
|
2413
|
+
return False, None
|
|
2414
|
+
|
|
2415
|
+
def get_session(self, request):
|
|
2416
|
+
session_token = get_session_token(request)
|
|
2417
|
+
|
|
2418
|
+
if not session_token or session_token not in g_sessions:
|
|
2419
|
+
return None
|
|
2420
|
+
|
|
2421
|
+
session_data = g_sessions[session_token]
|
|
2422
|
+
return session_data
|
|
2423
|
+
|
|
2424
|
+
def get_username(self, request):
|
|
2425
|
+
session = self.get_session(request)
|
|
2426
|
+
if session:
|
|
2427
|
+
return session.get("userName")
|
|
2428
|
+
return None
|
|
2429
|
+
|
|
2430
|
+
def get_user_path(self, username=None):
|
|
2431
|
+
if username:
|
|
2432
|
+
return home_llms_path(os.path.join("user", username))
|
|
2433
|
+
return home_llms_path(os.path.join("user", "default"))
|
|
2434
|
+
|
|
2435
|
+
def chat_request(self, template=None, text=None, model=None, system_prompt=None):
|
|
2436
|
+
return g_chat_request(template=template, text=text, model=model, system_prompt=system_prompt)
|
|
2437
|
+
|
|
2438
|
+
async def chat_completion(self, chat, context=None):
|
|
2439
|
+
response = await g_chat_completion(chat, context)
|
|
2440
|
+
return response
|
|
2441
|
+
|
|
2442
|
+
def on_cache_saved_filters(self, context):
|
|
2443
|
+
# _log(f"on_cache_saved_filters {len(self.cache_saved_filters)}: {context['url']}")
|
|
2444
|
+
for filter_func in self.cache_saved_filters:
|
|
2445
|
+
filter_func(context)
|
|
2446
|
+
|
|
2447
|
+
async def on_chat_error(self, e, context):
|
|
2448
|
+
# Apply chat error filters
|
|
2449
|
+
if "stackTrace" not in context:
|
|
2450
|
+
context["stackTrace"] = traceback.format_exc()
|
|
2451
|
+
for filter_func in self.chat_error_filters:
|
|
1460
2452
|
try:
|
|
1461
|
-
|
|
1462
|
-
|
|
1463
|
-
|
|
1464
|
-
|
|
1465
|
-
|
|
1466
|
-
|
|
1467
|
-
|
|
1468
|
-
|
|
1469
|
-
|
|
1470
|
-
|
|
1471
|
-
|
|
2453
|
+
await filter_func(e, context)
|
|
2454
|
+
except Exception as e:
|
|
2455
|
+
_err("chat error filter failed", e)
|
|
2456
|
+
|
|
2457
|
+
async def on_chat_tool(self, chat, context):
|
|
2458
|
+
m_len = len(chat.get("messages", []))
|
|
2459
|
+
t_len = len(self.chat_tool_filters)
|
|
2460
|
+
_dbg(
|
|
2461
|
+
f"on_tool_call for thread {context.get('threadId', None)} with {m_len} {pluralize('message', m_len)}, invoking {t_len} {pluralize('filter', t_len)}:"
|
|
2462
|
+
)
|
|
2463
|
+
for filter_func in self.chat_tool_filters:
|
|
2464
|
+
await filter_func(chat, context)
|
|
2465
|
+
|
|
2466
|
+
def exit(self, exit_code=0):
|
|
2467
|
+
if len(self.shutdown_handlers) > 0:
|
|
2468
|
+
_dbg(f"running {len(self.shutdown_handlers)} shutdown handlers...")
|
|
2469
|
+
for handler in self.shutdown_handlers:
|
|
2470
|
+
handler()
|
|
2471
|
+
|
|
2472
|
+
_dbg(f"exit({exit_code})")
|
|
2473
|
+
sys.exit(exit_code)
|
|
2474
|
+
|
|
2475
|
+
|
|
2476
|
+
def handler_name(handler):
|
|
2477
|
+
if hasattr(handler, "__name__"):
|
|
2478
|
+
return handler.__name__
|
|
2479
|
+
return "unknown"
|
|
2480
|
+
|
|
2481
|
+
|
|
2482
|
+
class ExtensionContext:
|
|
2483
|
+
def __init__(self, app, path):
|
|
2484
|
+
self.app = app
|
|
2485
|
+
self.cli_args = app.cli_args
|
|
2486
|
+
self.extra_args = app.extra_args
|
|
2487
|
+
self.error_auth_required = app.error_auth_required
|
|
2488
|
+
self.path = path
|
|
2489
|
+
self.name = os.path.basename(path)
|
|
2490
|
+
if self.name.endswith(".py"):
|
|
2491
|
+
self.name = self.name[:-3]
|
|
2492
|
+
self.ext_prefix = f"/ext/{self.name}"
|
|
2493
|
+
self.MOCK = MOCK
|
|
2494
|
+
self.MOCK_DIR = MOCK_DIR
|
|
2495
|
+
self.debug = DEBUG
|
|
2496
|
+
self.verbose = g_verbose
|
|
2497
|
+
self.aspect_ratios = app.aspect_ratios
|
|
2498
|
+
self.request_args = app.request_args
|
|
2499
|
+
|
|
2500
|
+
def chat_to_prompt(self, chat):
|
|
2501
|
+
return chat_to_prompt(chat)
|
|
2502
|
+
|
|
2503
|
+
def chat_to_system_prompt(self, chat):
|
|
2504
|
+
return chat_to_system_prompt(chat)
|
|
2505
|
+
|
|
2506
|
+
def chat_response_to_message(self, response):
|
|
2507
|
+
return chat_response_to_message(response)
|
|
2508
|
+
|
|
2509
|
+
def last_user_prompt(self, chat):
|
|
2510
|
+
return last_user_prompt(chat)
|
|
2511
|
+
|
|
2512
|
+
def to_file_info(self, chat, info=None, response=None):
|
|
2513
|
+
return to_file_info(chat, info=info, response=response)
|
|
2514
|
+
|
|
2515
|
+
def save_image_to_cache(self, base64_data, filename, image_info):
|
|
2516
|
+
return save_image_to_cache(base64_data, filename, image_info)
|
|
2517
|
+
|
|
2518
|
+
def save_bytes_to_cache(self, bytes_data, filename, file_info):
|
|
2519
|
+
return save_bytes_to_cache(bytes_data, filename, file_info)
|
|
2520
|
+
|
|
2521
|
+
def text_from_file(self, path):
|
|
2522
|
+
return text_from_file(path)
|
|
2523
|
+
|
|
2524
|
+
def log(self, message):
|
|
2525
|
+
if self.verbose:
|
|
2526
|
+
print(f"[{self.name}] {message}", flush=True)
|
|
2527
|
+
return message
|
|
1472
2528
|
|
|
1473
|
-
def
|
|
1474
|
-
|
|
2529
|
+
def log_json(self, obj):
|
|
2530
|
+
if self.verbose:
|
|
2531
|
+
print(f"[{self.name}] {json.dumps(obj, indent=2)}", flush=True)
|
|
2532
|
+
return obj
|
|
1475
2533
|
|
|
1476
|
-
|
|
1477
|
-
|
|
1478
|
-
|
|
2534
|
+
def dbg(self, message):
|
|
2535
|
+
if self.debug:
|
|
2536
|
+
print(f"DEBUG [{self.name}]: {message}", flush=True)
|
|
1479
2537
|
|
|
1480
|
-
|
|
1481
|
-
|
|
1482
|
-
|
|
1483
|
-
|
|
1484
|
-
parser.add_argument('--file', default=None, help='File input to use in chat completion')
|
|
1485
|
-
parser.add_argument('--args', default=None, help='URL-encoded parameters to add to chat request (e.g. "temperature=0.7&seed=111")', metavar='PARAMS')
|
|
1486
|
-
parser.add_argument('--raw', action='store_true', help='Return raw AI JSON response')
|
|
2538
|
+
def err(self, message, e):
|
|
2539
|
+
print(f"ERROR [{self.name}]: {message}", e)
|
|
2540
|
+
if self.verbose:
|
|
2541
|
+
print(traceback.format_exc(), flush=True)
|
|
1487
2542
|
|
|
1488
|
-
|
|
1489
|
-
|
|
2543
|
+
def error_message(self, e):
|
|
2544
|
+
return to_error_message(e)
|
|
1490
2545
|
|
|
1491
|
-
|
|
2546
|
+
def error_response(self, e, stacktrace=False):
|
|
2547
|
+
return to_error_response(e, stacktrace=stacktrace)
|
|
1492
2548
|
|
|
1493
|
-
|
|
1494
|
-
|
|
1495
|
-
|
|
2549
|
+
def add_provider(self, provider):
|
|
2550
|
+
self.log(f"Registered provider: {provider.__name__}")
|
|
2551
|
+
self.app.all_providers.append(provider)
|
|
1496
2552
|
|
|
1497
|
-
|
|
2553
|
+
def register_ui_extension(self, index):
|
|
2554
|
+
path = os.path.join(self.ext_prefix, index)
|
|
2555
|
+
self.log(f"Registered UI extension: {path}")
|
|
2556
|
+
self.app.ui_extensions.append({"id": self.name, "path": path})
|
|
1498
2557
|
|
|
1499
|
-
|
|
1500
|
-
|
|
1501
|
-
|
|
2558
|
+
def register_chat_request_filter(self, handler):
|
|
2559
|
+
self.log(f"Registered chat request filter: {handler_name(handler)}")
|
|
2560
|
+
self.app.chat_request_filters.append(handler)
|
|
2561
|
+
|
|
2562
|
+
def register_chat_tool_filter(self, handler):
|
|
2563
|
+
self.log(f"Registered chat tool filter: {handler_name(handler)}")
|
|
2564
|
+
self.app.chat_tool_filters.append(handler)
|
|
2565
|
+
|
|
2566
|
+
def register_chat_response_filter(self, handler):
|
|
2567
|
+
self.log(f"Registered chat response filter: {handler_name(handler)}")
|
|
2568
|
+
self.app.chat_response_filters.append(handler)
|
|
2569
|
+
|
|
2570
|
+
def register_chat_error_filter(self, handler):
|
|
2571
|
+
self.log(f"Registered chat error filter: {handler_name(handler)}")
|
|
2572
|
+
self.app.chat_error_filters.append(handler)
|
|
2573
|
+
|
|
2574
|
+
def register_cache_saved_filter(self, handler):
|
|
2575
|
+
self.log(f"Registered cache saved filter: {handler_name(handler)}")
|
|
2576
|
+
self.app.cache_saved_filters.append(handler)
|
|
2577
|
+
|
|
2578
|
+
def register_shutdown_handler(self, handler):
|
|
2579
|
+
self.log(f"Registered shutdown handler: {handler_name(handler)}")
|
|
2580
|
+
self.app.shutdown_handlers.append(handler)
|
|
2581
|
+
|
|
2582
|
+
def add_static_files(self, ext_dir):
|
|
2583
|
+
self.log(f"Registered static files: {ext_dir}")
|
|
2584
|
+
|
|
2585
|
+
async def serve_static(request):
|
|
2586
|
+
path = request.match_info["path"]
|
|
2587
|
+
file_path = os.path.join(ext_dir, path)
|
|
2588
|
+
if os.path.exists(file_path):
|
|
2589
|
+
return web.FileResponse(file_path)
|
|
2590
|
+
return web.Response(status=404)
|
|
2591
|
+
|
|
2592
|
+
self.app.server_add_get.append((os.path.join(self.ext_prefix, "{path:.*}"), serve_static, {}))
|
|
2593
|
+
|
|
2594
|
+
def add_get(self, path, handler, **kwargs):
|
|
2595
|
+
self.dbg(f"Registered GET: {os.path.join(self.ext_prefix, path)}")
|
|
2596
|
+
self.app.server_add_get.append((os.path.join(self.ext_prefix, path), handler, kwargs))
|
|
2597
|
+
|
|
2598
|
+
def add_post(self, path, handler, **kwargs):
|
|
2599
|
+
self.dbg(f"Registered POST: {os.path.join(self.ext_prefix, path)}")
|
|
2600
|
+
self.app.server_add_post.append((os.path.join(self.ext_prefix, path), handler, kwargs))
|
|
2601
|
+
|
|
2602
|
+
def add_put(self, path, handler, **kwargs):
|
|
2603
|
+
self.dbg(f"Registered PUT: {os.path.join(self.ext_prefix, path)}")
|
|
2604
|
+
self.app.server_add_put.append((os.path.join(self.ext_prefix, path), handler, kwargs))
|
|
2605
|
+
|
|
2606
|
+
def add_delete(self, path, handler, **kwargs):
|
|
2607
|
+
self.dbg(f"Registered DELETE: {os.path.join(self.ext_prefix, path)}")
|
|
2608
|
+
self.app.server_add_delete.append((os.path.join(self.ext_prefix, path), handler, kwargs))
|
|
2609
|
+
|
|
2610
|
+
def add_patch(self, path, handler, **kwargs):
|
|
2611
|
+
self.dbg(f"Registered PATCH: {os.path.join(self.ext_prefix, path)}")
|
|
2612
|
+
self.app.server_add_patch.append((os.path.join(self.ext_prefix, path), handler, kwargs))
|
|
2613
|
+
|
|
2614
|
+
def add_importmaps(self, dict):
|
|
2615
|
+
self.app.import_maps.update(dict)
|
|
2616
|
+
|
|
2617
|
+
def add_index_header(self, html):
|
|
2618
|
+
self.app.index_headers.append(html)
|
|
2619
|
+
|
|
2620
|
+
def add_index_footer(self, html):
|
|
2621
|
+
self.app.index_footers.append(html)
|
|
2622
|
+
|
|
2623
|
+
def get_config(self):
|
|
2624
|
+
return g_config
|
|
2625
|
+
|
|
2626
|
+
def get_cache_path(self, path=""):
|
|
2627
|
+
return get_cache_path(path)
|
|
2628
|
+
|
|
2629
|
+
def chat_request(self, template=None, text=None, model=None, system_prompt=None):
|
|
2630
|
+
return self.app.chat_request(template=template, text=text, model=model, system_prompt=system_prompt)
|
|
2631
|
+
|
|
2632
|
+
def chat_completion(self, chat, context=None):
|
|
2633
|
+
return self.app.chat_completion(chat, context=context)
|
|
2634
|
+
|
|
2635
|
+
def get_providers(self):
|
|
2636
|
+
return g_handlers
|
|
2637
|
+
|
|
2638
|
+
def get_provider(self, name):
|
|
2639
|
+
return g_handlers.get(name)
|
|
2640
|
+
|
|
2641
|
+
def register_tool(self, func, tool_def=None):
|
|
2642
|
+
if tool_def is None:
|
|
2643
|
+
tool_def = function_to_tool_definition(func)
|
|
2644
|
+
|
|
2645
|
+
name = tool_def["function"]["name"]
|
|
2646
|
+
self.log(f"Registered tool: {name}")
|
|
2647
|
+
self.app.tools[name] = func
|
|
2648
|
+
self.app.tool_definitions.append(tool_def)
|
|
2649
|
+
|
|
2650
|
+
def check_auth(self, request):
|
|
2651
|
+
return self.app.check_auth(request)
|
|
2652
|
+
|
|
2653
|
+
def get_session(self, request):
|
|
2654
|
+
return self.app.get_session(request)
|
|
2655
|
+
|
|
2656
|
+
def get_username(self, request):
|
|
2657
|
+
return self.app.get_username(request)
|
|
2658
|
+
|
|
2659
|
+
def get_user_path(self, username=None):
|
|
2660
|
+
return self.app.get_user_path(username)
|
|
2661
|
+
|
|
2662
|
+
def should_cancel_thread(self, context):
|
|
2663
|
+
return should_cancel_thread(context)
|
|
2664
|
+
|
|
2665
|
+
def cache_message_inline_data(self, message):
|
|
2666
|
+
return cache_message_inline_data(message)
|
|
2667
|
+
|
|
2668
|
+
def to_content(self, result):
|
|
2669
|
+
return to_content(result)
|
|
2670
|
+
|
|
2671
|
+
|
|
2672
|
+
def get_extensions_path():
|
|
2673
|
+
return os.getenv("LLMS_EXTENSIONS_DIR", os.path.join(Path.home(), ".llms", "extensions"))
|
|
2674
|
+
|
|
2675
|
+
|
|
2676
|
+
def get_disabled_extensions():
|
|
2677
|
+
ret = DISABLE_EXTENSIONS.copy()
|
|
2678
|
+
if g_config:
|
|
2679
|
+
for ext in g_config.get("disable_extensions", []):
|
|
2680
|
+
if ext not in ret:
|
|
2681
|
+
ret.append(ext)
|
|
2682
|
+
return ret
|
|
2683
|
+
|
|
2684
|
+
|
|
2685
|
+
def get_extensions_dirs():
|
|
2686
|
+
"""
|
|
2687
|
+
Returns a list of extension directories.
|
|
2688
|
+
"""
|
|
2689
|
+
extensions_path = get_extensions_path()
|
|
2690
|
+
os.makedirs(extensions_path, exist_ok=True)
|
|
2691
|
+
|
|
2692
|
+
# allow overriding builtin extensions
|
|
2693
|
+
override_extensions = []
|
|
2694
|
+
if os.path.exists(extensions_path):
|
|
2695
|
+
override_extensions = os.listdir(extensions_path)
|
|
2696
|
+
|
|
2697
|
+
ret = []
|
|
2698
|
+
disabled_extensions = get_disabled_extensions()
|
|
2699
|
+
|
|
2700
|
+
builtin_extensions_dir = _ROOT / "extensions"
|
|
2701
|
+
if os.path.exists(builtin_extensions_dir):
|
|
2702
|
+
for item in os.listdir(builtin_extensions_dir):
|
|
2703
|
+
if os.path.isdir(os.path.join(builtin_extensions_dir, item)):
|
|
2704
|
+
if item in override_extensions:
|
|
2705
|
+
continue
|
|
2706
|
+
if item in disabled_extensions:
|
|
2707
|
+
continue
|
|
2708
|
+
ret.append(os.path.join(builtin_extensions_dir, item))
|
|
2709
|
+
|
|
2710
|
+
if os.path.exists(extensions_path):
|
|
2711
|
+
for item in os.listdir(extensions_path):
|
|
2712
|
+
if os.path.isdir(os.path.join(extensions_path, item)):
|
|
2713
|
+
if item in disabled_extensions:
|
|
2714
|
+
continue
|
|
2715
|
+
ret.append(os.path.join(extensions_path, item))
|
|
2716
|
+
|
|
2717
|
+
return ret
|
|
2718
|
+
|
|
2719
|
+
|
|
2720
|
+
def init_extensions(parser):
|
|
2721
|
+
"""
|
|
2722
|
+
Initializes extensions by loading their __init__.py files and calling the __parser__ function if it exists.
|
|
2723
|
+
"""
|
|
2724
|
+
for item_path in get_extensions_dirs():
|
|
2725
|
+
item = os.path.basename(item_path)
|
|
2726
|
+
|
|
2727
|
+
if os.path.isdir(item_path):
|
|
2728
|
+
try:
|
|
2729
|
+
# check for __parser__ function if exists in __init.__.py and call it with parser
|
|
2730
|
+
init_file = os.path.join(item_path, "__init__.py")
|
|
2731
|
+
if os.path.exists(init_file):
|
|
2732
|
+
spec = importlib.util.spec_from_file_location(item, init_file)
|
|
2733
|
+
if spec and spec.loader:
|
|
2734
|
+
module = importlib.util.module_from_spec(spec)
|
|
2735
|
+
sys.modules[item] = module
|
|
2736
|
+
spec.loader.exec_module(module)
|
|
2737
|
+
|
|
2738
|
+
parser_func = getattr(module, "__parser__", None)
|
|
2739
|
+
if callable(parser_func):
|
|
2740
|
+
parser_func(parser)
|
|
2741
|
+
_log(f"Extension {item} parser loaded")
|
|
2742
|
+
except Exception as e:
|
|
2743
|
+
_err(f"Failed to load extension {item} parser", e)
|
|
2744
|
+
|
|
2745
|
+
|
|
2746
|
+
def install_extensions():
|
|
2747
|
+
"""
|
|
2748
|
+
Scans ensure ~/.llms/extensions/ for directories with __init__.py and loads them as extensions.
|
|
2749
|
+
Calls the `__install__(ctx)` function in the extension module.
|
|
2750
|
+
"""
|
|
2751
|
+
|
|
2752
|
+
extension_dirs = get_extensions_dirs()
|
|
2753
|
+
ext_count = len(list(extension_dirs))
|
|
2754
|
+
if ext_count == 0:
|
|
2755
|
+
_log("No extensions found")
|
|
2756
|
+
return
|
|
2757
|
+
|
|
2758
|
+
disabled_extensions = get_disabled_extensions()
|
|
2759
|
+
if len(disabled_extensions) > 0:
|
|
2760
|
+
_log(f"Disabled extensions: {', '.join(disabled_extensions)}")
|
|
2761
|
+
|
|
2762
|
+
_log(f"Installing {ext_count} extension{'' if ext_count == 1 else 's'}...")
|
|
2763
|
+
|
|
2764
|
+
for item_path in extension_dirs:
|
|
2765
|
+
item = os.path.basename(item_path)
|
|
2766
|
+
|
|
2767
|
+
if os.path.isdir(item_path):
|
|
2768
|
+
sys.path.append(item_path)
|
|
2769
|
+
try:
|
|
2770
|
+
ctx = ExtensionContext(g_app, item_path)
|
|
2771
|
+
init_file = os.path.join(item_path, "__init__.py")
|
|
2772
|
+
if os.path.exists(init_file):
|
|
2773
|
+
spec = importlib.util.spec_from_file_location(item, init_file)
|
|
2774
|
+
if spec and spec.loader:
|
|
2775
|
+
module = importlib.util.module_from_spec(spec)
|
|
2776
|
+
sys.modules[item] = module
|
|
2777
|
+
spec.loader.exec_module(module)
|
|
2778
|
+
|
|
2779
|
+
install_func = getattr(module, "__install__", None)
|
|
2780
|
+
if callable(install_func):
|
|
2781
|
+
install_func(ctx)
|
|
2782
|
+
_log(f"Extension {item} installed")
|
|
2783
|
+
else:
|
|
2784
|
+
_dbg(f"Extension {item} has no __install__ function")
|
|
2785
|
+
else:
|
|
2786
|
+
_dbg(f"Extension {item} has no __init__.py")
|
|
2787
|
+
else:
|
|
2788
|
+
_dbg(f"Extension {init_file} not found")
|
|
2789
|
+
|
|
2790
|
+
# if ui folder exists, serve as static files at /ext/{item}/
|
|
2791
|
+
ui_path = os.path.join(item_path, "ui")
|
|
2792
|
+
if os.path.exists(ui_path):
|
|
2793
|
+
ctx.add_static_files(ui_path)
|
|
2794
|
+
|
|
2795
|
+
# Register UI extension if index.mjs exists (/ext/{item}/index.mjs)
|
|
2796
|
+
if os.path.exists(os.path.join(ui_path, "index.mjs")):
|
|
2797
|
+
ctx.register_ui_extension("index.mjs")
|
|
2798
|
+
|
|
2799
|
+
except Exception as e:
|
|
2800
|
+
_err(f"Failed to install extension {item}", e)
|
|
2801
|
+
else:
|
|
2802
|
+
_dbg(f"Extension {item} not found: {item_path} is not a directory {os.path.exists(item_path)}")
|
|
2803
|
+
|
|
2804
|
+
|
|
2805
|
+
def run_extension_cli():
|
|
2806
|
+
"""
|
|
2807
|
+
Run the CLI for an extension.
|
|
2808
|
+
"""
|
|
2809
|
+
for item_path in get_extensions_dirs():
|
|
2810
|
+
item = os.path.basename(item_path)
|
|
2811
|
+
|
|
2812
|
+
if os.path.isdir(item_path):
|
|
2813
|
+
init_file = os.path.join(item_path, "__init__.py")
|
|
2814
|
+
if os.path.exists(init_file):
|
|
2815
|
+
ctx = ExtensionContext(g_app, item_path)
|
|
2816
|
+
try:
|
|
2817
|
+
spec = importlib.util.spec_from_file_location(item, init_file)
|
|
2818
|
+
if spec and spec.loader:
|
|
2819
|
+
module = importlib.util.module_from_spec(spec)
|
|
2820
|
+
sys.modules[item] = module
|
|
2821
|
+
spec.loader.exec_module(module)
|
|
2822
|
+
|
|
2823
|
+
# Check for __run__ function if exists in __init__.py and call it with ctx
|
|
2824
|
+
run_func = getattr(module, "__run__", None)
|
|
2825
|
+
if callable(run_func):
|
|
2826
|
+
_log(f"Running extension {item}...")
|
|
2827
|
+
handled = run_func(ctx)
|
|
2828
|
+
return handled
|
|
2829
|
+
|
|
2830
|
+
except Exception as e:
|
|
2831
|
+
_err(f"Failed to run extension {item}", e)
|
|
2832
|
+
return False
|
|
2833
|
+
|
|
2834
|
+
|
|
2835
|
+
def main():
|
|
2836
|
+
global _ROOT, g_verbose, g_default_model, g_logprefix, g_providers, g_config, g_config_path, g_app
|
|
2837
|
+
|
|
2838
|
+
_ROOT = os.getenv("LLMS_ROOT", resolve_root())
|
|
2839
|
+
if not _ROOT:
|
|
2840
|
+
print("Resource root not found")
|
|
2841
|
+
exit(1)
|
|
2842
|
+
|
|
2843
|
+
parser = argparse.ArgumentParser(description=f"llms v{VERSION}")
|
|
2844
|
+
parser.add_argument("--config", default=None, help="Path to config file", metavar="FILE")
|
|
2845
|
+
parser.add_argument("--providers", default=None, help="Path to models.dev providers file", metavar="FILE")
|
|
2846
|
+
parser.add_argument("-m", "--model", default=None, help="Model to use")
|
|
2847
|
+
|
|
2848
|
+
parser.add_argument("--chat", default=None, help="OpenAI Chat Completion Request to send", metavar="REQUEST")
|
|
2849
|
+
parser.add_argument(
|
|
2850
|
+
"-s", "--system", default=None, help="System prompt to use for chat completion", metavar="PROMPT"
|
|
2851
|
+
)
|
|
2852
|
+
parser.add_argument(
|
|
2853
|
+
"--tools", default=None, help="Tools to use for chat completion (all|none|<tool>,<tool>...)", metavar="TOOLS"
|
|
2854
|
+
)
|
|
2855
|
+
parser.add_argument("--image", default=None, help="Image input to use in chat completion")
|
|
2856
|
+
parser.add_argument("--audio", default=None, help="Audio input to use in chat completion")
|
|
2857
|
+
parser.add_argument("--file", default=None, help="File input to use in chat completion")
|
|
2858
|
+
parser.add_argument("--out", default=None, help="Image or Video Generation Request", metavar="MODALITY")
|
|
2859
|
+
parser.add_argument(
|
|
2860
|
+
"--args",
|
|
2861
|
+
default=None,
|
|
2862
|
+
help='URL-encoded parameters to add to chat request (e.g. "temperature=0.7&seed=111")',
|
|
2863
|
+
metavar="PARAMS",
|
|
2864
|
+
)
|
|
2865
|
+
parser.add_argument("--raw", action="store_true", help="Return raw AI JSON response")
|
|
2866
|
+
|
|
2867
|
+
parser.add_argument(
|
|
2868
|
+
"--list", action="store_true", help="Show list of enabled providers and their models (alias ls provider?)"
|
|
2869
|
+
)
|
|
2870
|
+
parser.add_argument("--check", default=None, help="Check validity of models for a provider", metavar="PROVIDER")
|
|
2871
|
+
|
|
2872
|
+
parser.add_argument(
|
|
2873
|
+
"--serve", default=None, help="Port to start an OpenAI Chat compatible server on", metavar="PORT"
|
|
2874
|
+
)
|
|
2875
|
+
|
|
2876
|
+
parser.add_argument("--enable", default=None, help="Enable a provider", metavar="PROVIDER")
|
|
2877
|
+
parser.add_argument("--disable", default=None, help="Disable a provider", metavar="PROVIDER")
|
|
2878
|
+
parser.add_argument("--default", default=None, help="Configure the default model to use", metavar="MODEL")
|
|
2879
|
+
|
|
2880
|
+
parser.add_argument("--init", action="store_true", help="Create a default llms.json")
|
|
2881
|
+
parser.add_argument("--update-providers", action="store_true", help="Update local models.dev providers.json")
|
|
2882
|
+
|
|
2883
|
+
parser.add_argument("--logprefix", default="", help="Prefix used in log messages", metavar="PREFIX")
|
|
2884
|
+
parser.add_argument("--verbose", action="store_true", help="Verbose output")
|
|
2885
|
+
|
|
2886
|
+
parser.add_argument(
|
|
2887
|
+
"--add",
|
|
2888
|
+
nargs="?",
|
|
2889
|
+
const="ls",
|
|
2890
|
+
default=None,
|
|
2891
|
+
help="Install an extension (lists available extensions if no name provided)",
|
|
2892
|
+
metavar="EXTENSION",
|
|
2893
|
+
)
|
|
2894
|
+
parser.add_argument(
|
|
2895
|
+
"--remove",
|
|
2896
|
+
nargs="?",
|
|
2897
|
+
const="ls",
|
|
2898
|
+
default=None,
|
|
2899
|
+
help="Remove an extension (lists installed extensions if no name provided)",
|
|
2900
|
+
metavar="EXTENSION",
|
|
2901
|
+
)
|
|
2902
|
+
|
|
2903
|
+
parser.add_argument(
|
|
2904
|
+
"--update",
|
|
2905
|
+
nargs="?",
|
|
2906
|
+
const="ls",
|
|
2907
|
+
default=None,
|
|
2908
|
+
help="Update an extension (use 'all' to update all extensions)",
|
|
2909
|
+
metavar="EXTENSION",
|
|
2910
|
+
)
|
|
2911
|
+
|
|
2912
|
+
# Load parser extensions, go through all extensions and load their parser arguments
|
|
2913
|
+
init_extensions(parser)
|
|
1502
2914
|
|
|
1503
2915
|
cli_args, extra_args = parser.parse_known_args()
|
|
1504
2916
|
|
|
2917
|
+
g_app = AppExtensions(cli_args, extra_args)
|
|
2918
|
+
|
|
1505
2919
|
# Check for verbose mode from CLI argument or environment variables
|
|
1506
|
-
verbose_env = os.
|
|
1507
|
-
if cli_args.verbose or verbose_env in (
|
|
2920
|
+
verbose_env = os.getenv("VERBOSE", "").lower()
|
|
2921
|
+
if cli_args.verbose or verbose_env in ("1", "true"):
|
|
1508
2922
|
g_verbose = True
|
|
1509
2923
|
# printdump(cli_args)
|
|
1510
2924
|
if cli_args.model:
|
|
@@ -1512,13 +2926,9 @@ def main():
|
|
|
1512
2926
|
if cli_args.logprefix:
|
|
1513
2927
|
g_logprefix = cli_args.logprefix
|
|
1514
2928
|
|
|
1515
|
-
_ROOT = Path(cli_args.root) if cli_args.root else resolve_root()
|
|
1516
|
-
if not _ROOT:
|
|
1517
|
-
print("Resource root not found")
|
|
1518
|
-
exit(1)
|
|
1519
|
-
|
|
1520
2929
|
home_config_path = home_llms_path("llms.json")
|
|
1521
|
-
|
|
2930
|
+
home_providers_path = home_llms_path("providers.json")
|
|
2931
|
+
home_providers_extra_path = home_llms_path("providers-extra.json")
|
|
1522
2932
|
|
|
1523
2933
|
if cli_args.init:
|
|
1524
2934
|
if os.path.exists(home_config_path):
|
|
@@ -1527,38 +2937,215 @@ def main():
|
|
|
1527
2937
|
asyncio.run(save_default_config(home_config_path))
|
|
1528
2938
|
print(f"Created default config at {home_config_path}")
|
|
1529
2939
|
|
|
1530
|
-
if os.path.exists(
|
|
1531
|
-
print(f"
|
|
2940
|
+
if os.path.exists(home_providers_path):
|
|
2941
|
+
print(f"providers.json already exists at {home_providers_path}")
|
|
2942
|
+
else:
|
|
2943
|
+
asyncio.run(save_text_url(github_url("providers.json"), home_providers_path))
|
|
2944
|
+
print(f"Created default providers config at {home_providers_path}")
|
|
2945
|
+
|
|
2946
|
+
if os.path.exists(home_providers_extra_path):
|
|
2947
|
+
print(f"providers-extra.json already exists at {home_providers_extra_path}")
|
|
1532
2948
|
else:
|
|
1533
|
-
asyncio.run(save_text_url(github_url("
|
|
1534
|
-
print(f"Created default
|
|
2949
|
+
asyncio.run(save_text_url(github_url("providers-extra.json"), home_providers_extra_path))
|
|
2950
|
+
print(f"Created default extra providers config at {home_providers_extra_path}")
|
|
1535
2951
|
exit(0)
|
|
1536
2952
|
|
|
2953
|
+
if cli_args.providers:
|
|
2954
|
+
if not os.path.exists(cli_args.providers):
|
|
2955
|
+
print(f"providers.json not found at {cli_args.providers}")
|
|
2956
|
+
exit(1)
|
|
2957
|
+
g_providers = json.loads(text_from_file(cli_args.providers))
|
|
2958
|
+
|
|
1537
2959
|
if cli_args.config:
|
|
1538
2960
|
# read contents
|
|
1539
2961
|
g_config_path = cli_args.config
|
|
1540
|
-
with open(g_config_path, "
|
|
2962
|
+
with open(g_config_path, encoding="utf-8") as f:
|
|
1541
2963
|
config_json = f.read()
|
|
1542
|
-
g_config =
|
|
2964
|
+
g_config = load_config_json(config_json)
|
|
1543
2965
|
|
|
1544
2966
|
config_dir = os.path.dirname(g_config_path)
|
|
1545
|
-
|
|
1546
|
-
|
|
1547
|
-
|
|
1548
|
-
|
|
1549
|
-
else:
|
|
1550
|
-
if not os.path.exists(home_ui_path):
|
|
1551
|
-
ui_json = text_from_resource("ui.json")
|
|
1552
|
-
with open(home_ui_path, "w") as f:
|
|
1553
|
-
f.write(ui_json)
|
|
1554
|
-
_log(f"Created default ui config at {home_ui_path}")
|
|
1555
|
-
g_ui_path = home_ui_path
|
|
2967
|
+
|
|
2968
|
+
if not g_providers and os.path.exists(os.path.join(config_dir, "providers.json")):
|
|
2969
|
+
g_providers = json.loads(text_from_file(os.path.join(config_dir, "providers.json")))
|
|
2970
|
+
|
|
1556
2971
|
else:
|
|
1557
|
-
# ensure llms.json and
|
|
2972
|
+
# ensure llms.json and providers.json exist in home directory
|
|
1558
2973
|
asyncio.run(save_home_configs())
|
|
1559
2974
|
g_config_path = home_config_path
|
|
1560
|
-
|
|
1561
|
-
|
|
2975
|
+
g_config = load_config_json(text_from_file(g_config_path))
|
|
2976
|
+
|
|
2977
|
+
g_app.set_config(g_config)
|
|
2978
|
+
|
|
2979
|
+
if not g_providers:
|
|
2980
|
+
g_providers = json.loads(text_from_file(home_providers_path))
|
|
2981
|
+
|
|
2982
|
+
if cli_args.update_providers:
|
|
2983
|
+
asyncio.run(update_providers(home_providers_path))
|
|
2984
|
+
print(f"Updated {home_providers_path}")
|
|
2985
|
+
exit(0)
|
|
2986
|
+
|
|
2987
|
+
# if home_providers_path is older than 1 day, update providers list
|
|
2988
|
+
if (
|
|
2989
|
+
os.path.exists(home_providers_path)
|
|
2990
|
+
and (time.time() - os.path.getmtime(home_providers_path)) > 86400
|
|
2991
|
+
and os.getenv("LLMS_DISABLE_UPDATE", "") != "1"
|
|
2992
|
+
):
|
|
2993
|
+
try:
|
|
2994
|
+
asyncio.run(update_providers(home_providers_path))
|
|
2995
|
+
_log(f"Updated {home_providers_path}")
|
|
2996
|
+
except Exception as e:
|
|
2997
|
+
_err("Failed to update providers", e)
|
|
2998
|
+
|
|
2999
|
+
if cli_args.add is not None:
|
|
3000
|
+
if cli_args.add == "ls":
|
|
3001
|
+
|
|
3002
|
+
async def list_extensions():
|
|
3003
|
+
print("\nAvailable extensions:")
|
|
3004
|
+
text = await get_text("https://api.github.com/orgs/llmspy/repos?per_page=100&sort=updated")
|
|
3005
|
+
repos = json.loads(text)
|
|
3006
|
+
max_name_length = 0
|
|
3007
|
+
for repo in repos:
|
|
3008
|
+
max_name_length = max(max_name_length, len(repo["name"]))
|
|
3009
|
+
|
|
3010
|
+
for repo in repos:
|
|
3011
|
+
print(f" {repo['name']:<{max_name_length + 2}} {repo['description']}")
|
|
3012
|
+
|
|
3013
|
+
print("\nUsage:")
|
|
3014
|
+
print(" llms --add <extension>")
|
|
3015
|
+
print(" llms --add <github-user>/<repo>")
|
|
3016
|
+
|
|
3017
|
+
asyncio.run(list_extensions())
|
|
3018
|
+
exit(0)
|
|
3019
|
+
|
|
3020
|
+
async def install_extension(name):
|
|
3021
|
+
# Determine git URL and target directory name
|
|
3022
|
+
if "/" in name:
|
|
3023
|
+
git_url = f"https://github.com/{name}"
|
|
3024
|
+
target_name = name.split("/")[-1]
|
|
3025
|
+
else:
|
|
3026
|
+
git_url = f"https://github.com/llmspy/{name}"
|
|
3027
|
+
target_name = name
|
|
3028
|
+
|
|
3029
|
+
# check extension is not already installed
|
|
3030
|
+
extensions_path = get_extensions_path()
|
|
3031
|
+
target_path = os.path.join(extensions_path, target_name)
|
|
3032
|
+
|
|
3033
|
+
if os.path.exists(target_path):
|
|
3034
|
+
print(f"Extension {target_name} is already installed at {target_path}")
|
|
3035
|
+
return
|
|
3036
|
+
|
|
3037
|
+
print(f"Installing extension: {name}")
|
|
3038
|
+
print(f"Cloning from {git_url} to {target_path}...")
|
|
3039
|
+
|
|
3040
|
+
try:
|
|
3041
|
+
subprocess.run(["git", "clone", git_url, target_path], check=True)
|
|
3042
|
+
|
|
3043
|
+
# Check for requirements.txt
|
|
3044
|
+
requirements_path = os.path.join(target_path, "requirements.txt")
|
|
3045
|
+
if os.path.exists(requirements_path):
|
|
3046
|
+
print(f"Installing dependencies from {requirements_path}...")
|
|
3047
|
+
|
|
3048
|
+
# Check if uv is installed
|
|
3049
|
+
has_uv = False
|
|
3050
|
+
try:
|
|
3051
|
+
subprocess.run(
|
|
3052
|
+
["uv", "--version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True
|
|
3053
|
+
)
|
|
3054
|
+
has_uv = True
|
|
3055
|
+
except (subprocess.CalledProcessError, FileNotFoundError):
|
|
3056
|
+
pass
|
|
3057
|
+
|
|
3058
|
+
if has_uv:
|
|
3059
|
+
subprocess.run(
|
|
3060
|
+
["uv", "pip", "install", "-p", sys.executable, "-r", "requirements.txt"],
|
|
3061
|
+
cwd=target_path,
|
|
3062
|
+
check=True,
|
|
3063
|
+
)
|
|
3064
|
+
else:
|
|
3065
|
+
subprocess.run(
|
|
3066
|
+
[sys.executable, "-m", "pip", "install", "-r", "requirements.txt"],
|
|
3067
|
+
cwd=target_path,
|
|
3068
|
+
check=True,
|
|
3069
|
+
)
|
|
3070
|
+
print("Dependencies installed successfully.")
|
|
3071
|
+
|
|
3072
|
+
print(f"Extension {target_name} installed successfully.")
|
|
3073
|
+
|
|
3074
|
+
except subprocess.CalledProcessError as e:
|
|
3075
|
+
print(f"Failed to install extension: {e}")
|
|
3076
|
+
# cleanup if clone failed but directory was created (unlikely with simple git clone but good practice)
|
|
3077
|
+
if os.path.exists(target_path) and not os.listdir(target_path):
|
|
3078
|
+
os.rmdir(target_path)
|
|
3079
|
+
|
|
3080
|
+
asyncio.run(install_extension(cli_args.add))
|
|
3081
|
+
exit(0)
|
|
3082
|
+
|
|
3083
|
+
if cli_args.remove is not None:
|
|
3084
|
+
if cli_args.remove == "ls":
|
|
3085
|
+
# List installed extensions
|
|
3086
|
+
extensions_path = get_extensions_path()
|
|
3087
|
+
extensions = os.listdir(extensions_path)
|
|
3088
|
+
if len(extensions) == 0:
|
|
3089
|
+
print("No extensions installed.")
|
|
3090
|
+
exit(0)
|
|
3091
|
+
print("Installed extensions:")
|
|
3092
|
+
for extension in extensions:
|
|
3093
|
+
print(f" {extension}")
|
|
3094
|
+
exit(0)
|
|
3095
|
+
# Remove an extension
|
|
3096
|
+
extension_name = cli_args.remove
|
|
3097
|
+
extensions_path = get_extensions_path()
|
|
3098
|
+
target_path = os.path.join(extensions_path, extension_name)
|
|
3099
|
+
|
|
3100
|
+
if not os.path.exists(target_path):
|
|
3101
|
+
print(f"Extension {extension_name} not found at {target_path}")
|
|
3102
|
+
exit(1)
|
|
3103
|
+
|
|
3104
|
+
print(f"Removing extension: {extension_name}...")
|
|
3105
|
+
try:
|
|
3106
|
+
shutil.rmtree(target_path)
|
|
3107
|
+
print(f"Extension {extension_name} removed successfully.")
|
|
3108
|
+
except Exception as e:
|
|
3109
|
+
print(f"Failed to remove extension: {e}")
|
|
3110
|
+
exit(1)
|
|
3111
|
+
|
|
3112
|
+
exit(0)
|
|
3113
|
+
|
|
3114
|
+
if cli_args.update:
|
|
3115
|
+
if cli_args.update == "ls":
|
|
3116
|
+
# List installed extensions
|
|
3117
|
+
extensions_path = get_extensions_path()
|
|
3118
|
+
extensions = os.listdir(extensions_path)
|
|
3119
|
+
if len(extensions) == 0:
|
|
3120
|
+
print("No extensions installed.")
|
|
3121
|
+
exit(0)
|
|
3122
|
+
print("Installed extensions:")
|
|
3123
|
+
for extension in extensions:
|
|
3124
|
+
print(f" {extension}")
|
|
3125
|
+
|
|
3126
|
+
print("\nUsage:")
|
|
3127
|
+
print(" llms --update <extension>")
|
|
3128
|
+
print(" llms --update all")
|
|
3129
|
+
exit(0)
|
|
3130
|
+
|
|
3131
|
+
async def update_extensions(extension_name):
|
|
3132
|
+
extensions_path = get_extensions_path()
|
|
3133
|
+
for extension in os.listdir(extensions_path):
|
|
3134
|
+
extension_path = os.path.join(extensions_path, extension)
|
|
3135
|
+
if os.path.isdir(extension_path):
|
|
3136
|
+
if extension_name != "all" and extension != extension_name:
|
|
3137
|
+
continue
|
|
3138
|
+
result = subprocess.run(["git", "pull"], cwd=extension_path, capture_output=True)
|
|
3139
|
+
if result.returncode != 0:
|
|
3140
|
+
print(f"Failed to update extension {extension}: {result.stderr.decode('utf-8')}")
|
|
3141
|
+
continue
|
|
3142
|
+
print(f"Updated extension {extension}")
|
|
3143
|
+
_log(result.stdout.decode("utf-8"))
|
|
3144
|
+
|
|
3145
|
+
asyncio.run(update_extensions(cli_args.update))
|
|
3146
|
+
exit(0)
|
|
3147
|
+
|
|
3148
|
+
install_extensions()
|
|
1562
3149
|
|
|
1563
3150
|
asyncio.run(reload_providers())
|
|
1564
3151
|
|
|
@@ -1568,7 +3155,7 @@ def main():
|
|
|
1568
3155
|
filter_list = []
|
|
1569
3156
|
if len(extra_args) > 0:
|
|
1570
3157
|
arg = extra_args[0]
|
|
1571
|
-
if arg ==
|
|
3158
|
+
if arg == "ls":
|
|
1572
3159
|
cli_args.list = True
|
|
1573
3160
|
if len(extra_args) > 1:
|
|
1574
3161
|
filter_list = extra_args[1:]
|
|
@@ -1576,36 +3163,57 @@ def main():
|
|
|
1576
3163
|
if cli_args.list:
|
|
1577
3164
|
# Show list of enabled providers and their models
|
|
1578
3165
|
enabled = []
|
|
3166
|
+
provider_count = 0
|
|
3167
|
+
model_count = 0
|
|
3168
|
+
|
|
3169
|
+
max_model_length = 0
|
|
1579
3170
|
for name, provider in g_handlers.items():
|
|
1580
3171
|
if len(filter_list) > 0 and name not in filter_list:
|
|
1581
3172
|
continue
|
|
3173
|
+
for model in provider.models:
|
|
3174
|
+
max_model_length = max(max_model_length, len(model))
|
|
3175
|
+
|
|
3176
|
+
for name, provider in g_handlers.items():
|
|
3177
|
+
if len(filter_list) > 0 and name not in filter_list:
|
|
3178
|
+
continue
|
|
3179
|
+
provider_count += 1
|
|
1582
3180
|
print(f"{name}:")
|
|
1583
3181
|
enabled.append(name)
|
|
1584
3182
|
for model in provider.models:
|
|
1585
|
-
|
|
3183
|
+
model_count += 1
|
|
3184
|
+
model_cost_info = None
|
|
3185
|
+
if "cost" in provider.models[model]:
|
|
3186
|
+
model_cost = provider.models[model]["cost"]
|
|
3187
|
+
if "input" in model_cost and "output" in model_cost:
|
|
3188
|
+
if model_cost["input"] == 0 and model_cost["output"] == 0:
|
|
3189
|
+
model_cost_info = " 0"
|
|
3190
|
+
else:
|
|
3191
|
+
model_cost_info = f"{model_cost['input']:5} / {model_cost['output']}"
|
|
3192
|
+
print(f" {model:{max_model_length}} {model_cost_info or ''}")
|
|
3193
|
+
|
|
3194
|
+
print(f"\n{model_count} models available from {provider_count} providers")
|
|
1586
3195
|
|
|
1587
3196
|
print_status()
|
|
1588
|
-
exit(0)
|
|
3197
|
+
g_app.exit(0)
|
|
1589
3198
|
|
|
1590
3199
|
if cli_args.check is not None:
|
|
1591
3200
|
# Check validity of models for a provider
|
|
1592
3201
|
provider_name = cli_args.check
|
|
1593
3202
|
model_names = extra_args if len(extra_args) > 0 else None
|
|
1594
3203
|
asyncio.run(check_models(provider_name, model_names))
|
|
1595
|
-
exit(0)
|
|
3204
|
+
g_app.exit(0)
|
|
1596
3205
|
|
|
1597
3206
|
if cli_args.serve is not None:
|
|
1598
3207
|
# Disable inactive providers and save to config before starting server
|
|
1599
|
-
all_providers = g_config[
|
|
3208
|
+
all_providers = g_config["providers"].keys()
|
|
1600
3209
|
enabled_providers = list(g_handlers.keys())
|
|
1601
3210
|
disable_providers = []
|
|
1602
3211
|
for provider in all_providers:
|
|
1603
|
-
provider_config = g_config[
|
|
1604
|
-
if provider not in enabled_providers:
|
|
1605
|
-
|
|
1606
|
-
|
|
1607
|
-
|
|
1608
|
-
|
|
3212
|
+
provider_config = g_config["providers"][provider]
|
|
3213
|
+
if provider not in enabled_providers and "enabled" in provider_config and provider_config["enabled"]:
|
|
3214
|
+
provider_config["enabled"] = False
|
|
3215
|
+
disable_providers.append(provider)
|
|
3216
|
+
|
|
1609
3217
|
if len(disable_providers) > 0:
|
|
1610
3218
|
_log(f"Disabled unavailable providers: {', '.join(disable_providers)}")
|
|
1611
3219
|
save_config(g_config)
|
|
@@ -1613,24 +3221,28 @@ def main():
|
|
|
1613
3221
|
# Start server
|
|
1614
3222
|
port = int(cli_args.serve)
|
|
1615
3223
|
|
|
1616
|
-
if not os.path.exists(g_ui_path):
|
|
1617
|
-
print(f"UI not found at {g_ui_path}")
|
|
1618
|
-
exit(1)
|
|
1619
|
-
|
|
1620
3224
|
# Validate auth configuration if enabled
|
|
1621
|
-
auth_enabled = g_config.get(
|
|
3225
|
+
auth_enabled = g_config.get("auth", {}).get("enabled", False)
|
|
1622
3226
|
if auth_enabled:
|
|
1623
|
-
github_config = g_config.get(
|
|
1624
|
-
client_id = github_config.get(
|
|
1625
|
-
client_secret = github_config.get(
|
|
3227
|
+
github_config = g_config.get("auth", {}).get("github", {})
|
|
3228
|
+
client_id = github_config.get("client_id", "")
|
|
3229
|
+
client_secret = github_config.get("client_secret", "")
|
|
1626
3230
|
|
|
1627
3231
|
# Expand environment variables
|
|
1628
|
-
if client_id.startswith(
|
|
1629
|
-
client_id =
|
|
1630
|
-
if client_secret.startswith(
|
|
1631
|
-
client_secret =
|
|
1632
|
-
|
|
1633
|
-
|
|
3232
|
+
if client_id.startswith("$"):
|
|
3233
|
+
client_id = client_id[1:]
|
|
3234
|
+
if client_secret.startswith("$"):
|
|
3235
|
+
client_secret = client_secret[1:]
|
|
3236
|
+
|
|
3237
|
+
client_id = os.getenv(client_id, client_id)
|
|
3238
|
+
client_secret = os.getenv(client_secret, client_secret)
|
|
3239
|
+
|
|
3240
|
+
if (
|
|
3241
|
+
not client_id
|
|
3242
|
+
or not client_secret
|
|
3243
|
+
or client_id == "GITHUB_CLIENT_ID"
|
|
3244
|
+
or client_secret == "GITHUB_CLIENT_SECRET"
|
|
3245
|
+
):
|
|
1634
3246
|
print("ERROR: Authentication is enabled but GitHub OAuth is not properly configured.")
|
|
1635
3247
|
print("Please set GITHUB_CLIENT_ID and GITHUB_CLIENT_SECRET environment variables,")
|
|
1636
3248
|
print("or disable authentication by setting 'auth.enabled' to false in llms.json")
|
|
@@ -1638,157 +3250,290 @@ def main():
|
|
|
1638
3250
|
|
|
1639
3251
|
_log("Authentication enabled - GitHub OAuth configured")
|
|
1640
3252
|
|
|
1641
|
-
client_max_size = g_config.get(
|
|
1642
|
-
|
|
3253
|
+
client_max_size = g_config.get("limits", {}).get(
|
|
3254
|
+
"client_max_size", 20 * 1024 * 1024
|
|
3255
|
+
) # 20MB max request size (to handle base64 encoding overhead)
|
|
3256
|
+
_log(f"client_max_size set to {client_max_size} bytes ({client_max_size / 1024 / 1024:.1f}MB)")
|
|
1643
3257
|
app = web.Application(client_max_size=client_max_size)
|
|
1644
3258
|
|
|
1645
|
-
# Authentication middleware helper
|
|
1646
|
-
def check_auth(request):
|
|
1647
|
-
"""Check if request is authenticated. Returns (is_authenticated, user_data)"""
|
|
1648
|
-
if not auth_enabled:
|
|
1649
|
-
return True, None
|
|
1650
|
-
|
|
1651
|
-
# Check for OAuth session token
|
|
1652
|
-
session_token = request.query.get('session') or request.headers.get('X-Session-Token')
|
|
1653
|
-
if session_token and session_token in g_sessions:
|
|
1654
|
-
return True, g_sessions[session_token]
|
|
1655
|
-
|
|
1656
|
-
# Check for API key
|
|
1657
|
-
auth_header = request.headers.get('Authorization', '')
|
|
1658
|
-
if auth_header.startswith('Bearer '):
|
|
1659
|
-
api_key = auth_header[7:]
|
|
1660
|
-
if api_key:
|
|
1661
|
-
return True, {"authProvider": "apikey"}
|
|
1662
|
-
|
|
1663
|
-
return False, None
|
|
1664
|
-
|
|
1665
3259
|
async def chat_handler(request):
|
|
1666
3260
|
# Check authentication if enabled
|
|
1667
|
-
is_authenticated, user_data = check_auth(request)
|
|
3261
|
+
is_authenticated, user_data = g_app.check_auth(request)
|
|
1668
3262
|
if not is_authenticated:
|
|
1669
|
-
return web.json_response(
|
|
1670
|
-
"error": {
|
|
1671
|
-
"message": "Authentication required",
|
|
1672
|
-
"type": "authentication_error",
|
|
1673
|
-
"code": "unauthorized"
|
|
1674
|
-
}
|
|
1675
|
-
}, status=401)
|
|
3263
|
+
return web.json_response(g_app.error_auth_required, status=401)
|
|
1676
3264
|
|
|
1677
3265
|
try:
|
|
1678
3266
|
chat = await request.json()
|
|
1679
|
-
|
|
3267
|
+
context = {"chat": chat, "request": request, "user": g_app.get_username(request)}
|
|
3268
|
+
metadata = chat.get("metadata", {})
|
|
3269
|
+
context["threadId"] = metadata.get("threadId", None)
|
|
3270
|
+
context["tools"] = metadata.get("tools", "all")
|
|
3271
|
+
response = await g_app.chat_completion(chat, context)
|
|
1680
3272
|
return web.json_response(response)
|
|
1681
3273
|
except Exception as e:
|
|
1682
|
-
return web.json_response(
|
|
1683
|
-
app.router.add_post('/v1/chat/completions', chat_handler)
|
|
3274
|
+
return web.json_response(to_error_response(e), status=500)
|
|
1684
3275
|
|
|
1685
|
-
|
|
1686
|
-
return web.json_response(get_models())
|
|
1687
|
-
app.router.add_get('/models/list', models_handler)
|
|
3276
|
+
app.router.add_post("/v1/chat/completions", chat_handler)
|
|
1688
3277
|
|
|
1689
3278
|
async def active_models_handler(request):
|
|
1690
3279
|
return web.json_response(get_active_models())
|
|
1691
|
-
|
|
3280
|
+
|
|
3281
|
+
app.router.add_get("/models", active_models_handler)
|
|
3282
|
+
|
|
3283
|
+
async def active_providers_handler(request):
|
|
3284
|
+
return web.json_response(api_providers())
|
|
3285
|
+
|
|
3286
|
+
app.router.add_get("/providers", active_providers_handler)
|
|
1692
3287
|
|
|
1693
3288
|
async def status_handler(request):
|
|
1694
3289
|
enabled, disabled = provider_status()
|
|
1695
|
-
return web.json_response(
|
|
1696
|
-
|
|
1697
|
-
|
|
1698
|
-
|
|
1699
|
-
|
|
1700
|
-
|
|
3290
|
+
return web.json_response(
|
|
3291
|
+
{
|
|
3292
|
+
"all": list(g_config["providers"].keys()),
|
|
3293
|
+
"enabled": enabled,
|
|
3294
|
+
"disabled": disabled,
|
|
3295
|
+
}
|
|
3296
|
+
)
|
|
3297
|
+
|
|
3298
|
+
app.router.add_get("/status", status_handler)
|
|
1701
3299
|
|
|
1702
3300
|
async def provider_handler(request):
|
|
1703
|
-
provider = request.match_info.get(
|
|
3301
|
+
provider = request.match_info.get("provider", "")
|
|
1704
3302
|
data = await request.json()
|
|
1705
3303
|
msg = None
|
|
1706
|
-
if provider:
|
|
1707
|
-
if data.get(
|
|
3304
|
+
if provider:
|
|
3305
|
+
if data.get("enable", False):
|
|
1708
3306
|
provider_config, msg = enable_provider(provider)
|
|
1709
|
-
_log(f"Enabled provider {provider}")
|
|
1710
|
-
|
|
1711
|
-
|
|
3307
|
+
_log(f"Enabled provider {provider} {msg}")
|
|
3308
|
+
if not msg:
|
|
3309
|
+
await load_llms()
|
|
3310
|
+
elif data.get("disable", False):
|
|
1712
3311
|
disable_provider(provider)
|
|
1713
3312
|
_log(f"Disabled provider {provider}")
|
|
1714
3313
|
enabled, disabled = provider_status()
|
|
1715
|
-
return web.json_response(
|
|
1716
|
-
|
|
1717
|
-
|
|
1718
|
-
|
|
1719
|
-
|
|
1720
|
-
|
|
3314
|
+
return web.json_response(
|
|
3315
|
+
{
|
|
3316
|
+
"enabled": enabled,
|
|
3317
|
+
"disabled": disabled,
|
|
3318
|
+
"feedback": msg or "",
|
|
3319
|
+
}
|
|
3320
|
+
)
|
|
3321
|
+
|
|
3322
|
+
app.router.add_post("/providers/{provider}", provider_handler)
|
|
3323
|
+
|
|
3324
|
+
async def upload_handler(request):
|
|
3325
|
+
# Check authentication if enabled
|
|
3326
|
+
is_authenticated, user_data = g_app.check_auth(request)
|
|
3327
|
+
if not is_authenticated:
|
|
3328
|
+
return web.json_response(g_app.error_auth_required, status=401)
|
|
3329
|
+
|
|
3330
|
+
reader = await request.multipart()
|
|
3331
|
+
|
|
3332
|
+
# Read first file field
|
|
3333
|
+
field = await reader.next()
|
|
3334
|
+
while field and field.name != "file":
|
|
3335
|
+
field = await reader.next()
|
|
3336
|
+
|
|
3337
|
+
if not field:
|
|
3338
|
+
return web.json_response(create_error_response("No file provided"), status=400)
|
|
3339
|
+
|
|
3340
|
+
filename = field.filename or "file"
|
|
3341
|
+
content = await field.read()
|
|
3342
|
+
mimetype = get_file_mime_type(filename)
|
|
3343
|
+
|
|
3344
|
+
# If image, resize if needed
|
|
3345
|
+
if mimetype.startswith("image/"):
|
|
3346
|
+
content, mimetype = convert_image_if_needed(content, mimetype)
|
|
3347
|
+
|
|
3348
|
+
# Calculate SHA256
|
|
3349
|
+
sha256_hash = hashlib.sha256(content).hexdigest()
|
|
3350
|
+
ext = filename.rsplit(".", 1)[1] if "." in filename else ""
|
|
3351
|
+
if not ext:
|
|
3352
|
+
ext = mimetypes.guess_extension(mimetype) or ""
|
|
3353
|
+
if ext.startswith("."):
|
|
3354
|
+
ext = ext[1:]
|
|
3355
|
+
|
|
3356
|
+
if not ext:
|
|
3357
|
+
ext = "bin"
|
|
3358
|
+
|
|
3359
|
+
save_filename = f"{sha256_hash}.{ext}" if ext else sha256_hash
|
|
3360
|
+
|
|
3361
|
+
# Use first 2 chars for subdir to avoid too many files in one dir
|
|
3362
|
+
subdir = sha256_hash[:2]
|
|
3363
|
+
relative_path = f"{subdir}/{save_filename}"
|
|
3364
|
+
full_path = get_cache_path(relative_path)
|
|
3365
|
+
|
|
3366
|
+
# if file and its .info.json already exists, return it
|
|
3367
|
+
info_path = os.path.splitext(full_path)[0] + ".info.json"
|
|
3368
|
+
if os.path.exists(full_path) and os.path.exists(info_path):
|
|
3369
|
+
return web.json_response(json.load(open(info_path)))
|
|
3370
|
+
|
|
3371
|
+
os.makedirs(os.path.dirname(full_path), exist_ok=True)
|
|
3372
|
+
|
|
3373
|
+
with open(full_path, "wb") as f:
|
|
3374
|
+
f.write(content)
|
|
3375
|
+
|
|
3376
|
+
url = f"/~cache/{relative_path}"
|
|
3377
|
+
response_data = {
|
|
3378
|
+
"date": int(time.time()),
|
|
3379
|
+
"url": url,
|
|
3380
|
+
"size": len(content),
|
|
3381
|
+
"type": mimetype,
|
|
3382
|
+
"name": filename,
|
|
3383
|
+
}
|
|
3384
|
+
|
|
3385
|
+
# If image, get dimensions
|
|
3386
|
+
if HAS_PIL and mimetype.startswith("image/"):
|
|
3387
|
+
try:
|
|
3388
|
+
with Image.open(BytesIO(content)) as img:
|
|
3389
|
+
response_data["width"] = img.width
|
|
3390
|
+
response_data["height"] = img.height
|
|
3391
|
+
except Exception:
|
|
3392
|
+
pass
|
|
3393
|
+
|
|
3394
|
+
# Save metadata
|
|
3395
|
+
info_path = os.path.splitext(full_path)[0] + ".info.json"
|
|
3396
|
+
with open(info_path, "w") as f:
|
|
3397
|
+
json.dump(response_data, f)
|
|
3398
|
+
|
|
3399
|
+
g_app.on_cache_saved_filters({"url": url, "info": response_data})
|
|
3400
|
+
|
|
3401
|
+
return web.json_response(response_data)
|
|
3402
|
+
|
|
3403
|
+
app.router.add_post("/upload", upload_handler)
|
|
3404
|
+
|
|
3405
|
+
async def extensions_handler(request):
|
|
3406
|
+
return web.json_response(g_app.ui_extensions)
|
|
3407
|
+
|
|
3408
|
+
app.router.add_get("/ext", extensions_handler)
|
|
3409
|
+
|
|
3410
|
+
async def tools_handler(request):
|
|
3411
|
+
return web.json_response(g_app.tool_definitions)
|
|
3412
|
+
|
|
3413
|
+
app.router.add_get("/ext/tools", tools_handler)
|
|
3414
|
+
|
|
3415
|
+
async def cache_handler(request):
|
|
3416
|
+
path = request.match_info["tail"]
|
|
3417
|
+
full_path = get_cache_path(path)
|
|
3418
|
+
|
|
3419
|
+
if "info" in request.query:
|
|
3420
|
+
info_path = os.path.splitext(full_path)[0] + ".info.json"
|
|
3421
|
+
if not os.path.exists(info_path):
|
|
3422
|
+
return web.Response(text="404: Not Found", status=404)
|
|
3423
|
+
|
|
3424
|
+
# Check for directory traversal for info path
|
|
3425
|
+
try:
|
|
3426
|
+
cache_root = Path(get_cache_path())
|
|
3427
|
+
requested_path = Path(info_path).resolve()
|
|
3428
|
+
if not str(requested_path).startswith(str(cache_root)):
|
|
3429
|
+
return web.Response(text="403: Forbidden", status=403)
|
|
3430
|
+
except Exception:
|
|
3431
|
+
return web.Response(text="403: Forbidden", status=403)
|
|
3432
|
+
|
|
3433
|
+
with open(info_path) as f:
|
|
3434
|
+
content = f.read()
|
|
3435
|
+
return web.Response(text=content, content_type="application/json")
|
|
3436
|
+
|
|
3437
|
+
if not os.path.exists(full_path):
|
|
3438
|
+
return web.Response(text="404: Not Found", status=404)
|
|
3439
|
+
|
|
3440
|
+
# Check for directory traversal
|
|
3441
|
+
try:
|
|
3442
|
+
cache_root = Path(get_cache_path())
|
|
3443
|
+
requested_path = Path(full_path).resolve()
|
|
3444
|
+
if not str(requested_path).startswith(str(cache_root)):
|
|
3445
|
+
return web.Response(text="403: Forbidden", status=403)
|
|
3446
|
+
except Exception:
|
|
3447
|
+
return web.Response(text="403: Forbidden", status=403)
|
|
3448
|
+
|
|
3449
|
+
with open(full_path, "rb") as f:
|
|
3450
|
+
content = f.read()
|
|
3451
|
+
|
|
3452
|
+
mimetype = get_file_mime_type(full_path)
|
|
3453
|
+
return web.Response(body=content, content_type=mimetype)
|
|
3454
|
+
|
|
3455
|
+
app.router.add_get("/~cache/{tail:.*}", cache_handler)
|
|
1721
3456
|
|
|
1722
3457
|
# OAuth handlers
|
|
1723
3458
|
async def github_auth_handler(request):
|
|
1724
3459
|
"""Initiate GitHub OAuth flow"""
|
|
1725
|
-
if
|
|
1726
|
-
return web.json_response(
|
|
3460
|
+
if "auth" not in g_config or "github" not in g_config["auth"]:
|
|
3461
|
+
return web.json_response(create_error_response("GitHub OAuth not configured"), status=500)
|
|
1727
3462
|
|
|
1728
|
-
auth_config = g_config[
|
|
1729
|
-
client_id = auth_config.get(
|
|
1730
|
-
redirect_uri = auth_config.get(
|
|
3463
|
+
auth_config = g_config["auth"]["github"]
|
|
3464
|
+
client_id = auth_config.get("client_id", "")
|
|
3465
|
+
redirect_uri = auth_config.get("redirect_uri", "")
|
|
1731
3466
|
|
|
1732
3467
|
# Expand environment variables
|
|
1733
|
-
if client_id.startswith(
|
|
1734
|
-
client_id =
|
|
1735
|
-
if redirect_uri.startswith(
|
|
1736
|
-
redirect_uri =
|
|
3468
|
+
if client_id.startswith("$"):
|
|
3469
|
+
client_id = client_id[1:]
|
|
3470
|
+
if redirect_uri.startswith("$"):
|
|
3471
|
+
redirect_uri = redirect_uri[1:]
|
|
3472
|
+
|
|
3473
|
+
client_id = os.getenv(client_id, client_id)
|
|
3474
|
+
redirect_uri = os.getenv(redirect_uri, redirect_uri)
|
|
1737
3475
|
|
|
1738
3476
|
if not client_id:
|
|
1739
|
-
return web.json_response(
|
|
3477
|
+
return web.json_response(create_error_response("GitHub client_id not configured"), status=500)
|
|
1740
3478
|
|
|
1741
3479
|
# Generate CSRF state token
|
|
1742
3480
|
state = secrets.token_urlsafe(32)
|
|
1743
|
-
g_oauth_states[state] = {
|
|
1744
|
-
'created': time.time(),
|
|
1745
|
-
'redirect_uri': redirect_uri
|
|
1746
|
-
}
|
|
3481
|
+
g_oauth_states[state] = {"created": time.time(), "redirect_uri": redirect_uri}
|
|
1747
3482
|
|
|
1748
3483
|
# Clean up old states (older than 10 minutes)
|
|
1749
3484
|
current_time = time.time()
|
|
1750
|
-
expired_states = [s for s, data in g_oauth_states.items() if current_time - data[
|
|
3485
|
+
expired_states = [s for s, data in g_oauth_states.items() if current_time - data["created"] > 600]
|
|
1751
3486
|
for s in expired_states:
|
|
1752
3487
|
del g_oauth_states[s]
|
|
1753
3488
|
|
|
1754
3489
|
# Build GitHub authorization URL
|
|
1755
3490
|
params = {
|
|
1756
|
-
|
|
1757
|
-
|
|
1758
|
-
|
|
1759
|
-
|
|
3491
|
+
"client_id": client_id,
|
|
3492
|
+
"redirect_uri": redirect_uri,
|
|
3493
|
+
"state": state,
|
|
3494
|
+
"scope": "read:user user:email",
|
|
1760
3495
|
}
|
|
1761
3496
|
auth_url = f"https://github.com/login/oauth/authorize?{urlencode(params)}"
|
|
1762
3497
|
|
|
1763
3498
|
return web.HTTPFound(auth_url)
|
|
1764
|
-
|
|
3499
|
+
|
|
1765
3500
|
def validate_user(github_username):
|
|
1766
|
-
auth_config = g_config[
|
|
3501
|
+
auth_config = g_config["auth"]["github"]
|
|
1767
3502
|
# Check if user is restricted
|
|
1768
|
-
restrict_to = auth_config.get(
|
|
3503
|
+
restrict_to = auth_config.get("restrict_to", "")
|
|
1769
3504
|
|
|
1770
3505
|
# Expand environment variables
|
|
1771
|
-
if restrict_to.startswith(
|
|
1772
|
-
restrict_to =
|
|
3506
|
+
if restrict_to.startswith("$"):
|
|
3507
|
+
restrict_to = restrict_to[1:]
|
|
3508
|
+
|
|
3509
|
+
restrict_to = os.getenv(restrict_to, None if restrict_to == "GITHUB_USERS" else restrict_to)
|
|
1773
3510
|
|
|
1774
3511
|
# If restrict_to is configured, validate the user
|
|
1775
3512
|
if restrict_to:
|
|
1776
3513
|
# Parse allowed users (comma or space delimited)
|
|
1777
|
-
allowed_users = [u.strip() for u in re.split(r
|
|
3514
|
+
allowed_users = [u.strip() for u in re.split(r"[,\s]+", restrict_to) if u.strip()]
|
|
1778
3515
|
|
|
1779
3516
|
# Check if user is in the allowed list
|
|
1780
3517
|
if not github_username or github_username not in allowed_users:
|
|
1781
3518
|
_log(f"Access denied for user: {github_username}. Not in allowed list: {allowed_users}")
|
|
1782
3519
|
return web.Response(
|
|
1783
3520
|
text=f"Access denied. User '{github_username}' is not authorized to access this application.",
|
|
1784
|
-
status=403
|
|
3521
|
+
status=403,
|
|
1785
3522
|
)
|
|
1786
3523
|
return None
|
|
1787
3524
|
|
|
1788
3525
|
async def github_callback_handler(request):
|
|
1789
3526
|
"""Handle GitHub OAuth callback"""
|
|
1790
|
-
code = request.query.get(
|
|
1791
|
-
state = request.query.get(
|
|
3527
|
+
code = request.query.get("code")
|
|
3528
|
+
state = request.query.get("state")
|
|
3529
|
+
|
|
3530
|
+
# Handle malformed URLs where query params are appended with & instead of ?
|
|
3531
|
+
if not code and "tail" in request.match_info:
|
|
3532
|
+
tail = request.match_info["tail"]
|
|
3533
|
+
if tail.startswith("&"):
|
|
3534
|
+
params = parse_qs(tail[1:])
|
|
3535
|
+
code = params.get("code", [None])[0]
|
|
3536
|
+
state = params.get("state", [None])[0]
|
|
1792
3537
|
|
|
1793
3538
|
if not code or not state:
|
|
1794
3539
|
return web.Response(text="Missing code or state parameter", status=400)
|
|
@@ -1797,118 +3542,122 @@ def main():
|
|
|
1797
3542
|
if state not in g_oauth_states:
|
|
1798
3543
|
return web.Response(text="Invalid state parameter", status=400)
|
|
1799
3544
|
|
|
1800
|
-
|
|
3545
|
+
g_oauth_states.pop(state)
|
|
1801
3546
|
|
|
1802
|
-
if
|
|
1803
|
-
return web.json_response(
|
|
3547
|
+
if "auth" not in g_config or "github" not in g_config["auth"]:
|
|
3548
|
+
return web.json_response(create_error_response("GitHub OAuth not configured"), status=500)
|
|
1804
3549
|
|
|
1805
|
-
auth_config = g_config[
|
|
1806
|
-
client_id = auth_config.get(
|
|
1807
|
-
client_secret = auth_config.get(
|
|
1808
|
-
redirect_uri = auth_config.get(
|
|
3550
|
+
auth_config = g_config["auth"]["github"]
|
|
3551
|
+
client_id = auth_config.get("client_id", "")
|
|
3552
|
+
client_secret = auth_config.get("client_secret", "")
|
|
3553
|
+
redirect_uri = auth_config.get("redirect_uri", "")
|
|
1809
3554
|
|
|
1810
3555
|
# Expand environment variables
|
|
1811
|
-
if client_id.startswith(
|
|
1812
|
-
client_id =
|
|
1813
|
-
if client_secret.startswith(
|
|
1814
|
-
client_secret =
|
|
1815
|
-
if redirect_uri.startswith(
|
|
1816
|
-
redirect_uri =
|
|
3556
|
+
if client_id.startswith("$"):
|
|
3557
|
+
client_id = client_id[1:]
|
|
3558
|
+
if client_secret.startswith("$"):
|
|
3559
|
+
client_secret = client_secret[1:]
|
|
3560
|
+
if redirect_uri.startswith("$"):
|
|
3561
|
+
redirect_uri = redirect_uri[1:]
|
|
3562
|
+
|
|
3563
|
+
client_id = os.getenv(client_id, client_id)
|
|
3564
|
+
client_secret = os.getenv(client_secret, client_secret)
|
|
3565
|
+
redirect_uri = os.getenv(redirect_uri, redirect_uri)
|
|
1817
3566
|
|
|
1818
3567
|
if not client_id or not client_secret:
|
|
1819
|
-
return web.json_response(
|
|
3568
|
+
return web.json_response(create_error_response("GitHub OAuth credentials not configured"), status=500)
|
|
1820
3569
|
|
|
1821
3570
|
# Exchange code for access token
|
|
1822
3571
|
async with aiohttp.ClientSession() as session:
|
|
1823
3572
|
token_url = "https://github.com/login/oauth/access_token"
|
|
1824
3573
|
token_data = {
|
|
1825
|
-
|
|
1826
|
-
|
|
1827
|
-
|
|
1828
|
-
|
|
3574
|
+
"client_id": client_id,
|
|
3575
|
+
"client_secret": client_secret,
|
|
3576
|
+
"code": code,
|
|
3577
|
+
"redirect_uri": redirect_uri,
|
|
1829
3578
|
}
|
|
1830
|
-
headers = {
|
|
3579
|
+
headers = {"Accept": "application/json"}
|
|
1831
3580
|
|
|
1832
3581
|
async with session.post(token_url, data=token_data, headers=headers) as resp:
|
|
1833
3582
|
token_response = await resp.json()
|
|
1834
|
-
access_token = token_response.get(
|
|
3583
|
+
access_token = token_response.get("access_token")
|
|
1835
3584
|
|
|
1836
3585
|
if not access_token:
|
|
1837
|
-
error = token_response.get(
|
|
1838
|
-
return web.
|
|
3586
|
+
error = token_response.get("error_description", "Failed to get access token")
|
|
3587
|
+
return web.json_response(create_error_response(f"OAuth error: {error}"), status=400)
|
|
1839
3588
|
|
|
1840
3589
|
# Fetch user info
|
|
1841
3590
|
user_url = "https://api.github.com/user"
|
|
1842
|
-
headers = {
|
|
1843
|
-
"Authorization": f"Bearer {access_token}",
|
|
1844
|
-
"Accept": "application/json"
|
|
1845
|
-
}
|
|
3591
|
+
headers = {"Authorization": f"Bearer {access_token}", "Accept": "application/json"}
|
|
1846
3592
|
|
|
1847
3593
|
async with session.get(user_url, headers=headers) as resp:
|
|
1848
3594
|
user_data = await resp.json()
|
|
1849
3595
|
|
|
1850
3596
|
# Validate user
|
|
1851
|
-
error_response = validate_user(user_data.get(
|
|
3597
|
+
error_response = validate_user(user_data.get("login", ""))
|
|
1852
3598
|
if error_response:
|
|
1853
3599
|
return error_response
|
|
1854
3600
|
|
|
1855
3601
|
# Create session
|
|
1856
3602
|
session_token = secrets.token_urlsafe(32)
|
|
1857
3603
|
g_sessions[session_token] = {
|
|
1858
|
-
"userId": str(user_data.get(
|
|
1859
|
-
"userName": user_data.get(
|
|
1860
|
-
"displayName": user_data.get(
|
|
1861
|
-
"profileUrl": user_data.get(
|
|
1862
|
-
"email": user_data.get(
|
|
1863
|
-
"created": time.time()
|
|
3604
|
+
"userId": str(user_data.get("id", "")),
|
|
3605
|
+
"userName": user_data.get("login", ""),
|
|
3606
|
+
"displayName": user_data.get("name", ""),
|
|
3607
|
+
"profileUrl": user_data.get("avatar_url", ""),
|
|
3608
|
+
"email": user_data.get("email", ""),
|
|
3609
|
+
"created": time.time(),
|
|
1864
3610
|
}
|
|
1865
3611
|
|
|
1866
3612
|
# Redirect to UI with session token
|
|
1867
|
-
|
|
3613
|
+
response = web.HTTPFound(f"/?session={session_token}")
|
|
3614
|
+
response.set_cookie("llms-token", session_token, httponly=True, path="/", max_age=86400)
|
|
3615
|
+
return response
|
|
1868
3616
|
|
|
1869
3617
|
async def session_handler(request):
|
|
1870
3618
|
"""Validate and return session info"""
|
|
1871
|
-
session_token =
|
|
3619
|
+
session_token = get_session_token(request)
|
|
1872
3620
|
|
|
1873
3621
|
if not session_token or session_token not in g_sessions:
|
|
1874
|
-
return web.json_response(
|
|
3622
|
+
return web.json_response(create_error_response("Invalid or expired session"), status=401)
|
|
1875
3623
|
|
|
1876
3624
|
session_data = g_sessions[session_token]
|
|
1877
3625
|
|
|
1878
3626
|
# Clean up old sessions (older than 24 hours)
|
|
1879
3627
|
current_time = time.time()
|
|
1880
|
-
expired_sessions = [token for token, data in g_sessions.items() if current_time - data[
|
|
3628
|
+
expired_sessions = [token for token, data in g_sessions.items() if current_time - data["created"] > 86400]
|
|
1881
3629
|
for token in expired_sessions:
|
|
1882
3630
|
del g_sessions[token]
|
|
1883
3631
|
|
|
1884
|
-
return web.json_response({
|
|
1885
|
-
**session_data,
|
|
1886
|
-
"sessionToken": session_token
|
|
1887
|
-
})
|
|
3632
|
+
return web.json_response({**session_data, "sessionToken": session_token})
|
|
1888
3633
|
|
|
1889
3634
|
async def logout_handler(request):
|
|
1890
3635
|
"""End OAuth session"""
|
|
1891
|
-
session_token =
|
|
3636
|
+
session_token = get_session_token(request)
|
|
1892
3637
|
|
|
1893
3638
|
if session_token and session_token in g_sessions:
|
|
1894
3639
|
del g_sessions[session_token]
|
|
1895
3640
|
|
|
1896
|
-
|
|
3641
|
+
response = web.json_response({"success": True})
|
|
3642
|
+
response.del_cookie("llms-token")
|
|
3643
|
+
return response
|
|
1897
3644
|
|
|
1898
3645
|
async def auth_handler(request):
|
|
1899
3646
|
"""Check authentication status and return user info"""
|
|
1900
3647
|
# Check for OAuth session token
|
|
1901
|
-
session_token =
|
|
3648
|
+
session_token = get_session_token(request)
|
|
1902
3649
|
|
|
1903
3650
|
if session_token and session_token in g_sessions:
|
|
1904
3651
|
session_data = g_sessions[session_token]
|
|
1905
|
-
return web.json_response(
|
|
1906
|
-
|
|
1907
|
-
|
|
1908
|
-
|
|
1909
|
-
|
|
1910
|
-
|
|
1911
|
-
|
|
3652
|
+
return web.json_response(
|
|
3653
|
+
{
|
|
3654
|
+
"userId": session_data.get("userId", ""),
|
|
3655
|
+
"userName": session_data.get("userName", ""),
|
|
3656
|
+
"displayName": session_data.get("displayName", ""),
|
|
3657
|
+
"profileUrl": session_data.get("profileUrl", ""),
|
|
3658
|
+
"authProvider": "github",
|
|
3659
|
+
}
|
|
3660
|
+
)
|
|
1912
3661
|
|
|
1913
3662
|
# Check for API key in Authorization header
|
|
1914
3663
|
# auth_header = request.headers.get('Authorization', '')
|
|
@@ -1926,25 +3675,21 @@ def main():
|
|
|
1926
3675
|
# })
|
|
1927
3676
|
|
|
1928
3677
|
# Not authenticated - return error in expected format
|
|
1929
|
-
return web.json_response(
|
|
1930
|
-
"responseStatus": {
|
|
1931
|
-
"errorCode": "Unauthorized",
|
|
1932
|
-
"message": "Not authenticated"
|
|
1933
|
-
}
|
|
1934
|
-
}, status=401)
|
|
3678
|
+
return web.json_response(g_app.error_auth_required, status=401)
|
|
1935
3679
|
|
|
1936
|
-
app.router.add_get(
|
|
1937
|
-
app.router.add_get(
|
|
1938
|
-
app.router.add_get(
|
|
1939
|
-
app.router.add_get(
|
|
1940
|
-
app.router.
|
|
3680
|
+
app.router.add_get("/auth", auth_handler)
|
|
3681
|
+
app.router.add_get("/auth/github", github_auth_handler)
|
|
3682
|
+
app.router.add_get("/auth/github/callback", github_callback_handler)
|
|
3683
|
+
app.router.add_get("/auth/github/callback{tail:.*}", github_callback_handler)
|
|
3684
|
+
app.router.add_get("/auth/session", session_handler)
|
|
3685
|
+
app.router.add_post("/auth/logout", logout_handler)
|
|
1941
3686
|
|
|
1942
3687
|
async def ui_static(request: web.Request) -> web.Response:
|
|
1943
3688
|
path = Path(request.match_info["path"])
|
|
1944
3689
|
|
|
1945
3690
|
try:
|
|
1946
3691
|
# Handle both Path objects and importlib.resources Traversable objects
|
|
1947
|
-
if hasattr(_ROOT,
|
|
3692
|
+
if hasattr(_ROOT, "joinpath"):
|
|
1948
3693
|
# importlib.resources Traversable
|
|
1949
3694
|
resource = _ROOT.joinpath("ui").joinpath(str(path))
|
|
1950
3695
|
if not resource.is_file():
|
|
@@ -1957,82 +3702,154 @@ def main():
|
|
|
1957
3702
|
raise web.HTTPNotFound
|
|
1958
3703
|
try:
|
|
1959
3704
|
resource.relative_to(Path(_ROOT)) # basic directory-traversal guard
|
|
1960
|
-
except ValueError:
|
|
1961
|
-
raise web.HTTPBadRequest(text="Invalid path")
|
|
3705
|
+
except ValueError as e:
|
|
3706
|
+
raise web.HTTPBadRequest(text="Invalid path") from e
|
|
1962
3707
|
content = resource.read_bytes()
|
|
1963
3708
|
|
|
1964
3709
|
content_type, _ = mimetypes.guess_type(str(path))
|
|
1965
3710
|
if content_type is None:
|
|
1966
3711
|
content_type = "application/octet-stream"
|
|
1967
3712
|
return web.Response(body=content, content_type=content_type)
|
|
1968
|
-
except (OSError, PermissionError, AttributeError):
|
|
1969
|
-
raise web.HTTPNotFound
|
|
3713
|
+
except (OSError, PermissionError, AttributeError) as e:
|
|
3714
|
+
raise web.HTTPNotFound from e
|
|
1970
3715
|
|
|
1971
3716
|
app.router.add_get("/ui/{path:.*}", ui_static, name="ui_static")
|
|
1972
|
-
|
|
1973
|
-
async def
|
|
1974
|
-
|
|
1975
|
-
|
|
1976
|
-
|
|
1977
|
-
|
|
1978
|
-
|
|
1979
|
-
|
|
1980
|
-
|
|
1981
|
-
|
|
1982
|
-
|
|
1983
|
-
|
|
1984
|
-
|
|
1985
|
-
ui['requiresAuth'] = auth_enabled
|
|
1986
|
-
ui['authType'] = 'oauth' if auth_enabled else 'apikey'
|
|
1987
|
-
return web.json_response(ui)
|
|
1988
|
-
app.router.add_get('/config', ui_config_handler)
|
|
3717
|
+
|
|
3718
|
+
async def config_handler(request):
|
|
3719
|
+
ret = {}
|
|
3720
|
+
if "defaults" not in ret:
|
|
3721
|
+
ret["defaults"] = g_config["defaults"]
|
|
3722
|
+
enabled, disabled = provider_status()
|
|
3723
|
+
ret["status"] = {"all": list(g_config["providers"].keys()), "enabled": enabled, "disabled": disabled}
|
|
3724
|
+
# Add auth configuration
|
|
3725
|
+
ret["requiresAuth"] = auth_enabled
|
|
3726
|
+
ret["authType"] = "oauth" if auth_enabled else "apikey"
|
|
3727
|
+
return web.json_response(ret)
|
|
3728
|
+
|
|
3729
|
+
app.router.add_get("/config", config_handler)
|
|
1989
3730
|
|
|
1990
3731
|
async def not_found_handler(request):
|
|
1991
3732
|
return web.Response(text="404: Not Found", status=404)
|
|
1992
|
-
|
|
3733
|
+
|
|
3734
|
+
app.router.add_get("/favicon.ico", not_found_handler)
|
|
3735
|
+
|
|
3736
|
+
# go through and register all g_app extensions
|
|
3737
|
+
for handler in g_app.server_add_get:
|
|
3738
|
+
handler_fn = handler[1]
|
|
3739
|
+
|
|
3740
|
+
async def managed_handler(request, handler_fn=handler_fn):
|
|
3741
|
+
try:
|
|
3742
|
+
return await handler_fn(request)
|
|
3743
|
+
except Exception as e:
|
|
3744
|
+
return web.json_response(to_error_response(e, stacktrace=g_verbose), status=500)
|
|
3745
|
+
|
|
3746
|
+
app.router.add_get(handler[0], managed_handler, **handler[2])
|
|
3747
|
+
for handler in g_app.server_add_post:
|
|
3748
|
+
handler_fn = handler[1]
|
|
3749
|
+
|
|
3750
|
+
async def managed_handler(request, handler_fn=handler_fn):
|
|
3751
|
+
try:
|
|
3752
|
+
return await handler_fn(request)
|
|
3753
|
+
except Exception as e:
|
|
3754
|
+
return web.json_response(to_error_response(e, stacktrace=g_verbose), status=500)
|
|
3755
|
+
|
|
3756
|
+
app.router.add_post(handler[0], managed_handler, **handler[2])
|
|
3757
|
+
for handler in g_app.server_add_put:
|
|
3758
|
+
handler_fn = handler[1]
|
|
3759
|
+
|
|
3760
|
+
async def managed_handler(request, handler_fn=handler_fn):
|
|
3761
|
+
try:
|
|
3762
|
+
return await handler_fn(request)
|
|
3763
|
+
except Exception as e:
|
|
3764
|
+
return web.json_response(to_error_response(e, stacktrace=g_verbose), status=500)
|
|
3765
|
+
|
|
3766
|
+
app.router.add_put(handler[0], managed_handler, **handler[2])
|
|
3767
|
+
for handler in g_app.server_add_delete:
|
|
3768
|
+
handler_fn = handler[1]
|
|
3769
|
+
|
|
3770
|
+
async def managed_handler(request, handler_fn=handler_fn):
|
|
3771
|
+
try:
|
|
3772
|
+
return await handler_fn(request)
|
|
3773
|
+
except Exception as e:
|
|
3774
|
+
return web.json_response(to_error_response(e, stacktrace=g_verbose), status=500)
|
|
3775
|
+
|
|
3776
|
+
app.router.add_delete(handler[0], managed_handler, **handler[2])
|
|
3777
|
+
for handler in g_app.server_add_patch:
|
|
3778
|
+
handler_fn = handler[1]
|
|
3779
|
+
|
|
3780
|
+
async def managed_handler(request, handler_fn=handler_fn):
|
|
3781
|
+
try:
|
|
3782
|
+
return await handler_fn(request)
|
|
3783
|
+
except Exception as e:
|
|
3784
|
+
return web.json_response(to_error_response(e, stacktrace=g_verbose), status=500)
|
|
3785
|
+
|
|
3786
|
+
app.router.add_patch(handler[0], managed_handler, **handler[2])
|
|
1993
3787
|
|
|
1994
3788
|
# Serve index.html from root
|
|
1995
3789
|
async def index_handler(request):
|
|
1996
3790
|
index_content = read_resource_file_bytes("index.html")
|
|
1997
|
-
|
|
1998
|
-
|
|
1999
|
-
|
|
2000
|
-
|
|
3791
|
+
|
|
3792
|
+
importmaps = {"imports": g_app.import_maps}
|
|
3793
|
+
importmaps_script = '<script type="importmap">\n' + json.dumps(importmaps, indent=4) + "\n</script>"
|
|
3794
|
+
index_content = index_content.replace(
|
|
3795
|
+
b'<script type="importmap"></script>',
|
|
3796
|
+
importmaps_script.encode("utf-8"),
|
|
3797
|
+
)
|
|
3798
|
+
|
|
3799
|
+
if len(g_app.index_headers) > 0:
|
|
3800
|
+
html_header = ""
|
|
3801
|
+
for header in g_app.index_headers:
|
|
3802
|
+
html_header += header
|
|
3803
|
+
# replace </head> with html_header
|
|
3804
|
+
index_content = index_content.replace(b"</head>", html_header.encode("utf-8") + b"\n</head>")
|
|
3805
|
+
|
|
3806
|
+
if len(g_app.index_footers) > 0:
|
|
3807
|
+
html_footer = ""
|
|
3808
|
+
for footer in g_app.index_footers:
|
|
3809
|
+
html_footer += footer
|
|
3810
|
+
# replace </body> with html_footer
|
|
3811
|
+
index_content = index_content.replace(b"</body>", html_footer.encode("utf-8") + b"\n</body>")
|
|
3812
|
+
|
|
3813
|
+
return web.Response(body=index_content, content_type="text/html")
|
|
3814
|
+
|
|
3815
|
+
app.router.add_get("/", index_handler)
|
|
2001
3816
|
|
|
2002
3817
|
# Serve index.html as fallback route (SPA routing)
|
|
2003
|
-
app.router.add_route(
|
|
3818
|
+
app.router.add_route("*", "/{tail:.*}", index_handler)
|
|
2004
3819
|
|
|
2005
3820
|
# Setup file watcher for config files
|
|
2006
3821
|
async def start_background_tasks(app):
|
|
2007
3822
|
"""Start background tasks when the app starts"""
|
|
2008
3823
|
# Start watching config files in the background
|
|
2009
|
-
asyncio.create_task(watch_config_files(g_config_path,
|
|
3824
|
+
asyncio.create_task(watch_config_files(g_config_path, home_providers_path))
|
|
2010
3825
|
|
|
2011
3826
|
app.on_startup.append(start_background_tasks)
|
|
2012
3827
|
|
|
3828
|
+
# go through and register all g_app extensions
|
|
3829
|
+
|
|
2013
3830
|
print(f"Starting server on port {port}...")
|
|
2014
|
-
web.run_app(app, host=
|
|
2015
|
-
exit(0)
|
|
3831
|
+
web.run_app(app, host="0.0.0.0", port=port, print=_log)
|
|
3832
|
+
g_app.exit(0)
|
|
2016
3833
|
|
|
2017
3834
|
if cli_args.enable is not None:
|
|
2018
|
-
if cli_args.enable.endswith(
|
|
3835
|
+
if cli_args.enable.endswith(","):
|
|
2019
3836
|
cli_args.enable = cli_args.enable[:-1].strip()
|
|
2020
3837
|
enable_providers = [cli_args.enable]
|
|
2021
|
-
all_providers = g_config[
|
|
3838
|
+
all_providers = g_config["providers"].keys()
|
|
2022
3839
|
msgs = []
|
|
2023
3840
|
if len(extra_args) > 0:
|
|
2024
3841
|
for arg in extra_args:
|
|
2025
|
-
if arg.endswith(
|
|
3842
|
+
if arg.endswith(","):
|
|
2026
3843
|
arg = arg[:-1].strip()
|
|
2027
3844
|
if arg in all_providers:
|
|
2028
3845
|
enable_providers.append(arg)
|
|
2029
3846
|
|
|
2030
3847
|
for provider in enable_providers:
|
|
2031
|
-
if provider not in g_config[
|
|
2032
|
-
print(f"Provider {provider} not found")
|
|
3848
|
+
if provider not in g_config["providers"]:
|
|
3849
|
+
print(f"Provider '{provider}' not found")
|
|
2033
3850
|
print(f"Available providers: {', '.join(g_config['providers'].keys())}")
|
|
2034
3851
|
exit(1)
|
|
2035
|
-
if provider in g_config[
|
|
3852
|
+
if provider in g_config["providers"]:
|
|
2036
3853
|
provider_config, msg = enable_provider(provider)
|
|
2037
3854
|
print(f"\nEnabled provider {provider}:")
|
|
2038
3855
|
printdump(provider_config)
|
|
@@ -2042,22 +3859,22 @@ def main():
|
|
|
2042
3859
|
print_status()
|
|
2043
3860
|
if len(msgs) > 0:
|
|
2044
3861
|
print("\n" + "\n".join(msgs))
|
|
2045
|
-
exit(0)
|
|
3862
|
+
g_app.exit(0)
|
|
2046
3863
|
|
|
2047
3864
|
if cli_args.disable is not None:
|
|
2048
|
-
if cli_args.disable.endswith(
|
|
3865
|
+
if cli_args.disable.endswith(","):
|
|
2049
3866
|
cli_args.disable = cli_args.disable[:-1].strip()
|
|
2050
3867
|
disable_providers = [cli_args.disable]
|
|
2051
|
-
all_providers = g_config[
|
|
3868
|
+
all_providers = g_config["providers"].keys()
|
|
2052
3869
|
if len(extra_args) > 0:
|
|
2053
3870
|
for arg in extra_args:
|
|
2054
|
-
if arg.endswith(
|
|
3871
|
+
if arg.endswith(","):
|
|
2055
3872
|
arg = arg[:-1].strip()
|
|
2056
3873
|
if arg in all_providers:
|
|
2057
3874
|
disable_providers.append(arg)
|
|
2058
3875
|
|
|
2059
3876
|
for provider in disable_providers:
|
|
2060
|
-
if provider not in g_config[
|
|
3877
|
+
if provider not in g_config["providers"]:
|
|
2061
3878
|
print(f"Provider {provider} not found")
|
|
2062
3879
|
print(f"Available providers: {', '.join(g_config['providers'].keys())}")
|
|
2063
3880
|
exit(1)
|
|
@@ -2065,30 +3882,42 @@ def main():
|
|
|
2065
3882
|
print(f"\nDisabled provider {provider}")
|
|
2066
3883
|
|
|
2067
3884
|
print_status()
|
|
2068
|
-
exit(0)
|
|
3885
|
+
g_app.exit(0)
|
|
2069
3886
|
|
|
2070
3887
|
if cli_args.default is not None:
|
|
2071
3888
|
default_model = cli_args.default
|
|
2072
|
-
|
|
2073
|
-
if
|
|
3889
|
+
provider_model = get_provider_model(default_model)
|
|
3890
|
+
if provider_model is None:
|
|
2074
3891
|
print(f"Model {default_model} not found")
|
|
2075
|
-
print(f"Available models: {', '.join(all_models)}")
|
|
2076
3892
|
exit(1)
|
|
2077
|
-
default_text = g_config[
|
|
2078
|
-
default_text[
|
|
3893
|
+
default_text = g_config["defaults"]["text"]
|
|
3894
|
+
default_text["model"] = default_model
|
|
2079
3895
|
save_config(g_config)
|
|
2080
3896
|
print(f"\nDefault model set to: {default_model}")
|
|
2081
|
-
exit(0)
|
|
2082
|
-
|
|
2083
|
-
if
|
|
3897
|
+
g_app.exit(0)
|
|
3898
|
+
|
|
3899
|
+
if (
|
|
3900
|
+
cli_args.chat is not None
|
|
3901
|
+
or cli_args.image is not None
|
|
3902
|
+
or cli_args.audio is not None
|
|
3903
|
+
or cli_args.file is not None
|
|
3904
|
+
or cli_args.out is not None
|
|
3905
|
+
or len(extra_args) > 0
|
|
3906
|
+
):
|
|
2084
3907
|
try:
|
|
2085
|
-
chat = g_config[
|
|
3908
|
+
chat = g_config["defaults"]["text"]
|
|
2086
3909
|
if cli_args.image is not None:
|
|
2087
|
-
chat = g_config[
|
|
3910
|
+
chat = g_config["defaults"]["image"]
|
|
2088
3911
|
elif cli_args.audio is not None:
|
|
2089
|
-
chat = g_config[
|
|
3912
|
+
chat = g_config["defaults"]["audio"]
|
|
2090
3913
|
elif cli_args.file is not None:
|
|
2091
|
-
chat = g_config[
|
|
3914
|
+
chat = g_config["defaults"]["file"]
|
|
3915
|
+
elif cli_args.out is not None:
|
|
3916
|
+
template = f"out:{cli_args.out}"
|
|
3917
|
+
if template not in g_config["defaults"]:
|
|
3918
|
+
print(f"Template for output modality '{cli_args.out}' not found")
|
|
3919
|
+
exit(1)
|
|
3920
|
+
chat = g_config["defaults"][template]
|
|
2092
3921
|
if cli_args.chat is not None:
|
|
2093
3922
|
chat_path = os.path.join(os.path.dirname(__file__), cli_args.chat)
|
|
2094
3923
|
if not os.path.exists(chat_path):
|
|
@@ -2096,41 +3925,60 @@ def main():
|
|
|
2096
3925
|
exit(1)
|
|
2097
3926
|
_log(f"Using chat: {chat_path}")
|
|
2098
3927
|
|
|
2099
|
-
with open
|
|
3928
|
+
with open(chat_path) as f:
|
|
2100
3929
|
chat_json = f.read()
|
|
2101
3930
|
chat = json.loads(chat_json)
|
|
2102
3931
|
|
|
2103
3932
|
if cli_args.system is not None:
|
|
2104
|
-
chat[
|
|
3933
|
+
chat["messages"].insert(0, {"role": "system", "content": cli_args.system})
|
|
2105
3934
|
|
|
2106
3935
|
if len(extra_args) > 0:
|
|
2107
|
-
prompt =
|
|
3936
|
+
prompt = " ".join(extra_args)
|
|
3937
|
+
if not chat["messages"] or len(chat["messages"]) == 0:
|
|
3938
|
+
chat["messages"] = [{"role": "user", "content": [{"type": "text", "text": ""}]}]
|
|
3939
|
+
|
|
2108
3940
|
# replace content of last message if exists, else add
|
|
2109
|
-
last_msg = chat[
|
|
2110
|
-
if last_msg and last_msg[
|
|
2111
|
-
if isinstance(last_msg[
|
|
2112
|
-
last_msg[
|
|
3941
|
+
last_msg = chat["messages"][-1] if "messages" in chat else None
|
|
3942
|
+
if last_msg and last_msg["role"] == "user":
|
|
3943
|
+
if isinstance(last_msg["content"], list):
|
|
3944
|
+
last_msg["content"][-1]["text"] = prompt
|
|
2113
3945
|
else:
|
|
2114
|
-
last_msg[
|
|
3946
|
+
last_msg["content"] = prompt
|
|
2115
3947
|
else:
|
|
2116
|
-
chat[
|
|
3948
|
+
chat["messages"].append({"role": "user", "content": prompt})
|
|
2117
3949
|
|
|
2118
3950
|
# Parse args parameters if provided
|
|
2119
3951
|
args = None
|
|
2120
3952
|
if cli_args.args is not None:
|
|
2121
3953
|
args = parse_args_params(cli_args.args)
|
|
2122
3954
|
|
|
2123
|
-
asyncio.run(
|
|
2124
|
-
|
|
3955
|
+
asyncio.run(
|
|
3956
|
+
cli_chat(
|
|
3957
|
+
chat,
|
|
3958
|
+
tools=cli_args.tools,
|
|
3959
|
+
image=cli_args.image,
|
|
3960
|
+
audio=cli_args.audio,
|
|
3961
|
+
file=cli_args.file,
|
|
3962
|
+
args=args,
|
|
3963
|
+
raw=cli_args.raw,
|
|
3964
|
+
)
|
|
3965
|
+
)
|
|
3966
|
+
g_app.exit(0)
|
|
2125
3967
|
except Exception as e:
|
|
2126
3968
|
print(f"{cli_args.logprefix}Error: {e}")
|
|
2127
3969
|
if cli_args.verbose:
|
|
2128
3970
|
traceback.print_exc()
|
|
2129
|
-
exit(1)
|
|
3971
|
+
g_app.exit(1)
|
|
3972
|
+
|
|
3973
|
+
handled = run_extension_cli()
|
|
2130
3974
|
|
|
2131
|
-
|
|
2132
|
-
|
|
3975
|
+
if not handled:
|
|
3976
|
+
# show usage from ArgumentParser
|
|
3977
|
+
parser.print_help()
|
|
3978
|
+
g_app.exit(0)
|
|
2133
3979
|
|
|
2134
3980
|
|
|
2135
|
-
if __name__ == "__main__":
|
|
3981
|
+
if __name__ == "__main__":
|
|
3982
|
+
if MOCK or DEBUG:
|
|
3983
|
+
print(f"MOCK={MOCK} or DEBUG={DEBUG}")
|
|
2136
3984
|
main()
|