lemonade-sdk 7.0.3__py3-none-any.whl → 8.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lemonade-sdk might be problematic. Click here for more details.

Files changed (55) hide show
  1. lemonade/api.py +3 -3
  2. lemonade/cli.py +11 -17
  3. lemonade/common/build.py +0 -47
  4. lemonade/common/network.py +50 -0
  5. lemonade/common/status.py +2 -21
  6. lemonade/common/system_info.py +19 -4
  7. lemonade/profilers/memory_tracker.py +3 -1
  8. lemonade/tools/accuracy.py +3 -4
  9. lemonade/tools/adapter.py +1 -2
  10. lemonade/tools/{huggingface_bench.py → huggingface/bench.py} +2 -87
  11. lemonade/tools/huggingface/load.py +235 -0
  12. lemonade/tools/{huggingface_load.py → huggingface/utils.py} +87 -255
  13. lemonade/tools/humaneval.py +9 -3
  14. lemonade/tools/{llamacpp_bench.py → llamacpp/bench.py} +1 -1
  15. lemonade/tools/{llamacpp.py → llamacpp/load.py} +18 -2
  16. lemonade/tools/mmlu.py +7 -15
  17. lemonade/tools/{ort_genai/oga.py → oga/load.py} +31 -422
  18. lemonade/tools/oga/utils.py +423 -0
  19. lemonade/tools/perplexity.py +4 -3
  20. lemonade/tools/prompt.py +2 -1
  21. lemonade/tools/quark/quark_load.py +2 -1
  22. lemonade/tools/quark/quark_quantize.py +5 -5
  23. lemonade/tools/report/table.py +3 -3
  24. lemonade/tools/server/llamacpp.py +159 -34
  25. lemonade/tools/server/serve.py +169 -147
  26. lemonade/tools/server/static/favicon.ico +0 -0
  27. lemonade/tools/server/static/styles.css +568 -0
  28. lemonade/tools/server/static/webapp.html +439 -0
  29. lemonade/tools/server/tray.py +458 -0
  30. lemonade/tools/server/{port_utils.py → utils/port.py} +22 -3
  31. lemonade/tools/server/utils/system_tray.py +395 -0
  32. lemonade/tools/server/{instructions.py → webapp.py} +4 -10
  33. lemonade/version.py +1 -1
  34. lemonade_install/install.py +46 -28
  35. {lemonade_sdk-7.0.3.dist-info → lemonade_sdk-8.0.0.dist-info}/METADATA +84 -22
  36. lemonade_sdk-8.0.0.dist-info/RECORD +70 -0
  37. lemonade_server/cli.py +182 -27
  38. lemonade_server/model_manager.py +192 -20
  39. lemonade_server/pydantic_models.py +9 -4
  40. lemonade_server/server_models.json +5 -3
  41. lemonade/common/analyze_model.py +0 -26
  42. lemonade/common/labels.py +0 -61
  43. lemonade/common/onnx_helpers.py +0 -176
  44. lemonade/common/plugins.py +0 -10
  45. lemonade/common/tensor_helpers.py +0 -83
  46. lemonade/tools/server/static/instructions.html +0 -262
  47. lemonade_sdk-7.0.3.dist-info/RECORD +0 -69
  48. /lemonade/tools/{ort_genai → oga}/__init__.py +0 -0
  49. /lemonade/tools/{ort_genai/oga_bench.py → oga/bench.py} +0 -0
  50. /lemonade/tools/server/{thread_utils.py → utils/thread.py} +0 -0
  51. {lemonade_sdk-7.0.3.dist-info → lemonade_sdk-8.0.0.dist-info}/WHEEL +0 -0
  52. {lemonade_sdk-7.0.3.dist-info → lemonade_sdk-8.0.0.dist-info}/entry_points.txt +0 -0
  53. {lemonade_sdk-7.0.3.dist-info → lemonade_sdk-8.0.0.dist-info}/licenses/LICENSE +0 -0
  54. {lemonade_sdk-7.0.3.dist-info → lemonade_sdk-8.0.0.dist-info}/licenses/NOTICE.md +0 -0
  55. {lemonade_sdk-7.0.3.dist-info → lemonade_sdk-8.0.0.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,12 @@
1
+ import sys
1
2
  import argparse
2
3
  import asyncio
3
4
  import statistics
4
5
  import time
5
6
  from threading import Thread, Event
6
7
  import logging
8
+ import platform
9
+ import tempfile
7
10
  import traceback
8
11
  from typing import Optional, Union
9
12
  import json
@@ -17,7 +20,6 @@ from fastapi.staticfiles import StaticFiles
17
20
  import uvicorn
18
21
  from uvicorn.config import Config
19
22
  from uvicorn.server import Server as UvicornServer
20
- from transformers import TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList
21
23
  from tabulate import tabulate
22
24
 
23
25
  from openai.types.completion import Completion, CompletionChoice
@@ -53,12 +55,17 @@ from lemonade_server.pydantic_models import (
53
55
  ChatCompletionRequest,
54
56
  ResponsesRequest,
55
57
  PullConfig,
58
+ DeleteConfig,
56
59
  )
57
60
  from lemonade.tools.management_tools import ManagementTool
58
61
  import lemonade.tools.server.llamacpp as llamacpp
59
62
  from lemonade.tools.server.tool_calls import extract_tool_calls, get_tool_call_pattern
60
- from lemonade.tools.server.instructions import get_instructions_html
61
- from lemonade.tools.server.port_utils import lifespan
63
+ from lemonade.tools.server.webapp import get_webapp_html
64
+ from lemonade.tools.server.utils.port import lifespan
65
+
66
+ # Only import tray on Windows
67
+ if platform.system() == "Windows":
68
+ from lemonade.tools.server.tray import LemonadeTray, OutputDuplicator
62
69
 
63
70
  DEFAULT_PORT = 8000
64
71
  DEFAULT_LOG_LEVEL = "info"
@@ -100,7 +107,7 @@ class GeneratorThread(Thread):
100
107
  self.streamer.done()
101
108
 
102
109
 
103
- class StopOnEvent(StoppingCriteria):
110
+ class StopOnEvent:
104
111
  """
105
112
  Custom stopping criteria that halts text generation when a specified event is set.
106
113
 
@@ -122,6 +129,7 @@ class Server(ManagementTool):
122
129
 
123
130
  The server exposes these endpoints:
124
131
  - /api/v1/pull: install an LLM by its Lemonade Server Model Name.
132
+ - /api/v1/delete: delete an LLM by its Lemonade Server Model Name.
125
133
  - /api/v1/load: load a model checkpoint.
126
134
  - /api/v1/unload: unload a model checkpoint.
127
135
  - /api/v1/health: check whether a model is loaded and ready to serve.
@@ -141,6 +149,10 @@ class Server(ManagementTool):
141
149
  # Initialize FastAPI app
142
150
  self.app = FastAPI(lifespan=lifespan)
143
151
 
152
+ # Lifespan will load some tasks in the background, and then set the
153
+ # app.initialized flag to True when this is done
154
+ self.app.initialized = False
155
+
144
156
  # Add CORS middleware
145
157
  self.app.add_middleware(
146
158
  CORSMiddleware,
@@ -153,11 +165,11 @@ class Server(ManagementTool):
153
165
  # Set up custom routes
154
166
  self.setup_routes(["/api/v0", "/api/v1"])
155
167
 
156
- # Set up instructions
157
- self.app.get("/")(self.instructions)
168
+ # Set up Web App
169
+ self.app.get("/")(self.webapp)
158
170
 
159
171
  # Mount a static assets dir for HTML responses, such
160
- # as the instructions
172
+ # as the Web App
161
173
  static_dir = Path(__file__).parent / "static"
162
174
  self.app.mount(
163
175
  "/static", StaticFiles(directory=static_dir), name="static_assets"
@@ -207,6 +219,7 @@ class Server(ManagementTool):
207
219
  for prefix in api_prefixes:
208
220
  # Custom routes
209
221
  self.app.post(f"{prefix}/pull")(self.pull)
222
+ self.app.post(f"{prefix}/delete")(self.delete)
210
223
  self.app.post(f"{prefix}/load")(self.load_llm)
211
224
  self.app.post(f"{prefix}/unload")(self.unload_llm)
212
225
  self.app.get(f"{prefix}/health")(self.health)
@@ -226,6 +239,14 @@ class Server(ManagementTool):
226
239
  add_help=add_help,
227
240
  )
228
241
 
242
+ # Only add the tray option on Windows
243
+ if platform.system() == "Windows":
244
+ parser.add_argument(
245
+ "--tray",
246
+ action="store_true",
247
+ help="Run the server in system tray mode",
248
+ )
249
+
229
250
  parser.add_argument(
230
251
  "--port",
231
252
  required=False,
@@ -242,6 +263,13 @@ class Server(ManagementTool):
242
263
  help=f"Logging level (default: {DEFAULT_LOG_LEVEL})",
243
264
  )
244
265
 
266
+ parser.add_argument(
267
+ "--log-file",
268
+ required=False,
269
+ type=str,
270
+ help="Path to the log file",
271
+ )
272
+
245
273
  return parser
246
274
 
247
275
  def _setup_server_common(
@@ -249,6 +277,8 @@ class Server(ManagementTool):
249
277
  port: int,
250
278
  truncate_inputs: bool = False,
251
279
  log_level: str = DEFAULT_LOG_LEVEL,
280
+ tray: bool = False,
281
+ log_file: str = None,
252
282
  threaded_mode: bool = False,
253
283
  ):
254
284
  """
@@ -280,23 +310,43 @@ class Server(ManagementTool):
280
310
  else:
281
311
  # Configure logging to match uvicorn's format
282
312
  logging_level = getattr(logging, log_level.upper())
283
- logging.basicConfig(
284
- level=logging_level,
285
- format="%(levelprefix)s %(message)s",
286
- datefmt="%Y-%m-%d %H:%M:%S",
287
- )
288
313
 
289
- # Add uvicorn's log formatter
290
- logging.root.handlers[0].formatter = uvicorn.logging.DefaultFormatter(
314
+ # Set up file handler for logging to lemonade.log
315
+ uvicorn_formatter = uvicorn.logging.DefaultFormatter(
291
316
  fmt="%(levelprefix)s %(message)s",
292
317
  use_colors=True,
293
318
  )
294
-
295
- # Ensure the log level is properly set
296
- logging.getLogger().setLevel(logging_level)
319
+ if not log_file:
320
+ log_file = tempfile.NamedTemporaryFile(
321
+ prefix="lemonade_", suffix=".log", delete=False
322
+ ).name
323
+ file_handler = logging.FileHandler(log_file, mode="a", encoding="utf-8")
324
+ file_handler.setLevel(logging_level)
325
+ file_handler.setFormatter(uvicorn_formatter)
326
+
327
+ # Set up console handler
328
+ console_handler = logging.StreamHandler()
329
+ console_handler.setLevel(logging_level)
330
+ console_handler.setFormatter(uvicorn_formatter)
331
+
332
+ # Configure root logger with both handlers
333
+ logging.basicConfig(
334
+ level=logging_level,
335
+ handlers=[file_handler, console_handler],
336
+ force=True,
337
+ )
297
338
 
298
339
  # Update debug logging state after setting log level
299
340
  self.debug_logging_enabled = logging.getLogger().isEnabledFor(logging.DEBUG)
341
+ if tray:
342
+ # Save original stdout/stderr
343
+ sys.stdout = OutputDuplicator(log_file, sys.stdout)
344
+ sys.stderr = OutputDuplicator(log_file, sys.stderr)
345
+
346
+ # Open lemonade server in tray mode
347
+ # lambda function used for deferred instantiation and thread safety
348
+ LemonadeTray(log_file, port, lambda: Server()).run()
349
+ sys.exit(0)
300
350
 
301
351
  if self.debug_logging_enabled:
302
352
  # Print the elapsed time for each request
@@ -314,6 +364,8 @@ class Server(ManagementTool):
314
364
  port: int = DEFAULT_PORT,
315
365
  log_level: str = DEFAULT_LOG_LEVEL,
316
366
  truncate_inputs: bool = False,
367
+ tray: bool = False,
368
+ log_file: str = None,
317
369
  ):
318
370
  # Common setup
319
371
  self._setup_server_common(
@@ -321,6 +373,8 @@ class Server(ManagementTool):
321
373
  truncate_inputs=truncate_inputs,
322
374
  log_level=log_level,
323
375
  threaded_mode=False,
376
+ tray=tray,
377
+ log_file=log_file,
324
378
  )
325
379
 
326
380
  uvicorn.run(self.app, host="localhost", port=port, log_level=log_level)
@@ -342,6 +396,7 @@ class Server(ManagementTool):
342
396
  truncate_inputs=truncate_inputs,
343
397
  log_level=log_level,
344
398
  threaded_mode=True,
399
+ tray=False,
345
400
  )
346
401
 
347
402
  class CustomServer(UvicornServer):
@@ -385,12 +440,12 @@ class Server(ManagementTool):
385
440
  # Show telemetry in debug while complying with uvicorn's log indentation
386
441
  logging.debug("\n ".join(table))
387
442
 
388
- def instructions(self):
443
+ def webapp(self):
389
444
  """
390
- Show instructions on how to use the server.
445
+ Serve the Web App to the user's browser.
391
446
  """
392
447
 
393
- return get_instructions_html(port=self.app.port)
448
+ return get_webapp_html(port=self.app.port)
394
449
 
395
450
  def initialize_load_config(
396
451
  self, request: Union[ChatCompletionRequest, CompletionRequest]
@@ -405,7 +460,8 @@ class Server(ManagementTool):
405
460
  # Get model config
406
461
  if "/" in request.model:
407
462
  # We know the model is a Hugging Face checkpoint if it contains a /
408
- lc = LoadConfig(checkpoint=request.model)
463
+ # This scenario is only supported for run-as-thread
464
+ lc = LoadConfig(model_name="custom", checkpoint=request.model)
409
465
  else:
410
466
  # The model should be a reference to a built-in model
411
467
  lc = LoadConfig(model_name=request.model)
@@ -420,7 +476,7 @@ class Server(ManagementTool):
420
476
  lc = self.initialize_load_config(completion_request)
421
477
 
422
478
  # Load the model if it's different from the currently loaded one
423
- await self.load_llm(lc, internal_call=True)
479
+ await self.load_llm(lc)
424
480
 
425
481
  # Check if the model supports reasoning
426
482
  reasoning_first_token = self.llm_loaded.reasoning
@@ -539,7 +595,7 @@ class Server(ManagementTool):
539
595
  lc = self.initialize_load_config(chat_completion_request)
540
596
 
541
597
  # Load the model if it's different from the currently loaded one
542
- await self.load_llm(lc, internal_call=True)
598
+ await self.load_llm(lc)
543
599
 
544
600
  if self.llm_loaded.recipe == "llamacpp":
545
601
  return llamacpp.chat_completion(
@@ -758,7 +814,7 @@ class Server(ManagementTool):
758
814
  lc = self.initialize_load_config(responses_request)
759
815
 
760
816
  # Load the model if it's different from the currently loaded one
761
- await self.load_llm(lc, internal_call=True)
817
+ await self.load_llm(lc)
762
818
 
763
819
  # Convert chat messages to text using the model's chat template
764
820
  if isinstance(responses_request.input, str):
@@ -912,6 +968,16 @@ class Server(ManagementTool):
912
968
  Core streaming completion logic, separated from response handling.
913
969
  Returns an async generator that yields tokens.
914
970
  """
971
+
972
+ while not self.app.initialized:
973
+ # Wait for the app's background tasks to finish before
974
+ # allowing generation to proceed
975
+ logging.debug("Waiting for server to fully initialize")
976
+ asyncio.sleep(0.5)
977
+ # These should already be imported as part of the app initialization process,
978
+ # they are just here to make 100% certain and to make the linter happy
979
+ from transformers import TextIteratorStreamer, StoppingCriteriaList
980
+
915
981
  model = self.model
916
982
  tokenizer = self.tokenizer
917
983
 
@@ -930,7 +996,7 @@ class Server(ManagementTool):
930
996
 
931
997
  # Set up the generation parameters
932
998
  if "oga-" in self.llm_loaded.recipe:
933
- from lemonade.tools.ort_genai.oga import OrtGenaiStreamer
999
+ from lemonade.tools.oga.utils import OrtGenaiStreamer
934
1000
 
935
1001
  streamer = OrtGenaiStreamer(tokenizer)
936
1002
  self.input_tokens = len(input_ids)
@@ -969,7 +1035,13 @@ class Server(ManagementTool):
969
1035
  logging.debug(f"Input Tokens: {self.input_tokens}")
970
1036
  logging.trace(f"Input Message: {message}")
971
1037
 
972
- stopping_criteria = StoppingCriteriaList([StopOnEvent(self.stop_event)])
1038
+ if self.llm_loaded.recipe.startswith("hf"):
1039
+ stopping_criteria = StoppingCriteriaList([StopOnEvent(self.stop_event)])
1040
+ else:
1041
+ # HF expects StoppingCriteriaList, which requires torch
1042
+ # If we aren't using HF, we can just use a list of StopOnEvent to
1043
+ # avoid the torch dep
1044
+ stopping_criteria = [StopOnEvent(self.stop_event)]
973
1045
 
974
1046
  generation_kwargs = {
975
1047
  "input_ids": input_ids,
@@ -1103,7 +1175,6 @@ class Server(ManagementTool):
1103
1175
  """
1104
1176
  Report server health information to the client.
1105
1177
  """
1106
- self.stop_event.set()
1107
1178
 
1108
1179
  return {
1109
1180
  "status": "ok",
@@ -1139,56 +1210,58 @@ class Server(ManagementTool):
1139
1210
  detail=detail,
1140
1211
  )
1141
1212
 
1142
- def recipe_missing_error(self, model_reference: str):
1143
- self.model_load_failure(
1144
- model_reference,
1145
- message=(
1146
- f"Attempted to load model by checkpoint name {model_reference}, "
1147
- "however the required 'recipe' parameter was not provided"
1148
- ),
1149
- )
1150
-
1151
1213
  async def pull(self, config: PullConfig):
1152
1214
  """
1153
1215
  Install a supported LLM by its Lemonade Model Name.
1154
1216
  """
1155
1217
 
1156
1218
  # Install the model
1157
- ModelManager().download_models([config.model_name])
1219
+ ModelManager().download_models(
1220
+ [config.model_name],
1221
+ checkpoint=config.checkpoint,
1222
+ recipe=config.recipe,
1223
+ reasoning=config.reasoning,
1224
+ mmproj=config.mmproj,
1225
+ )
1158
1226
 
1159
1227
  # Refresh the list of downloaded models, to ensure it
1160
1228
  # includes the model we just installed
1161
1229
  self.local_models = ModelManager().downloaded_models_enabled
1162
1230
 
1163
- async def load_llm(self, config: LoadConfig, internal_call=False):
1231
+ async def delete(self, config: DeleteConfig):
1232
+ """
1233
+ Delete a supported LLM by its Lemonade Model Name.
1164
1234
  """
1165
- Load an LLM into system memory.
1235
+ try:
1236
+ # If the model to be deleted is currently loaded, unload it first
1237
+ if self.llm_loaded and self.llm_loaded.model_name == config.model_name:
1238
+ await self.unload_llm(require_lock=True)
1239
+
1240
+ # Delete the model
1241
+ ModelManager().delete_model(config.model_name)
1242
+
1243
+ # Refresh the list of downloaded models
1244
+ self.local_models = ModelManager().downloaded_models_enabled
1245
+
1246
+ return {
1247
+ "status": "success",
1248
+ "message": f"Deleted model: {config.model_name}",
1249
+ }
1250
+ except ValueError as e:
1251
+ raise HTTPException(
1252
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
1253
+ detail=str(e),
1254
+ )
1255
+ except Exception as e:
1256
+ raise HTTPException(
1257
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
1258
+ detail=f"Failed to delete model {config.model_name}: {str(e)}",
1259
+ )
1260
+
1261
+ async def load_llm(self, config: LoadConfig):
1262
+ """
1263
+ Load a registered LLM into system memory. Install the model first, if needed.
1166
1264
  config: the information required to load the model
1167
- internal_call: indicates whether the call to this function came from
1168
- an endpoint (False) or a method of this class (True)
1169
-
1170
- There are 3 ways this method can be called:
1171
- 1. An external application asks to load a model by name, using the load endpoint
1172
- a. This only differs from #2 in that an external application may
1173
- provide more parameters than in #2, so we need to validate
1174
- that those parameters are ok.
1175
- b. Load the model
1176
-
1177
- 2. An external application asks to load a model by name,
1178
- using the completions or chat_completions endpoints
1179
- a. Look up the name in the built-in model dictionary to create
1180
- a fully-populated LoadConfig.
1181
- b. Load the model
1182
-
1183
- 3. An external application asks to load a model by checkpoint and recipe,
1184
- using the load endpoint
1185
- a. Populate the checkpoint and recipe into a LoadConfig
1186
- b. Load the model
1187
-
1188
- 4. Completions or ChatCompletions asks to "load" a model by checkpoint
1189
- a. This is only available when #3 has already been executed
1190
- b. Verify that the checkpoint is already loaded,
1191
- and raise an exception if it hasn't (don't load anything new)
1192
1265
  """
1193
1266
  try:
1194
1267
  await self._load_lock.acquire()
@@ -1197,104 +1270,53 @@ class Server(ManagementTool):
1197
1270
  for _ in range(self.max_concurrent_generations):
1198
1271
  await self._generate_semaphore.acquire()
1199
1272
 
1200
- # We will populate a LoadConfig that has all of the required fields
1201
- config_to_use: LoadConfig
1202
-
1203
- # First, ensure that the arguments are valid
1204
- if config.model_name:
1205
- # Get the dictionary of supported model from disk
1206
- supported_models = ModelManager().supported_models
1207
-
1208
- # Refer to the model by name, since we know the name
1209
- model_reference = config.model_name
1210
-
1211
- if config.checkpoint or config.recipe:
1212
- # Option #1, verify that there are no parameter mismatches
1213
- built_in_config = supported_models[config.model_name]
1214
- if config.checkpoint != built_in_config["checkpoint"]:
1215
- self.model_load_failure(
1216
- model_reference,
1217
- message=(
1218
- f"Load request for model_name={config.model_name} "
1219
- "included a mismatched "
1220
- f"checkpoint={config.checkpoint} parameter. Remove the checkpoint "
1221
- f"parameter, or change it to {built_in_config['checkpoint']}."
1222
- ),
1223
- )
1224
- if config.recipe != built_in_config["recipe"]:
1225
- self.model_load_failure(
1226
- model_reference,
1227
- message=(
1228
- f"Load request for model_name={config.model_name} "
1229
- "included a mismatched "
1230
- f"recipe={config.recipe} parameter. Remove the checkpoint "
1231
- f"parameter, or change it to {built_in_config['recipe']}."
1232
- ),
1233
- )
1273
+ # Make sure the model is already registered
1274
+ supported_models = ModelManager().supported_models
1234
1275
 
1235
- # Use the config as-is
1236
- config_to_use = config
1237
- else:
1238
- # Option #2, look up the config from the supported models dictionary
1239
- config_to_use = LoadConfig(**supported_models[config.model_name])
1240
-
1241
- elif config.checkpoint:
1242
- # Refer to the model by checkpoint
1243
- model_reference = config.checkpoint
1244
-
1245
- if config.recipe and not internal_call:
1246
- # Option 3, use the config as-is, but add a custom model name
1247
- config_to_use = config
1248
- config_to_use.model_name = "Custom"
1249
- elif internal_call:
1250
- # Option 4, make sure the right checkpoint is loaded and then return
1251
- if (
1252
- self.llm_loaded
1253
- and config.checkpoint == self.llm_loaded.checkpoint
1254
- ):
1255
- return {
1256
- "status": "success",
1257
- "message": f"Model already loaded: {model_reference}",
1258
- }
1259
- else:
1260
- self.model_load_failure(
1261
- model_reference,
1262
- message=(
1263
- "Attempted run completions by using model=<checkpoint name>, "
1264
- "however, "
1265
- "this feature only works if the model has already been loaded "
1266
- "using the load endpoint."
1267
- ),
1268
- )
1269
- else:
1270
- self.recipe_missing_error(model_reference)
1276
+ # The `custom` name allows run-as-thread servers to bypass loading
1277
+ if config.model_name == "custom":
1278
+ config_to_use = config
1271
1279
  else:
1272
- self.model_load_failure(
1273
- None,
1274
- message="Load requests must contain either a model_name or a "
1275
- "checkpoint parameter",
1276
- )
1280
+ if config.model_name not in supported_models.keys():
1281
+ self.model_load_failure(
1282
+ config.model_name,
1283
+ message=(
1284
+ f"Load request for model_name={config.model_name} "
1285
+ "not registered with Lemonade Server. You can register and "
1286
+ "install new models with a `pull` request."
1287
+ ),
1288
+ )
1289
+
1290
+ # Get additional properties from the model registry
1291
+ config_to_use = LoadConfig(**supported_models[config.model_name])
1277
1292
 
1278
1293
  # Caching mechanism: if the checkpoint is already loaded there is nothing else to do
1279
1294
  if (
1280
1295
  self.llm_loaded
1281
1296
  and config_to_use.checkpoint == self.llm_loaded.checkpoint
1282
1297
  ):
1283
- return {
1284
- "status": "success",
1285
- "message": f"Model already loaded: {model_reference}",
1286
- }
1298
+ if (
1299
+ self.llm_loaded.recipe == "llamacpp"
1300
+ and self.llama_server_process.poll()
1301
+ ):
1302
+ # llama-server process has gone away for some reason, so we should
1303
+ # proceed with loading to get it back
1304
+ pass
1305
+ else:
1306
+ return {
1307
+ "status": "success",
1308
+ "message": f"Model already loaded: {config.model_name}",
1309
+ }
1287
1310
 
1288
1311
  # Unload the current model if needed
1289
1312
  if self.llm_loaded:
1290
1313
  await self.unload_llm(require_lock=False)
1291
1314
 
1292
- logging.info(f"Loading llm: {model_reference}")
1315
+ logging.info(f"Loading llm: {config.model_name}")
1293
1316
  try:
1294
1317
  if config_to_use.recipe == "llamacpp":
1295
1318
  self.llama_server_process = llamacpp.server_load(
1296
1319
  model_config=config_to_use,
1297
- model_reference=model_reference,
1298
1320
  telemetry=self.llama_telemetry,
1299
1321
  )
1300
1322
 
@@ -1306,12 +1328,12 @@ class Server(ManagementTool):
1306
1328
 
1307
1329
  return {
1308
1330
  "status": "success",
1309
- "message": f"Loaded model: {model_reference}",
1331
+ "message": f"Loaded model: {config.model_name}",
1310
1332
  }
1311
1333
  except HTTPException:
1312
1334
  raise
1313
1335
  except Exception: # pylint: disable=broad-exception-caught
1314
- self.model_load_failure(model_reference)
1336
+ self.model_load_failure(config.model_name)
1315
1337
 
1316
1338
  finally:
1317
1339
  self._load_lock.release()
Binary file