lemonade-sdk 9.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lemonade/__init__.py +5 -0
- lemonade/api.py +180 -0
- lemonade/cache.py +92 -0
- lemonade/cli.py +173 -0
- lemonade/common/__init__.py +0 -0
- lemonade/common/build.py +176 -0
- lemonade/common/cli_helpers.py +139 -0
- lemonade/common/exceptions.py +98 -0
- lemonade/common/filesystem.py +368 -0
- lemonade/common/inference_engines.py +408 -0
- lemonade/common/network.py +93 -0
- lemonade/common/printing.py +110 -0
- lemonade/common/status.py +471 -0
- lemonade/common/system_info.py +1411 -0
- lemonade/common/test_helpers.py +28 -0
- lemonade/profilers/__init__.py +1 -0
- lemonade/profilers/agt_power.py +437 -0
- lemonade/profilers/hwinfo_power.py +429 -0
- lemonade/profilers/memory_tracker.py +259 -0
- lemonade/profilers/profiler.py +58 -0
- lemonade/sequence.py +363 -0
- lemonade/state.py +159 -0
- lemonade/tools/__init__.py +1 -0
- lemonade/tools/accuracy.py +432 -0
- lemonade/tools/adapter.py +114 -0
- lemonade/tools/bench.py +302 -0
- lemonade/tools/flm/__init__.py +1 -0
- lemonade/tools/flm/utils.py +305 -0
- lemonade/tools/huggingface/bench.py +187 -0
- lemonade/tools/huggingface/load.py +235 -0
- lemonade/tools/huggingface/utils.py +359 -0
- lemonade/tools/humaneval.py +264 -0
- lemonade/tools/llamacpp/bench.py +255 -0
- lemonade/tools/llamacpp/load.py +222 -0
- lemonade/tools/llamacpp/utils.py +1260 -0
- lemonade/tools/management_tools.py +319 -0
- lemonade/tools/mmlu.py +319 -0
- lemonade/tools/oga/__init__.py +0 -0
- lemonade/tools/oga/bench.py +120 -0
- lemonade/tools/oga/load.py +804 -0
- lemonade/tools/oga/migration.py +403 -0
- lemonade/tools/oga/utils.py +462 -0
- lemonade/tools/perplexity.py +147 -0
- lemonade/tools/prompt.py +263 -0
- lemonade/tools/report/__init__.py +0 -0
- lemonade/tools/report/llm_report.py +203 -0
- lemonade/tools/report/table.py +899 -0
- lemonade/tools/server/__init__.py +0 -0
- lemonade/tools/server/flm.py +133 -0
- lemonade/tools/server/llamacpp.py +320 -0
- lemonade/tools/server/serve.py +2123 -0
- lemonade/tools/server/static/favicon.ico +0 -0
- lemonade/tools/server/static/index.html +279 -0
- lemonade/tools/server/static/js/chat.js +1059 -0
- lemonade/tools/server/static/js/model-settings.js +183 -0
- lemonade/tools/server/static/js/models.js +1395 -0
- lemonade/tools/server/static/js/shared.js +556 -0
- lemonade/tools/server/static/logs.html +191 -0
- lemonade/tools/server/static/styles.css +2654 -0
- lemonade/tools/server/static/webapp.html +321 -0
- lemonade/tools/server/tool_calls.py +153 -0
- lemonade/tools/server/tray.py +664 -0
- lemonade/tools/server/utils/macos_tray.py +226 -0
- lemonade/tools/server/utils/port.py +77 -0
- lemonade/tools/server/utils/thread.py +85 -0
- lemonade/tools/server/utils/windows_tray.py +408 -0
- lemonade/tools/server/webapp.py +34 -0
- lemonade/tools/server/wrapped_server.py +559 -0
- lemonade/tools/tool.py +374 -0
- lemonade/version.py +1 -0
- lemonade_install/__init__.py +1 -0
- lemonade_install/install.py +239 -0
- lemonade_sdk-9.1.1.dist-info/METADATA +276 -0
- lemonade_sdk-9.1.1.dist-info/RECORD +84 -0
- lemonade_sdk-9.1.1.dist-info/WHEEL +5 -0
- lemonade_sdk-9.1.1.dist-info/entry_points.txt +5 -0
- lemonade_sdk-9.1.1.dist-info/licenses/LICENSE +201 -0
- lemonade_sdk-9.1.1.dist-info/licenses/NOTICE.md +47 -0
- lemonade_sdk-9.1.1.dist-info/top_level.txt +3 -0
- lemonade_server/cli.py +805 -0
- lemonade_server/model_manager.py +758 -0
- lemonade_server/pydantic_models.py +159 -0
- lemonade_server/server_models.json +643 -0
- lemonade_server/settings.py +39 -0
|
@@ -0,0 +1,2123 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import asyncio
|
|
3
|
+
import statistics
|
|
4
|
+
import time
|
|
5
|
+
from threading import Thread, Event
|
|
6
|
+
import logging
|
|
7
|
+
import platform
|
|
8
|
+
import tempfile
|
|
9
|
+
import traceback
|
|
10
|
+
from typing import Optional, Union, List
|
|
11
|
+
import json
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
import os
|
|
14
|
+
import shutil
|
|
15
|
+
from fastapi import FastAPI, HTTPException, status, Request, WebSocket, Form, UploadFile
|
|
16
|
+
from fastapi.responses import StreamingResponse
|
|
17
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
18
|
+
from fastapi.staticfiles import StaticFiles
|
|
19
|
+
from starlette.websockets import WebSocketDisconnect, WebSocketState
|
|
20
|
+
import uvicorn
|
|
21
|
+
from uvicorn.config import Config
|
|
22
|
+
from uvicorn.server import Server as UvicornServer
|
|
23
|
+
from tabulate import tabulate
|
|
24
|
+
|
|
25
|
+
from openai.types.completion import Completion, CompletionChoice
|
|
26
|
+
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
|
27
|
+
from openai.types.chat import ChatCompletionMessage
|
|
28
|
+
from openai.types.chat.chat_completion_message_tool_call import (
|
|
29
|
+
ChatCompletionMessageToolCall,
|
|
30
|
+
Function,
|
|
31
|
+
)
|
|
32
|
+
from openai.types.completion_usage import CompletionUsage
|
|
33
|
+
from openai.types.chat.chat_completion import Choice
|
|
34
|
+
from openai.types.chat.chat_completion_chunk import (
|
|
35
|
+
ChoiceDelta,
|
|
36
|
+
ChoiceDeltaToolCall,
|
|
37
|
+
ChoiceDeltaToolCallFunction,
|
|
38
|
+
)
|
|
39
|
+
from openai.types.completion_choice import Logprobs
|
|
40
|
+
from openai.types.model import Model
|
|
41
|
+
from openai.types.responses import (
|
|
42
|
+
Response,
|
|
43
|
+
ResponseOutputMessage,
|
|
44
|
+
ResponseOutputText,
|
|
45
|
+
ResponseCreatedEvent,
|
|
46
|
+
ResponseTextDeltaEvent,
|
|
47
|
+
ResponseCompletedEvent,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
import lemonade.api as lemonade_api
|
|
51
|
+
from lemonade.tools.server.wrapped_server import WrappedServer
|
|
52
|
+
from lemonade.tools.server.llamacpp import LlamaServer
|
|
53
|
+
from lemonade.tools.server.flm import FlmServer
|
|
54
|
+
from lemonade.tools.server.tool_calls import extract_tool_calls, get_tool_call_pattern
|
|
55
|
+
from lemonade.tools.server.webapp import get_webapp_html
|
|
56
|
+
from lemonade.tools.server.utils.port import lifespan
|
|
57
|
+
|
|
58
|
+
from lemonade_server.model_manager import ModelManager
|
|
59
|
+
from lemonade_server.pydantic_models import (
|
|
60
|
+
DEFAULT_PORT,
|
|
61
|
+
DEFAULT_HOST,
|
|
62
|
+
DEFAULT_LOG_LEVEL,
|
|
63
|
+
DEFAULT_LLAMACPP_BACKEND,
|
|
64
|
+
DEFAULT_CTX_SIZE,
|
|
65
|
+
LoadConfig,
|
|
66
|
+
CompletionRequest,
|
|
67
|
+
ChatCompletionRequest,
|
|
68
|
+
EmbeddingsRequest,
|
|
69
|
+
RerankingRequest,
|
|
70
|
+
ResponsesRequest,
|
|
71
|
+
PullConfig,
|
|
72
|
+
DeleteConfig,
|
|
73
|
+
LogLevelConfig,
|
|
74
|
+
)
|
|
75
|
+
from lemonade_server.settings import save_setting
|
|
76
|
+
|
|
77
|
+
# Set to a high number to allow for interesting experiences in real apps
|
|
78
|
+
# Tests should use the max_new_tokens argument to set a lower value
|
|
79
|
+
DEFAULT_MAX_NEW_TOKENS = 1500
|
|
80
|
+
|
|
81
|
+
if platform.system() in ["Windows", "Darwin"]:
|
|
82
|
+
# pylint: disable=ungrouped-imports
|
|
83
|
+
from lemonade.tools.server.tray import LemonadeTray, OutputDuplicator
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class ServerLogFilter(logging.Filter):
|
|
87
|
+
def __init__(self, server):
|
|
88
|
+
super().__init__()
|
|
89
|
+
self.server = server
|
|
90
|
+
self.noisy_paths = {
|
|
91
|
+
"/api/v1/health",
|
|
92
|
+
"/api/v0/health",
|
|
93
|
+
"/api/v1/models",
|
|
94
|
+
"/api/v0/models",
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
def filter(self, record: logging.LogRecord) -> bool:
|
|
98
|
+
msg = record.getMessage()
|
|
99
|
+
|
|
100
|
+
# Filter out websocket logs
|
|
101
|
+
if "> TEXT" in msg:
|
|
102
|
+
return False
|
|
103
|
+
|
|
104
|
+
# Filter out noisy HTTP routes if debug logs are OFF
|
|
105
|
+
if not self.server.debug_logging_enabled:
|
|
106
|
+
if any(path in msg for path in self.noisy_paths):
|
|
107
|
+
return False
|
|
108
|
+
|
|
109
|
+
# Otherwise, allow the log
|
|
110
|
+
return True
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
async def log_streamer(websocket: WebSocket, path: str, interval: float = 1.0):
|
|
114
|
+
logger = logging.getLogger()
|
|
115
|
+
await websocket.accept()
|
|
116
|
+
try:
|
|
117
|
+
with open(path, "r", encoding="utf-8") as f:
|
|
118
|
+
f.seek(0) # start at the beginning of the file
|
|
119
|
+
while True:
|
|
120
|
+
# Try reading a line
|
|
121
|
+
line = f.readline()
|
|
122
|
+
if not line:
|
|
123
|
+
await asyncio.sleep(interval)
|
|
124
|
+
continue
|
|
125
|
+
|
|
126
|
+
# Send defensively: if disconnected, bail out
|
|
127
|
+
if websocket.application_state != WebSocketState.CONNECTED:
|
|
128
|
+
# Server-side state says we're not connected anymore
|
|
129
|
+
break
|
|
130
|
+
|
|
131
|
+
try:
|
|
132
|
+
await websocket.send_text(line)
|
|
133
|
+
except WebSocketDisconnect:
|
|
134
|
+
# Client closed — normal path out
|
|
135
|
+
break
|
|
136
|
+
except RuntimeError as re:
|
|
137
|
+
# Starlette will raise this if a close has already been sent
|
|
138
|
+
logger.debug("RuntimeError during send: %s", re)
|
|
139
|
+
break
|
|
140
|
+
|
|
141
|
+
except WebSocketDisconnect:
|
|
142
|
+
# Client closed the socket; do not try to send or close again
|
|
143
|
+
pass
|
|
144
|
+
except Exception as e: # pylint: disable=broad-except
|
|
145
|
+
# Log server-side; do not attempt to send error over a possibly closed socket
|
|
146
|
+
logger.exception("Error in log_streamer: %s", e)
|
|
147
|
+
finally:
|
|
148
|
+
# Only close if Starlette still thinks we're connected.
|
|
149
|
+
# This prevents "Cannot call send once a close message has been sent."
|
|
150
|
+
try:
|
|
151
|
+
if websocket.application_state == WebSocketState.CONNECTED:
|
|
152
|
+
await websocket.close()
|
|
153
|
+
except Exception: # pylint: disable=broad-except
|
|
154
|
+
# If close itself races, swallow — we're shutting down anyway.
|
|
155
|
+
pass
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class ServerModel(Model):
|
|
159
|
+
"""
|
|
160
|
+
An extension of OpenAI's Model class that adds
|
|
161
|
+
checkpoint and recipe attributes.
|
|
162
|
+
"""
|
|
163
|
+
|
|
164
|
+
checkpoint: str
|
|
165
|
+
recipe: str
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
class GeneratorThread(Thread):
|
|
169
|
+
"""
|
|
170
|
+
Thread class designed for use with streaming generation within
|
|
171
|
+
an LLM server. It needs access to the streamer in order to order
|
|
172
|
+
to help the completions APIs escape the "for text in streamer" loop.
|
|
173
|
+
It also provides exception handling that works nicely with HTTP
|
|
174
|
+
servers by providing the stack trace and making the exception
|
|
175
|
+
information available to the main thread.
|
|
176
|
+
"""
|
|
177
|
+
|
|
178
|
+
def __init__(self, streamer, *args, **kwargs):
|
|
179
|
+
super().__init__(*args, **kwargs)
|
|
180
|
+
self.exception = None
|
|
181
|
+
self.streamer = streamer
|
|
182
|
+
|
|
183
|
+
def run(self):
|
|
184
|
+
try:
|
|
185
|
+
if self._target:
|
|
186
|
+
self._target(*self._args, **self._kwargs)
|
|
187
|
+
except Exception as e: # pylint: disable=broad-except
|
|
188
|
+
self.exception = e
|
|
189
|
+
logging.error(f"Exception raised in generate thread: {e}")
|
|
190
|
+
traceback.print_exc()
|
|
191
|
+
self.streamer.done()
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
class StopOnEvent:
|
|
195
|
+
"""
|
|
196
|
+
Custom stopping criteria that halts text generation when a specified event is set.
|
|
197
|
+
|
|
198
|
+
This allows for external control of generation, such as stopping a generation
|
|
199
|
+
before it reaches the maximum token limit.
|
|
200
|
+
"""
|
|
201
|
+
|
|
202
|
+
def __init__(self, stop_event: Event):
|
|
203
|
+
super().__init__()
|
|
204
|
+
self.stop_event = stop_event
|
|
205
|
+
|
|
206
|
+
def __call__(self, input_ids, scores, **kwargs):
|
|
207
|
+
return self.stop_event.is_set()
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
class NoCacheStaticFiles(StaticFiles):
|
|
211
|
+
"""Custom StaticFiles class with no-cache headers"""
|
|
212
|
+
|
|
213
|
+
def __init__(self, *args, **kwargs):
|
|
214
|
+
super().__init__(*args, **kwargs)
|
|
215
|
+
|
|
216
|
+
def file_response(self, *args, **kwargs) -> Response:
|
|
217
|
+
response = super().file_response(*args, **kwargs)
|
|
218
|
+
# Add no-cache headers for all static files
|
|
219
|
+
response.headers["Cache-Control"] = "no-cache, no-store, must-revalidate"
|
|
220
|
+
response.headers["Pragma"] = "no-cache"
|
|
221
|
+
response.headers["Expires"] = "0"
|
|
222
|
+
return response
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
class Server:
|
|
226
|
+
"""
|
|
227
|
+
Open a web server that apps can use to communicate with the LLM.
|
|
228
|
+
|
|
229
|
+
The server exposes these endpoints:
|
|
230
|
+
- /api/v1/pull: install an LLM by its Lemonade Server Model Name.
|
|
231
|
+
- /api/v1/delete: delete an LLM by its Lemonade Server Model Name.
|
|
232
|
+
- /api/v1/load: load a model checkpoint.
|
|
233
|
+
- /api/v1/unload: unload a model checkpoint.
|
|
234
|
+
- /api/v1/health: check whether a model is loaded and ready to serve.
|
|
235
|
+
- /api/v1/stats: performance statistics for the generation.
|
|
236
|
+
- /api/v1/halt: stop an in-progress generation from make more tokens.
|
|
237
|
+
- /api/v1/completions: completion responses using HTTP chunked transfer encoding.
|
|
238
|
+
- /api/v1/chat/completions: chat completion responses using HTTP chunked transfer encoding.
|
|
239
|
+
- /api/v1/responses: responses API using HTTP chunked transfer encoding.
|
|
240
|
+
- /api/v1/models: list all available models.
|
|
241
|
+
- /api/v1/models/{model_id}: retrieve a specific model by ID.
|
|
242
|
+
"""
|
|
243
|
+
|
|
244
|
+
def __init__(
|
|
245
|
+
self,
|
|
246
|
+
port: int = DEFAULT_PORT,
|
|
247
|
+
host: str = DEFAULT_HOST,
|
|
248
|
+
log_level: str = DEFAULT_LOG_LEVEL,
|
|
249
|
+
ctx_size: int = DEFAULT_CTX_SIZE,
|
|
250
|
+
tray: bool = False,
|
|
251
|
+
log_file: str = None,
|
|
252
|
+
llamacpp_backend: str = DEFAULT_LLAMACPP_BACKEND,
|
|
253
|
+
):
|
|
254
|
+
super().__init__()
|
|
255
|
+
|
|
256
|
+
# Save args as members
|
|
257
|
+
self.port = port
|
|
258
|
+
self.host = host
|
|
259
|
+
self.log_level = log_level
|
|
260
|
+
self.ctx_size = ctx_size
|
|
261
|
+
self.tray = tray
|
|
262
|
+
self.log_file = log_file
|
|
263
|
+
self.llamacpp_backend = llamacpp_backend
|
|
264
|
+
|
|
265
|
+
# Initialize FastAPI app
|
|
266
|
+
self.app = FastAPI(lifespan=lifespan)
|
|
267
|
+
|
|
268
|
+
# Lifespan will load some tasks in the background, and then set the
|
|
269
|
+
# app.initialized flag to True when this is done
|
|
270
|
+
self.app.initialized = False
|
|
271
|
+
|
|
272
|
+
# Add CORS middleware
|
|
273
|
+
self.app.add_middleware(
|
|
274
|
+
CORSMiddleware,
|
|
275
|
+
allow_origins=["*"], # Allows all origins
|
|
276
|
+
allow_credentials=True,
|
|
277
|
+
allow_methods=["*"], # Allows all methods
|
|
278
|
+
allow_headers=["*"], # Allows all headers
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
# Set up debug middleware if debug logging is enabled
|
|
282
|
+
# This must be done during app initialization, not at runtime
|
|
283
|
+
self.debug_logging_enabled = log_level == "debug"
|
|
284
|
+
if self.debug_logging_enabled:
|
|
285
|
+
self.setup_middleware_timer()
|
|
286
|
+
|
|
287
|
+
# Set up custom routes
|
|
288
|
+
self.setup_routes(["/api/v0", "/api/v1"])
|
|
289
|
+
|
|
290
|
+
# Set up Web App
|
|
291
|
+
self.app.get("/")(self.webapp)
|
|
292
|
+
|
|
293
|
+
# Mount a static assets dir for HTML responses, such
|
|
294
|
+
# as the Web App
|
|
295
|
+
static_dir = Path(__file__).parent / "static"
|
|
296
|
+
self.app.mount(
|
|
297
|
+
"/static", NoCacheStaticFiles(directory=static_dir), name="static_assets"
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
# Performance stats that are set during /ws and can be
|
|
301
|
+
# fetched in /stats
|
|
302
|
+
self.time_to_first_token = None
|
|
303
|
+
self.tokens_per_second = None
|
|
304
|
+
self.input_tokens = None
|
|
305
|
+
self.output_tokens = None
|
|
306
|
+
self.decode_token_times = None
|
|
307
|
+
|
|
308
|
+
# Store debug logging state
|
|
309
|
+
self.debug_logging_enabled = logging.getLogger().isEnabledFor(logging.DEBUG)
|
|
310
|
+
|
|
311
|
+
# Flag that tells the LLM to stop generating text and end the response
|
|
312
|
+
self.stop_event = Event()
|
|
313
|
+
|
|
314
|
+
self.llm_loaded: LoadConfig = None
|
|
315
|
+
self.tokenizer = None
|
|
316
|
+
|
|
317
|
+
# Placeholders for model and configs
|
|
318
|
+
self.model = None
|
|
319
|
+
|
|
320
|
+
# Initialize semaphore for tracking active generations
|
|
321
|
+
self.max_concurrent_generations = 1
|
|
322
|
+
self._generate_semaphore = asyncio.Semaphore(self.max_concurrent_generations)
|
|
323
|
+
|
|
324
|
+
# Dictionary of installed LLM, by model name : information about those models
|
|
325
|
+
# Does not include non-installed models
|
|
326
|
+
self.local_models = ModelManager().downloaded_models_enabled
|
|
327
|
+
|
|
328
|
+
# Add lock for load/unload operations
|
|
329
|
+
self._load_lock = asyncio.Lock()
|
|
330
|
+
|
|
331
|
+
# Subprocess handle for wrapped instance of llama_server.exe, etc.
|
|
332
|
+
self.wrapped_server: WrappedServer = None
|
|
333
|
+
|
|
334
|
+
def setup_routes(self, api_prefixes: list[str]):
|
|
335
|
+
for prefix in api_prefixes:
|
|
336
|
+
# Custom routes
|
|
337
|
+
self.app.post(f"{prefix}/pull")(self.pull)
|
|
338
|
+
self.app.post(f"{prefix}/delete")(self.delete)
|
|
339
|
+
self.app.post(f"{prefix}/load")(self.load_llm)
|
|
340
|
+
self.app.post(f"{prefix}/unload")(self.unload_llm)
|
|
341
|
+
self.app.get(f"{prefix}/health")(self.health)
|
|
342
|
+
self.app.get(f"{prefix}/halt")(self.halt_generation)
|
|
343
|
+
self.app.get(f"{prefix}/stats")(self.send_stats)
|
|
344
|
+
self.app.get(f"{prefix}/system-info")(self.get_system_info)
|
|
345
|
+
self.app.post(f"{prefix}/completions")(self.completions)
|
|
346
|
+
self.app.post(f"{prefix}/responses")(self.responses)
|
|
347
|
+
self.app.post(f"{prefix}/log-level")(self.set_log_level)
|
|
348
|
+
self.app.websocket(f"{prefix}/logs/ws")(self.logs_ws)
|
|
349
|
+
self.app.post(f"{prefix}/add-local-model")(self.add_local_model)
|
|
350
|
+
|
|
351
|
+
# OpenAI-compatible routes
|
|
352
|
+
self.app.post(f"{prefix}/chat/completions")(self.chat_completions)
|
|
353
|
+
self.app.post(f"{prefix}/embeddings")(self.embeddings)
|
|
354
|
+
self.app.get(f"{prefix}/models")(self.models)
|
|
355
|
+
self.app.get(f"{prefix}/models/{{model_id}}")(self.retrieve_model)
|
|
356
|
+
|
|
357
|
+
# JinaAI routes (jina.ai/reranker/)
|
|
358
|
+
self.app.post(f"{prefix}/reranking")(self.reranking)
|
|
359
|
+
self.app.post(f"{prefix}/rerank")(self.reranking)
|
|
360
|
+
|
|
361
|
+
# Migration routes
|
|
362
|
+
self.app.get(f"{prefix}/migration/incompatible-models")(
|
|
363
|
+
self.get_incompatible_models
|
|
364
|
+
)
|
|
365
|
+
self.app.post(f"{prefix}/migration/cleanup")(
|
|
366
|
+
self.cleanup_incompatible_models
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
async def add_local_model(
|
|
370
|
+
self,
|
|
371
|
+
model_name: str = Form(...),
|
|
372
|
+
checkpoint: str = Form(""),
|
|
373
|
+
recipe: str = Form(...),
|
|
374
|
+
reasoning: bool = Form(False),
|
|
375
|
+
vision: bool = Form(False),
|
|
376
|
+
mmproj: str = Form(None),
|
|
377
|
+
model_files: List[UploadFile] = None,
|
|
378
|
+
):
|
|
379
|
+
from huggingface_hub.constants import HF_HUB_CACHE
|
|
380
|
+
from lemonade.tools.llamacpp.utils import parse_checkpoint
|
|
381
|
+
|
|
382
|
+
# Upload and register a local model from files.
|
|
383
|
+
try:
|
|
384
|
+
if not model_files:
|
|
385
|
+
raise HTTPException(
|
|
386
|
+
status_code=status.HTTP_400_BAD_REQUEST,
|
|
387
|
+
detail="No model files provided for upload",
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
if not model_name.startswith("user."):
|
|
391
|
+
raise HTTPException(
|
|
392
|
+
status_code=status.HTTP_400_BAD_REQUEST,
|
|
393
|
+
detail="Model name must start with 'user.'",
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
valid_recipes = ["llamacpp", "oga-npu", "oga-hybrid", "oga-cpu"]
|
|
397
|
+
if recipe not in valid_recipes:
|
|
398
|
+
raise HTTPException(
|
|
399
|
+
status_code=status.HTTP_400_BAD_REQUEST,
|
|
400
|
+
detail=f"Invalid recipe. Must be one of: {', '.join(valid_recipes)}",
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
if recipe == "llamacpp" and not any(
|
|
404
|
+
f.filename.lower().endswith(".gguf") for f in model_files
|
|
405
|
+
):
|
|
406
|
+
raise HTTPException(
|
|
407
|
+
status_code=status.HTTP_400_BAD_REQUEST,
|
|
408
|
+
detail="At least one .gguf file is required for llamacpp",
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
# Check if model name already exists
|
|
412
|
+
if model_name in ModelManager().supported_models:
|
|
413
|
+
raise HTTPException(
|
|
414
|
+
status_code=status.HTTP_409_CONFLICT,
|
|
415
|
+
detail=(
|
|
416
|
+
f"Model name '{model_name}' already exists. "
|
|
417
|
+
"Please use a different name."
|
|
418
|
+
),
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
model_name_clean = model_name.replace("user.", "")
|
|
422
|
+
|
|
423
|
+
# Files are saved to models--{model_name_clean}
|
|
424
|
+
# Note: This is based on the user's custom model name, NOT the checkpoint field
|
|
425
|
+
repo_cache_name = model_name_clean.replace("/", "--")
|
|
426
|
+
snapshot_path = os.path.join(HF_HUB_CACHE, f"models--{repo_cache_name}")
|
|
427
|
+
os.makedirs(snapshot_path, exist_ok=True)
|
|
428
|
+
|
|
429
|
+
# Extract variant from checkpoint field if provided
|
|
430
|
+
# checkpoint field format: "folder:variant" or just "folder"
|
|
431
|
+
variant = None
|
|
432
|
+
if checkpoint and ":" in checkpoint:
|
|
433
|
+
_, variant = parse_checkpoint(checkpoint)
|
|
434
|
+
# variant now contains just the variant[can be with or without the
|
|
435
|
+
# .gguf extension] filename (e.g., "LFM2-VL-1.6B-F16 or LFM2-VL-1.6B-F16.gguf")
|
|
436
|
+
|
|
437
|
+
# Save uploaded files, preserving folder structure
|
|
438
|
+
for file in model_files:
|
|
439
|
+
relative_path = file.filename
|
|
440
|
+
path_parts = relative_path.split("/")
|
|
441
|
+
|
|
442
|
+
if len(path_parts) > 1:
|
|
443
|
+
internal_path = "/".join(path_parts[1:])
|
|
444
|
+
file_path = os.path.join(snapshot_path, internal_path)
|
|
445
|
+
else:
|
|
446
|
+
file_path = os.path.join(snapshot_path, path_parts[0])
|
|
447
|
+
|
|
448
|
+
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
|
449
|
+
with open(file_path, "wb") as f:
|
|
450
|
+
content = await file.read()
|
|
451
|
+
f.write(content)
|
|
452
|
+
|
|
453
|
+
# Resolve actual file paths after upload (for faster loading later)
|
|
454
|
+
resolved_checkpoint = None
|
|
455
|
+
resolved_mmproj = None
|
|
456
|
+
|
|
457
|
+
# For OGA models, find genai_config.json
|
|
458
|
+
if recipe.startswith("oga-"):
|
|
459
|
+
for root, _, files in os.walk(snapshot_path):
|
|
460
|
+
if "genai_config.json" in files:
|
|
461
|
+
resolved_checkpoint = root
|
|
462
|
+
break
|
|
463
|
+
if not resolved_checkpoint:
|
|
464
|
+
resolved_checkpoint = snapshot_path
|
|
465
|
+
|
|
466
|
+
# For llamacpp models, find the GGUF file
|
|
467
|
+
elif recipe == "llamacpp":
|
|
468
|
+
gguf_file_found = None
|
|
469
|
+
|
|
470
|
+
# If variant is specified, look for that specific file
|
|
471
|
+
if variant:
|
|
472
|
+
search_term = (
|
|
473
|
+
variant if variant.endswith(".gguf") else f"{variant}.gguf"
|
|
474
|
+
)
|
|
475
|
+
for root, _, files in os.walk(snapshot_path):
|
|
476
|
+
if search_term in files:
|
|
477
|
+
gguf_file_found = os.path.join(root, search_term)
|
|
478
|
+
break
|
|
479
|
+
|
|
480
|
+
# If no variant or variant not found, search for any .gguf file (excluding mmproj)
|
|
481
|
+
if not gguf_file_found:
|
|
482
|
+
for root, _, files in os.walk(snapshot_path):
|
|
483
|
+
gguf_files = [
|
|
484
|
+
f
|
|
485
|
+
for f in files
|
|
486
|
+
if f.endswith(".gguf") and "mmproj" not in f.lower()
|
|
487
|
+
]
|
|
488
|
+
if gguf_files:
|
|
489
|
+
gguf_file_found = os.path.join(root, gguf_files[0])
|
|
490
|
+
break
|
|
491
|
+
|
|
492
|
+
resolved_checkpoint = (
|
|
493
|
+
gguf_file_found if gguf_file_found else snapshot_path
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
# Search for mmproj file if provided
|
|
497
|
+
if mmproj:
|
|
498
|
+
for root, _, files in os.walk(snapshot_path):
|
|
499
|
+
if mmproj in files:
|
|
500
|
+
resolved_mmproj = os.path.join(root, mmproj)
|
|
501
|
+
break
|
|
502
|
+
|
|
503
|
+
# Build checkpoint for registration
|
|
504
|
+
# For llamacpp with resolved path, store the full path relative to HF_HUB_CACHE
|
|
505
|
+
if resolved_checkpoint:
|
|
506
|
+
# Store as relative path from HF_HUB_CACHE for portability
|
|
507
|
+
checkpoint_to_register = os.path.relpath(
|
|
508
|
+
resolved_checkpoint, HF_HUB_CACHE
|
|
509
|
+
)
|
|
510
|
+
elif variant:
|
|
511
|
+
checkpoint_to_register = f"models--{repo_cache_name}:{variant}"
|
|
512
|
+
else:
|
|
513
|
+
checkpoint_to_register = f"models--{repo_cache_name}"
|
|
514
|
+
|
|
515
|
+
# Register the model
|
|
516
|
+
ModelManager().register_local_model(
|
|
517
|
+
model_name=model_name,
|
|
518
|
+
checkpoint=checkpoint_to_register,
|
|
519
|
+
recipe=recipe,
|
|
520
|
+
reasoning=reasoning,
|
|
521
|
+
vision=vision,
|
|
522
|
+
mmproj=resolved_mmproj if resolved_mmproj else mmproj,
|
|
523
|
+
snapshot_path=snapshot_path,
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
# Refresh local models
|
|
527
|
+
self.local_models = ModelManager().downloaded_models_enabled
|
|
528
|
+
|
|
529
|
+
return {
|
|
530
|
+
"status": "success",
|
|
531
|
+
"message": f"Model {model_name} uploaded and registered successfully",
|
|
532
|
+
}
|
|
533
|
+
except Exception as e:
|
|
534
|
+
if os.path.exists(checkpoint_to_register):
|
|
535
|
+
shutil.rmtree(checkpoint_to_register)
|
|
536
|
+
raise HTTPException(
|
|
537
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
538
|
+
detail=f"Failed to upload model: {str(e)}",
|
|
539
|
+
)
|
|
540
|
+
|
|
541
|
+
async def set_log_level(self, config: LogLevelConfig):
|
|
542
|
+
"""
|
|
543
|
+
Set the logging level of the server.
|
|
544
|
+
"""
|
|
545
|
+
try:
|
|
546
|
+
log_level_upper = config.level.upper()
|
|
547
|
+
numeric_level = getattr(logging, log_level_upper, None)
|
|
548
|
+
if not isinstance(numeric_level, int):
|
|
549
|
+
raise ValueError(f"Invalid log level: {config.level}")
|
|
550
|
+
|
|
551
|
+
# Get the root logger
|
|
552
|
+
logger = logging.getLogger()
|
|
553
|
+
logger.setLevel(numeric_level)
|
|
554
|
+
|
|
555
|
+
# Update all handlers
|
|
556
|
+
for handler in logger.handlers:
|
|
557
|
+
handler.setLevel(numeric_level)
|
|
558
|
+
|
|
559
|
+
logging.getLogger("uvicorn.error").setLevel(numeric_level)
|
|
560
|
+
self.debug_logging_enabled = numeric_level <= logging.DEBUG
|
|
561
|
+
|
|
562
|
+
# Save the setting
|
|
563
|
+
save_setting("log_level", config.level)
|
|
564
|
+
|
|
565
|
+
logging.info(f"Log level changed to: {config.level}")
|
|
566
|
+
return {"status": "success", "message": f"Log level set to {config.level}"}
|
|
567
|
+
except Exception as e:
|
|
568
|
+
raise HTTPException(
|
|
569
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
570
|
+
detail=f"Failed to set log level: {str(e)}",
|
|
571
|
+
)
|
|
572
|
+
|
|
573
|
+
def _log_request_parameters(self, request, endpoint_name: str):
|
|
574
|
+
"""
|
|
575
|
+
Log request parameters excluding content fields like messages, prompt, or input.
|
|
576
|
+
|
|
577
|
+
Args:
|
|
578
|
+
request: Any request object (CompletionRequest, ChatCompletionRequest, etc.)
|
|
579
|
+
endpoint_name: Name of the endpoint for logging context
|
|
580
|
+
"""
|
|
581
|
+
if not logging.getLogger().isEnabledFor(logging.DEBUG):
|
|
582
|
+
return
|
|
583
|
+
|
|
584
|
+
# Fields to exclude from logging (content fields)
|
|
585
|
+
excluded_fields = {"messages", "prompt", "input"}
|
|
586
|
+
|
|
587
|
+
# Get all attributes from the request object
|
|
588
|
+
request_params = {}
|
|
589
|
+
if hasattr(request, "__dict__"):
|
|
590
|
+
# For pydantic models, get the dict representation
|
|
591
|
+
if hasattr(request, "model_dump"):
|
|
592
|
+
all_params = request.model_dump()
|
|
593
|
+
elif hasattr(request, "dict"):
|
|
594
|
+
all_params = request.dict()
|
|
595
|
+
else:
|
|
596
|
+
all_params = request.__dict__
|
|
597
|
+
|
|
598
|
+
# Filter out excluded fields and add special handling for certain fields
|
|
599
|
+
for key, value in all_params.items():
|
|
600
|
+
if key not in excluded_fields:
|
|
601
|
+
# Special handling for tools field - show count instead of full content
|
|
602
|
+
if key == "tools" and value is not None:
|
|
603
|
+
request_params[key] = (
|
|
604
|
+
f"{len(value)} tools" if isinstance(value, list) else value
|
|
605
|
+
)
|
|
606
|
+
# Special handling for input type in responses
|
|
607
|
+
elif key == "input" and hasattr(request, "input"):
|
|
608
|
+
request_params["input_type"] = type(value).__name__
|
|
609
|
+
else:
|
|
610
|
+
request_params[key] = value
|
|
611
|
+
|
|
612
|
+
logging.debug(f"{endpoint_name} request parameters: {request_params}")
|
|
613
|
+
|
|
614
|
+
def _setup_server_common(
|
|
615
|
+
self,
|
|
616
|
+
tray: bool = False,
|
|
617
|
+
threaded_mode: bool = False,
|
|
618
|
+
):
|
|
619
|
+
"""
|
|
620
|
+
Common setup logic shared between run() and run_in_thread().
|
|
621
|
+
|
|
622
|
+
Args:
|
|
623
|
+
tray: Whether to run the server in tray mode
|
|
624
|
+
threaded_mode: Whether this is being set up for threaded execution
|
|
625
|
+
"""
|
|
626
|
+
|
|
627
|
+
# Define TRACE level
|
|
628
|
+
logging.TRACE = 9 # Lower than DEBUG which is 10
|
|
629
|
+
logging.addLevelName(logging.TRACE, "TRACE")
|
|
630
|
+
|
|
631
|
+
# Add a convenience function at the module level
|
|
632
|
+
def trace(message, *args, **kwargs):
|
|
633
|
+
logging.log(logging.TRACE, message, *args, **kwargs)
|
|
634
|
+
|
|
635
|
+
logging.trace = trace
|
|
636
|
+
|
|
637
|
+
# Configure logging based on mode
|
|
638
|
+
if threaded_mode:
|
|
639
|
+
# Configure logging for warning level (to reduce noise in threaded execution)
|
|
640
|
+
logging.getLogger("uvicorn.error").setLevel(logging.WARNING)
|
|
641
|
+
else:
|
|
642
|
+
# Configure logging to match uvicorn's format
|
|
643
|
+
logging_level = getattr(logging, self.log_level.upper())
|
|
644
|
+
|
|
645
|
+
# Set up file handler for logging to lemonade.log
|
|
646
|
+
uvicorn_formatter = uvicorn.logging.DefaultFormatter(
|
|
647
|
+
fmt="%(levelprefix)s %(message)s",
|
|
648
|
+
use_colors=True,
|
|
649
|
+
)
|
|
650
|
+
if not self.log_file:
|
|
651
|
+
self.log_file = tempfile.NamedTemporaryFile(
|
|
652
|
+
prefix="lemonade_", suffix=".log", delete=False
|
|
653
|
+
).name
|
|
654
|
+
file_handler = logging.FileHandler(
|
|
655
|
+
self.log_file, mode="a", encoding="utf-8"
|
|
656
|
+
)
|
|
657
|
+
file_handler.setLevel(logging_level)
|
|
658
|
+
file_handler.setFormatter(uvicorn_formatter)
|
|
659
|
+
file_handler.addFilter(ServerLogFilter(self))
|
|
660
|
+
|
|
661
|
+
# Set up console handler
|
|
662
|
+
console_handler = logging.StreamHandler()
|
|
663
|
+
console_handler.setLevel(logging_level)
|
|
664
|
+
console_handler.setFormatter(uvicorn_formatter)
|
|
665
|
+
console_handler.addFilter(ServerLogFilter(self))
|
|
666
|
+
|
|
667
|
+
# Configure root logger with both handlers
|
|
668
|
+
logging.basicConfig(
|
|
669
|
+
level=logging_level,
|
|
670
|
+
handlers=[file_handler, console_handler],
|
|
671
|
+
force=True,
|
|
672
|
+
)
|
|
673
|
+
|
|
674
|
+
# Update debug logging state after setting log level
|
|
675
|
+
self.debug_logging_enabled = logging.getLogger().isEnabledFor(logging.DEBUG)
|
|
676
|
+
if tray:
|
|
677
|
+
# Save original stdout/stderr
|
|
678
|
+
sys.stdout = OutputDuplicator(self.log_file, sys.stdout)
|
|
679
|
+
sys.stderr = OutputDuplicator(self.log_file, sys.stderr)
|
|
680
|
+
|
|
681
|
+
# Open lemonade server in tray mode
|
|
682
|
+
# lambda function used for deferred instantiation and thread safety
|
|
683
|
+
LemonadeTray(
|
|
684
|
+
self.log_file, self.port, lambda: self, log_level=self.log_level
|
|
685
|
+
).run()
|
|
686
|
+
sys.exit(0)
|
|
687
|
+
|
|
688
|
+
# Let the app know what port it's running on, so
|
|
689
|
+
# that the lifespan can access it
|
|
690
|
+
self.app.port = self.port
|
|
691
|
+
# FastAPI already has a `host` function and we cannot use `_host` as
|
|
692
|
+
# PyLint will believe its private
|
|
693
|
+
self.app.host_ = self.host
|
|
694
|
+
|
|
695
|
+
def run(self):
|
|
696
|
+
# Common setup
|
|
697
|
+
self._setup_server_common(
|
|
698
|
+
threaded_mode=False,
|
|
699
|
+
tray=self.tray,
|
|
700
|
+
)
|
|
701
|
+
|
|
702
|
+
uvicorn.run(self.app, host=self.host, port=self.port, log_level=self.log_level)
|
|
703
|
+
|
|
704
|
+
def run_in_thread(self, host: str = "localhost"):
|
|
705
|
+
"""
|
|
706
|
+
Set up the server for running in a thread.
|
|
707
|
+
Returns a uvicorn server instance that can be controlled externally.
|
|
708
|
+
"""
|
|
709
|
+
# Common setup
|
|
710
|
+
self._setup_server_common(
|
|
711
|
+
threaded_mode=True,
|
|
712
|
+
tray=False,
|
|
713
|
+
)
|
|
714
|
+
|
|
715
|
+
class CustomServer(UvicornServer):
|
|
716
|
+
"""Custom Uvicorn server that can be properly shutdown from another thread"""
|
|
717
|
+
|
|
718
|
+
def install_signal_handlers(self):
|
|
719
|
+
pass
|
|
720
|
+
|
|
721
|
+
# Configure the server
|
|
722
|
+
config = Config(
|
|
723
|
+
app=self.app,
|
|
724
|
+
host=host,
|
|
725
|
+
port=self.port,
|
|
726
|
+
log_level=self.log_level,
|
|
727
|
+
log_config=None,
|
|
728
|
+
)
|
|
729
|
+
|
|
730
|
+
# Create and return the uvicorn server
|
|
731
|
+
return CustomServer(config=config)
|
|
732
|
+
|
|
733
|
+
async def _show_telemetry(self):
|
|
734
|
+
"""
|
|
735
|
+
Show telemetry data in debug mode.
|
|
736
|
+
"""
|
|
737
|
+
# Exit early if debug logging is disabled or no telemetry data is available
|
|
738
|
+
if not self.debug_logging_enabled or self.tokens_per_second is None:
|
|
739
|
+
return
|
|
740
|
+
|
|
741
|
+
# Prepare telemetry data (transposed format)
|
|
742
|
+
telemetry = [
|
|
743
|
+
["Input tokens", self.input_tokens],
|
|
744
|
+
["Output tokens", self.output_tokens],
|
|
745
|
+
["TTFT (s)", f"{self.time_to_first_token:.2f}"],
|
|
746
|
+
["TPS", f"{self.tokens_per_second:.2f}"],
|
|
747
|
+
]
|
|
748
|
+
|
|
749
|
+
table = tabulate(
|
|
750
|
+
telemetry, headers=["Metric", "Value"], tablefmt="fancy_grid"
|
|
751
|
+
).split("\n")
|
|
752
|
+
|
|
753
|
+
# Show telemetry in debug while complying with uvicorn's log indentation
|
|
754
|
+
logging.debug("\n ".join(table))
|
|
755
|
+
|
|
756
|
+
def webapp(self):
|
|
757
|
+
"""
|
|
758
|
+
Serve the Web App to the user's browser.
|
|
759
|
+
"""
|
|
760
|
+
|
|
761
|
+
return get_webapp_html(port=self.app.port)
|
|
762
|
+
|
|
763
|
+
def initialize_load_config(
|
|
764
|
+
self, request: Union[ChatCompletionRequest, CompletionRequest]
|
|
765
|
+
) -> LoadConfig:
|
|
766
|
+
"""
|
|
767
|
+
Turn the Request object into a partially-complete LoadConfig.
|
|
768
|
+
|
|
769
|
+
The load_llm() method is responsible for filling in the rest of
|
|
770
|
+
LoadConfig's parameters.
|
|
771
|
+
"""
|
|
772
|
+
|
|
773
|
+
# Get model config
|
|
774
|
+
if "/" in request.model:
|
|
775
|
+
# We know the model is a Hugging Face checkpoint if it contains a /
|
|
776
|
+
# This scenario is only supported for run-as-thread
|
|
777
|
+
lc = LoadConfig(model_name="custom", checkpoint=request.model)
|
|
778
|
+
else:
|
|
779
|
+
# The model should be a reference to a built-in model
|
|
780
|
+
lc = LoadConfig(model_name=request.model)
|
|
781
|
+
|
|
782
|
+
return lc
|
|
783
|
+
|
|
784
|
+
async def completions(
|
|
785
|
+
self, completion_request: CompletionRequest, request: Request
|
|
786
|
+
):
|
|
787
|
+
"""
|
|
788
|
+
Stream completion responses using HTTP chunked transfer encoding.
|
|
789
|
+
"""
|
|
790
|
+
|
|
791
|
+
lc = self.initialize_load_config(completion_request)
|
|
792
|
+
|
|
793
|
+
# Log request parameters (excluding message content for brevity)
|
|
794
|
+
self._log_request_parameters(completion_request, "Completions")
|
|
795
|
+
|
|
796
|
+
# Load the model if it's different from the currently loaded one
|
|
797
|
+
await self.load_llm(lc)
|
|
798
|
+
|
|
799
|
+
if self.llm_loaded.recipe == "llamacpp" or self.llm_loaded.recipe == "flm":
|
|
800
|
+
return self.wrapped_server.completion(completion_request)
|
|
801
|
+
|
|
802
|
+
# Check if the model supports reasoning
|
|
803
|
+
reasoning_first_token = self.llm_loaded.reasoning
|
|
804
|
+
|
|
805
|
+
# If the model supports reasoning, we:
|
|
806
|
+
# 1. add a <think> tag to the model's context
|
|
807
|
+
# 2. ensure that the first token is a <think> token
|
|
808
|
+
text = completion_request.prompt
|
|
809
|
+
if reasoning_first_token:
|
|
810
|
+
text += "<think>"
|
|
811
|
+
|
|
812
|
+
# Prepare generation arguments
|
|
813
|
+
generation_args = {
|
|
814
|
+
"message": text,
|
|
815
|
+
"stop": completion_request.stop,
|
|
816
|
+
"temperature": completion_request.temperature,
|
|
817
|
+
"repeat_penalty": completion_request.repeat_penalty,
|
|
818
|
+
"top_k": completion_request.top_k,
|
|
819
|
+
"top_p": completion_request.top_p,
|
|
820
|
+
"max_new_tokens": completion_request.max_tokens,
|
|
821
|
+
}
|
|
822
|
+
|
|
823
|
+
if completion_request.stream:
|
|
824
|
+
|
|
825
|
+
if completion_request.logprobs:
|
|
826
|
+
logging.warning("logprobs is not supported for streaming completion")
|
|
827
|
+
if completion_request.echo:
|
|
828
|
+
logging.warning(
|
|
829
|
+
"`Echo` parameter is not supported for streaming completions"
|
|
830
|
+
)
|
|
831
|
+
|
|
832
|
+
# Stream the response
|
|
833
|
+
async def generate():
|
|
834
|
+
# Declare it's the same variable from outside scope
|
|
835
|
+
# This is necessary because the variable is modified
|
|
836
|
+
# in the inner function
|
|
837
|
+
nonlocal reasoning_first_token
|
|
838
|
+
try:
|
|
839
|
+
async for token in self._generate_tokens(**generation_args):
|
|
840
|
+
# Handle client disconnect: stop generation and exit
|
|
841
|
+
if await request.is_disconnected():
|
|
842
|
+
self.stop_event.set()
|
|
843
|
+
break
|
|
844
|
+
|
|
845
|
+
choice = CompletionChoice(
|
|
846
|
+
text=(
|
|
847
|
+
"<think>" + token if reasoning_first_token else token
|
|
848
|
+
),
|
|
849
|
+
index=0,
|
|
850
|
+
finish_reason="stop",
|
|
851
|
+
logprobs=None,
|
|
852
|
+
)
|
|
853
|
+
|
|
854
|
+
completion = Completion(
|
|
855
|
+
id="0",
|
|
856
|
+
choices=[choice],
|
|
857
|
+
model=self.llm_loaded.checkpoint,
|
|
858
|
+
object="text_completion",
|
|
859
|
+
created=int(time.time()),
|
|
860
|
+
)
|
|
861
|
+
|
|
862
|
+
# Format as SSE
|
|
863
|
+
reasoning_first_token = False
|
|
864
|
+
yield f"data: {completion.model_dump_json()}\n\n".encode(
|
|
865
|
+
"utf-8"
|
|
866
|
+
)
|
|
867
|
+
|
|
868
|
+
# Send the [DONE] marker only if still connected
|
|
869
|
+
if not await request.is_disconnected():
|
|
870
|
+
yield b"data: [DONE]\n\n"
|
|
871
|
+
except asyncio.CancelledError:
|
|
872
|
+
# Propagate cancellation to the generator loop
|
|
873
|
+
self.stop_event.set()
|
|
874
|
+
return
|
|
875
|
+
|
|
876
|
+
return StreamingResponse(
|
|
877
|
+
generate(),
|
|
878
|
+
media_type="text/event-stream",
|
|
879
|
+
headers={
|
|
880
|
+
"Cache-Control": "no-cache",
|
|
881
|
+
"Connection": "keep-alive",
|
|
882
|
+
},
|
|
883
|
+
)
|
|
884
|
+
|
|
885
|
+
# If streaming is not requested, collect all generated tokens into a single response
|
|
886
|
+
else:
|
|
887
|
+
full_response = text if completion_request.echo else ""
|
|
888
|
+
async for token in self._generate_tokens(**generation_args):
|
|
889
|
+
full_response += token
|
|
890
|
+
|
|
891
|
+
# If logprobs are requested, create a logprobs object
|
|
892
|
+
logprobs = None
|
|
893
|
+
if completion_request.logprobs:
|
|
894
|
+
|
|
895
|
+
# Compute the logprobs
|
|
896
|
+
text_offset, token_logprobs, tokens, top_logprobs = (
|
|
897
|
+
self.model.compute_logprobs(
|
|
898
|
+
text=full_response,
|
|
899
|
+
tokenizer=self.tokenizer,
|
|
900
|
+
logprobs=completion_request.logprobs,
|
|
901
|
+
)
|
|
902
|
+
)
|
|
903
|
+
logprobs = Logprobs.model_construct(
|
|
904
|
+
text_offset=text_offset,
|
|
905
|
+
token_logprobs=token_logprobs,
|
|
906
|
+
tokens=tokens,
|
|
907
|
+
top_logprobs=top_logprobs,
|
|
908
|
+
)
|
|
909
|
+
|
|
910
|
+
choice = CompletionChoice(
|
|
911
|
+
text=full_response,
|
|
912
|
+
index=0,
|
|
913
|
+
finish_reason="stop",
|
|
914
|
+
logprobs=logprobs,
|
|
915
|
+
)
|
|
916
|
+
|
|
917
|
+
usage = CompletionUsage(
|
|
918
|
+
prompt_tokens=self.input_tokens,
|
|
919
|
+
completion_tokens=self.output_tokens,
|
|
920
|
+
total_tokens=self.input_tokens + self.output_tokens,
|
|
921
|
+
)
|
|
922
|
+
|
|
923
|
+
return Completion(
|
|
924
|
+
id="0",
|
|
925
|
+
choices=[choice],
|
|
926
|
+
usage=usage,
|
|
927
|
+
model=self.llm_loaded.checkpoint,
|
|
928
|
+
object="text_completion",
|
|
929
|
+
created=int(time.time()),
|
|
930
|
+
)
|
|
931
|
+
|
|
932
|
+
async def chat_completions(
|
|
933
|
+
self, chat_completion_request: ChatCompletionRequest, request: Request
|
|
934
|
+
):
|
|
935
|
+
"""
|
|
936
|
+
Stream chat completion responses using HTTP chunked transfer encoding.
|
|
937
|
+
"""
|
|
938
|
+
|
|
939
|
+
if chat_completion_request.logprobs:
|
|
940
|
+
logging.warning("logprobs is not supported on chat completion")
|
|
941
|
+
|
|
942
|
+
lc = self.initialize_load_config(chat_completion_request)
|
|
943
|
+
|
|
944
|
+
# Log request parameters (excluding message history for brevity)
|
|
945
|
+
self._log_request_parameters(chat_completion_request, "Chat completions")
|
|
946
|
+
|
|
947
|
+
# Load the model if it's different from the currently loaded one
|
|
948
|
+
await self.load_llm(lc)
|
|
949
|
+
|
|
950
|
+
if self.llm_loaded.recipe == "llamacpp" or self.llm_loaded.recipe == "flm":
|
|
951
|
+
if (
|
|
952
|
+
hasattr(chat_completion_request, "enable_thinking")
|
|
953
|
+
and chat_completion_request.enable_thinking is False
|
|
954
|
+
and "qwen3" in self.llm_loaded.model_name.lower()
|
|
955
|
+
):
|
|
956
|
+
|
|
957
|
+
# Modify the last user message to include /no_think
|
|
958
|
+
if chat_completion_request.messages:
|
|
959
|
+
for i in range(len(chat_completion_request.messages) - 1, -1, -1):
|
|
960
|
+
if chat_completion_request.messages[i].get("role") == "user":
|
|
961
|
+
original_content = chat_completion_request.messages[i][
|
|
962
|
+
"content"
|
|
963
|
+
]
|
|
964
|
+
chat_completion_request.messages[i][
|
|
965
|
+
"content"
|
|
966
|
+
] = f"/no_think\n{original_content}"
|
|
967
|
+
break
|
|
968
|
+
return self.wrapped_server.chat_completion(chat_completion_request)
|
|
969
|
+
|
|
970
|
+
# Convert chat messages to text using the model's chat template
|
|
971
|
+
text = self.apply_chat_template(
|
|
972
|
+
chat_completion_request.messages,
|
|
973
|
+
tools=chat_completion_request.tools,
|
|
974
|
+
)
|
|
975
|
+
|
|
976
|
+
# If the model supports reasoning, we:
|
|
977
|
+
# 1. add a <think> tag to the model's context
|
|
978
|
+
# 2. ensure that the first token is a <think> token
|
|
979
|
+
reasoning_first_token = self.llm_loaded.reasoning
|
|
980
|
+
|
|
981
|
+
if reasoning_first_token:
|
|
982
|
+
text += "<think>"
|
|
983
|
+
# Set the max_new_tokens parameter
|
|
984
|
+
if (
|
|
985
|
+
chat_completion_request.max_completion_tokens
|
|
986
|
+
and chat_completion_request.max_tokens
|
|
987
|
+
):
|
|
988
|
+
raise HTTPException(
|
|
989
|
+
status_code=status.HTTP_400_BAD_REQUEST,
|
|
990
|
+
detail=(
|
|
991
|
+
"Both max_tokens and max_completion_tokens were provided. "
|
|
992
|
+
"Please use only one of these parameters.",
|
|
993
|
+
),
|
|
994
|
+
)
|
|
995
|
+
max_new_tokens = (
|
|
996
|
+
chat_completion_request.max_completion_tokens
|
|
997
|
+
if chat_completion_request.max_completion_tokens
|
|
998
|
+
else chat_completion_request.max_tokens
|
|
999
|
+
)
|
|
1000
|
+
|
|
1001
|
+
# Prepare generation arguments
|
|
1002
|
+
generation_args = {
|
|
1003
|
+
"message": text,
|
|
1004
|
+
"stop": chat_completion_request.stop,
|
|
1005
|
+
"temperature": chat_completion_request.temperature,
|
|
1006
|
+
"repeat_penalty": chat_completion_request.repeat_penalty,
|
|
1007
|
+
"top_k": chat_completion_request.top_k,
|
|
1008
|
+
"top_p": chat_completion_request.top_p,
|
|
1009
|
+
"max_new_tokens": max_new_tokens,
|
|
1010
|
+
}
|
|
1011
|
+
|
|
1012
|
+
if chat_completion_request.tools:
|
|
1013
|
+
# Get the tool call pattern
|
|
1014
|
+
tool_call_pattern = get_tool_call_pattern(
|
|
1015
|
+
self.tokenizer.auto_tokenizer.added_tokens_decoder
|
|
1016
|
+
)
|
|
1017
|
+
|
|
1018
|
+
if chat_completion_request.stream:
|
|
1019
|
+
|
|
1020
|
+
# Stream the response
|
|
1021
|
+
async def generate():
|
|
1022
|
+
# Declare it's the same variable from outside scope
|
|
1023
|
+
# This is necessary because the variable is modified
|
|
1024
|
+
# in the inner function
|
|
1025
|
+
nonlocal reasoning_first_token
|
|
1026
|
+
|
|
1027
|
+
# Keep track of the full response for tool call extraction
|
|
1028
|
+
full_response = ""
|
|
1029
|
+
|
|
1030
|
+
# Track whether we're still in the thinking phase (before </think> tag)
|
|
1031
|
+
in_thinking_phase = self.llm_loaded.reasoning
|
|
1032
|
+
reasoning_buffer = "" # Accumulate reasoning tokens to detect </think>
|
|
1033
|
+
|
|
1034
|
+
try:
|
|
1035
|
+
async for token in self._generate_tokens(**generation_args):
|
|
1036
|
+
# Handle client disconnect: stop generation and exit
|
|
1037
|
+
if await request.is_disconnected():
|
|
1038
|
+
self.stop_event.set()
|
|
1039
|
+
break
|
|
1040
|
+
|
|
1041
|
+
# Continuously look for tool calls embedded into the generated text
|
|
1042
|
+
openai_tool_calls = None
|
|
1043
|
+
if chat_completion_request.tools:
|
|
1044
|
+
|
|
1045
|
+
# Append the token to the full response
|
|
1046
|
+
full_response += token
|
|
1047
|
+
|
|
1048
|
+
tool_calls, _ = extract_tool_calls(
|
|
1049
|
+
full_response,
|
|
1050
|
+
tool_call_pattern,
|
|
1051
|
+
)
|
|
1052
|
+
|
|
1053
|
+
# If there are tool calls, reset the full response for the next call
|
|
1054
|
+
if tool_calls:
|
|
1055
|
+
openai_tool_calls = []
|
|
1056
|
+
full_response = ""
|
|
1057
|
+
for tool_call in tool_calls:
|
|
1058
|
+
openai_tool_calls.append(
|
|
1059
|
+
ChoiceDeltaToolCall(
|
|
1060
|
+
index=0,
|
|
1061
|
+
id="-",
|
|
1062
|
+
function=ChoiceDeltaToolCallFunction(
|
|
1063
|
+
arguments=json.dumps(
|
|
1064
|
+
tool_call["arguments"]
|
|
1065
|
+
),
|
|
1066
|
+
name=tool_call["name"],
|
|
1067
|
+
),
|
|
1068
|
+
type="function",
|
|
1069
|
+
)
|
|
1070
|
+
)
|
|
1071
|
+
|
|
1072
|
+
# Create a ChatCompletionChunk with reasoning_content support
|
|
1073
|
+
# If we're in reasoning mode and haven't seen </think> yet,
|
|
1074
|
+
# send tokens as reasoning_content instead of content
|
|
1075
|
+
delta_content = None
|
|
1076
|
+
delta_reasoning = None
|
|
1077
|
+
|
|
1078
|
+
if reasoning_first_token:
|
|
1079
|
+
# First token - include opening tag in reasoning
|
|
1080
|
+
delta_reasoning = "<think>" + token
|
|
1081
|
+
reasoning_first_token = False
|
|
1082
|
+
reasoning_buffer = token
|
|
1083
|
+
elif in_thinking_phase:
|
|
1084
|
+
# Still in thinking phase - accumulate and check for </think>
|
|
1085
|
+
reasoning_buffer += token
|
|
1086
|
+
|
|
1087
|
+
# Check if we've seen the closing tag
|
|
1088
|
+
if "</think>" in reasoning_buffer:
|
|
1089
|
+
# Split at the closing tag
|
|
1090
|
+
before_close, after_close = reasoning_buffer.split(
|
|
1091
|
+
"</think>", 1
|
|
1092
|
+
)
|
|
1093
|
+
|
|
1094
|
+
# Send everything before + closing tag as reasoning
|
|
1095
|
+
if before_close or not reasoning_buffer.startswith(
|
|
1096
|
+
"</think>"
|
|
1097
|
+
):
|
|
1098
|
+
delta_reasoning = before_close + "</think>"
|
|
1099
|
+
else:
|
|
1100
|
+
delta_reasoning = "</think>"
|
|
1101
|
+
|
|
1102
|
+
# Everything after goes to content (will be sent in next iteration)
|
|
1103
|
+
# For now, mark that we've exited thinking phase
|
|
1104
|
+
in_thinking_phase = False
|
|
1105
|
+
|
|
1106
|
+
# If there's content after </think>, we need to send it too
|
|
1107
|
+
# But we send it in the current chunk as regular content
|
|
1108
|
+
if after_close:
|
|
1109
|
+
# We have both reasoning and content in this token
|
|
1110
|
+
# Send reasoning first, content will accumulate
|
|
1111
|
+
delta_content = after_close
|
|
1112
|
+
else:
|
|
1113
|
+
# Still accumulating thinking, send as reasoning_content
|
|
1114
|
+
delta_reasoning = token
|
|
1115
|
+
else:
|
|
1116
|
+
# Normal content (after thinking phase ended)
|
|
1117
|
+
delta_content = token
|
|
1118
|
+
|
|
1119
|
+
chunk = ChatCompletionChunk.model_construct(
|
|
1120
|
+
id="0",
|
|
1121
|
+
object="chat.completion.chunk",
|
|
1122
|
+
created=int(time.time()),
|
|
1123
|
+
model=self.llm_loaded.checkpoint,
|
|
1124
|
+
choices=[
|
|
1125
|
+
Choice.model_construct(
|
|
1126
|
+
index=0,
|
|
1127
|
+
delta=ChoiceDelta(
|
|
1128
|
+
content=delta_content,
|
|
1129
|
+
reasoning_content=delta_reasoning,
|
|
1130
|
+
function_call=None,
|
|
1131
|
+
role="assistant",
|
|
1132
|
+
tool_calls=openai_tool_calls,
|
|
1133
|
+
refusal=None,
|
|
1134
|
+
),
|
|
1135
|
+
finish_reason=None,
|
|
1136
|
+
logprobs=None,
|
|
1137
|
+
)
|
|
1138
|
+
],
|
|
1139
|
+
)
|
|
1140
|
+
|
|
1141
|
+
# Format as SSE
|
|
1142
|
+
yield f"data: {chunk.model_dump_json()}\n\n".encode("utf-8")
|
|
1143
|
+
|
|
1144
|
+
# Send the [DONE] marker only if still connected
|
|
1145
|
+
if not await request.is_disconnected():
|
|
1146
|
+
yield b"data: [DONE]\n\n"
|
|
1147
|
+
except asyncio.CancelledError:
|
|
1148
|
+
self.stop_event.set()
|
|
1149
|
+
return
|
|
1150
|
+
|
|
1151
|
+
return StreamingResponse(
|
|
1152
|
+
generate(),
|
|
1153
|
+
media_type="text/event-stream",
|
|
1154
|
+
headers={
|
|
1155
|
+
"Cache-Control": "no-cache",
|
|
1156
|
+
"Connection": "keep-alive",
|
|
1157
|
+
},
|
|
1158
|
+
)
|
|
1159
|
+
|
|
1160
|
+
# If streaming is not requested, collect all generated tokens into a single response
|
|
1161
|
+
else:
|
|
1162
|
+
full_response = "<think>" if reasoning_first_token else ""
|
|
1163
|
+
async for token in self._generate_tokens(**generation_args):
|
|
1164
|
+
full_response += token
|
|
1165
|
+
|
|
1166
|
+
# Extract tool calls from the response
|
|
1167
|
+
openai_tool_calls = None
|
|
1168
|
+
if chat_completion_request.tools:
|
|
1169
|
+
tool_calls, full_response = extract_tool_calls(
|
|
1170
|
+
full_response, tool_call_pattern
|
|
1171
|
+
)
|
|
1172
|
+
if tool_calls:
|
|
1173
|
+
openai_tool_calls = []
|
|
1174
|
+
for tool_call in tool_calls:
|
|
1175
|
+
openai_tool_calls.append(
|
|
1176
|
+
ChatCompletionMessageToolCall(
|
|
1177
|
+
id="-",
|
|
1178
|
+
function=Function(
|
|
1179
|
+
arguments=json.dumps(tool_call["arguments"]),
|
|
1180
|
+
name=tool_call["name"],
|
|
1181
|
+
),
|
|
1182
|
+
type="function",
|
|
1183
|
+
)
|
|
1184
|
+
)
|
|
1185
|
+
|
|
1186
|
+
ccm = ChatCompletionMessage(
|
|
1187
|
+
content=full_response,
|
|
1188
|
+
role="assistant",
|
|
1189
|
+
refusal=None,
|
|
1190
|
+
audio=None,
|
|
1191
|
+
function_call=None,
|
|
1192
|
+
tool_calls=openai_tool_calls,
|
|
1193
|
+
)
|
|
1194
|
+
|
|
1195
|
+
choice = Choice(
|
|
1196
|
+
finish_reason="stop",
|
|
1197
|
+
index=0,
|
|
1198
|
+
message=ccm,
|
|
1199
|
+
logprobs=None,
|
|
1200
|
+
)
|
|
1201
|
+
|
|
1202
|
+
usage = CompletionUsage(
|
|
1203
|
+
prompt_tokens=self.input_tokens,
|
|
1204
|
+
completion_tokens=self.output_tokens,
|
|
1205
|
+
total_tokens=self.input_tokens + self.output_tokens,
|
|
1206
|
+
)
|
|
1207
|
+
|
|
1208
|
+
return ChatCompletion(
|
|
1209
|
+
id="0",
|
|
1210
|
+
choices=[choice],
|
|
1211
|
+
usage=usage,
|
|
1212
|
+
model=self.llm_loaded.checkpoint,
|
|
1213
|
+
object="chat.completion",
|
|
1214
|
+
created=int(time.time()),
|
|
1215
|
+
)
|
|
1216
|
+
|
|
1217
|
+
async def embeddings(self, embeddings_request: EmbeddingsRequest):
|
|
1218
|
+
"""
|
|
1219
|
+
Generate embeddings for the provided input.
|
|
1220
|
+
"""
|
|
1221
|
+
# Initialize load config from embeddings request
|
|
1222
|
+
lc = LoadConfig(model_name=embeddings_request.model)
|
|
1223
|
+
|
|
1224
|
+
# Load the model if it's different from the currently loaded one
|
|
1225
|
+
await self.load_llm(lc)
|
|
1226
|
+
|
|
1227
|
+
if self.llm_loaded.recipe == "llamacpp":
|
|
1228
|
+
try:
|
|
1229
|
+
return self.wrapped_server.embeddings(embeddings_request)
|
|
1230
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
|
1231
|
+
# Check if model has embeddings label
|
|
1232
|
+
model_info = ModelManager().supported_models.get(
|
|
1233
|
+
self.llm_loaded.model_name, {}
|
|
1234
|
+
)
|
|
1235
|
+
if "embeddings" not in model_info.get("labels", []):
|
|
1236
|
+
raise HTTPException(
|
|
1237
|
+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
1238
|
+
detail="You tried to generate embeddings for a model that is "
|
|
1239
|
+
"not labeled as an embeddings model. Please use another model "
|
|
1240
|
+
"or re-register the current model with the 'embeddings' label.",
|
|
1241
|
+
) from e
|
|
1242
|
+
else:
|
|
1243
|
+
raise e
|
|
1244
|
+
else:
|
|
1245
|
+
raise HTTPException(
|
|
1246
|
+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
1247
|
+
detail=f"Embeddings not supported for recipe: {self.llm_loaded.recipe}",
|
|
1248
|
+
)
|
|
1249
|
+
|
|
1250
|
+
async def reranking(self, reranking_request: RerankingRequest):
|
|
1251
|
+
"""
|
|
1252
|
+
Rerank documents based on their relevance to a query.
|
|
1253
|
+
"""
|
|
1254
|
+
# Initialize load config from reranking request
|
|
1255
|
+
lc = LoadConfig(model_name=reranking_request.model)
|
|
1256
|
+
|
|
1257
|
+
# Load the model if it's different from the currently loaded one
|
|
1258
|
+
await self.load_llm(lc)
|
|
1259
|
+
|
|
1260
|
+
if self.llm_loaded.recipe == "llamacpp":
|
|
1261
|
+
try:
|
|
1262
|
+
return self.wrapped_server.reranking(reranking_request)
|
|
1263
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
|
1264
|
+
# Check if model has reranking label
|
|
1265
|
+
model_info = ModelManager().supported_models.get(
|
|
1266
|
+
self.llm_loaded.model_name, {}
|
|
1267
|
+
)
|
|
1268
|
+
if "reranking" not in model_info.get("labels", []):
|
|
1269
|
+
raise HTTPException(
|
|
1270
|
+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
1271
|
+
detail="You tried to use reranking for a model that is "
|
|
1272
|
+
"not labeled as a reranking model. Please use another model "
|
|
1273
|
+
"or re-register the current model with the 'reranking' label.",
|
|
1274
|
+
) from e
|
|
1275
|
+
else:
|
|
1276
|
+
raise e
|
|
1277
|
+
else:
|
|
1278
|
+
raise HTTPException(
|
|
1279
|
+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
1280
|
+
detail=f"Reranking not supported for recipe: {self.llm_loaded.recipe}",
|
|
1281
|
+
)
|
|
1282
|
+
|
|
1283
|
+
def apply_chat_template(
|
|
1284
|
+
self, messages: list[dict], tools: list[dict] | None = None
|
|
1285
|
+
):
|
|
1286
|
+
"""
|
|
1287
|
+
Apply the model's chat template to the messages.
|
|
1288
|
+
"""
|
|
1289
|
+
if self.tokenizer.chat_template:
|
|
1290
|
+
|
|
1291
|
+
return self.tokenizer.apply_chat_template(
|
|
1292
|
+
messages,
|
|
1293
|
+
tokenize=False,
|
|
1294
|
+
add_generation_prompt=True,
|
|
1295
|
+
tools=tools,
|
|
1296
|
+
)
|
|
1297
|
+
|
|
1298
|
+
# Fallback to a standardized template if the model doesn't provide one
|
|
1299
|
+
logging.warning("No chat template found. Using default template.")
|
|
1300
|
+
formatted_messages = []
|
|
1301
|
+
for msg in messages:
|
|
1302
|
+
role = msg.get("role", "user")
|
|
1303
|
+
content = msg.get("content", "")
|
|
1304
|
+
role_marker = "<|assistant|>" if role == "assistant" else "<|user|>"
|
|
1305
|
+
formatted_messages.append(f"{role_marker}\n{content} <|end|>")
|
|
1306
|
+
return "\n".join(formatted_messages) + "\n<|assistant|>"
|
|
1307
|
+
|
|
1308
|
+
async def responses(self, responses_request: ResponsesRequest, request: Request):
|
|
1309
|
+
"""
|
|
1310
|
+
Stream responses using HTTP chunked transfer encoding.
|
|
1311
|
+
"""
|
|
1312
|
+
|
|
1313
|
+
lc = self.initialize_load_config(responses_request)
|
|
1314
|
+
|
|
1315
|
+
# Log request parameters (excluding message history for brevity)
|
|
1316
|
+
self._log_request_parameters(responses_request, "Responses")
|
|
1317
|
+
|
|
1318
|
+
# Load the model if it's different from the currently loaded one
|
|
1319
|
+
await self.load_llm(lc)
|
|
1320
|
+
|
|
1321
|
+
if self.llm_loaded.recipe == "llamacpp":
|
|
1322
|
+
raise HTTPException(
|
|
1323
|
+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
1324
|
+
detail=f"Responses API not supported for recipe: {self.llm_loaded.recipe}",
|
|
1325
|
+
)
|
|
1326
|
+
|
|
1327
|
+
# Convert chat messages to text using the model's chat template
|
|
1328
|
+
if isinstance(responses_request.input, str):
|
|
1329
|
+
text = responses_request.input
|
|
1330
|
+
else:
|
|
1331
|
+
text = self.apply_chat_template(responses_request.input)
|
|
1332
|
+
|
|
1333
|
+
# If the model supports reasoning, we:
|
|
1334
|
+
# 1. add a <think> tag to the model's context
|
|
1335
|
+
# 2. ensure that the first token is a <think> token
|
|
1336
|
+
reasoning_first_token = self.llm_loaded.reasoning
|
|
1337
|
+
|
|
1338
|
+
if reasoning_first_token:
|
|
1339
|
+
text += "<think>"
|
|
1340
|
+
|
|
1341
|
+
# Prepare generation arguments
|
|
1342
|
+
generation_args = {
|
|
1343
|
+
"message": text,
|
|
1344
|
+
"temperature": responses_request.temperature,
|
|
1345
|
+
"repeat_penalty": responses_request.repeat_penalty,
|
|
1346
|
+
"top_k": responses_request.top_k,
|
|
1347
|
+
"top_p": responses_request.top_p,
|
|
1348
|
+
"max_new_tokens": responses_request.max_output_tokens,
|
|
1349
|
+
}
|
|
1350
|
+
|
|
1351
|
+
if responses_request.stream:
|
|
1352
|
+
|
|
1353
|
+
# Stream the response
|
|
1354
|
+
async def generate():
|
|
1355
|
+
# Declare it's the same variable from outside scope
|
|
1356
|
+
# This is necessary because the variable is modified
|
|
1357
|
+
# in the inner function
|
|
1358
|
+
nonlocal reasoning_first_token
|
|
1359
|
+
|
|
1360
|
+
# Send initial creation event
|
|
1361
|
+
response = Response(
|
|
1362
|
+
id="0",
|
|
1363
|
+
model=self.llm_loaded.checkpoint,
|
|
1364
|
+
created_at=int(time.time()),
|
|
1365
|
+
object="response",
|
|
1366
|
+
output=[],
|
|
1367
|
+
parallel_tool_calls=True,
|
|
1368
|
+
tool_choice="auto",
|
|
1369
|
+
tools=[],
|
|
1370
|
+
)
|
|
1371
|
+
created_event = ResponseCreatedEvent(
|
|
1372
|
+
response=response,
|
|
1373
|
+
type="response.created",
|
|
1374
|
+
sequence_number=0,
|
|
1375
|
+
)
|
|
1376
|
+
yield f"data: {created_event.model_dump_json()}\n\n".encode("utf-8")
|
|
1377
|
+
|
|
1378
|
+
full_response = "<think>" if reasoning_first_token else ""
|
|
1379
|
+
|
|
1380
|
+
try:
|
|
1381
|
+
async for token in self._generate_tokens(**generation_args):
|
|
1382
|
+
# Handle client disconnect: stop generation and exit
|
|
1383
|
+
if await request.is_disconnected():
|
|
1384
|
+
self.stop_event.set()
|
|
1385
|
+
break
|
|
1386
|
+
|
|
1387
|
+
# Create an event
|
|
1388
|
+
delta_event = ResponseTextDeltaEvent(
|
|
1389
|
+
content_index=0,
|
|
1390
|
+
delta=(
|
|
1391
|
+
"<think>" + token if reasoning_first_token else token
|
|
1392
|
+
),
|
|
1393
|
+
item_id="0 ",
|
|
1394
|
+
logprobs=[],
|
|
1395
|
+
output_index=0,
|
|
1396
|
+
sequence_number=0,
|
|
1397
|
+
type="response.output_text.delta",
|
|
1398
|
+
)
|
|
1399
|
+
full_response += token
|
|
1400
|
+
|
|
1401
|
+
# Format as SSE
|
|
1402
|
+
reasoning_first_token = False
|
|
1403
|
+
yield f"data: {delta_event.model_dump_json()}\n\n".encode(
|
|
1404
|
+
"utf-8"
|
|
1405
|
+
)
|
|
1406
|
+
|
|
1407
|
+
# Send the completed event (only if still connected)
|
|
1408
|
+
if not await request.is_disconnected():
|
|
1409
|
+
response_output_message = ResponseOutputMessage(
|
|
1410
|
+
id="0",
|
|
1411
|
+
content=[
|
|
1412
|
+
ResponseOutputText(
|
|
1413
|
+
annotations=[],
|
|
1414
|
+
text=full_response,
|
|
1415
|
+
type="output_text",
|
|
1416
|
+
)
|
|
1417
|
+
],
|
|
1418
|
+
role="assistant",
|
|
1419
|
+
status="completed",
|
|
1420
|
+
type="message",
|
|
1421
|
+
)
|
|
1422
|
+
response = Response(
|
|
1423
|
+
id="0",
|
|
1424
|
+
model=self.llm_loaded.checkpoint,
|
|
1425
|
+
created_at=int(time.time()),
|
|
1426
|
+
object="response",
|
|
1427
|
+
output=[response_output_message],
|
|
1428
|
+
parallel_tool_calls=True,
|
|
1429
|
+
tool_choice="auto",
|
|
1430
|
+
tools=[],
|
|
1431
|
+
)
|
|
1432
|
+
completed_event = ResponseCompletedEvent(
|
|
1433
|
+
response=response,
|
|
1434
|
+
type="response.completed",
|
|
1435
|
+
sequence_number=0,
|
|
1436
|
+
)
|
|
1437
|
+
yield f"data: {completed_event.model_dump_json()}\n\n".encode(
|
|
1438
|
+
"utf-8"
|
|
1439
|
+
)
|
|
1440
|
+
|
|
1441
|
+
# Send the [DONE] marker
|
|
1442
|
+
yield b"data: [DONE]\n\n"
|
|
1443
|
+
except asyncio.CancelledError:
|
|
1444
|
+
self.stop_event.set()
|
|
1445
|
+
return
|
|
1446
|
+
|
|
1447
|
+
return StreamingResponse(
|
|
1448
|
+
generate(),
|
|
1449
|
+
media_type="text/event-stream",
|
|
1450
|
+
headers={
|
|
1451
|
+
"Cache-Control": "no-cache",
|
|
1452
|
+
"Connection": "keep-alive",
|
|
1453
|
+
},
|
|
1454
|
+
)
|
|
1455
|
+
|
|
1456
|
+
# If streaming is not requested, collect all generated tokens into a single response
|
|
1457
|
+
else:
|
|
1458
|
+
full_response = "<think>" if reasoning_first_token else ""
|
|
1459
|
+
async for token in self._generate_tokens(**generation_args):
|
|
1460
|
+
full_response += token
|
|
1461
|
+
|
|
1462
|
+
# Send the completed event
|
|
1463
|
+
response_output_message = ResponseOutputMessage(
|
|
1464
|
+
id="0",
|
|
1465
|
+
content=[
|
|
1466
|
+
ResponseOutputText(
|
|
1467
|
+
annotations=[],
|
|
1468
|
+
text=full_response,
|
|
1469
|
+
type="output_text",
|
|
1470
|
+
)
|
|
1471
|
+
],
|
|
1472
|
+
role="assistant",
|
|
1473
|
+
status="completed",
|
|
1474
|
+
type="message",
|
|
1475
|
+
)
|
|
1476
|
+
return Response(
|
|
1477
|
+
id="0",
|
|
1478
|
+
model=self.llm_loaded.checkpoint,
|
|
1479
|
+
created_at=int(time.time()),
|
|
1480
|
+
object="response",
|
|
1481
|
+
output=[response_output_message],
|
|
1482
|
+
parallel_tool_calls=True,
|
|
1483
|
+
tool_choice="auto",
|
|
1484
|
+
tools=[],
|
|
1485
|
+
)
|
|
1486
|
+
|
|
1487
|
+
async def _generate_tokens(
|
|
1488
|
+
self,
|
|
1489
|
+
message: str,
|
|
1490
|
+
stop: list[str] | str | None = None,
|
|
1491
|
+
max_new_tokens: int | None = None,
|
|
1492
|
+
temperature: float | None = None,
|
|
1493
|
+
repeat_penalty: float | None = None,
|
|
1494
|
+
top_k: int | None = None,
|
|
1495
|
+
top_p: float | None = None,
|
|
1496
|
+
):
|
|
1497
|
+
"""
|
|
1498
|
+
Core streaming completion logic, separated from response handling.
|
|
1499
|
+
Returns an async generator that yields tokens.
|
|
1500
|
+
"""
|
|
1501
|
+
|
|
1502
|
+
while not self.app.initialized:
|
|
1503
|
+
# Wait for the app's background tasks to finish before
|
|
1504
|
+
# allowing generation to proceed
|
|
1505
|
+
logging.debug("Waiting for server to fully initialize")
|
|
1506
|
+
asyncio.sleep(0.5)
|
|
1507
|
+
# These should already be imported as part of the app initialization process,
|
|
1508
|
+
# they are just here to make 100% certain and to make the linter happy
|
|
1509
|
+
from transformers import TextIteratorStreamer, StoppingCriteriaList
|
|
1510
|
+
|
|
1511
|
+
model = self.model
|
|
1512
|
+
tokenizer = self.tokenizer
|
|
1513
|
+
|
|
1514
|
+
# Reset the early-exit flag before we start each generation
|
|
1515
|
+
self.stop_event.clear()
|
|
1516
|
+
|
|
1517
|
+
input_ids = tokenizer(message, return_tensors="pt").input_ids
|
|
1518
|
+
|
|
1519
|
+
# Process stop sequences
|
|
1520
|
+
stop_sequences = []
|
|
1521
|
+
if stop is not None:
|
|
1522
|
+
if isinstance(stop, str):
|
|
1523
|
+
stop_sequences = [stop]
|
|
1524
|
+
else:
|
|
1525
|
+
stop_sequences = stop[:4] # Limit to 4 sequences as per spec
|
|
1526
|
+
|
|
1527
|
+
# Set up the generation parameters
|
|
1528
|
+
if "oga-" in self.llm_loaded.recipe:
|
|
1529
|
+
from lemonade.tools.oga.utils import OrtGenaiStreamer
|
|
1530
|
+
|
|
1531
|
+
streamer = OrtGenaiStreamer(tokenizer)
|
|
1532
|
+
self.input_tokens = len(input_ids)
|
|
1533
|
+
else:
|
|
1534
|
+
streamer = TextIteratorStreamer(
|
|
1535
|
+
tokenizer,
|
|
1536
|
+
skip_prompt=True,
|
|
1537
|
+
)
|
|
1538
|
+
self.input_tokens = len(input_ids[0])
|
|
1539
|
+
|
|
1540
|
+
max_prompt_length = self.ctx_size # Default fallback
|
|
1541
|
+
# For OGA models, try to read the actual max prompt length from config
|
|
1542
|
+
if "oga-" in self.llm_loaded.recipe:
|
|
1543
|
+
try:
|
|
1544
|
+
if model.config and model.config.get("max_prompt_length"):
|
|
1545
|
+
max_prompt_length = model.config["max_prompt_length"]
|
|
1546
|
+
logging.debug(
|
|
1547
|
+
f"Using OGA model max_prompt_length: {max_prompt_length}"
|
|
1548
|
+
)
|
|
1549
|
+
# pylint: disable=broad-exception-caught
|
|
1550
|
+
except Exception as e:
|
|
1551
|
+
logging.debug(f"Could not read OGA model config, using ctx_size: {e}")
|
|
1552
|
+
|
|
1553
|
+
# Apply truncation if input exceeds the limit
|
|
1554
|
+
if self.input_tokens > max_prompt_length:
|
|
1555
|
+
# Truncate input ids
|
|
1556
|
+
truncate_amount = self.input_tokens - max_prompt_length
|
|
1557
|
+
input_ids = input_ids[:max_prompt_length]
|
|
1558
|
+
# Update token count
|
|
1559
|
+
if "oga-" in self.llm_loaded.recipe:
|
|
1560
|
+
self.input_tokens = len(input_ids)
|
|
1561
|
+
else:
|
|
1562
|
+
self.input_tokens = len(input_ids[0])
|
|
1563
|
+
|
|
1564
|
+
# Log warning message instead of raising exception
|
|
1565
|
+
truncation_message = (
|
|
1566
|
+
f"Input exceeded {max_prompt_length} tokens. "
|
|
1567
|
+
f"Truncated {truncate_amount} tokens from the beginning."
|
|
1568
|
+
)
|
|
1569
|
+
logging.warning(truncation_message)
|
|
1570
|
+
|
|
1571
|
+
# Log the input tokens early to avoid this not showing due to potential crashes
|
|
1572
|
+
logging.debug(f"Input Tokens: {self.input_tokens}")
|
|
1573
|
+
logging.trace(f"Input Message: {message}")
|
|
1574
|
+
|
|
1575
|
+
if self.llm_loaded.recipe.startswith("hf"):
|
|
1576
|
+
stopping_criteria = StoppingCriteriaList([StopOnEvent(self.stop_event)])
|
|
1577
|
+
else:
|
|
1578
|
+
# HF expects StoppingCriteriaList, which requires torch
|
|
1579
|
+
# If we aren't using HF, we can just use a list of StopOnEvent to
|
|
1580
|
+
# avoid the torch dep
|
|
1581
|
+
stopping_criteria = [StopOnEvent(self.stop_event)]
|
|
1582
|
+
|
|
1583
|
+
generation_kwargs = {
|
|
1584
|
+
"input_ids": input_ids,
|
|
1585
|
+
"streamer": streamer,
|
|
1586
|
+
"max_new_tokens": (
|
|
1587
|
+
max_new_tokens if max_new_tokens else DEFAULT_MAX_NEW_TOKENS
|
|
1588
|
+
),
|
|
1589
|
+
"min_new_tokens": 1,
|
|
1590
|
+
"pad_token_id": tokenizer.eos_token_id,
|
|
1591
|
+
"stopping_criteria": stopping_criteria,
|
|
1592
|
+
"temperature": temperature,
|
|
1593
|
+
"repeat_penalty": repeat_penalty,
|
|
1594
|
+
"top_k": top_k,
|
|
1595
|
+
"top_p": top_p,
|
|
1596
|
+
}
|
|
1597
|
+
|
|
1598
|
+
# Initialize performance variables
|
|
1599
|
+
generation_start_time = time.perf_counter()
|
|
1600
|
+
first_token = True
|
|
1601
|
+
self.decode_token_times = []
|
|
1602
|
+
self.output_tokens = 0
|
|
1603
|
+
|
|
1604
|
+
# Begin generation
|
|
1605
|
+
thread = GeneratorThread(
|
|
1606
|
+
streamer, target=model.generate, kwargs=generation_kwargs
|
|
1607
|
+
)
|
|
1608
|
+
thread.start()
|
|
1609
|
+
|
|
1610
|
+
# Acquire the generation semaphore
|
|
1611
|
+
await self._generate_semaphore.acquire()
|
|
1612
|
+
active_generations = (
|
|
1613
|
+
self.max_concurrent_generations
|
|
1614
|
+
- self._generate_semaphore._value # pylint: disable=protected-access
|
|
1615
|
+
)
|
|
1616
|
+
|
|
1617
|
+
logging.debug(f"Active generations: {active_generations}")
|
|
1618
|
+
|
|
1619
|
+
try:
|
|
1620
|
+
# Generate the response using streaming
|
|
1621
|
+
new_text = ""
|
|
1622
|
+
for new_text in streamer:
|
|
1623
|
+
# Yield control back to the event loop
|
|
1624
|
+
# This gives the FastAPI server a chance to send the chunks to the client
|
|
1625
|
+
await asyncio.sleep(0)
|
|
1626
|
+
|
|
1627
|
+
# Capture performance stats about this token
|
|
1628
|
+
self.output_tokens = self.output_tokens + 1
|
|
1629
|
+
if first_token:
|
|
1630
|
+
self.time_to_first_token = (
|
|
1631
|
+
time.perf_counter() - generation_start_time
|
|
1632
|
+
)
|
|
1633
|
+
first_token = False
|
|
1634
|
+
else:
|
|
1635
|
+
self.decode_token_times.append(
|
|
1636
|
+
time.perf_counter() - next_token_start_time
|
|
1637
|
+
)
|
|
1638
|
+
next_token_start_time = time.perf_counter()
|
|
1639
|
+
|
|
1640
|
+
# Remove the EOS token from the response if needed
|
|
1641
|
+
if hasattr(self.tokenizer, "eos_token"):
|
|
1642
|
+
new_text = new_text.replace(self.tokenizer.eos_token, "")
|
|
1643
|
+
|
|
1644
|
+
# Check for stop sequences
|
|
1645
|
+
if stop_sequences:
|
|
1646
|
+
for stop_seq in stop_sequences:
|
|
1647
|
+
if stop_seq in new_text:
|
|
1648
|
+
# Make sure we yield the text up to before the stop sequence
|
|
1649
|
+
new_text = new_text[: new_text.find(stop_seq)]
|
|
1650
|
+
self.stop_event.set()
|
|
1651
|
+
|
|
1652
|
+
yield new_text
|
|
1653
|
+
|
|
1654
|
+
# Allow the user to finish the response early
|
|
1655
|
+
if self.stop_event.is_set():
|
|
1656
|
+
logging.info("Stopping generation early.")
|
|
1657
|
+
break
|
|
1658
|
+
|
|
1659
|
+
if len(self.decode_token_times) > 0:
|
|
1660
|
+
self.tokens_per_second = 1 / statistics.mean(self.decode_token_times)
|
|
1661
|
+
else:
|
|
1662
|
+
self.tokens_per_second = 0
|
|
1663
|
+
|
|
1664
|
+
finally:
|
|
1665
|
+
thread.join()
|
|
1666
|
+
|
|
1667
|
+
# Release the semaphore when generation is complete (or if an error occurs)
|
|
1668
|
+
self._generate_semaphore.release()
|
|
1669
|
+
active_generations = (
|
|
1670
|
+
self.max_concurrent_generations
|
|
1671
|
+
- self._generate_semaphore._value # pylint: disable=protected-access
|
|
1672
|
+
)
|
|
1673
|
+
|
|
1674
|
+
# Check if an exception occurred in the generation thread
|
|
1675
|
+
# If it did, raise it as an HTTPException so that the client
|
|
1676
|
+
# knows they wont be getting a completion
|
|
1677
|
+
if thread.exception:
|
|
1678
|
+
raise HTTPException(
|
|
1679
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
1680
|
+
detail=f"Completion failure: {thread.exception}",
|
|
1681
|
+
)
|
|
1682
|
+
|
|
1683
|
+
# Display telemetry if in debug mode
|
|
1684
|
+
await self._show_telemetry()
|
|
1685
|
+
|
|
1686
|
+
async def send_stats(self):
|
|
1687
|
+
"""
|
|
1688
|
+
Send performance statistics to the client.
|
|
1689
|
+
"""
|
|
1690
|
+
# If using wrapped server, get telemetry from the telemetry instance
|
|
1691
|
+
if self.llm_loaded and (
|
|
1692
|
+
self.llm_loaded.recipe == "llamacpp" or self.llm_loaded.recipe == "flm"
|
|
1693
|
+
):
|
|
1694
|
+
return self.wrapped_server.telemetry.get_telemetry_data()
|
|
1695
|
+
|
|
1696
|
+
# For built-in server, use the existing telemetry
|
|
1697
|
+
return {
|
|
1698
|
+
"time_to_first_token": self.time_to_first_token,
|
|
1699
|
+
"tokens_per_second": self.tokens_per_second,
|
|
1700
|
+
"input_tokens": self.input_tokens,
|
|
1701
|
+
"output_tokens": self.output_tokens,
|
|
1702
|
+
"decode_token_times": self.decode_token_times,
|
|
1703
|
+
}
|
|
1704
|
+
|
|
1705
|
+
async def halt_generation(self):
|
|
1706
|
+
"""
|
|
1707
|
+
Allow the client to halt an in-progress generation.
|
|
1708
|
+
"""
|
|
1709
|
+
|
|
1710
|
+
self.stop_event.set()
|
|
1711
|
+
|
|
1712
|
+
return {
|
|
1713
|
+
"terminated": True,
|
|
1714
|
+
}
|
|
1715
|
+
|
|
1716
|
+
async def health(self):
|
|
1717
|
+
"""
|
|
1718
|
+
Report server health information to the client.
|
|
1719
|
+
"""
|
|
1720
|
+
|
|
1721
|
+
return {
|
|
1722
|
+
"status": "ok",
|
|
1723
|
+
"checkpoint_loaded": (
|
|
1724
|
+
self.llm_loaded.checkpoint if self.llm_loaded else None
|
|
1725
|
+
),
|
|
1726
|
+
"model_loaded": (
|
|
1727
|
+
self.llm_loaded.model_name
|
|
1728
|
+
if (self.llm_loaded and self.llm_loaded.model_name)
|
|
1729
|
+
else None
|
|
1730
|
+
),
|
|
1731
|
+
}
|
|
1732
|
+
|
|
1733
|
+
async def get_system_info(self, request: Request):
|
|
1734
|
+
"""
|
|
1735
|
+
Return system and device enumeration information.
|
|
1736
|
+
Supports optional 'verbose' query parameter.
|
|
1737
|
+
"""
|
|
1738
|
+
from lemonade.common.system_info import (
|
|
1739
|
+
get_system_info_dict,
|
|
1740
|
+
get_device_info_dict,
|
|
1741
|
+
get_system_info as get_system_info_obj,
|
|
1742
|
+
)
|
|
1743
|
+
|
|
1744
|
+
# Get verbose parameter from query string (default to False)
|
|
1745
|
+
verbose = request.query_params.get("verbose", "false").lower() in ["true", "1"]
|
|
1746
|
+
|
|
1747
|
+
info = get_system_info_dict()
|
|
1748
|
+
info["devices"] = get_device_info_dict()
|
|
1749
|
+
|
|
1750
|
+
# Filter out verbose-only information if not in verbose mode
|
|
1751
|
+
if not verbose:
|
|
1752
|
+
essential_keys = ["OS Version", "Processor", "Physical Memory", "devices"]
|
|
1753
|
+
info = {k: v for k, v in info.items() if k in essential_keys}
|
|
1754
|
+
else:
|
|
1755
|
+
# In verbose mode, add Python packages at the end
|
|
1756
|
+
system_info_obj = get_system_info_obj()
|
|
1757
|
+
info["Python Packages"] = system_info_obj.get_python_packages()
|
|
1758
|
+
|
|
1759
|
+
return info
|
|
1760
|
+
|
|
1761
|
+
def model_load_failure(self, model_reference: str, message: Optional[str] = None):
|
|
1762
|
+
"""
|
|
1763
|
+
Clean up after a model load failure, then log it and raise
|
|
1764
|
+
an HTTPException with details.
|
|
1765
|
+
"""
|
|
1766
|
+
self.llm_loaded = None
|
|
1767
|
+
self.tokenizer = None
|
|
1768
|
+
self.model = None
|
|
1769
|
+
|
|
1770
|
+
default_message = "see stack trace and error message below"
|
|
1771
|
+
if message:
|
|
1772
|
+
detail = message
|
|
1773
|
+
else:
|
|
1774
|
+
detail = default_message
|
|
1775
|
+
|
|
1776
|
+
logging.exception(f"Tried to load LLM {model_reference} and failed: {detail}")
|
|
1777
|
+
|
|
1778
|
+
raise HTTPException(
|
|
1779
|
+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
1780
|
+
detail=detail,
|
|
1781
|
+
)
|
|
1782
|
+
|
|
1783
|
+
async def pull(self, config: PullConfig):
|
|
1784
|
+
"""
|
|
1785
|
+
Install a supported LLM by its Lemonade Model Name.
|
|
1786
|
+
"""
|
|
1787
|
+
|
|
1788
|
+
# Install the model
|
|
1789
|
+
ModelManager().download_models(
|
|
1790
|
+
[config.model_name],
|
|
1791
|
+
checkpoint=config.checkpoint,
|
|
1792
|
+
recipe=config.recipe,
|
|
1793
|
+
reasoning=config.reasoning,
|
|
1794
|
+
vision=config.vision,
|
|
1795
|
+
mmproj=config.mmproj,
|
|
1796
|
+
# The pull endpoint will download an upgraded model if available, even
|
|
1797
|
+
# if we already have a local copy of the model
|
|
1798
|
+
do_not_upgrade=False,
|
|
1799
|
+
)
|
|
1800
|
+
|
|
1801
|
+
# Refresh the list of downloaded models, to ensure it
|
|
1802
|
+
# includes the model we just installed
|
|
1803
|
+
self.local_models = ModelManager().downloaded_models_enabled
|
|
1804
|
+
|
|
1805
|
+
async def delete(self, config: DeleteConfig):
|
|
1806
|
+
"""
|
|
1807
|
+
Delete a supported LLM by its Lemonade Model Name.
|
|
1808
|
+
"""
|
|
1809
|
+
try:
|
|
1810
|
+
# If the model to be deleted is currently loaded, unload it first
|
|
1811
|
+
if self.llm_loaded and self.llm_loaded.model_name == config.model_name:
|
|
1812
|
+
await self.unload_llm(require_lock=True)
|
|
1813
|
+
|
|
1814
|
+
# Delete the model
|
|
1815
|
+
ModelManager().delete_model(config.model_name)
|
|
1816
|
+
|
|
1817
|
+
# Refresh the list of downloaded models
|
|
1818
|
+
self.local_models = ModelManager().downloaded_models_enabled
|
|
1819
|
+
|
|
1820
|
+
return {
|
|
1821
|
+
"status": "success",
|
|
1822
|
+
"message": f"Deleted model: {config.model_name}",
|
|
1823
|
+
}
|
|
1824
|
+
except ValueError as e:
|
|
1825
|
+
raise HTTPException(
|
|
1826
|
+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
1827
|
+
detail=str(e),
|
|
1828
|
+
)
|
|
1829
|
+
except Exception as e:
|
|
1830
|
+
raise HTTPException(
|
|
1831
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
1832
|
+
detail=f"Failed to delete model {config.model_name}: {str(e)}",
|
|
1833
|
+
)
|
|
1834
|
+
|
|
1835
|
+
async def load_llm(self, config: LoadConfig):
|
|
1836
|
+
"""
|
|
1837
|
+
Load a registered LLM into system memory. Install the model first, if needed.
|
|
1838
|
+
config: the information required to load the model
|
|
1839
|
+
"""
|
|
1840
|
+
from huggingface_hub.constants import HF_HUB_CACHE
|
|
1841
|
+
|
|
1842
|
+
try:
|
|
1843
|
+
await self._load_lock.acquire()
|
|
1844
|
+
# Acquire all generate locks
|
|
1845
|
+
for _ in range(self.max_concurrent_generations):
|
|
1846
|
+
await self._generate_semaphore.acquire()
|
|
1847
|
+
|
|
1848
|
+
# Make sure the model is already registered
|
|
1849
|
+
supported_models = ModelManager().supported_models
|
|
1850
|
+
|
|
1851
|
+
# The `custom` name allows run-as-thread servers to bypass loading
|
|
1852
|
+
if config.model_name == "custom":
|
|
1853
|
+
config_to_use = config
|
|
1854
|
+
else:
|
|
1855
|
+
if config.model_name not in supported_models.keys():
|
|
1856
|
+
self.model_load_failure(
|
|
1857
|
+
config.model_name,
|
|
1858
|
+
message=(
|
|
1859
|
+
f"Load request for model_name={config.model_name} "
|
|
1860
|
+
"not registered with Lemonade Server. You can register and "
|
|
1861
|
+
"install new models with a `pull` request."
|
|
1862
|
+
),
|
|
1863
|
+
)
|
|
1864
|
+
|
|
1865
|
+
# Get additional properties from the model registry
|
|
1866
|
+
config_to_use = LoadConfig(**supported_models[config.model_name])
|
|
1867
|
+
|
|
1868
|
+
# For locally uploaded models, convert the relative checkpoint path to absolute path
|
|
1869
|
+
model_source = supported_models.get(config.model_name, {}).get(
|
|
1870
|
+
"source", None
|
|
1871
|
+
)
|
|
1872
|
+
if (
|
|
1873
|
+
model_source == "local_upload"
|
|
1874
|
+
and config_to_use.checkpoint
|
|
1875
|
+
and not config_to_use.recipe.startswith("hf-")
|
|
1876
|
+
):
|
|
1877
|
+
# Check if checkpoint is a relative path (stored during upload)
|
|
1878
|
+
if not os.path.isabs(config_to_use.checkpoint):
|
|
1879
|
+
# Convert relative path to absolute by joining with HF_HUB_CACHE
|
|
1880
|
+
absolute_checkpoint = os.path.join(
|
|
1881
|
+
HF_HUB_CACHE, config_to_use.checkpoint
|
|
1882
|
+
)
|
|
1883
|
+
if os.path.exists(absolute_checkpoint):
|
|
1884
|
+
config_to_use.checkpoint = absolute_checkpoint
|
|
1885
|
+
else:
|
|
1886
|
+
logging.warning(
|
|
1887
|
+
f"Checkpoint path does not exist: {absolute_checkpoint}"
|
|
1888
|
+
)
|
|
1889
|
+
|
|
1890
|
+
# Also resolve mmproj path if present
|
|
1891
|
+
if config_to_use.mmproj and not os.path.isabs(config_to_use.mmproj):
|
|
1892
|
+
absolute_mmproj = os.path.join(HF_HUB_CACHE, config_to_use.mmproj)
|
|
1893
|
+
if os.path.exists(absolute_mmproj):
|
|
1894
|
+
config_to_use.mmproj = absolute_mmproj
|
|
1895
|
+
else:
|
|
1896
|
+
logging.warning(
|
|
1897
|
+
f"MMProj path does not exist: {absolute_mmproj}"
|
|
1898
|
+
)
|
|
1899
|
+
|
|
1900
|
+
# Caching mechanism: if the checkpoint is already loaded there is nothing else to do
|
|
1901
|
+
if (
|
|
1902
|
+
self.llm_loaded
|
|
1903
|
+
and config_to_use.checkpoint == self.llm_loaded.checkpoint
|
|
1904
|
+
):
|
|
1905
|
+
if (
|
|
1906
|
+
self.llm_loaded.recipe == "llamacpp"
|
|
1907
|
+
or self.llm_loaded.recipe == "flm"
|
|
1908
|
+
) and self.wrapped_server.process.poll():
|
|
1909
|
+
# wrapped server process has gone away for some reason, so we should
|
|
1910
|
+
# proceed with loading to get it back
|
|
1911
|
+
pass
|
|
1912
|
+
else:
|
|
1913
|
+
return {
|
|
1914
|
+
"status": "success",
|
|
1915
|
+
"message": f"Model already loaded: {config.model_name}",
|
|
1916
|
+
}
|
|
1917
|
+
|
|
1918
|
+
# Unload the current model if needed
|
|
1919
|
+
if self.llm_loaded:
|
|
1920
|
+
await self.unload_llm(require_lock=False)
|
|
1921
|
+
|
|
1922
|
+
logging.info(f"Loading llm: {config.model_name}")
|
|
1923
|
+
try:
|
|
1924
|
+
if config_to_use.recipe == "llamacpp":
|
|
1925
|
+
self.wrapped_server = LlamaServer(self.llamacpp_backend)
|
|
1926
|
+
self.wrapped_server.load(
|
|
1927
|
+
model_config=config_to_use,
|
|
1928
|
+
ctx_size=self.ctx_size,
|
|
1929
|
+
do_not_upgrade=True,
|
|
1930
|
+
)
|
|
1931
|
+
|
|
1932
|
+
elif config_to_use.recipe == "flm":
|
|
1933
|
+
self.wrapped_server = FlmServer()
|
|
1934
|
+
self.wrapped_server.load(
|
|
1935
|
+
model_config=config_to_use,
|
|
1936
|
+
ctx_size=self.ctx_size,
|
|
1937
|
+
do_not_upgrade=True,
|
|
1938
|
+
)
|
|
1939
|
+
|
|
1940
|
+
else:
|
|
1941
|
+
self.model, self.tokenizer = lemonade_api.from_pretrained(
|
|
1942
|
+
checkpoint=config_to_use.checkpoint, recipe=config_to_use.recipe
|
|
1943
|
+
)
|
|
1944
|
+
self.llm_loaded = config_to_use
|
|
1945
|
+
|
|
1946
|
+
return {
|
|
1947
|
+
"status": "success",
|
|
1948
|
+
"message": f"Loaded model: {config.model_name}",
|
|
1949
|
+
}
|
|
1950
|
+
except HTTPException:
|
|
1951
|
+
raise
|
|
1952
|
+
except Exception: # pylint: disable=broad-exception-caught
|
|
1953
|
+
self.model_load_failure(config.model_name)
|
|
1954
|
+
|
|
1955
|
+
finally:
|
|
1956
|
+
self._load_lock.release()
|
|
1957
|
+
|
|
1958
|
+
# Release all generate locks
|
|
1959
|
+
for _ in range(self.max_concurrent_generations):
|
|
1960
|
+
self._generate_semaphore.release()
|
|
1961
|
+
|
|
1962
|
+
# Refresh the list of downloaded models, to ensure it
|
|
1963
|
+
# includes the model we just loaded
|
|
1964
|
+
if config.model_name not in self.local_models:
|
|
1965
|
+
self.local_models = ModelManager().downloaded_models_enabled
|
|
1966
|
+
|
|
1967
|
+
async def unload_llm(self, require_lock: bool = True):
|
|
1968
|
+
try:
|
|
1969
|
+
if require_lock:
|
|
1970
|
+
await self._load_lock.acquire()
|
|
1971
|
+
|
|
1972
|
+
# Acquire all generate locks
|
|
1973
|
+
for _ in range(self.max_concurrent_generations):
|
|
1974
|
+
await self._generate_semaphore.acquire()
|
|
1975
|
+
|
|
1976
|
+
if self.llm_loaded.recipe == "llamacpp" or self.llm_loaded.recipe == "flm":
|
|
1977
|
+
self.wrapped_server.process.terminate()
|
|
1978
|
+
|
|
1979
|
+
self.llm_loaded = None
|
|
1980
|
+
self.tokenizer = None
|
|
1981
|
+
self.model = None
|
|
1982
|
+
return {"status": "success", "message": "Unloaded model"}
|
|
1983
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
|
1984
|
+
return {
|
|
1985
|
+
"status": "error",
|
|
1986
|
+
"message": f"Failed to unload model: {str(e)}",
|
|
1987
|
+
}
|
|
1988
|
+
finally:
|
|
1989
|
+
if require_lock:
|
|
1990
|
+
self._load_lock.release()
|
|
1991
|
+
|
|
1992
|
+
# Release all generate locks
|
|
1993
|
+
for _ in range(self.max_concurrent_generations):
|
|
1994
|
+
self._generate_semaphore.release()
|
|
1995
|
+
|
|
1996
|
+
async def models(self):
|
|
1997
|
+
"""
|
|
1998
|
+
Return a list of available models in OpenAI-compatible format.
|
|
1999
|
+
"""
|
|
2000
|
+
models_list = []
|
|
2001
|
+
for model in self.local_models:
|
|
2002
|
+
m = ServerModel(
|
|
2003
|
+
id=model,
|
|
2004
|
+
owned_by="lemonade",
|
|
2005
|
+
object="model",
|
|
2006
|
+
created=int(time.time()),
|
|
2007
|
+
checkpoint=self.local_models[model]["checkpoint"],
|
|
2008
|
+
recipe=self.local_models[model]["recipe"],
|
|
2009
|
+
)
|
|
2010
|
+
models_list.append(m)
|
|
2011
|
+
|
|
2012
|
+
return {"object": "list", "data": models_list}
|
|
2013
|
+
|
|
2014
|
+
async def retrieve_model(self, model_id: str):
|
|
2015
|
+
"""
|
|
2016
|
+
Retrieve a specific model by ID in OpenAI-compatible format.
|
|
2017
|
+
"""
|
|
2018
|
+
# Raise an error if the model does not exist
|
|
2019
|
+
if model_id not in self.local_models:
|
|
2020
|
+
# Mimic the error format of the OpenAI API
|
|
2021
|
+
raise HTTPException(
|
|
2022
|
+
status_code=404,
|
|
2023
|
+
detail={
|
|
2024
|
+
"message": f"model {model_id} not found",
|
|
2025
|
+
"type": "api_error",
|
|
2026
|
+
"param": None,
|
|
2027
|
+
"code": None,
|
|
2028
|
+
},
|
|
2029
|
+
)
|
|
2030
|
+
|
|
2031
|
+
# Return the specific model
|
|
2032
|
+
model_info = self.local_models[model_id]
|
|
2033
|
+
model = ServerModel(
|
|
2034
|
+
id=model_id,
|
|
2035
|
+
owned_by="lemonade",
|
|
2036
|
+
object="model",
|
|
2037
|
+
created=int(time.time()),
|
|
2038
|
+
checkpoint=model_info["checkpoint"],
|
|
2039
|
+
recipe=model_info["recipe"],
|
|
2040
|
+
)
|
|
2041
|
+
|
|
2042
|
+
return model
|
|
2043
|
+
|
|
2044
|
+
def setup_middleware_timer(self):
|
|
2045
|
+
logging.info("Middleware set up")
|
|
2046
|
+
|
|
2047
|
+
@self.app.middleware("http")
|
|
2048
|
+
async def log_request_time(request: Request, call_next):
|
|
2049
|
+
"""
|
|
2050
|
+
Log the request processing time for any request.
|
|
2051
|
+
For streaming responses, wraps the body iterator to measure total time.
|
|
2052
|
+
Only applies the wrapper in debug mode.
|
|
2053
|
+
"""
|
|
2054
|
+
start_time = time.perf_counter()
|
|
2055
|
+
response = await call_next(request)
|
|
2056
|
+
|
|
2057
|
+
if (
|
|
2058
|
+
self.debug_logging_enabled
|
|
2059
|
+
and hasattr(response, "body_iterator")
|
|
2060
|
+
and response.body_iterator is not None
|
|
2061
|
+
):
|
|
2062
|
+
original_iterator = response.body_iterator
|
|
2063
|
+
|
|
2064
|
+
async def wrapped_iterator():
|
|
2065
|
+
async for chunk in original_iterator:
|
|
2066
|
+
yield chunk
|
|
2067
|
+
request_time = time.perf_counter() - start_time
|
|
2068
|
+
logging.debug(
|
|
2069
|
+
f"Total request time (streamed): {request_time:.4f} seconds"
|
|
2070
|
+
)
|
|
2071
|
+
|
|
2072
|
+
response.body_iterator = wrapped_iterator()
|
|
2073
|
+
else:
|
|
2074
|
+
request_time = time.perf_counter() - start_time
|
|
2075
|
+
if self.debug_logging_enabled:
|
|
2076
|
+
logging.debug(f"Total request time: {request_time:.4f} seconds")
|
|
2077
|
+
return response
|
|
2078
|
+
|
|
2079
|
+
async def logs_ws(self, websocket: WebSocket):
|
|
2080
|
+
if not self.log_file or not os.path.exists(self.log_file):
|
|
2081
|
+
await websocket.close(code=4000)
|
|
2082
|
+
return
|
|
2083
|
+
await log_streamer(websocket, self.log_file)
|
|
2084
|
+
|
|
2085
|
+
async def get_incompatible_models(self):
|
|
2086
|
+
"""
|
|
2087
|
+
Get information about incompatible RyzenAI models in the cache.
|
|
2088
|
+
"""
|
|
2089
|
+
try:
|
|
2090
|
+
return ModelManager().get_incompatible_ryzenai_models()
|
|
2091
|
+
except Exception as e:
|
|
2092
|
+
raise HTTPException(
|
|
2093
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
2094
|
+
detail=f"Failed to scan for incompatible models: {str(e)}",
|
|
2095
|
+
)
|
|
2096
|
+
|
|
2097
|
+
async def cleanup_incompatible_models(self, request: Request):
|
|
2098
|
+
"""
|
|
2099
|
+
Delete selected incompatible RyzenAI models from the cache.
|
|
2100
|
+
"""
|
|
2101
|
+
try:
|
|
2102
|
+
body = await request.json()
|
|
2103
|
+
model_paths = body.get("model_paths", [])
|
|
2104
|
+
|
|
2105
|
+
if not model_paths:
|
|
2106
|
+
raise HTTPException(
|
|
2107
|
+
status_code=status.HTTP_400_BAD_REQUEST,
|
|
2108
|
+
detail="No model_paths provided",
|
|
2109
|
+
)
|
|
2110
|
+
|
|
2111
|
+
result = ModelManager().cleanup_incompatible_models(model_paths)
|
|
2112
|
+
return result
|
|
2113
|
+
except HTTPException:
|
|
2114
|
+
raise
|
|
2115
|
+
except Exception as e:
|
|
2116
|
+
raise HTTPException(
|
|
2117
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
2118
|
+
detail=f"Failed to cleanup models: {str(e)}",
|
|
2119
|
+
)
|
|
2120
|
+
|
|
2121
|
+
|
|
2122
|
+
# This file was originally licensed under Apache 2.0. It has been modified.
|
|
2123
|
+
# Modifications Copyright (c) 2025 AMD
|