lemonade-sdk 7.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lemonade-sdk might be problematic. Click here for more details.
- lemonade/__init__.py +5 -0
- lemonade/api.py +125 -0
- lemonade/cache.py +85 -0
- lemonade/cli.py +135 -0
- lemonade/common/__init__.py +0 -0
- lemonade/common/analyze_model.py +26 -0
- lemonade/common/build.py +223 -0
- lemonade/common/cli_helpers.py +139 -0
- lemonade/common/exceptions.py +98 -0
- lemonade/common/filesystem.py +368 -0
- lemonade/common/labels.py +61 -0
- lemonade/common/onnx_helpers.py +176 -0
- lemonade/common/plugins.py +10 -0
- lemonade/common/printing.py +110 -0
- lemonade/common/status.py +490 -0
- lemonade/common/system_info.py +390 -0
- lemonade/common/tensor_helpers.py +83 -0
- lemonade/common/test_helpers.py +28 -0
- lemonade/profilers/__init__.py +1 -0
- lemonade/profilers/memory_tracker.py +257 -0
- lemonade/profilers/profiler.py +55 -0
- lemonade/sequence.py +363 -0
- lemonade/state.py +159 -0
- lemonade/tools/__init__.py +1 -0
- lemonade/tools/adapter.py +104 -0
- lemonade/tools/bench.py +284 -0
- lemonade/tools/huggingface_bench.py +267 -0
- lemonade/tools/huggingface_load.py +520 -0
- lemonade/tools/humaneval.py +258 -0
- lemonade/tools/llamacpp.py +261 -0
- lemonade/tools/llamacpp_bench.py +154 -0
- lemonade/tools/management_tools.py +273 -0
- lemonade/tools/mmlu.py +327 -0
- lemonade/tools/ort_genai/__init__.py +0 -0
- lemonade/tools/ort_genai/oga.py +1129 -0
- lemonade/tools/ort_genai/oga_bench.py +142 -0
- lemonade/tools/perplexity.py +146 -0
- lemonade/tools/prompt.py +228 -0
- lemonade/tools/quark/__init__.py +0 -0
- lemonade/tools/quark/quark_load.py +172 -0
- lemonade/tools/quark/quark_quantize.py +439 -0
- lemonade/tools/report/__init__.py +0 -0
- lemonade/tools/report/llm_report.py +203 -0
- lemonade/tools/report/table.py +739 -0
- lemonade/tools/server/__init__.py +0 -0
- lemonade/tools/server/serve.py +1354 -0
- lemonade/tools/server/tool_calls.py +146 -0
- lemonade/tools/tool.py +374 -0
- lemonade/version.py +1 -0
- lemonade_install/__init__.py +1 -0
- lemonade_install/install.py +774 -0
- lemonade_sdk-7.0.0.dist-info/METADATA +116 -0
- lemonade_sdk-7.0.0.dist-info/RECORD +61 -0
- lemonade_sdk-7.0.0.dist-info/WHEEL +5 -0
- lemonade_sdk-7.0.0.dist-info/entry_points.txt +4 -0
- lemonade_sdk-7.0.0.dist-info/licenses/LICENSE +201 -0
- lemonade_sdk-7.0.0.dist-info/licenses/NOTICE.md +21 -0
- lemonade_sdk-7.0.0.dist-info/top_level.txt +3 -0
- lemonade_server/cli.py +260 -0
- lemonade_server/model_manager.py +98 -0
- lemonade_server/server_models.json +142 -0
|
@@ -0,0 +1,1354 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import asyncio
|
|
3
|
+
import statistics
|
|
4
|
+
import time
|
|
5
|
+
from threading import Thread, Event
|
|
6
|
+
import logging
|
|
7
|
+
import traceback
|
|
8
|
+
from typing import Optional, Union
|
|
9
|
+
import json
|
|
10
|
+
|
|
11
|
+
from fastapi import FastAPI, HTTPException, status, Request
|
|
12
|
+
from fastapi.responses import StreamingResponse, HTMLResponse
|
|
13
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
14
|
+
from pydantic import BaseModel
|
|
15
|
+
import uvicorn
|
|
16
|
+
from transformers import TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList
|
|
17
|
+
from tabulate import tabulate
|
|
18
|
+
|
|
19
|
+
from openai.types.completion import Completion, CompletionChoice
|
|
20
|
+
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
|
21
|
+
from openai.types.chat import ChatCompletionMessage
|
|
22
|
+
from openai.types.chat.chat_completion_message_tool_call import (
|
|
23
|
+
ChatCompletionMessageToolCall,
|
|
24
|
+
Function,
|
|
25
|
+
)
|
|
26
|
+
from openai.types.chat.chat_completion import Choice
|
|
27
|
+
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
|
28
|
+
from openai.types.completion_choice import Logprobs
|
|
29
|
+
from openai.types.model import Model
|
|
30
|
+
from openai.types.responses import (
|
|
31
|
+
Response,
|
|
32
|
+
ResponseOutputMessage,
|
|
33
|
+
ResponseOutputText,
|
|
34
|
+
ResponseCreatedEvent,
|
|
35
|
+
ResponseTextDeltaEvent,
|
|
36
|
+
ResponseCompletedEvent,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
import lemonade.api as lemonade_api
|
|
40
|
+
from lemonade_server.model_manager import ModelManager
|
|
41
|
+
from lemonade.tools.management_tools import ManagementTool
|
|
42
|
+
from lemonade.tools.server.tool_calls import extract_tool_calls
|
|
43
|
+
|
|
44
|
+
# Set to a high number to allow for interesting experiences in real apps
|
|
45
|
+
# Tests should use the max_new_tokens argument to set a lower value
|
|
46
|
+
DEFAULT_MAX_NEW_TOKENS = 1500
|
|
47
|
+
|
|
48
|
+
DEFAULT_PORT = 8000
|
|
49
|
+
DEFAULT_LOG_LEVEL = "info"
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class ServerModel(Model):
|
|
53
|
+
"""
|
|
54
|
+
An extension of OpenAI's Model class that adds
|
|
55
|
+
checkpoint and recipe attributes.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
checkpoint: str
|
|
59
|
+
recipe: str
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class GeneratorThread(Thread):
|
|
63
|
+
"""
|
|
64
|
+
Thread class designed for use with streaming generation within
|
|
65
|
+
an LLM server. It needs access to the streamer in order to order
|
|
66
|
+
to help the completions APIs escape the "for text in streamer" loop.
|
|
67
|
+
It also provides exception handling that works nicely with HTTP
|
|
68
|
+
servers by providing the stack trace and making the exception
|
|
69
|
+
information available to the main thread.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(self, streamer, *args, **kwargs):
|
|
73
|
+
super().__init__(*args, **kwargs)
|
|
74
|
+
self.exception = None
|
|
75
|
+
self.streamer = streamer
|
|
76
|
+
|
|
77
|
+
def run(self):
|
|
78
|
+
try:
|
|
79
|
+
if self._target:
|
|
80
|
+
self._target(*self._args, **self._kwargs)
|
|
81
|
+
except Exception as e: # pylint: disable=broad-except
|
|
82
|
+
self.exception = e
|
|
83
|
+
logging.error(f"Exception raised in generate thread: {e}")
|
|
84
|
+
traceback.print_exc()
|
|
85
|
+
self.streamer.done()
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class StopOnEvent(StoppingCriteria):
|
|
89
|
+
"""
|
|
90
|
+
Custom stopping criteria that halts text generation when a specified event is set.
|
|
91
|
+
|
|
92
|
+
This allows for external control of generation, such as stopping a generation
|
|
93
|
+
before it reaches the maximum token limit.
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
def __init__(self, stop_event: Event):
|
|
97
|
+
super().__init__()
|
|
98
|
+
self.stop_event = stop_event
|
|
99
|
+
|
|
100
|
+
def __call__(self, input_ids, scores, **kwargs):
|
|
101
|
+
return self.stop_event.is_set()
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class PullConfig(BaseModel):
|
|
105
|
+
"""
|
|
106
|
+
Configurating for installing a supported LLM.
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
model_name: str
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class LoadConfig(BaseModel):
|
|
113
|
+
"""
|
|
114
|
+
Configuration for loading a language model.
|
|
115
|
+
|
|
116
|
+
Specifies the model checkpoint, generation parameters,
|
|
117
|
+
and hardware/framework configuration (recipe) for model loading.
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
model_name: Optional[str] = None
|
|
121
|
+
checkpoint: Optional[str] = None
|
|
122
|
+
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS
|
|
123
|
+
recipe: Optional[str] = None
|
|
124
|
+
# Indicates the maximum prompt length allowed for that specific
|
|
125
|
+
# checkpoint + recipe combination
|
|
126
|
+
max_prompt_length: Optional[int] = None
|
|
127
|
+
# Indicates whether the model is a reasoning model, like DeepSeek
|
|
128
|
+
reasoning: Optional[bool] = False
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class CompletionRequest(BaseModel):
|
|
132
|
+
"""
|
|
133
|
+
Request model for text completion API endpoint.
|
|
134
|
+
|
|
135
|
+
Contains a prompt, a model identifier, and a streaming
|
|
136
|
+
flag to control response delivery.
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
prompt: str
|
|
140
|
+
model: str
|
|
141
|
+
echo: bool = False
|
|
142
|
+
stream: bool = False
|
|
143
|
+
logprobs: int | None = False
|
|
144
|
+
stop: list[str] | str | None = None
|
|
145
|
+
temperature: float | None = None
|
|
146
|
+
max_tokens: int | None = None
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class ChatCompletionRequest(BaseModel):
|
|
150
|
+
"""
|
|
151
|
+
Request model for chat completion API endpoint.
|
|
152
|
+
|
|
153
|
+
Contains a list of chat messages, a model identifier,
|
|
154
|
+
and a streaming flag to control response delivery.
|
|
155
|
+
"""
|
|
156
|
+
|
|
157
|
+
messages: list[dict]
|
|
158
|
+
model: str
|
|
159
|
+
stream: bool = False
|
|
160
|
+
logprobs: int | None = False
|
|
161
|
+
stop: list[str] | str | None = None
|
|
162
|
+
temperature: float | None = None
|
|
163
|
+
tools: list[dict] | None = None
|
|
164
|
+
max_tokens: int | None = None
|
|
165
|
+
max_completion_tokens: int | None = None
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
class ResponsesRequest(BaseModel):
|
|
169
|
+
"""
|
|
170
|
+
Request model for responses API endpoint.
|
|
171
|
+
"""
|
|
172
|
+
|
|
173
|
+
input: list[dict] | str
|
|
174
|
+
model: str
|
|
175
|
+
max_output_tokens: int | None = None
|
|
176
|
+
temperature: float | None = None
|
|
177
|
+
stream: bool = False
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class Server(ManagementTool):
|
|
181
|
+
"""
|
|
182
|
+
Open a web server that apps can use to communicate with the LLM.
|
|
183
|
+
|
|
184
|
+
The server exposes these endpoints:
|
|
185
|
+
- /api/v0/pull: install an LLM by its Lemonade Server Model Name.
|
|
186
|
+
- /api/v0/load: load a model checkpoint.
|
|
187
|
+
- /api/v0/unload: unload a model checkpoint.
|
|
188
|
+
- /api/v0/health: check whether a model is loaded and ready to serve.
|
|
189
|
+
- /api/v0/stats: performance statistics for the generation.
|
|
190
|
+
- /api/v0/halt: stop an in-progress generation from make more tokens.
|
|
191
|
+
- /api/v0/completions: completion responses using HTTP chunked transfer encoding.
|
|
192
|
+
- /api/v0/chat/completions: chat completion responses using HTTP chunked transfer encoding.
|
|
193
|
+
- /api/v0/responses: responses API using HTTP chunked transfer encoding.
|
|
194
|
+
- /api/v0/models: list all available models.
|
|
195
|
+
"""
|
|
196
|
+
|
|
197
|
+
unique_name = "serve"
|
|
198
|
+
|
|
199
|
+
def __init__(self):
|
|
200
|
+
super().__init__()
|
|
201
|
+
|
|
202
|
+
# Initialize FastAPI app
|
|
203
|
+
self.app = FastAPI()
|
|
204
|
+
|
|
205
|
+
# Add CORS middleware
|
|
206
|
+
self.app.add_middleware(
|
|
207
|
+
CORSMiddleware,
|
|
208
|
+
allow_origins=["*"], # Allows all origins
|
|
209
|
+
allow_credentials=True,
|
|
210
|
+
allow_methods=["*"], # Allows all methods
|
|
211
|
+
allow_headers=["*"], # Allows all headers
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
# Set up custom routes
|
|
215
|
+
self.app.post("/api/v0/pull")(self.pull)
|
|
216
|
+
self.app.post("/api/v0/load")(self.load_llm)
|
|
217
|
+
self.app.post("/api/v0/unload")(self.unload_llm)
|
|
218
|
+
self.app.get("/api/v0/health")(self.health)
|
|
219
|
+
self.app.get("/api/v0/halt")(self.halt_generation)
|
|
220
|
+
self.app.get("/api/v0/stats")(self.send_stats)
|
|
221
|
+
self.app.post("/api/v0/completions")(self.completions)
|
|
222
|
+
self.app.post("/api/v0/responses")(self.responses)
|
|
223
|
+
|
|
224
|
+
# Set up OpenAI-compatible routes
|
|
225
|
+
self.app.post("/api/v0/chat/completions")(self.chat_completions)
|
|
226
|
+
self.app.post("/api/v0/completions")(self.completions)
|
|
227
|
+
self.app.get("/api/v0/models")(self.models)
|
|
228
|
+
|
|
229
|
+
# Set up instructions
|
|
230
|
+
self.app.get("/")(self.instructions)
|
|
231
|
+
|
|
232
|
+
# Performance stats that are set during /ws and can be
|
|
233
|
+
# fetched in /stats
|
|
234
|
+
self.time_to_first_token = None
|
|
235
|
+
self.tokens_per_second = None
|
|
236
|
+
self.input_tokens = None
|
|
237
|
+
self.output_tokens = None
|
|
238
|
+
self.decode_token_times = None
|
|
239
|
+
|
|
240
|
+
# Input truncation settings
|
|
241
|
+
self.truncate_inputs = False
|
|
242
|
+
|
|
243
|
+
# Store debug logging state
|
|
244
|
+
self.debug_logging_enabled = logging.getLogger().isEnabledFor(logging.DEBUG)
|
|
245
|
+
|
|
246
|
+
# Flag that tells the LLM to stop generating text and end the response
|
|
247
|
+
self.stop_event = Event()
|
|
248
|
+
|
|
249
|
+
self.llm_loaded: LoadConfig = None
|
|
250
|
+
self.tokenizer = None
|
|
251
|
+
|
|
252
|
+
# Placeholders for model and configs
|
|
253
|
+
self.model = None
|
|
254
|
+
|
|
255
|
+
# Initialize semaphore for tracking active generations
|
|
256
|
+
self.max_concurrent_generations = 1
|
|
257
|
+
self._generate_semaphore = asyncio.Semaphore(self.max_concurrent_generations)
|
|
258
|
+
|
|
259
|
+
# Dictionary of installed LLM, by model name : information about those models
|
|
260
|
+
# Does not include non-installed models
|
|
261
|
+
self.local_models = ModelManager().downloaded_models_enabled
|
|
262
|
+
|
|
263
|
+
# Add lock for load/unload operations
|
|
264
|
+
self._load_lock = asyncio.Lock()
|
|
265
|
+
|
|
266
|
+
@staticmethod
|
|
267
|
+
def parser(add_help: bool = True) -> argparse.ArgumentParser:
|
|
268
|
+
parser = __class__.helpful_parser(
|
|
269
|
+
short_description="Launch an industry-standard LLM server",
|
|
270
|
+
add_help=add_help,
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
parser.add_argument(
|
|
274
|
+
"--port",
|
|
275
|
+
required=False,
|
|
276
|
+
type=int,
|
|
277
|
+
default=DEFAULT_PORT,
|
|
278
|
+
help=f"Port number to run the server on (default: {DEFAULT_PORT})",
|
|
279
|
+
)
|
|
280
|
+
parser.add_argument(
|
|
281
|
+
"--log-level",
|
|
282
|
+
required=False,
|
|
283
|
+
type=str,
|
|
284
|
+
default=DEFAULT_LOG_LEVEL,
|
|
285
|
+
choices=["critical", "error", "warning", "info", "debug", "trace"],
|
|
286
|
+
help=f"Logging level (default: {DEFAULT_LOG_LEVEL})",
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
return parser
|
|
290
|
+
|
|
291
|
+
def run(
|
|
292
|
+
self,
|
|
293
|
+
# ManagementTool has a required cache_dir arg, but
|
|
294
|
+
# we always use the default cache directory
|
|
295
|
+
_=None,
|
|
296
|
+
port: int = DEFAULT_PORT,
|
|
297
|
+
log_level: str = DEFAULT_LOG_LEVEL,
|
|
298
|
+
truncate_inputs: bool = False,
|
|
299
|
+
):
|
|
300
|
+
# Store truncation settings
|
|
301
|
+
self.truncate_inputs = truncate_inputs
|
|
302
|
+
|
|
303
|
+
# Define TRACE level
|
|
304
|
+
logging.TRACE = 9 # Lower than DEBUG which is 10
|
|
305
|
+
logging.addLevelName(logging.TRACE, "TRACE")
|
|
306
|
+
|
|
307
|
+
# Add a convenience function at the module level
|
|
308
|
+
def trace(message, *args, **kwargs):
|
|
309
|
+
logging.log(logging.TRACE, message, *args, **kwargs)
|
|
310
|
+
|
|
311
|
+
logging.trace = trace
|
|
312
|
+
|
|
313
|
+
# Configure logging to match uvicorn's format
|
|
314
|
+
logging_level = getattr(logging, log_level.upper())
|
|
315
|
+
logging.basicConfig(
|
|
316
|
+
level=logging_level,
|
|
317
|
+
format="%(levelprefix)s %(message)s",
|
|
318
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
# Add uvicorn's log formatter
|
|
322
|
+
logging.root.handlers[0].formatter = uvicorn.logging.DefaultFormatter(
|
|
323
|
+
fmt="%(levelprefix)s %(message)s",
|
|
324
|
+
use_colors=True,
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
# Ensure the log level is properly set
|
|
328
|
+
logging.getLogger().setLevel(logging_level)
|
|
329
|
+
|
|
330
|
+
# Update debug logging state after setting log level
|
|
331
|
+
self.debug_logging_enabled = logging.getLogger().isEnabledFor(logging.DEBUG)
|
|
332
|
+
|
|
333
|
+
if self.debug_logging_enabled:
|
|
334
|
+
# Print the elapsed time for each request
|
|
335
|
+
self.setup_middleware_timer()
|
|
336
|
+
|
|
337
|
+
uvicorn.run(self.app, host="localhost", port=port, log_level=log_level)
|
|
338
|
+
|
|
339
|
+
async def _show_telemetry(self):
|
|
340
|
+
"""
|
|
341
|
+
Show telemetry data in debug mode.
|
|
342
|
+
"""
|
|
343
|
+
# Exit early if debug logging is disabled or no telemetry data is available
|
|
344
|
+
if not self.debug_logging_enabled or self.tokens_per_second is None:
|
|
345
|
+
return
|
|
346
|
+
|
|
347
|
+
# Prepare telemetry data (transposed format)
|
|
348
|
+
telemetry = [
|
|
349
|
+
["Input tokens", self.input_tokens],
|
|
350
|
+
["Output tokens", self.output_tokens],
|
|
351
|
+
["TTFT (s)", f"{self.time_to_first_token:.2f}"],
|
|
352
|
+
["TPS", f"{self.tokens_per_second:.2f}"],
|
|
353
|
+
]
|
|
354
|
+
|
|
355
|
+
table = tabulate(
|
|
356
|
+
telemetry, headers=["Metric", "Value"], tablefmt="fancy_grid"
|
|
357
|
+
).split("\n")
|
|
358
|
+
|
|
359
|
+
# Show telemetry in debug while complying with uvicorn's log indentation
|
|
360
|
+
logging.debug("\n ".join(table))
|
|
361
|
+
|
|
362
|
+
def instructions(self):
|
|
363
|
+
"""
|
|
364
|
+
Show instructions on how to use the server.
|
|
365
|
+
"""
|
|
366
|
+
html_content = """
|
|
367
|
+
<!DOCTYPE html>
|
|
368
|
+
<html>
|
|
369
|
+
<head>
|
|
370
|
+
<title>Lemonade Server</title>
|
|
371
|
+
<link rel="icon" href="data:,">
|
|
372
|
+
</head>
|
|
373
|
+
<body>
|
|
374
|
+
<h1>🍋 Welcome to Lemonade Server!</h1>
|
|
375
|
+
<p>
|
|
376
|
+
A standards-compliant server that provides REST APIs for LLM communication.
|
|
377
|
+
To get started, simply point your OpenAI-compatible application at the server's endpoint.
|
|
378
|
+
</p>
|
|
379
|
+
<div class="links">
|
|
380
|
+
<h3>Documentation:</h3>
|
|
381
|
+
<ul>
|
|
382
|
+
<li><a href="https://github.com/lemonade-sdk/lemonade/tree/main/docs/server/apps/README.md">Examples & Usage</a></li>
|
|
383
|
+
<li><a href="https://github.com/lemonade-sdk/lemonade/blob/main/docs/server/server_integration.md">Integration Guide</a></li>
|
|
384
|
+
<li><a href="https://github.com/lemonade-sdk/lemonade/blob/main/docs/server/server_spec.md">Server Specification</a></li>
|
|
385
|
+
</ul>
|
|
386
|
+
</div>
|
|
387
|
+
</body>
|
|
388
|
+
</html>
|
|
389
|
+
"""
|
|
390
|
+
return HTMLResponse(content=html_content, status_code=200)
|
|
391
|
+
|
|
392
|
+
def initialize_load_config(
|
|
393
|
+
self, request: Union[ChatCompletionRequest, CompletionRequest]
|
|
394
|
+
) -> LoadConfig:
|
|
395
|
+
"""
|
|
396
|
+
Turn the Request object into a partially-complete LoadConfig.
|
|
397
|
+
|
|
398
|
+
The load_llm() method is responsible for filling in the rest of
|
|
399
|
+
LoadConfig's parameters.
|
|
400
|
+
"""
|
|
401
|
+
|
|
402
|
+
# Get model config
|
|
403
|
+
if "/" in request.model:
|
|
404
|
+
# We know the model is a Hugging Face checkpoint if it contains a /
|
|
405
|
+
lc = LoadConfig(checkpoint=request.model)
|
|
406
|
+
else:
|
|
407
|
+
# The model should be a reference to a built-in model
|
|
408
|
+
lc = LoadConfig(model_name=request.model)
|
|
409
|
+
|
|
410
|
+
return lc
|
|
411
|
+
|
|
412
|
+
async def completions(self, completion_request: CompletionRequest):
|
|
413
|
+
"""
|
|
414
|
+
Stream completion responses using HTTP chunked transfer encoding.
|
|
415
|
+
"""
|
|
416
|
+
|
|
417
|
+
lc = self.initialize_load_config(completion_request)
|
|
418
|
+
|
|
419
|
+
# Load the model if it's different from the currently loaded one
|
|
420
|
+
await self.load_llm(lc, internal_call=True)
|
|
421
|
+
|
|
422
|
+
# Check if the model supports reasoning
|
|
423
|
+
reasoning_first_token = self.llm_loaded.reasoning
|
|
424
|
+
|
|
425
|
+
# If the model supports reasoning, we:
|
|
426
|
+
# 1. add a <think> tag to the model's context
|
|
427
|
+
# 2. ensure that the first token is a <think> token
|
|
428
|
+
text = completion_request.prompt
|
|
429
|
+
if reasoning_first_token:
|
|
430
|
+
text += "<think>"
|
|
431
|
+
|
|
432
|
+
# Prepare generation arguments
|
|
433
|
+
generation_args = {
|
|
434
|
+
"message": text,
|
|
435
|
+
"stop": completion_request.stop,
|
|
436
|
+
"temperature": completion_request.temperature,
|
|
437
|
+
"max_new_tokens": completion_request.max_tokens,
|
|
438
|
+
}
|
|
439
|
+
|
|
440
|
+
if completion_request.stream:
|
|
441
|
+
|
|
442
|
+
if completion_request.logprobs:
|
|
443
|
+
logging.warning("logprobs is not supported for streaming completion")
|
|
444
|
+
if completion_request.echo:
|
|
445
|
+
logging.warning(
|
|
446
|
+
"`Echo` parameter is not supported for streaming completions"
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
# Stream the response
|
|
450
|
+
async def generate():
|
|
451
|
+
# Declare it's the same variable from outside scope
|
|
452
|
+
# This is necessary because the variable is modified
|
|
453
|
+
# in the inner function
|
|
454
|
+
nonlocal reasoning_first_token
|
|
455
|
+
|
|
456
|
+
async for token in self._generate_tokens(**generation_args):
|
|
457
|
+
choice = CompletionChoice(
|
|
458
|
+
text=("<think>" + token if reasoning_first_token else token),
|
|
459
|
+
index=0,
|
|
460
|
+
finish_reason="stop",
|
|
461
|
+
logprobs=None,
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
completion = Completion(
|
|
465
|
+
id="0",
|
|
466
|
+
choices=[choice],
|
|
467
|
+
model=self.llm_loaded.checkpoint,
|
|
468
|
+
object="text_completion",
|
|
469
|
+
created=int(time.time()),
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
# Format as SSE
|
|
473
|
+
reasoning_first_token = False
|
|
474
|
+
yield f"data: {completion.model_dump_json()}\n\n".encode("utf-8")
|
|
475
|
+
|
|
476
|
+
# Send the [DONE] marker
|
|
477
|
+
yield b"data: [DONE]\n\n"
|
|
478
|
+
|
|
479
|
+
return StreamingResponse(
|
|
480
|
+
generate(),
|
|
481
|
+
media_type="text/event-stream",
|
|
482
|
+
headers={
|
|
483
|
+
"Cache-Control": "no-cache",
|
|
484
|
+
"Connection": "keep-alive",
|
|
485
|
+
},
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
# If streaming is not requested, collect all generated tokens into a single response
|
|
489
|
+
else:
|
|
490
|
+
full_response = text if completion_request.echo else ""
|
|
491
|
+
async for token in self._generate_tokens(**generation_args):
|
|
492
|
+
full_response += token
|
|
493
|
+
|
|
494
|
+
# If logprobs are requested, create a logprobs object
|
|
495
|
+
logprobs = None
|
|
496
|
+
if completion_request.logprobs:
|
|
497
|
+
|
|
498
|
+
# Compute the logprobs
|
|
499
|
+
text_offset, token_logprobs, tokens, top_logprobs = (
|
|
500
|
+
self.model.compute_logprobs(
|
|
501
|
+
text=full_response,
|
|
502
|
+
tokenizer=self.tokenizer,
|
|
503
|
+
logprobs=completion_request.logprobs,
|
|
504
|
+
)
|
|
505
|
+
)
|
|
506
|
+
logprobs = Logprobs.model_construct(
|
|
507
|
+
text_offset=text_offset,
|
|
508
|
+
token_logprobs=token_logprobs,
|
|
509
|
+
tokens=tokens,
|
|
510
|
+
top_logprobs=top_logprobs,
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
choice = CompletionChoice(
|
|
514
|
+
text=full_response,
|
|
515
|
+
index=0,
|
|
516
|
+
finish_reason="stop",
|
|
517
|
+
logprobs=logprobs,
|
|
518
|
+
)
|
|
519
|
+
|
|
520
|
+
return Completion(
|
|
521
|
+
id="0",
|
|
522
|
+
choices=[choice],
|
|
523
|
+
model=self.llm_loaded.checkpoint,
|
|
524
|
+
object="text_completion",
|
|
525
|
+
created=int(time.time()),
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
async def chat_completions(self, chat_completion_request: ChatCompletionRequest):
|
|
529
|
+
"""
|
|
530
|
+
Stream chat completion responses using HTTP chunked transfer encoding.
|
|
531
|
+
"""
|
|
532
|
+
|
|
533
|
+
if chat_completion_request.tools and chat_completion_request.stream:
|
|
534
|
+
logging.warning(
|
|
535
|
+
"tools are only supported on non-streaming chat completions"
|
|
536
|
+
)
|
|
537
|
+
if chat_completion_request.logprobs:
|
|
538
|
+
logging.warning("logprobs is not supported on chat completion")
|
|
539
|
+
|
|
540
|
+
lc = self.initialize_load_config(chat_completion_request)
|
|
541
|
+
|
|
542
|
+
# Load the model if it's different from the currently loaded one
|
|
543
|
+
await self.load_llm(lc, internal_call=True)
|
|
544
|
+
|
|
545
|
+
# Convert chat messages to text using the model's chat template
|
|
546
|
+
text = self.apply_chat_template(
|
|
547
|
+
chat_completion_request.messages,
|
|
548
|
+
tools=(
|
|
549
|
+
chat_completion_request.tools
|
|
550
|
+
if not chat_completion_request.stream
|
|
551
|
+
else None
|
|
552
|
+
),
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
# If the model supports reasoning, we:
|
|
556
|
+
# 1. add a <think> tag to the model's context
|
|
557
|
+
# 2. ensure that the first token is a <think> token
|
|
558
|
+
reasoning_first_token = self.llm_loaded.reasoning
|
|
559
|
+
|
|
560
|
+
if reasoning_first_token:
|
|
561
|
+
text += "<think>"
|
|
562
|
+
# Set the max_new_tokens parameter
|
|
563
|
+
if (
|
|
564
|
+
chat_completion_request.max_completion_tokens
|
|
565
|
+
and chat_completion_request.max_tokens
|
|
566
|
+
):
|
|
567
|
+
raise HTTPException(
|
|
568
|
+
status_code=status.HTTP_400_BAD_REQUEST,
|
|
569
|
+
detail=(
|
|
570
|
+
"Both max_tokens and max_completion_tokens were provided. "
|
|
571
|
+
"Please use only one of these parameters.",
|
|
572
|
+
),
|
|
573
|
+
)
|
|
574
|
+
max_new_tokens = (
|
|
575
|
+
chat_completion_request.max_completion_tokens
|
|
576
|
+
if chat_completion_request.max_completion_tokens
|
|
577
|
+
else chat_completion_request.max_tokens
|
|
578
|
+
)
|
|
579
|
+
|
|
580
|
+
# Prepare generation arguments
|
|
581
|
+
generation_args = {
|
|
582
|
+
"message": text,
|
|
583
|
+
"stop": chat_completion_request.stop,
|
|
584
|
+
"temperature": chat_completion_request.temperature,
|
|
585
|
+
"max_new_tokens": max_new_tokens,
|
|
586
|
+
}
|
|
587
|
+
|
|
588
|
+
if chat_completion_request.stream:
|
|
589
|
+
|
|
590
|
+
# Stream the response
|
|
591
|
+
async def generate():
|
|
592
|
+
# Declare it's the same variable from outside scope
|
|
593
|
+
# This is necessary because the variable is modified
|
|
594
|
+
# in the inner function
|
|
595
|
+
nonlocal reasoning_first_token
|
|
596
|
+
|
|
597
|
+
async for token in self._generate_tokens(**generation_args):
|
|
598
|
+
|
|
599
|
+
# Create a ChatCompletionChunk
|
|
600
|
+
chunk = ChatCompletionChunk.model_construct(
|
|
601
|
+
id="0",
|
|
602
|
+
object="chat.completion.chunk",
|
|
603
|
+
created=int(time.time()),
|
|
604
|
+
model=self.llm_loaded.checkpoint,
|
|
605
|
+
choices=[
|
|
606
|
+
Choice.model_construct(
|
|
607
|
+
index=0,
|
|
608
|
+
delta=ChoiceDelta(
|
|
609
|
+
content=(
|
|
610
|
+
"<think>" + token
|
|
611
|
+
if reasoning_first_token
|
|
612
|
+
else token
|
|
613
|
+
),
|
|
614
|
+
function_call=None,
|
|
615
|
+
role="assistant",
|
|
616
|
+
tool_calls=None,
|
|
617
|
+
refusal=None,
|
|
618
|
+
),
|
|
619
|
+
finish_reason=None,
|
|
620
|
+
logprobs=None,
|
|
621
|
+
)
|
|
622
|
+
],
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
# Format as SSE
|
|
626
|
+
reasoning_first_token = False
|
|
627
|
+
yield f"data: {chunk.model_dump_json()}\n\n".encode("utf-8")
|
|
628
|
+
|
|
629
|
+
# Send the [DONE] marker
|
|
630
|
+
yield b"data: [DONE]\n\n"
|
|
631
|
+
|
|
632
|
+
return StreamingResponse(
|
|
633
|
+
generate(),
|
|
634
|
+
media_type="text/event-stream",
|
|
635
|
+
headers={
|
|
636
|
+
"Cache-Control": "no-cache",
|
|
637
|
+
"Connection": "keep-alive",
|
|
638
|
+
},
|
|
639
|
+
)
|
|
640
|
+
|
|
641
|
+
# If streaming is not requested, collect all generated tokens into a single response
|
|
642
|
+
else:
|
|
643
|
+
full_response = "<think>" if reasoning_first_token else ""
|
|
644
|
+
async for token in self._generate_tokens(**generation_args):
|
|
645
|
+
full_response += token
|
|
646
|
+
|
|
647
|
+
# Extract tool calls from the response
|
|
648
|
+
openai_tool_calls = None
|
|
649
|
+
if chat_completion_request.tools:
|
|
650
|
+
tool_calls, full_response = extract_tool_calls(
|
|
651
|
+
full_response, self.tokenizer.auto_tokenizer.added_tokens_decoder
|
|
652
|
+
)
|
|
653
|
+
if tool_calls:
|
|
654
|
+
openai_tool_calls = []
|
|
655
|
+
for tool_call in tool_calls:
|
|
656
|
+
openai_tool_calls.append(
|
|
657
|
+
ChatCompletionMessageToolCall(
|
|
658
|
+
id="-",
|
|
659
|
+
function=Function(
|
|
660
|
+
arguments=json.dumps(tool_call["arguments"]),
|
|
661
|
+
name=tool_call["name"],
|
|
662
|
+
),
|
|
663
|
+
type="function",
|
|
664
|
+
)
|
|
665
|
+
)
|
|
666
|
+
|
|
667
|
+
ccm = ChatCompletionMessage(
|
|
668
|
+
content=full_response,
|
|
669
|
+
role="assistant",
|
|
670
|
+
refusal=None,
|
|
671
|
+
audio=None,
|
|
672
|
+
function_call=None,
|
|
673
|
+
tool_calls=openai_tool_calls,
|
|
674
|
+
)
|
|
675
|
+
|
|
676
|
+
choice = Choice(
|
|
677
|
+
finish_reason="stop",
|
|
678
|
+
index=0,
|
|
679
|
+
message=ccm,
|
|
680
|
+
logprobs=None,
|
|
681
|
+
)
|
|
682
|
+
|
|
683
|
+
return ChatCompletion(
|
|
684
|
+
id="0",
|
|
685
|
+
choices=[choice],
|
|
686
|
+
model=self.llm_loaded.checkpoint,
|
|
687
|
+
object="chat.completion",
|
|
688
|
+
created=int(time.time()),
|
|
689
|
+
)
|
|
690
|
+
|
|
691
|
+
def apply_chat_template(
|
|
692
|
+
self, messages: list[dict], tools: list[dict] | None = None
|
|
693
|
+
):
|
|
694
|
+
"""
|
|
695
|
+
Apply the model's chat template to the messages.
|
|
696
|
+
"""
|
|
697
|
+
if self.tokenizer.chat_template:
|
|
698
|
+
|
|
699
|
+
return self.tokenizer.apply_chat_template(
|
|
700
|
+
messages,
|
|
701
|
+
tokenize=False,
|
|
702
|
+
add_generation_prompt=True,
|
|
703
|
+
tools=tools,
|
|
704
|
+
)
|
|
705
|
+
|
|
706
|
+
# Fallback to a standardized template if the model doesn't provide one
|
|
707
|
+
logging.warning("No chat template found. Using default template.")
|
|
708
|
+
formatted_messages = []
|
|
709
|
+
for msg in messages:
|
|
710
|
+
role = msg.get("role", "user")
|
|
711
|
+
content = msg.get("content", "")
|
|
712
|
+
role_marker = "<|assistant|>" if role == "assistant" else "<|user|>"
|
|
713
|
+
formatted_messages.append(f"{role_marker}\n{content} <|end|>")
|
|
714
|
+
return "\n".join(formatted_messages) + "\n<|assistant|>"
|
|
715
|
+
|
|
716
|
+
async def responses(self, responses_request: ResponsesRequest):
|
|
717
|
+
"""
|
|
718
|
+
Stream responses using HTTP chunked transfer encoding.
|
|
719
|
+
"""
|
|
720
|
+
|
|
721
|
+
lc = self.initialize_load_config(responses_request)
|
|
722
|
+
|
|
723
|
+
# Load the model if it's different from the currently loaded one
|
|
724
|
+
await self.load_llm(lc, internal_call=True)
|
|
725
|
+
|
|
726
|
+
# Convert chat messages to text using the model's chat template
|
|
727
|
+
if isinstance(responses_request.input, str):
|
|
728
|
+
text = responses_request.input
|
|
729
|
+
else:
|
|
730
|
+
text = self.apply_chat_template(responses_request.input)
|
|
731
|
+
|
|
732
|
+
# If the model supports reasoning, we:
|
|
733
|
+
# 1. add a <think> tag to the model's context
|
|
734
|
+
# 2. ensure that the first token is a <think> token
|
|
735
|
+
reasoning_first_token = self.llm_loaded.reasoning
|
|
736
|
+
|
|
737
|
+
if reasoning_first_token:
|
|
738
|
+
text += "<think>"
|
|
739
|
+
|
|
740
|
+
# Prepare generation arguments
|
|
741
|
+
generation_args = {
|
|
742
|
+
"message": text,
|
|
743
|
+
"temperature": responses_request.temperature,
|
|
744
|
+
"max_new_tokens": responses_request.max_output_tokens,
|
|
745
|
+
}
|
|
746
|
+
|
|
747
|
+
if responses_request.stream:
|
|
748
|
+
|
|
749
|
+
# Stream the response
|
|
750
|
+
async def generate():
|
|
751
|
+
# Declare it's the same variable from outside scope
|
|
752
|
+
# This is necessary because the variable is modified
|
|
753
|
+
# in the inner function
|
|
754
|
+
nonlocal reasoning_first_token
|
|
755
|
+
|
|
756
|
+
# Send initial creation event
|
|
757
|
+
response = Response(
|
|
758
|
+
id="0",
|
|
759
|
+
model=self.llm_loaded.checkpoint,
|
|
760
|
+
created_at=int(time.time()),
|
|
761
|
+
object="response",
|
|
762
|
+
output=[],
|
|
763
|
+
parallel_tool_calls=True,
|
|
764
|
+
tool_choice="auto",
|
|
765
|
+
tools=[],
|
|
766
|
+
)
|
|
767
|
+
created_event = ResponseCreatedEvent(
|
|
768
|
+
response=response,
|
|
769
|
+
type="response.created",
|
|
770
|
+
)
|
|
771
|
+
yield f"data: {created_event.model_dump_json()}\n\n".encode("utf-8")
|
|
772
|
+
|
|
773
|
+
full_response = "<think>" if reasoning_first_token else ""
|
|
774
|
+
|
|
775
|
+
async for token in self._generate_tokens(**generation_args):
|
|
776
|
+
|
|
777
|
+
# Create an event
|
|
778
|
+
delta_event = ResponseTextDeltaEvent(
|
|
779
|
+
content_index=0,
|
|
780
|
+
delta=("<think>" + token if reasoning_first_token else token),
|
|
781
|
+
item_id="0 ",
|
|
782
|
+
output_index=0,
|
|
783
|
+
type="response.output_text.delta",
|
|
784
|
+
)
|
|
785
|
+
full_response += token
|
|
786
|
+
|
|
787
|
+
# Format as SSE
|
|
788
|
+
reasoning_first_token = False
|
|
789
|
+
yield f"data: {delta_event.model_dump_json()}\n\n".encode("utf-8")
|
|
790
|
+
|
|
791
|
+
# Send the completed event
|
|
792
|
+
response_output_message = ResponseOutputMessage(
|
|
793
|
+
id="0",
|
|
794
|
+
content=[
|
|
795
|
+
ResponseOutputText(
|
|
796
|
+
annotations=[],
|
|
797
|
+
text=full_response,
|
|
798
|
+
type="output_text",
|
|
799
|
+
)
|
|
800
|
+
],
|
|
801
|
+
role="assistant",
|
|
802
|
+
status="completed",
|
|
803
|
+
type="message",
|
|
804
|
+
)
|
|
805
|
+
response = Response(
|
|
806
|
+
id="0",
|
|
807
|
+
model=self.llm_loaded.checkpoint,
|
|
808
|
+
created_at=int(time.time()),
|
|
809
|
+
object="response",
|
|
810
|
+
output=[response_output_message],
|
|
811
|
+
parallel_tool_calls=True,
|
|
812
|
+
tool_choice="auto",
|
|
813
|
+
tools=[],
|
|
814
|
+
)
|
|
815
|
+
completed_event = ResponseCompletedEvent(
|
|
816
|
+
response=response,
|
|
817
|
+
type="response.completed",
|
|
818
|
+
)
|
|
819
|
+
yield f"data: {completed_event.model_dump_json()}\n\n".encode("utf-8")
|
|
820
|
+
|
|
821
|
+
# Send the [DONE] marker
|
|
822
|
+
yield b"data: [DONE]\n\n"
|
|
823
|
+
|
|
824
|
+
return StreamingResponse(
|
|
825
|
+
generate(),
|
|
826
|
+
media_type="text/event-stream",
|
|
827
|
+
headers={
|
|
828
|
+
"Cache-Control": "no-cache",
|
|
829
|
+
"Connection": "keep-alive",
|
|
830
|
+
},
|
|
831
|
+
)
|
|
832
|
+
|
|
833
|
+
# If streaming is not requested, collect all generated tokens into a single response
|
|
834
|
+
else:
|
|
835
|
+
full_response = "<think>" if reasoning_first_token else ""
|
|
836
|
+
async for token in self._generate_tokens(**generation_args):
|
|
837
|
+
full_response += token
|
|
838
|
+
|
|
839
|
+
# Send the completed event
|
|
840
|
+
response_output_message = ResponseOutputMessage(
|
|
841
|
+
id="0",
|
|
842
|
+
content=[
|
|
843
|
+
ResponseOutputText(
|
|
844
|
+
annotations=[],
|
|
845
|
+
text=full_response,
|
|
846
|
+
type="output_text",
|
|
847
|
+
)
|
|
848
|
+
],
|
|
849
|
+
role="assistant",
|
|
850
|
+
status="completed",
|
|
851
|
+
type="message",
|
|
852
|
+
)
|
|
853
|
+
return Response(
|
|
854
|
+
id="0",
|
|
855
|
+
model=self.llm_loaded.checkpoint,
|
|
856
|
+
created_at=int(time.time()),
|
|
857
|
+
object="response",
|
|
858
|
+
output=[response_output_message],
|
|
859
|
+
parallel_tool_calls=True,
|
|
860
|
+
tool_choice="auto",
|
|
861
|
+
tools=[],
|
|
862
|
+
)
|
|
863
|
+
|
|
864
|
+
async def _generate_tokens(
|
|
865
|
+
self,
|
|
866
|
+
message: str,
|
|
867
|
+
stop: list[str] | str | None = None,
|
|
868
|
+
max_new_tokens: int | None = None,
|
|
869
|
+
temperature: float | None = None,
|
|
870
|
+
):
|
|
871
|
+
"""
|
|
872
|
+
Core streaming completion logic, separated from response handling.
|
|
873
|
+
Returns an async generator that yields tokens.
|
|
874
|
+
"""
|
|
875
|
+
model = self.model
|
|
876
|
+
tokenizer = self.tokenizer
|
|
877
|
+
|
|
878
|
+
# Reset the early-exit flag before we start each generation
|
|
879
|
+
self.stop_event.clear()
|
|
880
|
+
|
|
881
|
+
input_ids = tokenizer(message, return_tensors="pt").input_ids
|
|
882
|
+
|
|
883
|
+
# Process stop sequences
|
|
884
|
+
stop_sequences = []
|
|
885
|
+
if stop is not None:
|
|
886
|
+
if isinstance(stop, str):
|
|
887
|
+
stop_sequences = [stop]
|
|
888
|
+
else:
|
|
889
|
+
stop_sequences = stop[:4] # Limit to 4 sequences as per spec
|
|
890
|
+
|
|
891
|
+
# Set up the generation parameters
|
|
892
|
+
if "oga-" in self.llm_loaded.recipe:
|
|
893
|
+
from lemonade.tools.ort_genai.oga import OrtGenaiStreamer
|
|
894
|
+
|
|
895
|
+
streamer = OrtGenaiStreamer(tokenizer)
|
|
896
|
+
self.input_tokens = len(input_ids)
|
|
897
|
+
else:
|
|
898
|
+
streamer = TextIteratorStreamer(
|
|
899
|
+
tokenizer,
|
|
900
|
+
skip_prompt=True,
|
|
901
|
+
)
|
|
902
|
+
self.input_tokens = len(input_ids[0])
|
|
903
|
+
|
|
904
|
+
if (
|
|
905
|
+
self.llm_loaded.max_prompt_length
|
|
906
|
+
and self.input_tokens > self.llm_loaded.max_prompt_length
|
|
907
|
+
):
|
|
908
|
+
if self.truncate_inputs:
|
|
909
|
+
# Truncate input ids
|
|
910
|
+
truncate_amount = self.input_tokens - self.llm_loaded.max_prompt_length
|
|
911
|
+
input_ids = input_ids[: self.llm_loaded.max_prompt_length]
|
|
912
|
+
|
|
913
|
+
# Update token count
|
|
914
|
+
self.input_tokens = len(input_ids)
|
|
915
|
+
|
|
916
|
+
# Show warning message
|
|
917
|
+
truncation_message = (
|
|
918
|
+
f"Input exceeded {self.llm_loaded.max_prompt_length} tokens. "
|
|
919
|
+
f"Truncated {truncate_amount} tokens."
|
|
920
|
+
)
|
|
921
|
+
logging.warning(truncation_message)
|
|
922
|
+
else:
|
|
923
|
+
raise RuntimeError(
|
|
924
|
+
f"Prompt tokens ({self.input_tokens}) cannot be greater "
|
|
925
|
+
f"than the model's max prompt length ({self.llm_loaded.max_prompt_length})"
|
|
926
|
+
)
|
|
927
|
+
|
|
928
|
+
# Log the input tokens early to avoid this not showing due to potential crashes
|
|
929
|
+
logging.debug(f"Input Tokens: {self.input_tokens}")
|
|
930
|
+
logging.trace(f"Input Message: {message}")
|
|
931
|
+
|
|
932
|
+
stopping_criteria = StoppingCriteriaList([StopOnEvent(self.stop_event)])
|
|
933
|
+
|
|
934
|
+
generation_kwargs = {
|
|
935
|
+
"input_ids": input_ids,
|
|
936
|
+
"streamer": streamer,
|
|
937
|
+
"max_new_tokens": (
|
|
938
|
+
max_new_tokens if max_new_tokens else DEFAULT_MAX_NEW_TOKENS
|
|
939
|
+
),
|
|
940
|
+
"min_new_tokens": 1,
|
|
941
|
+
"pad_token_id": tokenizer.eos_token_id,
|
|
942
|
+
"stopping_criteria": stopping_criteria,
|
|
943
|
+
"temperature": temperature,
|
|
944
|
+
}
|
|
945
|
+
|
|
946
|
+
# Initialize performance variables
|
|
947
|
+
generation_start_time = time.perf_counter()
|
|
948
|
+
first_token = True
|
|
949
|
+
self.decode_token_times = []
|
|
950
|
+
self.output_tokens = 0
|
|
951
|
+
|
|
952
|
+
# Begin generation
|
|
953
|
+
thread = GeneratorThread(
|
|
954
|
+
streamer, target=model.generate, kwargs=generation_kwargs
|
|
955
|
+
)
|
|
956
|
+
thread.start()
|
|
957
|
+
|
|
958
|
+
# Acquire the generation semaphore
|
|
959
|
+
await self._generate_semaphore.acquire()
|
|
960
|
+
active_generations = (
|
|
961
|
+
self.max_concurrent_generations
|
|
962
|
+
- self._generate_semaphore._value # pylint: disable=protected-access
|
|
963
|
+
)
|
|
964
|
+
|
|
965
|
+
logging.debug(f"Active generations: {active_generations}")
|
|
966
|
+
|
|
967
|
+
try:
|
|
968
|
+
# Generate the response using streaming
|
|
969
|
+
new_text = ""
|
|
970
|
+
for new_text in streamer:
|
|
971
|
+
# Yield control back to the event loop
|
|
972
|
+
# This gives the FastAPI server a chance to send the chunks to the client
|
|
973
|
+
await asyncio.sleep(0)
|
|
974
|
+
|
|
975
|
+
# Capture performance stats about this token
|
|
976
|
+
self.output_tokens = self.output_tokens + 1
|
|
977
|
+
if first_token:
|
|
978
|
+
self.time_to_first_token = (
|
|
979
|
+
time.perf_counter() - generation_start_time
|
|
980
|
+
)
|
|
981
|
+
first_token = False
|
|
982
|
+
else:
|
|
983
|
+
self.decode_token_times.append(
|
|
984
|
+
time.perf_counter() - next_token_start_time
|
|
985
|
+
)
|
|
986
|
+
next_token_start_time = time.perf_counter()
|
|
987
|
+
|
|
988
|
+
# Remove the EOS token from the response if needed
|
|
989
|
+
if hasattr(self.tokenizer, "eos_token"):
|
|
990
|
+
new_text = new_text.replace(self.tokenizer.eos_token, "")
|
|
991
|
+
|
|
992
|
+
# Check for stop sequences
|
|
993
|
+
if stop_sequences:
|
|
994
|
+
for stop_seq in stop_sequences:
|
|
995
|
+
if stop_seq in new_text:
|
|
996
|
+
# Make sure we yield the text up to before the stop sequence
|
|
997
|
+
new_text = new_text[: new_text.find(stop_seq)]
|
|
998
|
+
self.stop_event.set()
|
|
999
|
+
|
|
1000
|
+
yield new_text
|
|
1001
|
+
|
|
1002
|
+
# Allow the user to finish the response early
|
|
1003
|
+
if self.stop_event.is_set():
|
|
1004
|
+
logging.info("Stopping generation early.")
|
|
1005
|
+
break
|
|
1006
|
+
|
|
1007
|
+
if len(self.decode_token_times) > 0:
|
|
1008
|
+
self.tokens_per_second = 1 / statistics.mean(self.decode_token_times)
|
|
1009
|
+
else:
|
|
1010
|
+
self.tokens_per_second = 0
|
|
1011
|
+
|
|
1012
|
+
finally:
|
|
1013
|
+
thread.join()
|
|
1014
|
+
|
|
1015
|
+
# Release the semaphore when generation is complete (or if an error occurs)
|
|
1016
|
+
self._generate_semaphore.release()
|
|
1017
|
+
active_generations = (
|
|
1018
|
+
self.max_concurrent_generations
|
|
1019
|
+
- self._generate_semaphore._value # pylint: disable=protected-access
|
|
1020
|
+
)
|
|
1021
|
+
|
|
1022
|
+
# Check if an exception occurred in the generation thread
|
|
1023
|
+
# If it did, raise it as an HTTPException so that the client
|
|
1024
|
+
# knows they wont be getting a completion
|
|
1025
|
+
if thread.exception:
|
|
1026
|
+
raise HTTPException(
|
|
1027
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
1028
|
+
detail=f"Completion failure: {thread.exception}",
|
|
1029
|
+
)
|
|
1030
|
+
|
|
1031
|
+
# Display telemetry if in debug mode
|
|
1032
|
+
await self._show_telemetry()
|
|
1033
|
+
|
|
1034
|
+
async def send_stats(self):
|
|
1035
|
+
"""
|
|
1036
|
+
Send performance statistics to the client.
|
|
1037
|
+
"""
|
|
1038
|
+
return {
|
|
1039
|
+
"time_to_first_token": self.time_to_first_token,
|
|
1040
|
+
"tokens_per_second": self.tokens_per_second,
|
|
1041
|
+
"input_tokens": self.input_tokens,
|
|
1042
|
+
"output_tokens": self.output_tokens,
|
|
1043
|
+
"decode_token_times": self.decode_token_times,
|
|
1044
|
+
}
|
|
1045
|
+
|
|
1046
|
+
async def halt_generation(self):
|
|
1047
|
+
"""
|
|
1048
|
+
Allow the client to halt an in-progress generation.
|
|
1049
|
+
"""
|
|
1050
|
+
|
|
1051
|
+
self.stop_event.set()
|
|
1052
|
+
|
|
1053
|
+
return {
|
|
1054
|
+
"terminated": True,
|
|
1055
|
+
}
|
|
1056
|
+
|
|
1057
|
+
async def health(self):
|
|
1058
|
+
"""
|
|
1059
|
+
Report server health information to the client.
|
|
1060
|
+
"""
|
|
1061
|
+
self.stop_event.set()
|
|
1062
|
+
|
|
1063
|
+
return {
|
|
1064
|
+
"status": "ok",
|
|
1065
|
+
"checkpoint_loaded": (
|
|
1066
|
+
self.llm_loaded.checkpoint if self.llm_loaded else None
|
|
1067
|
+
),
|
|
1068
|
+
"model_loaded": (
|
|
1069
|
+
self.llm_loaded.model_name
|
|
1070
|
+
if (self.llm_loaded and self.llm_loaded.model_name)
|
|
1071
|
+
else None
|
|
1072
|
+
),
|
|
1073
|
+
}
|
|
1074
|
+
|
|
1075
|
+
def model_load_failure(self, model_reference: str, message: Optional[str] = None):
|
|
1076
|
+
"""
|
|
1077
|
+
Clean up after a model load failure, then log it and raise
|
|
1078
|
+
an HTTPException with details.
|
|
1079
|
+
"""
|
|
1080
|
+
self.llm_loaded = None
|
|
1081
|
+
self.tokenizer = None
|
|
1082
|
+
self.model = None
|
|
1083
|
+
|
|
1084
|
+
default_message = f"model {model_reference} not found"
|
|
1085
|
+
if message:
|
|
1086
|
+
detail = message
|
|
1087
|
+
else:
|
|
1088
|
+
detail = default_message
|
|
1089
|
+
|
|
1090
|
+
logging.exception(f"Tried to load LLM {model_reference} and failed: {detail}")
|
|
1091
|
+
|
|
1092
|
+
raise HTTPException(
|
|
1093
|
+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
1094
|
+
detail=detail,
|
|
1095
|
+
)
|
|
1096
|
+
|
|
1097
|
+
def recipe_missing_error(self, model_reference: str):
|
|
1098
|
+
self.model_load_failure(
|
|
1099
|
+
model_reference,
|
|
1100
|
+
message=(
|
|
1101
|
+
f"Attempted to load model by checkpoint name {model_reference}, "
|
|
1102
|
+
"however the required 'recipe' parameter was not provided"
|
|
1103
|
+
),
|
|
1104
|
+
)
|
|
1105
|
+
|
|
1106
|
+
async def pull(self, config: PullConfig):
|
|
1107
|
+
"""
|
|
1108
|
+
Install a supported LLM by its Lemonade Model Name.
|
|
1109
|
+
"""
|
|
1110
|
+
|
|
1111
|
+
# Install the model
|
|
1112
|
+
ModelManager().download_models([config.model_name])
|
|
1113
|
+
|
|
1114
|
+
# Refresh the list of downloaded models, to ensure it
|
|
1115
|
+
# includes the model we just installed
|
|
1116
|
+
self.local_models = ModelManager().downloaded_models_enabled
|
|
1117
|
+
|
|
1118
|
+
async def load_llm(self, config: LoadConfig, internal_call=False):
|
|
1119
|
+
"""
|
|
1120
|
+
Load an LLM into system memory.
|
|
1121
|
+
config: the information required to load the model
|
|
1122
|
+
internal_call: indicates whether the call to this function came from
|
|
1123
|
+
an endpoint (False) or a method of this class (True)
|
|
1124
|
+
|
|
1125
|
+
There are 3 ways this method can be called:
|
|
1126
|
+
1. An external application asks to load a model by name, using the load endpoint
|
|
1127
|
+
a. This only differs from #2 in that an external application may
|
|
1128
|
+
provide more parameters than in #2, so we need to validate
|
|
1129
|
+
that those parameters are ok.
|
|
1130
|
+
b. Load the model
|
|
1131
|
+
|
|
1132
|
+
2. An external application asks to load a model by name,
|
|
1133
|
+
using the completions or chat_completions endpoints
|
|
1134
|
+
a. Look up the name in the built-in model dictionary to create
|
|
1135
|
+
a fully-populated LoadConfig.
|
|
1136
|
+
b. Load the model
|
|
1137
|
+
|
|
1138
|
+
3. An external application asks to load a model by checkpoint and recipe,
|
|
1139
|
+
using the load endpoint
|
|
1140
|
+
a. Populate the checkpoint and recipe into a LoadConfig
|
|
1141
|
+
b. Load the model
|
|
1142
|
+
|
|
1143
|
+
4. Completions or ChatCompletions asks to "load" a model by checkpoint
|
|
1144
|
+
a. This is only available when #3 has already been executed
|
|
1145
|
+
b. Verify that the checkpoint is already loaded,
|
|
1146
|
+
and raise an exception if it hasn't (don't load anything new)
|
|
1147
|
+
"""
|
|
1148
|
+
try:
|
|
1149
|
+
await self._load_lock.acquire()
|
|
1150
|
+
|
|
1151
|
+
# Acquire all generate locks
|
|
1152
|
+
for _ in range(self.max_concurrent_generations):
|
|
1153
|
+
await self._generate_semaphore.acquire()
|
|
1154
|
+
|
|
1155
|
+
# We will populate a LoadConfig that has all of the required fields
|
|
1156
|
+
config_to_use: LoadConfig
|
|
1157
|
+
|
|
1158
|
+
# First, validate that the arguments are valid
|
|
1159
|
+
if config.model_name:
|
|
1160
|
+
# Get the dictionary of supported model from disk
|
|
1161
|
+
supported_models = ModelManager().supported_models
|
|
1162
|
+
|
|
1163
|
+
# Refer to the model by name, since we know the name
|
|
1164
|
+
model_reference = config.model_name
|
|
1165
|
+
|
|
1166
|
+
if config.checkpoint or config.recipe:
|
|
1167
|
+
# Option #1, verify that there are no parameter mismatches
|
|
1168
|
+
built_in_config = supported_models[config.model_name]
|
|
1169
|
+
if config.checkpoint != built_in_config["checkpoint"]:
|
|
1170
|
+
self.model_load_failure(
|
|
1171
|
+
model_reference,
|
|
1172
|
+
message=(
|
|
1173
|
+
f"Load request for model_name={config.model_name} "
|
|
1174
|
+
"included a mismatched "
|
|
1175
|
+
f"checkpoint={config.checkpoint} parameter. Remove the checkpoint "
|
|
1176
|
+
f"parameter, or change it to {built_in_config['checkpoint']}."
|
|
1177
|
+
),
|
|
1178
|
+
)
|
|
1179
|
+
if config.recipe != built_in_config["recipe"]:
|
|
1180
|
+
self.model_load_failure(
|
|
1181
|
+
model_reference,
|
|
1182
|
+
message=(
|
|
1183
|
+
f"Load request for model_name={config.model_name} "
|
|
1184
|
+
"included a mismatched "
|
|
1185
|
+
f"recipe={config.recipe} parameter. Remove the checkpoint "
|
|
1186
|
+
f"parameter, or change it to {built_in_config['recipe']}."
|
|
1187
|
+
),
|
|
1188
|
+
)
|
|
1189
|
+
|
|
1190
|
+
# Use the config as-is
|
|
1191
|
+
config_to_use = config
|
|
1192
|
+
else:
|
|
1193
|
+
# Option #2, look up the config from the supported models dictionary
|
|
1194
|
+
config_to_use = LoadConfig(**supported_models[config.model_name])
|
|
1195
|
+
|
|
1196
|
+
elif config.checkpoint:
|
|
1197
|
+
# Refer to the model by checkpoint
|
|
1198
|
+
model_reference = config.checkpoint
|
|
1199
|
+
|
|
1200
|
+
if config.recipe and not internal_call:
|
|
1201
|
+
# Option 3, use the config as-is, but add a custom model name
|
|
1202
|
+
config_to_use = config
|
|
1203
|
+
config_to_use.model_name = "Custom"
|
|
1204
|
+
elif internal_call:
|
|
1205
|
+
# Option 4, make sure the right checkpoint is loaded and then return
|
|
1206
|
+
if (
|
|
1207
|
+
self.llm_loaded
|
|
1208
|
+
and config.checkpoint == self.llm_loaded.checkpoint
|
|
1209
|
+
):
|
|
1210
|
+
return {
|
|
1211
|
+
"status": "success",
|
|
1212
|
+
"message": f"Model already loaded: {model_reference}",
|
|
1213
|
+
}
|
|
1214
|
+
else:
|
|
1215
|
+
self.model_load_failure(
|
|
1216
|
+
model_reference,
|
|
1217
|
+
message=(
|
|
1218
|
+
"Attempted run completions by using model=<checkpoint name>, "
|
|
1219
|
+
"however, "
|
|
1220
|
+
"this feature only works if the model has already been loaded "
|
|
1221
|
+
"using the load endpoint."
|
|
1222
|
+
),
|
|
1223
|
+
)
|
|
1224
|
+
else:
|
|
1225
|
+
self.recipe_missing_error(model_reference)
|
|
1226
|
+
else:
|
|
1227
|
+
self.model_load_failure(
|
|
1228
|
+
None,
|
|
1229
|
+
message="Load requests must contain either a model_name or a "
|
|
1230
|
+
"checkpoint parameter",
|
|
1231
|
+
)
|
|
1232
|
+
|
|
1233
|
+
# Caching mechanism: if the checkpoint is already loaded there is nothing else to do
|
|
1234
|
+
if (
|
|
1235
|
+
self.llm_loaded
|
|
1236
|
+
and config_to_use.checkpoint == self.llm_loaded.checkpoint
|
|
1237
|
+
):
|
|
1238
|
+
return {
|
|
1239
|
+
"status": "success",
|
|
1240
|
+
"message": f"Model already loaded: {model_reference}",
|
|
1241
|
+
}
|
|
1242
|
+
|
|
1243
|
+
# Unload the current model if needed
|
|
1244
|
+
if self.llm_loaded:
|
|
1245
|
+
await self.unload_llm(require_lock=False)
|
|
1246
|
+
|
|
1247
|
+
logging.info(f"Loading llm: {model_reference}")
|
|
1248
|
+
try:
|
|
1249
|
+
self.model, self.tokenizer = lemonade_api.from_pretrained(
|
|
1250
|
+
checkpoint=config_to_use.checkpoint, recipe=config_to_use.recipe
|
|
1251
|
+
)
|
|
1252
|
+
self.llm_loaded = config_to_use
|
|
1253
|
+
|
|
1254
|
+
return {
|
|
1255
|
+
"status": "success",
|
|
1256
|
+
"message": f"Loaded model: {model_reference}",
|
|
1257
|
+
}
|
|
1258
|
+
except Exception: # pylint: disable=broad-exception-caught
|
|
1259
|
+
self.model_load_failure(model_reference)
|
|
1260
|
+
|
|
1261
|
+
finally:
|
|
1262
|
+
self._load_lock.release()
|
|
1263
|
+
|
|
1264
|
+
# Release all generate locks
|
|
1265
|
+
for _ in range(self.max_concurrent_generations):
|
|
1266
|
+
self._generate_semaphore.release()
|
|
1267
|
+
|
|
1268
|
+
# Refresh the list of downloaded models, to ensure it
|
|
1269
|
+
# includes the model we just loaded
|
|
1270
|
+
if config.model_name not in self.local_models:
|
|
1271
|
+
self.local_models = ModelManager().downloaded_models_enabled
|
|
1272
|
+
|
|
1273
|
+
async def unload_llm(self, require_lock: bool = True):
|
|
1274
|
+
try:
|
|
1275
|
+
if require_lock:
|
|
1276
|
+
await self._load_lock.acquire()
|
|
1277
|
+
|
|
1278
|
+
# Acquire all generate locks
|
|
1279
|
+
for _ in range(self.max_concurrent_generations):
|
|
1280
|
+
await self._generate_semaphore.acquire()
|
|
1281
|
+
|
|
1282
|
+
self.llm_loaded = None
|
|
1283
|
+
self.tokenizer = None
|
|
1284
|
+
self.model = None
|
|
1285
|
+
return {"status": "success", "message": "Unloaded model"}
|
|
1286
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
|
1287
|
+
return {
|
|
1288
|
+
"status": "error",
|
|
1289
|
+
"message": f"Failed to unload model: {str(e)}",
|
|
1290
|
+
}
|
|
1291
|
+
finally:
|
|
1292
|
+
if require_lock:
|
|
1293
|
+
self._load_lock.release()
|
|
1294
|
+
|
|
1295
|
+
# Release all generate locks
|
|
1296
|
+
for _ in range(self.max_concurrent_generations):
|
|
1297
|
+
self._generate_semaphore.release()
|
|
1298
|
+
|
|
1299
|
+
async def models(self):
|
|
1300
|
+
"""
|
|
1301
|
+
Return a list of available models in OpenAI-compatible format.
|
|
1302
|
+
"""
|
|
1303
|
+
models_list = []
|
|
1304
|
+
for model in self.local_models:
|
|
1305
|
+
m = ServerModel(
|
|
1306
|
+
id=model,
|
|
1307
|
+
owned_by="lemonade",
|
|
1308
|
+
object="model",
|
|
1309
|
+
created=int(time.time()),
|
|
1310
|
+
checkpoint=self.local_models[model]["checkpoint"],
|
|
1311
|
+
recipe=self.local_models[model]["recipe"],
|
|
1312
|
+
)
|
|
1313
|
+
models_list.append(m)
|
|
1314
|
+
|
|
1315
|
+
return {"object": "list", "data": models_list}
|
|
1316
|
+
|
|
1317
|
+
def setup_middleware_timer(self):
|
|
1318
|
+
logging.info("Middleware set up")
|
|
1319
|
+
|
|
1320
|
+
@self.app.middleware("http")
|
|
1321
|
+
async def log_request_time(request: Request, call_next):
|
|
1322
|
+
"""
|
|
1323
|
+
Log the request processing time for any request.
|
|
1324
|
+
For streaming responses, wraps the body iterator to measure total time.
|
|
1325
|
+
Only applies the wrapper in debug mode.
|
|
1326
|
+
"""
|
|
1327
|
+
start_time = time.perf_counter()
|
|
1328
|
+
response = await call_next(request)
|
|
1329
|
+
|
|
1330
|
+
if (
|
|
1331
|
+
self.debug_logging_enabled
|
|
1332
|
+
and hasattr(response, "body_iterator")
|
|
1333
|
+
and response.body_iterator is not None
|
|
1334
|
+
):
|
|
1335
|
+
original_iterator = response.body_iterator
|
|
1336
|
+
|
|
1337
|
+
async def wrapped_iterator():
|
|
1338
|
+
async for chunk in original_iterator:
|
|
1339
|
+
yield chunk
|
|
1340
|
+
request_time = time.perf_counter() - start_time
|
|
1341
|
+
logging.debug(
|
|
1342
|
+
f"Total request time (streamed): {request_time:.4f} seconds"
|
|
1343
|
+
)
|
|
1344
|
+
|
|
1345
|
+
response.body_iterator = wrapped_iterator()
|
|
1346
|
+
else:
|
|
1347
|
+
request_time = time.perf_counter() - start_time
|
|
1348
|
+
if self.debug_logging_enabled:
|
|
1349
|
+
logging.debug(f"Total request time: {request_time:.4f} seconds")
|
|
1350
|
+
return response
|
|
1351
|
+
|
|
1352
|
+
|
|
1353
|
+
# This file was originally licensed under Apache 2.0. It has been modified.
|
|
1354
|
+
# Modifications Copyright (c) 2025 AMD
|