lemonade-sdk 9.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lemonade/__init__.py +5 -0
- lemonade/api.py +180 -0
- lemonade/cache.py +92 -0
- lemonade/cli.py +173 -0
- lemonade/common/__init__.py +0 -0
- lemonade/common/build.py +176 -0
- lemonade/common/cli_helpers.py +139 -0
- lemonade/common/exceptions.py +98 -0
- lemonade/common/filesystem.py +368 -0
- lemonade/common/inference_engines.py +408 -0
- lemonade/common/network.py +93 -0
- lemonade/common/printing.py +110 -0
- lemonade/common/status.py +471 -0
- lemonade/common/system_info.py +1411 -0
- lemonade/common/test_helpers.py +28 -0
- lemonade/profilers/__init__.py +1 -0
- lemonade/profilers/agt_power.py +437 -0
- lemonade/profilers/hwinfo_power.py +429 -0
- lemonade/profilers/memory_tracker.py +259 -0
- lemonade/profilers/profiler.py +58 -0
- lemonade/sequence.py +363 -0
- lemonade/state.py +159 -0
- lemonade/tools/__init__.py +1 -0
- lemonade/tools/accuracy.py +432 -0
- lemonade/tools/adapter.py +114 -0
- lemonade/tools/bench.py +302 -0
- lemonade/tools/flm/__init__.py +1 -0
- lemonade/tools/flm/utils.py +305 -0
- lemonade/tools/huggingface/bench.py +187 -0
- lemonade/tools/huggingface/load.py +235 -0
- lemonade/tools/huggingface/utils.py +359 -0
- lemonade/tools/humaneval.py +264 -0
- lemonade/tools/llamacpp/bench.py +255 -0
- lemonade/tools/llamacpp/load.py +222 -0
- lemonade/tools/llamacpp/utils.py +1260 -0
- lemonade/tools/management_tools.py +319 -0
- lemonade/tools/mmlu.py +319 -0
- lemonade/tools/oga/__init__.py +0 -0
- lemonade/tools/oga/bench.py +120 -0
- lemonade/tools/oga/load.py +804 -0
- lemonade/tools/oga/migration.py +403 -0
- lemonade/tools/oga/utils.py +462 -0
- lemonade/tools/perplexity.py +147 -0
- lemonade/tools/prompt.py +263 -0
- lemonade/tools/report/__init__.py +0 -0
- lemonade/tools/report/llm_report.py +203 -0
- lemonade/tools/report/table.py +899 -0
- lemonade/tools/server/__init__.py +0 -0
- lemonade/tools/server/flm.py +133 -0
- lemonade/tools/server/llamacpp.py +320 -0
- lemonade/tools/server/serve.py +2123 -0
- lemonade/tools/server/static/favicon.ico +0 -0
- lemonade/tools/server/static/index.html +279 -0
- lemonade/tools/server/static/js/chat.js +1059 -0
- lemonade/tools/server/static/js/model-settings.js +183 -0
- lemonade/tools/server/static/js/models.js +1395 -0
- lemonade/tools/server/static/js/shared.js +556 -0
- lemonade/tools/server/static/logs.html +191 -0
- lemonade/tools/server/static/styles.css +2654 -0
- lemonade/tools/server/static/webapp.html +321 -0
- lemonade/tools/server/tool_calls.py +153 -0
- lemonade/tools/server/tray.py +664 -0
- lemonade/tools/server/utils/macos_tray.py +226 -0
- lemonade/tools/server/utils/port.py +77 -0
- lemonade/tools/server/utils/thread.py +85 -0
- lemonade/tools/server/utils/windows_tray.py +408 -0
- lemonade/tools/server/webapp.py +34 -0
- lemonade/tools/server/wrapped_server.py +559 -0
- lemonade/tools/tool.py +374 -0
- lemonade/version.py +1 -0
- lemonade_install/__init__.py +1 -0
- lemonade_install/install.py +239 -0
- lemonade_sdk-9.1.1.dist-info/METADATA +276 -0
- lemonade_sdk-9.1.1.dist-info/RECORD +84 -0
- lemonade_sdk-9.1.1.dist-info/WHEEL +5 -0
- lemonade_sdk-9.1.1.dist-info/entry_points.txt +5 -0
- lemonade_sdk-9.1.1.dist-info/licenses/LICENSE +201 -0
- lemonade_sdk-9.1.1.dist-info/licenses/NOTICE.md +47 -0
- lemonade_sdk-9.1.1.dist-info/top_level.txt +3 -0
- lemonade_server/cli.py +805 -0
- lemonade_server/model_manager.py +758 -0
- lemonade_server/pydantic_models.py +159 -0
- lemonade_server/server_models.json +643 -0
- lemonade_server/settings.py +39 -0
|
@@ -0,0 +1,559 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import time
|
|
3
|
+
import subprocess
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
|
|
6
|
+
import requests
|
|
7
|
+
from tabulate import tabulate
|
|
8
|
+
from fastapi import HTTPException, status
|
|
9
|
+
from fastapi.responses import StreamingResponse
|
|
10
|
+
|
|
11
|
+
from openai import OpenAI
|
|
12
|
+
|
|
13
|
+
from lemonade_server.pydantic_models import (
|
|
14
|
+
ChatCompletionRequest,
|
|
15
|
+
CompletionRequest,
|
|
16
|
+
PullConfig,
|
|
17
|
+
EmbeddingsRequest,
|
|
18
|
+
RerankingRequest,
|
|
19
|
+
)
|
|
20
|
+
from lemonade_server.model_manager import ModelManager
|
|
21
|
+
from lemonade.tools.server.utils.port import find_free_port
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class WrappedServerTelemetry(ABC):
|
|
25
|
+
"""
|
|
26
|
+
Manages telemetry data collection and display for wrapped server.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(self):
|
|
30
|
+
self.input_tokens = None
|
|
31
|
+
self.output_tokens = None
|
|
32
|
+
self.time_to_first_token = None
|
|
33
|
+
self.tokens_per_second = None
|
|
34
|
+
self.prompt_eval_time = None
|
|
35
|
+
self.eval_time = None
|
|
36
|
+
|
|
37
|
+
@abstractmethod
|
|
38
|
+
def parse_telemetry_line(self, line: str):
|
|
39
|
+
"""
|
|
40
|
+
Parse telemetry data from wrapped server output lines.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def get_telemetry_data(self):
|
|
44
|
+
return {
|
|
45
|
+
"input_tokens": self.input_tokens,
|
|
46
|
+
"output_tokens": self.output_tokens,
|
|
47
|
+
"time_to_first_token": self.time_to_first_token,
|
|
48
|
+
"tokens_per_second": self.tokens_per_second,
|
|
49
|
+
"decode_token_times": None,
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
def show_telemetry(self):
|
|
53
|
+
# Check if debug logging is enabled
|
|
54
|
+
if not logging.getLogger().isEnabledFor(logging.DEBUG):
|
|
55
|
+
return
|
|
56
|
+
|
|
57
|
+
# Prepare telemetry data (transposed format)
|
|
58
|
+
telemetry = [
|
|
59
|
+
["Input tokens", self.input_tokens],
|
|
60
|
+
["Output tokens", self.output_tokens],
|
|
61
|
+
[
|
|
62
|
+
"TTFT (s)",
|
|
63
|
+
(
|
|
64
|
+
f"{self.time_to_first_token:.2f}"
|
|
65
|
+
if self.time_to_first_token is not None
|
|
66
|
+
else "N/A"
|
|
67
|
+
),
|
|
68
|
+
],
|
|
69
|
+
[
|
|
70
|
+
"TPS",
|
|
71
|
+
(
|
|
72
|
+
f"{self.tokens_per_second:.2f}"
|
|
73
|
+
if self.tokens_per_second is not None
|
|
74
|
+
else "N/A"
|
|
75
|
+
),
|
|
76
|
+
],
|
|
77
|
+
]
|
|
78
|
+
|
|
79
|
+
table = tabulate(
|
|
80
|
+
telemetry, headers=["Metric", "Value"], tablefmt="fancy_grid"
|
|
81
|
+
).split("\n")
|
|
82
|
+
|
|
83
|
+
# Show telemetry in debug while complying with uvicorn's log indentation
|
|
84
|
+
logging.debug("\n ".join(table))
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class WrappedServer(ABC):
|
|
88
|
+
"""
|
|
89
|
+
Abstract base class that defines the interface for Lemonade to "wrap" a server
|
|
90
|
+
like llama-server.
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
def __init__(self, server_name: str, telemetry: WrappedServerTelemetry):
|
|
94
|
+
self.port: int = None
|
|
95
|
+
self.process: subprocess.Popen = None
|
|
96
|
+
self.server_name: str = server_name
|
|
97
|
+
self.telemetry: WrappedServerTelemetry = telemetry
|
|
98
|
+
self.log_thread_exception = None
|
|
99
|
+
|
|
100
|
+
def _choose_port(self):
|
|
101
|
+
"""
|
|
102
|
+
Users probably don't care what port we start the wrapped server on, so let's
|
|
103
|
+
search for an empty port
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
self.port = find_free_port()
|
|
107
|
+
|
|
108
|
+
if self.port is None:
|
|
109
|
+
msg = f"Failed to find an empty port to start {self.server_name} on"
|
|
110
|
+
logging.error(msg)
|
|
111
|
+
raise HTTPException(
|
|
112
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
113
|
+
detail=msg,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
def address(self) -> str:
|
|
117
|
+
"""
|
|
118
|
+
Generate the base URL for the server.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
The base URL for the wrapped server
|
|
122
|
+
"""
|
|
123
|
+
return f"http://127.0.0.1:{self.port}/v1"
|
|
124
|
+
|
|
125
|
+
def _separate_openai_params(
|
|
126
|
+
self,
|
|
127
|
+
request_dict: dict,
|
|
128
|
+
endpoint_type: str = "chat",
|
|
129
|
+
) -> dict:
|
|
130
|
+
"""
|
|
131
|
+
Separate standard OpenAI parameters from custom wrapped server parameters.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
request_dict: Dictionary of all request parameters
|
|
135
|
+
endpoint_type: Type of endpoint ("chat" or "completion")
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
Dictionary with parameters properly separated for OpenAI client
|
|
139
|
+
"""
|
|
140
|
+
openai_client_params = {}
|
|
141
|
+
extra_params = {}
|
|
142
|
+
|
|
143
|
+
# Common OpenAI parameters for both endpoint types
|
|
144
|
+
common_params = {
|
|
145
|
+
"model",
|
|
146
|
+
"frequency_penalty",
|
|
147
|
+
"logit_bias",
|
|
148
|
+
"logprobs",
|
|
149
|
+
"max_tokens",
|
|
150
|
+
"n",
|
|
151
|
+
"presence_penalty",
|
|
152
|
+
"seed",
|
|
153
|
+
"stop",
|
|
154
|
+
"stream",
|
|
155
|
+
"temperature",
|
|
156
|
+
"top_p",
|
|
157
|
+
"user",
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
# Standard OpenAI parameters by endpoint type
|
|
161
|
+
if endpoint_type == "chat":
|
|
162
|
+
chat_specific_params = {
|
|
163
|
+
"messages",
|
|
164
|
+
"top_logprobs",
|
|
165
|
+
"response_format",
|
|
166
|
+
"service_tier",
|
|
167
|
+
"stream_options",
|
|
168
|
+
"tools",
|
|
169
|
+
"tool_choice",
|
|
170
|
+
"parallel_tool_calls",
|
|
171
|
+
}
|
|
172
|
+
openai_params = common_params | chat_specific_params
|
|
173
|
+
else: # completion
|
|
174
|
+
completion_specific_params = {
|
|
175
|
+
"prompt",
|
|
176
|
+
"best_of",
|
|
177
|
+
"echo",
|
|
178
|
+
"suffix",
|
|
179
|
+
}
|
|
180
|
+
openai_params = common_params | completion_specific_params
|
|
181
|
+
|
|
182
|
+
for key, value in request_dict.items():
|
|
183
|
+
if key in openai_params:
|
|
184
|
+
openai_client_params[key] = value
|
|
185
|
+
else:
|
|
186
|
+
extra_params[key] = value
|
|
187
|
+
|
|
188
|
+
# If there are custom parameters, use extra_body to pass them through
|
|
189
|
+
if extra_params:
|
|
190
|
+
openai_client_params["extra_body"] = extra_params
|
|
191
|
+
|
|
192
|
+
return openai_client_params
|
|
193
|
+
|
|
194
|
+
def _log_subprocess_output(self, prefix: str):
|
|
195
|
+
"""
|
|
196
|
+
Read subprocess output line by line, log to debug, and parse telemetry
|
|
197
|
+
"""
|
|
198
|
+
|
|
199
|
+
if self.process.stdout:
|
|
200
|
+
try:
|
|
201
|
+
for line in iter(self.process.stdout.readline, ""):
|
|
202
|
+
if line:
|
|
203
|
+
line_stripped = line.strip()
|
|
204
|
+
logging.debug("%s: %s", prefix, line_stripped)
|
|
205
|
+
|
|
206
|
+
self.telemetry.parse_telemetry_line(line_stripped)
|
|
207
|
+
|
|
208
|
+
if self.process.poll() is not None:
|
|
209
|
+
break
|
|
210
|
+
except HTTPException as e:
|
|
211
|
+
self.log_thread_exception = e
|
|
212
|
+
except UnicodeDecodeError as e:
|
|
213
|
+
logging.debug(
|
|
214
|
+
"Unicode decode error reading subprocess output: %s", str(e)
|
|
215
|
+
)
|
|
216
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
|
217
|
+
logging.error("Unexpected error reading subprocess output: %s", str(e))
|
|
218
|
+
|
|
219
|
+
def _wait_for_load(self):
|
|
220
|
+
status_code = None
|
|
221
|
+
while not self.process.poll() and status_code != 200:
|
|
222
|
+
health_url = f"http://localhost:{self.port}/health"
|
|
223
|
+
try:
|
|
224
|
+
health_response = requests.get(health_url)
|
|
225
|
+
except requests.exceptions.ConnectionError:
|
|
226
|
+
logging.debug(
|
|
227
|
+
f"Not able to connect to {self.server_name} yet, will retry"
|
|
228
|
+
)
|
|
229
|
+
else:
|
|
230
|
+
status_code = health_response.status_code
|
|
231
|
+
logging.debug(
|
|
232
|
+
f"Testing {self.server_name} readiness (will retry until ready), "
|
|
233
|
+
f"result: {health_response.json()}"
|
|
234
|
+
)
|
|
235
|
+
time.sleep(1)
|
|
236
|
+
|
|
237
|
+
if self.log_thread_exception:
|
|
238
|
+
e = self.log_thread_exception
|
|
239
|
+
self.log_thread_exception = None
|
|
240
|
+
raise e
|
|
241
|
+
|
|
242
|
+
@abstractmethod
|
|
243
|
+
def _launch_server_subprocess(
|
|
244
|
+
self,
|
|
245
|
+
model_config: PullConfig,
|
|
246
|
+
snapshot_files: dict,
|
|
247
|
+
ctx_size: int,
|
|
248
|
+
supports_embeddings: bool = False,
|
|
249
|
+
supports_reranking: bool = False,
|
|
250
|
+
):
|
|
251
|
+
"""
|
|
252
|
+
Launch wrapped server subprocess with appropriate configuration.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
snapshot_files: Dictionary of model files to load
|
|
256
|
+
supports_embeddings: Whether the model supports embeddings
|
|
257
|
+
supports_reranking: Whether the model supports reranking
|
|
258
|
+
"""
|
|
259
|
+
|
|
260
|
+
@abstractmethod
|
|
261
|
+
def install_server(self, backend=None):
|
|
262
|
+
"""
|
|
263
|
+
Install the wrapped server
|
|
264
|
+
"""
|
|
265
|
+
|
|
266
|
+
@abstractmethod
|
|
267
|
+
def download_model(
|
|
268
|
+
self, config_checkpoint, config_mmproj=None, do_not_upgrade=False
|
|
269
|
+
) -> dict:
|
|
270
|
+
"""
|
|
271
|
+
Download a model for the wrapper server
|
|
272
|
+
"""
|
|
273
|
+
|
|
274
|
+
def load(
|
|
275
|
+
self,
|
|
276
|
+
model_config: PullConfig,
|
|
277
|
+
ctx_size: int,
|
|
278
|
+
do_not_upgrade: bool = False,
|
|
279
|
+
):
|
|
280
|
+
# Install and/or update the wrapped server if needed
|
|
281
|
+
try:
|
|
282
|
+
self.install_server()
|
|
283
|
+
except NotImplementedError as e:
|
|
284
|
+
raise HTTPException(
|
|
285
|
+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e)
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
# Download the model to the hugging face cache
|
|
289
|
+
snapshot_files = self.download_model(
|
|
290
|
+
model_config.checkpoint, model_config.mmproj, do_not_upgrade=do_not_upgrade
|
|
291
|
+
)
|
|
292
|
+
logging.debug(f"Model file paths: {snapshot_files}")
|
|
293
|
+
|
|
294
|
+
# Check if model supports embeddings
|
|
295
|
+
supported_models = ModelManager().supported_models
|
|
296
|
+
model_info = supported_models.get(model_config.model_name, {})
|
|
297
|
+
supports_embeddings = "embeddings" in model_info.get("labels", [])
|
|
298
|
+
supports_reranking = "reranking" in model_info.get("labels", [])
|
|
299
|
+
|
|
300
|
+
self._launch_server_subprocess(
|
|
301
|
+
model_config=model_config,
|
|
302
|
+
snapshot_files=snapshot_files,
|
|
303
|
+
ctx_size=ctx_size,
|
|
304
|
+
supports_embeddings=supports_embeddings,
|
|
305
|
+
supports_reranking=supports_reranking,
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
# Check the /health endpoint until server is ready
|
|
309
|
+
self._wait_for_load()
|
|
310
|
+
|
|
311
|
+
if self.process.poll():
|
|
312
|
+
raise HTTPException(
|
|
313
|
+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
314
|
+
detail=f"Failed to load {model_config.model_name} with {self.server_name}",
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
def chat_completion(self, chat_completion_request: ChatCompletionRequest):
|
|
318
|
+
client = OpenAI(
|
|
319
|
+
base_url=self.address(),
|
|
320
|
+
api_key="lemonade",
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
# Convert Pydantic model to dict and remove unset/null values
|
|
324
|
+
request_dict = chat_completion_request.model_dump(
|
|
325
|
+
exclude_unset=True, exclude_none=True
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
# Separate standard OpenAI parameters from custom llama.cpp parameters
|
|
329
|
+
openai_client_params = self._separate_openai_params(request_dict, "chat")
|
|
330
|
+
|
|
331
|
+
# Check if streaming is requested
|
|
332
|
+
if chat_completion_request.stream:
|
|
333
|
+
|
|
334
|
+
def event_stream():
|
|
335
|
+
# Ensure streaming is enabled in params
|
|
336
|
+
stream_params = dict(openai_client_params)
|
|
337
|
+
stream_params["stream"] = True
|
|
338
|
+
|
|
339
|
+
# Use streaming context so we can explicitly close on cancellation
|
|
340
|
+
with client.chat.completions.with_streaming_response.create(
|
|
341
|
+
# pylint: disable=missing-kwoa
|
|
342
|
+
**stream_params,
|
|
343
|
+
) as response:
|
|
344
|
+
try:
|
|
345
|
+
for line in response.iter_lines():
|
|
346
|
+
# Preserve SSE event boundaries: blank line separates events
|
|
347
|
+
if line == b"" or line == "":
|
|
348
|
+
yield "\n"
|
|
349
|
+
continue
|
|
350
|
+
if isinstance(line, bytes):
|
|
351
|
+
try:
|
|
352
|
+
line = line.decode("utf-8", errors="ignore")
|
|
353
|
+
except (UnicodeDecodeError, LookupError):
|
|
354
|
+
# Skip lines that fail decoding due to encoding issues
|
|
355
|
+
continue
|
|
356
|
+
# Forward SSE lines as-is
|
|
357
|
+
if not line.endswith("\n"):
|
|
358
|
+
line += "\n"
|
|
359
|
+
yield line
|
|
360
|
+
|
|
361
|
+
# Show telemetry after completion
|
|
362
|
+
self.telemetry.show_telemetry()
|
|
363
|
+
|
|
364
|
+
except GeneratorExit:
|
|
365
|
+
# Client disconnected/cancelled; close upstream stream and stop
|
|
366
|
+
try:
|
|
367
|
+
response.close()
|
|
368
|
+
except Exception: # pylint: disable=broad-exception-caught
|
|
369
|
+
pass
|
|
370
|
+
raise
|
|
371
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
|
372
|
+
yield f'data: {{"error": "{str(e)}"}}\n\n'
|
|
373
|
+
|
|
374
|
+
return StreamingResponse(
|
|
375
|
+
event_stream(),
|
|
376
|
+
media_type="text/event-stream",
|
|
377
|
+
headers={
|
|
378
|
+
"Cache-Control": "no-cache",
|
|
379
|
+
"Connection": "keep-alive",
|
|
380
|
+
},
|
|
381
|
+
)
|
|
382
|
+
else:
|
|
383
|
+
# Non-streaming response
|
|
384
|
+
try:
|
|
385
|
+
# Disable streaming for non-streaming requests
|
|
386
|
+
# pylint: disable=missing-kwoa
|
|
387
|
+
response = client.chat.completions.create(**openai_client_params)
|
|
388
|
+
|
|
389
|
+
# Show telemetry after completion
|
|
390
|
+
self.telemetry.show_telemetry()
|
|
391
|
+
|
|
392
|
+
return response
|
|
393
|
+
|
|
394
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
|
395
|
+
logging.error("Error during chat completion: %s", str(e))
|
|
396
|
+
raise HTTPException(
|
|
397
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
398
|
+
detail=f"Chat completion error: {str(e)}",
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
def completion(self, completion_request: CompletionRequest):
|
|
402
|
+
"""
|
|
403
|
+
Handle text completions using the wrapped server.
|
|
404
|
+
|
|
405
|
+
Args:
|
|
406
|
+
completion_request: The completion request containing prompt and parameters
|
|
407
|
+
telemetry: Telemetry object containing the server port
|
|
408
|
+
|
|
409
|
+
Returns:
|
|
410
|
+
Completion response from the wrapped server
|
|
411
|
+
"""
|
|
412
|
+
|
|
413
|
+
client = OpenAI(
|
|
414
|
+
base_url=self.address(),
|
|
415
|
+
api_key="lemonade",
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
# Convert Pydantic model to dict and remove unset/null values
|
|
419
|
+
request_dict = completion_request.model_dump(
|
|
420
|
+
exclude_unset=True, exclude_none=True
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
# Separate standard OpenAI parameters from custom llama.cpp parameters
|
|
424
|
+
openai_client_params = self._separate_openai_params(request_dict, "completion")
|
|
425
|
+
|
|
426
|
+
# Check if streaming is requested
|
|
427
|
+
if completion_request.stream:
|
|
428
|
+
|
|
429
|
+
def event_stream():
|
|
430
|
+
# Ensure streaming is enabled in params
|
|
431
|
+
stream_params = dict(openai_client_params)
|
|
432
|
+
stream_params["stream"] = True
|
|
433
|
+
|
|
434
|
+
# Use streaming context so we can explicitly close on cancellation
|
|
435
|
+
with client.completions.with_streaming_response.create(
|
|
436
|
+
# pylint: disable=missing-kwoa
|
|
437
|
+
**stream_params,
|
|
438
|
+
) as response:
|
|
439
|
+
try:
|
|
440
|
+
for line in response.iter_lines():
|
|
441
|
+
# Preserve SSE event boundaries: blank line separates events
|
|
442
|
+
if line == b"" or line == "":
|
|
443
|
+
yield "\n"
|
|
444
|
+
continue
|
|
445
|
+
if isinstance(line, bytes):
|
|
446
|
+
try:
|
|
447
|
+
line = line.decode("utf-8", errors="ignore")
|
|
448
|
+
except (UnicodeDecodeError, LookupError):
|
|
449
|
+
# Skip lines that fail decoding due to encoding issues
|
|
450
|
+
continue
|
|
451
|
+
# Forward SSE lines as-is
|
|
452
|
+
if not line.endswith("\n"):
|
|
453
|
+
line += "\n"
|
|
454
|
+
yield line
|
|
455
|
+
|
|
456
|
+
# Show telemetry after completion
|
|
457
|
+
self.telemetry.show_telemetry()
|
|
458
|
+
|
|
459
|
+
except GeneratorExit:
|
|
460
|
+
# Client disconnected/cancelled; close upstream stream and stop
|
|
461
|
+
try:
|
|
462
|
+
response.close()
|
|
463
|
+
except Exception: # pylint: disable=broad-exception-caught
|
|
464
|
+
pass
|
|
465
|
+
raise
|
|
466
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
|
467
|
+
yield f'data: {{"error": "{str(e)}"}}\n\n'
|
|
468
|
+
|
|
469
|
+
return StreamingResponse(
|
|
470
|
+
event_stream(),
|
|
471
|
+
media_type="text/event-stream",
|
|
472
|
+
headers={
|
|
473
|
+
"Cache-Control": "no-cache",
|
|
474
|
+
"Connection": "keep-alive",
|
|
475
|
+
},
|
|
476
|
+
)
|
|
477
|
+
else:
|
|
478
|
+
# Non-streaming response
|
|
479
|
+
try:
|
|
480
|
+
# Disable streaming for non-streaming requests
|
|
481
|
+
# pylint: disable=missing-kwoa
|
|
482
|
+
response = client.completions.create(**openai_client_params)
|
|
483
|
+
|
|
484
|
+
# Show telemetry after completion
|
|
485
|
+
self.telemetry.show_telemetry()
|
|
486
|
+
|
|
487
|
+
return response
|
|
488
|
+
|
|
489
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
|
490
|
+
logging.error("Error during completion: %s", str(e))
|
|
491
|
+
raise HTTPException(
|
|
492
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
493
|
+
detail=f"Completion error: {str(e)}",
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
def embeddings(self, embeddings_request: EmbeddingsRequest):
|
|
497
|
+
"""
|
|
498
|
+
Generate embeddings using the wrapped server.
|
|
499
|
+
|
|
500
|
+
Args:
|
|
501
|
+
embeddings_request: The embeddings request containing input text/tokens
|
|
502
|
+
telemetry: Telemetry object containing the server port
|
|
503
|
+
|
|
504
|
+
Returns:
|
|
505
|
+
Embeddings response from the wrapped server
|
|
506
|
+
"""
|
|
507
|
+
client = OpenAI(
|
|
508
|
+
base_url=self.address(),
|
|
509
|
+
api_key="lemonade",
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
# Convert Pydantic model to dict and remove unset/null values
|
|
513
|
+
request_dict = embeddings_request.model_dump(
|
|
514
|
+
exclude_unset=True, exclude_none=True
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
try:
|
|
518
|
+
# Call the embeddings endpoint
|
|
519
|
+
response = client.embeddings.create(**request_dict)
|
|
520
|
+
return response
|
|
521
|
+
|
|
522
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
|
523
|
+
raise HTTPException(
|
|
524
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
525
|
+
detail=f"Embeddings error: {str(e)}",
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
def reranking(self, reranking_request: RerankingRequest):
|
|
529
|
+
"""
|
|
530
|
+
Rerank documents based on their relevance to a query using the wrapped server.
|
|
531
|
+
|
|
532
|
+
Args:
|
|
533
|
+
reranking_request: The reranking request containing query and documents
|
|
534
|
+
telemetry: Telemetry object containing the server port
|
|
535
|
+
|
|
536
|
+
Returns:
|
|
537
|
+
Reranking response from the wrapped server containing ranked documents and scores
|
|
538
|
+
"""
|
|
539
|
+
|
|
540
|
+
try:
|
|
541
|
+
# Convert Pydantic model to dict and exclude unset/null values
|
|
542
|
+
request_dict = reranking_request.model_dump(
|
|
543
|
+
exclude_unset=True, exclude_none=True
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
# Call the reranking endpoint directly since it's not supported by the OpenAI API
|
|
547
|
+
response = requests.post(
|
|
548
|
+
f"{self.address()}/rerank",
|
|
549
|
+
json=request_dict,
|
|
550
|
+
)
|
|
551
|
+
response.raise_for_status()
|
|
552
|
+
return response.json()
|
|
553
|
+
|
|
554
|
+
except Exception as e:
|
|
555
|
+
logging.error("Error during reranking: %s", str(e))
|
|
556
|
+
raise HTTPException(
|
|
557
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
558
|
+
detail=f"Reranking error: {str(e)}",
|
|
559
|
+
) from e
|