lemonade-sdk 8.0.2__py3-none-any.whl → 8.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lemonade-sdk might be problematic. Click here for more details.
- lemonade/cli.py +2 -2
- lemonade/profilers/profiler.py +4 -1
- lemonade/tools/humaneval.py +1 -1
- lemonade/tools/mmlu.py +1 -1
- lemonade/tools/oga/load.py +3 -9
- lemonade/tools/perplexity.py +2 -2
- lemonade/tools/prompt.py +21 -6
- lemonade/tools/quark/quark_load.py +1 -1
- lemonade/tools/quark/quark_quantize.py +2 -2
- lemonade/tools/report/table.py +80 -0
- lemonade/tools/server/llamacpp.py +148 -16
- lemonade/tools/server/serve.py +73 -0
- lemonade/tools/server/static/styles.css +424 -4
- lemonade/tools/server/static/webapp.html +337 -38
- lemonade/tools/server/tray.py +25 -9
- lemonade/version.py +1 -1
- {lemonade_sdk-8.0.2.dist-info → lemonade_sdk-8.0.4.dist-info}/METADATA +33 -36
- {lemonade_sdk-8.0.2.dist-info → lemonade_sdk-8.0.4.dist-info}/RECORD +26 -26
- lemonade_server/model_manager.py +123 -36
- lemonade_server/pydantic_models.py +25 -1
- lemonade_server/server_models.json +53 -43
- {lemonade_sdk-8.0.2.dist-info → lemonade_sdk-8.0.4.dist-info}/WHEEL +0 -0
- {lemonade_sdk-8.0.2.dist-info → lemonade_sdk-8.0.4.dist-info}/entry_points.txt +0 -0
- {lemonade_sdk-8.0.2.dist-info → lemonade_sdk-8.0.4.dist-info}/licenses/LICENSE +0 -0
- {lemonade_sdk-8.0.2.dist-info → lemonade_sdk-8.0.4.dist-info}/licenses/NOTICE.md +0 -0
- {lemonade_sdk-8.0.2.dist-info → lemonade_sdk-8.0.4.dist-info}/top_level.txt +0 -0
lemonade/cli.py
CHANGED
|
@@ -90,9 +90,9 @@ https://github.com/lemonade-sdk/lemonade/blob/main/docs/README.md""",
|
|
|
90
90
|
)
|
|
91
91
|
|
|
92
92
|
profiler_instances = [
|
|
93
|
-
profiler(global_args[profiler.unique_name])
|
|
93
|
+
profiler(global_args[profiler.unique_name.replace("-", "_")])
|
|
94
94
|
for profiler in profilers
|
|
95
|
-
if global_args.get(profiler.unique_name, None) is not None
|
|
95
|
+
if global_args.get(profiler.unique_name.replace("-", "_"), None) is not None
|
|
96
96
|
]
|
|
97
97
|
|
|
98
98
|
if len(evaluation_tools) > 0:
|
lemonade/profilers/profiler.py
CHANGED
|
@@ -48,7 +48,10 @@ class Profiler(abc.ABC):
|
|
|
48
48
|
This method is called so that the profiler can create its output files.
|
|
49
49
|
The state is passed so that build info can be gathered and stats can be written.
|
|
50
50
|
The timestamp can be used for filename in current working directory.
|
|
51
|
-
The start times
|
|
51
|
+
The start times parameter is a dict with the keys being the tools names and
|
|
52
|
+
the values being the time the tool started. There is an initial "warmup" key
|
|
53
|
+
that has a start time before the first tool and a "cool down" key that contains the
|
|
54
|
+
time when the last tool ended.
|
|
52
55
|
"""
|
|
53
56
|
|
|
54
57
|
|
lemonade/tools/humaneval.py
CHANGED
|
@@ -24,7 +24,7 @@ class AccuracyHumaneval(Tool):
|
|
|
24
24
|
- pass@10: Percentage of problems solved within 10 generation attempts
|
|
25
25
|
- pass@100: Percentage of problems solved within 100 generation attempts
|
|
26
26
|
|
|
27
|
-
See docs/
|
|
27
|
+
See docs/dev_cli/humaneval_accuracy.md for more details
|
|
28
28
|
"""
|
|
29
29
|
|
|
30
30
|
unique_name = "accuracy-humaneval"
|
lemonade/tools/mmlu.py
CHANGED
lemonade/tools/oga/load.py
CHANGED
|
@@ -1,12 +1,6 @@
|
|
|
1
1
|
# onnxruntime_genai is not lint-friendly yet and PyLint can't
|
|
2
2
|
# find any of the class methods
|
|
3
3
|
# pylint: disable=no-member
|
|
4
|
-
#
|
|
5
|
-
# Model builder constraints:
|
|
6
|
-
# 11/10/24 Need transformers <4.45.0 OR onnxruntime-genai 0.5.0 (which must be built from source)
|
|
7
|
-
# (transformers v4.45 changes the format of the tokenizer.json file which will be supported in
|
|
8
|
-
# onnxruntime-genai 0.5)
|
|
9
|
-
#
|
|
10
4
|
|
|
11
5
|
import argparse
|
|
12
6
|
import os
|
|
@@ -51,8 +45,8 @@ def import_error_heler(e: Exception):
|
|
|
51
45
|
"""
|
|
52
46
|
raise ImportError(
|
|
53
47
|
f"{e}\n Please install lemonade-sdk with "
|
|
54
|
-
"one of the
|
|
55
|
-
"pip install lemonade-sdk[
|
|
48
|
+
"one of the oga extras, for example:\n"
|
|
49
|
+
"pip install lemonade-sdk[dev,oga-cpu]\n"
|
|
56
50
|
"See https://lemonade_server.ai/install_options.html for details"
|
|
57
51
|
)
|
|
58
52
|
|
|
@@ -64,7 +58,7 @@ class OgaLoad(FirstTool):
|
|
|
64
58
|
Input: path to a checkpoint.
|
|
65
59
|
Supported choices for cpu and igpu from HF model repository:
|
|
66
60
|
LLM models on Huggingface supported by model_builder. See documentation
|
|
67
|
-
(https://github.com/lemonade-sdk/lemonade/blob/main/docs/ort_genai_igpu.md)
|
|
61
|
+
(https://github.com/lemonade-sdk/lemonade/blob/main/docs/dev_cli/ort_genai_igpu.md)
|
|
68
62
|
for supported models.
|
|
69
63
|
Supported choices for npu from HF model repository:
|
|
70
64
|
Models on Hugging Face that follow the "amd/**-onnx-ryzen-strix" pattern
|
lemonade/tools/perplexity.py
CHANGED
|
@@ -17,7 +17,7 @@ class AccuracyPerplexity(Tool):
|
|
|
17
17
|
|
|
18
18
|
Output state produced: None
|
|
19
19
|
|
|
20
|
-
See docs/
|
|
20
|
+
See docs/dev_cli/perplexity.md for more details.
|
|
21
21
|
"""
|
|
22
22
|
|
|
23
23
|
unique_name = "accuracy-perplexity"
|
|
@@ -63,7 +63,7 @@ class AccuracyPerplexity(Tool):
|
|
|
63
63
|
# try-except will allow a few more LLMs to work
|
|
64
64
|
max_length = 2048
|
|
65
65
|
# Set stride to half of the maximum input length for overlapping window processing
|
|
66
|
-
# Refer to docs/perplexity.md for more information on sliding window
|
|
66
|
+
# Refer to docs/dev_cli/perplexity.md for more information on sliding window
|
|
67
67
|
stride = max_length // 2
|
|
68
68
|
# Determine the total sequence length of the tokenized input
|
|
69
69
|
seq_len = encodings.input_ids.size(1)
|
lemonade/tools/prompt.py
CHANGED
|
@@ -176,12 +176,21 @@ class LLMPrompt(Tool):
|
|
|
176
176
|
|
|
177
177
|
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
|
178
178
|
if isinstance(input_ids, (list, str)):
|
|
179
|
-
# OGA models return a list of tokens
|
|
179
|
+
# OGA models return a list of tokens (older versions)
|
|
180
180
|
# Our llama.cpp adapter returns a string
|
|
181
181
|
len_tokens_in = len(input_ids)
|
|
182
|
-
|
|
182
|
+
elif hasattr(input_ids, "shape"):
|
|
183
183
|
# HF models return a 2-D tensor
|
|
184
|
-
|
|
184
|
+
# OGA models with newer versions may return numpy arrays
|
|
185
|
+
if len(input_ids.shape) == 1:
|
|
186
|
+
# 1-D array from newer OGA versions
|
|
187
|
+
len_tokens_in = len(input_ids)
|
|
188
|
+
else:
|
|
189
|
+
# 2-D tensor from HF models
|
|
190
|
+
len_tokens_in = input_ids.shape[1]
|
|
191
|
+
else:
|
|
192
|
+
# Fallback: try to get length directly
|
|
193
|
+
len_tokens_in = len(input_ids)
|
|
185
194
|
|
|
186
195
|
len_tokens_out = []
|
|
187
196
|
response_texts = []
|
|
@@ -202,9 +211,15 @@ class LLMPrompt(Tool):
|
|
|
202
211
|
random_seed += 1
|
|
203
212
|
|
|
204
213
|
# Flatten the input and response
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
)
|
|
214
|
+
if isinstance(input_ids, (list, str)):
|
|
215
|
+
input_ids_array = input_ids
|
|
216
|
+
elif hasattr(input_ids, "shape") and len(input_ids.shape) == 1:
|
|
217
|
+
# 1-D array from newer OGA versions - already flat
|
|
218
|
+
input_ids_array = input_ids
|
|
219
|
+
else:
|
|
220
|
+
# 2-D tensor from HF models - take first row
|
|
221
|
+
input_ids_array = input_ids[0]
|
|
222
|
+
|
|
208
223
|
response_array = response if isinstance(response, str) else response[0]
|
|
209
224
|
|
|
210
225
|
# Separate the prompt from the response
|
|
@@ -25,7 +25,7 @@ class QuarkQuantize(Tool):
|
|
|
25
25
|
Output:
|
|
26
26
|
- Modifies `state` with quantized and optionally exported model.
|
|
27
27
|
|
|
28
|
-
See docs/quark.md for more details.
|
|
28
|
+
See docs/dev_cli/quark.md for more details.
|
|
29
29
|
"""
|
|
30
30
|
|
|
31
31
|
unique_name = "quark-quantize"
|
|
@@ -94,7 +94,7 @@ class QuarkQuantize(Tool):
|
|
|
94
94
|
help="Number of samples for calibration.",
|
|
95
95
|
)
|
|
96
96
|
|
|
97
|
-
# See docs/quark.md for more details.
|
|
97
|
+
# See docs/dev_cli/quark.md for more details.
|
|
98
98
|
parser.add_argument(
|
|
99
99
|
"--quant-scheme",
|
|
100
100
|
type=str,
|
lemonade/tools/report/table.py
CHANGED
|
@@ -74,6 +74,7 @@ class SimpleStat(TableColumn):
|
|
|
74
74
|
align="center",
|
|
75
75
|
omit_if_lean=False,
|
|
76
76
|
wrap=None,
|
|
77
|
+
stat_fn=None,
|
|
77
78
|
):
|
|
78
79
|
self.column_header = column_header
|
|
79
80
|
self.stat = stat
|
|
@@ -81,6 +82,7 @@ class SimpleStat(TableColumn):
|
|
|
81
82
|
self.align = align
|
|
82
83
|
self.omit_if_lean = omit_if_lean
|
|
83
84
|
self.wrap = wrap or self.default_wrap
|
|
85
|
+
self.stat_fn = stat_fn
|
|
84
86
|
|
|
85
87
|
def get_str(self, build_stats, lean=False):
|
|
86
88
|
if lean and self.omit_if_lean:
|
|
@@ -88,6 +90,8 @@ class SimpleStat(TableColumn):
|
|
|
88
90
|
data = build_stats.get(self.stat, None)
|
|
89
91
|
if data is None:
|
|
90
92
|
return ""
|
|
93
|
+
if self.stat_fn:
|
|
94
|
+
data = self.stat_fn(data)
|
|
91
95
|
cell_str = "\n".join(
|
|
92
96
|
[_wrap(f"{x:{self.format_str}}", self.wrap) for x in _to_list(data)]
|
|
93
97
|
)
|
|
@@ -233,6 +237,47 @@ class AdditionalStat(TableColumn):
|
|
|
233
237
|
return "\n".join(cell_entry)
|
|
234
238
|
|
|
235
239
|
|
|
240
|
+
class DictListStat(TableColumn):
|
|
241
|
+
"""
|
|
242
|
+
A statistic that is a list of dicts and values from a given list of keys will be
|
|
243
|
+
pulled out of each dict and placed in the cell
|
|
244
|
+
"""
|
|
245
|
+
|
|
246
|
+
def __init__(
|
|
247
|
+
self,
|
|
248
|
+
column_header,
|
|
249
|
+
statistic_name,
|
|
250
|
+
key_format_list,
|
|
251
|
+
align="center",
|
|
252
|
+
omit_if_lean=False,
|
|
253
|
+
wrap=None,
|
|
254
|
+
):
|
|
255
|
+
self.column_header = column_header
|
|
256
|
+
self.statistic_name = statistic_name
|
|
257
|
+
self.key_format_list = key_format_list
|
|
258
|
+
self.align = align
|
|
259
|
+
self.omit_if_lean = omit_if_lean
|
|
260
|
+
self.wrap = wrap or self.default_wrap
|
|
261
|
+
|
|
262
|
+
def get_str(self, build_stats, lean=False):
|
|
263
|
+
if lean and self.omit_if_lean:
|
|
264
|
+
return None
|
|
265
|
+
stat = build_stats.get(self.statistic_name, None)
|
|
266
|
+
if not stat:
|
|
267
|
+
return ""
|
|
268
|
+
cell_entry = []
|
|
269
|
+
for stat_dict in stat:
|
|
270
|
+
line = [
|
|
271
|
+
format_str.format(stat_dict[key])
|
|
272
|
+
for key, format_str in self.key_format_list
|
|
273
|
+
]
|
|
274
|
+
cell_entry.append(" ".join(line))
|
|
275
|
+
return "\n".join(cell_entry)
|
|
276
|
+
|
|
277
|
+
def get_keys(self):
|
|
278
|
+
return [self.statistic_name]
|
|
279
|
+
|
|
280
|
+
|
|
236
281
|
################################################################################
|
|
237
282
|
# ABSTRACT BASE CLASS FOR DEFINING A TABLE
|
|
238
283
|
################################################################################
|
|
@@ -350,6 +395,28 @@ class Table(ABC):
|
|
|
350
395
|
headers.append(column.column_header)
|
|
351
396
|
col_align += (column.align,)
|
|
352
397
|
|
|
398
|
+
# Stat column headers
|
|
399
|
+
stat_columns = self.table_descriptor.get("stat_columns", [])
|
|
400
|
+
stat_columns_include = []
|
|
401
|
+
for column in stat_columns:
|
|
402
|
+
# Check to see that at least one build has data for the column
|
|
403
|
+
keep_column = False
|
|
404
|
+
if not (self.lean and column.omit_if_lean):
|
|
405
|
+
keys = column.get_keys()
|
|
406
|
+
for build_stats in self.all_stats:
|
|
407
|
+
found = [(key in build_stats) for key in keys]
|
|
408
|
+
if any(found):
|
|
409
|
+
keep_column = True
|
|
410
|
+
headers.append(column.column_header)
|
|
411
|
+
col_align += (column.align,)
|
|
412
|
+
break
|
|
413
|
+
stat_columns_include.append(keep_column)
|
|
414
|
+
stat_columns = [
|
|
415
|
+
column
|
|
416
|
+
for column, include in zip(stat_columns, stat_columns_include)
|
|
417
|
+
if include
|
|
418
|
+
]
|
|
419
|
+
|
|
353
420
|
# Final headers
|
|
354
421
|
last_columns = self.table_descriptor.get("last_columns", [])
|
|
355
422
|
for column in last_columns:
|
|
@@ -386,6 +453,12 @@ class Table(ABC):
|
|
|
386
453
|
if entry_str is not None:
|
|
387
454
|
row.append(entry_str)
|
|
388
455
|
|
|
456
|
+
# Per stat columns
|
|
457
|
+
for entry in stat_columns:
|
|
458
|
+
entry_str = entry.get_str(build_stats, self.lean)
|
|
459
|
+
if entry_str is not None:
|
|
460
|
+
row.append(entry_str)
|
|
461
|
+
|
|
389
462
|
# Final columns
|
|
390
463
|
for entry in last_columns:
|
|
391
464
|
entry_str = entry.get_str(build_stats, self.lean)
|
|
@@ -514,6 +587,12 @@ class LemonadePerfTable(Table):
|
|
|
514
587
|
Keys.STD_DEV_TOKENS_PER_SECOND,
|
|
515
588
|
".2f",
|
|
516
589
|
),
|
|
590
|
+
SimpleStat(
|
|
591
|
+
_wrap("Total Generated Tokens", 9),
|
|
592
|
+
Keys.RESPONSE_TOKENS,
|
|
593
|
+
"d",
|
|
594
|
+
stat_fn=sum,
|
|
595
|
+
),
|
|
517
596
|
SimpleStat(
|
|
518
597
|
_wrap("Memory Used (GB)", 8), Keys.MAX_MEMORY_USED_GBYTE, ".3f"
|
|
519
598
|
),
|
|
@@ -537,6 +616,7 @@ class LemonadePerfTable(Table):
|
|
|
537
616
|
)
|
|
538
617
|
],
|
|
539
618
|
},
|
|
619
|
+
"stat_columns": [],
|
|
540
620
|
"last_columns": [
|
|
541
621
|
SimpleStat(
|
|
542
622
|
"System Info",
|
|
@@ -16,11 +16,29 @@ from fastapi.responses import StreamingResponse
|
|
|
16
16
|
|
|
17
17
|
from openai import OpenAI
|
|
18
18
|
|
|
19
|
-
from lemonade_server.pydantic_models import
|
|
19
|
+
from lemonade_server.pydantic_models import (
|
|
20
|
+
ChatCompletionRequest,
|
|
21
|
+
PullConfig,
|
|
22
|
+
EmbeddingsRequest,
|
|
23
|
+
RerankingRequest,
|
|
24
|
+
)
|
|
20
25
|
from lemonade_server.model_manager import ModelManager
|
|
21
26
|
from lemonade.tools.server.utils.port import find_free_port
|
|
22
27
|
|
|
23
|
-
LLAMA_VERSION = "
|
|
28
|
+
LLAMA_VERSION = "b5787"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def llamacpp_address(port: int) -> str:
|
|
32
|
+
"""
|
|
33
|
+
Generate the base URL for the llamacpp server.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
port: The port number the llamacpp server is running on
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
The base URL for the llamacpp server
|
|
40
|
+
"""
|
|
41
|
+
return f"http://127.0.0.1:{port}/v1"
|
|
24
42
|
|
|
25
43
|
|
|
26
44
|
def get_llama_server_paths():
|
|
@@ -210,15 +228,20 @@ def _log_subprocess_output(
|
|
|
210
228
|
"""
|
|
211
229
|
|
|
212
230
|
if process.stdout:
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
231
|
+
try:
|
|
232
|
+
for line in iter(process.stdout.readline, ""):
|
|
233
|
+
if line:
|
|
234
|
+
line_stripped = line.strip()
|
|
235
|
+
logging.debug("%s: %s", prefix, line_stripped)
|
|
217
236
|
|
|
218
|
-
|
|
237
|
+
telemetry.parse_telemetry_line(line_stripped)
|
|
219
238
|
|
|
220
|
-
|
|
221
|
-
|
|
239
|
+
if process.poll() is not None:
|
|
240
|
+
break
|
|
241
|
+
except UnicodeDecodeError as e:
|
|
242
|
+
logging.debug("Unicode decode error reading subprocess output: %s", str(e))
|
|
243
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
|
244
|
+
logging.error("Unexpected error reading subprocess output: %s", str(e))
|
|
222
245
|
|
|
223
246
|
|
|
224
247
|
def _wait_for_load(llama_server_process: subprocess.Popen, port: int):
|
|
@@ -239,10 +262,24 @@ def _wait_for_load(llama_server_process: subprocess.Popen, port: int):
|
|
|
239
262
|
|
|
240
263
|
|
|
241
264
|
def _launch_llama_subprocess(
|
|
242
|
-
snapshot_files: dict,
|
|
265
|
+
snapshot_files: dict,
|
|
266
|
+
use_gpu: bool,
|
|
267
|
+
telemetry: LlamaTelemetry,
|
|
268
|
+
supports_embeddings: bool = False,
|
|
269
|
+
supports_reranking: bool = False,
|
|
243
270
|
) -> subprocess.Popen:
|
|
244
271
|
"""
|
|
245
|
-
Launch llama server subprocess with
|
|
272
|
+
Launch llama server subprocess with appropriate configuration.
|
|
273
|
+
|
|
274
|
+
Args:
|
|
275
|
+
snapshot_files: Dictionary of model files to load
|
|
276
|
+
use_gpu: Whether to use GPU acceleration
|
|
277
|
+
telemetry: Telemetry object for tracking performance metrics
|
|
278
|
+
supports_embeddings: Whether the model supports embeddings
|
|
279
|
+
supports_reranking: Whether the model supports reranking
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
Subprocess handle for the llama server
|
|
246
283
|
"""
|
|
247
284
|
|
|
248
285
|
# Get the current executable path (handles both Windows and Ubuntu structures)
|
|
@@ -266,6 +303,14 @@ def _launch_llama_subprocess(
|
|
|
266
303
|
# reasoning_content field
|
|
267
304
|
base_command.extend(["--reasoning-format", "none"])
|
|
268
305
|
|
|
306
|
+
# Add embeddings support if the model supports it
|
|
307
|
+
if supports_embeddings:
|
|
308
|
+
base_command.append("--embeddings")
|
|
309
|
+
|
|
310
|
+
# Add reranking support if the model supports it
|
|
311
|
+
if supports_reranking:
|
|
312
|
+
base_command.append("--reranking")
|
|
313
|
+
|
|
269
314
|
# Configure GPU layers: 99 for GPU, 0 for CPU-only
|
|
270
315
|
ngl_value = "99" if use_gpu else "0"
|
|
271
316
|
command = base_command + ["-ngl", ngl_value]
|
|
@@ -287,6 +332,8 @@ def _launch_llama_subprocess(
|
|
|
287
332
|
stdout=subprocess.PIPE,
|
|
288
333
|
stderr=subprocess.STDOUT,
|
|
289
334
|
text=True,
|
|
335
|
+
encoding="utf-8",
|
|
336
|
+
errors="replace",
|
|
290
337
|
bufsize=1,
|
|
291
338
|
env=env,
|
|
292
339
|
)
|
|
@@ -303,7 +350,6 @@ def _launch_llama_subprocess(
|
|
|
303
350
|
|
|
304
351
|
|
|
305
352
|
def server_load(model_config: PullConfig, telemetry: LlamaTelemetry):
|
|
306
|
-
|
|
307
353
|
# Validate platform support before proceeding
|
|
308
354
|
validate_platform_support()
|
|
309
355
|
|
|
@@ -360,15 +406,26 @@ def server_load(model_config: PullConfig, telemetry: LlamaTelemetry):
|
|
|
360
406
|
logging.info("Cleaned up zip file")
|
|
361
407
|
|
|
362
408
|
# Download the gguf to the hugging face cache
|
|
363
|
-
|
|
409
|
+
model_manager = ModelManager()
|
|
410
|
+
snapshot_files = model_manager.download_gguf(model_config)
|
|
364
411
|
logging.debug(f"GGUF file paths: {snapshot_files}")
|
|
365
412
|
|
|
413
|
+
# Check if model supports embeddings
|
|
414
|
+
supported_models = model_manager.supported_models
|
|
415
|
+
model_info = supported_models.get(model_config.model_name, {})
|
|
416
|
+
supports_embeddings = "embeddings" in model_info.get("labels", [])
|
|
417
|
+
supports_reranking = "reranking" in model_info.get("labels", [])
|
|
418
|
+
|
|
366
419
|
# Start the llama-serve.exe process
|
|
367
420
|
logging.debug(f"Using llama_server for GGUF model: {llama_server_exe_path}")
|
|
368
421
|
|
|
369
422
|
# Attempt loading on GPU first
|
|
370
423
|
llama_server_process = _launch_llama_subprocess(
|
|
371
|
-
snapshot_files,
|
|
424
|
+
snapshot_files,
|
|
425
|
+
use_gpu=True,
|
|
426
|
+
telemetry=telemetry,
|
|
427
|
+
supports_embeddings=supports_embeddings,
|
|
428
|
+
supports_reranking=supports_reranking,
|
|
372
429
|
)
|
|
373
430
|
|
|
374
431
|
# Check the /health endpoint until GPU server is ready
|
|
@@ -383,8 +440,16 @@ def server_load(model_config: PullConfig, telemetry: LlamaTelemetry):
|
|
|
383
440
|
f"Loading {model_config.model_name} on GPU didn't work, re-attempting on CPU"
|
|
384
441
|
)
|
|
385
442
|
|
|
443
|
+
if os.environ.get("LEMONADE_LLAMACPP_NO_FALLBACK"):
|
|
444
|
+
# Used for testing, when the test should fail if GPU didn't work
|
|
445
|
+
raise Exception("llamacpp GPU loading failed")
|
|
446
|
+
|
|
386
447
|
llama_server_process = _launch_llama_subprocess(
|
|
387
|
-
snapshot_files,
|
|
448
|
+
snapshot_files,
|
|
449
|
+
use_gpu=False,
|
|
450
|
+
telemetry=telemetry,
|
|
451
|
+
supports_embeddings=supports_embeddings,
|
|
452
|
+
supports_reranking=supports_reranking,
|
|
388
453
|
)
|
|
389
454
|
|
|
390
455
|
# Check the /health endpoint until CPU server is ready
|
|
@@ -405,7 +470,7 @@ def server_load(model_config: PullConfig, telemetry: LlamaTelemetry):
|
|
|
405
470
|
def chat_completion(
|
|
406
471
|
chat_completion_request: ChatCompletionRequest, telemetry: LlamaTelemetry
|
|
407
472
|
):
|
|
408
|
-
base_url =
|
|
473
|
+
base_url = llamacpp_address(telemetry.port)
|
|
409
474
|
client = OpenAI(
|
|
410
475
|
base_url=base_url,
|
|
411
476
|
api_key="lemonade",
|
|
@@ -456,3 +521,70 @@ def chat_completion(
|
|
|
456
521
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
457
522
|
detail=f"Chat completion error: {str(e)}",
|
|
458
523
|
)
|
|
524
|
+
|
|
525
|
+
|
|
526
|
+
def embeddings(embeddings_request: EmbeddingsRequest, telemetry: LlamaTelemetry):
|
|
527
|
+
"""
|
|
528
|
+
Generate embeddings using the llamacpp server.
|
|
529
|
+
|
|
530
|
+
Args:
|
|
531
|
+
embeddings_request: The embeddings request containing input text/tokens
|
|
532
|
+
telemetry: Telemetry object containing the server port
|
|
533
|
+
|
|
534
|
+
Returns:
|
|
535
|
+
Embeddings response from the llamacpp server
|
|
536
|
+
"""
|
|
537
|
+
base_url = llamacpp_address(telemetry.port)
|
|
538
|
+
client = OpenAI(
|
|
539
|
+
base_url=base_url,
|
|
540
|
+
api_key="lemonade",
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
# Convert Pydantic model to dict and remove unset/null values
|
|
544
|
+
request_dict = embeddings_request.model_dump(exclude_unset=True, exclude_none=True)
|
|
545
|
+
|
|
546
|
+
try:
|
|
547
|
+
# Call the embeddings endpoint
|
|
548
|
+
response = client.embeddings.create(**request_dict)
|
|
549
|
+
return response
|
|
550
|
+
|
|
551
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
|
552
|
+
raise HTTPException(
|
|
553
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
554
|
+
detail=f"Embeddings error: {str(e)}",
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
def reranking(reranking_request: RerankingRequest, telemetry: LlamaTelemetry):
|
|
559
|
+
"""
|
|
560
|
+
Rerank documents based on their relevance to a query using the llamacpp server.
|
|
561
|
+
|
|
562
|
+
Args:
|
|
563
|
+
reranking_request: The reranking request containing query and documents
|
|
564
|
+
telemetry: Telemetry object containing the server port
|
|
565
|
+
|
|
566
|
+
Returns:
|
|
567
|
+
Reranking response from the llamacpp server containing ranked documents and scores
|
|
568
|
+
"""
|
|
569
|
+
base_url = llamacpp_address(telemetry.port)
|
|
570
|
+
|
|
571
|
+
try:
|
|
572
|
+
# Convert Pydantic model to dict and exclude unset/null values
|
|
573
|
+
request_dict = reranking_request.model_dump(
|
|
574
|
+
exclude_unset=True, exclude_none=True
|
|
575
|
+
)
|
|
576
|
+
|
|
577
|
+
# Call the reranking endpoint directly since it's not supported by the OpenAI API
|
|
578
|
+
response = requests.post(
|
|
579
|
+
f"{base_url}/rerank",
|
|
580
|
+
json=request_dict,
|
|
581
|
+
)
|
|
582
|
+
response.raise_for_status()
|
|
583
|
+
return response.json()
|
|
584
|
+
|
|
585
|
+
except Exception as e:
|
|
586
|
+
logging.error("Error during reranking: %s", str(e))
|
|
587
|
+
raise HTTPException(
|
|
588
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
589
|
+
detail=f"Reranking error: {str(e)}",
|
|
590
|
+
) from e
|
lemonade/tools/server/serve.py
CHANGED
|
@@ -54,6 +54,8 @@ from lemonade_server.pydantic_models import (
|
|
|
54
54
|
LoadConfig,
|
|
55
55
|
CompletionRequest,
|
|
56
56
|
ChatCompletionRequest,
|
|
57
|
+
EmbeddingsRequest,
|
|
58
|
+
RerankingRequest,
|
|
57
59
|
ResponsesRequest,
|
|
58
60
|
PullConfig,
|
|
59
61
|
DeleteConfig,
|
|
@@ -231,8 +233,13 @@ class Server(ManagementTool):
|
|
|
231
233
|
|
|
232
234
|
# OpenAI-compatible routes
|
|
233
235
|
self.app.post(f"{prefix}/chat/completions")(self.chat_completions)
|
|
236
|
+
self.app.post(f"{prefix}/embeddings")(self.embeddings)
|
|
234
237
|
self.app.get(f"{prefix}/models")(self.models)
|
|
235
238
|
|
|
239
|
+
# JinaAI routes (jina.ai/reranker/)
|
|
240
|
+
self.app.post(f"{prefix}/reranking")(self.reranking)
|
|
241
|
+
self.app.post(f"{prefix}/rerank")(self.reranking)
|
|
242
|
+
|
|
236
243
|
@staticmethod
|
|
237
244
|
def parser(add_help: bool = True) -> argparse.ArgumentParser:
|
|
238
245
|
parser = __class__.helpful_parser(
|
|
@@ -796,6 +803,72 @@ class Server(ManagementTool):
|
|
|
796
803
|
created=int(time.time()),
|
|
797
804
|
)
|
|
798
805
|
|
|
806
|
+
async def embeddings(self, embeddings_request: EmbeddingsRequest):
|
|
807
|
+
"""
|
|
808
|
+
Generate embeddings for the provided input.
|
|
809
|
+
"""
|
|
810
|
+
# Initialize load config from embeddings request
|
|
811
|
+
lc = LoadConfig(model_name=embeddings_request.model)
|
|
812
|
+
|
|
813
|
+
# Load the model if it's different from the currently loaded one
|
|
814
|
+
await self.load_llm(lc)
|
|
815
|
+
|
|
816
|
+
if self.llm_loaded.recipe == "llamacpp":
|
|
817
|
+
try:
|
|
818
|
+
return llamacpp.embeddings(embeddings_request, self.llama_telemetry)
|
|
819
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
|
820
|
+
# Check if model has embeddings label
|
|
821
|
+
model_info = ModelManager().supported_models.get(
|
|
822
|
+
self.llm_loaded.model_name, {}
|
|
823
|
+
)
|
|
824
|
+
if "embeddings" not in model_info.get("labels", []):
|
|
825
|
+
raise HTTPException(
|
|
826
|
+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
827
|
+
detail="You tried to generate embeddings for a model that is "
|
|
828
|
+
"not labeled as an embeddings model. Please use another model "
|
|
829
|
+
"or re-register the current model with the 'embeddings' label.",
|
|
830
|
+
) from e
|
|
831
|
+
else:
|
|
832
|
+
raise e
|
|
833
|
+
else:
|
|
834
|
+
raise HTTPException(
|
|
835
|
+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
836
|
+
detail=f"Embeddings not supported for recipe: {self.llm_loaded.recipe}",
|
|
837
|
+
)
|
|
838
|
+
|
|
839
|
+
async def reranking(self, reranking_request: RerankingRequest):
|
|
840
|
+
"""
|
|
841
|
+
Rerank documents based on their relevance to a query using the llamacpp server.
|
|
842
|
+
"""
|
|
843
|
+
# Initialize load config from reranking request
|
|
844
|
+
lc = LoadConfig(model_name=reranking_request.model)
|
|
845
|
+
|
|
846
|
+
# Load the model if it's different from the currently loaded one
|
|
847
|
+
await self.load_llm(lc)
|
|
848
|
+
|
|
849
|
+
if self.llm_loaded.recipe == "llamacpp":
|
|
850
|
+
try:
|
|
851
|
+
return llamacpp.reranking(reranking_request, self.llama_telemetry)
|
|
852
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
|
853
|
+
# Check if model has reranking label
|
|
854
|
+
model_info = ModelManager().supported_models.get(
|
|
855
|
+
self.llm_loaded.model_name, {}
|
|
856
|
+
)
|
|
857
|
+
if "reranking" not in model_info.get("labels", []):
|
|
858
|
+
raise HTTPException(
|
|
859
|
+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
860
|
+
detail="You tried to use reranking for a model that is "
|
|
861
|
+
"not labeled as a reranking model. Please use another model "
|
|
862
|
+
"or re-register the current model with the 'reranking' label.",
|
|
863
|
+
) from e
|
|
864
|
+
else:
|
|
865
|
+
raise e
|
|
866
|
+
else:
|
|
867
|
+
raise HTTPException(
|
|
868
|
+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
869
|
+
detail=f"Reranking not supported for recipe: {self.llm_loaded.recipe}",
|
|
870
|
+
)
|
|
871
|
+
|
|
799
872
|
def apply_chat_template(
|
|
800
873
|
self, messages: list[dict], tools: list[dict] | None = None
|
|
801
874
|
):
|