lemonade-sdk 9.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. lemonade/__init__.py +5 -0
  2. lemonade/api.py +180 -0
  3. lemonade/cache.py +92 -0
  4. lemonade/cli.py +173 -0
  5. lemonade/common/__init__.py +0 -0
  6. lemonade/common/build.py +176 -0
  7. lemonade/common/cli_helpers.py +139 -0
  8. lemonade/common/exceptions.py +98 -0
  9. lemonade/common/filesystem.py +368 -0
  10. lemonade/common/inference_engines.py +408 -0
  11. lemonade/common/network.py +93 -0
  12. lemonade/common/printing.py +110 -0
  13. lemonade/common/status.py +471 -0
  14. lemonade/common/system_info.py +1411 -0
  15. lemonade/common/test_helpers.py +28 -0
  16. lemonade/profilers/__init__.py +1 -0
  17. lemonade/profilers/agt_power.py +437 -0
  18. lemonade/profilers/hwinfo_power.py +429 -0
  19. lemonade/profilers/memory_tracker.py +259 -0
  20. lemonade/profilers/profiler.py +58 -0
  21. lemonade/sequence.py +363 -0
  22. lemonade/state.py +159 -0
  23. lemonade/tools/__init__.py +1 -0
  24. lemonade/tools/accuracy.py +432 -0
  25. lemonade/tools/adapter.py +114 -0
  26. lemonade/tools/bench.py +302 -0
  27. lemonade/tools/flm/__init__.py +1 -0
  28. lemonade/tools/flm/utils.py +305 -0
  29. lemonade/tools/huggingface/bench.py +187 -0
  30. lemonade/tools/huggingface/load.py +235 -0
  31. lemonade/tools/huggingface/utils.py +359 -0
  32. lemonade/tools/humaneval.py +264 -0
  33. lemonade/tools/llamacpp/bench.py +255 -0
  34. lemonade/tools/llamacpp/load.py +222 -0
  35. lemonade/tools/llamacpp/utils.py +1260 -0
  36. lemonade/tools/management_tools.py +319 -0
  37. lemonade/tools/mmlu.py +319 -0
  38. lemonade/tools/oga/__init__.py +0 -0
  39. lemonade/tools/oga/bench.py +120 -0
  40. lemonade/tools/oga/load.py +804 -0
  41. lemonade/tools/oga/migration.py +403 -0
  42. lemonade/tools/oga/utils.py +462 -0
  43. lemonade/tools/perplexity.py +147 -0
  44. lemonade/tools/prompt.py +263 -0
  45. lemonade/tools/report/__init__.py +0 -0
  46. lemonade/tools/report/llm_report.py +203 -0
  47. lemonade/tools/report/table.py +899 -0
  48. lemonade/tools/server/__init__.py +0 -0
  49. lemonade/tools/server/flm.py +133 -0
  50. lemonade/tools/server/llamacpp.py +320 -0
  51. lemonade/tools/server/serve.py +2123 -0
  52. lemonade/tools/server/static/favicon.ico +0 -0
  53. lemonade/tools/server/static/index.html +279 -0
  54. lemonade/tools/server/static/js/chat.js +1059 -0
  55. lemonade/tools/server/static/js/model-settings.js +183 -0
  56. lemonade/tools/server/static/js/models.js +1395 -0
  57. lemonade/tools/server/static/js/shared.js +556 -0
  58. lemonade/tools/server/static/logs.html +191 -0
  59. lemonade/tools/server/static/styles.css +2654 -0
  60. lemonade/tools/server/static/webapp.html +321 -0
  61. lemonade/tools/server/tool_calls.py +153 -0
  62. lemonade/tools/server/tray.py +664 -0
  63. lemonade/tools/server/utils/macos_tray.py +226 -0
  64. lemonade/tools/server/utils/port.py +77 -0
  65. lemonade/tools/server/utils/thread.py +85 -0
  66. lemonade/tools/server/utils/windows_tray.py +408 -0
  67. lemonade/tools/server/webapp.py +34 -0
  68. lemonade/tools/server/wrapped_server.py +559 -0
  69. lemonade/tools/tool.py +374 -0
  70. lemonade/version.py +1 -0
  71. lemonade_install/__init__.py +1 -0
  72. lemonade_install/install.py +239 -0
  73. lemonade_sdk-9.1.1.dist-info/METADATA +276 -0
  74. lemonade_sdk-9.1.1.dist-info/RECORD +84 -0
  75. lemonade_sdk-9.1.1.dist-info/WHEEL +5 -0
  76. lemonade_sdk-9.1.1.dist-info/entry_points.txt +5 -0
  77. lemonade_sdk-9.1.1.dist-info/licenses/LICENSE +201 -0
  78. lemonade_sdk-9.1.1.dist-info/licenses/NOTICE.md +47 -0
  79. lemonade_sdk-9.1.1.dist-info/top_level.txt +3 -0
  80. lemonade_server/cli.py +805 -0
  81. lemonade_server/model_manager.py +758 -0
  82. lemonade_server/pydantic_models.py +159 -0
  83. lemonade_server/server_models.json +643 -0
  84. lemonade_server/settings.py +39 -0
@@ -0,0 +1,2123 @@
1
+ import sys
2
+ import asyncio
3
+ import statistics
4
+ import time
5
+ from threading import Thread, Event
6
+ import logging
7
+ import platform
8
+ import tempfile
9
+ import traceback
10
+ from typing import Optional, Union, List
11
+ import json
12
+ from pathlib import Path
13
+ import os
14
+ import shutil
15
+ from fastapi import FastAPI, HTTPException, status, Request, WebSocket, Form, UploadFile
16
+ from fastapi.responses import StreamingResponse
17
+ from fastapi.middleware.cors import CORSMiddleware
18
+ from fastapi.staticfiles import StaticFiles
19
+ from starlette.websockets import WebSocketDisconnect, WebSocketState
20
+ import uvicorn
21
+ from uvicorn.config import Config
22
+ from uvicorn.server import Server as UvicornServer
23
+ from tabulate import tabulate
24
+
25
+ from openai.types.completion import Completion, CompletionChoice
26
+ from openai.types.chat import ChatCompletion, ChatCompletionChunk
27
+ from openai.types.chat import ChatCompletionMessage
28
+ from openai.types.chat.chat_completion_message_tool_call import (
29
+ ChatCompletionMessageToolCall,
30
+ Function,
31
+ )
32
+ from openai.types.completion_usage import CompletionUsage
33
+ from openai.types.chat.chat_completion import Choice
34
+ from openai.types.chat.chat_completion_chunk import (
35
+ ChoiceDelta,
36
+ ChoiceDeltaToolCall,
37
+ ChoiceDeltaToolCallFunction,
38
+ )
39
+ from openai.types.completion_choice import Logprobs
40
+ from openai.types.model import Model
41
+ from openai.types.responses import (
42
+ Response,
43
+ ResponseOutputMessage,
44
+ ResponseOutputText,
45
+ ResponseCreatedEvent,
46
+ ResponseTextDeltaEvent,
47
+ ResponseCompletedEvent,
48
+ )
49
+
50
+ import lemonade.api as lemonade_api
51
+ from lemonade.tools.server.wrapped_server import WrappedServer
52
+ from lemonade.tools.server.llamacpp import LlamaServer
53
+ from lemonade.tools.server.flm import FlmServer
54
+ from lemonade.tools.server.tool_calls import extract_tool_calls, get_tool_call_pattern
55
+ from lemonade.tools.server.webapp import get_webapp_html
56
+ from lemonade.tools.server.utils.port import lifespan
57
+
58
+ from lemonade_server.model_manager import ModelManager
59
+ from lemonade_server.pydantic_models import (
60
+ DEFAULT_PORT,
61
+ DEFAULT_HOST,
62
+ DEFAULT_LOG_LEVEL,
63
+ DEFAULT_LLAMACPP_BACKEND,
64
+ DEFAULT_CTX_SIZE,
65
+ LoadConfig,
66
+ CompletionRequest,
67
+ ChatCompletionRequest,
68
+ EmbeddingsRequest,
69
+ RerankingRequest,
70
+ ResponsesRequest,
71
+ PullConfig,
72
+ DeleteConfig,
73
+ LogLevelConfig,
74
+ )
75
+ from lemonade_server.settings import save_setting
76
+
77
+ # Set to a high number to allow for interesting experiences in real apps
78
+ # Tests should use the max_new_tokens argument to set a lower value
79
+ DEFAULT_MAX_NEW_TOKENS = 1500
80
+
81
+ if platform.system() in ["Windows", "Darwin"]:
82
+ # pylint: disable=ungrouped-imports
83
+ from lemonade.tools.server.tray import LemonadeTray, OutputDuplicator
84
+
85
+
86
+ class ServerLogFilter(logging.Filter):
87
+ def __init__(self, server):
88
+ super().__init__()
89
+ self.server = server
90
+ self.noisy_paths = {
91
+ "/api/v1/health",
92
+ "/api/v0/health",
93
+ "/api/v1/models",
94
+ "/api/v0/models",
95
+ }
96
+
97
+ def filter(self, record: logging.LogRecord) -> bool:
98
+ msg = record.getMessage()
99
+
100
+ # Filter out websocket logs
101
+ if "> TEXT" in msg:
102
+ return False
103
+
104
+ # Filter out noisy HTTP routes if debug logs are OFF
105
+ if not self.server.debug_logging_enabled:
106
+ if any(path in msg for path in self.noisy_paths):
107
+ return False
108
+
109
+ # Otherwise, allow the log
110
+ return True
111
+
112
+
113
+ async def log_streamer(websocket: WebSocket, path: str, interval: float = 1.0):
114
+ logger = logging.getLogger()
115
+ await websocket.accept()
116
+ try:
117
+ with open(path, "r", encoding="utf-8") as f:
118
+ f.seek(0) # start at the beginning of the file
119
+ while True:
120
+ # Try reading a line
121
+ line = f.readline()
122
+ if not line:
123
+ await asyncio.sleep(interval)
124
+ continue
125
+
126
+ # Send defensively: if disconnected, bail out
127
+ if websocket.application_state != WebSocketState.CONNECTED:
128
+ # Server-side state says we're not connected anymore
129
+ break
130
+
131
+ try:
132
+ await websocket.send_text(line)
133
+ except WebSocketDisconnect:
134
+ # Client closed — normal path out
135
+ break
136
+ except RuntimeError as re:
137
+ # Starlette will raise this if a close has already been sent
138
+ logger.debug("RuntimeError during send: %s", re)
139
+ break
140
+
141
+ except WebSocketDisconnect:
142
+ # Client closed the socket; do not try to send or close again
143
+ pass
144
+ except Exception as e: # pylint: disable=broad-except
145
+ # Log server-side; do not attempt to send error over a possibly closed socket
146
+ logger.exception("Error in log_streamer: %s", e)
147
+ finally:
148
+ # Only close if Starlette still thinks we're connected.
149
+ # This prevents "Cannot call send once a close message has been sent."
150
+ try:
151
+ if websocket.application_state == WebSocketState.CONNECTED:
152
+ await websocket.close()
153
+ except Exception: # pylint: disable=broad-except
154
+ # If close itself races, swallow — we're shutting down anyway.
155
+ pass
156
+
157
+
158
+ class ServerModel(Model):
159
+ """
160
+ An extension of OpenAI's Model class that adds
161
+ checkpoint and recipe attributes.
162
+ """
163
+
164
+ checkpoint: str
165
+ recipe: str
166
+
167
+
168
+ class GeneratorThread(Thread):
169
+ """
170
+ Thread class designed for use with streaming generation within
171
+ an LLM server. It needs access to the streamer in order to order
172
+ to help the completions APIs escape the "for text in streamer" loop.
173
+ It also provides exception handling that works nicely with HTTP
174
+ servers by providing the stack trace and making the exception
175
+ information available to the main thread.
176
+ """
177
+
178
+ def __init__(self, streamer, *args, **kwargs):
179
+ super().__init__(*args, **kwargs)
180
+ self.exception = None
181
+ self.streamer = streamer
182
+
183
+ def run(self):
184
+ try:
185
+ if self._target:
186
+ self._target(*self._args, **self._kwargs)
187
+ except Exception as e: # pylint: disable=broad-except
188
+ self.exception = e
189
+ logging.error(f"Exception raised in generate thread: {e}")
190
+ traceback.print_exc()
191
+ self.streamer.done()
192
+
193
+
194
+ class StopOnEvent:
195
+ """
196
+ Custom stopping criteria that halts text generation when a specified event is set.
197
+
198
+ This allows for external control of generation, such as stopping a generation
199
+ before it reaches the maximum token limit.
200
+ """
201
+
202
+ def __init__(self, stop_event: Event):
203
+ super().__init__()
204
+ self.stop_event = stop_event
205
+
206
+ def __call__(self, input_ids, scores, **kwargs):
207
+ return self.stop_event.is_set()
208
+
209
+
210
+ class NoCacheStaticFiles(StaticFiles):
211
+ """Custom StaticFiles class with no-cache headers"""
212
+
213
+ def __init__(self, *args, **kwargs):
214
+ super().__init__(*args, **kwargs)
215
+
216
+ def file_response(self, *args, **kwargs) -> Response:
217
+ response = super().file_response(*args, **kwargs)
218
+ # Add no-cache headers for all static files
219
+ response.headers["Cache-Control"] = "no-cache, no-store, must-revalidate"
220
+ response.headers["Pragma"] = "no-cache"
221
+ response.headers["Expires"] = "0"
222
+ return response
223
+
224
+
225
+ class Server:
226
+ """
227
+ Open a web server that apps can use to communicate with the LLM.
228
+
229
+ The server exposes these endpoints:
230
+ - /api/v1/pull: install an LLM by its Lemonade Server Model Name.
231
+ - /api/v1/delete: delete an LLM by its Lemonade Server Model Name.
232
+ - /api/v1/load: load a model checkpoint.
233
+ - /api/v1/unload: unload a model checkpoint.
234
+ - /api/v1/health: check whether a model is loaded and ready to serve.
235
+ - /api/v1/stats: performance statistics for the generation.
236
+ - /api/v1/halt: stop an in-progress generation from make more tokens.
237
+ - /api/v1/completions: completion responses using HTTP chunked transfer encoding.
238
+ - /api/v1/chat/completions: chat completion responses using HTTP chunked transfer encoding.
239
+ - /api/v1/responses: responses API using HTTP chunked transfer encoding.
240
+ - /api/v1/models: list all available models.
241
+ - /api/v1/models/{model_id}: retrieve a specific model by ID.
242
+ """
243
+
244
+ def __init__(
245
+ self,
246
+ port: int = DEFAULT_PORT,
247
+ host: str = DEFAULT_HOST,
248
+ log_level: str = DEFAULT_LOG_LEVEL,
249
+ ctx_size: int = DEFAULT_CTX_SIZE,
250
+ tray: bool = False,
251
+ log_file: str = None,
252
+ llamacpp_backend: str = DEFAULT_LLAMACPP_BACKEND,
253
+ ):
254
+ super().__init__()
255
+
256
+ # Save args as members
257
+ self.port = port
258
+ self.host = host
259
+ self.log_level = log_level
260
+ self.ctx_size = ctx_size
261
+ self.tray = tray
262
+ self.log_file = log_file
263
+ self.llamacpp_backend = llamacpp_backend
264
+
265
+ # Initialize FastAPI app
266
+ self.app = FastAPI(lifespan=lifespan)
267
+
268
+ # Lifespan will load some tasks in the background, and then set the
269
+ # app.initialized flag to True when this is done
270
+ self.app.initialized = False
271
+
272
+ # Add CORS middleware
273
+ self.app.add_middleware(
274
+ CORSMiddleware,
275
+ allow_origins=["*"], # Allows all origins
276
+ allow_credentials=True,
277
+ allow_methods=["*"], # Allows all methods
278
+ allow_headers=["*"], # Allows all headers
279
+ )
280
+
281
+ # Set up debug middleware if debug logging is enabled
282
+ # This must be done during app initialization, not at runtime
283
+ self.debug_logging_enabled = log_level == "debug"
284
+ if self.debug_logging_enabled:
285
+ self.setup_middleware_timer()
286
+
287
+ # Set up custom routes
288
+ self.setup_routes(["/api/v0", "/api/v1"])
289
+
290
+ # Set up Web App
291
+ self.app.get("/")(self.webapp)
292
+
293
+ # Mount a static assets dir for HTML responses, such
294
+ # as the Web App
295
+ static_dir = Path(__file__).parent / "static"
296
+ self.app.mount(
297
+ "/static", NoCacheStaticFiles(directory=static_dir), name="static_assets"
298
+ )
299
+
300
+ # Performance stats that are set during /ws and can be
301
+ # fetched in /stats
302
+ self.time_to_first_token = None
303
+ self.tokens_per_second = None
304
+ self.input_tokens = None
305
+ self.output_tokens = None
306
+ self.decode_token_times = None
307
+
308
+ # Store debug logging state
309
+ self.debug_logging_enabled = logging.getLogger().isEnabledFor(logging.DEBUG)
310
+
311
+ # Flag that tells the LLM to stop generating text and end the response
312
+ self.stop_event = Event()
313
+
314
+ self.llm_loaded: LoadConfig = None
315
+ self.tokenizer = None
316
+
317
+ # Placeholders for model and configs
318
+ self.model = None
319
+
320
+ # Initialize semaphore for tracking active generations
321
+ self.max_concurrent_generations = 1
322
+ self._generate_semaphore = asyncio.Semaphore(self.max_concurrent_generations)
323
+
324
+ # Dictionary of installed LLM, by model name : information about those models
325
+ # Does not include non-installed models
326
+ self.local_models = ModelManager().downloaded_models_enabled
327
+
328
+ # Add lock for load/unload operations
329
+ self._load_lock = asyncio.Lock()
330
+
331
+ # Subprocess handle for wrapped instance of llama_server.exe, etc.
332
+ self.wrapped_server: WrappedServer = None
333
+
334
+ def setup_routes(self, api_prefixes: list[str]):
335
+ for prefix in api_prefixes:
336
+ # Custom routes
337
+ self.app.post(f"{prefix}/pull")(self.pull)
338
+ self.app.post(f"{prefix}/delete")(self.delete)
339
+ self.app.post(f"{prefix}/load")(self.load_llm)
340
+ self.app.post(f"{prefix}/unload")(self.unload_llm)
341
+ self.app.get(f"{prefix}/health")(self.health)
342
+ self.app.get(f"{prefix}/halt")(self.halt_generation)
343
+ self.app.get(f"{prefix}/stats")(self.send_stats)
344
+ self.app.get(f"{prefix}/system-info")(self.get_system_info)
345
+ self.app.post(f"{prefix}/completions")(self.completions)
346
+ self.app.post(f"{prefix}/responses")(self.responses)
347
+ self.app.post(f"{prefix}/log-level")(self.set_log_level)
348
+ self.app.websocket(f"{prefix}/logs/ws")(self.logs_ws)
349
+ self.app.post(f"{prefix}/add-local-model")(self.add_local_model)
350
+
351
+ # OpenAI-compatible routes
352
+ self.app.post(f"{prefix}/chat/completions")(self.chat_completions)
353
+ self.app.post(f"{prefix}/embeddings")(self.embeddings)
354
+ self.app.get(f"{prefix}/models")(self.models)
355
+ self.app.get(f"{prefix}/models/{{model_id}}")(self.retrieve_model)
356
+
357
+ # JinaAI routes (jina.ai/reranker/)
358
+ self.app.post(f"{prefix}/reranking")(self.reranking)
359
+ self.app.post(f"{prefix}/rerank")(self.reranking)
360
+
361
+ # Migration routes
362
+ self.app.get(f"{prefix}/migration/incompatible-models")(
363
+ self.get_incompatible_models
364
+ )
365
+ self.app.post(f"{prefix}/migration/cleanup")(
366
+ self.cleanup_incompatible_models
367
+ )
368
+
369
+ async def add_local_model(
370
+ self,
371
+ model_name: str = Form(...),
372
+ checkpoint: str = Form(""),
373
+ recipe: str = Form(...),
374
+ reasoning: bool = Form(False),
375
+ vision: bool = Form(False),
376
+ mmproj: str = Form(None),
377
+ model_files: List[UploadFile] = None,
378
+ ):
379
+ from huggingface_hub.constants import HF_HUB_CACHE
380
+ from lemonade.tools.llamacpp.utils import parse_checkpoint
381
+
382
+ # Upload and register a local model from files.
383
+ try:
384
+ if not model_files:
385
+ raise HTTPException(
386
+ status_code=status.HTTP_400_BAD_REQUEST,
387
+ detail="No model files provided for upload",
388
+ )
389
+
390
+ if not model_name.startswith("user."):
391
+ raise HTTPException(
392
+ status_code=status.HTTP_400_BAD_REQUEST,
393
+ detail="Model name must start with 'user.'",
394
+ )
395
+
396
+ valid_recipes = ["llamacpp", "oga-npu", "oga-hybrid", "oga-cpu"]
397
+ if recipe not in valid_recipes:
398
+ raise HTTPException(
399
+ status_code=status.HTTP_400_BAD_REQUEST,
400
+ detail=f"Invalid recipe. Must be one of: {', '.join(valid_recipes)}",
401
+ )
402
+
403
+ if recipe == "llamacpp" and not any(
404
+ f.filename.lower().endswith(".gguf") for f in model_files
405
+ ):
406
+ raise HTTPException(
407
+ status_code=status.HTTP_400_BAD_REQUEST,
408
+ detail="At least one .gguf file is required for llamacpp",
409
+ )
410
+
411
+ # Check if model name already exists
412
+ if model_name in ModelManager().supported_models:
413
+ raise HTTPException(
414
+ status_code=status.HTTP_409_CONFLICT,
415
+ detail=(
416
+ f"Model name '{model_name}' already exists. "
417
+ "Please use a different name."
418
+ ),
419
+ )
420
+
421
+ model_name_clean = model_name.replace("user.", "")
422
+
423
+ # Files are saved to models--{model_name_clean}
424
+ # Note: This is based on the user's custom model name, NOT the checkpoint field
425
+ repo_cache_name = model_name_clean.replace("/", "--")
426
+ snapshot_path = os.path.join(HF_HUB_CACHE, f"models--{repo_cache_name}")
427
+ os.makedirs(snapshot_path, exist_ok=True)
428
+
429
+ # Extract variant from checkpoint field if provided
430
+ # checkpoint field format: "folder:variant" or just "folder"
431
+ variant = None
432
+ if checkpoint and ":" in checkpoint:
433
+ _, variant = parse_checkpoint(checkpoint)
434
+ # variant now contains just the variant[can be with or without the
435
+ # .gguf extension] filename (e.g., "LFM2-VL-1.6B-F16 or LFM2-VL-1.6B-F16.gguf")
436
+
437
+ # Save uploaded files, preserving folder structure
438
+ for file in model_files:
439
+ relative_path = file.filename
440
+ path_parts = relative_path.split("/")
441
+
442
+ if len(path_parts) > 1:
443
+ internal_path = "/".join(path_parts[1:])
444
+ file_path = os.path.join(snapshot_path, internal_path)
445
+ else:
446
+ file_path = os.path.join(snapshot_path, path_parts[0])
447
+
448
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
449
+ with open(file_path, "wb") as f:
450
+ content = await file.read()
451
+ f.write(content)
452
+
453
+ # Resolve actual file paths after upload (for faster loading later)
454
+ resolved_checkpoint = None
455
+ resolved_mmproj = None
456
+
457
+ # For OGA models, find genai_config.json
458
+ if recipe.startswith("oga-"):
459
+ for root, _, files in os.walk(snapshot_path):
460
+ if "genai_config.json" in files:
461
+ resolved_checkpoint = root
462
+ break
463
+ if not resolved_checkpoint:
464
+ resolved_checkpoint = snapshot_path
465
+
466
+ # For llamacpp models, find the GGUF file
467
+ elif recipe == "llamacpp":
468
+ gguf_file_found = None
469
+
470
+ # If variant is specified, look for that specific file
471
+ if variant:
472
+ search_term = (
473
+ variant if variant.endswith(".gguf") else f"{variant}.gguf"
474
+ )
475
+ for root, _, files in os.walk(snapshot_path):
476
+ if search_term in files:
477
+ gguf_file_found = os.path.join(root, search_term)
478
+ break
479
+
480
+ # If no variant or variant not found, search for any .gguf file (excluding mmproj)
481
+ if not gguf_file_found:
482
+ for root, _, files in os.walk(snapshot_path):
483
+ gguf_files = [
484
+ f
485
+ for f in files
486
+ if f.endswith(".gguf") and "mmproj" not in f.lower()
487
+ ]
488
+ if gguf_files:
489
+ gguf_file_found = os.path.join(root, gguf_files[0])
490
+ break
491
+
492
+ resolved_checkpoint = (
493
+ gguf_file_found if gguf_file_found else snapshot_path
494
+ )
495
+
496
+ # Search for mmproj file if provided
497
+ if mmproj:
498
+ for root, _, files in os.walk(snapshot_path):
499
+ if mmproj in files:
500
+ resolved_mmproj = os.path.join(root, mmproj)
501
+ break
502
+
503
+ # Build checkpoint for registration
504
+ # For llamacpp with resolved path, store the full path relative to HF_HUB_CACHE
505
+ if resolved_checkpoint:
506
+ # Store as relative path from HF_HUB_CACHE for portability
507
+ checkpoint_to_register = os.path.relpath(
508
+ resolved_checkpoint, HF_HUB_CACHE
509
+ )
510
+ elif variant:
511
+ checkpoint_to_register = f"models--{repo_cache_name}:{variant}"
512
+ else:
513
+ checkpoint_to_register = f"models--{repo_cache_name}"
514
+
515
+ # Register the model
516
+ ModelManager().register_local_model(
517
+ model_name=model_name,
518
+ checkpoint=checkpoint_to_register,
519
+ recipe=recipe,
520
+ reasoning=reasoning,
521
+ vision=vision,
522
+ mmproj=resolved_mmproj if resolved_mmproj else mmproj,
523
+ snapshot_path=snapshot_path,
524
+ )
525
+
526
+ # Refresh local models
527
+ self.local_models = ModelManager().downloaded_models_enabled
528
+
529
+ return {
530
+ "status": "success",
531
+ "message": f"Model {model_name} uploaded and registered successfully",
532
+ }
533
+ except Exception as e:
534
+ if os.path.exists(checkpoint_to_register):
535
+ shutil.rmtree(checkpoint_to_register)
536
+ raise HTTPException(
537
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
538
+ detail=f"Failed to upload model: {str(e)}",
539
+ )
540
+
541
+ async def set_log_level(self, config: LogLevelConfig):
542
+ """
543
+ Set the logging level of the server.
544
+ """
545
+ try:
546
+ log_level_upper = config.level.upper()
547
+ numeric_level = getattr(logging, log_level_upper, None)
548
+ if not isinstance(numeric_level, int):
549
+ raise ValueError(f"Invalid log level: {config.level}")
550
+
551
+ # Get the root logger
552
+ logger = logging.getLogger()
553
+ logger.setLevel(numeric_level)
554
+
555
+ # Update all handlers
556
+ for handler in logger.handlers:
557
+ handler.setLevel(numeric_level)
558
+
559
+ logging.getLogger("uvicorn.error").setLevel(numeric_level)
560
+ self.debug_logging_enabled = numeric_level <= logging.DEBUG
561
+
562
+ # Save the setting
563
+ save_setting("log_level", config.level)
564
+
565
+ logging.info(f"Log level changed to: {config.level}")
566
+ return {"status": "success", "message": f"Log level set to {config.level}"}
567
+ except Exception as e:
568
+ raise HTTPException(
569
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
570
+ detail=f"Failed to set log level: {str(e)}",
571
+ )
572
+
573
+ def _log_request_parameters(self, request, endpoint_name: str):
574
+ """
575
+ Log request parameters excluding content fields like messages, prompt, or input.
576
+
577
+ Args:
578
+ request: Any request object (CompletionRequest, ChatCompletionRequest, etc.)
579
+ endpoint_name: Name of the endpoint for logging context
580
+ """
581
+ if not logging.getLogger().isEnabledFor(logging.DEBUG):
582
+ return
583
+
584
+ # Fields to exclude from logging (content fields)
585
+ excluded_fields = {"messages", "prompt", "input"}
586
+
587
+ # Get all attributes from the request object
588
+ request_params = {}
589
+ if hasattr(request, "__dict__"):
590
+ # For pydantic models, get the dict representation
591
+ if hasattr(request, "model_dump"):
592
+ all_params = request.model_dump()
593
+ elif hasattr(request, "dict"):
594
+ all_params = request.dict()
595
+ else:
596
+ all_params = request.__dict__
597
+
598
+ # Filter out excluded fields and add special handling for certain fields
599
+ for key, value in all_params.items():
600
+ if key not in excluded_fields:
601
+ # Special handling for tools field - show count instead of full content
602
+ if key == "tools" and value is not None:
603
+ request_params[key] = (
604
+ f"{len(value)} tools" if isinstance(value, list) else value
605
+ )
606
+ # Special handling for input type in responses
607
+ elif key == "input" and hasattr(request, "input"):
608
+ request_params["input_type"] = type(value).__name__
609
+ else:
610
+ request_params[key] = value
611
+
612
+ logging.debug(f"{endpoint_name} request parameters: {request_params}")
613
+
614
+ def _setup_server_common(
615
+ self,
616
+ tray: bool = False,
617
+ threaded_mode: bool = False,
618
+ ):
619
+ """
620
+ Common setup logic shared between run() and run_in_thread().
621
+
622
+ Args:
623
+ tray: Whether to run the server in tray mode
624
+ threaded_mode: Whether this is being set up for threaded execution
625
+ """
626
+
627
+ # Define TRACE level
628
+ logging.TRACE = 9 # Lower than DEBUG which is 10
629
+ logging.addLevelName(logging.TRACE, "TRACE")
630
+
631
+ # Add a convenience function at the module level
632
+ def trace(message, *args, **kwargs):
633
+ logging.log(logging.TRACE, message, *args, **kwargs)
634
+
635
+ logging.trace = trace
636
+
637
+ # Configure logging based on mode
638
+ if threaded_mode:
639
+ # Configure logging for warning level (to reduce noise in threaded execution)
640
+ logging.getLogger("uvicorn.error").setLevel(logging.WARNING)
641
+ else:
642
+ # Configure logging to match uvicorn's format
643
+ logging_level = getattr(logging, self.log_level.upper())
644
+
645
+ # Set up file handler for logging to lemonade.log
646
+ uvicorn_formatter = uvicorn.logging.DefaultFormatter(
647
+ fmt="%(levelprefix)s %(message)s",
648
+ use_colors=True,
649
+ )
650
+ if not self.log_file:
651
+ self.log_file = tempfile.NamedTemporaryFile(
652
+ prefix="lemonade_", suffix=".log", delete=False
653
+ ).name
654
+ file_handler = logging.FileHandler(
655
+ self.log_file, mode="a", encoding="utf-8"
656
+ )
657
+ file_handler.setLevel(logging_level)
658
+ file_handler.setFormatter(uvicorn_formatter)
659
+ file_handler.addFilter(ServerLogFilter(self))
660
+
661
+ # Set up console handler
662
+ console_handler = logging.StreamHandler()
663
+ console_handler.setLevel(logging_level)
664
+ console_handler.setFormatter(uvicorn_formatter)
665
+ console_handler.addFilter(ServerLogFilter(self))
666
+
667
+ # Configure root logger with both handlers
668
+ logging.basicConfig(
669
+ level=logging_level,
670
+ handlers=[file_handler, console_handler],
671
+ force=True,
672
+ )
673
+
674
+ # Update debug logging state after setting log level
675
+ self.debug_logging_enabled = logging.getLogger().isEnabledFor(logging.DEBUG)
676
+ if tray:
677
+ # Save original stdout/stderr
678
+ sys.stdout = OutputDuplicator(self.log_file, sys.stdout)
679
+ sys.stderr = OutputDuplicator(self.log_file, sys.stderr)
680
+
681
+ # Open lemonade server in tray mode
682
+ # lambda function used for deferred instantiation and thread safety
683
+ LemonadeTray(
684
+ self.log_file, self.port, lambda: self, log_level=self.log_level
685
+ ).run()
686
+ sys.exit(0)
687
+
688
+ # Let the app know what port it's running on, so
689
+ # that the lifespan can access it
690
+ self.app.port = self.port
691
+ # FastAPI already has a `host` function and we cannot use `_host` as
692
+ # PyLint will believe its private
693
+ self.app.host_ = self.host
694
+
695
+ def run(self):
696
+ # Common setup
697
+ self._setup_server_common(
698
+ threaded_mode=False,
699
+ tray=self.tray,
700
+ )
701
+
702
+ uvicorn.run(self.app, host=self.host, port=self.port, log_level=self.log_level)
703
+
704
+ def run_in_thread(self, host: str = "localhost"):
705
+ """
706
+ Set up the server for running in a thread.
707
+ Returns a uvicorn server instance that can be controlled externally.
708
+ """
709
+ # Common setup
710
+ self._setup_server_common(
711
+ threaded_mode=True,
712
+ tray=False,
713
+ )
714
+
715
+ class CustomServer(UvicornServer):
716
+ """Custom Uvicorn server that can be properly shutdown from another thread"""
717
+
718
+ def install_signal_handlers(self):
719
+ pass
720
+
721
+ # Configure the server
722
+ config = Config(
723
+ app=self.app,
724
+ host=host,
725
+ port=self.port,
726
+ log_level=self.log_level,
727
+ log_config=None,
728
+ )
729
+
730
+ # Create and return the uvicorn server
731
+ return CustomServer(config=config)
732
+
733
+ async def _show_telemetry(self):
734
+ """
735
+ Show telemetry data in debug mode.
736
+ """
737
+ # Exit early if debug logging is disabled or no telemetry data is available
738
+ if not self.debug_logging_enabled or self.tokens_per_second is None:
739
+ return
740
+
741
+ # Prepare telemetry data (transposed format)
742
+ telemetry = [
743
+ ["Input tokens", self.input_tokens],
744
+ ["Output tokens", self.output_tokens],
745
+ ["TTFT (s)", f"{self.time_to_first_token:.2f}"],
746
+ ["TPS", f"{self.tokens_per_second:.2f}"],
747
+ ]
748
+
749
+ table = tabulate(
750
+ telemetry, headers=["Metric", "Value"], tablefmt="fancy_grid"
751
+ ).split("\n")
752
+
753
+ # Show telemetry in debug while complying with uvicorn's log indentation
754
+ logging.debug("\n ".join(table))
755
+
756
+ def webapp(self):
757
+ """
758
+ Serve the Web App to the user's browser.
759
+ """
760
+
761
+ return get_webapp_html(port=self.app.port)
762
+
763
+ def initialize_load_config(
764
+ self, request: Union[ChatCompletionRequest, CompletionRequest]
765
+ ) -> LoadConfig:
766
+ """
767
+ Turn the Request object into a partially-complete LoadConfig.
768
+
769
+ The load_llm() method is responsible for filling in the rest of
770
+ LoadConfig's parameters.
771
+ """
772
+
773
+ # Get model config
774
+ if "/" in request.model:
775
+ # We know the model is a Hugging Face checkpoint if it contains a /
776
+ # This scenario is only supported for run-as-thread
777
+ lc = LoadConfig(model_name="custom", checkpoint=request.model)
778
+ else:
779
+ # The model should be a reference to a built-in model
780
+ lc = LoadConfig(model_name=request.model)
781
+
782
+ return lc
783
+
784
+ async def completions(
785
+ self, completion_request: CompletionRequest, request: Request
786
+ ):
787
+ """
788
+ Stream completion responses using HTTP chunked transfer encoding.
789
+ """
790
+
791
+ lc = self.initialize_load_config(completion_request)
792
+
793
+ # Log request parameters (excluding message content for brevity)
794
+ self._log_request_parameters(completion_request, "Completions")
795
+
796
+ # Load the model if it's different from the currently loaded one
797
+ await self.load_llm(lc)
798
+
799
+ if self.llm_loaded.recipe == "llamacpp" or self.llm_loaded.recipe == "flm":
800
+ return self.wrapped_server.completion(completion_request)
801
+
802
+ # Check if the model supports reasoning
803
+ reasoning_first_token = self.llm_loaded.reasoning
804
+
805
+ # If the model supports reasoning, we:
806
+ # 1. add a <think> tag to the model's context
807
+ # 2. ensure that the first token is a <think> token
808
+ text = completion_request.prompt
809
+ if reasoning_first_token:
810
+ text += "<think>"
811
+
812
+ # Prepare generation arguments
813
+ generation_args = {
814
+ "message": text,
815
+ "stop": completion_request.stop,
816
+ "temperature": completion_request.temperature,
817
+ "repeat_penalty": completion_request.repeat_penalty,
818
+ "top_k": completion_request.top_k,
819
+ "top_p": completion_request.top_p,
820
+ "max_new_tokens": completion_request.max_tokens,
821
+ }
822
+
823
+ if completion_request.stream:
824
+
825
+ if completion_request.logprobs:
826
+ logging.warning("logprobs is not supported for streaming completion")
827
+ if completion_request.echo:
828
+ logging.warning(
829
+ "`Echo` parameter is not supported for streaming completions"
830
+ )
831
+
832
+ # Stream the response
833
+ async def generate():
834
+ # Declare it's the same variable from outside scope
835
+ # This is necessary because the variable is modified
836
+ # in the inner function
837
+ nonlocal reasoning_first_token
838
+ try:
839
+ async for token in self._generate_tokens(**generation_args):
840
+ # Handle client disconnect: stop generation and exit
841
+ if await request.is_disconnected():
842
+ self.stop_event.set()
843
+ break
844
+
845
+ choice = CompletionChoice(
846
+ text=(
847
+ "<think>" + token if reasoning_first_token else token
848
+ ),
849
+ index=0,
850
+ finish_reason="stop",
851
+ logprobs=None,
852
+ )
853
+
854
+ completion = Completion(
855
+ id="0",
856
+ choices=[choice],
857
+ model=self.llm_loaded.checkpoint,
858
+ object="text_completion",
859
+ created=int(time.time()),
860
+ )
861
+
862
+ # Format as SSE
863
+ reasoning_first_token = False
864
+ yield f"data: {completion.model_dump_json()}\n\n".encode(
865
+ "utf-8"
866
+ )
867
+
868
+ # Send the [DONE] marker only if still connected
869
+ if not await request.is_disconnected():
870
+ yield b"data: [DONE]\n\n"
871
+ except asyncio.CancelledError:
872
+ # Propagate cancellation to the generator loop
873
+ self.stop_event.set()
874
+ return
875
+
876
+ return StreamingResponse(
877
+ generate(),
878
+ media_type="text/event-stream",
879
+ headers={
880
+ "Cache-Control": "no-cache",
881
+ "Connection": "keep-alive",
882
+ },
883
+ )
884
+
885
+ # If streaming is not requested, collect all generated tokens into a single response
886
+ else:
887
+ full_response = text if completion_request.echo else ""
888
+ async for token in self._generate_tokens(**generation_args):
889
+ full_response += token
890
+
891
+ # If logprobs are requested, create a logprobs object
892
+ logprobs = None
893
+ if completion_request.logprobs:
894
+
895
+ # Compute the logprobs
896
+ text_offset, token_logprobs, tokens, top_logprobs = (
897
+ self.model.compute_logprobs(
898
+ text=full_response,
899
+ tokenizer=self.tokenizer,
900
+ logprobs=completion_request.logprobs,
901
+ )
902
+ )
903
+ logprobs = Logprobs.model_construct(
904
+ text_offset=text_offset,
905
+ token_logprobs=token_logprobs,
906
+ tokens=tokens,
907
+ top_logprobs=top_logprobs,
908
+ )
909
+
910
+ choice = CompletionChoice(
911
+ text=full_response,
912
+ index=0,
913
+ finish_reason="stop",
914
+ logprobs=logprobs,
915
+ )
916
+
917
+ usage = CompletionUsage(
918
+ prompt_tokens=self.input_tokens,
919
+ completion_tokens=self.output_tokens,
920
+ total_tokens=self.input_tokens + self.output_tokens,
921
+ )
922
+
923
+ return Completion(
924
+ id="0",
925
+ choices=[choice],
926
+ usage=usage,
927
+ model=self.llm_loaded.checkpoint,
928
+ object="text_completion",
929
+ created=int(time.time()),
930
+ )
931
+
932
+ async def chat_completions(
933
+ self, chat_completion_request: ChatCompletionRequest, request: Request
934
+ ):
935
+ """
936
+ Stream chat completion responses using HTTP chunked transfer encoding.
937
+ """
938
+
939
+ if chat_completion_request.logprobs:
940
+ logging.warning("logprobs is not supported on chat completion")
941
+
942
+ lc = self.initialize_load_config(chat_completion_request)
943
+
944
+ # Log request parameters (excluding message history for brevity)
945
+ self._log_request_parameters(chat_completion_request, "Chat completions")
946
+
947
+ # Load the model if it's different from the currently loaded one
948
+ await self.load_llm(lc)
949
+
950
+ if self.llm_loaded.recipe == "llamacpp" or self.llm_loaded.recipe == "flm":
951
+ if (
952
+ hasattr(chat_completion_request, "enable_thinking")
953
+ and chat_completion_request.enable_thinking is False
954
+ and "qwen3" in self.llm_loaded.model_name.lower()
955
+ ):
956
+
957
+ # Modify the last user message to include /no_think
958
+ if chat_completion_request.messages:
959
+ for i in range(len(chat_completion_request.messages) - 1, -1, -1):
960
+ if chat_completion_request.messages[i].get("role") == "user":
961
+ original_content = chat_completion_request.messages[i][
962
+ "content"
963
+ ]
964
+ chat_completion_request.messages[i][
965
+ "content"
966
+ ] = f"/no_think\n{original_content}"
967
+ break
968
+ return self.wrapped_server.chat_completion(chat_completion_request)
969
+
970
+ # Convert chat messages to text using the model's chat template
971
+ text = self.apply_chat_template(
972
+ chat_completion_request.messages,
973
+ tools=chat_completion_request.tools,
974
+ )
975
+
976
+ # If the model supports reasoning, we:
977
+ # 1. add a <think> tag to the model's context
978
+ # 2. ensure that the first token is a <think> token
979
+ reasoning_first_token = self.llm_loaded.reasoning
980
+
981
+ if reasoning_first_token:
982
+ text += "<think>"
983
+ # Set the max_new_tokens parameter
984
+ if (
985
+ chat_completion_request.max_completion_tokens
986
+ and chat_completion_request.max_tokens
987
+ ):
988
+ raise HTTPException(
989
+ status_code=status.HTTP_400_BAD_REQUEST,
990
+ detail=(
991
+ "Both max_tokens and max_completion_tokens were provided. "
992
+ "Please use only one of these parameters.",
993
+ ),
994
+ )
995
+ max_new_tokens = (
996
+ chat_completion_request.max_completion_tokens
997
+ if chat_completion_request.max_completion_tokens
998
+ else chat_completion_request.max_tokens
999
+ )
1000
+
1001
+ # Prepare generation arguments
1002
+ generation_args = {
1003
+ "message": text,
1004
+ "stop": chat_completion_request.stop,
1005
+ "temperature": chat_completion_request.temperature,
1006
+ "repeat_penalty": chat_completion_request.repeat_penalty,
1007
+ "top_k": chat_completion_request.top_k,
1008
+ "top_p": chat_completion_request.top_p,
1009
+ "max_new_tokens": max_new_tokens,
1010
+ }
1011
+
1012
+ if chat_completion_request.tools:
1013
+ # Get the tool call pattern
1014
+ tool_call_pattern = get_tool_call_pattern(
1015
+ self.tokenizer.auto_tokenizer.added_tokens_decoder
1016
+ )
1017
+
1018
+ if chat_completion_request.stream:
1019
+
1020
+ # Stream the response
1021
+ async def generate():
1022
+ # Declare it's the same variable from outside scope
1023
+ # This is necessary because the variable is modified
1024
+ # in the inner function
1025
+ nonlocal reasoning_first_token
1026
+
1027
+ # Keep track of the full response for tool call extraction
1028
+ full_response = ""
1029
+
1030
+ # Track whether we're still in the thinking phase (before </think> tag)
1031
+ in_thinking_phase = self.llm_loaded.reasoning
1032
+ reasoning_buffer = "" # Accumulate reasoning tokens to detect </think>
1033
+
1034
+ try:
1035
+ async for token in self._generate_tokens(**generation_args):
1036
+ # Handle client disconnect: stop generation and exit
1037
+ if await request.is_disconnected():
1038
+ self.stop_event.set()
1039
+ break
1040
+
1041
+ # Continuously look for tool calls embedded into the generated text
1042
+ openai_tool_calls = None
1043
+ if chat_completion_request.tools:
1044
+
1045
+ # Append the token to the full response
1046
+ full_response += token
1047
+
1048
+ tool_calls, _ = extract_tool_calls(
1049
+ full_response,
1050
+ tool_call_pattern,
1051
+ )
1052
+
1053
+ # If there are tool calls, reset the full response for the next call
1054
+ if tool_calls:
1055
+ openai_tool_calls = []
1056
+ full_response = ""
1057
+ for tool_call in tool_calls:
1058
+ openai_tool_calls.append(
1059
+ ChoiceDeltaToolCall(
1060
+ index=0,
1061
+ id="-",
1062
+ function=ChoiceDeltaToolCallFunction(
1063
+ arguments=json.dumps(
1064
+ tool_call["arguments"]
1065
+ ),
1066
+ name=tool_call["name"],
1067
+ ),
1068
+ type="function",
1069
+ )
1070
+ )
1071
+
1072
+ # Create a ChatCompletionChunk with reasoning_content support
1073
+ # If we're in reasoning mode and haven't seen </think> yet,
1074
+ # send tokens as reasoning_content instead of content
1075
+ delta_content = None
1076
+ delta_reasoning = None
1077
+
1078
+ if reasoning_first_token:
1079
+ # First token - include opening tag in reasoning
1080
+ delta_reasoning = "<think>" + token
1081
+ reasoning_first_token = False
1082
+ reasoning_buffer = token
1083
+ elif in_thinking_phase:
1084
+ # Still in thinking phase - accumulate and check for </think>
1085
+ reasoning_buffer += token
1086
+
1087
+ # Check if we've seen the closing tag
1088
+ if "</think>" in reasoning_buffer:
1089
+ # Split at the closing tag
1090
+ before_close, after_close = reasoning_buffer.split(
1091
+ "</think>", 1
1092
+ )
1093
+
1094
+ # Send everything before + closing tag as reasoning
1095
+ if before_close or not reasoning_buffer.startswith(
1096
+ "</think>"
1097
+ ):
1098
+ delta_reasoning = before_close + "</think>"
1099
+ else:
1100
+ delta_reasoning = "</think>"
1101
+
1102
+ # Everything after goes to content (will be sent in next iteration)
1103
+ # For now, mark that we've exited thinking phase
1104
+ in_thinking_phase = False
1105
+
1106
+ # If there's content after </think>, we need to send it too
1107
+ # But we send it in the current chunk as regular content
1108
+ if after_close:
1109
+ # We have both reasoning and content in this token
1110
+ # Send reasoning first, content will accumulate
1111
+ delta_content = after_close
1112
+ else:
1113
+ # Still accumulating thinking, send as reasoning_content
1114
+ delta_reasoning = token
1115
+ else:
1116
+ # Normal content (after thinking phase ended)
1117
+ delta_content = token
1118
+
1119
+ chunk = ChatCompletionChunk.model_construct(
1120
+ id="0",
1121
+ object="chat.completion.chunk",
1122
+ created=int(time.time()),
1123
+ model=self.llm_loaded.checkpoint,
1124
+ choices=[
1125
+ Choice.model_construct(
1126
+ index=0,
1127
+ delta=ChoiceDelta(
1128
+ content=delta_content,
1129
+ reasoning_content=delta_reasoning,
1130
+ function_call=None,
1131
+ role="assistant",
1132
+ tool_calls=openai_tool_calls,
1133
+ refusal=None,
1134
+ ),
1135
+ finish_reason=None,
1136
+ logprobs=None,
1137
+ )
1138
+ ],
1139
+ )
1140
+
1141
+ # Format as SSE
1142
+ yield f"data: {chunk.model_dump_json()}\n\n".encode("utf-8")
1143
+
1144
+ # Send the [DONE] marker only if still connected
1145
+ if not await request.is_disconnected():
1146
+ yield b"data: [DONE]\n\n"
1147
+ except asyncio.CancelledError:
1148
+ self.stop_event.set()
1149
+ return
1150
+
1151
+ return StreamingResponse(
1152
+ generate(),
1153
+ media_type="text/event-stream",
1154
+ headers={
1155
+ "Cache-Control": "no-cache",
1156
+ "Connection": "keep-alive",
1157
+ },
1158
+ )
1159
+
1160
+ # If streaming is not requested, collect all generated tokens into a single response
1161
+ else:
1162
+ full_response = "<think>" if reasoning_first_token else ""
1163
+ async for token in self._generate_tokens(**generation_args):
1164
+ full_response += token
1165
+
1166
+ # Extract tool calls from the response
1167
+ openai_tool_calls = None
1168
+ if chat_completion_request.tools:
1169
+ tool_calls, full_response = extract_tool_calls(
1170
+ full_response, tool_call_pattern
1171
+ )
1172
+ if tool_calls:
1173
+ openai_tool_calls = []
1174
+ for tool_call in tool_calls:
1175
+ openai_tool_calls.append(
1176
+ ChatCompletionMessageToolCall(
1177
+ id="-",
1178
+ function=Function(
1179
+ arguments=json.dumps(tool_call["arguments"]),
1180
+ name=tool_call["name"],
1181
+ ),
1182
+ type="function",
1183
+ )
1184
+ )
1185
+
1186
+ ccm = ChatCompletionMessage(
1187
+ content=full_response,
1188
+ role="assistant",
1189
+ refusal=None,
1190
+ audio=None,
1191
+ function_call=None,
1192
+ tool_calls=openai_tool_calls,
1193
+ )
1194
+
1195
+ choice = Choice(
1196
+ finish_reason="stop",
1197
+ index=0,
1198
+ message=ccm,
1199
+ logprobs=None,
1200
+ )
1201
+
1202
+ usage = CompletionUsage(
1203
+ prompt_tokens=self.input_tokens,
1204
+ completion_tokens=self.output_tokens,
1205
+ total_tokens=self.input_tokens + self.output_tokens,
1206
+ )
1207
+
1208
+ return ChatCompletion(
1209
+ id="0",
1210
+ choices=[choice],
1211
+ usage=usage,
1212
+ model=self.llm_loaded.checkpoint,
1213
+ object="chat.completion",
1214
+ created=int(time.time()),
1215
+ )
1216
+
1217
+ async def embeddings(self, embeddings_request: EmbeddingsRequest):
1218
+ """
1219
+ Generate embeddings for the provided input.
1220
+ """
1221
+ # Initialize load config from embeddings request
1222
+ lc = LoadConfig(model_name=embeddings_request.model)
1223
+
1224
+ # Load the model if it's different from the currently loaded one
1225
+ await self.load_llm(lc)
1226
+
1227
+ if self.llm_loaded.recipe == "llamacpp":
1228
+ try:
1229
+ return self.wrapped_server.embeddings(embeddings_request)
1230
+ except Exception as e: # pylint: disable=broad-exception-caught
1231
+ # Check if model has embeddings label
1232
+ model_info = ModelManager().supported_models.get(
1233
+ self.llm_loaded.model_name, {}
1234
+ )
1235
+ if "embeddings" not in model_info.get("labels", []):
1236
+ raise HTTPException(
1237
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
1238
+ detail="You tried to generate embeddings for a model that is "
1239
+ "not labeled as an embeddings model. Please use another model "
1240
+ "or re-register the current model with the 'embeddings' label.",
1241
+ ) from e
1242
+ else:
1243
+ raise e
1244
+ else:
1245
+ raise HTTPException(
1246
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
1247
+ detail=f"Embeddings not supported for recipe: {self.llm_loaded.recipe}",
1248
+ )
1249
+
1250
+ async def reranking(self, reranking_request: RerankingRequest):
1251
+ """
1252
+ Rerank documents based on their relevance to a query.
1253
+ """
1254
+ # Initialize load config from reranking request
1255
+ lc = LoadConfig(model_name=reranking_request.model)
1256
+
1257
+ # Load the model if it's different from the currently loaded one
1258
+ await self.load_llm(lc)
1259
+
1260
+ if self.llm_loaded.recipe == "llamacpp":
1261
+ try:
1262
+ return self.wrapped_server.reranking(reranking_request)
1263
+ except Exception as e: # pylint: disable=broad-exception-caught
1264
+ # Check if model has reranking label
1265
+ model_info = ModelManager().supported_models.get(
1266
+ self.llm_loaded.model_name, {}
1267
+ )
1268
+ if "reranking" not in model_info.get("labels", []):
1269
+ raise HTTPException(
1270
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
1271
+ detail="You tried to use reranking for a model that is "
1272
+ "not labeled as a reranking model. Please use another model "
1273
+ "or re-register the current model with the 'reranking' label.",
1274
+ ) from e
1275
+ else:
1276
+ raise e
1277
+ else:
1278
+ raise HTTPException(
1279
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
1280
+ detail=f"Reranking not supported for recipe: {self.llm_loaded.recipe}",
1281
+ )
1282
+
1283
+ def apply_chat_template(
1284
+ self, messages: list[dict], tools: list[dict] | None = None
1285
+ ):
1286
+ """
1287
+ Apply the model's chat template to the messages.
1288
+ """
1289
+ if self.tokenizer.chat_template:
1290
+
1291
+ return self.tokenizer.apply_chat_template(
1292
+ messages,
1293
+ tokenize=False,
1294
+ add_generation_prompt=True,
1295
+ tools=tools,
1296
+ )
1297
+
1298
+ # Fallback to a standardized template if the model doesn't provide one
1299
+ logging.warning("No chat template found. Using default template.")
1300
+ formatted_messages = []
1301
+ for msg in messages:
1302
+ role = msg.get("role", "user")
1303
+ content = msg.get("content", "")
1304
+ role_marker = "<|assistant|>" if role == "assistant" else "<|user|>"
1305
+ formatted_messages.append(f"{role_marker}\n{content} <|end|>")
1306
+ return "\n".join(formatted_messages) + "\n<|assistant|>"
1307
+
1308
+ async def responses(self, responses_request: ResponsesRequest, request: Request):
1309
+ """
1310
+ Stream responses using HTTP chunked transfer encoding.
1311
+ """
1312
+
1313
+ lc = self.initialize_load_config(responses_request)
1314
+
1315
+ # Log request parameters (excluding message history for brevity)
1316
+ self._log_request_parameters(responses_request, "Responses")
1317
+
1318
+ # Load the model if it's different from the currently loaded one
1319
+ await self.load_llm(lc)
1320
+
1321
+ if self.llm_loaded.recipe == "llamacpp":
1322
+ raise HTTPException(
1323
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
1324
+ detail=f"Responses API not supported for recipe: {self.llm_loaded.recipe}",
1325
+ )
1326
+
1327
+ # Convert chat messages to text using the model's chat template
1328
+ if isinstance(responses_request.input, str):
1329
+ text = responses_request.input
1330
+ else:
1331
+ text = self.apply_chat_template(responses_request.input)
1332
+
1333
+ # If the model supports reasoning, we:
1334
+ # 1. add a <think> tag to the model's context
1335
+ # 2. ensure that the first token is a <think> token
1336
+ reasoning_first_token = self.llm_loaded.reasoning
1337
+
1338
+ if reasoning_first_token:
1339
+ text += "<think>"
1340
+
1341
+ # Prepare generation arguments
1342
+ generation_args = {
1343
+ "message": text,
1344
+ "temperature": responses_request.temperature,
1345
+ "repeat_penalty": responses_request.repeat_penalty,
1346
+ "top_k": responses_request.top_k,
1347
+ "top_p": responses_request.top_p,
1348
+ "max_new_tokens": responses_request.max_output_tokens,
1349
+ }
1350
+
1351
+ if responses_request.stream:
1352
+
1353
+ # Stream the response
1354
+ async def generate():
1355
+ # Declare it's the same variable from outside scope
1356
+ # This is necessary because the variable is modified
1357
+ # in the inner function
1358
+ nonlocal reasoning_first_token
1359
+
1360
+ # Send initial creation event
1361
+ response = Response(
1362
+ id="0",
1363
+ model=self.llm_loaded.checkpoint,
1364
+ created_at=int(time.time()),
1365
+ object="response",
1366
+ output=[],
1367
+ parallel_tool_calls=True,
1368
+ tool_choice="auto",
1369
+ tools=[],
1370
+ )
1371
+ created_event = ResponseCreatedEvent(
1372
+ response=response,
1373
+ type="response.created",
1374
+ sequence_number=0,
1375
+ )
1376
+ yield f"data: {created_event.model_dump_json()}\n\n".encode("utf-8")
1377
+
1378
+ full_response = "<think>" if reasoning_first_token else ""
1379
+
1380
+ try:
1381
+ async for token in self._generate_tokens(**generation_args):
1382
+ # Handle client disconnect: stop generation and exit
1383
+ if await request.is_disconnected():
1384
+ self.stop_event.set()
1385
+ break
1386
+
1387
+ # Create an event
1388
+ delta_event = ResponseTextDeltaEvent(
1389
+ content_index=0,
1390
+ delta=(
1391
+ "<think>" + token if reasoning_first_token else token
1392
+ ),
1393
+ item_id="0 ",
1394
+ logprobs=[],
1395
+ output_index=0,
1396
+ sequence_number=0,
1397
+ type="response.output_text.delta",
1398
+ )
1399
+ full_response += token
1400
+
1401
+ # Format as SSE
1402
+ reasoning_first_token = False
1403
+ yield f"data: {delta_event.model_dump_json()}\n\n".encode(
1404
+ "utf-8"
1405
+ )
1406
+
1407
+ # Send the completed event (only if still connected)
1408
+ if not await request.is_disconnected():
1409
+ response_output_message = ResponseOutputMessage(
1410
+ id="0",
1411
+ content=[
1412
+ ResponseOutputText(
1413
+ annotations=[],
1414
+ text=full_response,
1415
+ type="output_text",
1416
+ )
1417
+ ],
1418
+ role="assistant",
1419
+ status="completed",
1420
+ type="message",
1421
+ )
1422
+ response = Response(
1423
+ id="0",
1424
+ model=self.llm_loaded.checkpoint,
1425
+ created_at=int(time.time()),
1426
+ object="response",
1427
+ output=[response_output_message],
1428
+ parallel_tool_calls=True,
1429
+ tool_choice="auto",
1430
+ tools=[],
1431
+ )
1432
+ completed_event = ResponseCompletedEvent(
1433
+ response=response,
1434
+ type="response.completed",
1435
+ sequence_number=0,
1436
+ )
1437
+ yield f"data: {completed_event.model_dump_json()}\n\n".encode(
1438
+ "utf-8"
1439
+ )
1440
+
1441
+ # Send the [DONE] marker
1442
+ yield b"data: [DONE]\n\n"
1443
+ except asyncio.CancelledError:
1444
+ self.stop_event.set()
1445
+ return
1446
+
1447
+ return StreamingResponse(
1448
+ generate(),
1449
+ media_type="text/event-stream",
1450
+ headers={
1451
+ "Cache-Control": "no-cache",
1452
+ "Connection": "keep-alive",
1453
+ },
1454
+ )
1455
+
1456
+ # If streaming is not requested, collect all generated tokens into a single response
1457
+ else:
1458
+ full_response = "<think>" if reasoning_first_token else ""
1459
+ async for token in self._generate_tokens(**generation_args):
1460
+ full_response += token
1461
+
1462
+ # Send the completed event
1463
+ response_output_message = ResponseOutputMessage(
1464
+ id="0",
1465
+ content=[
1466
+ ResponseOutputText(
1467
+ annotations=[],
1468
+ text=full_response,
1469
+ type="output_text",
1470
+ )
1471
+ ],
1472
+ role="assistant",
1473
+ status="completed",
1474
+ type="message",
1475
+ )
1476
+ return Response(
1477
+ id="0",
1478
+ model=self.llm_loaded.checkpoint,
1479
+ created_at=int(time.time()),
1480
+ object="response",
1481
+ output=[response_output_message],
1482
+ parallel_tool_calls=True,
1483
+ tool_choice="auto",
1484
+ tools=[],
1485
+ )
1486
+
1487
+ async def _generate_tokens(
1488
+ self,
1489
+ message: str,
1490
+ stop: list[str] | str | None = None,
1491
+ max_new_tokens: int | None = None,
1492
+ temperature: float | None = None,
1493
+ repeat_penalty: float | None = None,
1494
+ top_k: int | None = None,
1495
+ top_p: float | None = None,
1496
+ ):
1497
+ """
1498
+ Core streaming completion logic, separated from response handling.
1499
+ Returns an async generator that yields tokens.
1500
+ """
1501
+
1502
+ while not self.app.initialized:
1503
+ # Wait for the app's background tasks to finish before
1504
+ # allowing generation to proceed
1505
+ logging.debug("Waiting for server to fully initialize")
1506
+ asyncio.sleep(0.5)
1507
+ # These should already be imported as part of the app initialization process,
1508
+ # they are just here to make 100% certain and to make the linter happy
1509
+ from transformers import TextIteratorStreamer, StoppingCriteriaList
1510
+
1511
+ model = self.model
1512
+ tokenizer = self.tokenizer
1513
+
1514
+ # Reset the early-exit flag before we start each generation
1515
+ self.stop_event.clear()
1516
+
1517
+ input_ids = tokenizer(message, return_tensors="pt").input_ids
1518
+
1519
+ # Process stop sequences
1520
+ stop_sequences = []
1521
+ if stop is not None:
1522
+ if isinstance(stop, str):
1523
+ stop_sequences = [stop]
1524
+ else:
1525
+ stop_sequences = stop[:4] # Limit to 4 sequences as per spec
1526
+
1527
+ # Set up the generation parameters
1528
+ if "oga-" in self.llm_loaded.recipe:
1529
+ from lemonade.tools.oga.utils import OrtGenaiStreamer
1530
+
1531
+ streamer = OrtGenaiStreamer(tokenizer)
1532
+ self.input_tokens = len(input_ids)
1533
+ else:
1534
+ streamer = TextIteratorStreamer(
1535
+ tokenizer,
1536
+ skip_prompt=True,
1537
+ )
1538
+ self.input_tokens = len(input_ids[0])
1539
+
1540
+ max_prompt_length = self.ctx_size # Default fallback
1541
+ # For OGA models, try to read the actual max prompt length from config
1542
+ if "oga-" in self.llm_loaded.recipe:
1543
+ try:
1544
+ if model.config and model.config.get("max_prompt_length"):
1545
+ max_prompt_length = model.config["max_prompt_length"]
1546
+ logging.debug(
1547
+ f"Using OGA model max_prompt_length: {max_prompt_length}"
1548
+ )
1549
+ # pylint: disable=broad-exception-caught
1550
+ except Exception as e:
1551
+ logging.debug(f"Could not read OGA model config, using ctx_size: {e}")
1552
+
1553
+ # Apply truncation if input exceeds the limit
1554
+ if self.input_tokens > max_prompt_length:
1555
+ # Truncate input ids
1556
+ truncate_amount = self.input_tokens - max_prompt_length
1557
+ input_ids = input_ids[:max_prompt_length]
1558
+ # Update token count
1559
+ if "oga-" in self.llm_loaded.recipe:
1560
+ self.input_tokens = len(input_ids)
1561
+ else:
1562
+ self.input_tokens = len(input_ids[0])
1563
+
1564
+ # Log warning message instead of raising exception
1565
+ truncation_message = (
1566
+ f"Input exceeded {max_prompt_length} tokens. "
1567
+ f"Truncated {truncate_amount} tokens from the beginning."
1568
+ )
1569
+ logging.warning(truncation_message)
1570
+
1571
+ # Log the input tokens early to avoid this not showing due to potential crashes
1572
+ logging.debug(f"Input Tokens: {self.input_tokens}")
1573
+ logging.trace(f"Input Message: {message}")
1574
+
1575
+ if self.llm_loaded.recipe.startswith("hf"):
1576
+ stopping_criteria = StoppingCriteriaList([StopOnEvent(self.stop_event)])
1577
+ else:
1578
+ # HF expects StoppingCriteriaList, which requires torch
1579
+ # If we aren't using HF, we can just use a list of StopOnEvent to
1580
+ # avoid the torch dep
1581
+ stopping_criteria = [StopOnEvent(self.stop_event)]
1582
+
1583
+ generation_kwargs = {
1584
+ "input_ids": input_ids,
1585
+ "streamer": streamer,
1586
+ "max_new_tokens": (
1587
+ max_new_tokens if max_new_tokens else DEFAULT_MAX_NEW_TOKENS
1588
+ ),
1589
+ "min_new_tokens": 1,
1590
+ "pad_token_id": tokenizer.eos_token_id,
1591
+ "stopping_criteria": stopping_criteria,
1592
+ "temperature": temperature,
1593
+ "repeat_penalty": repeat_penalty,
1594
+ "top_k": top_k,
1595
+ "top_p": top_p,
1596
+ }
1597
+
1598
+ # Initialize performance variables
1599
+ generation_start_time = time.perf_counter()
1600
+ first_token = True
1601
+ self.decode_token_times = []
1602
+ self.output_tokens = 0
1603
+
1604
+ # Begin generation
1605
+ thread = GeneratorThread(
1606
+ streamer, target=model.generate, kwargs=generation_kwargs
1607
+ )
1608
+ thread.start()
1609
+
1610
+ # Acquire the generation semaphore
1611
+ await self._generate_semaphore.acquire()
1612
+ active_generations = (
1613
+ self.max_concurrent_generations
1614
+ - self._generate_semaphore._value # pylint: disable=protected-access
1615
+ )
1616
+
1617
+ logging.debug(f"Active generations: {active_generations}")
1618
+
1619
+ try:
1620
+ # Generate the response using streaming
1621
+ new_text = ""
1622
+ for new_text in streamer:
1623
+ # Yield control back to the event loop
1624
+ # This gives the FastAPI server a chance to send the chunks to the client
1625
+ await asyncio.sleep(0)
1626
+
1627
+ # Capture performance stats about this token
1628
+ self.output_tokens = self.output_tokens + 1
1629
+ if first_token:
1630
+ self.time_to_first_token = (
1631
+ time.perf_counter() - generation_start_time
1632
+ )
1633
+ first_token = False
1634
+ else:
1635
+ self.decode_token_times.append(
1636
+ time.perf_counter() - next_token_start_time
1637
+ )
1638
+ next_token_start_time = time.perf_counter()
1639
+
1640
+ # Remove the EOS token from the response if needed
1641
+ if hasattr(self.tokenizer, "eos_token"):
1642
+ new_text = new_text.replace(self.tokenizer.eos_token, "")
1643
+
1644
+ # Check for stop sequences
1645
+ if stop_sequences:
1646
+ for stop_seq in stop_sequences:
1647
+ if stop_seq in new_text:
1648
+ # Make sure we yield the text up to before the stop sequence
1649
+ new_text = new_text[: new_text.find(stop_seq)]
1650
+ self.stop_event.set()
1651
+
1652
+ yield new_text
1653
+
1654
+ # Allow the user to finish the response early
1655
+ if self.stop_event.is_set():
1656
+ logging.info("Stopping generation early.")
1657
+ break
1658
+
1659
+ if len(self.decode_token_times) > 0:
1660
+ self.tokens_per_second = 1 / statistics.mean(self.decode_token_times)
1661
+ else:
1662
+ self.tokens_per_second = 0
1663
+
1664
+ finally:
1665
+ thread.join()
1666
+
1667
+ # Release the semaphore when generation is complete (or if an error occurs)
1668
+ self._generate_semaphore.release()
1669
+ active_generations = (
1670
+ self.max_concurrent_generations
1671
+ - self._generate_semaphore._value # pylint: disable=protected-access
1672
+ )
1673
+
1674
+ # Check if an exception occurred in the generation thread
1675
+ # If it did, raise it as an HTTPException so that the client
1676
+ # knows they wont be getting a completion
1677
+ if thread.exception:
1678
+ raise HTTPException(
1679
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
1680
+ detail=f"Completion failure: {thread.exception}",
1681
+ )
1682
+
1683
+ # Display telemetry if in debug mode
1684
+ await self._show_telemetry()
1685
+
1686
+ async def send_stats(self):
1687
+ """
1688
+ Send performance statistics to the client.
1689
+ """
1690
+ # If using wrapped server, get telemetry from the telemetry instance
1691
+ if self.llm_loaded and (
1692
+ self.llm_loaded.recipe == "llamacpp" or self.llm_loaded.recipe == "flm"
1693
+ ):
1694
+ return self.wrapped_server.telemetry.get_telemetry_data()
1695
+
1696
+ # For built-in server, use the existing telemetry
1697
+ return {
1698
+ "time_to_first_token": self.time_to_first_token,
1699
+ "tokens_per_second": self.tokens_per_second,
1700
+ "input_tokens": self.input_tokens,
1701
+ "output_tokens": self.output_tokens,
1702
+ "decode_token_times": self.decode_token_times,
1703
+ }
1704
+
1705
+ async def halt_generation(self):
1706
+ """
1707
+ Allow the client to halt an in-progress generation.
1708
+ """
1709
+
1710
+ self.stop_event.set()
1711
+
1712
+ return {
1713
+ "terminated": True,
1714
+ }
1715
+
1716
+ async def health(self):
1717
+ """
1718
+ Report server health information to the client.
1719
+ """
1720
+
1721
+ return {
1722
+ "status": "ok",
1723
+ "checkpoint_loaded": (
1724
+ self.llm_loaded.checkpoint if self.llm_loaded else None
1725
+ ),
1726
+ "model_loaded": (
1727
+ self.llm_loaded.model_name
1728
+ if (self.llm_loaded and self.llm_loaded.model_name)
1729
+ else None
1730
+ ),
1731
+ }
1732
+
1733
+ async def get_system_info(self, request: Request):
1734
+ """
1735
+ Return system and device enumeration information.
1736
+ Supports optional 'verbose' query parameter.
1737
+ """
1738
+ from lemonade.common.system_info import (
1739
+ get_system_info_dict,
1740
+ get_device_info_dict,
1741
+ get_system_info as get_system_info_obj,
1742
+ )
1743
+
1744
+ # Get verbose parameter from query string (default to False)
1745
+ verbose = request.query_params.get("verbose", "false").lower() in ["true", "1"]
1746
+
1747
+ info = get_system_info_dict()
1748
+ info["devices"] = get_device_info_dict()
1749
+
1750
+ # Filter out verbose-only information if not in verbose mode
1751
+ if not verbose:
1752
+ essential_keys = ["OS Version", "Processor", "Physical Memory", "devices"]
1753
+ info = {k: v for k, v in info.items() if k in essential_keys}
1754
+ else:
1755
+ # In verbose mode, add Python packages at the end
1756
+ system_info_obj = get_system_info_obj()
1757
+ info["Python Packages"] = system_info_obj.get_python_packages()
1758
+
1759
+ return info
1760
+
1761
+ def model_load_failure(self, model_reference: str, message: Optional[str] = None):
1762
+ """
1763
+ Clean up after a model load failure, then log it and raise
1764
+ an HTTPException with details.
1765
+ """
1766
+ self.llm_loaded = None
1767
+ self.tokenizer = None
1768
+ self.model = None
1769
+
1770
+ default_message = "see stack trace and error message below"
1771
+ if message:
1772
+ detail = message
1773
+ else:
1774
+ detail = default_message
1775
+
1776
+ logging.exception(f"Tried to load LLM {model_reference} and failed: {detail}")
1777
+
1778
+ raise HTTPException(
1779
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
1780
+ detail=detail,
1781
+ )
1782
+
1783
+ async def pull(self, config: PullConfig):
1784
+ """
1785
+ Install a supported LLM by its Lemonade Model Name.
1786
+ """
1787
+
1788
+ # Install the model
1789
+ ModelManager().download_models(
1790
+ [config.model_name],
1791
+ checkpoint=config.checkpoint,
1792
+ recipe=config.recipe,
1793
+ reasoning=config.reasoning,
1794
+ vision=config.vision,
1795
+ mmproj=config.mmproj,
1796
+ # The pull endpoint will download an upgraded model if available, even
1797
+ # if we already have a local copy of the model
1798
+ do_not_upgrade=False,
1799
+ )
1800
+
1801
+ # Refresh the list of downloaded models, to ensure it
1802
+ # includes the model we just installed
1803
+ self.local_models = ModelManager().downloaded_models_enabled
1804
+
1805
+ async def delete(self, config: DeleteConfig):
1806
+ """
1807
+ Delete a supported LLM by its Lemonade Model Name.
1808
+ """
1809
+ try:
1810
+ # If the model to be deleted is currently loaded, unload it first
1811
+ if self.llm_loaded and self.llm_loaded.model_name == config.model_name:
1812
+ await self.unload_llm(require_lock=True)
1813
+
1814
+ # Delete the model
1815
+ ModelManager().delete_model(config.model_name)
1816
+
1817
+ # Refresh the list of downloaded models
1818
+ self.local_models = ModelManager().downloaded_models_enabled
1819
+
1820
+ return {
1821
+ "status": "success",
1822
+ "message": f"Deleted model: {config.model_name}",
1823
+ }
1824
+ except ValueError as e:
1825
+ raise HTTPException(
1826
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
1827
+ detail=str(e),
1828
+ )
1829
+ except Exception as e:
1830
+ raise HTTPException(
1831
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
1832
+ detail=f"Failed to delete model {config.model_name}: {str(e)}",
1833
+ )
1834
+
1835
+ async def load_llm(self, config: LoadConfig):
1836
+ """
1837
+ Load a registered LLM into system memory. Install the model first, if needed.
1838
+ config: the information required to load the model
1839
+ """
1840
+ from huggingface_hub.constants import HF_HUB_CACHE
1841
+
1842
+ try:
1843
+ await self._load_lock.acquire()
1844
+ # Acquire all generate locks
1845
+ for _ in range(self.max_concurrent_generations):
1846
+ await self._generate_semaphore.acquire()
1847
+
1848
+ # Make sure the model is already registered
1849
+ supported_models = ModelManager().supported_models
1850
+
1851
+ # The `custom` name allows run-as-thread servers to bypass loading
1852
+ if config.model_name == "custom":
1853
+ config_to_use = config
1854
+ else:
1855
+ if config.model_name not in supported_models.keys():
1856
+ self.model_load_failure(
1857
+ config.model_name,
1858
+ message=(
1859
+ f"Load request for model_name={config.model_name} "
1860
+ "not registered with Lemonade Server. You can register and "
1861
+ "install new models with a `pull` request."
1862
+ ),
1863
+ )
1864
+
1865
+ # Get additional properties from the model registry
1866
+ config_to_use = LoadConfig(**supported_models[config.model_name])
1867
+
1868
+ # For locally uploaded models, convert the relative checkpoint path to absolute path
1869
+ model_source = supported_models.get(config.model_name, {}).get(
1870
+ "source", None
1871
+ )
1872
+ if (
1873
+ model_source == "local_upload"
1874
+ and config_to_use.checkpoint
1875
+ and not config_to_use.recipe.startswith("hf-")
1876
+ ):
1877
+ # Check if checkpoint is a relative path (stored during upload)
1878
+ if not os.path.isabs(config_to_use.checkpoint):
1879
+ # Convert relative path to absolute by joining with HF_HUB_CACHE
1880
+ absolute_checkpoint = os.path.join(
1881
+ HF_HUB_CACHE, config_to_use.checkpoint
1882
+ )
1883
+ if os.path.exists(absolute_checkpoint):
1884
+ config_to_use.checkpoint = absolute_checkpoint
1885
+ else:
1886
+ logging.warning(
1887
+ f"Checkpoint path does not exist: {absolute_checkpoint}"
1888
+ )
1889
+
1890
+ # Also resolve mmproj path if present
1891
+ if config_to_use.mmproj and not os.path.isabs(config_to_use.mmproj):
1892
+ absolute_mmproj = os.path.join(HF_HUB_CACHE, config_to_use.mmproj)
1893
+ if os.path.exists(absolute_mmproj):
1894
+ config_to_use.mmproj = absolute_mmproj
1895
+ else:
1896
+ logging.warning(
1897
+ f"MMProj path does not exist: {absolute_mmproj}"
1898
+ )
1899
+
1900
+ # Caching mechanism: if the checkpoint is already loaded there is nothing else to do
1901
+ if (
1902
+ self.llm_loaded
1903
+ and config_to_use.checkpoint == self.llm_loaded.checkpoint
1904
+ ):
1905
+ if (
1906
+ self.llm_loaded.recipe == "llamacpp"
1907
+ or self.llm_loaded.recipe == "flm"
1908
+ ) and self.wrapped_server.process.poll():
1909
+ # wrapped server process has gone away for some reason, so we should
1910
+ # proceed with loading to get it back
1911
+ pass
1912
+ else:
1913
+ return {
1914
+ "status": "success",
1915
+ "message": f"Model already loaded: {config.model_name}",
1916
+ }
1917
+
1918
+ # Unload the current model if needed
1919
+ if self.llm_loaded:
1920
+ await self.unload_llm(require_lock=False)
1921
+
1922
+ logging.info(f"Loading llm: {config.model_name}")
1923
+ try:
1924
+ if config_to_use.recipe == "llamacpp":
1925
+ self.wrapped_server = LlamaServer(self.llamacpp_backend)
1926
+ self.wrapped_server.load(
1927
+ model_config=config_to_use,
1928
+ ctx_size=self.ctx_size,
1929
+ do_not_upgrade=True,
1930
+ )
1931
+
1932
+ elif config_to_use.recipe == "flm":
1933
+ self.wrapped_server = FlmServer()
1934
+ self.wrapped_server.load(
1935
+ model_config=config_to_use,
1936
+ ctx_size=self.ctx_size,
1937
+ do_not_upgrade=True,
1938
+ )
1939
+
1940
+ else:
1941
+ self.model, self.tokenizer = lemonade_api.from_pretrained(
1942
+ checkpoint=config_to_use.checkpoint, recipe=config_to_use.recipe
1943
+ )
1944
+ self.llm_loaded = config_to_use
1945
+
1946
+ return {
1947
+ "status": "success",
1948
+ "message": f"Loaded model: {config.model_name}",
1949
+ }
1950
+ except HTTPException:
1951
+ raise
1952
+ except Exception: # pylint: disable=broad-exception-caught
1953
+ self.model_load_failure(config.model_name)
1954
+
1955
+ finally:
1956
+ self._load_lock.release()
1957
+
1958
+ # Release all generate locks
1959
+ for _ in range(self.max_concurrent_generations):
1960
+ self._generate_semaphore.release()
1961
+
1962
+ # Refresh the list of downloaded models, to ensure it
1963
+ # includes the model we just loaded
1964
+ if config.model_name not in self.local_models:
1965
+ self.local_models = ModelManager().downloaded_models_enabled
1966
+
1967
+ async def unload_llm(self, require_lock: bool = True):
1968
+ try:
1969
+ if require_lock:
1970
+ await self._load_lock.acquire()
1971
+
1972
+ # Acquire all generate locks
1973
+ for _ in range(self.max_concurrent_generations):
1974
+ await self._generate_semaphore.acquire()
1975
+
1976
+ if self.llm_loaded.recipe == "llamacpp" or self.llm_loaded.recipe == "flm":
1977
+ self.wrapped_server.process.terminate()
1978
+
1979
+ self.llm_loaded = None
1980
+ self.tokenizer = None
1981
+ self.model = None
1982
+ return {"status": "success", "message": "Unloaded model"}
1983
+ except Exception as e: # pylint: disable=broad-exception-caught
1984
+ return {
1985
+ "status": "error",
1986
+ "message": f"Failed to unload model: {str(e)}",
1987
+ }
1988
+ finally:
1989
+ if require_lock:
1990
+ self._load_lock.release()
1991
+
1992
+ # Release all generate locks
1993
+ for _ in range(self.max_concurrent_generations):
1994
+ self._generate_semaphore.release()
1995
+
1996
+ async def models(self):
1997
+ """
1998
+ Return a list of available models in OpenAI-compatible format.
1999
+ """
2000
+ models_list = []
2001
+ for model in self.local_models:
2002
+ m = ServerModel(
2003
+ id=model,
2004
+ owned_by="lemonade",
2005
+ object="model",
2006
+ created=int(time.time()),
2007
+ checkpoint=self.local_models[model]["checkpoint"],
2008
+ recipe=self.local_models[model]["recipe"],
2009
+ )
2010
+ models_list.append(m)
2011
+
2012
+ return {"object": "list", "data": models_list}
2013
+
2014
+ async def retrieve_model(self, model_id: str):
2015
+ """
2016
+ Retrieve a specific model by ID in OpenAI-compatible format.
2017
+ """
2018
+ # Raise an error if the model does not exist
2019
+ if model_id not in self.local_models:
2020
+ # Mimic the error format of the OpenAI API
2021
+ raise HTTPException(
2022
+ status_code=404,
2023
+ detail={
2024
+ "message": f"model {model_id} not found",
2025
+ "type": "api_error",
2026
+ "param": None,
2027
+ "code": None,
2028
+ },
2029
+ )
2030
+
2031
+ # Return the specific model
2032
+ model_info = self.local_models[model_id]
2033
+ model = ServerModel(
2034
+ id=model_id,
2035
+ owned_by="lemonade",
2036
+ object="model",
2037
+ created=int(time.time()),
2038
+ checkpoint=model_info["checkpoint"],
2039
+ recipe=model_info["recipe"],
2040
+ )
2041
+
2042
+ return model
2043
+
2044
+ def setup_middleware_timer(self):
2045
+ logging.info("Middleware set up")
2046
+
2047
+ @self.app.middleware("http")
2048
+ async def log_request_time(request: Request, call_next):
2049
+ """
2050
+ Log the request processing time for any request.
2051
+ For streaming responses, wraps the body iterator to measure total time.
2052
+ Only applies the wrapper in debug mode.
2053
+ """
2054
+ start_time = time.perf_counter()
2055
+ response = await call_next(request)
2056
+
2057
+ if (
2058
+ self.debug_logging_enabled
2059
+ and hasattr(response, "body_iterator")
2060
+ and response.body_iterator is not None
2061
+ ):
2062
+ original_iterator = response.body_iterator
2063
+
2064
+ async def wrapped_iterator():
2065
+ async for chunk in original_iterator:
2066
+ yield chunk
2067
+ request_time = time.perf_counter() - start_time
2068
+ logging.debug(
2069
+ f"Total request time (streamed): {request_time:.4f} seconds"
2070
+ )
2071
+
2072
+ response.body_iterator = wrapped_iterator()
2073
+ else:
2074
+ request_time = time.perf_counter() - start_time
2075
+ if self.debug_logging_enabled:
2076
+ logging.debug(f"Total request time: {request_time:.4f} seconds")
2077
+ return response
2078
+
2079
+ async def logs_ws(self, websocket: WebSocket):
2080
+ if not self.log_file or not os.path.exists(self.log_file):
2081
+ await websocket.close(code=4000)
2082
+ return
2083
+ await log_streamer(websocket, self.log_file)
2084
+
2085
+ async def get_incompatible_models(self):
2086
+ """
2087
+ Get information about incompatible RyzenAI models in the cache.
2088
+ """
2089
+ try:
2090
+ return ModelManager().get_incompatible_ryzenai_models()
2091
+ except Exception as e:
2092
+ raise HTTPException(
2093
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
2094
+ detail=f"Failed to scan for incompatible models: {str(e)}",
2095
+ )
2096
+
2097
+ async def cleanup_incompatible_models(self, request: Request):
2098
+ """
2099
+ Delete selected incompatible RyzenAI models from the cache.
2100
+ """
2101
+ try:
2102
+ body = await request.json()
2103
+ model_paths = body.get("model_paths", [])
2104
+
2105
+ if not model_paths:
2106
+ raise HTTPException(
2107
+ status_code=status.HTTP_400_BAD_REQUEST,
2108
+ detail="No model_paths provided",
2109
+ )
2110
+
2111
+ result = ModelManager().cleanup_incompatible_models(model_paths)
2112
+ return result
2113
+ except HTTPException:
2114
+ raise
2115
+ except Exception as e:
2116
+ raise HTTPException(
2117
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
2118
+ detail=f"Failed to cleanup models: {str(e)}",
2119
+ )
2120
+
2121
+
2122
+ # This file was originally licensed under Apache 2.0. It has been modified.
2123
+ # Modifications Copyright (c) 2025 AMD