lemonade-sdk 7.0.4__py3-none-any.whl → 8.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lemonade-sdk might be problematic. Click here for more details.

Files changed (56) hide show
  1. lemonade/api.py +3 -3
  2. lemonade/cli.py +11 -17
  3. lemonade/common/build.py +0 -47
  4. lemonade/common/network.py +50 -0
  5. lemonade/common/status.py +2 -21
  6. lemonade/common/system_info.py +19 -4
  7. lemonade/profilers/memory_tracker.py +3 -1
  8. lemonade/tools/accuracy.py +3 -4
  9. lemonade/tools/adapter.py +1 -2
  10. lemonade/tools/{huggingface_bench.py → huggingface/bench.py} +2 -87
  11. lemonade/tools/huggingface/load.py +235 -0
  12. lemonade/tools/{huggingface_load.py → huggingface/utils.py} +87 -255
  13. lemonade/tools/humaneval.py +9 -3
  14. lemonade/tools/{llamacpp_bench.py → llamacpp/bench.py} +1 -1
  15. lemonade/tools/{llamacpp.py → llamacpp/load.py} +18 -2
  16. lemonade/tools/mmlu.py +7 -15
  17. lemonade/tools/{ort_genai/oga.py → oga/load.py} +31 -422
  18. lemonade/tools/oga/utils.py +423 -0
  19. lemonade/tools/perplexity.py +4 -3
  20. lemonade/tools/prompt.py +2 -1
  21. lemonade/tools/quark/quark_load.py +2 -1
  22. lemonade/tools/quark/quark_quantize.py +5 -5
  23. lemonade/tools/report/table.py +3 -3
  24. lemonade/tools/server/llamacpp.py +188 -45
  25. lemonade/tools/server/serve.py +184 -146
  26. lemonade/tools/server/static/favicon.ico +0 -0
  27. lemonade/tools/server/static/styles.css +568 -0
  28. lemonade/tools/server/static/webapp.html +439 -0
  29. lemonade/tools/server/tray.py +458 -0
  30. lemonade/tools/server/{port_utils.py → utils/port.py} +22 -3
  31. lemonade/tools/server/utils/system_tray.py +395 -0
  32. lemonade/tools/server/{instructions.py → webapp.py} +4 -10
  33. lemonade/version.py +1 -1
  34. lemonade_install/install.py +46 -28
  35. lemonade_sdk-8.0.1.dist-info/METADATA +179 -0
  36. lemonade_sdk-8.0.1.dist-info/RECORD +70 -0
  37. lemonade_server/cli.py +182 -27
  38. lemonade_server/model_manager.py +192 -20
  39. lemonade_server/pydantic_models.py +9 -4
  40. lemonade_server/server_models.json +5 -3
  41. lemonade/common/analyze_model.py +0 -26
  42. lemonade/common/labels.py +0 -61
  43. lemonade/common/onnx_helpers.py +0 -176
  44. lemonade/common/plugins.py +0 -10
  45. lemonade/common/tensor_helpers.py +0 -83
  46. lemonade/tools/server/static/instructions.html +0 -262
  47. lemonade_sdk-7.0.4.dist-info/METADATA +0 -113
  48. lemonade_sdk-7.0.4.dist-info/RECORD +0 -69
  49. /lemonade/tools/{ort_genai → oga}/__init__.py +0 -0
  50. /lemonade/tools/{ort_genai/oga_bench.py → oga/bench.py} +0 -0
  51. /lemonade/tools/server/{thread_utils.py → utils/thread.py} +0 -0
  52. {lemonade_sdk-7.0.4.dist-info → lemonade_sdk-8.0.1.dist-info}/WHEEL +0 -0
  53. {lemonade_sdk-7.0.4.dist-info → lemonade_sdk-8.0.1.dist-info}/entry_points.txt +0 -0
  54. {lemonade_sdk-7.0.4.dist-info → lemonade_sdk-8.0.1.dist-info}/licenses/LICENSE +0 -0
  55. {lemonade_sdk-7.0.4.dist-info → lemonade_sdk-8.0.1.dist-info}/licenses/NOTICE.md +0 -0
  56. {lemonade_sdk-7.0.4.dist-info → lemonade_sdk-8.0.1.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,12 @@
1
+ import sys
1
2
  import argparse
2
3
  import asyncio
3
4
  import statistics
4
5
  import time
5
6
  from threading import Thread, Event
6
7
  import logging
8
+ import platform
9
+ import tempfile
7
10
  import traceback
8
11
  from typing import Optional, Union
9
12
  import json
@@ -17,7 +20,6 @@ from fastapi.staticfiles import StaticFiles
17
20
  import uvicorn
18
21
  from uvicorn.config import Config
19
22
  from uvicorn.server import Server as UvicornServer
20
- from transformers import TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList
21
23
  from tabulate import tabulate
22
24
 
23
25
  from openai.types.completion import Completion, CompletionChoice
@@ -27,6 +29,7 @@ from openai.types.chat.chat_completion_message_tool_call import (
27
29
  ChatCompletionMessageToolCall,
28
30
  Function,
29
31
  )
32
+ from openai.types.completion_usage import CompletionUsage
30
33
  from openai.types.chat.chat_completion import Choice
31
34
  from openai.types.chat.chat_completion_chunk import (
32
35
  ChoiceDelta,
@@ -53,12 +56,17 @@ from lemonade_server.pydantic_models import (
53
56
  ChatCompletionRequest,
54
57
  ResponsesRequest,
55
58
  PullConfig,
59
+ DeleteConfig,
56
60
  )
57
61
  from lemonade.tools.management_tools import ManagementTool
58
62
  import lemonade.tools.server.llamacpp as llamacpp
59
63
  from lemonade.tools.server.tool_calls import extract_tool_calls, get_tool_call_pattern
60
- from lemonade.tools.server.instructions import get_instructions_html
61
- from lemonade.tools.server.port_utils import lifespan
64
+ from lemonade.tools.server.webapp import get_webapp_html
65
+ from lemonade.tools.server.utils.port import lifespan
66
+
67
+ # Only import tray on Windows
68
+ if platform.system() == "Windows":
69
+ from lemonade.tools.server.tray import LemonadeTray, OutputDuplicator
62
70
 
63
71
  DEFAULT_PORT = 8000
64
72
  DEFAULT_LOG_LEVEL = "info"
@@ -100,7 +108,7 @@ class GeneratorThread(Thread):
100
108
  self.streamer.done()
101
109
 
102
110
 
103
- class StopOnEvent(StoppingCriteria):
111
+ class StopOnEvent:
104
112
  """
105
113
  Custom stopping criteria that halts text generation when a specified event is set.
106
114
 
@@ -122,6 +130,7 @@ class Server(ManagementTool):
122
130
 
123
131
  The server exposes these endpoints:
124
132
  - /api/v1/pull: install an LLM by its Lemonade Server Model Name.
133
+ - /api/v1/delete: delete an LLM by its Lemonade Server Model Name.
125
134
  - /api/v1/load: load a model checkpoint.
126
135
  - /api/v1/unload: unload a model checkpoint.
127
136
  - /api/v1/health: check whether a model is loaded and ready to serve.
@@ -141,6 +150,10 @@ class Server(ManagementTool):
141
150
  # Initialize FastAPI app
142
151
  self.app = FastAPI(lifespan=lifespan)
143
152
 
153
+ # Lifespan will load some tasks in the background, and then set the
154
+ # app.initialized flag to True when this is done
155
+ self.app.initialized = False
156
+
144
157
  # Add CORS middleware
145
158
  self.app.add_middleware(
146
159
  CORSMiddleware,
@@ -153,11 +166,11 @@ class Server(ManagementTool):
153
166
  # Set up custom routes
154
167
  self.setup_routes(["/api/v0", "/api/v1"])
155
168
 
156
- # Set up instructions
157
- self.app.get("/")(self.instructions)
169
+ # Set up Web App
170
+ self.app.get("/")(self.webapp)
158
171
 
159
172
  # Mount a static assets dir for HTML responses, such
160
- # as the instructions
173
+ # as the Web App
161
174
  static_dir = Path(__file__).parent / "static"
162
175
  self.app.mount(
163
176
  "/static", StaticFiles(directory=static_dir), name="static_assets"
@@ -207,6 +220,7 @@ class Server(ManagementTool):
207
220
  for prefix in api_prefixes:
208
221
  # Custom routes
209
222
  self.app.post(f"{prefix}/pull")(self.pull)
223
+ self.app.post(f"{prefix}/delete")(self.delete)
210
224
  self.app.post(f"{prefix}/load")(self.load_llm)
211
225
  self.app.post(f"{prefix}/unload")(self.unload_llm)
212
226
  self.app.get(f"{prefix}/health")(self.health)
@@ -226,6 +240,14 @@ class Server(ManagementTool):
226
240
  add_help=add_help,
227
241
  )
228
242
 
243
+ # Only add the tray option on Windows
244
+ if platform.system() == "Windows":
245
+ parser.add_argument(
246
+ "--tray",
247
+ action="store_true",
248
+ help="Run the server in system tray mode",
249
+ )
250
+
229
251
  parser.add_argument(
230
252
  "--port",
231
253
  required=False,
@@ -242,6 +264,13 @@ class Server(ManagementTool):
242
264
  help=f"Logging level (default: {DEFAULT_LOG_LEVEL})",
243
265
  )
244
266
 
267
+ parser.add_argument(
268
+ "--log-file",
269
+ required=False,
270
+ type=str,
271
+ help="Path to the log file",
272
+ )
273
+
245
274
  return parser
246
275
 
247
276
  def _setup_server_common(
@@ -249,6 +278,8 @@ class Server(ManagementTool):
249
278
  port: int,
250
279
  truncate_inputs: bool = False,
251
280
  log_level: str = DEFAULT_LOG_LEVEL,
281
+ tray: bool = False,
282
+ log_file: str = None,
252
283
  threaded_mode: bool = False,
253
284
  ):
254
285
  """
@@ -280,23 +311,43 @@ class Server(ManagementTool):
280
311
  else:
281
312
  # Configure logging to match uvicorn's format
282
313
  logging_level = getattr(logging, log_level.upper())
283
- logging.basicConfig(
284
- level=logging_level,
285
- format="%(levelprefix)s %(message)s",
286
- datefmt="%Y-%m-%d %H:%M:%S",
287
- )
288
314
 
289
- # Add uvicorn's log formatter
290
- logging.root.handlers[0].formatter = uvicorn.logging.DefaultFormatter(
315
+ # Set up file handler for logging to lemonade.log
316
+ uvicorn_formatter = uvicorn.logging.DefaultFormatter(
291
317
  fmt="%(levelprefix)s %(message)s",
292
318
  use_colors=True,
293
319
  )
294
-
295
- # Ensure the log level is properly set
296
- logging.getLogger().setLevel(logging_level)
320
+ if not log_file:
321
+ log_file = tempfile.NamedTemporaryFile(
322
+ prefix="lemonade_", suffix=".log", delete=False
323
+ ).name
324
+ file_handler = logging.FileHandler(log_file, mode="a", encoding="utf-8")
325
+ file_handler.setLevel(logging_level)
326
+ file_handler.setFormatter(uvicorn_formatter)
327
+
328
+ # Set up console handler
329
+ console_handler = logging.StreamHandler()
330
+ console_handler.setLevel(logging_level)
331
+ console_handler.setFormatter(uvicorn_formatter)
332
+
333
+ # Configure root logger with both handlers
334
+ logging.basicConfig(
335
+ level=logging_level,
336
+ handlers=[file_handler, console_handler],
337
+ force=True,
338
+ )
297
339
 
298
340
  # Update debug logging state after setting log level
299
341
  self.debug_logging_enabled = logging.getLogger().isEnabledFor(logging.DEBUG)
342
+ if tray:
343
+ # Save original stdout/stderr
344
+ sys.stdout = OutputDuplicator(log_file, sys.stdout)
345
+ sys.stderr = OutputDuplicator(log_file, sys.stderr)
346
+
347
+ # Open lemonade server in tray mode
348
+ # lambda function used for deferred instantiation and thread safety
349
+ LemonadeTray(log_file, port, lambda: Server()).run()
350
+ sys.exit(0)
300
351
 
301
352
  if self.debug_logging_enabled:
302
353
  # Print the elapsed time for each request
@@ -314,6 +365,8 @@ class Server(ManagementTool):
314
365
  port: int = DEFAULT_PORT,
315
366
  log_level: str = DEFAULT_LOG_LEVEL,
316
367
  truncate_inputs: bool = False,
368
+ tray: bool = False,
369
+ log_file: str = None,
317
370
  ):
318
371
  # Common setup
319
372
  self._setup_server_common(
@@ -321,6 +374,8 @@ class Server(ManagementTool):
321
374
  truncate_inputs=truncate_inputs,
322
375
  log_level=log_level,
323
376
  threaded_mode=False,
377
+ tray=tray,
378
+ log_file=log_file,
324
379
  )
325
380
 
326
381
  uvicorn.run(self.app, host="localhost", port=port, log_level=log_level)
@@ -342,6 +397,7 @@ class Server(ManagementTool):
342
397
  truncate_inputs=truncate_inputs,
343
398
  log_level=log_level,
344
399
  threaded_mode=True,
400
+ tray=False,
345
401
  )
346
402
 
347
403
  class CustomServer(UvicornServer):
@@ -385,12 +441,12 @@ class Server(ManagementTool):
385
441
  # Show telemetry in debug while complying with uvicorn's log indentation
386
442
  logging.debug("\n ".join(table))
387
443
 
388
- def instructions(self):
444
+ def webapp(self):
389
445
  """
390
- Show instructions on how to use the server.
446
+ Serve the Web App to the user's browser.
391
447
  """
392
448
 
393
- return get_instructions_html(port=self.app.port)
449
+ return get_webapp_html(port=self.app.port)
394
450
 
395
451
  def initialize_load_config(
396
452
  self, request: Union[ChatCompletionRequest, CompletionRequest]
@@ -405,7 +461,8 @@ class Server(ManagementTool):
405
461
  # Get model config
406
462
  if "/" in request.model:
407
463
  # We know the model is a Hugging Face checkpoint if it contains a /
408
- lc = LoadConfig(checkpoint=request.model)
464
+ # This scenario is only supported for run-as-thread
465
+ lc = LoadConfig(model_name="custom", checkpoint=request.model)
409
466
  else:
410
467
  # The model should be a reference to a built-in model
411
468
  lc = LoadConfig(model_name=request.model)
@@ -420,7 +477,7 @@ class Server(ManagementTool):
420
477
  lc = self.initialize_load_config(completion_request)
421
478
 
422
479
  # Load the model if it's different from the currently loaded one
423
- await self.load_llm(lc, internal_call=True)
480
+ await self.load_llm(lc)
424
481
 
425
482
  # Check if the model supports reasoning
426
483
  reasoning_first_token = self.llm_loaded.reasoning
@@ -520,9 +577,16 @@ class Server(ManagementTool):
520
577
  logprobs=logprobs,
521
578
  )
522
579
 
580
+ usage = CompletionUsage(
581
+ prompt_tokens=self.input_tokens,
582
+ completion_tokens=self.output_tokens,
583
+ total_tokens=self.input_tokens + self.output_tokens,
584
+ )
585
+
523
586
  return Completion(
524
587
  id="0",
525
588
  choices=[choice],
589
+ usage=usage,
526
590
  model=self.llm_loaded.checkpoint,
527
591
  object="text_completion",
528
592
  created=int(time.time()),
@@ -539,7 +603,7 @@ class Server(ManagementTool):
539
603
  lc = self.initialize_load_config(chat_completion_request)
540
604
 
541
605
  # Load the model if it's different from the currently loaded one
542
- await self.load_llm(lc, internal_call=True)
606
+ await self.load_llm(lc)
543
607
 
544
608
  if self.llm_loaded.recipe == "llamacpp":
545
609
  return llamacpp.chat_completion(
@@ -717,9 +781,16 @@ class Server(ManagementTool):
717
781
  logprobs=None,
718
782
  )
719
783
 
784
+ usage = CompletionUsage(
785
+ prompt_tokens=self.input_tokens,
786
+ completion_tokens=self.output_tokens,
787
+ total_tokens=self.input_tokens + self.output_tokens,
788
+ )
789
+
720
790
  return ChatCompletion(
721
791
  id="0",
722
792
  choices=[choice],
793
+ usage=usage,
723
794
  model=self.llm_loaded.checkpoint,
724
795
  object="chat.completion",
725
796
  created=int(time.time()),
@@ -758,7 +829,7 @@ class Server(ManagementTool):
758
829
  lc = self.initialize_load_config(responses_request)
759
830
 
760
831
  # Load the model if it's different from the currently loaded one
761
- await self.load_llm(lc, internal_call=True)
832
+ await self.load_llm(lc)
762
833
 
763
834
  # Convert chat messages to text using the model's chat template
764
835
  if isinstance(responses_request.input, str):
@@ -912,6 +983,16 @@ class Server(ManagementTool):
912
983
  Core streaming completion logic, separated from response handling.
913
984
  Returns an async generator that yields tokens.
914
985
  """
986
+
987
+ while not self.app.initialized:
988
+ # Wait for the app's background tasks to finish before
989
+ # allowing generation to proceed
990
+ logging.debug("Waiting for server to fully initialize")
991
+ asyncio.sleep(0.5)
992
+ # These should already be imported as part of the app initialization process,
993
+ # they are just here to make 100% certain and to make the linter happy
994
+ from transformers import TextIteratorStreamer, StoppingCriteriaList
995
+
915
996
  model = self.model
916
997
  tokenizer = self.tokenizer
917
998
 
@@ -930,7 +1011,7 @@ class Server(ManagementTool):
930
1011
 
931
1012
  # Set up the generation parameters
932
1013
  if "oga-" in self.llm_loaded.recipe:
933
- from lemonade.tools.ort_genai.oga import OrtGenaiStreamer
1014
+ from lemonade.tools.oga.utils import OrtGenaiStreamer
934
1015
 
935
1016
  streamer = OrtGenaiStreamer(tokenizer)
936
1017
  self.input_tokens = len(input_ids)
@@ -969,7 +1050,13 @@ class Server(ManagementTool):
969
1050
  logging.debug(f"Input Tokens: {self.input_tokens}")
970
1051
  logging.trace(f"Input Message: {message}")
971
1052
 
972
- stopping_criteria = StoppingCriteriaList([StopOnEvent(self.stop_event)])
1053
+ if self.llm_loaded.recipe.startswith("hf"):
1054
+ stopping_criteria = StoppingCriteriaList([StopOnEvent(self.stop_event)])
1055
+ else:
1056
+ # HF expects StoppingCriteriaList, which requires torch
1057
+ # If we aren't using HF, we can just use a list of StopOnEvent to
1058
+ # avoid the torch dep
1059
+ stopping_criteria = [StopOnEvent(self.stop_event)]
973
1060
 
974
1061
  generation_kwargs = {
975
1062
  "input_ids": input_ids,
@@ -1138,56 +1225,58 @@ class Server(ManagementTool):
1138
1225
  detail=detail,
1139
1226
  )
1140
1227
 
1141
- def recipe_missing_error(self, model_reference: str):
1142
- self.model_load_failure(
1143
- model_reference,
1144
- message=(
1145
- f"Attempted to load model by checkpoint name {model_reference}, "
1146
- "however the required 'recipe' parameter was not provided"
1147
- ),
1148
- )
1149
-
1150
1228
  async def pull(self, config: PullConfig):
1151
1229
  """
1152
1230
  Install a supported LLM by its Lemonade Model Name.
1153
1231
  """
1154
1232
 
1155
1233
  # Install the model
1156
- ModelManager().download_models([config.model_name])
1234
+ ModelManager().download_models(
1235
+ [config.model_name],
1236
+ checkpoint=config.checkpoint,
1237
+ recipe=config.recipe,
1238
+ reasoning=config.reasoning,
1239
+ mmproj=config.mmproj,
1240
+ )
1157
1241
 
1158
1242
  # Refresh the list of downloaded models, to ensure it
1159
1243
  # includes the model we just installed
1160
1244
  self.local_models = ModelManager().downloaded_models_enabled
1161
1245
 
1162
- async def load_llm(self, config: LoadConfig, internal_call=False):
1246
+ async def delete(self, config: DeleteConfig):
1163
1247
  """
1164
- Load an LLM into system memory.
1248
+ Delete a supported LLM by its Lemonade Model Name.
1249
+ """
1250
+ try:
1251
+ # If the model to be deleted is currently loaded, unload it first
1252
+ if self.llm_loaded and self.llm_loaded.model_name == config.model_name:
1253
+ await self.unload_llm(require_lock=True)
1254
+
1255
+ # Delete the model
1256
+ ModelManager().delete_model(config.model_name)
1257
+
1258
+ # Refresh the list of downloaded models
1259
+ self.local_models = ModelManager().downloaded_models_enabled
1260
+
1261
+ return {
1262
+ "status": "success",
1263
+ "message": f"Deleted model: {config.model_name}",
1264
+ }
1265
+ except ValueError as e:
1266
+ raise HTTPException(
1267
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
1268
+ detail=str(e),
1269
+ )
1270
+ except Exception as e:
1271
+ raise HTTPException(
1272
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
1273
+ detail=f"Failed to delete model {config.model_name}: {str(e)}",
1274
+ )
1275
+
1276
+ async def load_llm(self, config: LoadConfig):
1277
+ """
1278
+ Load a registered LLM into system memory. Install the model first, if needed.
1165
1279
  config: the information required to load the model
1166
- internal_call: indicates whether the call to this function came from
1167
- an endpoint (False) or a method of this class (True)
1168
-
1169
- There are 3 ways this method can be called:
1170
- 1. An external application asks to load a model by name, using the load endpoint
1171
- a. This only differs from #2 in that an external application may
1172
- provide more parameters than in #2, so we need to validate
1173
- that those parameters are ok.
1174
- b. Load the model
1175
-
1176
- 2. An external application asks to load a model by name,
1177
- using the completions or chat_completions endpoints
1178
- a. Look up the name in the built-in model dictionary to create
1179
- a fully-populated LoadConfig.
1180
- b. Load the model
1181
-
1182
- 3. An external application asks to load a model by checkpoint and recipe,
1183
- using the load endpoint
1184
- a. Populate the checkpoint and recipe into a LoadConfig
1185
- b. Load the model
1186
-
1187
- 4. Completions or ChatCompletions asks to "load" a model by checkpoint
1188
- a. This is only available when #3 has already been executed
1189
- b. Verify that the checkpoint is already loaded,
1190
- and raise an exception if it hasn't (don't load anything new)
1191
1280
  """
1192
1281
  try:
1193
1282
  await self._load_lock.acquire()
@@ -1196,104 +1285,53 @@ class Server(ManagementTool):
1196
1285
  for _ in range(self.max_concurrent_generations):
1197
1286
  await self._generate_semaphore.acquire()
1198
1287
 
1199
- # We will populate a LoadConfig that has all of the required fields
1200
- config_to_use: LoadConfig
1201
-
1202
- # First, ensure that the arguments are valid
1203
- if config.model_name:
1204
- # Get the dictionary of supported model from disk
1205
- supported_models = ModelManager().supported_models
1206
-
1207
- # Refer to the model by name, since we know the name
1208
- model_reference = config.model_name
1209
-
1210
- if config.checkpoint or config.recipe:
1211
- # Option #1, verify that there are no parameter mismatches
1212
- built_in_config = supported_models[config.model_name]
1213
- if config.checkpoint != built_in_config["checkpoint"]:
1214
- self.model_load_failure(
1215
- model_reference,
1216
- message=(
1217
- f"Load request for model_name={config.model_name} "
1218
- "included a mismatched "
1219
- f"checkpoint={config.checkpoint} parameter. Remove the checkpoint "
1220
- f"parameter, or change it to {built_in_config['checkpoint']}."
1221
- ),
1222
- )
1223
- if config.recipe != built_in_config["recipe"]:
1224
- self.model_load_failure(
1225
- model_reference,
1226
- message=(
1227
- f"Load request for model_name={config.model_name} "
1228
- "included a mismatched "
1229
- f"recipe={config.recipe} parameter. Remove the checkpoint "
1230
- f"parameter, or change it to {built_in_config['recipe']}."
1231
- ),
1232
- )
1288
+ # Make sure the model is already registered
1289
+ supported_models = ModelManager().supported_models
1233
1290
 
1234
- # Use the config as-is
1235
- config_to_use = config
1236
- else:
1237
- # Option #2, look up the config from the supported models dictionary
1238
- config_to_use = LoadConfig(**supported_models[config.model_name])
1239
-
1240
- elif config.checkpoint:
1241
- # Refer to the model by checkpoint
1242
- model_reference = config.checkpoint
1243
-
1244
- if config.recipe and not internal_call:
1245
- # Option 3, use the config as-is, but add a custom model name
1246
- config_to_use = config
1247
- config_to_use.model_name = "Custom"
1248
- elif internal_call:
1249
- # Option 4, make sure the right checkpoint is loaded and then return
1250
- if (
1251
- self.llm_loaded
1252
- and config.checkpoint == self.llm_loaded.checkpoint
1253
- ):
1254
- return {
1255
- "status": "success",
1256
- "message": f"Model already loaded: {model_reference}",
1257
- }
1258
- else:
1259
- self.model_load_failure(
1260
- model_reference,
1261
- message=(
1262
- "Attempted run completions by using model=<checkpoint name>, "
1263
- "however, "
1264
- "this feature only works if the model has already been loaded "
1265
- "using the load endpoint."
1266
- ),
1267
- )
1268
- else:
1269
- self.recipe_missing_error(model_reference)
1291
+ # The `custom` name allows run-as-thread servers to bypass loading
1292
+ if config.model_name == "custom":
1293
+ config_to_use = config
1270
1294
  else:
1271
- self.model_load_failure(
1272
- None,
1273
- message="Load requests must contain either a model_name or a "
1274
- "checkpoint parameter",
1275
- )
1295
+ if config.model_name not in supported_models.keys():
1296
+ self.model_load_failure(
1297
+ config.model_name,
1298
+ message=(
1299
+ f"Load request for model_name={config.model_name} "
1300
+ "not registered with Lemonade Server. You can register and "
1301
+ "install new models with a `pull` request."
1302
+ ),
1303
+ )
1304
+
1305
+ # Get additional properties from the model registry
1306
+ config_to_use = LoadConfig(**supported_models[config.model_name])
1276
1307
 
1277
1308
  # Caching mechanism: if the checkpoint is already loaded there is nothing else to do
1278
1309
  if (
1279
1310
  self.llm_loaded
1280
1311
  and config_to_use.checkpoint == self.llm_loaded.checkpoint
1281
1312
  ):
1282
- return {
1283
- "status": "success",
1284
- "message": f"Model already loaded: {model_reference}",
1285
- }
1313
+ if (
1314
+ self.llm_loaded.recipe == "llamacpp"
1315
+ and self.llama_server_process.poll()
1316
+ ):
1317
+ # llama-server process has gone away for some reason, so we should
1318
+ # proceed with loading to get it back
1319
+ pass
1320
+ else:
1321
+ return {
1322
+ "status": "success",
1323
+ "message": f"Model already loaded: {config.model_name}",
1324
+ }
1286
1325
 
1287
1326
  # Unload the current model if needed
1288
1327
  if self.llm_loaded:
1289
1328
  await self.unload_llm(require_lock=False)
1290
1329
 
1291
- logging.info(f"Loading llm: {model_reference}")
1330
+ logging.info(f"Loading llm: {config.model_name}")
1292
1331
  try:
1293
1332
  if config_to_use.recipe == "llamacpp":
1294
1333
  self.llama_server_process = llamacpp.server_load(
1295
1334
  model_config=config_to_use,
1296
- model_reference=model_reference,
1297
1335
  telemetry=self.llama_telemetry,
1298
1336
  )
1299
1337
 
@@ -1305,12 +1343,12 @@ class Server(ManagementTool):
1305
1343
 
1306
1344
  return {
1307
1345
  "status": "success",
1308
- "message": f"Loaded model: {model_reference}",
1346
+ "message": f"Loaded model: {config.model_name}",
1309
1347
  }
1310
1348
  except HTTPException:
1311
1349
  raise
1312
1350
  except Exception: # pylint: disable=broad-exception-caught
1313
- self.model_load_failure(model_reference)
1351
+ self.model_load_failure(config.model_name)
1314
1352
 
1315
1353
  finally:
1316
1354
  self._load_lock.release()
Binary file