llms-py 2.0.9__py3-none-any.whl → 3.0.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llms/__init__.py +4 -0
- llms/__main__.py +9 -0
- llms/db.py +359 -0
- llms/extensions/analytics/ui/index.mjs +1444 -0
- llms/extensions/app/README.md +20 -0
- llms/extensions/app/__init__.py +589 -0
- llms/extensions/app/db.py +536 -0
- {llms_py-2.0.9.data/data → llms/extensions/app}/ui/Recents.mjs +100 -73
- llms_py-2.0.9.data/data/ui/Sidebar.mjs → llms/extensions/app/ui/index.mjs +150 -79
- llms/extensions/app/ui/threadStore.mjs +433 -0
- llms/extensions/core_tools/CALCULATOR.md +32 -0
- llms/extensions/core_tools/__init__.py +637 -0
- llms/extensions/core_tools/ui/codemirror/addon/edit/closebrackets.js +201 -0
- llms/extensions/core_tools/ui/codemirror/addon/edit/closetag.js +185 -0
- llms/extensions/core_tools/ui/codemirror/addon/edit/continuelist.js +101 -0
- llms/extensions/core_tools/ui/codemirror/addon/edit/matchbrackets.js +160 -0
- llms/extensions/core_tools/ui/codemirror/addon/edit/matchtags.js +66 -0
- llms/extensions/core_tools/ui/codemirror/addon/edit/trailingspace.js +27 -0
- llms/extensions/core_tools/ui/codemirror/addon/selection/active-line.js +72 -0
- llms/extensions/core_tools/ui/codemirror/addon/selection/mark-selection.js +119 -0
- llms/extensions/core_tools/ui/codemirror/addon/selection/selection-pointer.js +98 -0
- llms/extensions/core_tools/ui/codemirror/codemirror.css +344 -0
- llms/extensions/core_tools/ui/codemirror/codemirror.js +9884 -0
- llms/extensions/core_tools/ui/codemirror/doc/docs.css +225 -0
- llms/extensions/core_tools/ui/codemirror/doc/source_sans.woff +0 -0
- llms/extensions/core_tools/ui/codemirror/mode/clike/clike.js +942 -0
- llms/extensions/core_tools/ui/codemirror/mode/javascript/index.html +118 -0
- llms/extensions/core_tools/ui/codemirror/mode/javascript/javascript.js +962 -0
- llms/extensions/core_tools/ui/codemirror/mode/javascript/typescript.html +62 -0
- llms/extensions/core_tools/ui/codemirror/mode/python/python.js +402 -0
- llms/extensions/core_tools/ui/codemirror/theme/dracula.css +40 -0
- llms/extensions/core_tools/ui/codemirror/theme/mocha.css +135 -0
- llms/extensions/core_tools/ui/index.mjs +650 -0
- llms/extensions/gallery/README.md +61 -0
- llms/extensions/gallery/__init__.py +63 -0
- llms/extensions/gallery/db.py +243 -0
- llms/extensions/gallery/ui/index.mjs +482 -0
- llms/extensions/katex/README.md +39 -0
- llms/extensions/katex/__init__.py +6 -0
- llms/extensions/katex/ui/README.md +125 -0
- llms/extensions/katex/ui/contrib/auto-render.js +338 -0
- llms/extensions/katex/ui/contrib/auto-render.min.js +1 -0
- llms/extensions/katex/ui/contrib/auto-render.mjs +244 -0
- llms/extensions/katex/ui/contrib/copy-tex.js +127 -0
- llms/extensions/katex/ui/contrib/copy-tex.min.js +1 -0
- llms/extensions/katex/ui/contrib/copy-tex.mjs +105 -0
- llms/extensions/katex/ui/contrib/mathtex-script-type.js +109 -0
- llms/extensions/katex/ui/contrib/mathtex-script-type.min.js +1 -0
- llms/extensions/katex/ui/contrib/mathtex-script-type.mjs +24 -0
- llms/extensions/katex/ui/contrib/mhchem.js +3213 -0
- llms/extensions/katex/ui/contrib/mhchem.min.js +1 -0
- llms/extensions/katex/ui/contrib/mhchem.mjs +3109 -0
- llms/extensions/katex/ui/contrib/render-a11y-string.js +887 -0
- llms/extensions/katex/ui/contrib/render-a11y-string.min.js +1 -0
- llms/extensions/katex/ui/contrib/render-a11y-string.mjs +800 -0
- llms/extensions/katex/ui/fonts/KaTeX_AMS-Regular.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_AMS-Regular.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_AMS-Regular.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Caligraphic-Bold.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Caligraphic-Bold.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Caligraphic-Bold.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Caligraphic-Regular.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Caligraphic-Regular.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Caligraphic-Regular.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Fraktur-Bold.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Fraktur-Bold.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Fraktur-Bold.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Fraktur-Regular.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Fraktur-Regular.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Fraktur-Regular.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Main-Bold.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Main-Bold.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Main-Bold.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Main-BoldItalic.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Main-BoldItalic.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Main-BoldItalic.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Main-Italic.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Main-Italic.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Main-Italic.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Main-Regular.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Main-Regular.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Main-Regular.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Math-BoldItalic.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Math-BoldItalic.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Math-BoldItalic.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Math-Italic.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Math-Italic.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Math-Italic.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_SansSerif-Bold.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_SansSerif-Bold.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_SansSerif-Bold.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_SansSerif-Italic.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_SansSerif-Italic.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_SansSerif-Italic.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_SansSerif-Regular.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_SansSerif-Regular.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_SansSerif-Regular.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Script-Regular.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Script-Regular.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Script-Regular.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Size1-Regular.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Size1-Regular.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Size1-Regular.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Size2-Regular.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Size2-Regular.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Size2-Regular.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Size3-Regular.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Size3-Regular.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Size3-Regular.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Size4-Regular.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Size4-Regular.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Size4-Regular.woff2 +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Typewriter-Regular.ttf +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Typewriter-Regular.woff +0 -0
- llms/extensions/katex/ui/fonts/KaTeX_Typewriter-Regular.woff2 +0 -0
- llms/extensions/katex/ui/index.mjs +92 -0
- llms/extensions/katex/ui/katex-swap.css +1230 -0
- llms/extensions/katex/ui/katex-swap.min.css +1 -0
- llms/extensions/katex/ui/katex.css +1230 -0
- llms/extensions/katex/ui/katex.js +19080 -0
- llms/extensions/katex/ui/katex.min.css +1 -0
- llms/extensions/katex/ui/katex.min.js +1 -0
- llms/extensions/katex/ui/katex.min.mjs +1 -0
- llms/extensions/katex/ui/katex.mjs +18547 -0
- llms/extensions/providers/__init__.py +22 -0
- llms/extensions/providers/anthropic.py +233 -0
- llms/extensions/providers/cerebras.py +37 -0
- llms/extensions/providers/chutes.py +153 -0
- llms/extensions/providers/google.py +481 -0
- llms/extensions/providers/nvidia.py +103 -0
- llms/extensions/providers/openai.py +154 -0
- llms/extensions/providers/openrouter.py +74 -0
- llms/extensions/providers/zai.py +182 -0
- llms/extensions/system_prompts/README.md +22 -0
- llms/extensions/system_prompts/__init__.py +45 -0
- llms/extensions/system_prompts/ui/index.mjs +280 -0
- llms/extensions/system_prompts/ui/prompts.json +1067 -0
- llms/extensions/tools/__init__.py +144 -0
- llms/extensions/tools/ui/index.mjs +706 -0
- llms/index.html +58 -0
- llms/llms.json +400 -0
- llms/main.py +4407 -0
- llms/providers-extra.json +394 -0
- llms/providers.json +1 -0
- llms/ui/App.mjs +188 -0
- llms/ui/ai.mjs +217 -0
- llms/ui/app.css +7081 -0
- llms/ui/ctx.mjs +412 -0
- llms/ui/index.mjs +131 -0
- llms/ui/lib/chart.js +14 -0
- llms/ui/lib/charts.mjs +16 -0
- llms/ui/lib/color.js +14 -0
- llms/ui/lib/servicestack-vue.mjs +37 -0
- llms/ui/lib/vue.min.mjs +13 -0
- llms/ui/lib/vue.mjs +18530 -0
- {llms_py-2.0.9.data/data → llms}/ui/markdown.mjs +33 -15
- llms/ui/modules/chat/ChatBody.mjs +976 -0
- llms/ui/modules/chat/SettingsDialog.mjs +374 -0
- llms/ui/modules/chat/index.mjs +991 -0
- llms/ui/modules/icons.mjs +46 -0
- llms/ui/modules/layout.mjs +271 -0
- llms/ui/modules/model-selector.mjs +811 -0
- llms/ui/tailwind.input.css +742 -0
- {llms_py-2.0.9.data/data → llms}/ui/typography.css +133 -7
- llms/ui/utils.mjs +261 -0
- llms_py-3.0.10.dist-info/METADATA +49 -0
- llms_py-3.0.10.dist-info/RECORD +177 -0
- llms_py-3.0.10.dist-info/entry_points.txt +2 -0
- {llms_py-2.0.9.dist-info → llms_py-3.0.10.dist-info}/licenses/LICENSE +1 -2
- llms.py +0 -1402
- llms_py-2.0.9.data/data/index.html +0 -64
- llms_py-2.0.9.data/data/llms.json +0 -447
- llms_py-2.0.9.data/data/requirements.txt +0 -1
- llms_py-2.0.9.data/data/ui/App.mjs +0 -20
- llms_py-2.0.9.data/data/ui/ChatPrompt.mjs +0 -389
- llms_py-2.0.9.data/data/ui/Main.mjs +0 -680
- llms_py-2.0.9.data/data/ui/app.css +0 -3951
- llms_py-2.0.9.data/data/ui/lib/servicestack-vue.min.mjs +0 -37
- llms_py-2.0.9.data/data/ui/lib/vue.min.mjs +0 -12
- llms_py-2.0.9.data/data/ui/tailwind.input.css +0 -261
- llms_py-2.0.9.data/data/ui/threadStore.mjs +0 -273
- llms_py-2.0.9.data/data/ui/utils.mjs +0 -114
- llms_py-2.0.9.data/data/ui.json +0 -1069
- llms_py-2.0.9.dist-info/METADATA +0 -941
- llms_py-2.0.9.dist-info/RECORD +0 -30
- llms_py-2.0.9.dist-info/entry_points.txt +0 -2
- {llms_py-2.0.9.data/data → llms}/ui/fav.svg +0 -0
- {llms_py-2.0.9.data/data → llms}/ui/lib/highlight.min.mjs +0 -0
- {llms_py-2.0.9.data/data → llms}/ui/lib/idb.min.mjs +0 -0
- {llms_py-2.0.9.data/data → llms}/ui/lib/marked.min.mjs +0 -0
- /llms_py-2.0.9.data/data/ui/lib/servicestack-client.min.mjs → /llms/ui/lib/servicestack-client.mjs +0 -0
- {llms_py-2.0.9.data/data → llms}/ui/lib/vue-router.min.mjs +0 -0
- {llms_py-2.0.9.dist-info → llms_py-3.0.10.dist-info}/WHEEL +0 -0
- {llms_py-2.0.9.dist-info → llms_py-3.0.10.dist-info}/top_level.txt +0 -0
llms/main.py
ADDED
|
@@ -0,0 +1,4407 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
# Copyright (c) Demis Bellot, ServiceStack <https://servicestack.net>
|
|
4
|
+
# License: https://github.com/ServiceStack/llms/blob/main/LICENSE
|
|
5
|
+
|
|
6
|
+
# A lightweight CLI tool and OpenAI-compatible server for querying multiple Large Language Model (LLM) providers.
|
|
7
|
+
# Docs: https://github.com/ServiceStack/llms
|
|
8
|
+
|
|
9
|
+
import argparse
|
|
10
|
+
import asyncio
|
|
11
|
+
import base64
|
|
12
|
+
import contextlib
|
|
13
|
+
import hashlib
|
|
14
|
+
import importlib.util
|
|
15
|
+
import inspect
|
|
16
|
+
import json
|
|
17
|
+
import mimetypes
|
|
18
|
+
import os
|
|
19
|
+
import re
|
|
20
|
+
import secrets
|
|
21
|
+
import shutil
|
|
22
|
+
import site
|
|
23
|
+
import subprocess
|
|
24
|
+
import sys
|
|
25
|
+
import time
|
|
26
|
+
import traceback
|
|
27
|
+
from datetime import datetime
|
|
28
|
+
from importlib import resources # Py≥3.9 (pip install importlib_resources for 3.7/3.8)
|
|
29
|
+
from io import BytesIO
|
|
30
|
+
from pathlib import Path
|
|
31
|
+
from typing import Optional, get_type_hints
|
|
32
|
+
from urllib.parse import parse_qs, urlencode, urljoin
|
|
33
|
+
|
|
34
|
+
import aiohttp
|
|
35
|
+
from aiohttp import web
|
|
36
|
+
|
|
37
|
+
try:
|
|
38
|
+
from PIL import Image
|
|
39
|
+
|
|
40
|
+
HAS_PIL = True
|
|
41
|
+
except ImportError:
|
|
42
|
+
HAS_PIL = False
|
|
43
|
+
|
|
44
|
+
VERSION = "3.0.10"
|
|
45
|
+
_ROOT = None
|
|
46
|
+
DEBUG = os.getenv("DEBUG") == "1"
|
|
47
|
+
MOCK = os.getenv("MOCK") == "1"
|
|
48
|
+
MOCK_DIR = os.getenv("MOCK_DIR")
|
|
49
|
+
DISABLE_EXTENSIONS = (os.getenv("LLMS_DISABLE") or "").split(",")
|
|
50
|
+
g_config_path = None
|
|
51
|
+
g_config = None
|
|
52
|
+
g_providers = None
|
|
53
|
+
g_handlers = {}
|
|
54
|
+
g_verbose = False
|
|
55
|
+
g_logprefix = ""
|
|
56
|
+
g_default_model = ""
|
|
57
|
+
g_sessions = {} # OAuth session storage: {session_token: {userId, userName, displayName, profileUrl, email, created}}
|
|
58
|
+
g_oauth_states = {} # CSRF protection: {state: {created, redirect_uri}}
|
|
59
|
+
g_app = None # ExtensionsContext Singleton
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _log(message):
|
|
63
|
+
if g_verbose:
|
|
64
|
+
print(f"{g_logprefix}{message}", flush=True)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _dbg(message):
|
|
68
|
+
if DEBUG:
|
|
69
|
+
print(f"DEBUG: {message}", flush=True)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _err(message, e):
|
|
73
|
+
print(f"ERROR: {message}: {e}", flush=True)
|
|
74
|
+
if g_verbose:
|
|
75
|
+
print(traceback.format_exc(), flush=True)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def printdump(obj):
|
|
79
|
+
args = obj.__dict__ if hasattr(obj, "__dict__") else obj
|
|
80
|
+
print(json.dumps(args, indent=2))
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def print_chat(chat):
|
|
84
|
+
_log(f"Chat: {chat_summary(chat)}")
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def chat_summary(chat):
|
|
88
|
+
"""Summarize chat completion request for logging."""
|
|
89
|
+
# replace image_url.url with <image>
|
|
90
|
+
clone = json.loads(json.dumps(chat))
|
|
91
|
+
for message in clone["messages"]:
|
|
92
|
+
if "content" in message and isinstance(message["content"], list):
|
|
93
|
+
for item in message["content"]:
|
|
94
|
+
if "image_url" in item:
|
|
95
|
+
if "url" in item["image_url"]:
|
|
96
|
+
url = item["image_url"]["url"]
|
|
97
|
+
prefix = url.split(",", 1)[0]
|
|
98
|
+
item["image_url"]["url"] = prefix + f",({len(url) - len(prefix)})"
|
|
99
|
+
elif "input_audio" in item:
|
|
100
|
+
if "data" in item["input_audio"]:
|
|
101
|
+
data = item["input_audio"]["data"]
|
|
102
|
+
item["input_audio"]["data"] = f"({len(data)})"
|
|
103
|
+
elif "file" in item and "file_data" in item["file"]:
|
|
104
|
+
data = item["file"]["file_data"]
|
|
105
|
+
prefix = data.split(",", 1)[0]
|
|
106
|
+
item["file"]["file_data"] = prefix + f",({len(data) - len(prefix)})"
|
|
107
|
+
return json.dumps(clone, indent=2)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
image_exts = ["png", "webp", "jpg", "jpeg", "gif", "bmp", "svg", "tiff", "ico"]
|
|
111
|
+
audio_exts = ["mp3", "wav", "ogg", "flac", "m4a", "opus", "webm"]
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def is_file_path(path):
|
|
115
|
+
# macOs max path is 1023
|
|
116
|
+
return path and len(path) < 1024 and os.path.exists(path)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def is_url(url):
|
|
120
|
+
return url and (url.startswith("http://") or url.startswith("https://"))
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def get_filename(file):
|
|
124
|
+
return file.rsplit("/", 1)[1] if "/" in file else "file"
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def parse_args_params(args_str):
|
|
128
|
+
"""Parse URL-encoded parameters and return a dictionary."""
|
|
129
|
+
if not args_str:
|
|
130
|
+
return {}
|
|
131
|
+
|
|
132
|
+
# Parse the URL-encoded string
|
|
133
|
+
parsed = parse_qs(args_str, keep_blank_values=True)
|
|
134
|
+
|
|
135
|
+
# Convert to simple dict with single values (not lists)
|
|
136
|
+
result = {}
|
|
137
|
+
for key, values in parsed.items():
|
|
138
|
+
if len(values) == 1:
|
|
139
|
+
value = values[0]
|
|
140
|
+
# Try to convert to appropriate types
|
|
141
|
+
if value.lower() == "true":
|
|
142
|
+
result[key] = True
|
|
143
|
+
elif value.lower() == "false":
|
|
144
|
+
result[key] = False
|
|
145
|
+
elif value.isdigit():
|
|
146
|
+
result[key] = int(value)
|
|
147
|
+
else:
|
|
148
|
+
try:
|
|
149
|
+
# Try to parse as float
|
|
150
|
+
result[key] = float(value)
|
|
151
|
+
except ValueError:
|
|
152
|
+
# Keep as string
|
|
153
|
+
result[key] = value
|
|
154
|
+
else:
|
|
155
|
+
# Multiple values, keep as list
|
|
156
|
+
result[key] = values
|
|
157
|
+
|
|
158
|
+
return result
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def apply_args_to_chat(chat, args_params):
|
|
162
|
+
"""Apply parsed arguments to the chat request."""
|
|
163
|
+
if not args_params:
|
|
164
|
+
return chat
|
|
165
|
+
|
|
166
|
+
# Apply each parameter to the chat request
|
|
167
|
+
for key, value in args_params.items():
|
|
168
|
+
if isinstance(value, str):
|
|
169
|
+
if key == "stop":
|
|
170
|
+
if "," in value:
|
|
171
|
+
value = value.split(",")
|
|
172
|
+
elif (
|
|
173
|
+
key == "max_completion_tokens"
|
|
174
|
+
or key == "max_tokens"
|
|
175
|
+
or key == "n"
|
|
176
|
+
or key == "seed"
|
|
177
|
+
or key == "top_logprobs"
|
|
178
|
+
):
|
|
179
|
+
value = int(value)
|
|
180
|
+
elif key == "temperature" or key == "top_p" or key == "frequency_penalty" or key == "presence_penalty":
|
|
181
|
+
value = float(value)
|
|
182
|
+
elif (
|
|
183
|
+
key == "store"
|
|
184
|
+
or key == "logprobs"
|
|
185
|
+
or key == "enable_thinking"
|
|
186
|
+
or key == "parallel_tool_calls"
|
|
187
|
+
or key == "stream"
|
|
188
|
+
):
|
|
189
|
+
value = bool(value)
|
|
190
|
+
chat[key] = value
|
|
191
|
+
|
|
192
|
+
return chat
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def is_base_64(data):
|
|
196
|
+
try:
|
|
197
|
+
base64.b64decode(data)
|
|
198
|
+
return True
|
|
199
|
+
except Exception:
|
|
200
|
+
return False
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def id_to_name(id):
|
|
204
|
+
return id.replace("-", " ").title()
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def pluralize(word, count):
|
|
208
|
+
if count == 1:
|
|
209
|
+
return word
|
|
210
|
+
return word + "s"
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def get_file_mime_type(filename):
|
|
214
|
+
mimetype, _ = mimetypes.guess_type(filename)
|
|
215
|
+
return mimetype or "application/octet-stream"
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def price_to_string(price: float | int | str | None) -> str | None:
|
|
219
|
+
"""Convert numeric price to string without scientific notation.
|
|
220
|
+
|
|
221
|
+
Detects and rounds up numbers with recurring 9s (e.g., 0.00014999999999999999)
|
|
222
|
+
to avoid floating-point precision artifacts.
|
|
223
|
+
"""
|
|
224
|
+
if price is None or price == 0 or price == "0":
|
|
225
|
+
return "0"
|
|
226
|
+
try:
|
|
227
|
+
price_float = float(price)
|
|
228
|
+
# Format with enough decimal places to avoid scientific notation
|
|
229
|
+
formatted = format(price_float, ".20f")
|
|
230
|
+
|
|
231
|
+
# Detect recurring 9s pattern (e.g., "...9999999")
|
|
232
|
+
# If we have 4 or more consecutive 9s, round up
|
|
233
|
+
if "9999" in formatted:
|
|
234
|
+
# Round up by adding a small amount and reformatting
|
|
235
|
+
# Find the position of the 9s to determine precision
|
|
236
|
+
import decimal
|
|
237
|
+
|
|
238
|
+
decimal.getcontext().prec = 28
|
|
239
|
+
d = decimal.Decimal(str(price_float))
|
|
240
|
+
# Round to one less decimal place than where the 9s start
|
|
241
|
+
nines_pos = formatted.find("9999")
|
|
242
|
+
if nines_pos > 0:
|
|
243
|
+
# Round up at the position before the 9s
|
|
244
|
+
decimal_places = nines_pos - formatted.find(".") - 1
|
|
245
|
+
if decimal_places > 0:
|
|
246
|
+
quantize_str = "0." + "0" * (decimal_places - 1) + "1"
|
|
247
|
+
d = d.quantize(decimal.Decimal(quantize_str), rounding=decimal.ROUND_UP)
|
|
248
|
+
result = str(d)
|
|
249
|
+
# Remove trailing zeros
|
|
250
|
+
if "." in result:
|
|
251
|
+
result = result.rstrip("0").rstrip(".")
|
|
252
|
+
return result
|
|
253
|
+
|
|
254
|
+
# Normal case: strip trailing zeros
|
|
255
|
+
return formatted.rstrip("0").rstrip(".")
|
|
256
|
+
except (ValueError, TypeError):
|
|
257
|
+
return None
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def convert_image_if_needed(image_bytes, mimetype="image/png"):
|
|
261
|
+
"""
|
|
262
|
+
Convert and resize image to WebP if it exceeds configured limits.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
image_bytes: Raw image bytes
|
|
266
|
+
mimetype: Original image MIME type
|
|
267
|
+
|
|
268
|
+
Returns:
|
|
269
|
+
tuple: (converted_bytes, new_mimetype) or (original_bytes, original_mimetype) if no conversion needed
|
|
270
|
+
"""
|
|
271
|
+
if not HAS_PIL:
|
|
272
|
+
return image_bytes, mimetype
|
|
273
|
+
|
|
274
|
+
# Get conversion config
|
|
275
|
+
convert_config = g_config.get("convert", {}).get("image", {}) if g_config else {}
|
|
276
|
+
if not convert_config:
|
|
277
|
+
return image_bytes, mimetype
|
|
278
|
+
|
|
279
|
+
max_size_str = convert_config.get("max_size", "1536x1024")
|
|
280
|
+
max_length = convert_config.get("max_length", 1.5 * 1024 * 1024) # 1.5MB
|
|
281
|
+
|
|
282
|
+
try:
|
|
283
|
+
# Parse max_size (e.g., "1536x1024")
|
|
284
|
+
max_width, max_height = map(int, max_size_str.split("x"))
|
|
285
|
+
|
|
286
|
+
# Open image
|
|
287
|
+
with Image.open(BytesIO(image_bytes)) as img:
|
|
288
|
+
original_width, original_height = img.size
|
|
289
|
+
|
|
290
|
+
# Check if image exceeds limits
|
|
291
|
+
needs_resize = original_width > max_width or original_height > max_height
|
|
292
|
+
|
|
293
|
+
# Check if base64 length would exceed max_length (in KB)
|
|
294
|
+
# Base64 encoding increases size by ~33%, so check raw bytes * 1.33 / 1024
|
|
295
|
+
estimated_kb = (len(image_bytes) * 1.33) / 1024
|
|
296
|
+
needs_conversion = estimated_kb > max_length
|
|
297
|
+
|
|
298
|
+
if not needs_resize and not needs_conversion:
|
|
299
|
+
return image_bytes, mimetype
|
|
300
|
+
|
|
301
|
+
# Convert RGBA to RGB if necessary (WebP doesn't support transparency in RGB mode)
|
|
302
|
+
if img.mode in ("RGBA", "LA", "P"):
|
|
303
|
+
# Create a white background
|
|
304
|
+
background = Image.new("RGB", img.size, (255, 255, 255))
|
|
305
|
+
if img.mode == "P":
|
|
306
|
+
img = img.convert("RGBA")
|
|
307
|
+
background.paste(img, mask=img.split()[-1] if img.mode in ("RGBA", "LA") else None)
|
|
308
|
+
img = background
|
|
309
|
+
elif img.mode != "RGB":
|
|
310
|
+
img = img.convert("RGB")
|
|
311
|
+
|
|
312
|
+
# Resize if needed (preserve aspect ratio)
|
|
313
|
+
if needs_resize:
|
|
314
|
+
img.thumbnail((max_width, max_height), Image.Resampling.LANCZOS)
|
|
315
|
+
_log(f"Resized image from {original_width}x{original_height} to {img.size[0]}x{img.size[1]}")
|
|
316
|
+
|
|
317
|
+
# Convert to WebP
|
|
318
|
+
output = BytesIO()
|
|
319
|
+
img.save(output, format="WEBP", quality=85, method=6)
|
|
320
|
+
converted_bytes = output.getvalue()
|
|
321
|
+
|
|
322
|
+
_log(
|
|
323
|
+
f"Converted image to WebP: {len(image_bytes)} bytes -> {len(converted_bytes)} bytes ({len(converted_bytes) * 100 // len(image_bytes)}%)"
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
return converted_bytes, "image/webp"
|
|
327
|
+
|
|
328
|
+
except Exception as e:
|
|
329
|
+
_log(f"Error converting image: {e}")
|
|
330
|
+
# Return original if conversion fails
|
|
331
|
+
return image_bytes, mimetype
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def to_content(result):
|
|
335
|
+
if isinstance(result, (str, int, float, bool)):
|
|
336
|
+
return str(result)
|
|
337
|
+
elif isinstance(result, (list, set, tuple, dict)):
|
|
338
|
+
return json.dumps(result)
|
|
339
|
+
else:
|
|
340
|
+
return str(result)
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def function_to_tool_definition(func):
|
|
344
|
+
type_hints = get_type_hints(func)
|
|
345
|
+
signature = inspect.signature(func)
|
|
346
|
+
parameters = {"type": "object", "properties": {}, "required": []}
|
|
347
|
+
|
|
348
|
+
for name, param in signature.parameters.items():
|
|
349
|
+
param_type = type_hints.get(name, str)
|
|
350
|
+
param_type_name = "string"
|
|
351
|
+
if param_type is int:
|
|
352
|
+
param_type_name = "integer"
|
|
353
|
+
elif param_type is float:
|
|
354
|
+
param_type_name = "number"
|
|
355
|
+
elif param_type is bool:
|
|
356
|
+
param_type_name = "boolean"
|
|
357
|
+
|
|
358
|
+
parameters["properties"][name] = {"type": param_type_name}
|
|
359
|
+
if param.default == inspect.Parameter.empty:
|
|
360
|
+
parameters["required"].append(name)
|
|
361
|
+
|
|
362
|
+
return {
|
|
363
|
+
"type": "function",
|
|
364
|
+
"function": {
|
|
365
|
+
"name": func.__name__,
|
|
366
|
+
"description": func.__doc__ or "",
|
|
367
|
+
"parameters": parameters,
|
|
368
|
+
},
|
|
369
|
+
}
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
async def download_file(url):
|
|
373
|
+
async with aiohttp.ClientSession() as session:
|
|
374
|
+
return await session_download_file(session, url)
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
async def session_download_file(session, url, default_mimetype="application/octet-stream"):
|
|
378
|
+
try:
|
|
379
|
+
async with session.get(url, timeout=aiohttp.ClientTimeout(total=120)) as response:
|
|
380
|
+
response.raise_for_status()
|
|
381
|
+
content = await response.read()
|
|
382
|
+
mimetype = response.headers.get("Content-Type")
|
|
383
|
+
disposition = response.headers.get("Content-Disposition")
|
|
384
|
+
if mimetype and ";" in mimetype:
|
|
385
|
+
mimetype = mimetype.split(";")[0]
|
|
386
|
+
ext = None
|
|
387
|
+
if disposition:
|
|
388
|
+
start = disposition.index('filename="') + len('filename="')
|
|
389
|
+
end = disposition.index('"', start)
|
|
390
|
+
filename = disposition[start:end]
|
|
391
|
+
if not mimetype:
|
|
392
|
+
mimetype = mimetypes.guess_type(filename)[0] or default_mimetype
|
|
393
|
+
else:
|
|
394
|
+
filename = url.split("/")[-1]
|
|
395
|
+
if "." not in filename:
|
|
396
|
+
if mimetype is None:
|
|
397
|
+
mimetype = default_mimetype
|
|
398
|
+
ext = mimetypes.guess_extension(mimetype) or mimetype.split("/")[1]
|
|
399
|
+
filename = f"{filename}.{ext}"
|
|
400
|
+
|
|
401
|
+
if not ext:
|
|
402
|
+
ext = Path(filename).suffix.lstrip(".")
|
|
403
|
+
|
|
404
|
+
info = {
|
|
405
|
+
"url": url,
|
|
406
|
+
"type": mimetype,
|
|
407
|
+
"name": filename,
|
|
408
|
+
"ext": ext,
|
|
409
|
+
}
|
|
410
|
+
return content, info
|
|
411
|
+
except Exception as e:
|
|
412
|
+
_err(f"Error downloading file: {url}", e)
|
|
413
|
+
raise e
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
def read_binary_file(url):
|
|
417
|
+
try:
|
|
418
|
+
path = Path(url)
|
|
419
|
+
with open(url, "rb") as f:
|
|
420
|
+
content = f.read()
|
|
421
|
+
info_path = path.stem + ".info.json"
|
|
422
|
+
if os.path.exists(info_path):
|
|
423
|
+
with open(info_path) as f_info:
|
|
424
|
+
info = json.load(f_info)
|
|
425
|
+
return content, info
|
|
426
|
+
|
|
427
|
+
stat = path.stat()
|
|
428
|
+
info = {
|
|
429
|
+
"date": int(stat.st_mtime),
|
|
430
|
+
"name": path.name,
|
|
431
|
+
"ext": path.suffix.lstrip("."),
|
|
432
|
+
"type": mimetypes.guess_type(path.name)[0],
|
|
433
|
+
"url": f"/~cache/{path.name[:2]}/{path.name}",
|
|
434
|
+
}
|
|
435
|
+
return content, info
|
|
436
|
+
except Exception as e:
|
|
437
|
+
_err(f"Error reading file: {url}", e)
|
|
438
|
+
raise e
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
async def process_chat(chat, provider_id=None):
|
|
442
|
+
if not chat:
|
|
443
|
+
raise Exception("No chat provided")
|
|
444
|
+
if "stream" not in chat:
|
|
445
|
+
chat["stream"] = False
|
|
446
|
+
# Some providers don't support empty tools
|
|
447
|
+
if "tools" in chat and (chat["tools"] is None or len(chat["tools"]) == 0):
|
|
448
|
+
del chat["tools"]
|
|
449
|
+
if "messages" not in chat:
|
|
450
|
+
return chat
|
|
451
|
+
|
|
452
|
+
async with aiohttp.ClientSession() as session:
|
|
453
|
+
for message in chat["messages"]:
|
|
454
|
+
if "content" not in message:
|
|
455
|
+
continue
|
|
456
|
+
|
|
457
|
+
if isinstance(message["content"], list):
|
|
458
|
+
for item in message["content"]:
|
|
459
|
+
if "type" not in item:
|
|
460
|
+
continue
|
|
461
|
+
if item["type"] == "image_url" and "image_url" in item:
|
|
462
|
+
image_url = item["image_url"]
|
|
463
|
+
if "url" in image_url:
|
|
464
|
+
url = image_url["url"]
|
|
465
|
+
if url.startswith("/~cache/"):
|
|
466
|
+
url = get_cache_path(url[8:])
|
|
467
|
+
if is_url(url):
|
|
468
|
+
_log(f"Downloading image: {url}")
|
|
469
|
+
content, info = await session_download_file(session, url, default_mimetype="image/png")
|
|
470
|
+
mimetype = info["type"]
|
|
471
|
+
# convert/resize image if needed
|
|
472
|
+
content, mimetype = convert_image_if_needed(content, mimetype)
|
|
473
|
+
# convert to data uri
|
|
474
|
+
image_url["url"] = f"data:{mimetype};base64,{base64.b64encode(content).decode('utf-8')}"
|
|
475
|
+
elif is_file_path(url):
|
|
476
|
+
_log(f"Reading image: {url}")
|
|
477
|
+
content, info = read_binary_file(url)
|
|
478
|
+
mimetype = info["type"]
|
|
479
|
+
# convert/resize image if needed
|
|
480
|
+
content, mimetype = convert_image_if_needed(content, mimetype)
|
|
481
|
+
# convert to data uri
|
|
482
|
+
image_url["url"] = f"data:{mimetype};base64,{base64.b64encode(content).decode('utf-8')}"
|
|
483
|
+
elif url.startswith("data:"):
|
|
484
|
+
# Extract existing data URI and process it
|
|
485
|
+
if ";base64," in url:
|
|
486
|
+
prefix = url.split(";base64,")[0]
|
|
487
|
+
mimetype = prefix.split(":")[1] if ":" in prefix else "image/png"
|
|
488
|
+
base64_data = url.split(";base64,")[1]
|
|
489
|
+
content = base64.b64decode(base64_data)
|
|
490
|
+
# convert/resize image if needed
|
|
491
|
+
content, mimetype = convert_image_if_needed(content, mimetype)
|
|
492
|
+
# update data uri with potentially converted image
|
|
493
|
+
image_url["url"] = (
|
|
494
|
+
f"data:{mimetype};base64,{base64.b64encode(content).decode('utf-8')}"
|
|
495
|
+
)
|
|
496
|
+
else:
|
|
497
|
+
raise Exception(f"Invalid image: {url}")
|
|
498
|
+
elif item["type"] == "input_audio" and "input_audio" in item:
|
|
499
|
+
input_audio = item["input_audio"]
|
|
500
|
+
if "data" in input_audio:
|
|
501
|
+
url = input_audio["data"]
|
|
502
|
+
if url.startswith("/~cache/"):
|
|
503
|
+
url = get_cache_path(url[8:])
|
|
504
|
+
if is_url(url):
|
|
505
|
+
_log(f"Downloading audio: {url}")
|
|
506
|
+
content, info = await session_download_file(session, url, default_mimetype="audio/mp3")
|
|
507
|
+
mimetype = info["type"]
|
|
508
|
+
# convert to base64
|
|
509
|
+
input_audio["data"] = base64.b64encode(content).decode("utf-8")
|
|
510
|
+
if provider_id == "alibaba":
|
|
511
|
+
input_audio["data"] = f"data:{mimetype};base64,{input_audio['data']}"
|
|
512
|
+
input_audio["format"] = mimetype.rsplit("/", 1)[1]
|
|
513
|
+
elif is_file_path(url):
|
|
514
|
+
_log(f"Reading audio: {url}")
|
|
515
|
+
content, info = read_binary_file(url)
|
|
516
|
+
mimetype = info["type"]
|
|
517
|
+
# convert to base64
|
|
518
|
+
input_audio["data"] = base64.b64encode(content).decode("utf-8")
|
|
519
|
+
if provider_id == "alibaba":
|
|
520
|
+
input_audio["data"] = f"data:{mimetype};base64,{input_audio['data']}"
|
|
521
|
+
input_audio["format"] = mimetype.rsplit("/", 1)[1]
|
|
522
|
+
elif is_base_64(url):
|
|
523
|
+
pass # use base64 data as-is
|
|
524
|
+
else:
|
|
525
|
+
raise Exception(f"Invalid audio: {url}")
|
|
526
|
+
elif item["type"] == "file" and "file" in item:
|
|
527
|
+
file = item["file"]
|
|
528
|
+
if "file_data" in file:
|
|
529
|
+
url = file["file_data"]
|
|
530
|
+
if url.startswith("/~cache/"):
|
|
531
|
+
url = get_cache_path(url[8:])
|
|
532
|
+
if is_url(url):
|
|
533
|
+
_log(f"Downloading file: {url}")
|
|
534
|
+
content, info = await session_download_file(
|
|
535
|
+
session, url, default_mimetype="application/pdf"
|
|
536
|
+
)
|
|
537
|
+
mimetype = info["type"]
|
|
538
|
+
file["filename"] = info["name"]
|
|
539
|
+
file["file_data"] = (
|
|
540
|
+
f"data:{mimetype};base64,{base64.b64encode(content).decode('utf-8')}"
|
|
541
|
+
)
|
|
542
|
+
elif is_file_path(url):
|
|
543
|
+
_log(f"Reading file: {url}")
|
|
544
|
+
content, info = read_binary_file(url)
|
|
545
|
+
mimetype = info["type"]
|
|
546
|
+
file["filename"] = info["name"]
|
|
547
|
+
file["file_data"] = (
|
|
548
|
+
f"data:{mimetype};base64,{base64.b64encode(content).decode('utf-8')}"
|
|
549
|
+
)
|
|
550
|
+
elif url.startswith("data:"):
|
|
551
|
+
if "filename" not in file:
|
|
552
|
+
file["filename"] = "file"
|
|
553
|
+
pass # use base64 data as-is
|
|
554
|
+
else:
|
|
555
|
+
raise Exception(f"Invalid file: {url}")
|
|
556
|
+
return chat
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
def image_ext_from_mimetype(mimetype, default="png"):
|
|
560
|
+
if "/" in mimetype:
|
|
561
|
+
_ext = mimetypes.guess_extension(mimetype)
|
|
562
|
+
if _ext:
|
|
563
|
+
return _ext.lstrip(".")
|
|
564
|
+
return default
|
|
565
|
+
|
|
566
|
+
|
|
567
|
+
def audio_ext_from_format(format, default="mp3"):
|
|
568
|
+
if format == "mpeg":
|
|
569
|
+
return "mp3"
|
|
570
|
+
return format or default
|
|
571
|
+
|
|
572
|
+
|
|
573
|
+
def file_ext_from_mimetype(mimetype, default="pdf"):
|
|
574
|
+
if "/" in mimetype:
|
|
575
|
+
_ext = mimetypes.guess_extension(mimetype)
|
|
576
|
+
if _ext:
|
|
577
|
+
return _ext.lstrip(".")
|
|
578
|
+
return default
|
|
579
|
+
|
|
580
|
+
|
|
581
|
+
def cache_message_inline_data(m):
|
|
582
|
+
"""
|
|
583
|
+
Replaces and caches any inline data URIs in the message content.
|
|
584
|
+
"""
|
|
585
|
+
if "content" not in m:
|
|
586
|
+
return
|
|
587
|
+
|
|
588
|
+
content = m["content"]
|
|
589
|
+
if isinstance(content, list):
|
|
590
|
+
for item in content:
|
|
591
|
+
if item.get("type") == "image_url":
|
|
592
|
+
image_url = item.get("image_url", {})
|
|
593
|
+
url = image_url.get("url")
|
|
594
|
+
if url and url.startswith("data:"):
|
|
595
|
+
# Extract base64 and mimetype
|
|
596
|
+
try:
|
|
597
|
+
header, base64_data = url.split(";base64,")
|
|
598
|
+
# header is like "data:image/png"
|
|
599
|
+
ext = image_ext_from_mimetype(header.split(":")[1])
|
|
600
|
+
filename = f"image.{ext}" # Hash will handle uniqueness
|
|
601
|
+
|
|
602
|
+
cache_url, _ = save_image_to_cache(base64_data, filename, {}, ignore_info=True)
|
|
603
|
+
image_url["url"] = cache_url
|
|
604
|
+
except Exception as e:
|
|
605
|
+
_log(f"Error caching inline image: {e}")
|
|
606
|
+
|
|
607
|
+
elif item.get("type") == "input_audio":
|
|
608
|
+
input_audio = item.get("input_audio", {})
|
|
609
|
+
data = input_audio.get("data")
|
|
610
|
+
if data:
|
|
611
|
+
# Handle data URI or raw base64
|
|
612
|
+
base64_data = data
|
|
613
|
+
if data.startswith("data:"):
|
|
614
|
+
with contextlib.suppress(ValueError):
|
|
615
|
+
header, base64_data = data.split(";base64,")
|
|
616
|
+
|
|
617
|
+
fmt = audio_ext_from_format(input_audio.get("format"))
|
|
618
|
+
filename = f"audio.{fmt}"
|
|
619
|
+
|
|
620
|
+
try:
|
|
621
|
+
cache_url, _ = save_bytes_to_cache(base64_data, filename, {}, ignore_info=True)
|
|
622
|
+
input_audio["data"] = cache_url
|
|
623
|
+
except Exception as e:
|
|
624
|
+
_log(f"Error caching inline audio: {e}")
|
|
625
|
+
|
|
626
|
+
elif item.get("type") == "file":
|
|
627
|
+
file_info = item.get("file", {})
|
|
628
|
+
file_data = file_info.get("file_data")
|
|
629
|
+
if file_data and file_data.startswith("data:"):
|
|
630
|
+
try:
|
|
631
|
+
header, base64_data = file_data.split(";base64,")
|
|
632
|
+
mimetype = header.split(":")[1]
|
|
633
|
+
# Try to get extension from filename if available, else mimetype
|
|
634
|
+
filename = file_info.get("filename", "file")
|
|
635
|
+
if "." not in filename:
|
|
636
|
+
ext = file_ext_from_mimetype(mimetype)
|
|
637
|
+
filename = f"{filename}.{ext}"
|
|
638
|
+
|
|
639
|
+
cache_url, info = save_bytes_to_cache(base64_data, filename)
|
|
640
|
+
file_info["file_data"] = cache_url
|
|
641
|
+
file_info["filename"] = info["name"]
|
|
642
|
+
except Exception as e:
|
|
643
|
+
_log(f"Error caching inline file: {e}")
|
|
644
|
+
|
|
645
|
+
|
|
646
|
+
class HTTPError(Exception):
|
|
647
|
+
def __init__(self, status, reason, body, headers=None):
|
|
648
|
+
self.status = status
|
|
649
|
+
self.reason = reason
|
|
650
|
+
self.body = body
|
|
651
|
+
self.headers = headers
|
|
652
|
+
super().__init__(f"HTTP {status} {reason}")
|
|
653
|
+
|
|
654
|
+
|
|
655
|
+
def save_bytes_to_cache(base64_data, filename, file_info=None, ignore_info=False):
|
|
656
|
+
ext = filename.split(".")[-1]
|
|
657
|
+
mimetype = get_file_mime_type(filename)
|
|
658
|
+
content = base64.b64decode(base64_data) if isinstance(base64_data, str) else base64_data
|
|
659
|
+
sha256_hash = hashlib.sha256(content).hexdigest()
|
|
660
|
+
|
|
661
|
+
save_filename = f"{sha256_hash}.{ext}" if ext else sha256_hash
|
|
662
|
+
|
|
663
|
+
# Use first 2 chars for subdir to avoid too many files in one dir
|
|
664
|
+
subdir = sha256_hash[:2]
|
|
665
|
+
relative_path = f"{subdir}/{save_filename}"
|
|
666
|
+
full_path = get_cache_path(relative_path)
|
|
667
|
+
url = f"/~cache/{relative_path}"
|
|
668
|
+
|
|
669
|
+
# if file and its .info.json already exists, return it
|
|
670
|
+
info_path = os.path.splitext(full_path)[0] + ".info.json"
|
|
671
|
+
if os.path.exists(full_path) and os.path.exists(info_path):
|
|
672
|
+
_dbg(f"Cached bytes exists: {relative_path}")
|
|
673
|
+
if ignore_info:
|
|
674
|
+
return url, None
|
|
675
|
+
return url, json_from_file(info_path)
|
|
676
|
+
|
|
677
|
+
os.makedirs(os.path.dirname(full_path), exist_ok=True)
|
|
678
|
+
|
|
679
|
+
with open(full_path, "wb") as f:
|
|
680
|
+
f.write(content)
|
|
681
|
+
info = {
|
|
682
|
+
"date": int(time.time()),
|
|
683
|
+
"url": url,
|
|
684
|
+
"size": len(content),
|
|
685
|
+
"type": mimetype,
|
|
686
|
+
"name": filename,
|
|
687
|
+
}
|
|
688
|
+
if file_info:
|
|
689
|
+
info.update(file_info)
|
|
690
|
+
|
|
691
|
+
# Save metadata
|
|
692
|
+
info_path = os.path.splitext(full_path)[0] + ".info.json"
|
|
693
|
+
with open(info_path, "w") as f:
|
|
694
|
+
json.dump(info, f)
|
|
695
|
+
|
|
696
|
+
_dbg(f"Saved cached bytes and info: {relative_path}")
|
|
697
|
+
|
|
698
|
+
g_app.on_cache_saved_filters({"url": url, "info": info})
|
|
699
|
+
|
|
700
|
+
return url, info
|
|
701
|
+
|
|
702
|
+
|
|
703
|
+
def save_audio_to_cache(base64_data, filename, audio_info, ignore_info=False):
|
|
704
|
+
return save_bytes_to_cache(base64_data, filename, audio_info, ignore_info)
|
|
705
|
+
|
|
706
|
+
|
|
707
|
+
def save_video_to_cache(base64_data, filename, file_info, ignore_info=False):
|
|
708
|
+
return save_bytes_to_cache(base64_data, filename, file_info, ignore_info)
|
|
709
|
+
|
|
710
|
+
|
|
711
|
+
def save_image_to_cache(base64_data, filename, image_info, ignore_info=False):
|
|
712
|
+
ext = filename.split(".")[-1]
|
|
713
|
+
mimetype = get_file_mime_type(filename)
|
|
714
|
+
content = base64.b64decode(base64_data) if isinstance(base64_data, str) else base64_data
|
|
715
|
+
sha256_hash = hashlib.sha256(content).hexdigest()
|
|
716
|
+
|
|
717
|
+
save_filename = f"{sha256_hash}.{ext}" if ext else sha256_hash
|
|
718
|
+
|
|
719
|
+
# Use first 2 chars for subdir to avoid too many files in one dir
|
|
720
|
+
subdir = sha256_hash[:2]
|
|
721
|
+
relative_path = f"{subdir}/{save_filename}"
|
|
722
|
+
full_path = get_cache_path(relative_path)
|
|
723
|
+
url = f"/~cache/{relative_path}"
|
|
724
|
+
|
|
725
|
+
# if file and its .info.json already exists, return it
|
|
726
|
+
info_path = os.path.splitext(full_path)[0] + ".info.json"
|
|
727
|
+
if os.path.exists(full_path) and os.path.exists(info_path):
|
|
728
|
+
_dbg(f"Saved image exists: {relative_path}")
|
|
729
|
+
if ignore_info:
|
|
730
|
+
return url, None
|
|
731
|
+
return url, json_from_file(info_path)
|
|
732
|
+
|
|
733
|
+
os.makedirs(os.path.dirname(full_path), exist_ok=True)
|
|
734
|
+
|
|
735
|
+
with open(full_path, "wb") as f:
|
|
736
|
+
f.write(content)
|
|
737
|
+
info = {
|
|
738
|
+
"date": int(time.time()),
|
|
739
|
+
"url": url,
|
|
740
|
+
"size": len(content),
|
|
741
|
+
"type": mimetype,
|
|
742
|
+
"name": filename,
|
|
743
|
+
}
|
|
744
|
+
info.update(image_info)
|
|
745
|
+
|
|
746
|
+
# If image, get dimensions
|
|
747
|
+
if HAS_PIL and mimetype.startswith("image/"):
|
|
748
|
+
try:
|
|
749
|
+
with Image.open(BytesIO(content)) as img:
|
|
750
|
+
info["width"] = img.width
|
|
751
|
+
info["height"] = img.height
|
|
752
|
+
except Exception:
|
|
753
|
+
pass
|
|
754
|
+
|
|
755
|
+
if "width" in info and "height" in info:
|
|
756
|
+
_log(f"Saved image to cache: {full_path} ({len(content)} bytes) {info['width']}x{info['height']}")
|
|
757
|
+
else:
|
|
758
|
+
_log(f"Saved image to cache: {full_path} ({len(content)} bytes)")
|
|
759
|
+
|
|
760
|
+
# Save metadata
|
|
761
|
+
info_path = os.path.splitext(full_path)[0] + ".info.json"
|
|
762
|
+
with open(info_path, "w") as f:
|
|
763
|
+
json.dump(info, f)
|
|
764
|
+
|
|
765
|
+
_dbg(f"Saved image and info: {relative_path}")
|
|
766
|
+
|
|
767
|
+
g_app.on_cache_saved_filters({"url": url, "info": info})
|
|
768
|
+
|
|
769
|
+
return url, info
|
|
770
|
+
|
|
771
|
+
|
|
772
|
+
async def response_json(response):
|
|
773
|
+
text = await response.text()
|
|
774
|
+
if response.status >= 400:
|
|
775
|
+
_dbg(f"HTTP {response.status} {response.reason}: {text}")
|
|
776
|
+
raise HTTPError(response.status, reason=response.reason, body=text, headers=dict(response.headers))
|
|
777
|
+
response.raise_for_status()
|
|
778
|
+
body = json.loads(text)
|
|
779
|
+
return body
|
|
780
|
+
|
|
781
|
+
|
|
782
|
+
def chat_to_prompt(chat):
|
|
783
|
+
prompt = ""
|
|
784
|
+
if "messages" in chat:
|
|
785
|
+
for message in chat["messages"]:
|
|
786
|
+
if message["role"] == "user":
|
|
787
|
+
# if content is string
|
|
788
|
+
if isinstance(message["content"], str):
|
|
789
|
+
if prompt:
|
|
790
|
+
prompt += "\n"
|
|
791
|
+
prompt += message["content"]
|
|
792
|
+
elif isinstance(message["content"], list):
|
|
793
|
+
# if content is array of objects
|
|
794
|
+
for part in message["content"]:
|
|
795
|
+
if part["type"] == "text":
|
|
796
|
+
if prompt:
|
|
797
|
+
prompt += "\n"
|
|
798
|
+
prompt += part["text"]
|
|
799
|
+
return prompt
|
|
800
|
+
|
|
801
|
+
|
|
802
|
+
def chat_to_system_prompt(chat):
|
|
803
|
+
if "messages" in chat:
|
|
804
|
+
for message in chat["messages"]:
|
|
805
|
+
if message["role"] == "system":
|
|
806
|
+
# if content is string
|
|
807
|
+
if isinstance(message["content"], str):
|
|
808
|
+
return message["content"]
|
|
809
|
+
elif isinstance(message["content"], list):
|
|
810
|
+
# if content is array of objects
|
|
811
|
+
for part in message["content"]:
|
|
812
|
+
if part["type"] == "text":
|
|
813
|
+
return part["text"]
|
|
814
|
+
return None
|
|
815
|
+
|
|
816
|
+
|
|
817
|
+
def chat_to_username(chat):
|
|
818
|
+
if "metadata" in chat and "user" in chat["metadata"]:
|
|
819
|
+
return chat["metadata"]["user"]
|
|
820
|
+
return None
|
|
821
|
+
|
|
822
|
+
|
|
823
|
+
def chat_to_aspect_ratio(chat):
|
|
824
|
+
if "image_config" in chat and "aspect_ratio" in chat["image_config"]:
|
|
825
|
+
return chat["image_config"]["aspect_ratio"]
|
|
826
|
+
return None
|
|
827
|
+
|
|
828
|
+
|
|
829
|
+
def last_user_prompt(chat):
|
|
830
|
+
prompt = ""
|
|
831
|
+
if "messages" in chat:
|
|
832
|
+
for message in chat["messages"]:
|
|
833
|
+
if message["role"] == "user":
|
|
834
|
+
# if content is string
|
|
835
|
+
if isinstance(message["content"], str):
|
|
836
|
+
prompt = message["content"]
|
|
837
|
+
elif isinstance(message["content"], list):
|
|
838
|
+
# if content is array of objects
|
|
839
|
+
for part in message["content"]:
|
|
840
|
+
if part["type"] == "text":
|
|
841
|
+
prompt = part["text"]
|
|
842
|
+
return prompt
|
|
843
|
+
|
|
844
|
+
|
|
845
|
+
def chat_response_to_message(openai_response):
|
|
846
|
+
"""
|
|
847
|
+
Returns an assistant message from the OpenAI Response.
|
|
848
|
+
Handles normalizing text, image, and audio responses into the message content.
|
|
849
|
+
"""
|
|
850
|
+
timestamp = int(time.time() * 1000) # openai_response.get("created")
|
|
851
|
+
choices = openai_response
|
|
852
|
+
if isinstance(openai_response, dict) and "choices" in openai_response:
|
|
853
|
+
choices = openai_response["choices"]
|
|
854
|
+
|
|
855
|
+
choice = choices[0] if isinstance(choices, list) and choices else choices
|
|
856
|
+
|
|
857
|
+
if isinstance(choice, str):
|
|
858
|
+
return {"role": "assistant", "content": choice, "timestamp": timestamp}
|
|
859
|
+
|
|
860
|
+
if isinstance(choice, dict):
|
|
861
|
+
message = choice.get("message", choice)
|
|
862
|
+
else:
|
|
863
|
+
return {"role": "assistant", "content": str(choice), "timestamp": timestamp}
|
|
864
|
+
|
|
865
|
+
# Ensure message is a dict
|
|
866
|
+
if not isinstance(message, dict):
|
|
867
|
+
return {"role": "assistant", "content": message, "timestamp": timestamp}
|
|
868
|
+
|
|
869
|
+
message.update({"timestamp": timestamp})
|
|
870
|
+
return message
|
|
871
|
+
|
|
872
|
+
|
|
873
|
+
def to_file_info(chat, info=None, response=None):
|
|
874
|
+
prompt = last_user_prompt(chat)
|
|
875
|
+
ret = info or {}
|
|
876
|
+
if chat["model"] and "model" not in ret:
|
|
877
|
+
ret["model"] = chat["model"]
|
|
878
|
+
if prompt and "prompt" not in ret:
|
|
879
|
+
ret["prompt"] = prompt
|
|
880
|
+
if "image_config" in chat:
|
|
881
|
+
ret.update(chat["image_config"])
|
|
882
|
+
user = chat_to_username(chat)
|
|
883
|
+
if user:
|
|
884
|
+
ret["user"] = user
|
|
885
|
+
return ret
|
|
886
|
+
|
|
887
|
+
|
|
888
|
+
# Image Generator Providers
|
|
889
|
+
class GeneratorBase:
|
|
890
|
+
def __init__(self, **kwargs):
|
|
891
|
+
self.id = kwargs.get("id")
|
|
892
|
+
self.api = kwargs.get("api")
|
|
893
|
+
self.api_key = kwargs.get("api_key")
|
|
894
|
+
self.headers = {
|
|
895
|
+
"Accept": "application/json",
|
|
896
|
+
"Content-Type": "application/json",
|
|
897
|
+
}
|
|
898
|
+
self.chat_url = f"{self.api}/chat/completions"
|
|
899
|
+
self.default_content = "I've generated the image for you."
|
|
900
|
+
|
|
901
|
+
def validate(self, **kwargs):
|
|
902
|
+
if not self.api_key:
|
|
903
|
+
api_keys = ", ".join(self.env)
|
|
904
|
+
return f"Provider '{self.name}' requires API Key {api_keys}"
|
|
905
|
+
return None
|
|
906
|
+
|
|
907
|
+
def test(self, **kwargs):
|
|
908
|
+
error_msg = self.validate(**kwargs)
|
|
909
|
+
if error_msg:
|
|
910
|
+
_log(error_msg)
|
|
911
|
+
return False
|
|
912
|
+
return True
|
|
913
|
+
|
|
914
|
+
async def load(self):
|
|
915
|
+
pass
|
|
916
|
+
|
|
917
|
+
def gen_summary(self, gen):
|
|
918
|
+
"""Summarize gen response for logging."""
|
|
919
|
+
clone = json.loads(json.dumps(gen))
|
|
920
|
+
return json.dumps(clone, indent=2)
|
|
921
|
+
|
|
922
|
+
def chat_summary(self, chat):
|
|
923
|
+
return chat_summary(chat)
|
|
924
|
+
|
|
925
|
+
def process_chat(self, chat, provider_id=None):
|
|
926
|
+
return process_chat(chat, provider_id)
|
|
927
|
+
|
|
928
|
+
async def response_json(self, response):
|
|
929
|
+
return await response_json(response)
|
|
930
|
+
|
|
931
|
+
def get_headers(self, provider, chat):
|
|
932
|
+
headers = self.headers.copy()
|
|
933
|
+
if provider is not None:
|
|
934
|
+
headers["Authorization"] = f"Bearer {provider.api_key}"
|
|
935
|
+
elif self.api_key:
|
|
936
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
937
|
+
return headers
|
|
938
|
+
|
|
939
|
+
def to_response(self, response, chat, started_at):
|
|
940
|
+
raise NotImplementedError
|
|
941
|
+
|
|
942
|
+
async def chat(self, chat, provider=None, context=None):
|
|
943
|
+
return {
|
|
944
|
+
"choices": [
|
|
945
|
+
{
|
|
946
|
+
"message": {
|
|
947
|
+
"role": "assistant",
|
|
948
|
+
"content": "Not Implemented",
|
|
949
|
+
"images": [
|
|
950
|
+
{
|
|
951
|
+
"type": "image_url",
|
|
952
|
+
"image_url": {
|
|
953
|
+
"url": "data:image/svg+xml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHdpZHRoPSIyNCIgaGVpZ2h0PSIyNCIgdmlld0JveD0iMCAwIDI0IDI0Ij48cGF0aCBmaWxsPSJjdXJyZW50Q29sb3IiIGQ9Ik0xMiAyMGE4IDggMCAxIDAgMC0xNmE4IDggMCAwIDAgMCAxNm0wIDJDNi40NzcgMjIgMiAxNy41MjMgMiAxMlM2LjQ3NyAyIDEyIDJzMTAgNC40NzcgMTAgMTBzLTQuNDc3IDEwLTEwIDEwbS0xLTZoMnYyaC0yem0wLTEwaDJ2OGgtMnoiLz48L3N2Zz4=",
|
|
954
|
+
},
|
|
955
|
+
}
|
|
956
|
+
],
|
|
957
|
+
}
|
|
958
|
+
}
|
|
959
|
+
]
|
|
960
|
+
}
|
|
961
|
+
|
|
962
|
+
|
|
963
|
+
# OpenAI Providers
|
|
964
|
+
|
|
965
|
+
|
|
966
|
+
class OpenAiCompatible:
|
|
967
|
+
sdk = "@ai-sdk/openai-compatible"
|
|
968
|
+
|
|
969
|
+
def __init__(self, **kwargs):
|
|
970
|
+
required_args = ["id", "api"]
|
|
971
|
+
for arg in required_args:
|
|
972
|
+
if arg not in kwargs:
|
|
973
|
+
raise ValueError(f"Missing required argument: {arg}")
|
|
974
|
+
|
|
975
|
+
self.id = kwargs.get("id")
|
|
976
|
+
self.api = kwargs.get("api").strip("/")
|
|
977
|
+
self.env = kwargs.get("env", [])
|
|
978
|
+
self.api_key = kwargs.get("api_key")
|
|
979
|
+
self.name = kwargs.get("name", id_to_name(self.id))
|
|
980
|
+
self.set_models(**kwargs)
|
|
981
|
+
|
|
982
|
+
self.chat_url = f"{self.api}/chat/completions"
|
|
983
|
+
|
|
984
|
+
self.headers = kwargs.get("headers", {"Content-Type": "application/json"})
|
|
985
|
+
if self.api_key is not None:
|
|
986
|
+
self.headers["Authorization"] = f"Bearer {self.api_key}"
|
|
987
|
+
|
|
988
|
+
self.frequency_penalty = float(kwargs["frequency_penalty"]) if "frequency_penalty" in kwargs else None
|
|
989
|
+
self.max_completion_tokens = int(kwargs["max_completion_tokens"]) if "max_completion_tokens" in kwargs else None
|
|
990
|
+
self.n = int(kwargs["n"]) if "n" in kwargs else None
|
|
991
|
+
self.parallel_tool_calls = bool(kwargs["parallel_tool_calls"]) if "parallel_tool_calls" in kwargs else None
|
|
992
|
+
self.presence_penalty = float(kwargs["presence_penalty"]) if "presence_penalty" in kwargs else None
|
|
993
|
+
self.prompt_cache_key = kwargs.get("prompt_cache_key")
|
|
994
|
+
self.reasoning_effort = kwargs.get("reasoning_effort")
|
|
995
|
+
self.safety_identifier = kwargs.get("safety_identifier")
|
|
996
|
+
self.seed = int(kwargs["seed"]) if "seed" in kwargs else None
|
|
997
|
+
self.service_tier = kwargs.get("service_tier")
|
|
998
|
+
self.stop = kwargs.get("stop")
|
|
999
|
+
self.store = bool(kwargs["store"]) if "store" in kwargs else None
|
|
1000
|
+
self.temperature = float(kwargs["temperature"]) if "temperature" in kwargs else None
|
|
1001
|
+
self.top_logprobs = int(kwargs["top_logprobs"]) if "top_logprobs" in kwargs else None
|
|
1002
|
+
self.top_p = float(kwargs["top_p"]) if "top_p" in kwargs else None
|
|
1003
|
+
self.verbosity = kwargs.get("verbosity")
|
|
1004
|
+
self.stream = bool(kwargs["stream"]) if "stream" in kwargs else None
|
|
1005
|
+
self.enable_thinking = bool(kwargs["enable_thinking"]) if "enable_thinking" in kwargs else None
|
|
1006
|
+
self.check = kwargs.get("check")
|
|
1007
|
+
self.modalities = kwargs.get("modalities", {})
|
|
1008
|
+
|
|
1009
|
+
def set_models(self, **kwargs):
|
|
1010
|
+
models = kwargs.get("models", {})
|
|
1011
|
+
self.map_models = kwargs.get("map_models", {})
|
|
1012
|
+
# if 'map_models' is provided, only include models in `map_models[model_id] = provider_model_id`
|
|
1013
|
+
if self.map_models:
|
|
1014
|
+
self.models = {}
|
|
1015
|
+
for provider_model_id in self.map_models.values():
|
|
1016
|
+
if provider_model_id in models:
|
|
1017
|
+
self.models[provider_model_id] = models[provider_model_id]
|
|
1018
|
+
else:
|
|
1019
|
+
self.models = models
|
|
1020
|
+
|
|
1021
|
+
include_models = kwargs.get("include_models") # string regex pattern
|
|
1022
|
+
# only include models that match the regex pattern
|
|
1023
|
+
if include_models:
|
|
1024
|
+
_log(f"Filtering {len(self.models)} models, only including models that match regex: {include_models}")
|
|
1025
|
+
self.models = {k: v for k, v in self.models.items() if re.search(include_models, k)}
|
|
1026
|
+
|
|
1027
|
+
exclude_models = kwargs.get("exclude_models") # string regex pattern
|
|
1028
|
+
# exclude models that match the regex pattern
|
|
1029
|
+
if exclude_models:
|
|
1030
|
+
_log(f"Filtering {len(self.models)} models, excluding models that match regex: {exclude_models}")
|
|
1031
|
+
self.models = {k: v for k, v in self.models.items() if not re.search(exclude_models, k)}
|
|
1032
|
+
|
|
1033
|
+
def validate(self, **kwargs):
|
|
1034
|
+
if not self.api_key:
|
|
1035
|
+
api_keys = ", ".join(self.env)
|
|
1036
|
+
return f"Provider '{self.name}' requires API Key {api_keys}"
|
|
1037
|
+
return None
|
|
1038
|
+
|
|
1039
|
+
def test(self, **kwargs):
|
|
1040
|
+
error_msg = self.validate(**kwargs)
|
|
1041
|
+
if error_msg:
|
|
1042
|
+
_log(error_msg)
|
|
1043
|
+
return False
|
|
1044
|
+
return True
|
|
1045
|
+
|
|
1046
|
+
async def load(self):
|
|
1047
|
+
if not self.models:
|
|
1048
|
+
await self.load_models()
|
|
1049
|
+
|
|
1050
|
+
def model_info(self, model):
|
|
1051
|
+
provider_model = self.provider_model(model) or model
|
|
1052
|
+
for model_id, model_info in self.models.items():
|
|
1053
|
+
if model_id.lower() == provider_model.lower():
|
|
1054
|
+
return model_info
|
|
1055
|
+
return None
|
|
1056
|
+
|
|
1057
|
+
def model_cost(self, model):
|
|
1058
|
+
model_info = self.model_info(model)
|
|
1059
|
+
return model_info.get("cost") if model_info else None
|
|
1060
|
+
|
|
1061
|
+
def provider_model(self, model):
|
|
1062
|
+
# convert model to lowercase for case-insensitive comparison
|
|
1063
|
+
model_lower = model.lower()
|
|
1064
|
+
|
|
1065
|
+
# if model is a map model id, return the provider model id
|
|
1066
|
+
for model_id, provider_model in self.map_models.items():
|
|
1067
|
+
if model_id.lower() == model_lower:
|
|
1068
|
+
return provider_model
|
|
1069
|
+
|
|
1070
|
+
# if model is a provider model id, try again with just the model name
|
|
1071
|
+
for provider_model in self.map_models.values():
|
|
1072
|
+
if provider_model.lower() == model_lower:
|
|
1073
|
+
return provider_model
|
|
1074
|
+
|
|
1075
|
+
# if model is a model id, try again with just the model id or name
|
|
1076
|
+
for model_id, provider_model_info in self.models.items():
|
|
1077
|
+
id = provider_model_info.get("id") or model_id
|
|
1078
|
+
if model_id.lower() == model_lower or id.lower() == model_lower:
|
|
1079
|
+
return id
|
|
1080
|
+
name = provider_model_info.get("name")
|
|
1081
|
+
if name and name.lower() == model_lower:
|
|
1082
|
+
return id
|
|
1083
|
+
|
|
1084
|
+
# fallback to trying again with just the model short name
|
|
1085
|
+
for model_id, provider_model_info in self.models.items():
|
|
1086
|
+
id = provider_model_info.get("id") or model_id
|
|
1087
|
+
if "/" in id:
|
|
1088
|
+
model_name = id.split("/")[-1]
|
|
1089
|
+
if model_name.lower() == model_lower:
|
|
1090
|
+
return id
|
|
1091
|
+
|
|
1092
|
+
# if model is a full provider model id, try again with just the model name
|
|
1093
|
+
if "/" in model:
|
|
1094
|
+
last_part = model.split("/")[-1]
|
|
1095
|
+
return self.provider_model(last_part)
|
|
1096
|
+
|
|
1097
|
+
return None
|
|
1098
|
+
|
|
1099
|
+
def response_json(self, response):
|
|
1100
|
+
return response_json(response)
|
|
1101
|
+
|
|
1102
|
+
def to_response(self, response, chat, started_at, context=None):
|
|
1103
|
+
if "metadata" not in response:
|
|
1104
|
+
response["metadata"] = {}
|
|
1105
|
+
response["metadata"]["duration"] = int((time.time() - started_at) * 1000)
|
|
1106
|
+
if chat is not None and "model" in chat:
|
|
1107
|
+
pricing = self.model_cost(chat["model"])
|
|
1108
|
+
if pricing and "input" in pricing and "output" in pricing:
|
|
1109
|
+
response["metadata"]["pricing"] = f"{pricing['input']}/{pricing['output']}"
|
|
1110
|
+
if context is not None:
|
|
1111
|
+
context["providerResponse"] = response
|
|
1112
|
+
return response
|
|
1113
|
+
|
|
1114
|
+
def chat_summary(self, chat):
|
|
1115
|
+
return chat_summary(chat)
|
|
1116
|
+
|
|
1117
|
+
def process_chat(self, chat, provider_id=None):
|
|
1118
|
+
return process_chat(chat, provider_id)
|
|
1119
|
+
|
|
1120
|
+
async def chat(self, chat, context=None):
|
|
1121
|
+
chat["model"] = self.provider_model(chat["model"]) or chat["model"]
|
|
1122
|
+
|
|
1123
|
+
modalities = chat.get("modalities") or []
|
|
1124
|
+
if len(modalities) > 0:
|
|
1125
|
+
for modality in modalities:
|
|
1126
|
+
# use default implementation for text modalities
|
|
1127
|
+
if modality == "text":
|
|
1128
|
+
continue
|
|
1129
|
+
modality_provider = self.modalities.get(modality)
|
|
1130
|
+
if modality_provider:
|
|
1131
|
+
return await modality_provider.chat(chat, self, context=context)
|
|
1132
|
+
else:
|
|
1133
|
+
raise Exception(f"Provider {self.name} does not support '{modality}' modality")
|
|
1134
|
+
|
|
1135
|
+
# with open(os.path.join(os.path.dirname(__file__), 'chat.wip.json'), "w") as f:
|
|
1136
|
+
# f.write(json.dumps(chat, indent=2))
|
|
1137
|
+
|
|
1138
|
+
if self.frequency_penalty is not None:
|
|
1139
|
+
chat["frequency_penalty"] = self.frequency_penalty
|
|
1140
|
+
if self.max_completion_tokens is not None:
|
|
1141
|
+
chat["max_completion_tokens"] = self.max_completion_tokens
|
|
1142
|
+
if self.n is not None:
|
|
1143
|
+
chat["n"] = self.n
|
|
1144
|
+
if self.parallel_tool_calls is not None:
|
|
1145
|
+
chat["parallel_tool_calls"] = self.parallel_tool_calls
|
|
1146
|
+
if self.presence_penalty is not None:
|
|
1147
|
+
chat["presence_penalty"] = self.presence_penalty
|
|
1148
|
+
if self.prompt_cache_key is not None:
|
|
1149
|
+
chat["prompt_cache_key"] = self.prompt_cache_key
|
|
1150
|
+
if self.reasoning_effort is not None:
|
|
1151
|
+
chat["reasoning_effort"] = self.reasoning_effort
|
|
1152
|
+
if self.safety_identifier is not None:
|
|
1153
|
+
chat["safety_identifier"] = self.safety_identifier
|
|
1154
|
+
if self.seed is not None:
|
|
1155
|
+
chat["seed"] = self.seed
|
|
1156
|
+
if self.service_tier is not None:
|
|
1157
|
+
chat["service_tier"] = self.service_tier
|
|
1158
|
+
if self.stop is not None:
|
|
1159
|
+
chat["stop"] = self.stop
|
|
1160
|
+
if self.store is not None:
|
|
1161
|
+
chat["store"] = self.store
|
|
1162
|
+
if self.temperature is not None:
|
|
1163
|
+
chat["temperature"] = self.temperature
|
|
1164
|
+
if self.top_logprobs is not None:
|
|
1165
|
+
chat["top_logprobs"] = self.top_logprobs
|
|
1166
|
+
if self.top_p is not None:
|
|
1167
|
+
chat["top_p"] = self.top_p
|
|
1168
|
+
if self.verbosity is not None:
|
|
1169
|
+
chat["verbosity"] = self.verbosity
|
|
1170
|
+
if self.enable_thinking is not None:
|
|
1171
|
+
chat["enable_thinking"] = self.enable_thinking
|
|
1172
|
+
|
|
1173
|
+
chat = await process_chat(chat, provider_id=self.id)
|
|
1174
|
+
_log(f"POST {self.chat_url}")
|
|
1175
|
+
_log(chat_summary(chat))
|
|
1176
|
+
# remove metadata if any (conflicts with some providers, e.g. Z.ai)
|
|
1177
|
+
metadata = chat.pop("metadata", None)
|
|
1178
|
+
|
|
1179
|
+
async with aiohttp.ClientSession() as session:
|
|
1180
|
+
started_at = time.time()
|
|
1181
|
+
async with session.post(
|
|
1182
|
+
self.chat_url, headers=self.headers, data=json.dumps(chat), timeout=aiohttp.ClientTimeout(total=120)
|
|
1183
|
+
) as response:
|
|
1184
|
+
chat["metadata"] = metadata
|
|
1185
|
+
return self.to_response(await response_json(response), chat, started_at, context=context)
|
|
1186
|
+
|
|
1187
|
+
|
|
1188
|
+
class MistralProvider(OpenAiCompatible):
|
|
1189
|
+
sdk = "@ai-sdk/mistral"
|
|
1190
|
+
|
|
1191
|
+
def __init__(self, **kwargs):
|
|
1192
|
+
if "api" not in kwargs:
|
|
1193
|
+
kwargs["api"] = "https://api.mistral.ai/v1"
|
|
1194
|
+
super().__init__(**kwargs)
|
|
1195
|
+
|
|
1196
|
+
|
|
1197
|
+
class GroqProvider(OpenAiCompatible):
|
|
1198
|
+
sdk = "@ai-sdk/groq"
|
|
1199
|
+
|
|
1200
|
+
def __init__(self, **kwargs):
|
|
1201
|
+
if "api" not in kwargs:
|
|
1202
|
+
kwargs["api"] = "https://api.groq.com/openai/v1"
|
|
1203
|
+
super().__init__(**kwargs)
|
|
1204
|
+
|
|
1205
|
+
|
|
1206
|
+
class XaiProvider(OpenAiCompatible):
|
|
1207
|
+
sdk = "@ai-sdk/xai"
|
|
1208
|
+
|
|
1209
|
+
def __init__(self, **kwargs):
|
|
1210
|
+
if "api" not in kwargs:
|
|
1211
|
+
kwargs["api"] = "https://api.x.ai/v1"
|
|
1212
|
+
super().__init__(**kwargs)
|
|
1213
|
+
|
|
1214
|
+
|
|
1215
|
+
class CodestralProvider(OpenAiCompatible):
|
|
1216
|
+
sdk = "codestral"
|
|
1217
|
+
|
|
1218
|
+
def __init__(self, **kwargs):
|
|
1219
|
+
super().__init__(**kwargs)
|
|
1220
|
+
|
|
1221
|
+
|
|
1222
|
+
class OllamaProvider(OpenAiCompatible):
|
|
1223
|
+
sdk = "ollama"
|
|
1224
|
+
|
|
1225
|
+
def __init__(self, **kwargs):
|
|
1226
|
+
super().__init__(**kwargs)
|
|
1227
|
+
# Ollama's OpenAI-compatible endpoint is at /v1/chat/completions
|
|
1228
|
+
self.chat_url = f"{self.api}/v1/chat/completions"
|
|
1229
|
+
|
|
1230
|
+
async def load(self):
|
|
1231
|
+
if not self.models:
|
|
1232
|
+
await self.load_models()
|
|
1233
|
+
|
|
1234
|
+
async def get_models(self):
|
|
1235
|
+
ret = {}
|
|
1236
|
+
try:
|
|
1237
|
+
async with aiohttp.ClientSession() as session:
|
|
1238
|
+
_log(f"GET {self.api}/api/tags")
|
|
1239
|
+
async with session.get(
|
|
1240
|
+
f"{self.api}/api/tags", headers=self.headers, timeout=aiohttp.ClientTimeout(total=120)
|
|
1241
|
+
) as response:
|
|
1242
|
+
data = await response_json(response)
|
|
1243
|
+
for model in data.get("models", []):
|
|
1244
|
+
model_id = model["model"]
|
|
1245
|
+
if model_id.endswith(":latest"):
|
|
1246
|
+
model_id = model_id[:-7]
|
|
1247
|
+
ret[model_id] = model_id
|
|
1248
|
+
_log(f"Loaded Ollama models: {ret}")
|
|
1249
|
+
except Exception as e:
|
|
1250
|
+
_log(f"Error getting Ollama models: {e}")
|
|
1251
|
+
# return empty dict if ollama is not available
|
|
1252
|
+
return ret
|
|
1253
|
+
|
|
1254
|
+
async def load_models(self):
|
|
1255
|
+
"""Load models if all_models was requested"""
|
|
1256
|
+
|
|
1257
|
+
# Map models to provider models {model_id:model_id}
|
|
1258
|
+
model_map = await self.get_models()
|
|
1259
|
+
if self.map_models:
|
|
1260
|
+
map_model_values = set(self.map_models.values())
|
|
1261
|
+
to = {}
|
|
1262
|
+
for k, v in model_map.items():
|
|
1263
|
+
if k in self.map_models:
|
|
1264
|
+
to[k] = v
|
|
1265
|
+
if v in map_model_values:
|
|
1266
|
+
to[k] = v
|
|
1267
|
+
model_map = to
|
|
1268
|
+
else:
|
|
1269
|
+
self.map_models = model_map
|
|
1270
|
+
models = {}
|
|
1271
|
+
for k, v in model_map.items():
|
|
1272
|
+
models[k] = {
|
|
1273
|
+
"id": k,
|
|
1274
|
+
"name": v.replace(":", " "),
|
|
1275
|
+
"modalities": {"input": ["text"], "output": ["text"]},
|
|
1276
|
+
"cost": {
|
|
1277
|
+
"input": 0,
|
|
1278
|
+
"output": 0,
|
|
1279
|
+
},
|
|
1280
|
+
}
|
|
1281
|
+
self.models = models
|
|
1282
|
+
|
|
1283
|
+
def validate(self, **kwargs):
|
|
1284
|
+
return None
|
|
1285
|
+
|
|
1286
|
+
|
|
1287
|
+
class LMStudioProvider(OllamaProvider):
|
|
1288
|
+
sdk = "lmstudio"
|
|
1289
|
+
|
|
1290
|
+
def __init__(self, **kwargs):
|
|
1291
|
+
super().__init__(**kwargs)
|
|
1292
|
+
self.chat_url = f"{self.api}/chat/completions"
|
|
1293
|
+
|
|
1294
|
+
async def get_models(self):
|
|
1295
|
+
ret = {}
|
|
1296
|
+
try:
|
|
1297
|
+
async with aiohttp.ClientSession() as session:
|
|
1298
|
+
_log(f"GET {self.api}/models")
|
|
1299
|
+
async with session.get(
|
|
1300
|
+
f"{self.api}/models", headers=self.headers, timeout=aiohttp.ClientTimeout(total=120)
|
|
1301
|
+
) as response:
|
|
1302
|
+
data = await response_json(response)
|
|
1303
|
+
for model in data.get("data", []):
|
|
1304
|
+
id = model["id"]
|
|
1305
|
+
ret[id] = id
|
|
1306
|
+
_log(f"Loaded LMStudio models: {ret}")
|
|
1307
|
+
except Exception as e:
|
|
1308
|
+
_log(f"Error getting LMStudio models: {e}")
|
|
1309
|
+
# return empty dict if ollama is not available
|
|
1310
|
+
return ret
|
|
1311
|
+
|
|
1312
|
+
|
|
1313
|
+
def get_provider_model(model_name):
|
|
1314
|
+
for provider in g_handlers.values():
|
|
1315
|
+
provider_model = provider.provider_model(model_name)
|
|
1316
|
+
if provider_model:
|
|
1317
|
+
return provider_model
|
|
1318
|
+
return None
|
|
1319
|
+
|
|
1320
|
+
|
|
1321
|
+
def get_models():
|
|
1322
|
+
ret = []
|
|
1323
|
+
for provider in g_handlers.values():
|
|
1324
|
+
for model in provider.models:
|
|
1325
|
+
if model not in ret:
|
|
1326
|
+
ret.append(model)
|
|
1327
|
+
ret.sort()
|
|
1328
|
+
return ret
|
|
1329
|
+
|
|
1330
|
+
|
|
1331
|
+
def get_active_models():
|
|
1332
|
+
ret = []
|
|
1333
|
+
existing_models = set()
|
|
1334
|
+
for provider_id, provider in g_handlers.items():
|
|
1335
|
+
for model in provider.models.values():
|
|
1336
|
+
name = model.get("name")
|
|
1337
|
+
if not name:
|
|
1338
|
+
_log(f"Provider {provider_id} model {model} has no name")
|
|
1339
|
+
continue
|
|
1340
|
+
if name not in existing_models:
|
|
1341
|
+
existing_models.add(name)
|
|
1342
|
+
item = model.copy()
|
|
1343
|
+
item.update({"provider": provider_id})
|
|
1344
|
+
ret.append(item)
|
|
1345
|
+
ret.sort(key=lambda x: x["id"])
|
|
1346
|
+
return ret
|
|
1347
|
+
|
|
1348
|
+
|
|
1349
|
+
def api_providers():
|
|
1350
|
+
ret = []
|
|
1351
|
+
for id, provider in g_handlers.items():
|
|
1352
|
+
ret.append({"id": id, "name": provider.name, "models": provider.models})
|
|
1353
|
+
return ret
|
|
1354
|
+
|
|
1355
|
+
|
|
1356
|
+
def to_error_message(e):
|
|
1357
|
+
# check if has 'message' attribute
|
|
1358
|
+
if hasattr(e, "message"):
|
|
1359
|
+
return e.message
|
|
1360
|
+
if hasattr(e, "status"):
|
|
1361
|
+
return str(e.status)
|
|
1362
|
+
return str(e)
|
|
1363
|
+
|
|
1364
|
+
|
|
1365
|
+
def to_error_response(e, stacktrace=False):
|
|
1366
|
+
status = {"errorCode": "Error", "message": to_error_message(e)}
|
|
1367
|
+
if stacktrace:
|
|
1368
|
+
status["stackTrace"] = traceback.format_exc()
|
|
1369
|
+
return {"responseStatus": status}
|
|
1370
|
+
|
|
1371
|
+
|
|
1372
|
+
def create_error_response(message, error_code="Error", stack_trace=None):
|
|
1373
|
+
ret = {"responseStatus": {"errorCode": error_code, "message": message}}
|
|
1374
|
+
if stack_trace:
|
|
1375
|
+
ret["responseStatus"]["stackTrace"] = stack_trace
|
|
1376
|
+
return ret
|
|
1377
|
+
|
|
1378
|
+
|
|
1379
|
+
def should_cancel_thread(context):
|
|
1380
|
+
ret = context.get("cancelled", False)
|
|
1381
|
+
if ret:
|
|
1382
|
+
thread_id = context.get("threadId")
|
|
1383
|
+
_dbg(f"Thread cancelled {thread_id}")
|
|
1384
|
+
return ret
|
|
1385
|
+
|
|
1386
|
+
|
|
1387
|
+
def g_chat_request(template=None, text=None, model=None, system_prompt=None):
|
|
1388
|
+
chat_template = g_config["defaults"].get(template or "text")
|
|
1389
|
+
if not chat_template:
|
|
1390
|
+
raise Exception(f"Chat template '{template}' not found")
|
|
1391
|
+
|
|
1392
|
+
chat = chat_template.copy()
|
|
1393
|
+
if model:
|
|
1394
|
+
chat["model"] = model
|
|
1395
|
+
if system_prompt is not None:
|
|
1396
|
+
chat["messages"].insert(0, {"role": "system", "content": system_prompt})
|
|
1397
|
+
if text is not None:
|
|
1398
|
+
if not chat["messages"] or len(chat["messages"]) == 0:
|
|
1399
|
+
chat["messages"] = [{"role": "user", "content": [{"type": "text", "text": ""}]}]
|
|
1400
|
+
|
|
1401
|
+
# replace content of last message if exists, else add
|
|
1402
|
+
last_msg = chat["messages"][-1] if "messages" in chat else None
|
|
1403
|
+
if last_msg and last_msg["role"] == "user":
|
|
1404
|
+
if isinstance(last_msg["content"], list):
|
|
1405
|
+
last_msg["content"][-1]["text"] = text
|
|
1406
|
+
else:
|
|
1407
|
+
last_msg["content"] = text
|
|
1408
|
+
else:
|
|
1409
|
+
chat["messages"].append({"role": "user", "content": text})
|
|
1410
|
+
|
|
1411
|
+
return chat
|
|
1412
|
+
|
|
1413
|
+
|
|
1414
|
+
def tool_result_part(result: dict, function_name: Optional[str] = None, function_args: Optional[dict] = None):
|
|
1415
|
+
args = function_args or {}
|
|
1416
|
+
type = result.get("type")
|
|
1417
|
+
prompt = args.get("prompt") or args.get("text") or args.get("message")
|
|
1418
|
+
if type == "text":
|
|
1419
|
+
return result.get("text"), None
|
|
1420
|
+
elif type == "image":
|
|
1421
|
+
format = result.get("format") or args.get("format") or "png"
|
|
1422
|
+
filename = result.get("filename") or args.get("filename") or f"{function_name}-{int(time.time())}.{format}"
|
|
1423
|
+
mime_type = get_file_mime_type(filename)
|
|
1424
|
+
image_info = {"type": mime_type}
|
|
1425
|
+
if prompt:
|
|
1426
|
+
image_info["prompt"] = prompt
|
|
1427
|
+
if "model" in args:
|
|
1428
|
+
image_info["model"] = args["model"]
|
|
1429
|
+
if "aspect_ratio" in args:
|
|
1430
|
+
image_info["aspect_ratio"] = args["aspect_ratio"]
|
|
1431
|
+
base64_data = result.get("data")
|
|
1432
|
+
if not base64_data:
|
|
1433
|
+
_dbg(f"Image data not found for {function_name}")
|
|
1434
|
+
return None, None
|
|
1435
|
+
url, _ = save_image_to_cache(base64_data, filename, image_info=image_info, ignore_info=True)
|
|
1436
|
+
resource = {
|
|
1437
|
+
"type": "image_url",
|
|
1438
|
+
"image_url": {
|
|
1439
|
+
"url": url,
|
|
1440
|
+
},
|
|
1441
|
+
}
|
|
1442
|
+
text = f"\n"
|
|
1443
|
+
return text, resource
|
|
1444
|
+
elif type == "audio":
|
|
1445
|
+
format = result.get("format") or args.get("format") or "mp3"
|
|
1446
|
+
filename = result.get("filename") or args.get("filename") or f"{function_name}-{int(time.time())}.{format}"
|
|
1447
|
+
mime_type = get_file_mime_type(filename)
|
|
1448
|
+
audio_info = {"type": mime_type}
|
|
1449
|
+
if prompt:
|
|
1450
|
+
audio_info["prompt"] = prompt
|
|
1451
|
+
if "model" in args:
|
|
1452
|
+
audio_info["model"] = args["model"]
|
|
1453
|
+
base64_data = result.get("data")
|
|
1454
|
+
if not base64_data:
|
|
1455
|
+
_dbg(f"Audio data not found for {function_name}")
|
|
1456
|
+
return None, None
|
|
1457
|
+
url, _ = save_audio_to_cache(base64_data, filename, audio_info=audio_info, ignore_info=True)
|
|
1458
|
+
resource = {
|
|
1459
|
+
"type": "audio_url",
|
|
1460
|
+
"audio_url": {
|
|
1461
|
+
"url": url,
|
|
1462
|
+
},
|
|
1463
|
+
}
|
|
1464
|
+
text = f"[{args.get('prompt') or filename}]({url})\n"
|
|
1465
|
+
return text, resource
|
|
1466
|
+
elif type == "file":
|
|
1467
|
+
filename = result.get("filename") or args.get("filename") or result.get("name") or args.get("name")
|
|
1468
|
+
format = result.get("format") or args.get("format") or (get_filename(filename) if filename else "txt")
|
|
1469
|
+
if not filename:
|
|
1470
|
+
filename = f"{function_name}-{int(time.time())}.{format}"
|
|
1471
|
+
|
|
1472
|
+
mime_type = get_file_mime_type(filename)
|
|
1473
|
+
file_info = {"type": mime_type}
|
|
1474
|
+
if prompt:
|
|
1475
|
+
file_info["prompt"] = prompt
|
|
1476
|
+
if "model" in args:
|
|
1477
|
+
file_info["model"] = args["model"]
|
|
1478
|
+
base64_data = result.get("data")
|
|
1479
|
+
if not base64_data:
|
|
1480
|
+
_dbg(f"File data not found for {function_name}")
|
|
1481
|
+
return None, None
|
|
1482
|
+
url, info = save_bytes_to_cache(base64_data, filename, file_info=file_info)
|
|
1483
|
+
resource = {
|
|
1484
|
+
"type": "file",
|
|
1485
|
+
"file": {
|
|
1486
|
+
"file_data": url,
|
|
1487
|
+
"filename": info["name"],
|
|
1488
|
+
},
|
|
1489
|
+
}
|
|
1490
|
+
text = f"[{args.get('prompt') or filename}]({url})\n"
|
|
1491
|
+
return text, resource
|
|
1492
|
+
else:
|
|
1493
|
+
try:
|
|
1494
|
+
return json.dumps(result), None
|
|
1495
|
+
except Exception as e:
|
|
1496
|
+
_dbg(f"Error converting result to JSON: {e}")
|
|
1497
|
+
try:
|
|
1498
|
+
return str(result), None
|
|
1499
|
+
except Exception as e:
|
|
1500
|
+
_dbg(f"Error converting result to string: {e}")
|
|
1501
|
+
return None, None
|
|
1502
|
+
|
|
1503
|
+
|
|
1504
|
+
def g_tool_result(result, function_name: Optional[str] = None, function_args: Optional[dict] = None):
|
|
1505
|
+
content = []
|
|
1506
|
+
resources = []
|
|
1507
|
+
args = function_args or {}
|
|
1508
|
+
if isinstance(result, dict):
|
|
1509
|
+
text, res = tool_result_part(result, function_name, args)
|
|
1510
|
+
if text:
|
|
1511
|
+
content.append(text)
|
|
1512
|
+
if res:
|
|
1513
|
+
resources.append(res)
|
|
1514
|
+
elif isinstance(result, list):
|
|
1515
|
+
for item in result:
|
|
1516
|
+
text, res = tool_result_part(item, function_name, args)
|
|
1517
|
+
if text:
|
|
1518
|
+
content.append(text)
|
|
1519
|
+
if res:
|
|
1520
|
+
resources.append(res)
|
|
1521
|
+
else:
|
|
1522
|
+
content = [str(result)]
|
|
1523
|
+
|
|
1524
|
+
text = "\n".join(content)
|
|
1525
|
+
return text, resources
|
|
1526
|
+
|
|
1527
|
+
|
|
1528
|
+
async def g_exec_tool(function_name, function_args):
|
|
1529
|
+
if function_name in g_app.tools:
|
|
1530
|
+
try:
|
|
1531
|
+
func = g_app.tools[function_name]
|
|
1532
|
+
is_async = inspect.iscoroutinefunction(func)
|
|
1533
|
+
_dbg(f"Executing {'async' if is_async else 'sync'} tool '{function_name}' with args: {function_args}")
|
|
1534
|
+
if is_async:
|
|
1535
|
+
return g_tool_result(await func(**function_args), function_name, function_args)
|
|
1536
|
+
else:
|
|
1537
|
+
return g_tool_result(func(**function_args), function_name, function_args)
|
|
1538
|
+
except Exception as e:
|
|
1539
|
+
return f"Error executing tool '{function_name}': {to_error_message(e)}", None
|
|
1540
|
+
return f"Error: Tool '{function_name}' not found", None
|
|
1541
|
+
|
|
1542
|
+
|
|
1543
|
+
def group_resources(resources: list):
|
|
1544
|
+
"""
|
|
1545
|
+
converts list of parts into a grouped dictionary, e.g:
|
|
1546
|
+
[{"type: "image_url", "image_url": {"url": "/image.jpg"}}] =>
|
|
1547
|
+
{"images": [{"type": "image_url", "image_url": {"url": "/image.jpg"}}] }
|
|
1548
|
+
"""
|
|
1549
|
+
grouped = {}
|
|
1550
|
+
for res in resources:
|
|
1551
|
+
type = res.get("type")
|
|
1552
|
+
if not type:
|
|
1553
|
+
continue
|
|
1554
|
+
if type == "image_url":
|
|
1555
|
+
group = "images"
|
|
1556
|
+
elif type == "audio_url":
|
|
1557
|
+
group = "audios"
|
|
1558
|
+
elif type == "file_urls" or type == "file":
|
|
1559
|
+
group = "files"
|
|
1560
|
+
elif type == "text":
|
|
1561
|
+
group = "texts"
|
|
1562
|
+
else:
|
|
1563
|
+
group = "others"
|
|
1564
|
+
if group not in grouped:
|
|
1565
|
+
grouped[group] = []
|
|
1566
|
+
grouped[group].append(res)
|
|
1567
|
+
return grouped
|
|
1568
|
+
|
|
1569
|
+
|
|
1570
|
+
async def g_chat_completion(chat, context=None):
|
|
1571
|
+
try:
|
|
1572
|
+
model = chat.get("model")
|
|
1573
|
+
if not model:
|
|
1574
|
+
raise Exception("Model not specified")
|
|
1575
|
+
|
|
1576
|
+
if context is None:
|
|
1577
|
+
context = {"chat": chat, "tools": "all"}
|
|
1578
|
+
|
|
1579
|
+
# get first provider that has the model
|
|
1580
|
+
candidate_providers = [name for name, provider in g_handlers.items() if provider.provider_model(model)]
|
|
1581
|
+
if len(candidate_providers) == 0:
|
|
1582
|
+
raise (Exception(f"Model {model} not found"))
|
|
1583
|
+
except Exception as e:
|
|
1584
|
+
await g_app.on_chat_error(e, context or {"chat": chat})
|
|
1585
|
+
raise e
|
|
1586
|
+
|
|
1587
|
+
started_at = time.time()
|
|
1588
|
+
first_exception = None
|
|
1589
|
+
provider_name = "Unknown"
|
|
1590
|
+
for name in candidate_providers:
|
|
1591
|
+
try:
|
|
1592
|
+
provider_name = name
|
|
1593
|
+
provider = g_handlers[name]
|
|
1594
|
+
_log(f"provider: {name} {type(provider).__name__}")
|
|
1595
|
+
started_at = time.time()
|
|
1596
|
+
context["startedAt"] = datetime.now()
|
|
1597
|
+
context["provider"] = name
|
|
1598
|
+
model_info = provider.model_info(model)
|
|
1599
|
+
context["modelCost"] = model_info.get("cost", provider.model_cost(model)) or {"input": 0, "output": 0}
|
|
1600
|
+
context["modelInfo"] = model_info
|
|
1601
|
+
|
|
1602
|
+
# Accumulate usage across tool calls
|
|
1603
|
+
total_usage = {
|
|
1604
|
+
"prompt_tokens": 0,
|
|
1605
|
+
"completion_tokens": 0,
|
|
1606
|
+
"total_tokens": 0,
|
|
1607
|
+
}
|
|
1608
|
+
accumulated_cost = 0.0
|
|
1609
|
+
|
|
1610
|
+
# Inject global tools if present
|
|
1611
|
+
current_chat = g_app.create_chat_with_tools(chat, use_tools=context.get("tools", "all"))
|
|
1612
|
+
|
|
1613
|
+
# Apply pre-chat filters ONCE
|
|
1614
|
+
context["chat"] = current_chat
|
|
1615
|
+
for filter_func in g_app.chat_request_filters:
|
|
1616
|
+
await filter_func(current_chat, context)
|
|
1617
|
+
|
|
1618
|
+
# Tool execution loop
|
|
1619
|
+
max_iterations = 10
|
|
1620
|
+
tool_history = []
|
|
1621
|
+
final_response = None
|
|
1622
|
+
|
|
1623
|
+
for _ in range(max_iterations):
|
|
1624
|
+
if should_cancel_thread(context):
|
|
1625
|
+
return
|
|
1626
|
+
|
|
1627
|
+
response = await provider.chat(current_chat, context=context)
|
|
1628
|
+
|
|
1629
|
+
if should_cancel_thread(context):
|
|
1630
|
+
return
|
|
1631
|
+
|
|
1632
|
+
# Aggregate usage
|
|
1633
|
+
if "usage" in response:
|
|
1634
|
+
usage = response["usage"]
|
|
1635
|
+
total_usage["prompt_tokens"] += usage.get("prompt_tokens", 0)
|
|
1636
|
+
total_usage["completion_tokens"] += usage.get("completion_tokens", 0)
|
|
1637
|
+
total_usage["total_tokens"] += usage.get("total_tokens", 0)
|
|
1638
|
+
|
|
1639
|
+
# Calculate cost for this step if available
|
|
1640
|
+
if "cost" in response and isinstance(response["cost"], (int, float)):
|
|
1641
|
+
accumulated_cost += response["cost"]
|
|
1642
|
+
elif "cost" in usage and isinstance(usage["cost"], (int, float)):
|
|
1643
|
+
accumulated_cost += usage["cost"]
|
|
1644
|
+
|
|
1645
|
+
# Check for tool_calls in the response
|
|
1646
|
+
choice = response.get("choices", [])[0] if response.get("choices") else {}
|
|
1647
|
+
message = choice.get("message", {})
|
|
1648
|
+
tool_calls = message.get("tool_calls")
|
|
1649
|
+
supports_tool_calls = model_info.get("tool_call", False)
|
|
1650
|
+
|
|
1651
|
+
if tool_calls and supports_tool_calls:
|
|
1652
|
+
# Append the assistant's message with tool calls to history
|
|
1653
|
+
if "messages" not in current_chat:
|
|
1654
|
+
current_chat["messages"] = []
|
|
1655
|
+
if "timestamp" not in message:
|
|
1656
|
+
message["timestamp"] = int(time.time() * 1000)
|
|
1657
|
+
current_chat["messages"].append(message)
|
|
1658
|
+
tool_history.append(message)
|
|
1659
|
+
|
|
1660
|
+
await g_app.on_chat_tool(current_chat, context)
|
|
1661
|
+
|
|
1662
|
+
for tool_call in tool_calls:
|
|
1663
|
+
function_name = tool_call["function"]["name"]
|
|
1664
|
+
try:
|
|
1665
|
+
function_args = json.loads(tool_call["function"]["arguments"])
|
|
1666
|
+
except Exception as e:
|
|
1667
|
+
tool_result = f"Error: Failed to parse JSON arguments for tool '{function_name}': {to_error_message(e)}"
|
|
1668
|
+
else:
|
|
1669
|
+
tool_result, resources = await g_exec_tool(function_name, function_args)
|
|
1670
|
+
|
|
1671
|
+
# Append tool result to history
|
|
1672
|
+
tool_msg = {"role": "tool", "tool_call_id": tool_call["id"], "content": to_content(tool_result)}
|
|
1673
|
+
|
|
1674
|
+
tool_msg.update(group_resources(resources))
|
|
1675
|
+
|
|
1676
|
+
current_chat["messages"].append(tool_msg)
|
|
1677
|
+
tool_history.append(tool_msg)
|
|
1678
|
+
|
|
1679
|
+
await g_app.on_chat_tool(current_chat, context)
|
|
1680
|
+
|
|
1681
|
+
if should_cancel_thread(context):
|
|
1682
|
+
return
|
|
1683
|
+
|
|
1684
|
+
# Continue loop to send tool results back to LLM
|
|
1685
|
+
continue
|
|
1686
|
+
|
|
1687
|
+
# If no tool calls, this is the final response
|
|
1688
|
+
if tool_history:
|
|
1689
|
+
response["tool_history"] = tool_history
|
|
1690
|
+
|
|
1691
|
+
# Update final response with aggregated usage
|
|
1692
|
+
if "usage" not in response:
|
|
1693
|
+
response["usage"] = {}
|
|
1694
|
+
# convert to int seconds
|
|
1695
|
+
context["duration"] = duration = int(time.time() - started_at)
|
|
1696
|
+
total_usage.update({"duration": duration})
|
|
1697
|
+
response["usage"].update(total_usage)
|
|
1698
|
+
# If we accumulated cost, set it on the response
|
|
1699
|
+
if accumulated_cost > 0:
|
|
1700
|
+
response["cost"] = accumulated_cost
|
|
1701
|
+
|
|
1702
|
+
final_response = response
|
|
1703
|
+
break # Exit tool loop
|
|
1704
|
+
|
|
1705
|
+
if final_response:
|
|
1706
|
+
# Apply post-chat filters ONCE on final response
|
|
1707
|
+
for filter_func in g_app.chat_response_filters:
|
|
1708
|
+
await filter_func(final_response, context)
|
|
1709
|
+
|
|
1710
|
+
if DEBUG:
|
|
1711
|
+
_dbg(json.dumps(final_response, indent=2))
|
|
1712
|
+
|
|
1713
|
+
return final_response
|
|
1714
|
+
|
|
1715
|
+
except Exception as e:
|
|
1716
|
+
if first_exception is None:
|
|
1717
|
+
first_exception = e
|
|
1718
|
+
context["stackTrace"] = traceback.format_exc()
|
|
1719
|
+
_err(f"Provider {provider_name} failed", first_exception)
|
|
1720
|
+
await g_app.on_chat_error(e, context)
|
|
1721
|
+
|
|
1722
|
+
continue
|
|
1723
|
+
|
|
1724
|
+
# If we get here, all providers failed
|
|
1725
|
+
raise first_exception
|
|
1726
|
+
|
|
1727
|
+
|
|
1728
|
+
async def cli_chat(chat, tools=None, image=None, audio=None, file=None, args=None, raw=False):
|
|
1729
|
+
if g_default_model:
|
|
1730
|
+
chat["model"] = g_default_model
|
|
1731
|
+
|
|
1732
|
+
# Apply args parameters to chat request
|
|
1733
|
+
if args:
|
|
1734
|
+
chat = apply_args_to_chat(chat, args)
|
|
1735
|
+
|
|
1736
|
+
# process_chat downloads the image, just adding the reference here
|
|
1737
|
+
if image is not None:
|
|
1738
|
+
first_message = None
|
|
1739
|
+
for message in chat["messages"]:
|
|
1740
|
+
if message["role"] == "user":
|
|
1741
|
+
first_message = message
|
|
1742
|
+
break
|
|
1743
|
+
image_content = {"type": "image_url", "image_url": {"url": image}}
|
|
1744
|
+
if "content" in first_message:
|
|
1745
|
+
if isinstance(first_message["content"], list):
|
|
1746
|
+
image_url = None
|
|
1747
|
+
for item in first_message["content"]:
|
|
1748
|
+
if "image_url" in item:
|
|
1749
|
+
image_url = item["image_url"]
|
|
1750
|
+
# If no image_url, add one
|
|
1751
|
+
if image_url is None:
|
|
1752
|
+
first_message["content"].insert(0, image_content)
|
|
1753
|
+
else:
|
|
1754
|
+
image_url["url"] = image
|
|
1755
|
+
else:
|
|
1756
|
+
first_message["content"] = [image_content, {"type": "text", "text": first_message["content"]}]
|
|
1757
|
+
if audio is not None:
|
|
1758
|
+
first_message = None
|
|
1759
|
+
for message in chat["messages"]:
|
|
1760
|
+
if message["role"] == "user":
|
|
1761
|
+
first_message = message
|
|
1762
|
+
break
|
|
1763
|
+
audio_content = {"type": "input_audio", "input_audio": {"data": audio, "format": "mp3"}}
|
|
1764
|
+
if "content" in first_message:
|
|
1765
|
+
if isinstance(first_message["content"], list):
|
|
1766
|
+
input_audio = None
|
|
1767
|
+
for item in first_message["content"]:
|
|
1768
|
+
if "input_audio" in item:
|
|
1769
|
+
input_audio = item["input_audio"]
|
|
1770
|
+
# If no input_audio, add one
|
|
1771
|
+
if input_audio is None:
|
|
1772
|
+
first_message["content"].insert(0, audio_content)
|
|
1773
|
+
else:
|
|
1774
|
+
input_audio["data"] = audio
|
|
1775
|
+
else:
|
|
1776
|
+
first_message["content"] = [audio_content, {"type": "text", "text": first_message["content"]}]
|
|
1777
|
+
if file is not None:
|
|
1778
|
+
first_message = None
|
|
1779
|
+
for message in chat["messages"]:
|
|
1780
|
+
if message["role"] == "user":
|
|
1781
|
+
first_message = message
|
|
1782
|
+
break
|
|
1783
|
+
file_content = {"type": "file", "file": {"filename": get_filename(file), "file_data": file}}
|
|
1784
|
+
if "content" in first_message:
|
|
1785
|
+
if isinstance(first_message["content"], list):
|
|
1786
|
+
file_data = None
|
|
1787
|
+
for item in first_message["content"]:
|
|
1788
|
+
if "file" in item:
|
|
1789
|
+
file_data = item["file"]
|
|
1790
|
+
# If no file_data, add one
|
|
1791
|
+
if file_data is None:
|
|
1792
|
+
first_message["content"].insert(0, file_content)
|
|
1793
|
+
else:
|
|
1794
|
+
file_data["filename"] = get_filename(file)
|
|
1795
|
+
file_data["file_data"] = file
|
|
1796
|
+
else:
|
|
1797
|
+
first_message["content"] = [file_content, {"type": "text", "text": first_message["content"]}]
|
|
1798
|
+
|
|
1799
|
+
if g_verbose:
|
|
1800
|
+
printdump(chat)
|
|
1801
|
+
|
|
1802
|
+
try:
|
|
1803
|
+
context = {
|
|
1804
|
+
"tools": tools or "all",
|
|
1805
|
+
}
|
|
1806
|
+
response = await g_app.chat_completion(chat, context=context)
|
|
1807
|
+
|
|
1808
|
+
if raw:
|
|
1809
|
+
print(json.dumps(response, indent=2))
|
|
1810
|
+
exit(0)
|
|
1811
|
+
else:
|
|
1812
|
+
msg = response["choices"][0]["message"]
|
|
1813
|
+
if "content" in msg or "answer" in msg:
|
|
1814
|
+
print(msg["content"])
|
|
1815
|
+
|
|
1816
|
+
generated_files = []
|
|
1817
|
+
for choice in response["choices"]:
|
|
1818
|
+
if "message" in choice:
|
|
1819
|
+
msg = choice["message"]
|
|
1820
|
+
if "images" in msg:
|
|
1821
|
+
for image in msg["images"]:
|
|
1822
|
+
image_url = image["image_url"]["url"]
|
|
1823
|
+
generated_files.append(image_url)
|
|
1824
|
+
if "audios" in msg:
|
|
1825
|
+
for audio in msg["audios"]:
|
|
1826
|
+
audio_url = audio["audio_url"]["url"]
|
|
1827
|
+
generated_files.append(audio_url)
|
|
1828
|
+
|
|
1829
|
+
if len(generated_files) > 0:
|
|
1830
|
+
print("\nSaved files:")
|
|
1831
|
+
for file in generated_files:
|
|
1832
|
+
if file.startswith("/~cache"):
|
|
1833
|
+
print(get_cache_path(file[8:]))
|
|
1834
|
+
print(urljoin("http://localhost:8000", file))
|
|
1835
|
+
else:
|
|
1836
|
+
print(file)
|
|
1837
|
+
|
|
1838
|
+
except HTTPError as e:
|
|
1839
|
+
# HTTP error (4xx, 5xx)
|
|
1840
|
+
print(f"{e}:\n{e.body}")
|
|
1841
|
+
g_app.exit(1)
|
|
1842
|
+
except aiohttp.ClientConnectionError as e:
|
|
1843
|
+
# Connection issues
|
|
1844
|
+
print(f"Connection error: {e}")
|
|
1845
|
+
g_app.exit(1)
|
|
1846
|
+
except asyncio.TimeoutError as e:
|
|
1847
|
+
# Timeout
|
|
1848
|
+
print(f"Timeout error: {e}")
|
|
1849
|
+
g_app.exit(1)
|
|
1850
|
+
|
|
1851
|
+
|
|
1852
|
+
def config_str(key):
|
|
1853
|
+
return key in g_config and g_config[key] or None
|
|
1854
|
+
|
|
1855
|
+
|
|
1856
|
+
def load_config(config, providers, verbose=None):
|
|
1857
|
+
global g_config, g_providers, g_verbose
|
|
1858
|
+
g_config = config
|
|
1859
|
+
g_providers = providers
|
|
1860
|
+
if verbose:
|
|
1861
|
+
g_verbose = verbose
|
|
1862
|
+
|
|
1863
|
+
|
|
1864
|
+
def init_llms(config, providers):
|
|
1865
|
+
global g_config, g_handlers
|
|
1866
|
+
|
|
1867
|
+
load_config(config, providers)
|
|
1868
|
+
g_handlers = {}
|
|
1869
|
+
# iterate over config and replace $ENV with env value
|
|
1870
|
+
for key, value in g_config.items():
|
|
1871
|
+
if isinstance(value, str) and value.startswith("$"):
|
|
1872
|
+
g_config[key] = os.getenv(value[1:], "")
|
|
1873
|
+
|
|
1874
|
+
# if g_verbose:
|
|
1875
|
+
# printdump(g_config)
|
|
1876
|
+
providers = g_config["providers"]
|
|
1877
|
+
|
|
1878
|
+
for id, orig in providers.items():
|
|
1879
|
+
if "enabled" in orig and not orig["enabled"]:
|
|
1880
|
+
continue
|
|
1881
|
+
|
|
1882
|
+
provider, constructor_kwargs = create_provider_from_definition(id, orig)
|
|
1883
|
+
if provider and provider.test(**constructor_kwargs):
|
|
1884
|
+
g_handlers[id] = provider
|
|
1885
|
+
return g_handlers
|
|
1886
|
+
|
|
1887
|
+
|
|
1888
|
+
def create_provider_from_definition(id, orig):
|
|
1889
|
+
definition = orig.copy()
|
|
1890
|
+
provider_id = definition.get("id", id)
|
|
1891
|
+
if "id" not in definition:
|
|
1892
|
+
definition["id"] = provider_id
|
|
1893
|
+
provider = g_providers.get(provider_id)
|
|
1894
|
+
constructor_kwargs = create_provider_kwargs(definition, provider)
|
|
1895
|
+
provider = create_provider(constructor_kwargs)
|
|
1896
|
+
return provider, constructor_kwargs
|
|
1897
|
+
|
|
1898
|
+
|
|
1899
|
+
def create_provider_kwargs(definition, provider=None):
|
|
1900
|
+
if provider:
|
|
1901
|
+
provider = provider.copy()
|
|
1902
|
+
provider.update(definition)
|
|
1903
|
+
else:
|
|
1904
|
+
provider = definition.copy()
|
|
1905
|
+
|
|
1906
|
+
# Replace API keys with environment variables if they start with $
|
|
1907
|
+
if "api_key" in provider:
|
|
1908
|
+
value = provider["api_key"]
|
|
1909
|
+
if isinstance(value, str) and value.startswith("$"):
|
|
1910
|
+
provider["api_key"] = os.getenv(value[1:], "")
|
|
1911
|
+
|
|
1912
|
+
if "api_key" not in provider and "env" in provider:
|
|
1913
|
+
for env_var in provider["env"]:
|
|
1914
|
+
val = os.getenv(env_var)
|
|
1915
|
+
if val:
|
|
1916
|
+
provider["api_key"] = val
|
|
1917
|
+
break
|
|
1918
|
+
|
|
1919
|
+
# Create a copy of provider
|
|
1920
|
+
constructor_kwargs = dict(provider.items())
|
|
1921
|
+
# Create a copy of all list and dict values
|
|
1922
|
+
for key, value in constructor_kwargs.items():
|
|
1923
|
+
if isinstance(value, (list, dict)):
|
|
1924
|
+
constructor_kwargs[key] = value.copy()
|
|
1925
|
+
constructor_kwargs["headers"] = g_config["defaults"]["headers"].copy()
|
|
1926
|
+
|
|
1927
|
+
if "modalities" in definition:
|
|
1928
|
+
constructor_kwargs["modalities"] = {}
|
|
1929
|
+
for modality, modality_definition in definition["modalities"].items():
|
|
1930
|
+
modality_provider = create_provider(modality_definition)
|
|
1931
|
+
if not modality_provider:
|
|
1932
|
+
return None
|
|
1933
|
+
constructor_kwargs["modalities"][modality] = modality_provider
|
|
1934
|
+
|
|
1935
|
+
return constructor_kwargs
|
|
1936
|
+
|
|
1937
|
+
|
|
1938
|
+
def create_provider(provider):
|
|
1939
|
+
if not isinstance(provider, dict):
|
|
1940
|
+
return None
|
|
1941
|
+
provider_label = provider.get("id", provider.get("name", "unknown"))
|
|
1942
|
+
npm_sdk = provider.get("npm")
|
|
1943
|
+
if not npm_sdk:
|
|
1944
|
+
_log(f"Provider {provider_label} is missing 'npm' sdk")
|
|
1945
|
+
return None
|
|
1946
|
+
|
|
1947
|
+
for provider_type in g_app.all_providers:
|
|
1948
|
+
if provider_type.sdk == npm_sdk:
|
|
1949
|
+
kwargs = create_provider_kwargs(provider)
|
|
1950
|
+
if kwargs is None:
|
|
1951
|
+
kwargs = provider
|
|
1952
|
+
return provider_type(**kwargs)
|
|
1953
|
+
|
|
1954
|
+
_log(f"Could not find provider {provider_label} with npm sdk {npm_sdk}")
|
|
1955
|
+
return None
|
|
1956
|
+
|
|
1957
|
+
|
|
1958
|
+
async def load_llms():
|
|
1959
|
+
global g_handlers
|
|
1960
|
+
_log("Loading providers...")
|
|
1961
|
+
for _name, provider in g_handlers.items():
|
|
1962
|
+
await provider.load()
|
|
1963
|
+
|
|
1964
|
+
|
|
1965
|
+
def save_config(config):
|
|
1966
|
+
global g_config, g_config_path
|
|
1967
|
+
g_config = config
|
|
1968
|
+
with open(g_config_path, "w", encoding="utf-8") as f:
|
|
1969
|
+
json.dump(g_config, f, indent=4)
|
|
1970
|
+
_log(f"Saved config to {g_config_path}")
|
|
1971
|
+
|
|
1972
|
+
|
|
1973
|
+
def github_url(filename):
|
|
1974
|
+
return f"https://raw.githubusercontent.com/ServiceStack/llms/refs/heads/main/llms/{filename}"
|
|
1975
|
+
|
|
1976
|
+
|
|
1977
|
+
async def get_text(url):
|
|
1978
|
+
async with aiohttp.ClientSession() as session:
|
|
1979
|
+
_log(f"GET {url}")
|
|
1980
|
+
async with session.get(url) as resp:
|
|
1981
|
+
text = await resp.text()
|
|
1982
|
+
if resp.status >= 400:
|
|
1983
|
+
raise HTTPError(resp.status, reason=resp.reason, body=text, headers=dict(resp.headers))
|
|
1984
|
+
return text
|
|
1985
|
+
|
|
1986
|
+
|
|
1987
|
+
async def save_text_url(url, save_path):
|
|
1988
|
+
text = await get_text(url)
|
|
1989
|
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
1990
|
+
with open(save_path, "w", encoding="utf-8") as f:
|
|
1991
|
+
f.write(text)
|
|
1992
|
+
return text
|
|
1993
|
+
|
|
1994
|
+
|
|
1995
|
+
async def save_default_config(config_path):
|
|
1996
|
+
global g_config
|
|
1997
|
+
config_json = await save_text_url(github_url("llms.json"), config_path)
|
|
1998
|
+
g_config = json.loads(config_json)
|
|
1999
|
+
|
|
2000
|
+
|
|
2001
|
+
async def update_providers(home_providers_path):
|
|
2002
|
+
global g_providers
|
|
2003
|
+
text = await get_text("https://models.dev/api.json")
|
|
2004
|
+
all_providers = json.loads(text)
|
|
2005
|
+
extra_providers = {}
|
|
2006
|
+
extra_providers_path = home_providers_path.replace("providers.json", "providers-extra.json")
|
|
2007
|
+
if os.path.exists(extra_providers_path):
|
|
2008
|
+
with open(extra_providers_path) as f:
|
|
2009
|
+
extra_providers = json.load(f)
|
|
2010
|
+
|
|
2011
|
+
filtered_providers = {}
|
|
2012
|
+
for id, provider in all_providers.items():
|
|
2013
|
+
if id in g_config["providers"]:
|
|
2014
|
+
filtered_providers[id] = provider
|
|
2015
|
+
if id in extra_providers and "models" in extra_providers[id]:
|
|
2016
|
+
for model_id, model in extra_providers[id]["models"].items():
|
|
2017
|
+
if "id" not in model:
|
|
2018
|
+
model["id"] = model_id
|
|
2019
|
+
if "name" not in model:
|
|
2020
|
+
model["name"] = id_to_name(model["id"])
|
|
2021
|
+
filtered_providers[id]["models"][model_id] = model
|
|
2022
|
+
|
|
2023
|
+
os.makedirs(os.path.dirname(home_providers_path), exist_ok=True)
|
|
2024
|
+
with open(home_providers_path, "w", encoding="utf-8") as f:
|
|
2025
|
+
json.dump(filtered_providers, f)
|
|
2026
|
+
|
|
2027
|
+
g_providers = filtered_providers
|
|
2028
|
+
|
|
2029
|
+
|
|
2030
|
+
def provider_status():
|
|
2031
|
+
enabled = list(g_handlers.keys())
|
|
2032
|
+
disabled = [provider for provider in g_config["providers"] if provider not in enabled]
|
|
2033
|
+
enabled.sort()
|
|
2034
|
+
disabled.sort()
|
|
2035
|
+
return enabled, disabled
|
|
2036
|
+
|
|
2037
|
+
|
|
2038
|
+
def print_status():
|
|
2039
|
+
enabled, disabled = provider_status()
|
|
2040
|
+
if len(enabled) > 0:
|
|
2041
|
+
print(f"\nEnabled: {', '.join(enabled)}")
|
|
2042
|
+
else:
|
|
2043
|
+
print("\nEnabled: None")
|
|
2044
|
+
if len(disabled) > 0:
|
|
2045
|
+
print(f"Disabled: {', '.join(disabled)}")
|
|
2046
|
+
else:
|
|
2047
|
+
print("Disabled: None")
|
|
2048
|
+
|
|
2049
|
+
|
|
2050
|
+
def home_llms_path(filename):
|
|
2051
|
+
home_dir = os.getenv("LLMS_HOME", os.path.join(os.getenv("HOME"), ".llms"))
|
|
2052
|
+
relative_path = os.path.join(home_dir, filename)
|
|
2053
|
+
# return resolved full absolute path
|
|
2054
|
+
return os.path.abspath(os.path.normpath(relative_path))
|
|
2055
|
+
|
|
2056
|
+
|
|
2057
|
+
def get_cache_path(path=""):
|
|
2058
|
+
return home_llms_path(f"cache/{path}") if path else home_llms_path("cache")
|
|
2059
|
+
|
|
2060
|
+
|
|
2061
|
+
def get_config_path():
|
|
2062
|
+
home_config_path = home_llms_path("llms.json")
|
|
2063
|
+
check_paths = [
|
|
2064
|
+
"./llms.json",
|
|
2065
|
+
home_config_path,
|
|
2066
|
+
]
|
|
2067
|
+
if os.getenv("LLMS_CONFIG_PATH"):
|
|
2068
|
+
check_paths.insert(0, os.getenv("LLMS_CONFIG_PATH"))
|
|
2069
|
+
|
|
2070
|
+
for check_path in check_paths:
|
|
2071
|
+
g_config_path = os.path.normpath(os.path.join(os.path.dirname(__file__), check_path))
|
|
2072
|
+
if os.path.exists(g_config_path):
|
|
2073
|
+
return g_config_path
|
|
2074
|
+
return None
|
|
2075
|
+
|
|
2076
|
+
|
|
2077
|
+
def enable_provider(provider):
|
|
2078
|
+
msg = None
|
|
2079
|
+
provider_config = g_config["providers"][provider]
|
|
2080
|
+
if not provider_config:
|
|
2081
|
+
return None, f"Provider {provider} not found"
|
|
2082
|
+
|
|
2083
|
+
provider, constructor_kwargs = create_provider_from_definition(provider, provider_config)
|
|
2084
|
+
msg = provider.validate(**constructor_kwargs)
|
|
2085
|
+
if msg:
|
|
2086
|
+
return None, msg
|
|
2087
|
+
|
|
2088
|
+
provider_config["enabled"] = True
|
|
2089
|
+
save_config(g_config)
|
|
2090
|
+
init_llms(g_config, g_providers)
|
|
2091
|
+
return provider_config, msg
|
|
2092
|
+
|
|
2093
|
+
|
|
2094
|
+
def disable_provider(provider):
|
|
2095
|
+
provider_config = g_config["providers"][provider]
|
|
2096
|
+
provider_config["enabled"] = False
|
|
2097
|
+
save_config(g_config)
|
|
2098
|
+
init_llms(g_config, g_providers)
|
|
2099
|
+
|
|
2100
|
+
|
|
2101
|
+
def resolve_root():
|
|
2102
|
+
# Try to find the resource root directory
|
|
2103
|
+
# When installed as a package, static files may be in different locations
|
|
2104
|
+
|
|
2105
|
+
# Method 1: Try importlib.resources for package data (Python 3.9+)
|
|
2106
|
+
try:
|
|
2107
|
+
try:
|
|
2108
|
+
# Try to access the package resources
|
|
2109
|
+
pkg_files = resources.files("llms")
|
|
2110
|
+
# Check if ui directory exists in package resources
|
|
2111
|
+
if hasattr(pkg_files, "is_dir") and (pkg_files / "ui").is_dir():
|
|
2112
|
+
_log(f"RESOURCE ROOT (package): {pkg_files}")
|
|
2113
|
+
return pkg_files
|
|
2114
|
+
except (FileNotFoundError, AttributeError, TypeError):
|
|
2115
|
+
# Package doesn't have the resources, try other methods
|
|
2116
|
+
pass
|
|
2117
|
+
except ImportError:
|
|
2118
|
+
# importlib.resources not available (Python < 3.9)
|
|
2119
|
+
pass
|
|
2120
|
+
|
|
2121
|
+
# Method 1b: Look for the installed package and check for UI files
|
|
2122
|
+
try:
|
|
2123
|
+
import llms
|
|
2124
|
+
|
|
2125
|
+
# If llms is a package, check its directory
|
|
2126
|
+
if hasattr(llms, "__path__"):
|
|
2127
|
+
# It's a package
|
|
2128
|
+
package_path = Path(llms.__path__[0])
|
|
2129
|
+
|
|
2130
|
+
# Check if UI files are in the package directory
|
|
2131
|
+
if (package_path / "index.html").exists() and (package_path / "ui").is_dir():
|
|
2132
|
+
_log(f"RESOURCE ROOT (package directory): {package_path}")
|
|
2133
|
+
return package_path
|
|
2134
|
+
else:
|
|
2135
|
+
# It's a module
|
|
2136
|
+
module_path = Path(llms.__file__).resolve().parent
|
|
2137
|
+
|
|
2138
|
+
# Check if UI files are in the same directory as the module
|
|
2139
|
+
if (module_path / "index.html").exists() and (module_path / "ui").is_dir():
|
|
2140
|
+
_log(f"RESOURCE ROOT (module directory): {module_path}")
|
|
2141
|
+
return module_path
|
|
2142
|
+
|
|
2143
|
+
# Check parent directory (sometimes data files are installed one level up)
|
|
2144
|
+
parent_path = module_path.parent
|
|
2145
|
+
if (parent_path / "index.html").exists() and (parent_path / "ui").is_dir():
|
|
2146
|
+
_log(f"RESOURCE ROOT (module parent): {parent_path}")
|
|
2147
|
+
return parent_path
|
|
2148
|
+
|
|
2149
|
+
except (ImportError, AttributeError):
|
|
2150
|
+
pass
|
|
2151
|
+
|
|
2152
|
+
# Method 2: Try to find data files in sys.prefix (where data_files are installed)
|
|
2153
|
+
# Get all possible installation directories
|
|
2154
|
+
possible_roots = [
|
|
2155
|
+
Path(sys.prefix), # Standard installation
|
|
2156
|
+
Path(sys.prefix) / "share", # Some distributions
|
|
2157
|
+
Path(sys.base_prefix), # Virtual environments
|
|
2158
|
+
Path(sys.base_prefix) / "share",
|
|
2159
|
+
]
|
|
2160
|
+
|
|
2161
|
+
# Add site-packages directories
|
|
2162
|
+
for site_dir in site.getsitepackages():
|
|
2163
|
+
possible_roots.extend(
|
|
2164
|
+
[
|
|
2165
|
+
Path(site_dir),
|
|
2166
|
+
Path(site_dir).parent,
|
|
2167
|
+
Path(site_dir).parent / "share",
|
|
2168
|
+
]
|
|
2169
|
+
)
|
|
2170
|
+
|
|
2171
|
+
# Add user site directory
|
|
2172
|
+
try:
|
|
2173
|
+
user_site = site.getusersitepackages()
|
|
2174
|
+
if user_site:
|
|
2175
|
+
possible_roots.extend(
|
|
2176
|
+
[
|
|
2177
|
+
Path(user_site),
|
|
2178
|
+
Path(user_site).parent,
|
|
2179
|
+
Path(user_site).parent / "share",
|
|
2180
|
+
]
|
|
2181
|
+
)
|
|
2182
|
+
except AttributeError:
|
|
2183
|
+
pass
|
|
2184
|
+
|
|
2185
|
+
# Method 2b: Look for data files in common macOS Homebrew locations
|
|
2186
|
+
# Homebrew often installs data files in different locations
|
|
2187
|
+
homebrew_roots = []
|
|
2188
|
+
if sys.platform == "darwin": # macOS
|
|
2189
|
+
homebrew_prefixes = ["/opt/homebrew", "/usr/local"] # Apple Silicon and Intel
|
|
2190
|
+
for prefix in homebrew_prefixes:
|
|
2191
|
+
if Path(prefix).exists():
|
|
2192
|
+
homebrew_roots.extend(
|
|
2193
|
+
[
|
|
2194
|
+
Path(prefix),
|
|
2195
|
+
Path(prefix) / "share",
|
|
2196
|
+
Path(prefix) / "lib" / "python3.11" / "site-packages",
|
|
2197
|
+
Path(prefix)
|
|
2198
|
+
/ "lib"
|
|
2199
|
+
/ f"python{sys.version_info.major}.{sys.version_info.minor}"
|
|
2200
|
+
/ "site-packages",
|
|
2201
|
+
]
|
|
2202
|
+
)
|
|
2203
|
+
|
|
2204
|
+
possible_roots.extend(homebrew_roots)
|
|
2205
|
+
|
|
2206
|
+
for root in possible_roots:
|
|
2207
|
+
try:
|
|
2208
|
+
if root.exists() and (root / "index.html").exists() and (root / "ui").is_dir():
|
|
2209
|
+
_log(f"RESOURCE ROOT (data files): {root}")
|
|
2210
|
+
return root
|
|
2211
|
+
except (OSError, PermissionError):
|
|
2212
|
+
continue
|
|
2213
|
+
|
|
2214
|
+
# Method 3: Development mode - look relative to this file
|
|
2215
|
+
# __file__ is *this* module; look in same directory first, then parent
|
|
2216
|
+
dev_roots = [
|
|
2217
|
+
Path(__file__).resolve().parent, # Same directory as llms.py
|
|
2218
|
+
Path(__file__).resolve().parent.parent, # Parent directory (repo root)
|
|
2219
|
+
]
|
|
2220
|
+
|
|
2221
|
+
for root in dev_roots:
|
|
2222
|
+
try:
|
|
2223
|
+
if (root / "index.html").exists() and (root / "ui").is_dir():
|
|
2224
|
+
_log(f"RESOURCE ROOT (development): {root}")
|
|
2225
|
+
return root
|
|
2226
|
+
except (OSError, PermissionError):
|
|
2227
|
+
continue
|
|
2228
|
+
|
|
2229
|
+
# Fallback: use the directory containing this file
|
|
2230
|
+
from_file = Path(__file__).resolve().parent
|
|
2231
|
+
_log(f"RESOURCE ROOT (fallback): {from_file}")
|
|
2232
|
+
return from_file
|
|
2233
|
+
|
|
2234
|
+
|
|
2235
|
+
def resource_exists(resource_path):
|
|
2236
|
+
# Check if resource files exist (handle both Path and Traversable objects)
|
|
2237
|
+
try:
|
|
2238
|
+
if hasattr(resource_path, "is_file"):
|
|
2239
|
+
return resource_path.is_file()
|
|
2240
|
+
else:
|
|
2241
|
+
return os.path.exists(resource_path)
|
|
2242
|
+
except (OSError, AttributeError):
|
|
2243
|
+
pass
|
|
2244
|
+
|
|
2245
|
+
|
|
2246
|
+
def read_resource_text(resource_path):
|
|
2247
|
+
if hasattr(resource_path, "read_text"):
|
|
2248
|
+
return resource_path.read_text()
|
|
2249
|
+
else:
|
|
2250
|
+
with open(resource_path, encoding="utf-8") as f:
|
|
2251
|
+
return f.read()
|
|
2252
|
+
|
|
2253
|
+
|
|
2254
|
+
def read_resource_file_bytes(resource_file):
|
|
2255
|
+
try:
|
|
2256
|
+
if hasattr(_ROOT, "joinpath"):
|
|
2257
|
+
# importlib.resources Traversable
|
|
2258
|
+
index_resource = _ROOT.joinpath(resource_file)
|
|
2259
|
+
if index_resource.is_file():
|
|
2260
|
+
return index_resource.read_bytes()
|
|
2261
|
+
else:
|
|
2262
|
+
# Regular Path object
|
|
2263
|
+
index_path = _ROOT / resource_file
|
|
2264
|
+
if index_path.exists():
|
|
2265
|
+
return index_path.read_bytes()
|
|
2266
|
+
except (OSError, PermissionError, AttributeError) as e:
|
|
2267
|
+
_log(f"Error reading resource bytes: {e}")
|
|
2268
|
+
|
|
2269
|
+
|
|
2270
|
+
async def check_models(provider_name, model_names=None):
|
|
2271
|
+
"""
|
|
2272
|
+
Check validity of models for a specific provider by sending a ping message.
|
|
2273
|
+
|
|
2274
|
+
Args:
|
|
2275
|
+
provider_name: Name of the provider to check
|
|
2276
|
+
model_names: List of specific model names to check, or None to check all models
|
|
2277
|
+
"""
|
|
2278
|
+
if provider_name not in g_handlers:
|
|
2279
|
+
print(f"Provider '{provider_name}' not found or not enabled")
|
|
2280
|
+
print(f"Available providers: {', '.join(g_handlers.keys())}")
|
|
2281
|
+
return
|
|
2282
|
+
|
|
2283
|
+
provider = g_handlers[provider_name]
|
|
2284
|
+
models_to_check = []
|
|
2285
|
+
|
|
2286
|
+
# Determine which models to check
|
|
2287
|
+
if model_names is None or (len(model_names) == 1 and model_names[0] == "all"):
|
|
2288
|
+
# Check all models for this provider
|
|
2289
|
+
models_to_check = list(provider.models.keys())
|
|
2290
|
+
else:
|
|
2291
|
+
# Check only specified models
|
|
2292
|
+
for model_name in model_names:
|
|
2293
|
+
provider_model = provider.provider_model(model_name)
|
|
2294
|
+
if provider_model:
|
|
2295
|
+
models_to_check.append(model_name)
|
|
2296
|
+
else:
|
|
2297
|
+
print(f"Model '{model_name}' not found in provider '{provider_name}'")
|
|
2298
|
+
|
|
2299
|
+
if not models_to_check:
|
|
2300
|
+
print(f"No models to check for provider '{provider_name}'")
|
|
2301
|
+
return
|
|
2302
|
+
|
|
2303
|
+
print(
|
|
2304
|
+
f"\nChecking {len(models_to_check)} model{'' if len(models_to_check) == 1 else 's'} for provider '{provider_name}':\n"
|
|
2305
|
+
)
|
|
2306
|
+
|
|
2307
|
+
# Test each model
|
|
2308
|
+
for model in models_to_check:
|
|
2309
|
+
await check_provider_model(provider, model)
|
|
2310
|
+
|
|
2311
|
+
print()
|
|
2312
|
+
|
|
2313
|
+
|
|
2314
|
+
async def check_provider_model(provider, model):
|
|
2315
|
+
# Create a simple ping chat request
|
|
2316
|
+
chat = (provider.check or g_config["defaults"]["check"]).copy()
|
|
2317
|
+
chat["model"] = model
|
|
2318
|
+
|
|
2319
|
+
success = False
|
|
2320
|
+
started_at = time.time()
|
|
2321
|
+
try:
|
|
2322
|
+
# Try to get a response from the model
|
|
2323
|
+
response = await provider.chat(chat)
|
|
2324
|
+
duration_ms = int((time.time() - started_at) * 1000)
|
|
2325
|
+
|
|
2326
|
+
# Check if we got a valid response
|
|
2327
|
+
if response and "choices" in response and len(response["choices"]) > 0:
|
|
2328
|
+
success = True
|
|
2329
|
+
print(f" ✓ {model:<40} ({duration_ms}ms)")
|
|
2330
|
+
else:
|
|
2331
|
+
print(f" ✗ {model:<40} Invalid response format")
|
|
2332
|
+
except HTTPError as e:
|
|
2333
|
+
duration_ms = int((time.time() - started_at) * 1000)
|
|
2334
|
+
error_msg = f"HTTP {e.status}"
|
|
2335
|
+
try:
|
|
2336
|
+
# Try to parse error body for more details
|
|
2337
|
+
error_body = json.loads(e.body) if e.body else {}
|
|
2338
|
+
if "error" in error_body:
|
|
2339
|
+
error = error_body["error"]
|
|
2340
|
+
if isinstance(error, dict):
|
|
2341
|
+
if "message" in error and isinstance(error["message"], str):
|
|
2342
|
+
# OpenRouter
|
|
2343
|
+
error_msg = error["message"]
|
|
2344
|
+
if "code" in error:
|
|
2345
|
+
error_msg = f"{error['code']} {error_msg}"
|
|
2346
|
+
if "metadata" in error and "raw" in error["metadata"]:
|
|
2347
|
+
error_msg += f" - {error['metadata']['raw']}"
|
|
2348
|
+
if "provider" in error:
|
|
2349
|
+
error_msg += f" ({error['provider']})"
|
|
2350
|
+
elif isinstance(error, str):
|
|
2351
|
+
error_msg = error
|
|
2352
|
+
elif "message" in error_body:
|
|
2353
|
+
if isinstance(error_body["message"], str):
|
|
2354
|
+
error_msg = error_body["message"]
|
|
2355
|
+
elif (
|
|
2356
|
+
isinstance(error_body["message"], dict)
|
|
2357
|
+
and "detail" in error_body["message"]
|
|
2358
|
+
and isinstance(error_body["message"]["detail"], list)
|
|
2359
|
+
):
|
|
2360
|
+
# codestral error format
|
|
2361
|
+
error_msg = error_body["message"]["detail"][0]["msg"]
|
|
2362
|
+
if (
|
|
2363
|
+
"loc" in error_body["message"]["detail"][0]
|
|
2364
|
+
and len(error_body["message"]["detail"][0]["loc"]) > 0
|
|
2365
|
+
):
|
|
2366
|
+
error_msg += f" (in {' '.join(error_body['message']['detail'][0]['loc'])})"
|
|
2367
|
+
except Exception as parse_error:
|
|
2368
|
+
_log(f"Error parsing error body: {parse_error}")
|
|
2369
|
+
error_msg = e.body[:100] if e.body else f"HTTP {e.status}"
|
|
2370
|
+
print(f" ✗ {model:<40} {error_msg}")
|
|
2371
|
+
except asyncio.TimeoutError:
|
|
2372
|
+
duration_ms = int((time.time() - started_at) * 1000)
|
|
2373
|
+
print(f" ✗ {model:<40} Timeout after {duration_ms}ms")
|
|
2374
|
+
except Exception as e:
|
|
2375
|
+
duration_ms = int((time.time() - started_at) * 1000)
|
|
2376
|
+
error_msg = str(e)[:100]
|
|
2377
|
+
print(f" ✗ {model:<40} {error_msg}")
|
|
2378
|
+
return success
|
|
2379
|
+
|
|
2380
|
+
|
|
2381
|
+
def text_from_resource(filename):
|
|
2382
|
+
global _ROOT
|
|
2383
|
+
resource_path = _ROOT / filename
|
|
2384
|
+
if resource_exists(resource_path):
|
|
2385
|
+
try:
|
|
2386
|
+
return read_resource_text(resource_path)
|
|
2387
|
+
except (OSError, AttributeError) as e:
|
|
2388
|
+
_log(f"Error reading resource config {filename}: {e}")
|
|
2389
|
+
return None
|
|
2390
|
+
|
|
2391
|
+
|
|
2392
|
+
def text_from_file(filename):
|
|
2393
|
+
if os.path.exists(filename):
|
|
2394
|
+
with open(filename, encoding="utf-8") as f:
|
|
2395
|
+
return f.read()
|
|
2396
|
+
return None
|
|
2397
|
+
|
|
2398
|
+
|
|
2399
|
+
def json_from_file(filename):
|
|
2400
|
+
if os.path.exists(filename):
|
|
2401
|
+
with open(filename, encoding="utf-8") as f:
|
|
2402
|
+
return json.load(f)
|
|
2403
|
+
return None
|
|
2404
|
+
|
|
2405
|
+
|
|
2406
|
+
async def text_from_resource_or_url(filename):
|
|
2407
|
+
text = text_from_resource(filename)
|
|
2408
|
+
if not text:
|
|
2409
|
+
try:
|
|
2410
|
+
resource_url = github_url(filename)
|
|
2411
|
+
text = await get_text(resource_url)
|
|
2412
|
+
except Exception as e:
|
|
2413
|
+
_log(f"Error downloading JSON from {resource_url}: {e}")
|
|
2414
|
+
raise e
|
|
2415
|
+
return text
|
|
2416
|
+
|
|
2417
|
+
|
|
2418
|
+
async def save_home_configs():
|
|
2419
|
+
home_config_path = home_llms_path("llms.json")
|
|
2420
|
+
home_providers_path = home_llms_path("providers.json")
|
|
2421
|
+
home_providers_extra_path = home_llms_path("providers-extra.json")
|
|
2422
|
+
|
|
2423
|
+
if (
|
|
2424
|
+
os.path.exists(home_config_path)
|
|
2425
|
+
and os.path.exists(home_providers_path)
|
|
2426
|
+
and os.path.exists(home_providers_extra_path)
|
|
2427
|
+
):
|
|
2428
|
+
return
|
|
2429
|
+
|
|
2430
|
+
llms_home = os.path.dirname(home_config_path)
|
|
2431
|
+
os.makedirs(llms_home, exist_ok=True)
|
|
2432
|
+
try:
|
|
2433
|
+
if not os.path.exists(home_config_path):
|
|
2434
|
+
config_json = await text_from_resource_or_url("llms.json")
|
|
2435
|
+
with open(home_config_path, "w", encoding="utf-8") as f:
|
|
2436
|
+
f.write(config_json)
|
|
2437
|
+
_log(f"Created default config at {home_config_path}")
|
|
2438
|
+
|
|
2439
|
+
if not os.path.exists(home_providers_path):
|
|
2440
|
+
providers_json = await text_from_resource_or_url("providers.json")
|
|
2441
|
+
with open(home_providers_path, "w", encoding="utf-8") as f:
|
|
2442
|
+
f.write(providers_json)
|
|
2443
|
+
_log(f"Created default providers config at {home_providers_path}")
|
|
2444
|
+
|
|
2445
|
+
if not os.path.exists(home_providers_extra_path):
|
|
2446
|
+
extra_json = await text_from_resource_or_url("providers-extra.json")
|
|
2447
|
+
with open(home_providers_extra_path, "w", encoding="utf-8") as f:
|
|
2448
|
+
f.write(extra_json)
|
|
2449
|
+
_log(f"Created default extra providers config at {home_providers_extra_path}")
|
|
2450
|
+
except Exception:
|
|
2451
|
+
print("Could not create llms.json. Create one with --init or use --config <path>")
|
|
2452
|
+
exit(1)
|
|
2453
|
+
|
|
2454
|
+
|
|
2455
|
+
def load_config_json(config_json):
|
|
2456
|
+
if config_json is None:
|
|
2457
|
+
return None
|
|
2458
|
+
config = json.loads(config_json)
|
|
2459
|
+
if not config or "version" not in config or config["version"] < 3:
|
|
2460
|
+
preserve_keys = ["auth", "defaults", "limits", "convert"]
|
|
2461
|
+
new_config = json.loads(text_from_resource("llms.json"))
|
|
2462
|
+
if config:
|
|
2463
|
+
for key in preserve_keys:
|
|
2464
|
+
if key in config:
|
|
2465
|
+
new_config[key] = config[key]
|
|
2466
|
+
config = new_config
|
|
2467
|
+
# move old config to YYYY-MM-DD.bak
|
|
2468
|
+
new_path = f"{g_config_path}.{datetime.now().strftime('%Y-%m-%d')}.bak"
|
|
2469
|
+
if os.path.exists(new_path):
|
|
2470
|
+
os.remove(new_path)
|
|
2471
|
+
os.rename(g_config_path, new_path)
|
|
2472
|
+
print(f"llms.json migrated. old config moved to {new_path}")
|
|
2473
|
+
# save new config
|
|
2474
|
+
save_config(g_config)
|
|
2475
|
+
return config
|
|
2476
|
+
|
|
2477
|
+
|
|
2478
|
+
async def reload_providers():
|
|
2479
|
+
global g_config, g_handlers
|
|
2480
|
+
g_handlers = init_llms(g_config, g_providers)
|
|
2481
|
+
await load_llms()
|
|
2482
|
+
_log(f"{len(g_handlers)} providers loaded")
|
|
2483
|
+
return g_handlers
|
|
2484
|
+
|
|
2485
|
+
|
|
2486
|
+
async def watch_config_files(config_path, providers_path, interval=1):
|
|
2487
|
+
"""Watch config files and reload providers when they change"""
|
|
2488
|
+
global g_config
|
|
2489
|
+
|
|
2490
|
+
config_path = Path(config_path)
|
|
2491
|
+
providers_path = Path(providers_path)
|
|
2492
|
+
|
|
2493
|
+
_log(f"Watching config file: {config_path}")
|
|
2494
|
+
_log(f"Watching providers file: {providers_path}")
|
|
2495
|
+
|
|
2496
|
+
def get_latest_mtime():
|
|
2497
|
+
ret = 0
|
|
2498
|
+
name = "llms.json"
|
|
2499
|
+
if config_path.is_file():
|
|
2500
|
+
ret = config_path.stat().st_mtime
|
|
2501
|
+
name = config_path.name
|
|
2502
|
+
if providers_path.is_file() and providers_path.stat().st_mtime > ret:
|
|
2503
|
+
ret = providers_path.stat().st_mtime
|
|
2504
|
+
name = providers_path.name
|
|
2505
|
+
return ret, name
|
|
2506
|
+
|
|
2507
|
+
latest_mtime, name = get_latest_mtime()
|
|
2508
|
+
|
|
2509
|
+
while True:
|
|
2510
|
+
await asyncio.sleep(interval)
|
|
2511
|
+
|
|
2512
|
+
# Check llms.json
|
|
2513
|
+
try:
|
|
2514
|
+
new_mtime, name = get_latest_mtime()
|
|
2515
|
+
if new_mtime > latest_mtime:
|
|
2516
|
+
_log(f"Config file changed: {name}")
|
|
2517
|
+
latest_mtime = new_mtime
|
|
2518
|
+
|
|
2519
|
+
try:
|
|
2520
|
+
# Reload llms.json
|
|
2521
|
+
with open(config_path) as f:
|
|
2522
|
+
g_config = json.load(f)
|
|
2523
|
+
|
|
2524
|
+
# Reload providers
|
|
2525
|
+
await reload_providers()
|
|
2526
|
+
_log("Providers reloaded successfully")
|
|
2527
|
+
except Exception as e:
|
|
2528
|
+
_log(f"Error reloading config: {e}")
|
|
2529
|
+
except FileNotFoundError:
|
|
2530
|
+
pass
|
|
2531
|
+
|
|
2532
|
+
|
|
2533
|
+
def get_session_token(request):
|
|
2534
|
+
return request.query.get("session") or request.headers.get("X-Session-Token") or request.cookies.get("llms-token")
|
|
2535
|
+
|
|
2536
|
+
|
|
2537
|
+
class AppExtensions:
|
|
2538
|
+
"""
|
|
2539
|
+
APIs extensions can use to extend the app
|
|
2540
|
+
"""
|
|
2541
|
+
|
|
2542
|
+
def __init__(self, cli_args, extra_args):
|
|
2543
|
+
self.cli_args = cli_args
|
|
2544
|
+
self.extra_args = extra_args
|
|
2545
|
+
self.config = None
|
|
2546
|
+
self.error_auth_required = create_error_response("Authentication required", "Unauthorized")
|
|
2547
|
+
self.ui_extensions = []
|
|
2548
|
+
self.chat_request_filters = []
|
|
2549
|
+
self.extensions = []
|
|
2550
|
+
self.loaded = False
|
|
2551
|
+
self.chat_tool_filters = []
|
|
2552
|
+
self.chat_response_filters = []
|
|
2553
|
+
self.chat_error_filters = []
|
|
2554
|
+
self.server_add_get = []
|
|
2555
|
+
self.server_add_post = []
|
|
2556
|
+
self.server_add_put = []
|
|
2557
|
+
self.server_add_delete = []
|
|
2558
|
+
self.server_add_patch = []
|
|
2559
|
+
self.cache_saved_filters = []
|
|
2560
|
+
self.shutdown_handlers = []
|
|
2561
|
+
self.tools = {}
|
|
2562
|
+
self.tool_definitions = []
|
|
2563
|
+
self.tool_groups = {}
|
|
2564
|
+
self.index_headers = []
|
|
2565
|
+
self.index_footers = []
|
|
2566
|
+
self.request_args = {
|
|
2567
|
+
"image_config": dict, # e.g. { "aspect_ratio": "1:1" }
|
|
2568
|
+
"temperature": float, # e.g: 0.7
|
|
2569
|
+
"max_completion_tokens": int, # e.g: 2048
|
|
2570
|
+
"seed": int, # e.g: 42
|
|
2571
|
+
"top_p": float, # e.g: 0.9
|
|
2572
|
+
"frequency_penalty": float, # e.g: 0.5
|
|
2573
|
+
"presence_penalty": float, # e.g: 0.5
|
|
2574
|
+
"stop": list, # e.g: ["Stop"]
|
|
2575
|
+
"reasoning_effort": str, # e.g: minimal, low, medium, high
|
|
2576
|
+
"verbosity": str, # e.g: low, medium, high
|
|
2577
|
+
"service_tier": str, # e.g: auto, default
|
|
2578
|
+
"top_logprobs": int,
|
|
2579
|
+
"safety_identifier": str,
|
|
2580
|
+
"store": bool,
|
|
2581
|
+
"enable_thinking": bool,
|
|
2582
|
+
}
|
|
2583
|
+
self.all_providers = [
|
|
2584
|
+
OpenAiCompatible,
|
|
2585
|
+
MistralProvider,
|
|
2586
|
+
GroqProvider,
|
|
2587
|
+
XaiProvider,
|
|
2588
|
+
CodestralProvider,
|
|
2589
|
+
OllamaProvider,
|
|
2590
|
+
LMStudioProvider,
|
|
2591
|
+
]
|
|
2592
|
+
self.aspect_ratios = {
|
|
2593
|
+
"1:1": "1024×1024",
|
|
2594
|
+
"2:3": "832×1248",
|
|
2595
|
+
"3:2": "1248×832",
|
|
2596
|
+
"3:4": "864×1184",
|
|
2597
|
+
"4:3": "1184×864",
|
|
2598
|
+
"4:5": "896×1152",
|
|
2599
|
+
"5:4": "1152×896",
|
|
2600
|
+
"9:16": "768×1344",
|
|
2601
|
+
"16:9": "1344×768",
|
|
2602
|
+
"21:9": "1536×672",
|
|
2603
|
+
}
|
|
2604
|
+
self.import_maps = {
|
|
2605
|
+
"vue-prod": "/ui/lib/vue.min.mjs",
|
|
2606
|
+
"vue": "/ui/lib/vue.mjs",
|
|
2607
|
+
"vue-router": "/ui/lib/vue-router.min.mjs",
|
|
2608
|
+
"@servicestack/client": "/ui/lib/servicestack-client.mjs",
|
|
2609
|
+
"@servicestack/vue": "/ui/lib/servicestack-vue.mjs",
|
|
2610
|
+
"idb": "/ui/lib/idb.min.mjs",
|
|
2611
|
+
"marked": "/ui/lib/marked.min.mjs",
|
|
2612
|
+
"highlight.js": "/ui/lib/highlight.min.mjs",
|
|
2613
|
+
"chart.js": "/ui/lib/chart.js",
|
|
2614
|
+
"color.js": "/ui/lib/color.js",
|
|
2615
|
+
"ctx.mjs": "/ui/ctx.mjs",
|
|
2616
|
+
}
|
|
2617
|
+
|
|
2618
|
+
def set_config(self, config):
|
|
2619
|
+
self.config = config
|
|
2620
|
+
self.auth_enabled = self.config.get("auth", {}).get("enabled", False)
|
|
2621
|
+
|
|
2622
|
+
# Authentication middleware helper
|
|
2623
|
+
def check_auth(self, request):
|
|
2624
|
+
"""Check if request is authenticated. Returns (is_authenticated, user_data)"""
|
|
2625
|
+
if not self.auth_enabled:
|
|
2626
|
+
return True, None
|
|
2627
|
+
|
|
2628
|
+
# Check for OAuth session token
|
|
2629
|
+
session_token = get_session_token(request)
|
|
2630
|
+
if session_token and session_token in g_sessions:
|
|
2631
|
+
return True, g_sessions[session_token]
|
|
2632
|
+
|
|
2633
|
+
# Check for API key
|
|
2634
|
+
auth_header = request.headers.get("Authorization", "")
|
|
2635
|
+
if auth_header.startswith("Bearer "):
|
|
2636
|
+
api_key = auth_header[7:]
|
|
2637
|
+
if api_key:
|
|
2638
|
+
return True, {"authProvider": "apikey"}
|
|
2639
|
+
|
|
2640
|
+
return False, None
|
|
2641
|
+
|
|
2642
|
+
def get_session(self, request):
|
|
2643
|
+
session_token = get_session_token(request)
|
|
2644
|
+
|
|
2645
|
+
if not session_token or session_token not in g_sessions:
|
|
2646
|
+
return None
|
|
2647
|
+
|
|
2648
|
+
session_data = g_sessions[session_token]
|
|
2649
|
+
return session_data
|
|
2650
|
+
|
|
2651
|
+
def get_username(self, request):
|
|
2652
|
+
session = self.get_session(request)
|
|
2653
|
+
if session:
|
|
2654
|
+
return session.get("userName")
|
|
2655
|
+
return None
|
|
2656
|
+
|
|
2657
|
+
def get_user_path(self, username=None):
|
|
2658
|
+
if username:
|
|
2659
|
+
return home_llms_path(os.path.join("user", username))
|
|
2660
|
+
return home_llms_path(os.path.join("user", "default"))
|
|
2661
|
+
|
|
2662
|
+
def chat_request(self, template=None, text=None, model=None, system_prompt=None):
|
|
2663
|
+
return g_chat_request(template=template, text=text, model=model, system_prompt=system_prompt)
|
|
2664
|
+
|
|
2665
|
+
async def chat_completion(self, chat, context=None):
|
|
2666
|
+
response = await g_chat_completion(chat, context)
|
|
2667
|
+
return response
|
|
2668
|
+
|
|
2669
|
+
def on_cache_saved_filters(self, context):
|
|
2670
|
+
# _log(f"on_cache_saved_filters {len(self.cache_saved_filters)}: {context['url']}")
|
|
2671
|
+
for filter_func in self.cache_saved_filters:
|
|
2672
|
+
filter_func(context)
|
|
2673
|
+
|
|
2674
|
+
async def on_chat_error(self, e, context):
|
|
2675
|
+
# Apply chat error filters
|
|
2676
|
+
if "stackTrace" not in context:
|
|
2677
|
+
context["stackTrace"] = traceback.format_exc()
|
|
2678
|
+
for filter_func in self.chat_error_filters:
|
|
2679
|
+
try:
|
|
2680
|
+
await filter_func(e, context)
|
|
2681
|
+
except Exception as e:
|
|
2682
|
+
_err("chat error filter failed", e)
|
|
2683
|
+
|
|
2684
|
+
async def on_chat_tool(self, chat, context):
|
|
2685
|
+
m_len = len(chat.get("messages", []))
|
|
2686
|
+
t_len = len(self.chat_tool_filters)
|
|
2687
|
+
_dbg(
|
|
2688
|
+
f"on_tool_call for thread {context.get('threadId', None)} with {m_len} {pluralize('message', m_len)}, invoking {t_len} {pluralize('filter', t_len)}:"
|
|
2689
|
+
)
|
|
2690
|
+
for filter_func in self.chat_tool_filters:
|
|
2691
|
+
await filter_func(chat, context)
|
|
2692
|
+
|
|
2693
|
+
def exit(self, exit_code=0):
|
|
2694
|
+
if len(self.shutdown_handlers) > 0:
|
|
2695
|
+
_dbg(f"running {len(self.shutdown_handlers)} shutdown handlers...")
|
|
2696
|
+
for handler in self.shutdown_handlers:
|
|
2697
|
+
handler()
|
|
2698
|
+
|
|
2699
|
+
_dbg(f"exit({exit_code})")
|
|
2700
|
+
sys.exit(exit_code)
|
|
2701
|
+
|
|
2702
|
+
def create_chat_with_tools(self, chat, use_tools="all"):
|
|
2703
|
+
# Inject global tools if present
|
|
2704
|
+
current_chat = chat.copy()
|
|
2705
|
+
tools = current_chat.get("tools")
|
|
2706
|
+
if tools is None:
|
|
2707
|
+
tools = current_chat["tools"] = []
|
|
2708
|
+
if self.tool_definitions and len(tools) == 0:
|
|
2709
|
+
include_all_tools = use_tools == "all"
|
|
2710
|
+
only_tools_list = use_tools.split(",")
|
|
2711
|
+
|
|
2712
|
+
if include_all_tools or len(only_tools_list) > 0:
|
|
2713
|
+
if "tools" not in current_chat:
|
|
2714
|
+
current_chat["tools"] = []
|
|
2715
|
+
|
|
2716
|
+
existing_tools = {t["function"]["name"] for t in current_chat["tools"]}
|
|
2717
|
+
for tool_def in self.tool_definitions:
|
|
2718
|
+
name = tool_def["function"]["name"]
|
|
2719
|
+
if name not in existing_tools and (include_all_tools or name in only_tools_list):
|
|
2720
|
+
current_chat["tools"].append(tool_def)
|
|
2721
|
+
return current_chat
|
|
2722
|
+
|
|
2723
|
+
|
|
2724
|
+
def handler_name(handler):
|
|
2725
|
+
if hasattr(handler, "__name__"):
|
|
2726
|
+
return handler.__name__
|
|
2727
|
+
return "unknown"
|
|
2728
|
+
|
|
2729
|
+
|
|
2730
|
+
class ExtensionContext:
|
|
2731
|
+
def __init__(self, app, path):
|
|
2732
|
+
self.app = app
|
|
2733
|
+
self.cli_args = app.cli_args
|
|
2734
|
+
self.extra_args = app.extra_args
|
|
2735
|
+
self.error_auth_required = app.error_auth_required
|
|
2736
|
+
self.path = path
|
|
2737
|
+
self.name = os.path.basename(path)
|
|
2738
|
+
if self.name.endswith(".py"):
|
|
2739
|
+
self.name = self.name[:-3]
|
|
2740
|
+
self.ext_prefix = f"/ext/{self.name}"
|
|
2741
|
+
self.MOCK = MOCK
|
|
2742
|
+
self.MOCK_DIR = MOCK_DIR
|
|
2743
|
+
self.debug = DEBUG
|
|
2744
|
+
self.verbose = g_verbose
|
|
2745
|
+
self.aspect_ratios = app.aspect_ratios
|
|
2746
|
+
self.request_args = app.request_args
|
|
2747
|
+
self.disabled = False
|
|
2748
|
+
|
|
2749
|
+
def chat_to_prompt(self, chat):
|
|
2750
|
+
return chat_to_prompt(chat)
|
|
2751
|
+
|
|
2752
|
+
def chat_to_system_prompt(self, chat):
|
|
2753
|
+
return chat_to_system_prompt(chat)
|
|
2754
|
+
|
|
2755
|
+
def chat_response_to_message(self, response):
|
|
2756
|
+
return chat_response_to_message(response)
|
|
2757
|
+
|
|
2758
|
+
def last_user_prompt(self, chat):
|
|
2759
|
+
return last_user_prompt(chat)
|
|
2760
|
+
|
|
2761
|
+
def to_file_info(self, chat, info=None, response=None):
|
|
2762
|
+
return to_file_info(chat, info=info, response=response)
|
|
2763
|
+
|
|
2764
|
+
def save_image_to_cache(self, base64_data, filename, image_info, ignore_info=False):
|
|
2765
|
+
return save_image_to_cache(base64_data, filename, image_info, ignore_info=ignore_info)
|
|
2766
|
+
|
|
2767
|
+
def save_bytes_to_cache(self, bytes_data, filename, file_info):
|
|
2768
|
+
return save_bytes_to_cache(bytes_data, filename, file_info)
|
|
2769
|
+
|
|
2770
|
+
def text_from_file(self, path):
|
|
2771
|
+
return text_from_file(path)
|
|
2772
|
+
|
|
2773
|
+
def json_from_file(self, path):
|
|
2774
|
+
return json_from_file(path)
|
|
2775
|
+
|
|
2776
|
+
def download_file(self, url):
|
|
2777
|
+
return download_file(url)
|
|
2778
|
+
|
|
2779
|
+
def session_download_file(self, session, url):
|
|
2780
|
+
return session_download_file(session, url)
|
|
2781
|
+
|
|
2782
|
+
def read_binary_file(self, url):
|
|
2783
|
+
return read_binary_file(url)
|
|
2784
|
+
|
|
2785
|
+
def log(self, message):
|
|
2786
|
+
if self.verbose:
|
|
2787
|
+
print(f"[{self.name}] {message}", flush=True)
|
|
2788
|
+
return message
|
|
2789
|
+
|
|
2790
|
+
def log_json(self, obj):
|
|
2791
|
+
if self.verbose:
|
|
2792
|
+
print(f"[{self.name}] {json.dumps(obj, indent=2)}", flush=True)
|
|
2793
|
+
return obj
|
|
2794
|
+
|
|
2795
|
+
def dbg(self, message):
|
|
2796
|
+
if self.debug:
|
|
2797
|
+
print(f"DEBUG [{self.name}]: {message}", flush=True)
|
|
2798
|
+
|
|
2799
|
+
def err(self, message, e):
|
|
2800
|
+
print(f"ERROR [{self.name}]: {message}", e)
|
|
2801
|
+
if self.verbose:
|
|
2802
|
+
print(traceback.format_exc(), flush=True)
|
|
2803
|
+
|
|
2804
|
+
def error_message(self, e):
|
|
2805
|
+
return to_error_message(e)
|
|
2806
|
+
|
|
2807
|
+
def error_response(self, e, stacktrace=False):
|
|
2808
|
+
return to_error_response(e, stacktrace=stacktrace)
|
|
2809
|
+
|
|
2810
|
+
def add_provider(self, provider):
|
|
2811
|
+
self.log(f"Registered provider: {provider.__name__}")
|
|
2812
|
+
self.app.all_providers.append(provider)
|
|
2813
|
+
|
|
2814
|
+
def register_ui_extension(self, index):
|
|
2815
|
+
path = os.path.join(self.ext_prefix, index)
|
|
2816
|
+
self.log(f"Registered UI extension: {path}")
|
|
2817
|
+
self.app.ui_extensions.append({"id": self.name, "path": path})
|
|
2818
|
+
|
|
2819
|
+
def register_chat_request_filter(self, handler):
|
|
2820
|
+
self.log(f"Registered chat request filter: {handler_name(handler)}")
|
|
2821
|
+
self.app.chat_request_filters.append(handler)
|
|
2822
|
+
|
|
2823
|
+
def register_chat_tool_filter(self, handler):
|
|
2824
|
+
self.log(f"Registered chat tool filter: {handler_name(handler)}")
|
|
2825
|
+
self.app.chat_tool_filters.append(handler)
|
|
2826
|
+
|
|
2827
|
+
def register_chat_response_filter(self, handler):
|
|
2828
|
+
self.log(f"Registered chat response filter: {handler_name(handler)}")
|
|
2829
|
+
self.app.chat_response_filters.append(handler)
|
|
2830
|
+
|
|
2831
|
+
def register_chat_error_filter(self, handler):
|
|
2832
|
+
self.log(f"Registered chat error filter: {handler_name(handler)}")
|
|
2833
|
+
self.app.chat_error_filters.append(handler)
|
|
2834
|
+
|
|
2835
|
+
def register_cache_saved_filter(self, handler):
|
|
2836
|
+
self.log(f"Registered cache saved filter: {handler_name(handler)}")
|
|
2837
|
+
self.app.cache_saved_filters.append(handler)
|
|
2838
|
+
|
|
2839
|
+
def register_shutdown_handler(self, handler):
|
|
2840
|
+
self.log(f"Registered shutdown handler: {handler_name(handler)}")
|
|
2841
|
+
self.app.shutdown_handlers.append(handler)
|
|
2842
|
+
|
|
2843
|
+
def add_static_files(self, ext_dir):
|
|
2844
|
+
self.log(f"Registered static files: {ext_dir}")
|
|
2845
|
+
|
|
2846
|
+
async def serve_static(request):
|
|
2847
|
+
path = request.match_info["path"]
|
|
2848
|
+
file_path = os.path.join(ext_dir, path)
|
|
2849
|
+
if os.path.exists(file_path):
|
|
2850
|
+
return web.FileResponse(file_path)
|
|
2851
|
+
return web.Response(status=404)
|
|
2852
|
+
|
|
2853
|
+
self.app.server_add_get.append((os.path.join(self.ext_prefix, "{path:.*}"), serve_static, {}))
|
|
2854
|
+
|
|
2855
|
+
def web_path(self, method, path):
|
|
2856
|
+
full_path = os.path.join(self.ext_prefix, path) if path else self.ext_prefix
|
|
2857
|
+
self.dbg(f"Registered {method:<6} {full_path}")
|
|
2858
|
+
return full_path
|
|
2859
|
+
|
|
2860
|
+
def add_get(self, path, handler, **kwargs):
|
|
2861
|
+
self.app.server_add_get.append((self.web_path("GET", path), handler, kwargs))
|
|
2862
|
+
|
|
2863
|
+
def add_post(self, path, handler, **kwargs):
|
|
2864
|
+
self.app.server_add_post.append((self.web_path("POST", path), handler, kwargs))
|
|
2865
|
+
|
|
2866
|
+
def add_put(self, path, handler, **kwargs):
|
|
2867
|
+
self.app.server_add_put.append((self.web_path("PUT", path), handler, kwargs))
|
|
2868
|
+
|
|
2869
|
+
def add_delete(self, path, handler, **kwargs):
|
|
2870
|
+
self.app.server_add_delete.append((self.web_path("DELETE", path), handler, kwargs))
|
|
2871
|
+
|
|
2872
|
+
def add_patch(self, path, handler, **kwargs):
|
|
2873
|
+
self.app.server_add_patch.append((self.web_path("PATCH", path), handler, kwargs))
|
|
2874
|
+
|
|
2875
|
+
def add_importmaps(self, dict):
|
|
2876
|
+
self.app.import_maps.update(dict)
|
|
2877
|
+
|
|
2878
|
+
def add_index_header(self, html):
|
|
2879
|
+
self.app.index_headers.append(html)
|
|
2880
|
+
|
|
2881
|
+
def add_index_footer(self, html):
|
|
2882
|
+
self.app.index_footers.append(html)
|
|
2883
|
+
|
|
2884
|
+
def get_config(self):
|
|
2885
|
+
return g_config
|
|
2886
|
+
|
|
2887
|
+
def get_cache_path(self, path=""):
|
|
2888
|
+
return get_cache_path(path)
|
|
2889
|
+
|
|
2890
|
+
def get_file_mime_type(self, filename):
|
|
2891
|
+
return get_file_mime_type(filename)
|
|
2892
|
+
|
|
2893
|
+
def chat_request(self, template=None, text=None, model=None, system_prompt=None):
|
|
2894
|
+
return self.app.chat_request(template=template, text=text, model=model, system_prompt=system_prompt)
|
|
2895
|
+
|
|
2896
|
+
def chat_completion(self, chat, context=None):
|
|
2897
|
+
return self.app.chat_completion(chat, context=context)
|
|
2898
|
+
|
|
2899
|
+
def get_providers(self):
|
|
2900
|
+
return g_handlers
|
|
2901
|
+
|
|
2902
|
+
def get_provider(self, name):
|
|
2903
|
+
return g_handlers.get(name)
|
|
2904
|
+
|
|
2905
|
+
def sanitize_tool_def(self, tool_def):
|
|
2906
|
+
"""
|
|
2907
|
+
Merge $defs parameter into tool_def property to reduce client/server complexity
|
|
2908
|
+
"""
|
|
2909
|
+
# parameters = {
|
|
2910
|
+
# "$defs": {
|
|
2911
|
+
# "AspectRatio": {
|
|
2912
|
+
# "description": "Supported aspect ratios for image generation.",
|
|
2913
|
+
# "enum": [
|
|
2914
|
+
# "1:1",
|
|
2915
|
+
# "2:3",
|
|
2916
|
+
# "16:9"
|
|
2917
|
+
# ],
|
|
2918
|
+
# "type": "string"
|
|
2919
|
+
# }
|
|
2920
|
+
# },
|
|
2921
|
+
# "properties": {
|
|
2922
|
+
# "prompt": {
|
|
2923
|
+
# "type": "string"
|
|
2924
|
+
# },
|
|
2925
|
+
# "model": {
|
|
2926
|
+
# "default": "gemini-2.5-flash-image",
|
|
2927
|
+
# "type": "string"
|
|
2928
|
+
# },
|
|
2929
|
+
# "aspect_ratio": {
|
|
2930
|
+
# "$ref": "#/$defs/AspectRatio",
|
|
2931
|
+
# "default": "1:1"
|
|
2932
|
+
# }
|
|
2933
|
+
# },
|
|
2934
|
+
# "required": [
|
|
2935
|
+
# "prompt"
|
|
2936
|
+
# ],
|
|
2937
|
+
# "type": "object"
|
|
2938
|
+
# }
|
|
2939
|
+
type = tool_def.get("type")
|
|
2940
|
+
if type == "function":
|
|
2941
|
+
func_def = tool_def.get("function", {})
|
|
2942
|
+
parameters = func_def.get("parameters", {})
|
|
2943
|
+
defs = parameters.get("$defs", {})
|
|
2944
|
+
properties = parameters.get("properties", {})
|
|
2945
|
+
for prop_name, prop_def in properties.items():
|
|
2946
|
+
if "$ref" in prop_def:
|
|
2947
|
+
ref = prop_def["$ref"]
|
|
2948
|
+
if ref.startswith("#/$defs/"):
|
|
2949
|
+
def_name = ref.replace("#/$defs/", "")
|
|
2950
|
+
if def_name in defs:
|
|
2951
|
+
prop_def.update(defs[def_name])
|
|
2952
|
+
del prop_def["$ref"]
|
|
2953
|
+
if "$defs" in parameters:
|
|
2954
|
+
del parameters["$defs"]
|
|
2955
|
+
return tool_def
|
|
2956
|
+
|
|
2957
|
+
def register_tool(self, func, tool_def=None, group=None):
|
|
2958
|
+
if tool_def is None:
|
|
2959
|
+
tool_def = function_to_tool_definition(func)
|
|
2960
|
+
|
|
2961
|
+
name = tool_def["function"]["name"]
|
|
2962
|
+
if name in self.app.tools:
|
|
2963
|
+
self.log(f"Overriding existing tool: {name}")
|
|
2964
|
+
self.app.tool_definitions = [t for t in self.app.tool_definitions if t["function"]["name"] != name]
|
|
2965
|
+
for g_tools in self.app.tool_groups.values():
|
|
2966
|
+
if name in g_tools:
|
|
2967
|
+
g_tools.remove(name)
|
|
2968
|
+
else:
|
|
2969
|
+
self.log(f"Registered tool: {name}")
|
|
2970
|
+
|
|
2971
|
+
self.app.tools[name] = func
|
|
2972
|
+
self.app.tool_definitions.append(self.sanitize_tool_def(tool_def))
|
|
2973
|
+
if not group:
|
|
2974
|
+
group = "custom"
|
|
2975
|
+
if group not in self.app.tool_groups:
|
|
2976
|
+
self.app.tool_groups[group] = []
|
|
2977
|
+
self.app.tool_groups[group].append(name)
|
|
2978
|
+
|
|
2979
|
+
def get_tool_definition(self, name):
|
|
2980
|
+
for tool_def in self.app.tool_definitions:
|
|
2981
|
+
if tool_def["function"]["name"] == name:
|
|
2982
|
+
return tool_def
|
|
2983
|
+
return None
|
|
2984
|
+
|
|
2985
|
+
def group_resources(self, resources: list):
|
|
2986
|
+
return group_resources(resources)
|
|
2987
|
+
|
|
2988
|
+
def check_auth(self, request):
|
|
2989
|
+
return self.app.check_auth(request)
|
|
2990
|
+
|
|
2991
|
+
def get_session(self, request):
|
|
2992
|
+
return self.app.get_session(request)
|
|
2993
|
+
|
|
2994
|
+
def get_username(self, request):
|
|
2995
|
+
return self.app.get_username(request)
|
|
2996
|
+
|
|
2997
|
+
def get_user_path(self, username=None):
|
|
2998
|
+
return self.app.get_user_path(username)
|
|
2999
|
+
|
|
3000
|
+
def context_to_username(self, context):
|
|
3001
|
+
if context and "request" in context:
|
|
3002
|
+
return self.get_username(context["request"])
|
|
3003
|
+
return None
|
|
3004
|
+
|
|
3005
|
+
def should_cancel_thread(self, context):
|
|
3006
|
+
return should_cancel_thread(context)
|
|
3007
|
+
|
|
3008
|
+
def cache_message_inline_data(self, message):
|
|
3009
|
+
return cache_message_inline_data(message)
|
|
3010
|
+
|
|
3011
|
+
async def exec_tool(self, name, args):
|
|
3012
|
+
return await g_exec_tool(name, args)
|
|
3013
|
+
|
|
3014
|
+
def tool_result(self, result, function_name: Optional[str] = None, function_args: Optional[dict] = None):
|
|
3015
|
+
return g_tool_result(result, function_name, function_args)
|
|
3016
|
+
|
|
3017
|
+
def tool_result_part(self, result: dict, function_name: Optional[str] = None, function_args: Optional[dict] = None):
|
|
3018
|
+
return tool_result_part(result, function_name, function_args)
|
|
3019
|
+
|
|
3020
|
+
def to_content(self, result):
|
|
3021
|
+
return to_content(result)
|
|
3022
|
+
|
|
3023
|
+
def create_chat_with_tools(self, chat, use_tools="all"):
|
|
3024
|
+
return self.app.create_chat_with_tools(chat, use_tools)
|
|
3025
|
+
|
|
3026
|
+
def chat_to_aspect_ratio(self, chat):
|
|
3027
|
+
return chat_to_aspect_ratio(chat)
|
|
3028
|
+
|
|
3029
|
+
|
|
3030
|
+
def get_extensions_path():
|
|
3031
|
+
return os.getenv("LLMS_EXTENSIONS_DIR", home_llms_path("extensions"))
|
|
3032
|
+
|
|
3033
|
+
|
|
3034
|
+
def get_disabled_extensions():
|
|
3035
|
+
ret = DISABLE_EXTENSIONS.copy()
|
|
3036
|
+
if g_config:
|
|
3037
|
+
for ext in g_config.get("disable_extensions", []):
|
|
3038
|
+
if ext not in ret:
|
|
3039
|
+
ret.append(ext)
|
|
3040
|
+
return ret
|
|
3041
|
+
|
|
3042
|
+
|
|
3043
|
+
def get_extensions_dirs():
|
|
3044
|
+
"""
|
|
3045
|
+
Returns a list of extension directories.
|
|
3046
|
+
"""
|
|
3047
|
+
extensions_path = get_extensions_path()
|
|
3048
|
+
os.makedirs(extensions_path, exist_ok=True)
|
|
3049
|
+
|
|
3050
|
+
# allow overriding builtin extensions
|
|
3051
|
+
override_extensions = []
|
|
3052
|
+
if os.path.exists(extensions_path):
|
|
3053
|
+
override_extensions = os.listdir(extensions_path)
|
|
3054
|
+
|
|
3055
|
+
ret = []
|
|
3056
|
+
disabled_extensions = get_disabled_extensions()
|
|
3057
|
+
|
|
3058
|
+
builtin_extensions_dir = _ROOT / "extensions"
|
|
3059
|
+
if not os.path.exists(builtin_extensions_dir):
|
|
3060
|
+
# look for local ./extensions dir from script
|
|
3061
|
+
builtin_extensions_dir = os.path.join(os.path.dirname(__file__), "extensions")
|
|
3062
|
+
|
|
3063
|
+
_dbg(f"Loading extensions from {builtin_extensions_dir}")
|
|
3064
|
+
if os.path.exists(builtin_extensions_dir):
|
|
3065
|
+
for item in os.listdir(builtin_extensions_dir):
|
|
3066
|
+
if os.path.isdir(os.path.join(builtin_extensions_dir, item)):
|
|
3067
|
+
if item in override_extensions:
|
|
3068
|
+
continue
|
|
3069
|
+
if item in disabled_extensions:
|
|
3070
|
+
continue
|
|
3071
|
+
ret.append(os.path.join(builtin_extensions_dir, item))
|
|
3072
|
+
|
|
3073
|
+
if os.path.exists(extensions_path):
|
|
3074
|
+
for item in os.listdir(extensions_path):
|
|
3075
|
+
if os.path.isdir(os.path.join(extensions_path, item)):
|
|
3076
|
+
if item in disabled_extensions:
|
|
3077
|
+
continue
|
|
3078
|
+
ret.append(os.path.join(extensions_path, item))
|
|
3079
|
+
|
|
3080
|
+
return ret
|
|
3081
|
+
|
|
3082
|
+
|
|
3083
|
+
def init_extensions(parser):
|
|
3084
|
+
"""
|
|
3085
|
+
Initializes extensions by loading their __init__.py files and calling the __parser__ function if it exists.
|
|
3086
|
+
"""
|
|
3087
|
+
for item_path in get_extensions_dirs():
|
|
3088
|
+
item = os.path.basename(item_path)
|
|
3089
|
+
|
|
3090
|
+
if os.path.isdir(item_path):
|
|
3091
|
+
try:
|
|
3092
|
+
# check for __parser__ function if exists in __init.__.py and call it with parser
|
|
3093
|
+
init_file = os.path.join(item_path, "__init__.py")
|
|
3094
|
+
if os.path.exists(init_file):
|
|
3095
|
+
spec = importlib.util.spec_from_file_location(item, init_file)
|
|
3096
|
+
if spec and spec.loader:
|
|
3097
|
+
module = importlib.util.module_from_spec(spec)
|
|
3098
|
+
sys.modules[item] = module
|
|
3099
|
+
spec.loader.exec_module(module)
|
|
3100
|
+
|
|
3101
|
+
parser_func = getattr(module, "__parser__", None)
|
|
3102
|
+
if callable(parser_func):
|
|
3103
|
+
parser_func(parser)
|
|
3104
|
+
_log(f"Extension {item} parser loaded")
|
|
3105
|
+
except Exception as e:
|
|
3106
|
+
_err(f"Failed to load extension {item} parser", e)
|
|
3107
|
+
|
|
3108
|
+
|
|
3109
|
+
def install_extensions():
|
|
3110
|
+
"""
|
|
3111
|
+
Scans ensure ~/.llms/extensions/ for directories with __init__.py and loads them as extensions.
|
|
3112
|
+
Calls the `__install__(ctx)` function in the extension module.
|
|
3113
|
+
"""
|
|
3114
|
+
|
|
3115
|
+
extension_dirs = get_extensions_dirs()
|
|
3116
|
+
ext_count = len(list(extension_dirs))
|
|
3117
|
+
if ext_count == 0:
|
|
3118
|
+
_log("No extensions found")
|
|
3119
|
+
return
|
|
3120
|
+
|
|
3121
|
+
disabled_extensions = get_disabled_extensions()
|
|
3122
|
+
if len(disabled_extensions) > 0:
|
|
3123
|
+
_log(f"Disabled extensions: {', '.join(disabled_extensions)}")
|
|
3124
|
+
|
|
3125
|
+
_log(f"Installing {ext_count} extension{'' if ext_count == 1 else 's'}...")
|
|
3126
|
+
|
|
3127
|
+
extensions = []
|
|
3128
|
+
|
|
3129
|
+
for item_path in extension_dirs:
|
|
3130
|
+
item = os.path.basename(item_path)
|
|
3131
|
+
|
|
3132
|
+
if os.path.isdir(item_path):
|
|
3133
|
+
sys.path.append(item_path)
|
|
3134
|
+
try:
|
|
3135
|
+
ctx = ExtensionContext(g_app, item_path)
|
|
3136
|
+
init_file = os.path.join(item_path, "__init__.py")
|
|
3137
|
+
if os.path.exists(init_file):
|
|
3138
|
+
spec = importlib.util.spec_from_file_location(item, init_file)
|
|
3139
|
+
if spec and spec.loader:
|
|
3140
|
+
module = importlib.util.module_from_spec(spec)
|
|
3141
|
+
sys.modules[item] = module
|
|
3142
|
+
spec.loader.exec_module(module)
|
|
3143
|
+
|
|
3144
|
+
install_func = getattr(module, "__install__", None)
|
|
3145
|
+
if callable(install_func):
|
|
3146
|
+
install_func(ctx)
|
|
3147
|
+
_log(f"Extension {item} installed")
|
|
3148
|
+
else:
|
|
3149
|
+
_dbg(f"Extension {item} has no __install__ function")
|
|
3150
|
+
else:
|
|
3151
|
+
_dbg(f"Extension {item} has no __init__.py")
|
|
3152
|
+
else:
|
|
3153
|
+
_dbg(f"Extension {init_file} not found")
|
|
3154
|
+
|
|
3155
|
+
if ctx.disabled:
|
|
3156
|
+
_log(f"Extension {item} was disabled")
|
|
3157
|
+
continue
|
|
3158
|
+
|
|
3159
|
+
# if ui folder exists, serve as static files at /ext/{item}/
|
|
3160
|
+
ui_path = os.path.join(item_path, "ui")
|
|
3161
|
+
if os.path.exists(ui_path):
|
|
3162
|
+
ctx.add_static_files(ui_path)
|
|
3163
|
+
|
|
3164
|
+
# Register UI extension if index.mjs exists (/ext/{item}/index.mjs)
|
|
3165
|
+
if os.path.exists(os.path.join(ui_path, "index.mjs")):
|
|
3166
|
+
ctx.register_ui_extension("index.mjs")
|
|
3167
|
+
|
|
3168
|
+
# include __load__ and __run__ hooks if they exist
|
|
3169
|
+
load_func = getattr(module, "__load__", None)
|
|
3170
|
+
if callable(load_func) and not inspect.iscoroutinefunction(load_func):
|
|
3171
|
+
_log(f"Warning: Extension {item} __load__ must be async")
|
|
3172
|
+
load_func = None
|
|
3173
|
+
|
|
3174
|
+
run_func = getattr(module, "__run__", None)
|
|
3175
|
+
if callable(run_func) and inspect.iscoroutinefunction(run_func):
|
|
3176
|
+
_log(f"Warning: Extension {item} __run__ must be sync")
|
|
3177
|
+
run_func = None
|
|
3178
|
+
|
|
3179
|
+
extensions.append({"name": item, "module": module, "ctx": ctx, "load": load_func, "run": run_func})
|
|
3180
|
+
except Exception as e:
|
|
3181
|
+
_err(f"Failed to install extension {item}", e)
|
|
3182
|
+
else:
|
|
3183
|
+
_dbg(f"Extension {item} not found: {item_path} is not a directory {os.path.exists(item_path)}")
|
|
3184
|
+
|
|
3185
|
+
return extensions
|
|
3186
|
+
|
|
3187
|
+
|
|
3188
|
+
async def load_extensions():
|
|
3189
|
+
"""
|
|
3190
|
+
Calls the `__load__(ctx)` async function in all installed extensions concurrently.
|
|
3191
|
+
"""
|
|
3192
|
+
tasks = []
|
|
3193
|
+
for ext in g_app.extensions:
|
|
3194
|
+
if ext.get("load"):
|
|
3195
|
+
task = ext["load"](ext["ctx"])
|
|
3196
|
+
tasks.append({"name": ext["name"], "task": task})
|
|
3197
|
+
|
|
3198
|
+
if len(tasks) > 0:
|
|
3199
|
+
_log(f"Loading {len(tasks)} extensions...")
|
|
3200
|
+
results = await asyncio.gather(*[t["task"] for t in tasks], return_exceptions=True)
|
|
3201
|
+
for i, result in enumerate(results):
|
|
3202
|
+
if isinstance(result, Exception):
|
|
3203
|
+
# Gather returns results in order corresponding to tasks
|
|
3204
|
+
extension = tasks[i]
|
|
3205
|
+
_err(f"Failed to load extension {extension['name']}", result)
|
|
3206
|
+
|
|
3207
|
+
|
|
3208
|
+
def run_extension_cli():
|
|
3209
|
+
"""
|
|
3210
|
+
Run the CLI for an extension.
|
|
3211
|
+
"""
|
|
3212
|
+
for item_path in get_extensions_dirs():
|
|
3213
|
+
item = os.path.basename(item_path)
|
|
3214
|
+
|
|
3215
|
+
if os.path.isdir(item_path):
|
|
3216
|
+
init_file = os.path.join(item_path, "__init__.py")
|
|
3217
|
+
if os.path.exists(init_file):
|
|
3218
|
+
ctx = ExtensionContext(g_app, item_path)
|
|
3219
|
+
try:
|
|
3220
|
+
spec = importlib.util.spec_from_file_location(item, init_file)
|
|
3221
|
+
if spec and spec.loader:
|
|
3222
|
+
module = importlib.util.module_from_spec(spec)
|
|
3223
|
+
sys.modules[item] = module
|
|
3224
|
+
spec.loader.exec_module(module)
|
|
3225
|
+
|
|
3226
|
+
# Check for __run__ function if exists in __init__.py and call it with ctx
|
|
3227
|
+
run_func = getattr(module, "__run__", None)
|
|
3228
|
+
if callable(run_func):
|
|
3229
|
+
_log(f"Running extension {item}...")
|
|
3230
|
+
handled = run_func(ctx)
|
|
3231
|
+
return handled
|
|
3232
|
+
|
|
3233
|
+
except Exception as e:
|
|
3234
|
+
_err(f"Failed to run extension {item}", e)
|
|
3235
|
+
return False
|
|
3236
|
+
|
|
3237
|
+
|
|
3238
|
+
def main():
|
|
3239
|
+
global _ROOT, g_verbose, g_default_model, g_logprefix, g_providers, g_config, g_config_path, g_app
|
|
3240
|
+
|
|
3241
|
+
_ROOT = os.getenv("LLMS_ROOT", resolve_root())
|
|
3242
|
+
if not _ROOT:
|
|
3243
|
+
print("Resource root not found")
|
|
3244
|
+
exit(1)
|
|
3245
|
+
|
|
3246
|
+
parser = argparse.ArgumentParser(description=f"llms v{VERSION}")
|
|
3247
|
+
parser.add_argument("--config", default=None, help="Path to config file", metavar="FILE")
|
|
3248
|
+
parser.add_argument("--providers", default=None, help="Path to models.dev providers file", metavar="FILE")
|
|
3249
|
+
parser.add_argument("-m", "--model", default=None, help="Model to use")
|
|
3250
|
+
|
|
3251
|
+
parser.add_argument("--chat", default=None, help="OpenAI Chat Completion Request to send", metavar="REQUEST")
|
|
3252
|
+
parser.add_argument(
|
|
3253
|
+
"-s", "--system", default=None, help="System prompt to use for chat completion", metavar="PROMPT"
|
|
3254
|
+
)
|
|
3255
|
+
parser.add_argument(
|
|
3256
|
+
"--tools", default=None, help="Tools to use for chat completion (all|none|<tool>,<tool>...)", metavar="TOOLS"
|
|
3257
|
+
)
|
|
3258
|
+
parser.add_argument("--image", default=None, help="Image input to use in chat completion")
|
|
3259
|
+
parser.add_argument("--audio", default=None, help="Audio input to use in chat completion")
|
|
3260
|
+
parser.add_argument("--file", default=None, help="File input to use in chat completion")
|
|
3261
|
+
parser.add_argument("--out", default=None, help="Image or Video Generation Request", metavar="MODALITY")
|
|
3262
|
+
parser.add_argument(
|
|
3263
|
+
"--args",
|
|
3264
|
+
default=None,
|
|
3265
|
+
help='URL-encoded parameters to add to chat request (e.g. "temperature=0.7&seed=111")',
|
|
3266
|
+
metavar="PARAMS",
|
|
3267
|
+
)
|
|
3268
|
+
parser.add_argument("--raw", action="store_true", help="Return raw AI JSON response")
|
|
3269
|
+
|
|
3270
|
+
parser.add_argument(
|
|
3271
|
+
"--list", action="store_true", help="Show list of enabled providers and their models (alias ls provider?)"
|
|
3272
|
+
)
|
|
3273
|
+
parser.add_argument("--check", default=None, help="Check validity of models for a provider", metavar="PROVIDER")
|
|
3274
|
+
|
|
3275
|
+
parser.add_argument(
|
|
3276
|
+
"--serve", default=None, help="Port to start an OpenAI Chat compatible server on", metavar="PORT"
|
|
3277
|
+
)
|
|
3278
|
+
|
|
3279
|
+
parser.add_argument("--enable", default=None, help="Enable a provider", metavar="PROVIDER")
|
|
3280
|
+
parser.add_argument("--disable", default=None, help="Disable a provider", metavar="PROVIDER")
|
|
3281
|
+
parser.add_argument("--default", default=None, help="Configure the default model to use", metavar="MODEL")
|
|
3282
|
+
|
|
3283
|
+
parser.add_argument("--init", action="store_true", help="Create a default llms.json")
|
|
3284
|
+
parser.add_argument("--update-providers", action="store_true", help="Update local models.dev providers.json")
|
|
3285
|
+
|
|
3286
|
+
parser.add_argument("--logprefix", default="", help="Prefix used in log messages", metavar="PREFIX")
|
|
3287
|
+
parser.add_argument("--verbose", action="store_true", help="Verbose output")
|
|
3288
|
+
|
|
3289
|
+
parser.add_argument(
|
|
3290
|
+
"--add",
|
|
3291
|
+
nargs="?",
|
|
3292
|
+
const="ls",
|
|
3293
|
+
default=None,
|
|
3294
|
+
help="Install an extension (lists available extensions if no name provided)",
|
|
3295
|
+
metavar="EXTENSION",
|
|
3296
|
+
)
|
|
3297
|
+
parser.add_argument(
|
|
3298
|
+
"--remove",
|
|
3299
|
+
nargs="?",
|
|
3300
|
+
const="ls",
|
|
3301
|
+
default=None,
|
|
3302
|
+
help="Remove an extension (lists installed extensions if no name provided)",
|
|
3303
|
+
metavar="EXTENSION",
|
|
3304
|
+
)
|
|
3305
|
+
|
|
3306
|
+
parser.add_argument(
|
|
3307
|
+
"--update",
|
|
3308
|
+
nargs="?",
|
|
3309
|
+
const="ls",
|
|
3310
|
+
default=None,
|
|
3311
|
+
help="Update an extension (use 'all' to update all extensions)",
|
|
3312
|
+
metavar="EXTENSION",
|
|
3313
|
+
)
|
|
3314
|
+
|
|
3315
|
+
# Load parser extensions, go through all extensions and load their parser arguments
|
|
3316
|
+
init_extensions(parser)
|
|
3317
|
+
|
|
3318
|
+
cli_args, extra_args = parser.parse_known_args()
|
|
3319
|
+
|
|
3320
|
+
g_app = AppExtensions(cli_args, extra_args)
|
|
3321
|
+
|
|
3322
|
+
# Check for verbose mode from CLI argument or environment variables
|
|
3323
|
+
verbose_env = os.getenv("VERBOSE", "").lower()
|
|
3324
|
+
if cli_args.verbose or verbose_env in ("1", "true"):
|
|
3325
|
+
g_verbose = True
|
|
3326
|
+
# printdump(cli_args)
|
|
3327
|
+
if cli_args.model:
|
|
3328
|
+
g_default_model = cli_args.model
|
|
3329
|
+
if cli_args.logprefix:
|
|
3330
|
+
g_logprefix = cli_args.logprefix
|
|
3331
|
+
|
|
3332
|
+
home_config_path = home_llms_path("llms.json")
|
|
3333
|
+
home_providers_path = home_llms_path("providers.json")
|
|
3334
|
+
home_providers_extra_path = home_llms_path("providers-extra.json")
|
|
3335
|
+
|
|
3336
|
+
if cli_args.init:
|
|
3337
|
+
if os.path.exists(home_config_path):
|
|
3338
|
+
print(f"llms.json already exists at {home_config_path}")
|
|
3339
|
+
else:
|
|
3340
|
+
asyncio.run(save_default_config(home_config_path))
|
|
3341
|
+
print(f"Created default config at {home_config_path}")
|
|
3342
|
+
|
|
3343
|
+
if os.path.exists(home_providers_path):
|
|
3344
|
+
print(f"providers.json already exists at {home_providers_path}")
|
|
3345
|
+
else:
|
|
3346
|
+
asyncio.run(save_text_url(github_url("providers.json"), home_providers_path))
|
|
3347
|
+
print(f"Created default providers config at {home_providers_path}")
|
|
3348
|
+
|
|
3349
|
+
if os.path.exists(home_providers_extra_path):
|
|
3350
|
+
print(f"providers-extra.json already exists at {home_providers_extra_path}")
|
|
3351
|
+
else:
|
|
3352
|
+
asyncio.run(save_text_url(github_url("providers-extra.json"), home_providers_extra_path))
|
|
3353
|
+
print(f"Created default extra providers config at {home_providers_extra_path}")
|
|
3354
|
+
exit(0)
|
|
3355
|
+
|
|
3356
|
+
if cli_args.providers:
|
|
3357
|
+
if not os.path.exists(cli_args.providers):
|
|
3358
|
+
print(f"providers.json not found at {cli_args.providers}")
|
|
3359
|
+
exit(1)
|
|
3360
|
+
g_providers = json.loads(text_from_file(cli_args.providers))
|
|
3361
|
+
|
|
3362
|
+
if cli_args.config:
|
|
3363
|
+
# read contents
|
|
3364
|
+
g_config_path = cli_args.config
|
|
3365
|
+
with open(g_config_path, encoding="utf-8") as f:
|
|
3366
|
+
config_json = f.read()
|
|
3367
|
+
g_config = load_config_json(config_json)
|
|
3368
|
+
|
|
3369
|
+
config_dir = os.path.dirname(g_config_path)
|
|
3370
|
+
|
|
3371
|
+
if not g_providers and os.path.exists(os.path.join(config_dir, "providers.json")):
|
|
3372
|
+
g_providers = json.loads(text_from_file(os.path.join(config_dir, "providers.json")))
|
|
3373
|
+
|
|
3374
|
+
else:
|
|
3375
|
+
# ensure llms.json and providers.json exist in home directory
|
|
3376
|
+
asyncio.run(save_home_configs())
|
|
3377
|
+
g_config_path = home_config_path
|
|
3378
|
+
g_config = load_config_json(text_from_file(g_config_path))
|
|
3379
|
+
|
|
3380
|
+
g_app.set_config(g_config)
|
|
3381
|
+
|
|
3382
|
+
if not g_providers:
|
|
3383
|
+
g_providers = json.loads(text_from_file(home_providers_path))
|
|
3384
|
+
|
|
3385
|
+
if cli_args.update_providers:
|
|
3386
|
+
asyncio.run(update_providers(home_providers_path))
|
|
3387
|
+
print(f"Updated {home_providers_path}")
|
|
3388
|
+
exit(0)
|
|
3389
|
+
|
|
3390
|
+
# if home_providers_path is older than 1 day, update providers list
|
|
3391
|
+
if (
|
|
3392
|
+
os.path.exists(home_providers_path)
|
|
3393
|
+
and (time.time() - os.path.getmtime(home_providers_path)) > 86400
|
|
3394
|
+
and os.getenv("LLMS_DISABLE_UPDATE", "") != "1"
|
|
3395
|
+
):
|
|
3396
|
+
try:
|
|
3397
|
+
asyncio.run(update_providers(home_providers_path))
|
|
3398
|
+
_log(f"Updated {home_providers_path}")
|
|
3399
|
+
except Exception as e:
|
|
3400
|
+
_err("Failed to update providers", e)
|
|
3401
|
+
|
|
3402
|
+
if cli_args.add is not None:
|
|
3403
|
+
if cli_args.add == "ls":
|
|
3404
|
+
|
|
3405
|
+
async def list_extensions():
|
|
3406
|
+
print("\nAvailable extensions:")
|
|
3407
|
+
text = await get_text("https://api.github.com/orgs/llmspy/repos?per_page=100&sort=updated")
|
|
3408
|
+
repos = json.loads(text)
|
|
3409
|
+
max_name_length = 0
|
|
3410
|
+
for repo in repos:
|
|
3411
|
+
max_name_length = max(max_name_length, len(repo["name"]))
|
|
3412
|
+
|
|
3413
|
+
for repo in repos:
|
|
3414
|
+
print(f" {repo['name']:<{max_name_length + 2}} {repo['description']}")
|
|
3415
|
+
|
|
3416
|
+
print("\nUsage:")
|
|
3417
|
+
print(" llms --add <extension>")
|
|
3418
|
+
print(" llms --add <github-user>/<repo>")
|
|
3419
|
+
|
|
3420
|
+
asyncio.run(list_extensions())
|
|
3421
|
+
exit(0)
|
|
3422
|
+
|
|
3423
|
+
async def install_extension(name):
|
|
3424
|
+
# Determine git URL and target directory name
|
|
3425
|
+
if "/" in name:
|
|
3426
|
+
git_url = f"https://github.com/{name}"
|
|
3427
|
+
target_name = name.split("/")[-1]
|
|
3428
|
+
else:
|
|
3429
|
+
git_url = f"https://github.com/llmspy/{name}"
|
|
3430
|
+
target_name = name
|
|
3431
|
+
|
|
3432
|
+
# check extension is not already installed
|
|
3433
|
+
extensions_path = get_extensions_path()
|
|
3434
|
+
target_path = os.path.join(extensions_path, target_name)
|
|
3435
|
+
|
|
3436
|
+
if os.path.exists(target_path):
|
|
3437
|
+
print(f"Extension {target_name} is already installed at {target_path}")
|
|
3438
|
+
return
|
|
3439
|
+
|
|
3440
|
+
print(f"Installing extension: {name}")
|
|
3441
|
+
print(f"Cloning from {git_url} to {target_path}...")
|
|
3442
|
+
|
|
3443
|
+
try:
|
|
3444
|
+
subprocess.run(["git", "clone", git_url, target_path], check=True)
|
|
3445
|
+
|
|
3446
|
+
# Check for requirements.txt
|
|
3447
|
+
requirements_path = os.path.join(target_path, "requirements.txt")
|
|
3448
|
+
if os.path.exists(requirements_path):
|
|
3449
|
+
print(f"Installing dependencies from {requirements_path}...")
|
|
3450
|
+
|
|
3451
|
+
# Check if uv is installed
|
|
3452
|
+
has_uv = False
|
|
3453
|
+
try:
|
|
3454
|
+
subprocess.run(
|
|
3455
|
+
["uv", "--version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True
|
|
3456
|
+
)
|
|
3457
|
+
has_uv = True
|
|
3458
|
+
except (subprocess.CalledProcessError, FileNotFoundError):
|
|
3459
|
+
pass
|
|
3460
|
+
|
|
3461
|
+
if has_uv:
|
|
3462
|
+
subprocess.run(
|
|
3463
|
+
["uv", "pip", "install", "-p", sys.executable, "-r", "requirements.txt"],
|
|
3464
|
+
cwd=target_path,
|
|
3465
|
+
check=True,
|
|
3466
|
+
)
|
|
3467
|
+
else:
|
|
3468
|
+
subprocess.run(
|
|
3469
|
+
[sys.executable, "-m", "pip", "install", "-r", "requirements.txt"],
|
|
3470
|
+
cwd=target_path,
|
|
3471
|
+
check=True,
|
|
3472
|
+
)
|
|
3473
|
+
print("Dependencies installed successfully.")
|
|
3474
|
+
|
|
3475
|
+
print(f"Extension {target_name} installed successfully.")
|
|
3476
|
+
|
|
3477
|
+
except subprocess.CalledProcessError as e:
|
|
3478
|
+
print(f"Failed to install extension: {e}")
|
|
3479
|
+
# cleanup if clone failed but directory was created (unlikely with simple git clone but good practice)
|
|
3480
|
+
if os.path.exists(target_path) and not os.listdir(target_path):
|
|
3481
|
+
os.rmdir(target_path)
|
|
3482
|
+
|
|
3483
|
+
asyncio.run(install_extension(cli_args.add))
|
|
3484
|
+
exit(0)
|
|
3485
|
+
|
|
3486
|
+
if cli_args.remove is not None:
|
|
3487
|
+
if cli_args.remove == "ls":
|
|
3488
|
+
# List installed extensions
|
|
3489
|
+
extensions_path = get_extensions_path()
|
|
3490
|
+
extensions = os.listdir(extensions_path)
|
|
3491
|
+
if len(extensions) == 0:
|
|
3492
|
+
print("No extensions installed.")
|
|
3493
|
+
exit(0)
|
|
3494
|
+
print("Installed extensions:")
|
|
3495
|
+
for extension in extensions:
|
|
3496
|
+
print(f" {extension}")
|
|
3497
|
+
exit(0)
|
|
3498
|
+
# Remove an extension
|
|
3499
|
+
extension_name = cli_args.remove
|
|
3500
|
+
extensions_path = get_extensions_path()
|
|
3501
|
+
target_path = os.path.join(extensions_path, extension_name)
|
|
3502
|
+
|
|
3503
|
+
if not os.path.exists(target_path):
|
|
3504
|
+
print(f"Extension {extension_name} not found at {target_path}")
|
|
3505
|
+
exit(1)
|
|
3506
|
+
|
|
3507
|
+
print(f"Removing extension: {extension_name}...")
|
|
3508
|
+
try:
|
|
3509
|
+
shutil.rmtree(target_path)
|
|
3510
|
+
print(f"Extension {extension_name} removed successfully.")
|
|
3511
|
+
except Exception as e:
|
|
3512
|
+
print(f"Failed to remove extension: {e}")
|
|
3513
|
+
exit(1)
|
|
3514
|
+
|
|
3515
|
+
exit(0)
|
|
3516
|
+
|
|
3517
|
+
if cli_args.update:
|
|
3518
|
+
if cli_args.update == "ls":
|
|
3519
|
+
# List installed extensions
|
|
3520
|
+
extensions_path = get_extensions_path()
|
|
3521
|
+
extensions = os.listdir(extensions_path)
|
|
3522
|
+
if len(extensions) == 0:
|
|
3523
|
+
print("No extensions installed.")
|
|
3524
|
+
exit(0)
|
|
3525
|
+
print("Installed extensions:")
|
|
3526
|
+
for extension in extensions:
|
|
3527
|
+
print(f" {extension}")
|
|
3528
|
+
|
|
3529
|
+
print("\nUsage:")
|
|
3530
|
+
print(" llms --update <extension>")
|
|
3531
|
+
print(" llms --update all")
|
|
3532
|
+
exit(0)
|
|
3533
|
+
|
|
3534
|
+
async def update_extensions(extension_name):
|
|
3535
|
+
extensions_path = get_extensions_path()
|
|
3536
|
+
for extension in os.listdir(extensions_path):
|
|
3537
|
+
extension_path = os.path.join(extensions_path, extension)
|
|
3538
|
+
if os.path.isdir(extension_path):
|
|
3539
|
+
if extension_name != "all" and extension != extension_name:
|
|
3540
|
+
continue
|
|
3541
|
+
result = subprocess.run(["git", "pull"], cwd=extension_path, capture_output=True)
|
|
3542
|
+
if result.returncode != 0:
|
|
3543
|
+
print(f"Failed to update extension {extension}: {result.stderr.decode('utf-8')}")
|
|
3544
|
+
continue
|
|
3545
|
+
print(f"Updated extension {extension}")
|
|
3546
|
+
_log(result.stdout.decode("utf-8"))
|
|
3547
|
+
|
|
3548
|
+
asyncio.run(update_extensions(cli_args.update))
|
|
3549
|
+
exit(0)
|
|
3550
|
+
|
|
3551
|
+
g_app.extensions = install_extensions()
|
|
3552
|
+
|
|
3553
|
+
# Use a persistent event loop to ensure async connections (like MCP)
|
|
3554
|
+
# established in load_extensions() remain active during cli_chat()
|
|
3555
|
+
loop = asyncio.new_event_loop()
|
|
3556
|
+
asyncio.set_event_loop(loop)
|
|
3557
|
+
|
|
3558
|
+
loop.run_until_complete(reload_providers())
|
|
3559
|
+
loop.run_until_complete(load_extensions())
|
|
3560
|
+
g_app.loaded = True
|
|
3561
|
+
|
|
3562
|
+
# print names
|
|
3563
|
+
_log(f"enabled providers: {', '.join(g_handlers.keys())}")
|
|
3564
|
+
|
|
3565
|
+
filter_list = []
|
|
3566
|
+
if len(extra_args) > 0:
|
|
3567
|
+
arg = extra_args[0]
|
|
3568
|
+
if arg == "ls":
|
|
3569
|
+
cli_args.list = True
|
|
3570
|
+
if len(extra_args) > 1:
|
|
3571
|
+
filter_list = extra_args[1:]
|
|
3572
|
+
|
|
3573
|
+
if cli_args.list:
|
|
3574
|
+
# Show list of enabled providers and their models
|
|
3575
|
+
enabled = []
|
|
3576
|
+
provider_count = 0
|
|
3577
|
+
model_count = 0
|
|
3578
|
+
|
|
3579
|
+
max_model_length = 0
|
|
3580
|
+
for name, provider in g_handlers.items():
|
|
3581
|
+
if len(filter_list) > 0 and name not in filter_list:
|
|
3582
|
+
continue
|
|
3583
|
+
for model in provider.models:
|
|
3584
|
+
max_model_length = max(max_model_length, len(model))
|
|
3585
|
+
|
|
3586
|
+
for name, provider in g_handlers.items():
|
|
3587
|
+
if len(filter_list) > 0 and name not in filter_list:
|
|
3588
|
+
continue
|
|
3589
|
+
provider_count += 1
|
|
3590
|
+
print(f"{name}:")
|
|
3591
|
+
enabled.append(name)
|
|
3592
|
+
for model in provider.models:
|
|
3593
|
+
model_count += 1
|
|
3594
|
+
model_cost_info = None
|
|
3595
|
+
if "cost" in provider.models[model]:
|
|
3596
|
+
model_cost = provider.models[model]["cost"]
|
|
3597
|
+
if "input" in model_cost and "output" in model_cost:
|
|
3598
|
+
if model_cost["input"] == 0 and model_cost["output"] == 0:
|
|
3599
|
+
model_cost_info = " 0"
|
|
3600
|
+
else:
|
|
3601
|
+
model_cost_info = f"{model_cost['input']:5} / {model_cost['output']}"
|
|
3602
|
+
print(f" {model:{max_model_length}} {model_cost_info or ''}")
|
|
3603
|
+
|
|
3604
|
+
print(f"\n{model_count} models available from {provider_count} providers")
|
|
3605
|
+
|
|
3606
|
+
print_status()
|
|
3607
|
+
g_app.exit(0)
|
|
3608
|
+
|
|
3609
|
+
if cli_args.check is not None:
|
|
3610
|
+
# Check validity of models for a provider
|
|
3611
|
+
provider_name = cli_args.check
|
|
3612
|
+
model_names = extra_args if len(extra_args) > 0 else None
|
|
3613
|
+
provider_name = cli_args.check
|
|
3614
|
+
model_names = extra_args if len(extra_args) > 0 else None
|
|
3615
|
+
loop.run_until_complete(check_models(provider_name, model_names))
|
|
3616
|
+
g_app.exit(0)
|
|
3617
|
+
|
|
3618
|
+
if cli_args.serve is not None:
|
|
3619
|
+
# Disable inactive providers and save to config before starting server
|
|
3620
|
+
all_providers = g_config["providers"].keys()
|
|
3621
|
+
enabled_providers = list(g_handlers.keys())
|
|
3622
|
+
disable_providers = []
|
|
3623
|
+
for provider in all_providers:
|
|
3624
|
+
provider_config = g_config["providers"][provider]
|
|
3625
|
+
if provider not in enabled_providers and "enabled" in provider_config and provider_config["enabled"]:
|
|
3626
|
+
provider_config["enabled"] = False
|
|
3627
|
+
disable_providers.append(provider)
|
|
3628
|
+
|
|
3629
|
+
if len(disable_providers) > 0:
|
|
3630
|
+
_log(f"Disabled unavailable providers: {', '.join(disable_providers)}")
|
|
3631
|
+
save_config(g_config)
|
|
3632
|
+
|
|
3633
|
+
# Start server
|
|
3634
|
+
port = int(cli_args.serve)
|
|
3635
|
+
|
|
3636
|
+
# Validate auth configuration if enabled
|
|
3637
|
+
auth_enabled = g_config.get("auth", {}).get("enabled", False)
|
|
3638
|
+
if auth_enabled:
|
|
3639
|
+
github_config = g_config.get("auth", {}).get("github", {})
|
|
3640
|
+
client_id = github_config.get("client_id", "")
|
|
3641
|
+
client_secret = github_config.get("client_secret", "")
|
|
3642
|
+
|
|
3643
|
+
# Expand environment variables
|
|
3644
|
+
if client_id.startswith("$"):
|
|
3645
|
+
client_id = client_id[1:]
|
|
3646
|
+
if client_secret.startswith("$"):
|
|
3647
|
+
client_secret = client_secret[1:]
|
|
3648
|
+
|
|
3649
|
+
client_id = os.getenv(client_id, client_id)
|
|
3650
|
+
client_secret = os.getenv(client_secret, client_secret)
|
|
3651
|
+
|
|
3652
|
+
if (
|
|
3653
|
+
not client_id
|
|
3654
|
+
or not client_secret
|
|
3655
|
+
or client_id == "GITHUB_CLIENT_ID"
|
|
3656
|
+
or client_secret == "GITHUB_CLIENT_SECRET"
|
|
3657
|
+
):
|
|
3658
|
+
print("ERROR: Authentication is enabled but GitHub OAuth is not properly configured.")
|
|
3659
|
+
print("Please set GITHUB_CLIENT_ID and GITHUB_CLIENT_SECRET environment variables,")
|
|
3660
|
+
print("or disable authentication by setting 'auth.enabled' to false in llms.json")
|
|
3661
|
+
exit(1)
|
|
3662
|
+
|
|
3663
|
+
_log("Authentication enabled - GitHub OAuth configured")
|
|
3664
|
+
|
|
3665
|
+
client_max_size = g_config.get("limits", {}).get(
|
|
3666
|
+
"client_max_size", 20 * 1024 * 1024
|
|
3667
|
+
) # 20MB max request size (to handle base64 encoding overhead)
|
|
3668
|
+
_log(f"client_max_size set to {client_max_size} bytes ({client_max_size / 1024 / 1024:.1f}MB)")
|
|
3669
|
+
app = web.Application(client_max_size=client_max_size)
|
|
3670
|
+
|
|
3671
|
+
async def chat_handler(request):
|
|
3672
|
+
# Check authentication if enabled
|
|
3673
|
+
is_authenticated, user_data = g_app.check_auth(request)
|
|
3674
|
+
if not is_authenticated:
|
|
3675
|
+
return web.json_response(g_app.error_auth_required, status=401)
|
|
3676
|
+
|
|
3677
|
+
try:
|
|
3678
|
+
chat = await request.json()
|
|
3679
|
+
context = {"chat": chat, "request": request, "user": g_app.get_username(request)}
|
|
3680
|
+
metadata = chat.get("metadata", {})
|
|
3681
|
+
context["threadId"] = metadata.get("threadId", None)
|
|
3682
|
+
context["tools"] = metadata.get("tools", "all")
|
|
3683
|
+
response = await g_app.chat_completion(chat, context)
|
|
3684
|
+
return web.json_response(response)
|
|
3685
|
+
except Exception as e:
|
|
3686
|
+
return web.json_response(to_error_response(e), status=500)
|
|
3687
|
+
|
|
3688
|
+
app.router.add_post("/v1/chat/completions", chat_handler)
|
|
3689
|
+
|
|
3690
|
+
async def active_models_handler(request):
|
|
3691
|
+
return web.json_response(get_active_models())
|
|
3692
|
+
|
|
3693
|
+
app.router.add_get("/models", active_models_handler)
|
|
3694
|
+
|
|
3695
|
+
async def active_providers_handler(request):
|
|
3696
|
+
return web.json_response(api_providers())
|
|
3697
|
+
|
|
3698
|
+
app.router.add_get("/providers", active_providers_handler)
|
|
3699
|
+
|
|
3700
|
+
async def status_handler(request):
|
|
3701
|
+
enabled, disabled = provider_status()
|
|
3702
|
+
return web.json_response(
|
|
3703
|
+
{
|
|
3704
|
+
"all": list(g_config["providers"].keys()),
|
|
3705
|
+
"enabled": enabled,
|
|
3706
|
+
"disabled": disabled,
|
|
3707
|
+
}
|
|
3708
|
+
)
|
|
3709
|
+
|
|
3710
|
+
app.router.add_get("/status", status_handler)
|
|
3711
|
+
|
|
3712
|
+
async def provider_handler(request):
|
|
3713
|
+
provider = request.match_info.get("provider", "")
|
|
3714
|
+
data = await request.json()
|
|
3715
|
+
msg = None
|
|
3716
|
+
if provider:
|
|
3717
|
+
if data.get("enable", False):
|
|
3718
|
+
provider_config, msg = enable_provider(provider)
|
|
3719
|
+
_log(f"Enabled provider {provider} {msg}")
|
|
3720
|
+
if not msg:
|
|
3721
|
+
await load_llms()
|
|
3722
|
+
elif data.get("disable", False):
|
|
3723
|
+
disable_provider(provider)
|
|
3724
|
+
_log(f"Disabled provider {provider}")
|
|
3725
|
+
enabled, disabled = provider_status()
|
|
3726
|
+
return web.json_response(
|
|
3727
|
+
{
|
|
3728
|
+
"enabled": enabled,
|
|
3729
|
+
"disabled": disabled,
|
|
3730
|
+
"feedback": msg or "",
|
|
3731
|
+
}
|
|
3732
|
+
)
|
|
3733
|
+
|
|
3734
|
+
app.router.add_post("/providers/{provider}", provider_handler)
|
|
3735
|
+
|
|
3736
|
+
async def upload_handler(request):
|
|
3737
|
+
# Check authentication if enabled
|
|
3738
|
+
is_authenticated, user_data = g_app.check_auth(request)
|
|
3739
|
+
if not is_authenticated:
|
|
3740
|
+
return web.json_response(g_app.error_auth_required, status=401)
|
|
3741
|
+
|
|
3742
|
+
reader = await request.multipart()
|
|
3743
|
+
|
|
3744
|
+
# Read first file field
|
|
3745
|
+
field = await reader.next()
|
|
3746
|
+
while field and field.name != "file":
|
|
3747
|
+
field = await reader.next()
|
|
3748
|
+
|
|
3749
|
+
if not field:
|
|
3750
|
+
return web.json_response(create_error_response("No file provided"), status=400)
|
|
3751
|
+
|
|
3752
|
+
filename = field.filename or "file"
|
|
3753
|
+
content = await field.read()
|
|
3754
|
+
mimetype = get_file_mime_type(filename)
|
|
3755
|
+
|
|
3756
|
+
# If image, resize if needed
|
|
3757
|
+
if mimetype.startswith("image/"):
|
|
3758
|
+
content, mimetype = convert_image_if_needed(content, mimetype)
|
|
3759
|
+
|
|
3760
|
+
# Calculate SHA256
|
|
3761
|
+
sha256_hash = hashlib.sha256(content).hexdigest()
|
|
3762
|
+
ext = filename.rsplit(".", 1)[1] if "." in filename else ""
|
|
3763
|
+
if not ext:
|
|
3764
|
+
ext = mimetypes.guess_extension(mimetype) or ""
|
|
3765
|
+
if ext.startswith("."):
|
|
3766
|
+
ext = ext[1:]
|
|
3767
|
+
|
|
3768
|
+
if not ext:
|
|
3769
|
+
ext = "bin"
|
|
3770
|
+
|
|
3771
|
+
save_filename = f"{sha256_hash}.{ext}" if ext else sha256_hash
|
|
3772
|
+
|
|
3773
|
+
# Use first 2 chars for subdir to avoid too many files in one dir
|
|
3774
|
+
subdir = sha256_hash[:2]
|
|
3775
|
+
relative_path = f"{subdir}/{save_filename}"
|
|
3776
|
+
full_path = get_cache_path(relative_path)
|
|
3777
|
+
|
|
3778
|
+
# if file and its .info.json already exists, return it
|
|
3779
|
+
info_path = os.path.splitext(full_path)[0] + ".info.json"
|
|
3780
|
+
if os.path.exists(full_path) and os.path.exists(info_path):
|
|
3781
|
+
return web.json_response(json_from_file(info_path))
|
|
3782
|
+
|
|
3783
|
+
os.makedirs(os.path.dirname(full_path), exist_ok=True)
|
|
3784
|
+
|
|
3785
|
+
with open(full_path, "wb") as f:
|
|
3786
|
+
f.write(content)
|
|
3787
|
+
|
|
3788
|
+
url = f"/~cache/{relative_path}"
|
|
3789
|
+
response_data = {
|
|
3790
|
+
"date": int(time.time()),
|
|
3791
|
+
"url": url,
|
|
3792
|
+
"size": len(content),
|
|
3793
|
+
"type": mimetype,
|
|
3794
|
+
"name": filename,
|
|
3795
|
+
}
|
|
3796
|
+
|
|
3797
|
+
# If image, get dimensions
|
|
3798
|
+
if HAS_PIL and mimetype.startswith("image/"):
|
|
3799
|
+
try:
|
|
3800
|
+
with Image.open(BytesIO(content)) as img:
|
|
3801
|
+
response_data["width"] = img.width
|
|
3802
|
+
response_data["height"] = img.height
|
|
3803
|
+
except Exception:
|
|
3804
|
+
pass
|
|
3805
|
+
|
|
3806
|
+
# Save metadata
|
|
3807
|
+
info_path = os.path.splitext(full_path)[0] + ".info.json"
|
|
3808
|
+
with open(info_path, "w") as f:
|
|
3809
|
+
json.dump(response_data, f)
|
|
3810
|
+
|
|
3811
|
+
g_app.on_cache_saved_filters({"url": url, "info": response_data})
|
|
3812
|
+
|
|
3813
|
+
return web.json_response(response_data)
|
|
3814
|
+
|
|
3815
|
+
app.router.add_post("/upload", upload_handler)
|
|
3816
|
+
|
|
3817
|
+
async def extensions_handler(request):
|
|
3818
|
+
return web.json_response(g_app.ui_extensions)
|
|
3819
|
+
|
|
3820
|
+
app.router.add_get("/ext", extensions_handler)
|
|
3821
|
+
|
|
3822
|
+
async def cache_handler(request):
|
|
3823
|
+
path = request.match_info["tail"]
|
|
3824
|
+
full_path = get_cache_path(path)
|
|
3825
|
+
info_path = os.path.splitext(full_path)[0] + ".info.json"
|
|
3826
|
+
|
|
3827
|
+
if "info" in request.query:
|
|
3828
|
+
if not os.path.exists(info_path):
|
|
3829
|
+
return web.Response(text="404: Not Found", status=404)
|
|
3830
|
+
|
|
3831
|
+
# Check for directory traversal for info path
|
|
3832
|
+
try:
|
|
3833
|
+
cache_root = Path(get_cache_path())
|
|
3834
|
+
requested_path = Path(info_path).resolve()
|
|
3835
|
+
if not str(requested_path).startswith(str(cache_root)):
|
|
3836
|
+
_dbg(f"Forbidden: {requested_path} is not in {cache_root}")
|
|
3837
|
+
return web.Response(text="403: Forbidden", status=403)
|
|
3838
|
+
except Exception as e:
|
|
3839
|
+
_err(f"Forbidden: {requested_path} is not in {cache_root}", e)
|
|
3840
|
+
return web.Response(text="403: Forbidden", status=403)
|
|
3841
|
+
|
|
3842
|
+
with open(info_path) as f:
|
|
3843
|
+
content = f.read()
|
|
3844
|
+
return web.Response(text=content, content_type="application/json")
|
|
3845
|
+
|
|
3846
|
+
if not os.path.exists(full_path):
|
|
3847
|
+
return web.Response(text="404: Not Found", status=404)
|
|
3848
|
+
|
|
3849
|
+
# Check for directory traversal
|
|
3850
|
+
try:
|
|
3851
|
+
cache_root = Path(get_cache_path())
|
|
3852
|
+
requested_path = Path(full_path).resolve()
|
|
3853
|
+
if not str(requested_path).startswith(str(cache_root)):
|
|
3854
|
+
_dbg(f"Forbidden: {requested_path} is not in {cache_root}")
|
|
3855
|
+
return web.Response(text="403: Forbidden", status=403)
|
|
3856
|
+
except Exception as e:
|
|
3857
|
+
_err(f"Forbidden: {requested_path} is not in {cache_root}", e)
|
|
3858
|
+
return web.Response(text="403: Forbidden", status=403)
|
|
3859
|
+
|
|
3860
|
+
mimetype = get_file_mime_type(full_path)
|
|
3861
|
+
if "download" in request.query:
|
|
3862
|
+
# download file as an attachment
|
|
3863
|
+
info = json_from_file(info_path) or {}
|
|
3864
|
+
mimetype = info.get("type", mimetype)
|
|
3865
|
+
filename = info.get("name") or os.path.basename(full_path)
|
|
3866
|
+
mtime = info.get("date", os.path.getmtime(full_path))
|
|
3867
|
+
mdate = datetime.fromtimestamp(mtime).isoformat()
|
|
3868
|
+
return web.FileResponse(
|
|
3869
|
+
full_path,
|
|
3870
|
+
headers={
|
|
3871
|
+
"Content-Disposition": f'attachment; filename="{filename}"; modification-date="{mdate}"',
|
|
3872
|
+
"Content-Type": mimetype,
|
|
3873
|
+
},
|
|
3874
|
+
)
|
|
3875
|
+
else:
|
|
3876
|
+
return web.FileResponse(full_path, headers={"Content-Type": mimetype})
|
|
3877
|
+
|
|
3878
|
+
app.router.add_get("/~cache/{tail:.*}", cache_handler)
|
|
3879
|
+
|
|
3880
|
+
# OAuth handlers
|
|
3881
|
+
async def github_auth_handler(request):
|
|
3882
|
+
"""Initiate GitHub OAuth flow"""
|
|
3883
|
+
if "auth" not in g_config or "github" not in g_config["auth"]:
|
|
3884
|
+
return web.json_response(create_error_response("GitHub OAuth not configured"), status=500)
|
|
3885
|
+
|
|
3886
|
+
auth_config = g_config["auth"]["github"]
|
|
3887
|
+
client_id = auth_config.get("client_id", "")
|
|
3888
|
+
redirect_uri = auth_config.get("redirect_uri", "")
|
|
3889
|
+
|
|
3890
|
+
# Expand environment variables
|
|
3891
|
+
if client_id.startswith("$"):
|
|
3892
|
+
client_id = client_id[1:]
|
|
3893
|
+
if redirect_uri.startswith("$"):
|
|
3894
|
+
redirect_uri = redirect_uri[1:]
|
|
3895
|
+
|
|
3896
|
+
client_id = os.getenv(client_id, client_id)
|
|
3897
|
+
redirect_uri = os.getenv(redirect_uri, redirect_uri)
|
|
3898
|
+
|
|
3899
|
+
if not client_id:
|
|
3900
|
+
return web.json_response(create_error_response("GitHub client_id not configured"), status=500)
|
|
3901
|
+
|
|
3902
|
+
# Generate CSRF state token
|
|
3903
|
+
state = secrets.token_urlsafe(32)
|
|
3904
|
+
g_oauth_states[state] = {"created": time.time(), "redirect_uri": redirect_uri}
|
|
3905
|
+
|
|
3906
|
+
# Clean up old states (older than 10 minutes)
|
|
3907
|
+
current_time = time.time()
|
|
3908
|
+
expired_states = [s for s, data in g_oauth_states.items() if current_time - data["created"] > 600]
|
|
3909
|
+
for s in expired_states:
|
|
3910
|
+
del g_oauth_states[s]
|
|
3911
|
+
|
|
3912
|
+
# Build GitHub authorization URL
|
|
3913
|
+
params = {
|
|
3914
|
+
"client_id": client_id,
|
|
3915
|
+
"redirect_uri": redirect_uri,
|
|
3916
|
+
"state": state,
|
|
3917
|
+
"scope": "read:user user:email",
|
|
3918
|
+
}
|
|
3919
|
+
auth_url = f"https://github.com/login/oauth/authorize?{urlencode(params)}"
|
|
3920
|
+
|
|
3921
|
+
return web.HTTPFound(auth_url)
|
|
3922
|
+
|
|
3923
|
+
def validate_user(github_username):
|
|
3924
|
+
auth_config = g_config["auth"]["github"]
|
|
3925
|
+
# Check if user is restricted
|
|
3926
|
+
restrict_to = auth_config.get("restrict_to", "")
|
|
3927
|
+
|
|
3928
|
+
# Expand environment variables
|
|
3929
|
+
if restrict_to.startswith("$"):
|
|
3930
|
+
restrict_to = restrict_to[1:]
|
|
3931
|
+
|
|
3932
|
+
restrict_to = os.getenv(restrict_to, None if restrict_to == "GITHUB_USERS" else restrict_to)
|
|
3933
|
+
|
|
3934
|
+
# If restrict_to is configured, validate the user
|
|
3935
|
+
if restrict_to:
|
|
3936
|
+
# Parse allowed users (comma or space delimited)
|
|
3937
|
+
allowed_users = [u.strip() for u in re.split(r"[,\s]+", restrict_to) if u.strip()]
|
|
3938
|
+
|
|
3939
|
+
# Check if user is in the allowed list
|
|
3940
|
+
if not github_username or github_username not in allowed_users:
|
|
3941
|
+
_log(f"Access denied for user: {github_username}. Not in allowed list: {allowed_users}")
|
|
3942
|
+
return web.Response(
|
|
3943
|
+
text=f"Access denied. User '{github_username}' is not authorized to access this application.",
|
|
3944
|
+
status=403,
|
|
3945
|
+
)
|
|
3946
|
+
return None
|
|
3947
|
+
|
|
3948
|
+
async def github_callback_handler(request):
|
|
3949
|
+
"""Handle GitHub OAuth callback"""
|
|
3950
|
+
code = request.query.get("code")
|
|
3951
|
+
state = request.query.get("state")
|
|
3952
|
+
|
|
3953
|
+
# Handle malformed URLs where query params are appended with & instead of ?
|
|
3954
|
+
if not code and "tail" in request.match_info:
|
|
3955
|
+
tail = request.match_info["tail"]
|
|
3956
|
+
if tail.startswith("&"):
|
|
3957
|
+
params = parse_qs(tail[1:])
|
|
3958
|
+
code = params.get("code", [None])[0]
|
|
3959
|
+
state = params.get("state", [None])[0]
|
|
3960
|
+
|
|
3961
|
+
if not code or not state:
|
|
3962
|
+
return web.Response(text="Missing code or state parameter", status=400)
|
|
3963
|
+
|
|
3964
|
+
# Verify state token (CSRF protection)
|
|
3965
|
+
if state not in g_oauth_states:
|
|
3966
|
+
return web.Response(text="Invalid state parameter", status=400)
|
|
3967
|
+
|
|
3968
|
+
g_oauth_states.pop(state)
|
|
3969
|
+
|
|
3970
|
+
if "auth" not in g_config or "github" not in g_config["auth"]:
|
|
3971
|
+
return web.json_response(create_error_response("GitHub OAuth not configured"), status=500)
|
|
3972
|
+
|
|
3973
|
+
auth_config = g_config["auth"]["github"]
|
|
3974
|
+
client_id = auth_config.get("client_id", "")
|
|
3975
|
+
client_secret = auth_config.get("client_secret", "")
|
|
3976
|
+
redirect_uri = auth_config.get("redirect_uri", "")
|
|
3977
|
+
|
|
3978
|
+
# Expand environment variables
|
|
3979
|
+
if client_id.startswith("$"):
|
|
3980
|
+
client_id = client_id[1:]
|
|
3981
|
+
if client_secret.startswith("$"):
|
|
3982
|
+
client_secret = client_secret[1:]
|
|
3983
|
+
if redirect_uri.startswith("$"):
|
|
3984
|
+
redirect_uri = redirect_uri[1:]
|
|
3985
|
+
|
|
3986
|
+
client_id = os.getenv(client_id, client_id)
|
|
3987
|
+
client_secret = os.getenv(client_secret, client_secret)
|
|
3988
|
+
redirect_uri = os.getenv(redirect_uri, redirect_uri)
|
|
3989
|
+
|
|
3990
|
+
if not client_id or not client_secret:
|
|
3991
|
+
return web.json_response(create_error_response("GitHub OAuth credentials not configured"), status=500)
|
|
3992
|
+
|
|
3993
|
+
# Exchange code for access token
|
|
3994
|
+
async with aiohttp.ClientSession() as session:
|
|
3995
|
+
token_url = "https://github.com/login/oauth/access_token"
|
|
3996
|
+
token_data = {
|
|
3997
|
+
"client_id": client_id,
|
|
3998
|
+
"client_secret": client_secret,
|
|
3999
|
+
"code": code,
|
|
4000
|
+
"redirect_uri": redirect_uri,
|
|
4001
|
+
}
|
|
4002
|
+
headers = {"Accept": "application/json"}
|
|
4003
|
+
|
|
4004
|
+
async with session.post(token_url, data=token_data, headers=headers) as resp:
|
|
4005
|
+
token_response = await resp.json()
|
|
4006
|
+
access_token = token_response.get("access_token")
|
|
4007
|
+
|
|
4008
|
+
if not access_token:
|
|
4009
|
+
error = token_response.get("error_description", "Failed to get access token")
|
|
4010
|
+
return web.json_response(create_error_response(f"OAuth error: {error}"), status=400)
|
|
4011
|
+
|
|
4012
|
+
# Fetch user info
|
|
4013
|
+
user_url = "https://api.github.com/user"
|
|
4014
|
+
headers = {"Authorization": f"Bearer {access_token}", "Accept": "application/json"}
|
|
4015
|
+
|
|
4016
|
+
async with session.get(user_url, headers=headers) as resp:
|
|
4017
|
+
user_data = await resp.json()
|
|
4018
|
+
|
|
4019
|
+
# Validate user
|
|
4020
|
+
error_response = validate_user(user_data.get("login", ""))
|
|
4021
|
+
if error_response:
|
|
4022
|
+
return error_response
|
|
4023
|
+
|
|
4024
|
+
# Create session
|
|
4025
|
+
session_token = secrets.token_urlsafe(32)
|
|
4026
|
+
g_sessions[session_token] = {
|
|
4027
|
+
"userId": str(user_data.get("id", "")),
|
|
4028
|
+
"userName": user_data.get("login", ""),
|
|
4029
|
+
"displayName": user_data.get("name", ""),
|
|
4030
|
+
"profileUrl": user_data.get("avatar_url", ""),
|
|
4031
|
+
"email": user_data.get("email", ""),
|
|
4032
|
+
"created": time.time(),
|
|
4033
|
+
}
|
|
4034
|
+
|
|
4035
|
+
# Redirect to UI with session token
|
|
4036
|
+
response = web.HTTPFound(f"/?session={session_token}")
|
|
4037
|
+
response.set_cookie("llms-token", session_token, httponly=True, path="/", max_age=86400)
|
|
4038
|
+
return response
|
|
4039
|
+
|
|
4040
|
+
async def session_handler(request):
|
|
4041
|
+
"""Validate and return session info"""
|
|
4042
|
+
session_token = get_session_token(request)
|
|
4043
|
+
|
|
4044
|
+
if not session_token or session_token not in g_sessions:
|
|
4045
|
+
return web.json_response(create_error_response("Invalid or expired session"), status=401)
|
|
4046
|
+
|
|
4047
|
+
session_data = g_sessions[session_token]
|
|
4048
|
+
|
|
4049
|
+
# Clean up old sessions (older than 24 hours)
|
|
4050
|
+
current_time = time.time()
|
|
4051
|
+
expired_sessions = [token for token, data in g_sessions.items() if current_time - data["created"] > 86400]
|
|
4052
|
+
for token in expired_sessions:
|
|
4053
|
+
del g_sessions[token]
|
|
4054
|
+
|
|
4055
|
+
return web.json_response({**session_data, "sessionToken": session_token})
|
|
4056
|
+
|
|
4057
|
+
async def logout_handler(request):
|
|
4058
|
+
"""End OAuth session"""
|
|
4059
|
+
session_token = get_session_token(request)
|
|
4060
|
+
|
|
4061
|
+
if session_token and session_token in g_sessions:
|
|
4062
|
+
del g_sessions[session_token]
|
|
4063
|
+
|
|
4064
|
+
response = web.json_response({"success": True})
|
|
4065
|
+
response.del_cookie("llms-token")
|
|
4066
|
+
return response
|
|
4067
|
+
|
|
4068
|
+
async def auth_handler(request):
|
|
4069
|
+
"""Check authentication status and return user info"""
|
|
4070
|
+
# Check for OAuth session token
|
|
4071
|
+
session_token = get_session_token(request)
|
|
4072
|
+
|
|
4073
|
+
if session_token and session_token in g_sessions:
|
|
4074
|
+
session_data = g_sessions[session_token]
|
|
4075
|
+
return web.json_response(
|
|
4076
|
+
{
|
|
4077
|
+
"userId": session_data.get("userId", ""),
|
|
4078
|
+
"userName": session_data.get("userName", ""),
|
|
4079
|
+
"displayName": session_data.get("displayName", ""),
|
|
4080
|
+
"profileUrl": session_data.get("profileUrl", ""),
|
|
4081
|
+
"authProvider": "github",
|
|
4082
|
+
}
|
|
4083
|
+
)
|
|
4084
|
+
|
|
4085
|
+
# Check for API key in Authorization header
|
|
4086
|
+
# auth_header = request.headers.get('Authorization', '')
|
|
4087
|
+
# if auth_header.startswith('Bearer '):
|
|
4088
|
+
# # For API key auth, return a basic response
|
|
4089
|
+
# # You can customize this based on your API key validation logic
|
|
4090
|
+
# api_key = auth_header[7:]
|
|
4091
|
+
# if api_key: # Add your API key validation logic here
|
|
4092
|
+
# return web.json_response({
|
|
4093
|
+
# "userId": "1",
|
|
4094
|
+
# "userName": "apiuser",
|
|
4095
|
+
# "displayName": "API User",
|
|
4096
|
+
# "profileUrl": "",
|
|
4097
|
+
# "authProvider": "apikey"
|
|
4098
|
+
# })
|
|
4099
|
+
|
|
4100
|
+
# Not authenticated - return error in expected format
|
|
4101
|
+
return web.json_response(g_app.error_auth_required, status=401)
|
|
4102
|
+
|
|
4103
|
+
app.router.add_get("/auth", auth_handler)
|
|
4104
|
+
app.router.add_get("/auth/github", github_auth_handler)
|
|
4105
|
+
app.router.add_get("/auth/github/callback", github_callback_handler)
|
|
4106
|
+
app.router.add_get("/auth/github/callback{tail:.*}", github_callback_handler)
|
|
4107
|
+
app.router.add_get("/auth/session", session_handler)
|
|
4108
|
+
app.router.add_post("/auth/logout", logout_handler)
|
|
4109
|
+
|
|
4110
|
+
async def ui_static(request: web.Request) -> web.Response:
|
|
4111
|
+
path = Path(request.match_info["path"])
|
|
4112
|
+
|
|
4113
|
+
try:
|
|
4114
|
+
# Handle both Path objects and importlib.resources Traversable objects
|
|
4115
|
+
if hasattr(_ROOT, "joinpath"):
|
|
4116
|
+
# importlib.resources Traversable
|
|
4117
|
+
resource = _ROOT.joinpath("ui").joinpath(str(path))
|
|
4118
|
+
if not resource.is_file():
|
|
4119
|
+
raise web.HTTPNotFound
|
|
4120
|
+
content = resource.read_bytes()
|
|
4121
|
+
else:
|
|
4122
|
+
# Regular Path object
|
|
4123
|
+
resource = _ROOT / "ui" / path
|
|
4124
|
+
if not resource.is_file():
|
|
4125
|
+
raise web.HTTPNotFound
|
|
4126
|
+
try:
|
|
4127
|
+
resource.relative_to(Path(_ROOT)) # basic directory-traversal guard
|
|
4128
|
+
except ValueError as e:
|
|
4129
|
+
raise web.HTTPBadRequest(text="Invalid path") from e
|
|
4130
|
+
content = resource.read_bytes()
|
|
4131
|
+
|
|
4132
|
+
content_type, _ = mimetypes.guess_type(str(path))
|
|
4133
|
+
if content_type is None:
|
|
4134
|
+
content_type = "application/octet-stream"
|
|
4135
|
+
return web.Response(body=content, content_type=content_type)
|
|
4136
|
+
except (OSError, PermissionError, AttributeError) as e:
|
|
4137
|
+
raise web.HTTPNotFound from e
|
|
4138
|
+
|
|
4139
|
+
app.router.add_get("/ui/{path:.*}", ui_static, name="ui_static")
|
|
4140
|
+
|
|
4141
|
+
async def config_handler(request):
|
|
4142
|
+
ret = {}
|
|
4143
|
+
if "defaults" not in ret:
|
|
4144
|
+
ret["defaults"] = g_config["defaults"]
|
|
4145
|
+
enabled, disabled = provider_status()
|
|
4146
|
+
ret["status"] = {"all": list(g_config["providers"].keys()), "enabled": enabled, "disabled": disabled}
|
|
4147
|
+
# Add auth configuration
|
|
4148
|
+
ret["requiresAuth"] = auth_enabled
|
|
4149
|
+
ret["authType"] = "oauth" if auth_enabled else "apikey"
|
|
4150
|
+
return web.json_response(ret)
|
|
4151
|
+
|
|
4152
|
+
app.router.add_get("/config", config_handler)
|
|
4153
|
+
|
|
4154
|
+
async def not_found_handler(request):
|
|
4155
|
+
return web.Response(text="404: Not Found", status=404)
|
|
4156
|
+
|
|
4157
|
+
app.router.add_get("/favicon.ico", not_found_handler)
|
|
4158
|
+
|
|
4159
|
+
# go through and register all g_app extensions
|
|
4160
|
+
for handler in g_app.server_add_get:
|
|
4161
|
+
handler_fn = handler[1]
|
|
4162
|
+
|
|
4163
|
+
async def managed_handler(request, handler_fn=handler_fn):
|
|
4164
|
+
try:
|
|
4165
|
+
return await handler_fn(request)
|
|
4166
|
+
except Exception as e:
|
|
4167
|
+
return web.json_response(to_error_response(e, stacktrace=g_verbose), status=500)
|
|
4168
|
+
|
|
4169
|
+
app.router.add_get(handler[0], managed_handler, **handler[2])
|
|
4170
|
+
for handler in g_app.server_add_post:
|
|
4171
|
+
handler_fn = handler[1]
|
|
4172
|
+
|
|
4173
|
+
async def managed_handler(request, handler_fn=handler_fn):
|
|
4174
|
+
try:
|
|
4175
|
+
return await handler_fn(request)
|
|
4176
|
+
except Exception as e:
|
|
4177
|
+
return web.json_response(to_error_response(e, stacktrace=g_verbose), status=500)
|
|
4178
|
+
|
|
4179
|
+
app.router.add_post(handler[0], managed_handler, **handler[2])
|
|
4180
|
+
for handler in g_app.server_add_put:
|
|
4181
|
+
handler_fn = handler[1]
|
|
4182
|
+
|
|
4183
|
+
async def managed_handler(request, handler_fn=handler_fn):
|
|
4184
|
+
try:
|
|
4185
|
+
return await handler_fn(request)
|
|
4186
|
+
except Exception as e:
|
|
4187
|
+
return web.json_response(to_error_response(e, stacktrace=g_verbose), status=500)
|
|
4188
|
+
|
|
4189
|
+
app.router.add_put(handler[0], managed_handler, **handler[2])
|
|
4190
|
+
for handler in g_app.server_add_delete:
|
|
4191
|
+
handler_fn = handler[1]
|
|
4192
|
+
|
|
4193
|
+
async def managed_handler(request, handler_fn=handler_fn):
|
|
4194
|
+
try:
|
|
4195
|
+
return await handler_fn(request)
|
|
4196
|
+
except Exception as e:
|
|
4197
|
+
return web.json_response(to_error_response(e, stacktrace=g_verbose), status=500)
|
|
4198
|
+
|
|
4199
|
+
app.router.add_delete(handler[0], managed_handler, **handler[2])
|
|
4200
|
+
for handler in g_app.server_add_patch:
|
|
4201
|
+
handler_fn = handler[1]
|
|
4202
|
+
|
|
4203
|
+
async def managed_handler(request, handler_fn=handler_fn):
|
|
4204
|
+
try:
|
|
4205
|
+
return await handler_fn(request)
|
|
4206
|
+
except Exception as e:
|
|
4207
|
+
return web.json_response(to_error_response(e, stacktrace=g_verbose), status=500)
|
|
4208
|
+
|
|
4209
|
+
app.router.add_patch(handler[0], managed_handler, **handler[2])
|
|
4210
|
+
|
|
4211
|
+
# Serve index.html from root
|
|
4212
|
+
async def index_handler(request):
|
|
4213
|
+
index_content = read_resource_file_bytes("index.html")
|
|
4214
|
+
|
|
4215
|
+
importmaps = {"imports": g_app.import_maps}
|
|
4216
|
+
importmaps_script = '<script type="importmap">\n' + json.dumps(importmaps, indent=4) + "\n</script>"
|
|
4217
|
+
index_content = index_content.replace(
|
|
4218
|
+
b'<script type="importmap"></script>',
|
|
4219
|
+
importmaps_script.encode("utf-8"),
|
|
4220
|
+
)
|
|
4221
|
+
|
|
4222
|
+
if len(g_app.index_headers) > 0:
|
|
4223
|
+
html_header = ""
|
|
4224
|
+
for header in g_app.index_headers:
|
|
4225
|
+
html_header += header
|
|
4226
|
+
# replace </head> with html_header
|
|
4227
|
+
index_content = index_content.replace(b"</head>", html_header.encode("utf-8") + b"\n</head>")
|
|
4228
|
+
|
|
4229
|
+
if len(g_app.index_footers) > 0:
|
|
4230
|
+
html_footer = ""
|
|
4231
|
+
for footer in g_app.index_footers:
|
|
4232
|
+
html_footer += footer
|
|
4233
|
+
# replace </body> with html_footer
|
|
4234
|
+
index_content = index_content.replace(b"</body>", html_footer.encode("utf-8") + b"\n</body>")
|
|
4235
|
+
|
|
4236
|
+
return web.Response(body=index_content, content_type="text/html")
|
|
4237
|
+
|
|
4238
|
+
app.router.add_get("/", index_handler)
|
|
4239
|
+
|
|
4240
|
+
# Serve index.html as fallback route (SPA routing)
|
|
4241
|
+
app.router.add_route("*", "/{tail:.*}", index_handler)
|
|
4242
|
+
|
|
4243
|
+
# Setup file watcher for config files
|
|
4244
|
+
async def start_background_tasks(app):
|
|
4245
|
+
"""Start background tasks when the app starts"""
|
|
4246
|
+
# Start watching config files in the background
|
|
4247
|
+
asyncio.create_task(watch_config_files(g_config_path, home_providers_path))
|
|
4248
|
+
|
|
4249
|
+
app.on_startup.append(start_background_tasks)
|
|
4250
|
+
|
|
4251
|
+
# go through and register all g_app extensions
|
|
4252
|
+
|
|
4253
|
+
print(f"Starting server on port {port}...")
|
|
4254
|
+
web.run_app(app, host="0.0.0.0", port=port, print=_log)
|
|
4255
|
+
g_app.exit(0)
|
|
4256
|
+
|
|
4257
|
+
if cli_args.enable is not None:
|
|
4258
|
+
if cli_args.enable.endswith(","):
|
|
4259
|
+
cli_args.enable = cli_args.enable[:-1].strip()
|
|
4260
|
+
enable_providers = [cli_args.enable]
|
|
4261
|
+
all_providers = g_config["providers"].keys()
|
|
4262
|
+
msgs = []
|
|
4263
|
+
if len(extra_args) > 0:
|
|
4264
|
+
for arg in extra_args:
|
|
4265
|
+
if arg.endswith(","):
|
|
4266
|
+
arg = arg[:-1].strip()
|
|
4267
|
+
if arg in all_providers:
|
|
4268
|
+
enable_providers.append(arg)
|
|
4269
|
+
|
|
4270
|
+
for provider in enable_providers:
|
|
4271
|
+
if provider not in g_config["providers"]:
|
|
4272
|
+
print(f"Provider '{provider}' not found")
|
|
4273
|
+
print(f"Available providers: {', '.join(g_config['providers'].keys())}")
|
|
4274
|
+
exit(1)
|
|
4275
|
+
if provider in g_config["providers"]:
|
|
4276
|
+
provider_config, msg = enable_provider(provider)
|
|
4277
|
+
print(f"\nEnabled provider {provider}:")
|
|
4278
|
+
printdump(provider_config)
|
|
4279
|
+
if msg:
|
|
4280
|
+
msgs.append(msg)
|
|
4281
|
+
|
|
4282
|
+
print_status()
|
|
4283
|
+
if len(msgs) > 0:
|
|
4284
|
+
print("\n" + "\n".join(msgs))
|
|
4285
|
+
g_app.exit(0)
|
|
4286
|
+
|
|
4287
|
+
if cli_args.disable is not None:
|
|
4288
|
+
if cli_args.disable.endswith(","):
|
|
4289
|
+
cli_args.disable = cli_args.disable[:-1].strip()
|
|
4290
|
+
disable_providers = [cli_args.disable]
|
|
4291
|
+
all_providers = g_config["providers"].keys()
|
|
4292
|
+
if len(extra_args) > 0:
|
|
4293
|
+
for arg in extra_args:
|
|
4294
|
+
if arg.endswith(","):
|
|
4295
|
+
arg = arg[:-1].strip()
|
|
4296
|
+
if arg in all_providers:
|
|
4297
|
+
disable_providers.append(arg)
|
|
4298
|
+
|
|
4299
|
+
for provider in disable_providers:
|
|
4300
|
+
if provider not in g_config["providers"]:
|
|
4301
|
+
print(f"Provider {provider} not found")
|
|
4302
|
+
print(f"Available providers: {', '.join(g_config['providers'].keys())}")
|
|
4303
|
+
exit(1)
|
|
4304
|
+
disable_provider(provider)
|
|
4305
|
+
print(f"\nDisabled provider {provider}")
|
|
4306
|
+
|
|
4307
|
+
print_status()
|
|
4308
|
+
g_app.exit(0)
|
|
4309
|
+
|
|
4310
|
+
if cli_args.default is not None:
|
|
4311
|
+
default_model = cli_args.default
|
|
4312
|
+
provider_model = get_provider_model(default_model)
|
|
4313
|
+
if provider_model is None:
|
|
4314
|
+
print(f"Model {default_model} not found")
|
|
4315
|
+
exit(1)
|
|
4316
|
+
default_text = g_config["defaults"]["text"]
|
|
4317
|
+
default_text["model"] = default_model
|
|
4318
|
+
save_config(g_config)
|
|
4319
|
+
print(f"\nDefault model set to: {default_model}")
|
|
4320
|
+
g_app.exit(0)
|
|
4321
|
+
|
|
4322
|
+
if (
|
|
4323
|
+
cli_args.chat is not None
|
|
4324
|
+
or cli_args.image is not None
|
|
4325
|
+
or cli_args.audio is not None
|
|
4326
|
+
or cli_args.file is not None
|
|
4327
|
+
or cli_args.out is not None
|
|
4328
|
+
or len(extra_args) > 0
|
|
4329
|
+
):
|
|
4330
|
+
try:
|
|
4331
|
+
chat = g_config["defaults"]["text"]
|
|
4332
|
+
if cli_args.image is not None:
|
|
4333
|
+
chat = g_config["defaults"]["image"]
|
|
4334
|
+
elif cli_args.audio is not None:
|
|
4335
|
+
chat = g_config["defaults"]["audio"]
|
|
4336
|
+
elif cli_args.file is not None:
|
|
4337
|
+
chat = g_config["defaults"]["file"]
|
|
4338
|
+
elif cli_args.out is not None:
|
|
4339
|
+
template = f"out:{cli_args.out}"
|
|
4340
|
+
if template not in g_config["defaults"]:
|
|
4341
|
+
print(f"Template for output modality '{cli_args.out}' not found")
|
|
4342
|
+
exit(1)
|
|
4343
|
+
chat = g_config["defaults"][template]
|
|
4344
|
+
if cli_args.chat is not None:
|
|
4345
|
+
chat_path = os.path.join(os.path.dirname(__file__), cli_args.chat)
|
|
4346
|
+
if not os.path.exists(chat_path):
|
|
4347
|
+
print(f"Chat request template not found: {chat_path}")
|
|
4348
|
+
exit(1)
|
|
4349
|
+
_log(f"Using chat: {chat_path}")
|
|
4350
|
+
|
|
4351
|
+
with open(chat_path) as f:
|
|
4352
|
+
chat_json = f.read()
|
|
4353
|
+
chat = json.loads(chat_json)
|
|
4354
|
+
|
|
4355
|
+
if cli_args.system is not None:
|
|
4356
|
+
chat["messages"].insert(0, {"role": "system", "content": cli_args.system})
|
|
4357
|
+
|
|
4358
|
+
if len(extra_args) > 0:
|
|
4359
|
+
prompt = " ".join(extra_args)
|
|
4360
|
+
if not chat["messages"] or len(chat["messages"]) == 0:
|
|
4361
|
+
chat["messages"] = [{"role": "user", "content": [{"type": "text", "text": ""}]}]
|
|
4362
|
+
|
|
4363
|
+
# replace content of last message if exists, else add
|
|
4364
|
+
last_msg = chat["messages"][-1] if "messages" in chat else None
|
|
4365
|
+
if last_msg and last_msg["role"] == "user":
|
|
4366
|
+
if isinstance(last_msg["content"], list):
|
|
4367
|
+
last_msg["content"][-1]["text"] = prompt
|
|
4368
|
+
else:
|
|
4369
|
+
last_msg["content"] = prompt
|
|
4370
|
+
else:
|
|
4371
|
+
chat["messages"].append({"role": "user", "content": prompt})
|
|
4372
|
+
|
|
4373
|
+
# Parse args parameters if provided
|
|
4374
|
+
args = None
|
|
4375
|
+
if cli_args.args is not None:
|
|
4376
|
+
args = parse_args_params(cli_args.args)
|
|
4377
|
+
|
|
4378
|
+
loop.run_until_complete(
|
|
4379
|
+
cli_chat(
|
|
4380
|
+
chat,
|
|
4381
|
+
tools=cli_args.tools,
|
|
4382
|
+
image=cli_args.image,
|
|
4383
|
+
audio=cli_args.audio,
|
|
4384
|
+
file=cli_args.file,
|
|
4385
|
+
args=args,
|
|
4386
|
+
raw=cli_args.raw,
|
|
4387
|
+
)
|
|
4388
|
+
)
|
|
4389
|
+
g_app.exit(0)
|
|
4390
|
+
except Exception as e:
|
|
4391
|
+
print(f"{cli_args.logprefix}Error: {e}")
|
|
4392
|
+
if cli_args.verbose:
|
|
4393
|
+
traceback.print_exc()
|
|
4394
|
+
g_app.exit(1)
|
|
4395
|
+
|
|
4396
|
+
handled = run_extension_cli()
|
|
4397
|
+
|
|
4398
|
+
if not handled:
|
|
4399
|
+
# show usage from ArgumentParser
|
|
4400
|
+
parser.print_help()
|
|
4401
|
+
g_app.exit(0)
|
|
4402
|
+
|
|
4403
|
+
|
|
4404
|
+
if __name__ == "__main__":
|
|
4405
|
+
if MOCK or DEBUG:
|
|
4406
|
+
print(f"MOCK={MOCK} or DEBUG={DEBUG}")
|
|
4407
|
+
main()
|