lemonade-sdk 7.0.4__py3-none-any.whl → 8.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lemonade-sdk might be problematic. Click here for more details.
- lemonade/api.py +3 -3
- lemonade/cli.py +11 -17
- lemonade/common/build.py +0 -47
- lemonade/common/network.py +50 -0
- lemonade/common/status.py +2 -21
- lemonade/common/system_info.py +19 -4
- lemonade/profilers/memory_tracker.py +3 -1
- lemonade/tools/accuracy.py +3 -4
- lemonade/tools/adapter.py +1 -2
- lemonade/tools/{huggingface_bench.py → huggingface/bench.py} +2 -87
- lemonade/tools/huggingface/load.py +235 -0
- lemonade/tools/{huggingface_load.py → huggingface/utils.py} +87 -255
- lemonade/tools/humaneval.py +9 -3
- lemonade/tools/{llamacpp_bench.py → llamacpp/bench.py} +1 -1
- lemonade/tools/{llamacpp.py → llamacpp/load.py} +18 -2
- lemonade/tools/mmlu.py +7 -15
- lemonade/tools/{ort_genai/oga.py → oga/load.py} +31 -422
- lemonade/tools/oga/utils.py +423 -0
- lemonade/tools/perplexity.py +4 -3
- lemonade/tools/prompt.py +2 -1
- lemonade/tools/quark/quark_load.py +2 -1
- lemonade/tools/quark/quark_quantize.py +5 -5
- lemonade/tools/report/table.py +3 -3
- lemonade/tools/server/llamacpp.py +154 -29
- lemonade/tools/server/serve.py +169 -146
- lemonade/tools/server/static/favicon.ico +0 -0
- lemonade/tools/server/static/styles.css +568 -0
- lemonade/tools/server/static/webapp.html +439 -0
- lemonade/tools/server/tray.py +458 -0
- lemonade/tools/server/{port_utils.py → utils/port.py} +22 -3
- lemonade/tools/server/utils/system_tray.py +395 -0
- lemonade/tools/server/{instructions.py → webapp.py} +4 -10
- lemonade/version.py +1 -1
- lemonade_install/install.py +46 -28
- {lemonade_sdk-7.0.4.dist-info → lemonade_sdk-8.0.0.dist-info}/METADATA +84 -22
- lemonade_sdk-8.0.0.dist-info/RECORD +70 -0
- lemonade_server/cli.py +182 -27
- lemonade_server/model_manager.py +192 -20
- lemonade_server/pydantic_models.py +9 -4
- lemonade_server/server_models.json +5 -3
- lemonade/common/analyze_model.py +0 -26
- lemonade/common/labels.py +0 -61
- lemonade/common/onnx_helpers.py +0 -176
- lemonade/common/plugins.py +0 -10
- lemonade/common/tensor_helpers.py +0 -83
- lemonade/tools/server/static/instructions.html +0 -262
- lemonade_sdk-7.0.4.dist-info/RECORD +0 -69
- /lemonade/tools/{ort_genai → oga}/__init__.py +0 -0
- /lemonade/tools/{ort_genai/oga_bench.py → oga/bench.py} +0 -0
- /lemonade/tools/server/{thread_utils.py → utils/thread.py} +0 -0
- {lemonade_sdk-7.0.4.dist-info → lemonade_sdk-8.0.0.dist-info}/WHEEL +0 -0
- {lemonade_sdk-7.0.4.dist-info → lemonade_sdk-8.0.0.dist-info}/entry_points.txt +0 -0
- {lemonade_sdk-7.0.4.dist-info → lemonade_sdk-8.0.0.dist-info}/licenses/LICENSE +0 -0
- {lemonade_sdk-7.0.4.dist-info → lemonade_sdk-8.0.0.dist-info}/licenses/NOTICE.md +0 -0
- {lemonade_sdk-7.0.4.dist-info → lemonade_sdk-8.0.0.dist-info}/top_level.txt +0 -0
lemonade/tools/server/serve.py
CHANGED
|
@@ -1,9 +1,12 @@
|
|
|
1
|
+
import sys
|
|
1
2
|
import argparse
|
|
2
3
|
import asyncio
|
|
3
4
|
import statistics
|
|
4
5
|
import time
|
|
5
6
|
from threading import Thread, Event
|
|
6
7
|
import logging
|
|
8
|
+
import platform
|
|
9
|
+
import tempfile
|
|
7
10
|
import traceback
|
|
8
11
|
from typing import Optional, Union
|
|
9
12
|
import json
|
|
@@ -17,7 +20,6 @@ from fastapi.staticfiles import StaticFiles
|
|
|
17
20
|
import uvicorn
|
|
18
21
|
from uvicorn.config import Config
|
|
19
22
|
from uvicorn.server import Server as UvicornServer
|
|
20
|
-
from transformers import TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList
|
|
21
23
|
from tabulate import tabulate
|
|
22
24
|
|
|
23
25
|
from openai.types.completion import Completion, CompletionChoice
|
|
@@ -53,12 +55,17 @@ from lemonade_server.pydantic_models import (
|
|
|
53
55
|
ChatCompletionRequest,
|
|
54
56
|
ResponsesRequest,
|
|
55
57
|
PullConfig,
|
|
58
|
+
DeleteConfig,
|
|
56
59
|
)
|
|
57
60
|
from lemonade.tools.management_tools import ManagementTool
|
|
58
61
|
import lemonade.tools.server.llamacpp as llamacpp
|
|
59
62
|
from lemonade.tools.server.tool_calls import extract_tool_calls, get_tool_call_pattern
|
|
60
|
-
from lemonade.tools.server.
|
|
61
|
-
from lemonade.tools.server.
|
|
63
|
+
from lemonade.tools.server.webapp import get_webapp_html
|
|
64
|
+
from lemonade.tools.server.utils.port import lifespan
|
|
65
|
+
|
|
66
|
+
# Only import tray on Windows
|
|
67
|
+
if platform.system() == "Windows":
|
|
68
|
+
from lemonade.tools.server.tray import LemonadeTray, OutputDuplicator
|
|
62
69
|
|
|
63
70
|
DEFAULT_PORT = 8000
|
|
64
71
|
DEFAULT_LOG_LEVEL = "info"
|
|
@@ -100,7 +107,7 @@ class GeneratorThread(Thread):
|
|
|
100
107
|
self.streamer.done()
|
|
101
108
|
|
|
102
109
|
|
|
103
|
-
class StopOnEvent
|
|
110
|
+
class StopOnEvent:
|
|
104
111
|
"""
|
|
105
112
|
Custom stopping criteria that halts text generation when a specified event is set.
|
|
106
113
|
|
|
@@ -122,6 +129,7 @@ class Server(ManagementTool):
|
|
|
122
129
|
|
|
123
130
|
The server exposes these endpoints:
|
|
124
131
|
- /api/v1/pull: install an LLM by its Lemonade Server Model Name.
|
|
132
|
+
- /api/v1/delete: delete an LLM by its Lemonade Server Model Name.
|
|
125
133
|
- /api/v1/load: load a model checkpoint.
|
|
126
134
|
- /api/v1/unload: unload a model checkpoint.
|
|
127
135
|
- /api/v1/health: check whether a model is loaded and ready to serve.
|
|
@@ -141,6 +149,10 @@ class Server(ManagementTool):
|
|
|
141
149
|
# Initialize FastAPI app
|
|
142
150
|
self.app = FastAPI(lifespan=lifespan)
|
|
143
151
|
|
|
152
|
+
# Lifespan will load some tasks in the background, and then set the
|
|
153
|
+
# app.initialized flag to True when this is done
|
|
154
|
+
self.app.initialized = False
|
|
155
|
+
|
|
144
156
|
# Add CORS middleware
|
|
145
157
|
self.app.add_middleware(
|
|
146
158
|
CORSMiddleware,
|
|
@@ -153,11 +165,11 @@ class Server(ManagementTool):
|
|
|
153
165
|
# Set up custom routes
|
|
154
166
|
self.setup_routes(["/api/v0", "/api/v1"])
|
|
155
167
|
|
|
156
|
-
# Set up
|
|
157
|
-
self.app.get("/")(self.
|
|
168
|
+
# Set up Web App
|
|
169
|
+
self.app.get("/")(self.webapp)
|
|
158
170
|
|
|
159
171
|
# Mount a static assets dir for HTML responses, such
|
|
160
|
-
# as the
|
|
172
|
+
# as the Web App
|
|
161
173
|
static_dir = Path(__file__).parent / "static"
|
|
162
174
|
self.app.mount(
|
|
163
175
|
"/static", StaticFiles(directory=static_dir), name="static_assets"
|
|
@@ -207,6 +219,7 @@ class Server(ManagementTool):
|
|
|
207
219
|
for prefix in api_prefixes:
|
|
208
220
|
# Custom routes
|
|
209
221
|
self.app.post(f"{prefix}/pull")(self.pull)
|
|
222
|
+
self.app.post(f"{prefix}/delete")(self.delete)
|
|
210
223
|
self.app.post(f"{prefix}/load")(self.load_llm)
|
|
211
224
|
self.app.post(f"{prefix}/unload")(self.unload_llm)
|
|
212
225
|
self.app.get(f"{prefix}/health")(self.health)
|
|
@@ -226,6 +239,14 @@ class Server(ManagementTool):
|
|
|
226
239
|
add_help=add_help,
|
|
227
240
|
)
|
|
228
241
|
|
|
242
|
+
# Only add the tray option on Windows
|
|
243
|
+
if platform.system() == "Windows":
|
|
244
|
+
parser.add_argument(
|
|
245
|
+
"--tray",
|
|
246
|
+
action="store_true",
|
|
247
|
+
help="Run the server in system tray mode",
|
|
248
|
+
)
|
|
249
|
+
|
|
229
250
|
parser.add_argument(
|
|
230
251
|
"--port",
|
|
231
252
|
required=False,
|
|
@@ -242,6 +263,13 @@ class Server(ManagementTool):
|
|
|
242
263
|
help=f"Logging level (default: {DEFAULT_LOG_LEVEL})",
|
|
243
264
|
)
|
|
244
265
|
|
|
266
|
+
parser.add_argument(
|
|
267
|
+
"--log-file",
|
|
268
|
+
required=False,
|
|
269
|
+
type=str,
|
|
270
|
+
help="Path to the log file",
|
|
271
|
+
)
|
|
272
|
+
|
|
245
273
|
return parser
|
|
246
274
|
|
|
247
275
|
def _setup_server_common(
|
|
@@ -249,6 +277,8 @@ class Server(ManagementTool):
|
|
|
249
277
|
port: int,
|
|
250
278
|
truncate_inputs: bool = False,
|
|
251
279
|
log_level: str = DEFAULT_LOG_LEVEL,
|
|
280
|
+
tray: bool = False,
|
|
281
|
+
log_file: str = None,
|
|
252
282
|
threaded_mode: bool = False,
|
|
253
283
|
):
|
|
254
284
|
"""
|
|
@@ -280,23 +310,43 @@ class Server(ManagementTool):
|
|
|
280
310
|
else:
|
|
281
311
|
# Configure logging to match uvicorn's format
|
|
282
312
|
logging_level = getattr(logging, log_level.upper())
|
|
283
|
-
logging.basicConfig(
|
|
284
|
-
level=logging_level,
|
|
285
|
-
format="%(levelprefix)s %(message)s",
|
|
286
|
-
datefmt="%Y-%m-%d %H:%M:%S",
|
|
287
|
-
)
|
|
288
313
|
|
|
289
|
-
#
|
|
290
|
-
|
|
314
|
+
# Set up file handler for logging to lemonade.log
|
|
315
|
+
uvicorn_formatter = uvicorn.logging.DefaultFormatter(
|
|
291
316
|
fmt="%(levelprefix)s %(message)s",
|
|
292
317
|
use_colors=True,
|
|
293
318
|
)
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
319
|
+
if not log_file:
|
|
320
|
+
log_file = tempfile.NamedTemporaryFile(
|
|
321
|
+
prefix="lemonade_", suffix=".log", delete=False
|
|
322
|
+
).name
|
|
323
|
+
file_handler = logging.FileHandler(log_file, mode="a", encoding="utf-8")
|
|
324
|
+
file_handler.setLevel(logging_level)
|
|
325
|
+
file_handler.setFormatter(uvicorn_formatter)
|
|
326
|
+
|
|
327
|
+
# Set up console handler
|
|
328
|
+
console_handler = logging.StreamHandler()
|
|
329
|
+
console_handler.setLevel(logging_level)
|
|
330
|
+
console_handler.setFormatter(uvicorn_formatter)
|
|
331
|
+
|
|
332
|
+
# Configure root logger with both handlers
|
|
333
|
+
logging.basicConfig(
|
|
334
|
+
level=logging_level,
|
|
335
|
+
handlers=[file_handler, console_handler],
|
|
336
|
+
force=True,
|
|
337
|
+
)
|
|
297
338
|
|
|
298
339
|
# Update debug logging state after setting log level
|
|
299
340
|
self.debug_logging_enabled = logging.getLogger().isEnabledFor(logging.DEBUG)
|
|
341
|
+
if tray:
|
|
342
|
+
# Save original stdout/stderr
|
|
343
|
+
sys.stdout = OutputDuplicator(log_file, sys.stdout)
|
|
344
|
+
sys.stderr = OutputDuplicator(log_file, sys.stderr)
|
|
345
|
+
|
|
346
|
+
# Open lemonade server in tray mode
|
|
347
|
+
# lambda function used for deferred instantiation and thread safety
|
|
348
|
+
LemonadeTray(log_file, port, lambda: Server()).run()
|
|
349
|
+
sys.exit(0)
|
|
300
350
|
|
|
301
351
|
if self.debug_logging_enabled:
|
|
302
352
|
# Print the elapsed time for each request
|
|
@@ -314,6 +364,8 @@ class Server(ManagementTool):
|
|
|
314
364
|
port: int = DEFAULT_PORT,
|
|
315
365
|
log_level: str = DEFAULT_LOG_LEVEL,
|
|
316
366
|
truncate_inputs: bool = False,
|
|
367
|
+
tray: bool = False,
|
|
368
|
+
log_file: str = None,
|
|
317
369
|
):
|
|
318
370
|
# Common setup
|
|
319
371
|
self._setup_server_common(
|
|
@@ -321,6 +373,8 @@ class Server(ManagementTool):
|
|
|
321
373
|
truncate_inputs=truncate_inputs,
|
|
322
374
|
log_level=log_level,
|
|
323
375
|
threaded_mode=False,
|
|
376
|
+
tray=tray,
|
|
377
|
+
log_file=log_file,
|
|
324
378
|
)
|
|
325
379
|
|
|
326
380
|
uvicorn.run(self.app, host="localhost", port=port, log_level=log_level)
|
|
@@ -342,6 +396,7 @@ class Server(ManagementTool):
|
|
|
342
396
|
truncate_inputs=truncate_inputs,
|
|
343
397
|
log_level=log_level,
|
|
344
398
|
threaded_mode=True,
|
|
399
|
+
tray=False,
|
|
345
400
|
)
|
|
346
401
|
|
|
347
402
|
class CustomServer(UvicornServer):
|
|
@@ -385,12 +440,12 @@ class Server(ManagementTool):
|
|
|
385
440
|
# Show telemetry in debug while complying with uvicorn's log indentation
|
|
386
441
|
logging.debug("\n ".join(table))
|
|
387
442
|
|
|
388
|
-
def
|
|
443
|
+
def webapp(self):
|
|
389
444
|
"""
|
|
390
|
-
|
|
445
|
+
Serve the Web App to the user's browser.
|
|
391
446
|
"""
|
|
392
447
|
|
|
393
|
-
return
|
|
448
|
+
return get_webapp_html(port=self.app.port)
|
|
394
449
|
|
|
395
450
|
def initialize_load_config(
|
|
396
451
|
self, request: Union[ChatCompletionRequest, CompletionRequest]
|
|
@@ -405,7 +460,8 @@ class Server(ManagementTool):
|
|
|
405
460
|
# Get model config
|
|
406
461
|
if "/" in request.model:
|
|
407
462
|
# We know the model is a Hugging Face checkpoint if it contains a /
|
|
408
|
-
|
|
463
|
+
# This scenario is only supported for run-as-thread
|
|
464
|
+
lc = LoadConfig(model_name="custom", checkpoint=request.model)
|
|
409
465
|
else:
|
|
410
466
|
# The model should be a reference to a built-in model
|
|
411
467
|
lc = LoadConfig(model_name=request.model)
|
|
@@ -420,7 +476,7 @@ class Server(ManagementTool):
|
|
|
420
476
|
lc = self.initialize_load_config(completion_request)
|
|
421
477
|
|
|
422
478
|
# Load the model if it's different from the currently loaded one
|
|
423
|
-
await self.load_llm(lc
|
|
479
|
+
await self.load_llm(lc)
|
|
424
480
|
|
|
425
481
|
# Check if the model supports reasoning
|
|
426
482
|
reasoning_first_token = self.llm_loaded.reasoning
|
|
@@ -539,7 +595,7 @@ class Server(ManagementTool):
|
|
|
539
595
|
lc = self.initialize_load_config(chat_completion_request)
|
|
540
596
|
|
|
541
597
|
# Load the model if it's different from the currently loaded one
|
|
542
|
-
await self.load_llm(lc
|
|
598
|
+
await self.load_llm(lc)
|
|
543
599
|
|
|
544
600
|
if self.llm_loaded.recipe == "llamacpp":
|
|
545
601
|
return llamacpp.chat_completion(
|
|
@@ -758,7 +814,7 @@ class Server(ManagementTool):
|
|
|
758
814
|
lc = self.initialize_load_config(responses_request)
|
|
759
815
|
|
|
760
816
|
# Load the model if it's different from the currently loaded one
|
|
761
|
-
await self.load_llm(lc
|
|
817
|
+
await self.load_llm(lc)
|
|
762
818
|
|
|
763
819
|
# Convert chat messages to text using the model's chat template
|
|
764
820
|
if isinstance(responses_request.input, str):
|
|
@@ -912,6 +968,16 @@ class Server(ManagementTool):
|
|
|
912
968
|
Core streaming completion logic, separated from response handling.
|
|
913
969
|
Returns an async generator that yields tokens.
|
|
914
970
|
"""
|
|
971
|
+
|
|
972
|
+
while not self.app.initialized:
|
|
973
|
+
# Wait for the app's background tasks to finish before
|
|
974
|
+
# allowing generation to proceed
|
|
975
|
+
logging.debug("Waiting for server to fully initialize")
|
|
976
|
+
asyncio.sleep(0.5)
|
|
977
|
+
# These should already be imported as part of the app initialization process,
|
|
978
|
+
# they are just here to make 100% certain and to make the linter happy
|
|
979
|
+
from transformers import TextIteratorStreamer, StoppingCriteriaList
|
|
980
|
+
|
|
915
981
|
model = self.model
|
|
916
982
|
tokenizer = self.tokenizer
|
|
917
983
|
|
|
@@ -930,7 +996,7 @@ class Server(ManagementTool):
|
|
|
930
996
|
|
|
931
997
|
# Set up the generation parameters
|
|
932
998
|
if "oga-" in self.llm_loaded.recipe:
|
|
933
|
-
from lemonade.tools.
|
|
999
|
+
from lemonade.tools.oga.utils import OrtGenaiStreamer
|
|
934
1000
|
|
|
935
1001
|
streamer = OrtGenaiStreamer(tokenizer)
|
|
936
1002
|
self.input_tokens = len(input_ids)
|
|
@@ -969,7 +1035,13 @@ class Server(ManagementTool):
|
|
|
969
1035
|
logging.debug(f"Input Tokens: {self.input_tokens}")
|
|
970
1036
|
logging.trace(f"Input Message: {message}")
|
|
971
1037
|
|
|
972
|
-
|
|
1038
|
+
if self.llm_loaded.recipe.startswith("hf"):
|
|
1039
|
+
stopping_criteria = StoppingCriteriaList([StopOnEvent(self.stop_event)])
|
|
1040
|
+
else:
|
|
1041
|
+
# HF expects StoppingCriteriaList, which requires torch
|
|
1042
|
+
# If we aren't using HF, we can just use a list of StopOnEvent to
|
|
1043
|
+
# avoid the torch dep
|
|
1044
|
+
stopping_criteria = [StopOnEvent(self.stop_event)]
|
|
973
1045
|
|
|
974
1046
|
generation_kwargs = {
|
|
975
1047
|
"input_ids": input_ids,
|
|
@@ -1138,56 +1210,58 @@ class Server(ManagementTool):
|
|
|
1138
1210
|
detail=detail,
|
|
1139
1211
|
)
|
|
1140
1212
|
|
|
1141
|
-
def recipe_missing_error(self, model_reference: str):
|
|
1142
|
-
self.model_load_failure(
|
|
1143
|
-
model_reference,
|
|
1144
|
-
message=(
|
|
1145
|
-
f"Attempted to load model by checkpoint name {model_reference}, "
|
|
1146
|
-
"however the required 'recipe' parameter was not provided"
|
|
1147
|
-
),
|
|
1148
|
-
)
|
|
1149
|
-
|
|
1150
1213
|
async def pull(self, config: PullConfig):
|
|
1151
1214
|
"""
|
|
1152
1215
|
Install a supported LLM by its Lemonade Model Name.
|
|
1153
1216
|
"""
|
|
1154
1217
|
|
|
1155
1218
|
# Install the model
|
|
1156
|
-
ModelManager().download_models(
|
|
1219
|
+
ModelManager().download_models(
|
|
1220
|
+
[config.model_name],
|
|
1221
|
+
checkpoint=config.checkpoint,
|
|
1222
|
+
recipe=config.recipe,
|
|
1223
|
+
reasoning=config.reasoning,
|
|
1224
|
+
mmproj=config.mmproj,
|
|
1225
|
+
)
|
|
1157
1226
|
|
|
1158
1227
|
# Refresh the list of downloaded models, to ensure it
|
|
1159
1228
|
# includes the model we just installed
|
|
1160
1229
|
self.local_models = ModelManager().downloaded_models_enabled
|
|
1161
1230
|
|
|
1162
|
-
async def
|
|
1231
|
+
async def delete(self, config: DeleteConfig):
|
|
1232
|
+
"""
|
|
1233
|
+
Delete a supported LLM by its Lemonade Model Name.
|
|
1234
|
+
"""
|
|
1235
|
+
try:
|
|
1236
|
+
# If the model to be deleted is currently loaded, unload it first
|
|
1237
|
+
if self.llm_loaded and self.llm_loaded.model_name == config.model_name:
|
|
1238
|
+
await self.unload_llm(require_lock=True)
|
|
1239
|
+
|
|
1240
|
+
# Delete the model
|
|
1241
|
+
ModelManager().delete_model(config.model_name)
|
|
1242
|
+
|
|
1243
|
+
# Refresh the list of downloaded models
|
|
1244
|
+
self.local_models = ModelManager().downloaded_models_enabled
|
|
1245
|
+
|
|
1246
|
+
return {
|
|
1247
|
+
"status": "success",
|
|
1248
|
+
"message": f"Deleted model: {config.model_name}",
|
|
1249
|
+
}
|
|
1250
|
+
except ValueError as e:
|
|
1251
|
+
raise HTTPException(
|
|
1252
|
+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
1253
|
+
detail=str(e),
|
|
1254
|
+
)
|
|
1255
|
+
except Exception as e:
|
|
1256
|
+
raise HTTPException(
|
|
1257
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
1258
|
+
detail=f"Failed to delete model {config.model_name}: {str(e)}",
|
|
1259
|
+
)
|
|
1260
|
+
|
|
1261
|
+
async def load_llm(self, config: LoadConfig):
|
|
1163
1262
|
"""
|
|
1164
|
-
Load
|
|
1263
|
+
Load a registered LLM into system memory. Install the model first, if needed.
|
|
1165
1264
|
config: the information required to load the model
|
|
1166
|
-
internal_call: indicates whether the call to this function came from
|
|
1167
|
-
an endpoint (False) or a method of this class (True)
|
|
1168
|
-
|
|
1169
|
-
There are 3 ways this method can be called:
|
|
1170
|
-
1. An external application asks to load a model by name, using the load endpoint
|
|
1171
|
-
a. This only differs from #2 in that an external application may
|
|
1172
|
-
provide more parameters than in #2, so we need to validate
|
|
1173
|
-
that those parameters are ok.
|
|
1174
|
-
b. Load the model
|
|
1175
|
-
|
|
1176
|
-
2. An external application asks to load a model by name,
|
|
1177
|
-
using the completions or chat_completions endpoints
|
|
1178
|
-
a. Look up the name in the built-in model dictionary to create
|
|
1179
|
-
a fully-populated LoadConfig.
|
|
1180
|
-
b. Load the model
|
|
1181
|
-
|
|
1182
|
-
3. An external application asks to load a model by checkpoint and recipe,
|
|
1183
|
-
using the load endpoint
|
|
1184
|
-
a. Populate the checkpoint and recipe into a LoadConfig
|
|
1185
|
-
b. Load the model
|
|
1186
|
-
|
|
1187
|
-
4. Completions or ChatCompletions asks to "load" a model by checkpoint
|
|
1188
|
-
a. This is only available when #3 has already been executed
|
|
1189
|
-
b. Verify that the checkpoint is already loaded,
|
|
1190
|
-
and raise an exception if it hasn't (don't load anything new)
|
|
1191
1265
|
"""
|
|
1192
1266
|
try:
|
|
1193
1267
|
await self._load_lock.acquire()
|
|
@@ -1196,104 +1270,53 @@ class Server(ManagementTool):
|
|
|
1196
1270
|
for _ in range(self.max_concurrent_generations):
|
|
1197
1271
|
await self._generate_semaphore.acquire()
|
|
1198
1272
|
|
|
1199
|
-
#
|
|
1200
|
-
|
|
1201
|
-
|
|
1202
|
-
# First, ensure that the arguments are valid
|
|
1203
|
-
if config.model_name:
|
|
1204
|
-
# Get the dictionary of supported model from disk
|
|
1205
|
-
supported_models = ModelManager().supported_models
|
|
1206
|
-
|
|
1207
|
-
# Refer to the model by name, since we know the name
|
|
1208
|
-
model_reference = config.model_name
|
|
1209
|
-
|
|
1210
|
-
if config.checkpoint or config.recipe:
|
|
1211
|
-
# Option #1, verify that there are no parameter mismatches
|
|
1212
|
-
built_in_config = supported_models[config.model_name]
|
|
1213
|
-
if config.checkpoint != built_in_config["checkpoint"]:
|
|
1214
|
-
self.model_load_failure(
|
|
1215
|
-
model_reference,
|
|
1216
|
-
message=(
|
|
1217
|
-
f"Load request for model_name={config.model_name} "
|
|
1218
|
-
"included a mismatched "
|
|
1219
|
-
f"checkpoint={config.checkpoint} parameter. Remove the checkpoint "
|
|
1220
|
-
f"parameter, or change it to {built_in_config['checkpoint']}."
|
|
1221
|
-
),
|
|
1222
|
-
)
|
|
1223
|
-
if config.recipe != built_in_config["recipe"]:
|
|
1224
|
-
self.model_load_failure(
|
|
1225
|
-
model_reference,
|
|
1226
|
-
message=(
|
|
1227
|
-
f"Load request for model_name={config.model_name} "
|
|
1228
|
-
"included a mismatched "
|
|
1229
|
-
f"recipe={config.recipe} parameter. Remove the checkpoint "
|
|
1230
|
-
f"parameter, or change it to {built_in_config['recipe']}."
|
|
1231
|
-
),
|
|
1232
|
-
)
|
|
1273
|
+
# Make sure the model is already registered
|
|
1274
|
+
supported_models = ModelManager().supported_models
|
|
1233
1275
|
|
|
1234
|
-
|
|
1235
|
-
|
|
1236
|
-
|
|
1237
|
-
# Option #2, look up the config from the supported models dictionary
|
|
1238
|
-
config_to_use = LoadConfig(**supported_models[config.model_name])
|
|
1239
|
-
|
|
1240
|
-
elif config.checkpoint:
|
|
1241
|
-
# Refer to the model by checkpoint
|
|
1242
|
-
model_reference = config.checkpoint
|
|
1243
|
-
|
|
1244
|
-
if config.recipe and not internal_call:
|
|
1245
|
-
# Option 3, use the config as-is, but add a custom model name
|
|
1246
|
-
config_to_use = config
|
|
1247
|
-
config_to_use.model_name = "Custom"
|
|
1248
|
-
elif internal_call:
|
|
1249
|
-
# Option 4, make sure the right checkpoint is loaded and then return
|
|
1250
|
-
if (
|
|
1251
|
-
self.llm_loaded
|
|
1252
|
-
and config.checkpoint == self.llm_loaded.checkpoint
|
|
1253
|
-
):
|
|
1254
|
-
return {
|
|
1255
|
-
"status": "success",
|
|
1256
|
-
"message": f"Model already loaded: {model_reference}",
|
|
1257
|
-
}
|
|
1258
|
-
else:
|
|
1259
|
-
self.model_load_failure(
|
|
1260
|
-
model_reference,
|
|
1261
|
-
message=(
|
|
1262
|
-
"Attempted run completions by using model=<checkpoint name>, "
|
|
1263
|
-
"however, "
|
|
1264
|
-
"this feature only works if the model has already been loaded "
|
|
1265
|
-
"using the load endpoint."
|
|
1266
|
-
),
|
|
1267
|
-
)
|
|
1268
|
-
else:
|
|
1269
|
-
self.recipe_missing_error(model_reference)
|
|
1276
|
+
# The `custom` name allows run-as-thread servers to bypass loading
|
|
1277
|
+
if config.model_name == "custom":
|
|
1278
|
+
config_to_use = config
|
|
1270
1279
|
else:
|
|
1271
|
-
|
|
1272
|
-
|
|
1273
|
-
|
|
1274
|
-
|
|
1275
|
-
|
|
1280
|
+
if config.model_name not in supported_models.keys():
|
|
1281
|
+
self.model_load_failure(
|
|
1282
|
+
config.model_name,
|
|
1283
|
+
message=(
|
|
1284
|
+
f"Load request for model_name={config.model_name} "
|
|
1285
|
+
"not registered with Lemonade Server. You can register and "
|
|
1286
|
+
"install new models with a `pull` request."
|
|
1287
|
+
),
|
|
1288
|
+
)
|
|
1289
|
+
|
|
1290
|
+
# Get additional properties from the model registry
|
|
1291
|
+
config_to_use = LoadConfig(**supported_models[config.model_name])
|
|
1276
1292
|
|
|
1277
1293
|
# Caching mechanism: if the checkpoint is already loaded there is nothing else to do
|
|
1278
1294
|
if (
|
|
1279
1295
|
self.llm_loaded
|
|
1280
1296
|
and config_to_use.checkpoint == self.llm_loaded.checkpoint
|
|
1281
1297
|
):
|
|
1282
|
-
|
|
1283
|
-
|
|
1284
|
-
|
|
1285
|
-
|
|
1298
|
+
if (
|
|
1299
|
+
self.llm_loaded.recipe == "llamacpp"
|
|
1300
|
+
and self.llama_server_process.poll()
|
|
1301
|
+
):
|
|
1302
|
+
# llama-server process has gone away for some reason, so we should
|
|
1303
|
+
# proceed with loading to get it back
|
|
1304
|
+
pass
|
|
1305
|
+
else:
|
|
1306
|
+
return {
|
|
1307
|
+
"status": "success",
|
|
1308
|
+
"message": f"Model already loaded: {config.model_name}",
|
|
1309
|
+
}
|
|
1286
1310
|
|
|
1287
1311
|
# Unload the current model if needed
|
|
1288
1312
|
if self.llm_loaded:
|
|
1289
1313
|
await self.unload_llm(require_lock=False)
|
|
1290
1314
|
|
|
1291
|
-
logging.info(f"Loading llm: {
|
|
1315
|
+
logging.info(f"Loading llm: {config.model_name}")
|
|
1292
1316
|
try:
|
|
1293
1317
|
if config_to_use.recipe == "llamacpp":
|
|
1294
1318
|
self.llama_server_process = llamacpp.server_load(
|
|
1295
1319
|
model_config=config_to_use,
|
|
1296
|
-
model_reference=model_reference,
|
|
1297
1320
|
telemetry=self.llama_telemetry,
|
|
1298
1321
|
)
|
|
1299
1322
|
|
|
@@ -1305,12 +1328,12 @@ class Server(ManagementTool):
|
|
|
1305
1328
|
|
|
1306
1329
|
return {
|
|
1307
1330
|
"status": "success",
|
|
1308
|
-
"message": f"Loaded model: {
|
|
1331
|
+
"message": f"Loaded model: {config.model_name}",
|
|
1309
1332
|
}
|
|
1310
1333
|
except HTTPException:
|
|
1311
1334
|
raise
|
|
1312
1335
|
except Exception: # pylint: disable=broad-exception-caught
|
|
1313
|
-
self.model_load_failure(
|
|
1336
|
+
self.model_load_failure(config.model_name)
|
|
1314
1337
|
|
|
1315
1338
|
finally:
|
|
1316
1339
|
self._load_lock.release()
|
|
Binary file
|