llms-py 3.0.0b1__py3-none-any.whl → 3.0.0b3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llms/__pycache__/__init__.cpython-312.pyc +0 -0
- llms/__pycache__/__init__.cpython-313.pyc +0 -0
- llms/__pycache__/__init__.cpython-314.pyc +0 -0
- llms/__pycache__/__main__.cpython-312.pyc +0 -0
- llms/__pycache__/__main__.cpython-314.pyc +0 -0
- llms/__pycache__/llms.cpython-312.pyc +0 -0
- llms/__pycache__/main.cpython-312.pyc +0 -0
- llms/__pycache__/main.cpython-313.pyc +0 -0
- llms/__pycache__/main.cpython-314.pyc +0 -0
- llms/__pycache__/plugins.cpython-314.pyc +0 -0
- llms/index.html +27 -57
- llms/llms.json +48 -15
- llms/main.py +923 -624
- llms/providers/__pycache__/anthropic.cpython-314.pyc +0 -0
- llms/providers/__pycache__/chutes.cpython-314.pyc +0 -0
- llms/providers/__pycache__/google.cpython-314.pyc +0 -0
- llms/providers/__pycache__/nvidia.cpython-314.pyc +0 -0
- llms/providers/__pycache__/openai.cpython-314.pyc +0 -0
- llms/providers/__pycache__/openrouter.cpython-314.pyc +0 -0
- llms/providers/anthropic.py +189 -0
- llms/providers/chutes.py +152 -0
- llms/providers/google.py +306 -0
- llms/providers/nvidia.py +107 -0
- llms/providers/openai.py +159 -0
- llms/providers/openrouter.py +70 -0
- llms/providers-extra.json +356 -0
- llms/providers.json +1 -1
- llms/ui/App.mjs +150 -57
- llms/ui/ai.mjs +84 -50
- llms/ui/app.css +1 -4963
- llms/ui/ctx.mjs +196 -0
- llms/ui/index.mjs +117 -0
- llms/ui/lib/charts.mjs +9 -13
- llms/ui/markdown.mjs +6 -0
- llms/ui/{Analytics.mjs → modules/analytics.mjs} +76 -64
- llms/ui/{Main.mjs → modules/chat/ChatBody.mjs} +91 -179
- llms/ui/{SettingsDialog.mjs → modules/chat/SettingsDialog.mjs} +8 -8
- llms/ui/{ChatPrompt.mjs → modules/chat/index.mjs} +281 -96
- llms/ui/modules/layout.mjs +267 -0
- llms/ui/modules/model-selector.mjs +851 -0
- llms/ui/{Recents.mjs → modules/threads/Recents.mjs} +10 -11
- llms/ui/{Sidebar.mjs → modules/threads/index.mjs} +48 -45
- llms/ui/{threadStore.mjs → modules/threads/threadStore.mjs} +21 -7
- llms/ui/tailwind.input.css +441 -79
- llms/ui/utils.mjs +83 -123
- {llms_py-3.0.0b1.dist-info → llms_py-3.0.0b3.dist-info}/METADATA +1 -1
- llms_py-3.0.0b3.dist-info/RECORD +65 -0
- llms/ui/Avatar.mjs +0 -85
- llms/ui/Brand.mjs +0 -52
- llms/ui/ModelSelector.mjs +0 -693
- llms/ui/OAuthSignIn.mjs +0 -92
- llms/ui/ProviderIcon.mjs +0 -36
- llms/ui/ProviderStatus.mjs +0 -105
- llms/ui/SignIn.mjs +0 -64
- llms/ui/SystemPromptEditor.mjs +0 -31
- llms/ui/SystemPromptSelector.mjs +0 -56
- llms/ui/Welcome.mjs +0 -8
- llms/ui.json +0 -1069
- llms_py-3.0.0b1.dist-info/RECORD +0 -49
- {llms_py-3.0.0b1.dist-info → llms_py-3.0.0b3.dist-info}/WHEEL +0 -0
- {llms_py-3.0.0b1.dist-info → llms_py-3.0.0b3.dist-info}/entry_points.txt +0 -0
- {llms_py-3.0.0b1.dist-info → llms_py-3.0.0b3.dist-info}/licenses/LICENSE +0 -0
- {llms_py-3.0.0b1.dist-info → llms_py-3.0.0b3.dist-info}/top_level.txt +0 -0
llms/main.py
CHANGED
|
@@ -9,18 +9,20 @@
|
|
|
9
9
|
import argparse
|
|
10
10
|
import asyncio
|
|
11
11
|
import base64
|
|
12
|
-
from datetime import datetime
|
|
13
12
|
import hashlib
|
|
13
|
+
import importlib.util
|
|
14
14
|
import json
|
|
15
15
|
import mimetypes
|
|
16
16
|
import os
|
|
17
17
|
import re
|
|
18
18
|
import secrets
|
|
19
|
+
import shutil
|
|
19
20
|
import site
|
|
20
21
|
import subprocess
|
|
21
22
|
import sys
|
|
22
23
|
import time
|
|
23
24
|
import traceback
|
|
25
|
+
from datetime import datetime
|
|
24
26
|
from importlib import resources # Py≥3.9 (pip install importlib_resources for 3.7/3.8)
|
|
25
27
|
from io import BytesIO
|
|
26
28
|
from pathlib import Path
|
|
@@ -36,10 +38,13 @@ try:
|
|
|
36
38
|
except ImportError:
|
|
37
39
|
HAS_PIL = False
|
|
38
40
|
|
|
39
|
-
VERSION = "3.0.
|
|
41
|
+
VERSION = "3.0.0b3"
|
|
40
42
|
_ROOT = None
|
|
43
|
+
DEBUG = True # os.getenv("PYPI_SERVICESTACK") is not None
|
|
44
|
+
MOCK = False
|
|
45
|
+
MOCK_DIR = os.getenv("MOCK_DIR")
|
|
46
|
+
MOCK = os.getenv("MOCK") == "1"
|
|
41
47
|
g_config_path = None
|
|
42
|
-
g_ui_path = None
|
|
43
48
|
g_config = None
|
|
44
49
|
g_providers = None
|
|
45
50
|
g_handlers = {}
|
|
@@ -48,14 +53,25 @@ g_logprefix = ""
|
|
|
48
53
|
g_default_model = ""
|
|
49
54
|
g_sessions = {} # OAuth session storage: {session_token: {userId, userName, displayName, profileUrl, email, created}}
|
|
50
55
|
g_oauth_states = {} # CSRF protection: {state: {created, redirect_uri}}
|
|
56
|
+
g_app = None # ExtensionsContext Singleton
|
|
51
57
|
|
|
52
58
|
|
|
53
59
|
def _log(message):
|
|
54
|
-
"""Helper method for logging from the global polling task."""
|
|
55
60
|
if g_verbose:
|
|
56
61
|
print(f"{g_logprefix}{message}", flush=True)
|
|
57
62
|
|
|
58
63
|
|
|
64
|
+
def _dbg(message):
|
|
65
|
+
if DEBUG:
|
|
66
|
+
print(f"DEBUG: {message}", flush=True)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _err(message, e):
|
|
70
|
+
print(f"ERROR: {message}: {e}", flush=True)
|
|
71
|
+
if g_verbose:
|
|
72
|
+
print(traceback.format_exc(), flush=True)
|
|
73
|
+
|
|
74
|
+
|
|
59
75
|
def printdump(obj):
|
|
60
76
|
args = obj.__dict__ if hasattr(obj, "__dict__") else obj
|
|
61
77
|
print(json.dumps(args, indent=2))
|
|
@@ -88,17 +104,6 @@ def chat_summary(chat):
|
|
|
88
104
|
return json.dumps(clone, indent=2)
|
|
89
105
|
|
|
90
106
|
|
|
91
|
-
def gemini_chat_summary(gemini_chat):
|
|
92
|
-
"""Summarize Gemini chat completion request for logging. Replace inline_data with size of content only"""
|
|
93
|
-
clone = json.loads(json.dumps(gemini_chat))
|
|
94
|
-
for content in clone["contents"]:
|
|
95
|
-
for part in content["parts"]:
|
|
96
|
-
if "inline_data" in part:
|
|
97
|
-
data = part["inline_data"]["data"]
|
|
98
|
-
part["inline_data"]["data"] = f"({len(data)})"
|
|
99
|
-
return json.dumps(clone, indent=2)
|
|
100
|
-
|
|
101
|
-
|
|
102
107
|
image_exts = ["png", "webp", "jpg", "jpeg", "gif", "bmp", "svg", "tiff", "ico"]
|
|
103
108
|
audio_exts = ["mp3", "wav", "ogg", "flac", "m4a", "opus", "webm"]
|
|
104
109
|
|
|
@@ -192,6 +197,10 @@ def is_base_64(data):
|
|
|
192
197
|
return False
|
|
193
198
|
|
|
194
199
|
|
|
200
|
+
def id_to_name(id):
|
|
201
|
+
return id.replace("-", " ").title()
|
|
202
|
+
|
|
203
|
+
|
|
195
204
|
def get_file_mime_type(filename):
|
|
196
205
|
mime_type, _ = mimetypes.guess_type(filename)
|
|
197
206
|
return mime_type or "application/octet-stream"
|
|
@@ -453,6 +462,61 @@ class HTTPError(Exception):
|
|
|
453
462
|
super().__init__(f"HTTP {status} {reason}")
|
|
454
463
|
|
|
455
464
|
|
|
465
|
+
def save_image_to_cache(base64_data, filename, image_info):
|
|
466
|
+
ext = filename.split(".")[-1]
|
|
467
|
+
mimetype = get_file_mime_type(filename)
|
|
468
|
+
content = base64.b64decode(base64_data) if isinstance(base64_data, str) else base64_data
|
|
469
|
+
sha256_hash = hashlib.sha256(content).hexdigest()
|
|
470
|
+
|
|
471
|
+
save_filename = f"{sha256_hash}.{ext}" if ext else sha256_hash
|
|
472
|
+
|
|
473
|
+
# Use first 2 chars for subdir to avoid too many files in one dir
|
|
474
|
+
subdir = sha256_hash[:2]
|
|
475
|
+
relative_path = f"{subdir}/{save_filename}"
|
|
476
|
+
full_path = get_cache_path(relative_path)
|
|
477
|
+
|
|
478
|
+
url = f"~cache/{relative_path}"
|
|
479
|
+
|
|
480
|
+
# if file and its .info.json already exists, return it
|
|
481
|
+
info_path = os.path.splitext(full_path)[0] + ".info.json"
|
|
482
|
+
if os.path.exists(full_path) and os.path.exists(info_path):
|
|
483
|
+
return url, json.load(open(info_path))
|
|
484
|
+
|
|
485
|
+
os.makedirs(os.path.dirname(full_path), exist_ok=True)
|
|
486
|
+
|
|
487
|
+
with open(full_path, "wb") as f:
|
|
488
|
+
f.write(content)
|
|
489
|
+
info = {
|
|
490
|
+
"date": int(time.time()),
|
|
491
|
+
"url": url,
|
|
492
|
+
"size": len(content),
|
|
493
|
+
"type": mimetype,
|
|
494
|
+
"name": filename,
|
|
495
|
+
}
|
|
496
|
+
info.update(image_info)
|
|
497
|
+
|
|
498
|
+
# If image, get dimensions
|
|
499
|
+
if HAS_PIL and mimetype.startswith("image/"):
|
|
500
|
+
try:
|
|
501
|
+
with Image.open(BytesIO(content)) as img:
|
|
502
|
+
info["width"] = img.width
|
|
503
|
+
info["height"] = img.height
|
|
504
|
+
except Exception:
|
|
505
|
+
pass
|
|
506
|
+
|
|
507
|
+
if "width" in info and "height" in info:
|
|
508
|
+
_log(f"Saved image to cache: {full_path} ({len(content)} bytes) {info['width']}x{info['height']}")
|
|
509
|
+
else:
|
|
510
|
+
_log(f"Saved image to cache: {full_path} ({len(content)} bytes)")
|
|
511
|
+
|
|
512
|
+
# Save metadata
|
|
513
|
+
info_path = os.path.splitext(full_path)[0] + ".info.json"
|
|
514
|
+
with open(info_path, "w") as f:
|
|
515
|
+
json.dump(info, f)
|
|
516
|
+
|
|
517
|
+
return url, info
|
|
518
|
+
|
|
519
|
+
|
|
456
520
|
async def response_json(response):
|
|
457
521
|
text = await response.text()
|
|
458
522
|
if response.status >= 400:
|
|
@@ -462,6 +526,120 @@ async def response_json(response):
|
|
|
462
526
|
return body
|
|
463
527
|
|
|
464
528
|
|
|
529
|
+
def chat_to_prompt(chat):
|
|
530
|
+
prompt = ""
|
|
531
|
+
if "messages" in chat:
|
|
532
|
+
for message in chat["messages"]:
|
|
533
|
+
if message["role"] == "user":
|
|
534
|
+
# if content is string
|
|
535
|
+
if isinstance(message["content"], str):
|
|
536
|
+
if prompt:
|
|
537
|
+
prompt += "\n"
|
|
538
|
+
prompt += message["content"]
|
|
539
|
+
elif isinstance(message["content"], list):
|
|
540
|
+
# if content is array of objects
|
|
541
|
+
for part in message["content"]:
|
|
542
|
+
if part["type"] == "text":
|
|
543
|
+
if prompt:
|
|
544
|
+
prompt += "\n"
|
|
545
|
+
prompt += part["text"]
|
|
546
|
+
return prompt
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
def last_user_prompt(chat):
|
|
550
|
+
prompt = ""
|
|
551
|
+
if "messages" in chat:
|
|
552
|
+
for message in chat["messages"]:
|
|
553
|
+
if message["role"] == "user":
|
|
554
|
+
# if content is string
|
|
555
|
+
if isinstance(message["content"], str):
|
|
556
|
+
prompt = message["content"]
|
|
557
|
+
elif isinstance(message["content"], list):
|
|
558
|
+
# if content is array of objects
|
|
559
|
+
for part in message["content"]:
|
|
560
|
+
if part["type"] == "text":
|
|
561
|
+
prompt = part["text"]
|
|
562
|
+
return prompt
|
|
563
|
+
|
|
564
|
+
|
|
565
|
+
# Image Generator Providers
|
|
566
|
+
class GeneratorBase:
|
|
567
|
+
def __init__(self, **kwargs):
|
|
568
|
+
self.id = kwargs.get("id")
|
|
569
|
+
self.api = kwargs.get("api")
|
|
570
|
+
self.api_key = kwargs.get("api_key")
|
|
571
|
+
self.headers = {
|
|
572
|
+
"Accept": "application/json",
|
|
573
|
+
"Content-Type": "application/json",
|
|
574
|
+
}
|
|
575
|
+
self.chat_url = f"{self.api}/chat/completions"
|
|
576
|
+
self.default_content = "I've generated the image for you."
|
|
577
|
+
|
|
578
|
+
def validate(self, **kwargs):
|
|
579
|
+
if not self.api_key:
|
|
580
|
+
api_keys = ", ".join(self.env)
|
|
581
|
+
return f"Provider '{self.name}' requires API Key {api_keys}"
|
|
582
|
+
return None
|
|
583
|
+
|
|
584
|
+
def test(self, **kwargs):
|
|
585
|
+
error_msg = self.validate(**kwargs)
|
|
586
|
+
if error_msg:
|
|
587
|
+
_log(error_msg)
|
|
588
|
+
return False
|
|
589
|
+
return True
|
|
590
|
+
|
|
591
|
+
async def load(self):
|
|
592
|
+
pass
|
|
593
|
+
|
|
594
|
+
def gen_summary(self, gen):
|
|
595
|
+
"""Summarize gen response for logging."""
|
|
596
|
+
clone = json.loads(json.dumps(gen))
|
|
597
|
+
return json.dumps(clone, indent=2)
|
|
598
|
+
|
|
599
|
+
def chat_summary(self, chat):
|
|
600
|
+
return chat_summary(chat)
|
|
601
|
+
|
|
602
|
+
def process_chat(self, chat, provider_id=None):
|
|
603
|
+
return process_chat(chat, provider_id)
|
|
604
|
+
|
|
605
|
+
async def response_json(self, response):
|
|
606
|
+
return await response_json(response)
|
|
607
|
+
|
|
608
|
+
def get_headers(self, provider, chat):
|
|
609
|
+
headers = self.headers.copy()
|
|
610
|
+
if provider is not None:
|
|
611
|
+
headers["Authorization"] = f"Bearer {provider.api_key}"
|
|
612
|
+
elif self.api_key:
|
|
613
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
614
|
+
return headers
|
|
615
|
+
|
|
616
|
+
def to_response(self, response, chat, started_at):
|
|
617
|
+
raise NotImplementedError
|
|
618
|
+
|
|
619
|
+
async def chat(self, chat, provider=None):
|
|
620
|
+
return {
|
|
621
|
+
"choices": [
|
|
622
|
+
{
|
|
623
|
+
"message": {
|
|
624
|
+
"role": "assistant",
|
|
625
|
+
"content": "Not Implemented",
|
|
626
|
+
"images": [
|
|
627
|
+
{
|
|
628
|
+
"type": "image_url",
|
|
629
|
+
"image_url": {
|
|
630
|
+
"url": "",
|
|
631
|
+
},
|
|
632
|
+
}
|
|
633
|
+
],
|
|
634
|
+
}
|
|
635
|
+
}
|
|
636
|
+
]
|
|
637
|
+
}
|
|
638
|
+
|
|
639
|
+
|
|
640
|
+
# OpenAI Providers
|
|
641
|
+
|
|
642
|
+
|
|
465
643
|
class OpenAiCompatible:
|
|
466
644
|
sdk = "@ai-sdk/openai-compatible"
|
|
467
645
|
|
|
@@ -473,8 +651,9 @@ class OpenAiCompatible:
|
|
|
473
651
|
|
|
474
652
|
self.id = kwargs.get("id")
|
|
475
653
|
self.api = kwargs.get("api").strip("/")
|
|
654
|
+
self.env = kwargs.get("env", [])
|
|
476
655
|
self.api_key = kwargs.get("api_key")
|
|
477
|
-
self.name = kwargs.get("name", self.id
|
|
656
|
+
self.name = kwargs.get("name", id_to_name(self.id))
|
|
478
657
|
self.set_models(**kwargs)
|
|
479
658
|
|
|
480
659
|
self.chat_url = f"{self.api}/chat/completions"
|
|
@@ -502,6 +681,7 @@ class OpenAiCompatible:
|
|
|
502
681
|
self.stream = bool(kwargs["stream"]) if "stream" in kwargs else None
|
|
503
682
|
self.enable_thinking = bool(kwargs["enable_thinking"]) if "enable_thinking" in kwargs else None
|
|
504
683
|
self.check = kwargs.get("check")
|
|
684
|
+
self.modalities = kwargs.get("modalities", {})
|
|
505
685
|
|
|
506
686
|
def set_models(self, **kwargs):
|
|
507
687
|
models = kwargs.get("models", {})
|
|
@@ -527,11 +707,18 @@ class OpenAiCompatible:
|
|
|
527
707
|
_log(f"Filtering {len(self.models)} models, excluding models that match regex: {exclude_models}")
|
|
528
708
|
self.models = {k: v for k, v in self.models.items() if not re.search(exclude_models, k)}
|
|
529
709
|
|
|
710
|
+
def validate(self, **kwargs):
|
|
711
|
+
if not self.api_key:
|
|
712
|
+
api_keys = ", ".join(self.env)
|
|
713
|
+
return f"Provider '{self.name}' requires API Key {api_keys}"
|
|
714
|
+
return None
|
|
715
|
+
|
|
530
716
|
def test(self, **kwargs):
|
|
531
|
-
|
|
532
|
-
if
|
|
533
|
-
_log(
|
|
534
|
-
|
|
717
|
+
error_msg = self.validate(**kwargs)
|
|
718
|
+
if error_msg:
|
|
719
|
+
_log(error_msg)
|
|
720
|
+
return False
|
|
721
|
+
return True
|
|
535
722
|
|
|
536
723
|
async def load(self):
|
|
537
724
|
if not self.models:
|
|
@@ -579,56 +766,11 @@ class OpenAiCompatible:
|
|
|
579
766
|
if "/" in model:
|
|
580
767
|
last_part = model.split("/")[-1]
|
|
581
768
|
return self.provider_model(last_part)
|
|
582
|
-
return None
|
|
583
|
-
|
|
584
|
-
def validate_modalities(self, chat):
|
|
585
|
-
model_id = chat.get("model")
|
|
586
|
-
if not model_id or not self.models:
|
|
587
|
-
return
|
|
588
769
|
|
|
589
|
-
|
|
590
|
-
# Try to find model info using provider_model logic (already resolved to ID)
|
|
591
|
-
if model_id in self.models:
|
|
592
|
-
model_info = self.models[model_id]
|
|
593
|
-
else:
|
|
594
|
-
# Fallback scan
|
|
595
|
-
for m_id, m_info in self.models.items():
|
|
596
|
-
if m_id == model_id or m_info.get("id") == model_id:
|
|
597
|
-
model_info = m_info
|
|
598
|
-
break
|
|
599
|
-
|
|
600
|
-
print(f"DEBUG: Validate modalities: model={model_id}, found_info={model_info is not None}")
|
|
601
|
-
if model_info:
|
|
602
|
-
print(f"DEBUG: Modalities: {model_info.get('modalities')}")
|
|
603
|
-
|
|
604
|
-
if not model_info:
|
|
605
|
-
return
|
|
606
|
-
|
|
607
|
-
modalities = model_info.get("modalities", {})
|
|
608
|
-
input_modalities = modalities.get("input", [])
|
|
609
|
-
|
|
610
|
-
# Check for unsupported modalities
|
|
611
|
-
has_audio = False
|
|
612
|
-
has_image = False
|
|
613
|
-
for message in chat.get("messages", []):
|
|
614
|
-
content = message.get("content")
|
|
615
|
-
if isinstance(content, list):
|
|
616
|
-
for item in content:
|
|
617
|
-
type_ = item.get("type")
|
|
618
|
-
if type_ == "input_audio" or "input_audio" in item:
|
|
619
|
-
has_audio = True
|
|
620
|
-
elif type_ == "image_url" or "image_url" in item:
|
|
621
|
-
has_image = True
|
|
622
|
-
|
|
623
|
-
if has_audio and "audio" not in input_modalities:
|
|
624
|
-
raise Exception(
|
|
625
|
-
f"Model '{model_id}' does not support audio input. Supported modalities: {', '.join(input_modalities)}"
|
|
626
|
-
)
|
|
770
|
+
return None
|
|
627
771
|
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
f"Model '{model_id}' does not support image input. Supported modalities: {', '.join(input_modalities)}"
|
|
631
|
-
)
|
|
772
|
+
def response_json(self, response):
|
|
773
|
+
return response_json(response)
|
|
632
774
|
|
|
633
775
|
def to_response(self, response, chat, started_at):
|
|
634
776
|
if "metadata" not in response:
|
|
@@ -638,13 +780,27 @@ class OpenAiCompatible:
|
|
|
638
780
|
pricing = self.model_cost(chat["model"])
|
|
639
781
|
if pricing and "input" in pricing and "output" in pricing:
|
|
640
782
|
response["metadata"]["pricing"] = f"{pricing['input']}/{pricing['output']}"
|
|
641
|
-
_log(json.dumps(response, indent=2))
|
|
642
783
|
return response
|
|
643
784
|
|
|
785
|
+
def chat_summary(self, chat):
|
|
786
|
+
return chat_summary(chat)
|
|
787
|
+
|
|
788
|
+
def process_chat(self, chat, provider_id=None):
|
|
789
|
+
return process_chat(chat, provider_id)
|
|
790
|
+
|
|
644
791
|
async def chat(self, chat):
|
|
645
792
|
chat["model"] = self.provider_model(chat["model"]) or chat["model"]
|
|
646
793
|
|
|
647
|
-
|
|
794
|
+
if "modalities" in chat:
|
|
795
|
+
for modality in chat["modalities"]:
|
|
796
|
+
# use default implementation for text modalities
|
|
797
|
+
if modality == "text":
|
|
798
|
+
continue
|
|
799
|
+
modality_provider = self.modalities.get(modality)
|
|
800
|
+
if modality_provider:
|
|
801
|
+
return await modality_provider.chat(chat, self)
|
|
802
|
+
else:
|
|
803
|
+
raise Exception(f"Provider {self.name} does not support '{modality}' modality")
|
|
648
804
|
|
|
649
805
|
# with open(os.path.join(os.path.dirname(__file__), 'chat.wip.json'), "w") as f:
|
|
650
806
|
# f.write(json.dumps(chat, indent=2))
|
|
@@ -698,193 +854,6 @@ class OpenAiCompatible:
|
|
|
698
854
|
return self.to_response(await response_json(response), chat, started_at)
|
|
699
855
|
|
|
700
856
|
|
|
701
|
-
class OpenAiProvider(OpenAiCompatible):
|
|
702
|
-
sdk = "@ai-sdk/openai"
|
|
703
|
-
|
|
704
|
-
def __init__(self, **kwargs):
|
|
705
|
-
if "api" not in kwargs:
|
|
706
|
-
kwargs["api"] = "https://api.openai.com/v1"
|
|
707
|
-
super().__init__(**kwargs)
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
class AnthropicProvider(OpenAiCompatible):
|
|
711
|
-
sdk = "@ai-sdk/anthropic"
|
|
712
|
-
|
|
713
|
-
def __init__(self, **kwargs):
|
|
714
|
-
if "api" not in kwargs:
|
|
715
|
-
kwargs["api"] = "https://api.anthropic.com/v1"
|
|
716
|
-
super().__init__(**kwargs)
|
|
717
|
-
|
|
718
|
-
# Anthropic uses x-api-key header instead of Authorization
|
|
719
|
-
if self.api_key:
|
|
720
|
-
self.headers = self.headers.copy()
|
|
721
|
-
if "Authorization" in self.headers:
|
|
722
|
-
del self.headers["Authorization"]
|
|
723
|
-
self.headers["x-api-key"] = self.api_key
|
|
724
|
-
|
|
725
|
-
if "anthropic-version" not in self.headers:
|
|
726
|
-
self.headers = self.headers.copy()
|
|
727
|
-
self.headers["anthropic-version"] = "2023-06-01"
|
|
728
|
-
self.chat_url = f"{self.api}/messages"
|
|
729
|
-
|
|
730
|
-
async def chat(self, chat):
|
|
731
|
-
chat["model"] = self.provider_model(chat["model"]) or chat["model"]
|
|
732
|
-
|
|
733
|
-
chat = await process_chat(chat, provider_id=self.id)
|
|
734
|
-
|
|
735
|
-
# Transform OpenAI format to Anthropic format
|
|
736
|
-
anthropic_request = {
|
|
737
|
-
"model": chat["model"],
|
|
738
|
-
"messages": [],
|
|
739
|
-
}
|
|
740
|
-
|
|
741
|
-
# Extract system message (Anthropic uses top-level 'system' parameter)
|
|
742
|
-
system_messages = []
|
|
743
|
-
for message in chat.get("messages", []):
|
|
744
|
-
if message.get("role") == "system":
|
|
745
|
-
content = message.get("content", "")
|
|
746
|
-
if isinstance(content, str):
|
|
747
|
-
system_messages.append(content)
|
|
748
|
-
elif isinstance(content, list):
|
|
749
|
-
for item in content:
|
|
750
|
-
if item.get("type") == "text":
|
|
751
|
-
system_messages.append(item.get("text", ""))
|
|
752
|
-
|
|
753
|
-
if system_messages:
|
|
754
|
-
anthropic_request["system"] = "\n".join(system_messages)
|
|
755
|
-
|
|
756
|
-
# Transform messages (exclude system messages)
|
|
757
|
-
for message in chat.get("messages", []):
|
|
758
|
-
if message.get("role") == "system":
|
|
759
|
-
continue
|
|
760
|
-
|
|
761
|
-
anthropic_message = {"role": message.get("role"), "content": []}
|
|
762
|
-
|
|
763
|
-
content = message.get("content", "")
|
|
764
|
-
if isinstance(content, str):
|
|
765
|
-
anthropic_message["content"] = content
|
|
766
|
-
elif isinstance(content, list):
|
|
767
|
-
for item in content:
|
|
768
|
-
if item.get("type") == "text":
|
|
769
|
-
anthropic_message["content"].append({"type": "text", "text": item.get("text", "")})
|
|
770
|
-
elif item.get("type") == "image_url" and "image_url" in item:
|
|
771
|
-
# Transform OpenAI image_url format to Anthropic format
|
|
772
|
-
image_url = item["image_url"].get("url", "")
|
|
773
|
-
if image_url.startswith("data:"):
|
|
774
|
-
# Extract media type and base64 data
|
|
775
|
-
parts = image_url.split(";base64,", 1)
|
|
776
|
-
if len(parts) == 2:
|
|
777
|
-
media_type = parts[0].replace("data:", "")
|
|
778
|
-
base64_data = parts[1]
|
|
779
|
-
anthropic_message["content"].append(
|
|
780
|
-
{
|
|
781
|
-
"type": "image",
|
|
782
|
-
"source": {"type": "base64", "media_type": media_type, "data": base64_data},
|
|
783
|
-
}
|
|
784
|
-
)
|
|
785
|
-
|
|
786
|
-
anthropic_request["messages"].append(anthropic_message)
|
|
787
|
-
|
|
788
|
-
# Handle max_tokens (required by Anthropic, uses max_tokens not max_completion_tokens)
|
|
789
|
-
if "max_completion_tokens" in chat:
|
|
790
|
-
anthropic_request["max_tokens"] = chat["max_completion_tokens"]
|
|
791
|
-
elif "max_tokens" in chat:
|
|
792
|
-
anthropic_request["max_tokens"] = chat["max_tokens"]
|
|
793
|
-
else:
|
|
794
|
-
# Anthropic requires max_tokens, set a default
|
|
795
|
-
anthropic_request["max_tokens"] = 4096
|
|
796
|
-
|
|
797
|
-
# Copy other supported parameters
|
|
798
|
-
if "temperature" in chat:
|
|
799
|
-
anthropic_request["temperature"] = chat["temperature"]
|
|
800
|
-
if "top_p" in chat:
|
|
801
|
-
anthropic_request["top_p"] = chat["top_p"]
|
|
802
|
-
if "top_k" in chat:
|
|
803
|
-
anthropic_request["top_k"] = chat["top_k"]
|
|
804
|
-
if "stop" in chat:
|
|
805
|
-
anthropic_request["stop_sequences"] = chat["stop"] if isinstance(chat["stop"], list) else [chat["stop"]]
|
|
806
|
-
if "stream" in chat:
|
|
807
|
-
anthropic_request["stream"] = chat["stream"]
|
|
808
|
-
if "tools" in chat:
|
|
809
|
-
anthropic_request["tools"] = chat["tools"]
|
|
810
|
-
if "tool_choice" in chat:
|
|
811
|
-
anthropic_request["tool_choice"] = chat["tool_choice"]
|
|
812
|
-
|
|
813
|
-
_log(f"POST {self.chat_url}")
|
|
814
|
-
_log(f"Anthropic Request: {json.dumps(anthropic_request, indent=2)}")
|
|
815
|
-
|
|
816
|
-
async with aiohttp.ClientSession() as session:
|
|
817
|
-
started_at = time.time()
|
|
818
|
-
async with session.post(
|
|
819
|
-
self.chat_url,
|
|
820
|
-
headers=self.headers,
|
|
821
|
-
data=json.dumps(anthropic_request),
|
|
822
|
-
timeout=aiohttp.ClientTimeout(total=120),
|
|
823
|
-
) as response:
|
|
824
|
-
return self.to_response(await response_json(response), chat, started_at)
|
|
825
|
-
|
|
826
|
-
def to_response(self, response, chat, started_at):
|
|
827
|
-
"""Convert Anthropic response format to OpenAI-compatible format."""
|
|
828
|
-
# Transform Anthropic response to OpenAI format
|
|
829
|
-
openai_response = {
|
|
830
|
-
"id": response.get("id", ""),
|
|
831
|
-
"object": "chat.completion",
|
|
832
|
-
"created": int(started_at),
|
|
833
|
-
"model": response.get("model", ""),
|
|
834
|
-
"choices": [],
|
|
835
|
-
"usage": {},
|
|
836
|
-
}
|
|
837
|
-
|
|
838
|
-
# Transform content blocks to message content
|
|
839
|
-
content_parts = []
|
|
840
|
-
thinking_parts = []
|
|
841
|
-
|
|
842
|
-
for block in response.get("content", []):
|
|
843
|
-
if block.get("type") == "text":
|
|
844
|
-
content_parts.append(block.get("text", ""))
|
|
845
|
-
elif block.get("type") == "thinking":
|
|
846
|
-
# Store thinking blocks separately (some models include reasoning)
|
|
847
|
-
thinking_parts.append(block.get("thinking", ""))
|
|
848
|
-
|
|
849
|
-
# Combine all text content
|
|
850
|
-
message_content = "\n".join(content_parts) if content_parts else ""
|
|
851
|
-
|
|
852
|
-
# Create the choice object
|
|
853
|
-
choice = {
|
|
854
|
-
"index": 0,
|
|
855
|
-
"message": {"role": "assistant", "content": message_content},
|
|
856
|
-
"finish_reason": response.get("stop_reason", "stop"),
|
|
857
|
-
}
|
|
858
|
-
|
|
859
|
-
# Add thinking as metadata if present
|
|
860
|
-
if thinking_parts:
|
|
861
|
-
choice["message"]["thinking"] = "\n".join(thinking_parts)
|
|
862
|
-
|
|
863
|
-
openai_response["choices"].append(choice)
|
|
864
|
-
|
|
865
|
-
# Transform usage
|
|
866
|
-
if "usage" in response:
|
|
867
|
-
usage = response["usage"]
|
|
868
|
-
openai_response["usage"] = {
|
|
869
|
-
"prompt_tokens": usage.get("input_tokens", 0),
|
|
870
|
-
"completion_tokens": usage.get("output_tokens", 0),
|
|
871
|
-
"total_tokens": usage.get("input_tokens", 0) + usage.get("output_tokens", 0),
|
|
872
|
-
}
|
|
873
|
-
|
|
874
|
-
# Add metadata
|
|
875
|
-
if "metadata" not in openai_response:
|
|
876
|
-
openai_response["metadata"] = {}
|
|
877
|
-
openai_response["metadata"]["duration"] = int((time.time() - started_at) * 1000)
|
|
878
|
-
|
|
879
|
-
if chat is not None and "model" in chat:
|
|
880
|
-
cost = self.model_cost(chat["model"])
|
|
881
|
-
if cost and "input" in cost and "output" in cost:
|
|
882
|
-
openai_response["metadata"]["pricing"] = f"{cost['input']}/{cost['output']}"
|
|
883
|
-
|
|
884
|
-
_log(json.dumps(openai_response, indent=2))
|
|
885
|
-
return openai_response
|
|
886
|
-
|
|
887
|
-
|
|
888
857
|
class MistralProvider(OpenAiCompatible):
|
|
889
858
|
sdk = "@ai-sdk/mistral"
|
|
890
859
|
|
|
@@ -941,11 +910,10 @@ class OllamaProvider(OpenAiCompatible):
|
|
|
941
910
|
) as response:
|
|
942
911
|
data = await response_json(response)
|
|
943
912
|
for model in data.get("models", []):
|
|
944
|
-
|
|
945
|
-
if
|
|
946
|
-
|
|
947
|
-
model_id =
|
|
948
|
-
ret[model_id] = name
|
|
913
|
+
model_id = model["model"]
|
|
914
|
+
if model_id.endswith(":latest"):
|
|
915
|
+
model_id = model_id[:-7]
|
|
916
|
+
ret[model_id] = model_id
|
|
949
917
|
_log(f"Loaded Ollama models: {ret}")
|
|
950
918
|
except Exception as e:
|
|
951
919
|
_log(f"Error getting Ollama models: {e}")
|
|
@@ -981,8 +949,8 @@ class OllamaProvider(OpenAiCompatible):
|
|
|
981
949
|
}
|
|
982
950
|
self.models = models
|
|
983
951
|
|
|
984
|
-
def
|
|
985
|
-
return
|
|
952
|
+
def validate(self, **kwargs):
|
|
953
|
+
return None
|
|
986
954
|
|
|
987
955
|
|
|
988
956
|
class LMStudioProvider(OllamaProvider):
|
|
@@ -1011,237 +979,6 @@ class LMStudioProvider(OllamaProvider):
|
|
|
1011
979
|
return ret
|
|
1012
980
|
|
|
1013
981
|
|
|
1014
|
-
# class GoogleOpenAiProvider(OpenAiCompatible):
|
|
1015
|
-
# sdk = "google-openai-compatible"
|
|
1016
|
-
|
|
1017
|
-
# def __init__(self, api_key, **kwargs):
|
|
1018
|
-
# super().__init__(api="https://generativelanguage.googleapis.com", api_key=api_key, **kwargs)
|
|
1019
|
-
# self.chat_url = "https://generativelanguage.googleapis.com/v1beta/chat/completions"
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
class GoogleProvider(OpenAiCompatible):
|
|
1023
|
-
sdk = "@ai-sdk/google"
|
|
1024
|
-
|
|
1025
|
-
def __init__(self, **kwargs):
|
|
1026
|
-
new_kwargs = {"api": "https://generativelanguage.googleapis.com", **kwargs}
|
|
1027
|
-
super().__init__(**new_kwargs)
|
|
1028
|
-
self.safety_settings = kwargs.get("safety_settings")
|
|
1029
|
-
self.thinking_config = kwargs.get("thinking_config")
|
|
1030
|
-
self.curl = kwargs.get("curl")
|
|
1031
|
-
self.headers = kwargs.get("headers", {"Content-Type": "application/json"})
|
|
1032
|
-
# Google fails when using Authorization header, use query string param instead
|
|
1033
|
-
if "Authorization" in self.headers:
|
|
1034
|
-
del self.headers["Authorization"]
|
|
1035
|
-
|
|
1036
|
-
async def chat(self, chat):
|
|
1037
|
-
chat["model"] = self.provider_model(chat["model"]) or chat["model"]
|
|
1038
|
-
|
|
1039
|
-
chat = await process_chat(chat)
|
|
1040
|
-
generation_config = {}
|
|
1041
|
-
|
|
1042
|
-
# Filter out system messages and convert to proper Gemini format
|
|
1043
|
-
contents = []
|
|
1044
|
-
system_prompt = None
|
|
1045
|
-
|
|
1046
|
-
async with aiohttp.ClientSession() as session:
|
|
1047
|
-
for message in chat["messages"]:
|
|
1048
|
-
if message["role"] == "system":
|
|
1049
|
-
content = message["content"]
|
|
1050
|
-
if isinstance(content, list):
|
|
1051
|
-
for item in content:
|
|
1052
|
-
if "text" in item:
|
|
1053
|
-
system_prompt = item["text"]
|
|
1054
|
-
break
|
|
1055
|
-
elif isinstance(content, str):
|
|
1056
|
-
system_prompt = content
|
|
1057
|
-
elif "content" in message:
|
|
1058
|
-
if isinstance(message["content"], list):
|
|
1059
|
-
parts = []
|
|
1060
|
-
for item in message["content"]:
|
|
1061
|
-
if "type" in item:
|
|
1062
|
-
if item["type"] == "image_url" and "image_url" in item:
|
|
1063
|
-
image_url = item["image_url"]
|
|
1064
|
-
if "url" not in image_url:
|
|
1065
|
-
continue
|
|
1066
|
-
url = image_url["url"]
|
|
1067
|
-
if not url.startswith("data:"):
|
|
1068
|
-
raise (Exception("Image was not downloaded: " + url))
|
|
1069
|
-
# Extract mime type from data uri
|
|
1070
|
-
mimetype = url.split(";", 1)[0].split(":", 1)[1] if ";" in url else "image/png"
|
|
1071
|
-
base64_data = url.split(",", 1)[1]
|
|
1072
|
-
parts.append({"inline_data": {"mime_type": mimetype, "data": base64_data}})
|
|
1073
|
-
elif item["type"] == "input_audio" and "input_audio" in item:
|
|
1074
|
-
input_audio = item["input_audio"]
|
|
1075
|
-
if "data" not in input_audio:
|
|
1076
|
-
continue
|
|
1077
|
-
data = input_audio["data"]
|
|
1078
|
-
format = input_audio["format"]
|
|
1079
|
-
mimetype = f"audio/{format}"
|
|
1080
|
-
parts.append({"inline_data": {"mime_type": mimetype, "data": data}})
|
|
1081
|
-
elif item["type"] == "file" and "file" in item:
|
|
1082
|
-
file = item["file"]
|
|
1083
|
-
if "file_data" not in file:
|
|
1084
|
-
continue
|
|
1085
|
-
data = file["file_data"]
|
|
1086
|
-
if not data.startswith("data:"):
|
|
1087
|
-
raise (Exception("File was not downloaded: " + data))
|
|
1088
|
-
# Extract mime type from data uri
|
|
1089
|
-
mimetype = (
|
|
1090
|
-
data.split(";", 1)[0].split(":", 1)[1]
|
|
1091
|
-
if ";" in data
|
|
1092
|
-
else "application/octet-stream"
|
|
1093
|
-
)
|
|
1094
|
-
base64_data = data.split(",", 1)[1]
|
|
1095
|
-
parts.append({"inline_data": {"mime_type": mimetype, "data": base64_data}})
|
|
1096
|
-
if "text" in item:
|
|
1097
|
-
text = item["text"]
|
|
1098
|
-
parts.append({"text": text})
|
|
1099
|
-
if len(parts) > 0:
|
|
1100
|
-
contents.append(
|
|
1101
|
-
{
|
|
1102
|
-
"role": message["role"]
|
|
1103
|
-
if "role" in message and message["role"] == "user"
|
|
1104
|
-
else "model",
|
|
1105
|
-
"parts": parts,
|
|
1106
|
-
}
|
|
1107
|
-
)
|
|
1108
|
-
else:
|
|
1109
|
-
content = message["content"]
|
|
1110
|
-
contents.append(
|
|
1111
|
-
{
|
|
1112
|
-
"role": message["role"] if "role" in message and message["role"] == "user" else "model",
|
|
1113
|
-
"parts": [{"text": content}],
|
|
1114
|
-
}
|
|
1115
|
-
)
|
|
1116
|
-
|
|
1117
|
-
gemini_chat = {
|
|
1118
|
-
"contents": contents,
|
|
1119
|
-
}
|
|
1120
|
-
|
|
1121
|
-
if self.safety_settings:
|
|
1122
|
-
gemini_chat["safetySettings"] = self.safety_settings
|
|
1123
|
-
|
|
1124
|
-
# Add system instruction if present
|
|
1125
|
-
if system_prompt is not None:
|
|
1126
|
-
gemini_chat["systemInstruction"] = {"parts": [{"text": system_prompt}]}
|
|
1127
|
-
|
|
1128
|
-
if "max_completion_tokens" in chat:
|
|
1129
|
-
generation_config["maxOutputTokens"] = chat["max_completion_tokens"]
|
|
1130
|
-
if "stop" in chat:
|
|
1131
|
-
generation_config["stopSequences"] = [chat["stop"]]
|
|
1132
|
-
if "temperature" in chat:
|
|
1133
|
-
generation_config["temperature"] = chat["temperature"]
|
|
1134
|
-
if "top_p" in chat:
|
|
1135
|
-
generation_config["topP"] = chat["top_p"]
|
|
1136
|
-
if "top_logprobs" in chat:
|
|
1137
|
-
generation_config["topK"] = chat["top_logprobs"]
|
|
1138
|
-
|
|
1139
|
-
if "thinkingConfig" in chat:
|
|
1140
|
-
generation_config["thinkingConfig"] = chat["thinkingConfig"]
|
|
1141
|
-
elif self.thinking_config:
|
|
1142
|
-
generation_config["thinkingConfig"] = self.thinking_config
|
|
1143
|
-
|
|
1144
|
-
if len(generation_config) > 0:
|
|
1145
|
-
gemini_chat["generationConfig"] = generation_config
|
|
1146
|
-
|
|
1147
|
-
started_at = int(time.time() * 1000)
|
|
1148
|
-
gemini_chat_url = f"https://generativelanguage.googleapis.com/v1beta/models/{chat['model']}:generateContent?key={self.api_key}"
|
|
1149
|
-
|
|
1150
|
-
_log(f"POST {gemini_chat_url}")
|
|
1151
|
-
_log(gemini_chat_summary(gemini_chat))
|
|
1152
|
-
started_at = time.time()
|
|
1153
|
-
|
|
1154
|
-
if self.curl:
|
|
1155
|
-
curl_args = [
|
|
1156
|
-
"curl",
|
|
1157
|
-
"-X",
|
|
1158
|
-
"POST",
|
|
1159
|
-
"-H",
|
|
1160
|
-
"Content-Type: application/json",
|
|
1161
|
-
"-d",
|
|
1162
|
-
json.dumps(gemini_chat),
|
|
1163
|
-
gemini_chat_url,
|
|
1164
|
-
]
|
|
1165
|
-
try:
|
|
1166
|
-
o = subprocess.run(curl_args, check=True, capture_output=True, text=True, timeout=120)
|
|
1167
|
-
obj = json.loads(o.stdout)
|
|
1168
|
-
except Exception as e:
|
|
1169
|
-
raise Exception(f"Error executing curl: {e}") from e
|
|
1170
|
-
else:
|
|
1171
|
-
async with session.post(
|
|
1172
|
-
gemini_chat_url,
|
|
1173
|
-
headers=self.headers,
|
|
1174
|
-
data=json.dumps(gemini_chat),
|
|
1175
|
-
timeout=aiohttp.ClientTimeout(total=120),
|
|
1176
|
-
) as res:
|
|
1177
|
-
obj = await response_json(res)
|
|
1178
|
-
_log(f"google response:\n{json.dumps(obj, indent=2)}")
|
|
1179
|
-
|
|
1180
|
-
response = {
|
|
1181
|
-
"id": f"chatcmpl-{started_at}",
|
|
1182
|
-
"created": started_at,
|
|
1183
|
-
"model": obj.get("modelVersion", chat["model"]),
|
|
1184
|
-
}
|
|
1185
|
-
choices = []
|
|
1186
|
-
if "error" in obj:
|
|
1187
|
-
_log(f"Error: {obj['error']}")
|
|
1188
|
-
raise Exception(obj["error"]["message"])
|
|
1189
|
-
for i, candidate in enumerate(obj["candidates"]):
|
|
1190
|
-
role = "assistant"
|
|
1191
|
-
if "content" in candidate and "role" in candidate["content"]:
|
|
1192
|
-
role = "assistant" if candidate["content"]["role"] == "model" else candidate["content"]["role"]
|
|
1193
|
-
|
|
1194
|
-
# Safely extract content from all text parts
|
|
1195
|
-
content = ""
|
|
1196
|
-
reasoning = ""
|
|
1197
|
-
if "content" in candidate and "parts" in candidate["content"]:
|
|
1198
|
-
text_parts = []
|
|
1199
|
-
reasoning_parts = []
|
|
1200
|
-
for part in candidate["content"]["parts"]:
|
|
1201
|
-
if "text" in part:
|
|
1202
|
-
if "thought" in part and part["thought"]:
|
|
1203
|
-
reasoning_parts.append(part["text"])
|
|
1204
|
-
else:
|
|
1205
|
-
text_parts.append(part["text"])
|
|
1206
|
-
content = " ".join(text_parts)
|
|
1207
|
-
reasoning = " ".join(reasoning_parts)
|
|
1208
|
-
|
|
1209
|
-
choice = {
|
|
1210
|
-
"index": i,
|
|
1211
|
-
"finish_reason": candidate.get("finishReason", "stop"),
|
|
1212
|
-
"message": {
|
|
1213
|
-
"role": role,
|
|
1214
|
-
"content": content,
|
|
1215
|
-
},
|
|
1216
|
-
}
|
|
1217
|
-
if reasoning:
|
|
1218
|
-
choice["message"]["reasoning"] = reasoning
|
|
1219
|
-
choices.append(choice)
|
|
1220
|
-
response["choices"] = choices
|
|
1221
|
-
if "usageMetadata" in obj:
|
|
1222
|
-
usage = obj["usageMetadata"]
|
|
1223
|
-
response["usage"] = {
|
|
1224
|
-
"completion_tokens": usage["candidatesTokenCount"],
|
|
1225
|
-
"total_tokens": usage["totalTokenCount"],
|
|
1226
|
-
"prompt_tokens": usage["promptTokenCount"],
|
|
1227
|
-
}
|
|
1228
|
-
return self.to_response(response, chat, started_at)
|
|
1229
|
-
|
|
1230
|
-
|
|
1231
|
-
ALL_PROVIDERS = [
|
|
1232
|
-
OpenAiCompatible,
|
|
1233
|
-
OpenAiProvider,
|
|
1234
|
-
AnthropicProvider,
|
|
1235
|
-
MistralProvider,
|
|
1236
|
-
GroqProvider,
|
|
1237
|
-
XaiProvider,
|
|
1238
|
-
CodestralProvider,
|
|
1239
|
-
GoogleProvider,
|
|
1240
|
-
OllamaProvider,
|
|
1241
|
-
LMStudioProvider,
|
|
1242
|
-
]
|
|
1243
|
-
|
|
1244
|
-
|
|
1245
982
|
def get_provider_model(model_name):
|
|
1246
983
|
for provider in g_handlers.values():
|
|
1247
984
|
provider_model = provider.provider_model(model_name)
|
|
@@ -1389,8 +1126,29 @@ async def cli_chat(chat, image=None, audio=None, file=None, args=None, raw=False
|
|
|
1389
1126
|
print(json.dumps(response, indent=2))
|
|
1390
1127
|
exit(0)
|
|
1391
1128
|
else:
|
|
1392
|
-
|
|
1393
|
-
|
|
1129
|
+
msg = response["choices"][0]["message"]
|
|
1130
|
+
if "answer" in msg:
|
|
1131
|
+
answer = msg["content"]
|
|
1132
|
+
print(answer)
|
|
1133
|
+
|
|
1134
|
+
generated_files = []
|
|
1135
|
+
for choice in response["choices"]:
|
|
1136
|
+
if "message" in choice:
|
|
1137
|
+
msg = choice["message"]
|
|
1138
|
+
if "images" in msg:
|
|
1139
|
+
for image in msg["images"]:
|
|
1140
|
+
image_url = image["image_url"]["url"]
|
|
1141
|
+
generated_files.append(image_url)
|
|
1142
|
+
|
|
1143
|
+
if len(generated_files) > 0:
|
|
1144
|
+
print("\nSaved files:")
|
|
1145
|
+
for file in generated_files:
|
|
1146
|
+
if file.startswith("~cache"):
|
|
1147
|
+
print(get_cache_path(file[7:]))
|
|
1148
|
+
_log(f"http://localhost:8000/{file}")
|
|
1149
|
+
else:
|
|
1150
|
+
print(file)
|
|
1151
|
+
|
|
1394
1152
|
except HTTPError as e:
|
|
1395
1153
|
# HTTP error (4xx, 5xx)
|
|
1396
1154
|
print(f"{e}:\n{e.body}")
|
|
@@ -1432,22 +1190,26 @@ def init_llms(config, providers):
|
|
|
1432
1190
|
providers = g_config["providers"]
|
|
1433
1191
|
|
|
1434
1192
|
for id, orig in providers.items():
|
|
1435
|
-
|
|
1436
|
-
if "enabled" in definition and not definition["enabled"]:
|
|
1193
|
+
if "enabled" in orig and not orig["enabled"]:
|
|
1437
1194
|
continue
|
|
1438
1195
|
|
|
1439
|
-
|
|
1440
|
-
if "id" not in definition:
|
|
1441
|
-
definition["id"] = provider_id
|
|
1442
|
-
provider = g_providers.get(provider_id)
|
|
1443
|
-
constructor_kwargs = create_provider_kwargs(definition, provider)
|
|
1444
|
-
provider = create_provider(constructor_kwargs)
|
|
1445
|
-
|
|
1196
|
+
provider, constructor_kwargs = create_provider_from_definition(id, orig)
|
|
1446
1197
|
if provider and provider.test(**constructor_kwargs):
|
|
1447
1198
|
g_handlers[id] = provider
|
|
1448
1199
|
return g_handlers
|
|
1449
1200
|
|
|
1450
1201
|
|
|
1202
|
+
def create_provider_from_definition(id, orig):
|
|
1203
|
+
definition = orig.copy()
|
|
1204
|
+
provider_id = definition.get("id", id)
|
|
1205
|
+
if "id" not in definition:
|
|
1206
|
+
definition["id"] = provider_id
|
|
1207
|
+
provider = g_providers.get(provider_id)
|
|
1208
|
+
constructor_kwargs = create_provider_kwargs(definition, provider)
|
|
1209
|
+
provider = create_provider(constructor_kwargs)
|
|
1210
|
+
return provider, constructor_kwargs
|
|
1211
|
+
|
|
1212
|
+
|
|
1451
1213
|
def create_provider_kwargs(definition, provider=None):
|
|
1452
1214
|
if provider:
|
|
1453
1215
|
provider = provider.copy()
|
|
@@ -1475,6 +1237,15 @@ def create_provider_kwargs(definition, provider=None):
|
|
|
1475
1237
|
if isinstance(value, (list, dict)):
|
|
1476
1238
|
constructor_kwargs[key] = value.copy()
|
|
1477
1239
|
constructor_kwargs["headers"] = g_config["defaults"]["headers"].copy()
|
|
1240
|
+
|
|
1241
|
+
if "modalities" in definition:
|
|
1242
|
+
constructor_kwargs["modalities"] = {}
|
|
1243
|
+
for modality, modality_definition in definition["modalities"].items():
|
|
1244
|
+
modality_provider = create_provider(modality_definition)
|
|
1245
|
+
if not modality_provider:
|
|
1246
|
+
return None
|
|
1247
|
+
constructor_kwargs["modalities"][modality] = modality_provider
|
|
1248
|
+
|
|
1478
1249
|
return constructor_kwargs
|
|
1479
1250
|
|
|
1480
1251
|
|
|
@@ -1487,9 +1258,11 @@ def create_provider(provider):
|
|
|
1487
1258
|
_log(f"Provider {provider_label} is missing 'npm' sdk")
|
|
1488
1259
|
return None
|
|
1489
1260
|
|
|
1490
|
-
for provider_type in
|
|
1261
|
+
for provider_type in g_app.all_providers:
|
|
1491
1262
|
if provider_type.sdk == npm_sdk:
|
|
1492
1263
|
kwargs = create_provider_kwargs(provider)
|
|
1264
|
+
if kwargs is None:
|
|
1265
|
+
kwargs = provider
|
|
1493
1266
|
return provider_type(**kwargs)
|
|
1494
1267
|
|
|
1495
1268
|
_log(f"Could not find provider {provider_label} with npm sdk {npm_sdk}")
|
|
@@ -1543,11 +1316,23 @@ async def update_providers(home_providers_path):
|
|
|
1543
1316
|
global g_providers
|
|
1544
1317
|
text = await get_text("https://models.dev/api.json")
|
|
1545
1318
|
all_providers = json.loads(text)
|
|
1319
|
+
extra_providers = {}
|
|
1320
|
+
extra_providers_path = home_providers_path.replace("providers.json", "providers-extra.json")
|
|
1321
|
+
if os.path.exists(extra_providers_path):
|
|
1322
|
+
with open(extra_providers_path) as f:
|
|
1323
|
+
extra_providers = json.load(f)
|
|
1546
1324
|
|
|
1547
1325
|
filtered_providers = {}
|
|
1548
1326
|
for id, provider in all_providers.items():
|
|
1549
1327
|
if id in g_config["providers"]:
|
|
1550
1328
|
filtered_providers[id] = provider
|
|
1329
|
+
if id in extra_providers and "models" in extra_providers[id]:
|
|
1330
|
+
for model_id, model in extra_providers[id]["models"].items():
|
|
1331
|
+
if "id" not in model:
|
|
1332
|
+
model["id"] = model_id
|
|
1333
|
+
if "name" not in model:
|
|
1334
|
+
model["name"] = id_to_name(model["id"])
|
|
1335
|
+
filtered_providers[id]["models"][model_id] = model
|
|
1551
1336
|
|
|
1552
1337
|
os.makedirs(os.path.dirname(home_providers_path), exist_ok=True)
|
|
1553
1338
|
with open(home_providers_path, "w", encoding="utf-8") as f:
|
|
@@ -1600,26 +1385,18 @@ def get_config_path():
|
|
|
1600
1385
|
return None
|
|
1601
1386
|
|
|
1602
1387
|
|
|
1603
|
-
def get_ui_path():
|
|
1604
|
-
ui_paths = [home_llms_path("ui.json"), "ui.json"]
|
|
1605
|
-
for ui_path in ui_paths:
|
|
1606
|
-
if os.path.exists(ui_path):
|
|
1607
|
-
return ui_path
|
|
1608
|
-
return None
|
|
1609
|
-
|
|
1610
|
-
|
|
1611
1388
|
def enable_provider(provider):
|
|
1612
1389
|
msg = None
|
|
1613
1390
|
provider_config = g_config["providers"][provider]
|
|
1391
|
+
if not provider_config:
|
|
1392
|
+
return None, f"Provider {provider} not found"
|
|
1393
|
+
|
|
1394
|
+
provider, constructor_kwargs = create_provider_from_definition(provider, provider_config)
|
|
1395
|
+
msg = provider.validate(**constructor_kwargs)
|
|
1396
|
+
if msg:
|
|
1397
|
+
return None, msg
|
|
1398
|
+
|
|
1614
1399
|
provider_config["enabled"] = True
|
|
1615
|
-
if "api_key" in provider_config:
|
|
1616
|
-
api_key = provider_config["api_key"]
|
|
1617
|
-
if isinstance(api_key, str):
|
|
1618
|
-
if api_key.startswith("$"):
|
|
1619
|
-
if not os.environ.get(api_key[1:], ""):
|
|
1620
|
-
msg = f"WARNING: {provider} requires missing API Key in Environment Variable {api_key}"
|
|
1621
|
-
else:
|
|
1622
|
-
msg = f"WARNING: {provider} is not configured with an API Key"
|
|
1623
1400
|
save_config(g_config)
|
|
1624
1401
|
init_llms(g_config, g_providers)
|
|
1625
1402
|
return provider_config, msg
|
|
@@ -1944,9 +1721,14 @@ async def text_from_resource_or_url(filename):
|
|
|
1944
1721
|
|
|
1945
1722
|
async def save_home_configs():
|
|
1946
1723
|
home_config_path = home_llms_path("llms.json")
|
|
1947
|
-
home_ui_path = home_llms_path("ui.json")
|
|
1948
1724
|
home_providers_path = home_llms_path("providers.json")
|
|
1949
|
-
|
|
1725
|
+
home_providers_extra_path = home_llms_path("providers-extra.json")
|
|
1726
|
+
|
|
1727
|
+
if (
|
|
1728
|
+
os.path.exists(home_config_path)
|
|
1729
|
+
and os.path.exists(home_providers_path)
|
|
1730
|
+
and os.path.exists(home_providers_extra_path)
|
|
1731
|
+
):
|
|
1950
1732
|
return
|
|
1951
1733
|
|
|
1952
1734
|
llms_home = os.path.dirname(home_config_path)
|
|
@@ -1958,17 +1740,17 @@ async def save_home_configs():
|
|
|
1958
1740
|
f.write(config_json)
|
|
1959
1741
|
_log(f"Created default config at {home_config_path}")
|
|
1960
1742
|
|
|
1961
|
-
if not os.path.exists(home_ui_path):
|
|
1962
|
-
ui_json = await text_from_resource_or_url("ui.json")
|
|
1963
|
-
with open(home_ui_path, "w", encoding="utf-8") as f:
|
|
1964
|
-
f.write(ui_json)
|
|
1965
|
-
_log(f"Created default ui config at {home_ui_path}")
|
|
1966
|
-
|
|
1967
1743
|
if not os.path.exists(home_providers_path):
|
|
1968
1744
|
providers_json = await text_from_resource_or_url("providers.json")
|
|
1969
1745
|
with open(home_providers_path, "w", encoding="utf-8") as f:
|
|
1970
1746
|
f.write(providers_json)
|
|
1971
1747
|
_log(f"Created default providers config at {home_providers_path}")
|
|
1748
|
+
|
|
1749
|
+
if not os.path.exists(home_providers_extra_path):
|
|
1750
|
+
extra_json = await text_from_resource_or_url("providers-extra.json")
|
|
1751
|
+
with open(home_providers_extra_path, "w", encoding="utf-8") as f:
|
|
1752
|
+
f.write(extra_json)
|
|
1753
|
+
_log(f"Created default extra providers config at {home_providers_extra_path}")
|
|
1972
1754
|
except Exception:
|
|
1973
1755
|
print("Could not create llms.json. Create one with --init or use --config <path>")
|
|
1974
1756
|
exit(1)
|
|
@@ -2005,62 +1787,348 @@ async def reload_providers():
|
|
|
2005
1787
|
return g_handlers
|
|
2006
1788
|
|
|
2007
1789
|
|
|
2008
|
-
async def watch_config_files(config_path,
|
|
1790
|
+
async def watch_config_files(config_path, providers_path, interval=1):
|
|
2009
1791
|
"""Watch config files and reload providers when they change"""
|
|
2010
1792
|
global g_config
|
|
2011
1793
|
|
|
2012
1794
|
config_path = Path(config_path)
|
|
2013
|
-
|
|
1795
|
+
providers_path = Path(providers_path)
|
|
2014
1796
|
|
|
2015
|
-
|
|
1797
|
+
_log(f"Watching config file: {config_path}")
|
|
1798
|
+
_log(f"Watching providers file: {providers_path}")
|
|
2016
1799
|
|
|
2017
|
-
|
|
1800
|
+
def get_latest_mtime():
|
|
1801
|
+
ret = 0
|
|
1802
|
+
name = "llms.json"
|
|
1803
|
+
if config_path.is_file():
|
|
1804
|
+
ret = config_path.stat().st_mtime
|
|
1805
|
+
name = config_path.name
|
|
1806
|
+
if providers_path.is_file() and providers_path.stat().st_mtime > ret:
|
|
1807
|
+
ret = providers_path.stat().st_mtime
|
|
1808
|
+
name = providers_path.name
|
|
1809
|
+
return ret, name
|
|
1810
|
+
|
|
1811
|
+
latest_mtime, name = get_latest_mtime()
|
|
2018
1812
|
|
|
2019
1813
|
while True:
|
|
2020
1814
|
await asyncio.sleep(interval)
|
|
2021
1815
|
|
|
2022
1816
|
# Check llms.json
|
|
2023
1817
|
try:
|
|
2024
|
-
|
|
2025
|
-
|
|
1818
|
+
new_mtime, name = get_latest_mtime()
|
|
1819
|
+
if new_mtime > latest_mtime:
|
|
1820
|
+
_log(f"Config file changed: {name}")
|
|
1821
|
+
latest_mtime = new_mtime
|
|
2026
1822
|
|
|
2027
|
-
|
|
2028
|
-
|
|
2029
|
-
|
|
2030
|
-
|
|
2031
|
-
file_mtimes[str(config_path)] = mtime
|
|
1823
|
+
try:
|
|
1824
|
+
# Reload llms.json
|
|
1825
|
+
with open(config_path) as f:
|
|
1826
|
+
g_config = json.load(f)
|
|
2032
1827
|
|
|
2033
|
-
|
|
2034
|
-
|
|
2035
|
-
|
|
2036
|
-
|
|
2037
|
-
|
|
2038
|
-
# Reload providers
|
|
2039
|
-
await reload_providers()
|
|
2040
|
-
_log("Providers reloaded successfully")
|
|
2041
|
-
except Exception as e:
|
|
2042
|
-
_log(f"Error reloading config: {e}")
|
|
1828
|
+
# Reload providers
|
|
1829
|
+
await reload_providers()
|
|
1830
|
+
_log("Providers reloaded successfully")
|
|
1831
|
+
except Exception as e:
|
|
1832
|
+
_log(f"Error reloading config: {e}")
|
|
2043
1833
|
except FileNotFoundError:
|
|
2044
1834
|
pass
|
|
2045
1835
|
|
|
2046
|
-
|
|
2047
|
-
|
|
1836
|
+
|
|
1837
|
+
def get_session_token(request):
|
|
1838
|
+
return request.query.get("session") or request.headers.get("X-Session-Token") or request.cookies.get("llms-token")
|
|
1839
|
+
|
|
1840
|
+
|
|
1841
|
+
class AppExtensions:
|
|
1842
|
+
"""
|
|
1843
|
+
APIs extensions can use to extend the app
|
|
1844
|
+
"""
|
|
1845
|
+
|
|
1846
|
+
def __init__(self, cli_args, extra_args):
|
|
1847
|
+
self.cli_args = cli_args
|
|
1848
|
+
self.extra_args = extra_args
|
|
1849
|
+
self.ui_extensions = []
|
|
1850
|
+
self.chat_request_filters = []
|
|
1851
|
+
self.chat_response_filters = []
|
|
1852
|
+
self.server_add_get = []
|
|
1853
|
+
self.server_add_post = []
|
|
1854
|
+
self.all_providers = [
|
|
1855
|
+
OpenAiCompatible,
|
|
1856
|
+
MistralProvider,
|
|
1857
|
+
GroqProvider,
|
|
1858
|
+
XaiProvider,
|
|
1859
|
+
CodestralProvider,
|
|
1860
|
+
OllamaProvider,
|
|
1861
|
+
LMStudioProvider,
|
|
1862
|
+
]
|
|
1863
|
+
self.aspect_ratios = {
|
|
1864
|
+
"1:1": "1024×1024",
|
|
1865
|
+
"2:3": "832×1248",
|
|
1866
|
+
"3:2": "1248×832",
|
|
1867
|
+
"3:4": "864×1184",
|
|
1868
|
+
"4:3": "1184×864",
|
|
1869
|
+
"4:5": "896×1152",
|
|
1870
|
+
"5:4": "1152×896",
|
|
1871
|
+
"9:16": "768×1344",
|
|
1872
|
+
"16:9": "1344×768",
|
|
1873
|
+
"21:9": "1536×672",
|
|
1874
|
+
}
|
|
1875
|
+
|
|
1876
|
+
|
|
1877
|
+
class ExtensionContext:
|
|
1878
|
+
def __init__(self, app, path):
|
|
1879
|
+
self.app = app
|
|
1880
|
+
self.path = path
|
|
1881
|
+
self.name = os.path.basename(path)
|
|
1882
|
+
if self.name.endswith(".py"):
|
|
1883
|
+
self.name = self.name[:-3]
|
|
1884
|
+
self.ext_prefix = f"/ext/{self.name}"
|
|
1885
|
+
self.MOCK = MOCK
|
|
1886
|
+
self.MOCK_DIR = MOCK_DIR
|
|
1887
|
+
self.debug = DEBUG
|
|
1888
|
+
self.verbose = g_verbose
|
|
1889
|
+
|
|
1890
|
+
def chat_to_prompt(self, chat):
|
|
1891
|
+
return chat_to_prompt(chat)
|
|
1892
|
+
|
|
1893
|
+
def last_user_prompt(self, chat):
|
|
1894
|
+
return last_user_prompt(chat)
|
|
1895
|
+
|
|
1896
|
+
def save_image_to_cache(self, base64_data, filename, image_info):
|
|
1897
|
+
return save_image_to_cache(base64_data, filename, image_info)
|
|
1898
|
+
|
|
1899
|
+
def text_from_file(self, path):
|
|
1900
|
+
return text_from_file(path)
|
|
1901
|
+
|
|
1902
|
+
def log(self, message):
|
|
1903
|
+
if self.verbose:
|
|
1904
|
+
print(f"[{self.name}] {message}", flush=True)
|
|
1905
|
+
return message
|
|
1906
|
+
|
|
1907
|
+
def log_json(self, obj):
|
|
1908
|
+
if self.verbose:
|
|
1909
|
+
print(f"[{self.name}] {json.dumps(obj, indent=2)}", flush=True)
|
|
1910
|
+
return obj
|
|
1911
|
+
|
|
1912
|
+
def dbg(self, message):
|
|
1913
|
+
if self.debug:
|
|
1914
|
+
print(f"DEBUG [{self.name}]: {message}", flush=True)
|
|
1915
|
+
|
|
1916
|
+
def err(self, message, e):
|
|
1917
|
+
print(f"ERROR [{self.name}]: {message}", e)
|
|
1918
|
+
if self.verbose:
|
|
1919
|
+
print(traceback.format_exc(), flush=True)
|
|
1920
|
+
|
|
1921
|
+
def add_provider(self, provider):
|
|
1922
|
+
self.log(f"Registered provider: {provider}")
|
|
1923
|
+
self.app.all_providers.append(provider)
|
|
1924
|
+
|
|
1925
|
+
def register_ui_extension(self, index):
|
|
1926
|
+
path = os.path.join(self.ext_prefix, index)
|
|
1927
|
+
self.log(f"Registered UI extension: {path}")
|
|
1928
|
+
self.app.ui_extensions.append({"id": self.name, "path": path})
|
|
1929
|
+
|
|
1930
|
+
def register_chat_request_filter(self, handler):
|
|
1931
|
+
self.log(f"Registered chat request filter: {handler}")
|
|
1932
|
+
self.app.chat_request_filters.append(handler)
|
|
1933
|
+
|
|
1934
|
+
def register_chat_response_filter(self, handler):
|
|
1935
|
+
self.log(f"Registered chat response filter: {handler}")
|
|
1936
|
+
self.app.chat_response_filters.append(handler)
|
|
1937
|
+
|
|
1938
|
+
def add_static_files(self, ext_dir):
|
|
1939
|
+
self.log(f"Registered static files: {ext_dir}")
|
|
1940
|
+
|
|
1941
|
+
async def serve_static(request):
|
|
1942
|
+
path = request.match_info["path"]
|
|
1943
|
+
file_path = os.path.join(ext_dir, path)
|
|
1944
|
+
if os.path.exists(file_path):
|
|
1945
|
+
return web.FileResponse(file_path)
|
|
1946
|
+
return web.Response(status=404)
|
|
1947
|
+
|
|
1948
|
+
self.app.server_add_get.append((os.path.join(self.ext_prefix, "{path:.*}"), serve_static, {}))
|
|
1949
|
+
|
|
1950
|
+
def add_get(self, path, handler, **kwargs):
|
|
1951
|
+
self.dbg(f"Registered GET: {os.path.join(self.ext_prefix, path)}")
|
|
1952
|
+
self.app.server_add_get.append((os.path.join(self.ext_prefix, path), handler, kwargs))
|
|
1953
|
+
|
|
1954
|
+
def add_post(self, path, handler, **kwargs):
|
|
1955
|
+
self.dbg(f"Registered POST: {os.path.join(self.ext_prefix, path)}")
|
|
1956
|
+
self.app.server_add_post.append((os.path.join(self.ext_prefix, path), handler, kwargs))
|
|
1957
|
+
|
|
1958
|
+
def get_config(self):
|
|
1959
|
+
return g_config
|
|
1960
|
+
|
|
1961
|
+
def chat_completion(self, chat):
|
|
1962
|
+
return chat_completion(chat)
|
|
1963
|
+
|
|
1964
|
+
def get_providers(self):
|
|
1965
|
+
return g_handlers
|
|
1966
|
+
|
|
1967
|
+
def get_provider(self, name):
|
|
1968
|
+
return g_handlers.get(name)
|
|
1969
|
+
|
|
1970
|
+
def get_session(self, request):
|
|
1971
|
+
session_token = get_session_token(request)
|
|
1972
|
+
|
|
1973
|
+
if not session_token or session_token not in g_sessions:
|
|
1974
|
+
return None
|
|
1975
|
+
|
|
1976
|
+
session_data = g_sessions[session_token]
|
|
1977
|
+
return session_data
|
|
1978
|
+
|
|
1979
|
+
def get_username(self, request):
|
|
1980
|
+
session = self.get_session(request)
|
|
1981
|
+
if session:
|
|
1982
|
+
return session.get("userName")
|
|
1983
|
+
return None
|
|
1984
|
+
|
|
1985
|
+
|
|
1986
|
+
def load_builtin_extensions():
|
|
1987
|
+
providers_path = _ROOT / "providers"
|
|
1988
|
+
if not providers_path.exists():
|
|
1989
|
+
return
|
|
1990
|
+
|
|
1991
|
+
for item in os.listdir(providers_path):
|
|
1992
|
+
if not item.endswith(".py") or item == "__init__.py":
|
|
1993
|
+
continue
|
|
1994
|
+
|
|
1995
|
+
item_path = providers_path / item
|
|
1996
|
+
module_name = item[:-3]
|
|
1997
|
+
|
|
1998
|
+
try:
|
|
1999
|
+
spec = importlib.util.spec_from_file_location(module_name, item_path)
|
|
2000
|
+
if spec and spec.loader:
|
|
2001
|
+
module = importlib.util.module_from_spec(spec)
|
|
2002
|
+
sys.modules[f"llms.providers.{module_name}"] = module
|
|
2003
|
+
spec.loader.exec_module(module)
|
|
2004
|
+
|
|
2005
|
+
install_func = getattr(module, "__install__", None)
|
|
2006
|
+
if callable(install_func):
|
|
2007
|
+
install_func(ExtensionContext(g_app, item_path))
|
|
2008
|
+
_log(f"Loaded builtin extension: {module_name}")
|
|
2009
|
+
except Exception as e:
|
|
2010
|
+
_err(f"Failed to load builtin extension {module_name}", e)
|
|
2011
|
+
|
|
2012
|
+
|
|
2013
|
+
def get_extensions_path():
|
|
2014
|
+
return os.path.join(Path.home(), ".llms", "extensions")
|
|
2015
|
+
|
|
2016
|
+
|
|
2017
|
+
def init_extensions(parser):
|
|
2018
|
+
extensions_path = get_extensions_path()
|
|
2019
|
+
os.makedirs(extensions_path, exist_ok=True)
|
|
2020
|
+
|
|
2021
|
+
for item in os.listdir(extensions_path):
|
|
2022
|
+
item_path = os.path.join(extensions_path, item)
|
|
2023
|
+
if os.path.isdir(item_path):
|
|
2048
2024
|
try:
|
|
2049
|
-
if
|
|
2050
|
-
|
|
2025
|
+
# check for __parser__ function if exists in __init.__.py and call it with parser
|
|
2026
|
+
init_file = os.path.join(item_path, "__init__.py")
|
|
2027
|
+
if os.path.exists(init_file):
|
|
2028
|
+
spec = importlib.util.spec_from_file_location(item, init_file)
|
|
2029
|
+
if spec and spec.loader:
|
|
2030
|
+
module = importlib.util.module_from_spec(spec)
|
|
2031
|
+
sys.modules[item] = module
|
|
2032
|
+
spec.loader.exec_module(module)
|
|
2033
|
+
|
|
2034
|
+
parser_func = getattr(module, "__parser__", None)
|
|
2035
|
+
if callable(parser_func):
|
|
2036
|
+
parser_func(parser)
|
|
2037
|
+
_log(f"Extension {item} parser loaded")
|
|
2038
|
+
except Exception as e:
|
|
2039
|
+
_err(f"Failed to load extension {item} parser", e)
|
|
2040
|
+
|
|
2041
|
+
|
|
2042
|
+
def install_extensions():
|
|
2043
|
+
"""
|
|
2044
|
+
Scans ensure ~/.llms/extensions/ for directories with __init__.py and loads them as extensions.
|
|
2045
|
+
Calls the `__install__(ctx)` function in the extension module.
|
|
2046
|
+
"""
|
|
2047
|
+
extensions_path = get_extensions_path()
|
|
2048
|
+
os.makedirs(extensions_path, exist_ok=True)
|
|
2049
|
+
|
|
2050
|
+
ext_count = len(os.listdir(extensions_path))
|
|
2051
|
+
if ext_count == 0:
|
|
2052
|
+
_log("No extensions found")
|
|
2053
|
+
return
|
|
2054
|
+
|
|
2055
|
+
_log(f"Installing {ext_count} extension{'' if ext_count == 1 else 's'}...")
|
|
2056
|
+
|
|
2057
|
+
sys.path.append(extensions_path)
|
|
2058
|
+
|
|
2059
|
+
for item in os.listdir(extensions_path):
|
|
2060
|
+
item_path = os.path.join(extensions_path, item)
|
|
2061
|
+
if os.path.isdir(item_path):
|
|
2062
|
+
init_file = os.path.join(item_path, "__init__.py")
|
|
2063
|
+
if os.path.exists(init_file):
|
|
2064
|
+
ctx = ExtensionContext(g_app, item_path)
|
|
2065
|
+
try:
|
|
2066
|
+
spec = importlib.util.spec_from_file_location(item, init_file)
|
|
2067
|
+
if spec and spec.loader:
|
|
2068
|
+
module = importlib.util.module_from_spec(spec)
|
|
2069
|
+
sys.modules[item] = module
|
|
2070
|
+
spec.loader.exec_module(module)
|
|
2071
|
+
|
|
2072
|
+
install_func = getattr(module, "__install__", None)
|
|
2073
|
+
if callable(install_func):
|
|
2074
|
+
install_func(ctx)
|
|
2075
|
+
_log(f"Extension {item} installed")
|
|
2076
|
+
else:
|
|
2077
|
+
_dbg(f"Extension {item} has no __install__ function")
|
|
2078
|
+
else:
|
|
2079
|
+
_dbg(f"Extension {item} has no __init__.py")
|
|
2080
|
+
|
|
2081
|
+
# if ui folder exists, serve as static files at /ext/{item}/
|
|
2082
|
+
ui_path = os.path.join(item_path, "ui")
|
|
2083
|
+
if os.path.exists(ui_path):
|
|
2084
|
+
ctx.add_static_files(ui_path)
|
|
2085
|
+
|
|
2086
|
+
# Register UI extension if index.mjs exists (/ext/{item}/index.mjs)
|
|
2087
|
+
if os.path.exists(os.path.join(ui_path, "index.mjs")):
|
|
2088
|
+
ctx.register_ui_extension("index.mjs")
|
|
2089
|
+
|
|
2090
|
+
except Exception as e:
|
|
2091
|
+
_err(f"Failed to install extension {item}", e)
|
|
2092
|
+
else:
|
|
2093
|
+
_dbg(f"Extension {init_file} not found")
|
|
2094
|
+
else:
|
|
2095
|
+
_dbg(f"Extension {item} not found: {item_path} is not a directory {os.path.exists(item_path)}")
|
|
2096
|
+
|
|
2097
|
+
|
|
2098
|
+
def run_extension_cli():
|
|
2099
|
+
"""
|
|
2100
|
+
Run the CLI for an extension.
|
|
2101
|
+
"""
|
|
2102
|
+
extensions_path = get_extensions_path()
|
|
2103
|
+
os.makedirs(extensions_path, exist_ok=True)
|
|
2104
|
+
|
|
2105
|
+
for item in os.listdir(extensions_path):
|
|
2106
|
+
item_path = os.path.join(extensions_path, item)
|
|
2107
|
+
if os.path.isdir(item_path):
|
|
2108
|
+
init_file = os.path.join(item_path, "__init__.py")
|
|
2109
|
+
if os.path.exists(init_file):
|
|
2110
|
+
ctx = ExtensionContext(g_app, item_path)
|
|
2111
|
+
try:
|
|
2112
|
+
spec = importlib.util.spec_from_file_location(item, init_file)
|
|
2113
|
+
if spec and spec.loader:
|
|
2114
|
+
module = importlib.util.module_from_spec(spec)
|
|
2115
|
+
sys.modules[item] = module
|
|
2116
|
+
spec.loader.exec_module(module)
|
|
2117
|
+
|
|
2118
|
+
# Check for __run__ function if exists in __init__.py and call it with ctx
|
|
2119
|
+
run_func = getattr(module, "__run__", None)
|
|
2120
|
+
if callable(run_func):
|
|
2121
|
+
handled = run_func(ctx)
|
|
2122
|
+
_log(f"Extension {item} was run")
|
|
2123
|
+
return handled
|
|
2051
2124
|
|
|
2052
|
-
|
|
2053
|
-
|
|
2054
|
-
|
|
2055
|
-
_log(f"Config file changed: {ui_path.name}")
|
|
2056
|
-
file_mtimes[str(ui_path)] = mtime
|
|
2057
|
-
_log("ui.json reloaded - reload page to update")
|
|
2058
|
-
except FileNotFoundError:
|
|
2059
|
-
pass
|
|
2125
|
+
except Exception as e:
|
|
2126
|
+
_err(f"Failed to run extension {item}", e)
|
|
2127
|
+
return False
|
|
2060
2128
|
|
|
2061
2129
|
|
|
2062
2130
|
def main():
|
|
2063
|
-
global _ROOT, g_verbose, g_default_model, g_logprefix, g_providers, g_config, g_config_path,
|
|
2131
|
+
global _ROOT, g_verbose, g_default_model, g_logprefix, g_providers, g_config, g_config_path, g_app
|
|
2064
2132
|
|
|
2065
2133
|
parser = argparse.ArgumentParser(description=f"llms v{VERSION}")
|
|
2066
2134
|
parser.add_argument("--config", default=None, help="Path to config file", metavar="FILE")
|
|
@@ -2074,6 +2142,7 @@ def main():
|
|
|
2074
2142
|
parser.add_argument("--image", default=None, help="Image input to use in chat completion")
|
|
2075
2143
|
parser.add_argument("--audio", default=None, help="Audio input to use in chat completion")
|
|
2076
2144
|
parser.add_argument("--file", default=None, help="File input to use in chat completion")
|
|
2145
|
+
parser.add_argument("--out", default=None, help="Image or Video Generation Request", metavar="MODALITY")
|
|
2077
2146
|
parser.add_argument(
|
|
2078
2147
|
"--args",
|
|
2079
2148
|
default=None,
|
|
@@ -2096,14 +2165,46 @@ def main():
|
|
|
2096
2165
|
parser.add_argument("--default", default=None, help="Configure the default model to use", metavar="MODEL")
|
|
2097
2166
|
|
|
2098
2167
|
parser.add_argument("--init", action="store_true", help="Create a default llms.json")
|
|
2099
|
-
parser.add_argument("--update", action="store_true", help="Update local models.dev providers.json")
|
|
2168
|
+
parser.add_argument("--update-providers", action="store_true", help="Update local models.dev providers.json")
|
|
2169
|
+
parser.add_argument("--update-extensions", action="store_true", help="Update installed extensions")
|
|
2100
2170
|
|
|
2101
2171
|
parser.add_argument("--root", default=None, help="Change root directory for UI files", metavar="PATH")
|
|
2102
2172
|
parser.add_argument("--logprefix", default="", help="Prefix used in log messages", metavar="PREFIX")
|
|
2103
2173
|
parser.add_argument("--verbose", action="store_true", help="Verbose output")
|
|
2104
2174
|
|
|
2175
|
+
parser.add_argument(
|
|
2176
|
+
"--add",
|
|
2177
|
+
nargs="?",
|
|
2178
|
+
const="ls",
|
|
2179
|
+
default=None,
|
|
2180
|
+
help="Install an extension (lists available extensions if no name provided)",
|
|
2181
|
+
metavar="EXTENSION",
|
|
2182
|
+
)
|
|
2183
|
+
parser.add_argument(
|
|
2184
|
+
"--remove",
|
|
2185
|
+
nargs="?",
|
|
2186
|
+
const="ls",
|
|
2187
|
+
default=None,
|
|
2188
|
+
help="Remove an extension (lists installed extensions if no name provided)",
|
|
2189
|
+
metavar="EXTENSION",
|
|
2190
|
+
)
|
|
2191
|
+
|
|
2192
|
+
parser.add_argument(
|
|
2193
|
+
"--update",
|
|
2194
|
+
nargs="?",
|
|
2195
|
+
const="ls",
|
|
2196
|
+
default=None,
|
|
2197
|
+
help="Update an extension (use 'all' to update all extensions)",
|
|
2198
|
+
metavar="EXTENSION",
|
|
2199
|
+
)
|
|
2200
|
+
|
|
2201
|
+
# Load parser extensions, go through all extensions and load their parser arguments
|
|
2202
|
+
init_extensions(parser)
|
|
2203
|
+
|
|
2105
2204
|
cli_args, extra_args = parser.parse_known_args()
|
|
2106
2205
|
|
|
2206
|
+
g_app = AppExtensions(cli_args, extra_args)
|
|
2207
|
+
|
|
2107
2208
|
# Check for verbose mode from CLI argument or environment variables
|
|
2108
2209
|
verbose_env = os.environ.get("VERBOSE", "").lower()
|
|
2109
2210
|
if cli_args.verbose or verbose_env in ("1", "true"):
|
|
@@ -2120,8 +2221,8 @@ def main():
|
|
|
2120
2221
|
exit(1)
|
|
2121
2222
|
|
|
2122
2223
|
home_config_path = home_llms_path("llms.json")
|
|
2123
|
-
home_ui_path = home_llms_path("ui.json")
|
|
2124
2224
|
home_providers_path = home_llms_path("providers.json")
|
|
2225
|
+
home_providers_extra_path = home_llms_path("providers-extra.json")
|
|
2125
2226
|
|
|
2126
2227
|
if cli_args.init:
|
|
2127
2228
|
if os.path.exists(home_config_path):
|
|
@@ -2130,17 +2231,17 @@ def main():
|
|
|
2130
2231
|
asyncio.run(save_default_config(home_config_path))
|
|
2131
2232
|
print(f"Created default config at {home_config_path}")
|
|
2132
2233
|
|
|
2133
|
-
if os.path.exists(home_ui_path):
|
|
2134
|
-
print(f"ui.json already exists at {home_ui_path}")
|
|
2135
|
-
else:
|
|
2136
|
-
asyncio.run(save_text_url(github_url("ui.json"), home_ui_path))
|
|
2137
|
-
print(f"Created default ui config at {home_ui_path}")
|
|
2138
|
-
|
|
2139
2234
|
if os.path.exists(home_providers_path):
|
|
2140
2235
|
print(f"providers.json already exists at {home_providers_path}")
|
|
2141
2236
|
else:
|
|
2142
2237
|
asyncio.run(save_text_url(github_url("providers.json"), home_providers_path))
|
|
2143
2238
|
print(f"Created default providers config at {home_providers_path}")
|
|
2239
|
+
|
|
2240
|
+
if os.path.exists(home_providers_extra_path):
|
|
2241
|
+
print(f"providers-extra.json already exists at {home_providers_extra_path}")
|
|
2242
|
+
else:
|
|
2243
|
+
asyncio.run(save_text_url(github_url("providers-extra.json"), home_providers_extra_path))
|
|
2244
|
+
print(f"Created default extra providers config at {home_providers_extra_path}")
|
|
2144
2245
|
exit(0)
|
|
2145
2246
|
|
|
2146
2247
|
if cli_args.providers:
|
|
@@ -2157,38 +2258,171 @@ def main():
|
|
|
2157
2258
|
g_config = load_config_json(config_json)
|
|
2158
2259
|
|
|
2159
2260
|
config_dir = os.path.dirname(g_config_path)
|
|
2160
|
-
# look for ui.json in same directory as config
|
|
2161
|
-
ui_path = os.path.join(config_dir, "ui.json")
|
|
2162
|
-
if os.path.exists(ui_path):
|
|
2163
|
-
g_ui_path = ui_path
|
|
2164
|
-
else:
|
|
2165
|
-
if not os.path.exists(home_ui_path):
|
|
2166
|
-
ui_json = text_from_resource("ui.json")
|
|
2167
|
-
with open(home_ui_path, "w", encoding="utf-8") as f:
|
|
2168
|
-
f.write(ui_json)
|
|
2169
|
-
_log(f"Created default ui config at {home_ui_path}")
|
|
2170
|
-
g_ui_path = home_ui_path
|
|
2171
2261
|
|
|
2172
2262
|
if not g_providers and os.path.exists(os.path.join(config_dir, "providers.json")):
|
|
2173
2263
|
g_providers = json.loads(text_from_file(os.path.join(config_dir, "providers.json")))
|
|
2174
2264
|
|
|
2175
2265
|
else:
|
|
2176
|
-
# ensure llms.json and
|
|
2266
|
+
# ensure llms.json and providers.json exist in home directory
|
|
2177
2267
|
asyncio.run(save_home_configs())
|
|
2178
2268
|
g_config_path = home_config_path
|
|
2179
|
-
g_ui_path = home_ui_path
|
|
2180
2269
|
g_config = load_config_json(text_from_file(g_config_path))
|
|
2181
2270
|
|
|
2182
2271
|
if not g_providers:
|
|
2183
2272
|
g_providers = json.loads(text_from_file(home_providers_path))
|
|
2184
2273
|
|
|
2185
|
-
if cli_args.
|
|
2274
|
+
if cli_args.update_providers:
|
|
2186
2275
|
asyncio.run(update_providers(home_providers_path))
|
|
2187
2276
|
print(f"Updated {home_providers_path}")
|
|
2188
2277
|
exit(0)
|
|
2189
2278
|
|
|
2279
|
+
# if home_providers_path is older than 1 day, update providers list
|
|
2280
|
+
if (
|
|
2281
|
+
os.path.exists(home_providers_path)
|
|
2282
|
+
and (time.time() - os.path.getmtime(home_providers_path)) > 86400
|
|
2283
|
+
and os.environ.get("LLMS_DISABLE_UPDATE", "") != "1"
|
|
2284
|
+
):
|
|
2285
|
+
try:
|
|
2286
|
+
asyncio.run(update_providers(home_providers_path))
|
|
2287
|
+
_log(f"Updated {home_providers_path}")
|
|
2288
|
+
except Exception as e:
|
|
2289
|
+
_err("Failed to update providers", e)
|
|
2290
|
+
|
|
2291
|
+
if cli_args.add is not None:
|
|
2292
|
+
if cli_args.add == "ls":
|
|
2293
|
+
|
|
2294
|
+
async def list_extensions():
|
|
2295
|
+
print("\nAvailable extensions:")
|
|
2296
|
+
text = await get_text("https://api.github.com/orgs/llmspy/repos?per_page=100&sort=updated")
|
|
2297
|
+
repos = json.loads(text)
|
|
2298
|
+
max_name_length = 0
|
|
2299
|
+
for repo in repos:
|
|
2300
|
+
max_name_length = max(max_name_length, len(repo["name"]))
|
|
2301
|
+
|
|
2302
|
+
for repo in repos:
|
|
2303
|
+
print(f" {repo['name']:<{max_name_length + 2}} {repo['description']}")
|
|
2304
|
+
|
|
2305
|
+
print("\nUsage:")
|
|
2306
|
+
print(" llms --add <extension>")
|
|
2307
|
+
print(" llms --add <github-user>/<repo>")
|
|
2308
|
+
|
|
2309
|
+
asyncio.run(list_extensions())
|
|
2310
|
+
exit(0)
|
|
2311
|
+
|
|
2312
|
+
async def install_extension(name):
|
|
2313
|
+
# Determine git URL and target directory name
|
|
2314
|
+
if "/" in name:
|
|
2315
|
+
git_url = f"https://github.com/{name}"
|
|
2316
|
+
target_name = name.split("/")[-1]
|
|
2317
|
+
else:
|
|
2318
|
+
git_url = f"https://github.com/llmspy/{name}"
|
|
2319
|
+
target_name = name
|
|
2320
|
+
|
|
2321
|
+
# check extension is not already installed
|
|
2322
|
+
extensions_path = get_extensions_path()
|
|
2323
|
+
target_path = os.path.join(extensions_path, target_name)
|
|
2324
|
+
|
|
2325
|
+
if os.path.exists(target_path):
|
|
2326
|
+
print(f"Extension {target_name} is already installed at {target_path}")
|
|
2327
|
+
return
|
|
2328
|
+
|
|
2329
|
+
print(f"Installing extension: {name}")
|
|
2330
|
+
print(f"Cloning from {git_url} to {target_path}...")
|
|
2331
|
+
|
|
2332
|
+
try:
|
|
2333
|
+
subprocess.run(["git", "clone", git_url, target_path], check=True)
|
|
2334
|
+
|
|
2335
|
+
# Check for requirements.txt
|
|
2336
|
+
requirements_path = os.path.join(target_path, "requirements.txt")
|
|
2337
|
+
if os.path.exists(requirements_path):
|
|
2338
|
+
print(f"Installing dependencies from {requirements_path}...")
|
|
2339
|
+
subprocess.run(
|
|
2340
|
+
[sys.executable, "-m", "pip", "install", "-r", "requirements.txt"], cwd=target_path, check=True
|
|
2341
|
+
)
|
|
2342
|
+
print("Dependencies installed successfully.")
|
|
2343
|
+
|
|
2344
|
+
print(f"Extension {target_name} installed successfully.")
|
|
2345
|
+
|
|
2346
|
+
except subprocess.CalledProcessError as e:
|
|
2347
|
+
print(f"Failed to install extension: {e}")
|
|
2348
|
+
# cleanup if clone failed but directory was created (unlikely with simple git clone but good practice)
|
|
2349
|
+
if os.path.exists(target_path) and not os.listdir(target_path):
|
|
2350
|
+
os.rmdir(target_path)
|
|
2351
|
+
|
|
2352
|
+
asyncio.run(install_extension(cli_args.add))
|
|
2353
|
+
exit(0)
|
|
2354
|
+
|
|
2355
|
+
if cli_args.remove is not None:
|
|
2356
|
+
if cli_args.remove == "ls":
|
|
2357
|
+
# List installed extensions
|
|
2358
|
+
extensions_path = get_extensions_path()
|
|
2359
|
+
extensions = os.listdir(extensions_path)
|
|
2360
|
+
if len(extensions) == 0:
|
|
2361
|
+
print("No extensions installed.")
|
|
2362
|
+
exit(0)
|
|
2363
|
+
print("Installed extensions:")
|
|
2364
|
+
for extension in extensions:
|
|
2365
|
+
print(f" {extension}")
|
|
2366
|
+
exit(0)
|
|
2367
|
+
# Remove an extension
|
|
2368
|
+
extension_name = cli_args.remove
|
|
2369
|
+
extensions_path = get_extensions_path()
|
|
2370
|
+
target_path = os.path.join(extensions_path, extension_name)
|
|
2371
|
+
|
|
2372
|
+
if not os.path.exists(target_path):
|
|
2373
|
+
print(f"Extension {extension_name} not found at {target_path}")
|
|
2374
|
+
exit(1)
|
|
2375
|
+
|
|
2376
|
+
print(f"Removing extension: {extension_name}...")
|
|
2377
|
+
try:
|
|
2378
|
+
shutil.rmtree(target_path)
|
|
2379
|
+
print(f"Extension {extension_name} removed successfully.")
|
|
2380
|
+
except Exception as e:
|
|
2381
|
+
print(f"Failed to remove extension: {e}")
|
|
2382
|
+
exit(1)
|
|
2383
|
+
|
|
2384
|
+
exit(0)
|
|
2385
|
+
|
|
2386
|
+
if cli_args.update:
|
|
2387
|
+
if cli_args.update == "ls":
|
|
2388
|
+
# List installed extensions
|
|
2389
|
+
extensions_path = get_extensions_path()
|
|
2390
|
+
extensions = os.listdir(extensions_path)
|
|
2391
|
+
if len(extensions) == 0:
|
|
2392
|
+
print("No extensions installed.")
|
|
2393
|
+
exit(0)
|
|
2394
|
+
print("Installed extensions:")
|
|
2395
|
+
for extension in extensions:
|
|
2396
|
+
print(f" {extension}")
|
|
2397
|
+
|
|
2398
|
+
print("\nUsage:")
|
|
2399
|
+
print(" llms --update <extension>")
|
|
2400
|
+
print(" llms --update all")
|
|
2401
|
+
exit(0)
|
|
2402
|
+
|
|
2403
|
+
async def update_extensions(extension_name):
|
|
2404
|
+
extensions_path = get_extensions_path()
|
|
2405
|
+
for extension in os.listdir(extensions_path):
|
|
2406
|
+
extension_path = os.path.join(extensions_path, extension)
|
|
2407
|
+
if os.path.isdir(extension_path):
|
|
2408
|
+
if extension_name != "all" and extension != extension_name:
|
|
2409
|
+
continue
|
|
2410
|
+
result = subprocess.run(["git", "pull"], cwd=extension_path, capture_output=True)
|
|
2411
|
+
if result.returncode != 0:
|
|
2412
|
+
print(f"Failed to update extension {extension}: {result.stderr.decode('utf-8')}")
|
|
2413
|
+
continue
|
|
2414
|
+
print(f"Updated extension {extension}")
|
|
2415
|
+
_log(result.stdout.decode("utf-8"))
|
|
2416
|
+
|
|
2417
|
+
asyncio.run(update_extensions(cli_args.update))
|
|
2418
|
+
exit(0)
|
|
2419
|
+
|
|
2420
|
+
load_builtin_extensions()
|
|
2421
|
+
|
|
2190
2422
|
asyncio.run(reload_providers())
|
|
2191
2423
|
|
|
2424
|
+
install_extensions()
|
|
2425
|
+
|
|
2192
2426
|
# print names
|
|
2193
2427
|
_log(f"enabled providers: {', '.join(g_handlers.keys())}")
|
|
2194
2428
|
|
|
@@ -2261,10 +2495,6 @@ def main():
|
|
|
2261
2495
|
# Start server
|
|
2262
2496
|
port = int(cli_args.serve)
|
|
2263
2497
|
|
|
2264
|
-
if not os.path.exists(g_ui_path):
|
|
2265
|
-
print(f"UI not found at {g_ui_path}")
|
|
2266
|
-
exit(1)
|
|
2267
|
-
|
|
2268
2498
|
# Validate auth configuration if enabled
|
|
2269
2499
|
auth_enabled = g_config.get("auth", {}).get("enabled", False)
|
|
2270
2500
|
if auth_enabled:
|
|
@@ -2274,11 +2504,19 @@ def main():
|
|
|
2274
2504
|
|
|
2275
2505
|
# Expand environment variables
|
|
2276
2506
|
if client_id.startswith("$"):
|
|
2277
|
-
client_id =
|
|
2507
|
+
client_id = client_id[1:]
|
|
2278
2508
|
if client_secret.startswith("$"):
|
|
2279
|
-
client_secret =
|
|
2509
|
+
client_secret = client_secret[1:]
|
|
2280
2510
|
|
|
2281
|
-
|
|
2511
|
+
client_id = os.environ.get(client_id, client_id)
|
|
2512
|
+
client_secret = os.environ.get(client_secret, client_secret)
|
|
2513
|
+
|
|
2514
|
+
if (
|
|
2515
|
+
not client_id
|
|
2516
|
+
or not client_secret
|
|
2517
|
+
or client_id == "GITHUB_CLIENT_ID"
|
|
2518
|
+
or client_secret == "GITHUB_CLIENT_SECRET"
|
|
2519
|
+
):
|
|
2282
2520
|
print("ERROR: Authentication is enabled but GitHub OAuth is not properly configured.")
|
|
2283
2521
|
print("Please set GITHUB_CLIENT_ID and GITHUB_CLIENT_SECRET environment variables,")
|
|
2284
2522
|
print("or disable authentication by setting 'auth.enabled' to false in llms.json")
|
|
@@ -2299,7 +2537,7 @@ def main():
|
|
|
2299
2537
|
return True, None
|
|
2300
2538
|
|
|
2301
2539
|
# Check for OAuth session token
|
|
2302
|
-
session_token =
|
|
2540
|
+
session_token = get_session_token(request)
|
|
2303
2541
|
if session_token and session_token in g_sessions:
|
|
2304
2542
|
return True, g_sessions[session_token]
|
|
2305
2543
|
|
|
@@ -2329,13 +2567,32 @@ def main():
|
|
|
2329
2567
|
|
|
2330
2568
|
try:
|
|
2331
2569
|
chat = await request.json()
|
|
2570
|
+
|
|
2571
|
+
# Apply pre-chat filters
|
|
2572
|
+
context = {"request": request}
|
|
2573
|
+
# Apply pre-chat filters
|
|
2574
|
+
context = {"request": request}
|
|
2575
|
+
for filter_func in g_app.chat_request_filters:
|
|
2576
|
+
chat = await filter_func(chat, context)
|
|
2577
|
+
|
|
2332
2578
|
response = await chat_completion(chat)
|
|
2579
|
+
|
|
2580
|
+
# Apply post-chat filters
|
|
2581
|
+
# Apply post-chat filters
|
|
2582
|
+
for filter_func in g_app.chat_response_filters:
|
|
2583
|
+
response = await filter_func(response, context)
|
|
2584
|
+
|
|
2333
2585
|
return web.json_response(response)
|
|
2334
2586
|
except Exception as e:
|
|
2335
2587
|
return web.json_response({"error": str(e)}, status=500)
|
|
2336
2588
|
|
|
2337
2589
|
app.router.add_post("/v1/chat/completions", chat_handler)
|
|
2338
2590
|
|
|
2591
|
+
async def extensions_handler(request):
|
|
2592
|
+
return web.json_response(g_app.ui_extensions)
|
|
2593
|
+
|
|
2594
|
+
app.router.add_get("/ext", extensions_handler)
|
|
2595
|
+
|
|
2339
2596
|
async def models_handler(request):
|
|
2340
2597
|
return web.json_response(get_models())
|
|
2341
2598
|
|
|
@@ -2370,8 +2627,9 @@ def main():
|
|
|
2370
2627
|
if provider:
|
|
2371
2628
|
if data.get("enable", False):
|
|
2372
2629
|
provider_config, msg = enable_provider(provider)
|
|
2373
|
-
_log(f"Enabled provider {provider}")
|
|
2374
|
-
|
|
2630
|
+
_log(f"Enabled provider {provider} {msg}")
|
|
2631
|
+
if not msg:
|
|
2632
|
+
await load_llms()
|
|
2375
2633
|
elif data.get("disable", False):
|
|
2376
2634
|
disable_provider(provider)
|
|
2377
2635
|
_log(f"Disabled provider {provider}")
|
|
@@ -2491,7 +2749,7 @@ def main():
|
|
|
2491
2749
|
except Exception:
|
|
2492
2750
|
return web.Response(text="403: Forbidden", status=403)
|
|
2493
2751
|
|
|
2494
|
-
with open(info_path
|
|
2752
|
+
with open(info_path) as f:
|
|
2495
2753
|
content = f.read()
|
|
2496
2754
|
return web.Response(text=content, content_type="application/json")
|
|
2497
2755
|
|
|
@@ -2527,9 +2785,12 @@ def main():
|
|
|
2527
2785
|
|
|
2528
2786
|
# Expand environment variables
|
|
2529
2787
|
if client_id.startswith("$"):
|
|
2530
|
-
client_id =
|
|
2788
|
+
client_id = client_id[1:]
|
|
2531
2789
|
if redirect_uri.startswith("$"):
|
|
2532
|
-
redirect_uri =
|
|
2790
|
+
redirect_uri = redirect_uri[1:]
|
|
2791
|
+
|
|
2792
|
+
client_id = os.environ.get(client_id, client_id)
|
|
2793
|
+
redirect_uri = os.environ.get(redirect_uri, redirect_uri)
|
|
2533
2794
|
|
|
2534
2795
|
if not client_id:
|
|
2535
2796
|
return web.json_response({"error": "GitHub client_id not configured"}, status=500)
|
|
@@ -2562,7 +2823,9 @@ def main():
|
|
|
2562
2823
|
|
|
2563
2824
|
# Expand environment variables
|
|
2564
2825
|
if restrict_to.startswith("$"):
|
|
2565
|
-
restrict_to =
|
|
2826
|
+
restrict_to = restrict_to[1:]
|
|
2827
|
+
|
|
2828
|
+
restrict_to = os.environ.get(restrict_to, None if restrict_to == "GITHUB_USERS" else restrict_to)
|
|
2566
2829
|
|
|
2567
2830
|
# If restrict_to is configured, validate the user
|
|
2568
2831
|
if restrict_to:
|
|
@@ -2583,6 +2846,14 @@ def main():
|
|
|
2583
2846
|
code = request.query.get("code")
|
|
2584
2847
|
state = request.query.get("state")
|
|
2585
2848
|
|
|
2849
|
+
# Handle malformed URLs where query params are appended with & instead of ?
|
|
2850
|
+
if not code and "tail" in request.match_info:
|
|
2851
|
+
tail = request.match_info["tail"]
|
|
2852
|
+
if tail.startswith("&"):
|
|
2853
|
+
params = parse_qs(tail[1:])
|
|
2854
|
+
code = params.get("code", [None])[0]
|
|
2855
|
+
state = params.get("state", [None])[0]
|
|
2856
|
+
|
|
2586
2857
|
if not code or not state:
|
|
2587
2858
|
return web.Response(text="Missing code or state parameter", status=400)
|
|
2588
2859
|
|
|
@@ -2602,11 +2873,15 @@ def main():
|
|
|
2602
2873
|
|
|
2603
2874
|
# Expand environment variables
|
|
2604
2875
|
if client_id.startswith("$"):
|
|
2605
|
-
client_id =
|
|
2876
|
+
client_id = client_id[1:]
|
|
2606
2877
|
if client_secret.startswith("$"):
|
|
2607
|
-
client_secret =
|
|
2878
|
+
client_secret = client_secret[1:]
|
|
2608
2879
|
if redirect_uri.startswith("$"):
|
|
2609
|
-
redirect_uri =
|
|
2880
|
+
redirect_uri = redirect_uri[1:]
|
|
2881
|
+
|
|
2882
|
+
client_id = os.environ.get(client_id, client_id)
|
|
2883
|
+
client_secret = os.environ.get(client_secret, client_secret)
|
|
2884
|
+
redirect_uri = os.environ.get(redirect_uri, redirect_uri)
|
|
2610
2885
|
|
|
2611
2886
|
if not client_id or not client_secret:
|
|
2612
2887
|
return web.json_response({"error": "GitHub OAuth credentials not configured"}, status=500)
|
|
@@ -2654,11 +2929,13 @@ def main():
|
|
|
2654
2929
|
}
|
|
2655
2930
|
|
|
2656
2931
|
# Redirect to UI with session token
|
|
2657
|
-
|
|
2932
|
+
response = web.HTTPFound(f"/?session={session_token}")
|
|
2933
|
+
response.set_cookie("llms-token", session_token, httponly=True, path="/", max_age=86400)
|
|
2934
|
+
return response
|
|
2658
2935
|
|
|
2659
2936
|
async def session_handler(request):
|
|
2660
2937
|
"""Validate and return session info"""
|
|
2661
|
-
session_token =
|
|
2938
|
+
session_token = get_session_token(request)
|
|
2662
2939
|
|
|
2663
2940
|
if not session_token or session_token not in g_sessions:
|
|
2664
2941
|
return web.json_response({"error": "Invalid or expired session"}, status=401)
|
|
@@ -2675,17 +2952,19 @@ def main():
|
|
|
2675
2952
|
|
|
2676
2953
|
async def logout_handler(request):
|
|
2677
2954
|
"""End OAuth session"""
|
|
2678
|
-
session_token =
|
|
2955
|
+
session_token = get_session_token(request)
|
|
2679
2956
|
|
|
2680
2957
|
if session_token and session_token in g_sessions:
|
|
2681
2958
|
del g_sessions[session_token]
|
|
2682
2959
|
|
|
2683
|
-
|
|
2960
|
+
response = web.json_response({"success": True})
|
|
2961
|
+
response.del_cookie("llms-token")
|
|
2962
|
+
return response
|
|
2684
2963
|
|
|
2685
2964
|
async def auth_handler(request):
|
|
2686
2965
|
"""Check authentication status and return user info"""
|
|
2687
2966
|
# Check for OAuth session token
|
|
2688
|
-
session_token =
|
|
2967
|
+
session_token = get_session_token(request)
|
|
2689
2968
|
|
|
2690
2969
|
if session_token and session_token in g_sessions:
|
|
2691
2970
|
session_data = g_sessions[session_token]
|
|
@@ -2722,6 +3001,7 @@ def main():
|
|
|
2722
3001
|
app.router.add_get("/auth", auth_handler)
|
|
2723
3002
|
app.router.add_get("/auth/github", github_auth_handler)
|
|
2724
3003
|
app.router.add_get("/auth/github/callback", github_callback_handler)
|
|
3004
|
+
app.router.add_get("/auth/github/callback{tail:.*}", github_callback_handler)
|
|
2725
3005
|
app.router.add_get("/auth/session", session_handler)
|
|
2726
3006
|
app.router.add_post("/auth/logout", logout_handler)
|
|
2727
3007
|
|
|
@@ -2756,25 +3036,30 @@ def main():
|
|
|
2756
3036
|
|
|
2757
3037
|
app.router.add_get("/ui/{path:.*}", ui_static, name="ui_static")
|
|
2758
3038
|
|
|
2759
|
-
async def
|
|
2760
|
-
|
|
2761
|
-
|
|
2762
|
-
|
|
2763
|
-
|
|
2764
|
-
|
|
2765
|
-
|
|
2766
|
-
|
|
2767
|
-
|
|
2768
|
-
|
|
2769
|
-
return web.json_response(ui)
|
|
3039
|
+
async def config_handler(request):
|
|
3040
|
+
ret = {}
|
|
3041
|
+
if "defaults" not in ret:
|
|
3042
|
+
ret["defaults"] = g_config["defaults"]
|
|
3043
|
+
enabled, disabled = provider_status()
|
|
3044
|
+
ret["status"] = {"all": list(g_config["providers"].keys()), "enabled": enabled, "disabled": disabled}
|
|
3045
|
+
# Add auth configuration
|
|
3046
|
+
ret["requiresAuth"] = auth_enabled
|
|
3047
|
+
ret["authType"] = "oauth" if auth_enabled else "apikey"
|
|
3048
|
+
return web.json_response(ret)
|
|
2770
3049
|
|
|
2771
|
-
app.router.add_get("/config",
|
|
3050
|
+
app.router.add_get("/config", config_handler)
|
|
2772
3051
|
|
|
2773
3052
|
async def not_found_handler(request):
|
|
2774
3053
|
return web.Response(text="404: Not Found", status=404)
|
|
2775
3054
|
|
|
2776
3055
|
app.router.add_get("/favicon.ico", not_found_handler)
|
|
2777
3056
|
|
|
3057
|
+
# go through and register all g_app extensions
|
|
3058
|
+
for handler in g_app.server_add_get:
|
|
3059
|
+
app.router.add_get(handler[0], handler[1], **handler[2])
|
|
3060
|
+
for handler in g_app.server_add_post:
|
|
3061
|
+
app.router.add_post(handler[0], handler[1], **handler[2])
|
|
3062
|
+
|
|
2778
3063
|
# Serve index.html from root
|
|
2779
3064
|
async def index_handler(request):
|
|
2780
3065
|
index_content = read_resource_file_bytes("index.html")
|
|
@@ -2791,10 +3076,12 @@ def main():
|
|
|
2791
3076
|
async def start_background_tasks(app):
|
|
2792
3077
|
"""Start background tasks when the app starts"""
|
|
2793
3078
|
# Start watching config files in the background
|
|
2794
|
-
asyncio.create_task(watch_config_files(g_config_path,
|
|
3079
|
+
asyncio.create_task(watch_config_files(g_config_path, home_providers_path))
|
|
2795
3080
|
|
|
2796
3081
|
app.on_startup.append(start_background_tasks)
|
|
2797
3082
|
|
|
3083
|
+
# go through and register all g_app extensions
|
|
3084
|
+
|
|
2798
3085
|
print(f"Starting server on port {port}...")
|
|
2799
3086
|
web.run_app(app, host="0.0.0.0", port=port, print=_log)
|
|
2800
3087
|
exit(0)
|
|
@@ -2869,6 +3156,7 @@ def main():
|
|
|
2869
3156
|
or cli_args.image is not None
|
|
2870
3157
|
or cli_args.audio is not None
|
|
2871
3158
|
or cli_args.file is not None
|
|
3159
|
+
or cli_args.out is not None
|
|
2872
3160
|
or len(extra_args) > 0
|
|
2873
3161
|
):
|
|
2874
3162
|
try:
|
|
@@ -2879,6 +3167,12 @@ def main():
|
|
|
2879
3167
|
chat = g_config["defaults"]["audio"]
|
|
2880
3168
|
elif cli_args.file is not None:
|
|
2881
3169
|
chat = g_config["defaults"]["file"]
|
|
3170
|
+
elif cli_args.out is not None:
|
|
3171
|
+
template = f"out:{cli_args.out}"
|
|
3172
|
+
if template not in g_config["defaults"]:
|
|
3173
|
+
print(f"Template for output modality '{cli_args.out}' not found")
|
|
3174
|
+
exit(1)
|
|
3175
|
+
chat = g_config["defaults"][template]
|
|
2882
3176
|
if cli_args.chat is not None:
|
|
2883
3177
|
chat_path = os.path.join(os.path.dirname(__file__), cli_args.chat)
|
|
2884
3178
|
if not os.path.exists(chat_path):
|
|
@@ -2922,9 +3216,14 @@ def main():
|
|
|
2922
3216
|
traceback.print_exc()
|
|
2923
3217
|
exit(1)
|
|
2924
3218
|
|
|
2925
|
-
|
|
2926
|
-
|
|
3219
|
+
handled = run_extension_cli()
|
|
3220
|
+
|
|
3221
|
+
if not handled:
|
|
3222
|
+
# show usage from ArgumentParser
|
|
3223
|
+
parser.print_help()
|
|
2927
3224
|
|
|
2928
3225
|
|
|
2929
3226
|
if __name__ == "__main__":
|
|
3227
|
+
if MOCK or DEBUG:
|
|
3228
|
+
print(f"MOCK={MOCK} or DEBUG={DEBUG}")
|
|
2930
3229
|
main()
|