lemonade-sdk 8.0.6__py3-none-any.whl → 8.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lemonade-sdk might be problematic. Click here for more details.

@@ -1,5 +1,4 @@
1
1
  import os
2
- import sys
3
2
  import logging
4
3
  import time
5
4
  import subprocess
@@ -9,6 +8,7 @@ import platform
9
8
 
10
9
  import requests
11
10
  from tabulate import tabulate
11
+ from dotenv import load_dotenv
12
12
  from fastapi import HTTPException, status
13
13
  from fastapi.responses import StreamingResponse
14
14
 
@@ -29,8 +29,6 @@ from lemonade.tools.llamacpp.utils import (
29
29
  download_gguf,
30
30
  )
31
31
 
32
- LLAMA_VERSION = "b5787"
33
-
34
32
 
35
33
  def llamacpp_address(port: int) -> str:
36
34
  """
@@ -45,45 +43,6 @@ def llamacpp_address(port: int) -> str:
45
43
  return f"http://127.0.0.1:{port}/v1"
46
44
 
47
45
 
48
- def get_llama_server_paths():
49
- """
50
- Get platform-specific paths for llama server directory and executable
51
- """
52
- base_dir = os.path.join(os.path.dirname(sys.executable), "llama_server")
53
-
54
- if platform.system().lower() == "windows":
55
- return base_dir, os.path.join(base_dir, "llama-server.exe")
56
- else: # Linux/Ubuntu
57
- # Check if executable exists in build/bin subdirectory (Current Ubuntu structure)
58
- build_bin_path = os.path.join(base_dir, "build", "bin", "llama-server")
59
- if os.path.exists(build_bin_path):
60
- return base_dir, build_bin_path
61
- else:
62
- # Fallback to root directory
63
- return base_dir, os.path.join(base_dir, "llama-server")
64
-
65
-
66
- def get_binary_url_and_filename(version):
67
- """
68
- Get the appropriate binary URL and filename based on platform
69
- """
70
- system = platform.system().lower()
71
-
72
- if system == "windows":
73
- filename = f"llama-{version}-bin-win-vulkan-x64.zip"
74
- elif system == "linux":
75
- filename = f"llama-{version}-bin-ubuntu-vulkan-x64.zip"
76
- else:
77
- raise NotImplementedError(
78
- f"Platform {system} not supported for llamacpp. Supported: Windows, Ubuntu Linux"
79
- )
80
-
81
- url = (
82
- f"https://github.com/ggml-org/llama.cpp/releases/download/{version}/{filename}"
83
- )
84
- return url, filename
85
-
86
-
87
46
  class LlamaTelemetry:
88
47
  """
89
48
  Manages telemetry data collection and display for llama server.
@@ -125,7 +84,7 @@ class LlamaTelemetry:
125
84
  device_count = int(vulkan_match.group(1))
126
85
  if device_count > 0:
127
86
  logging.info(
128
- f"GPU acceleration active: {device_count} Vulkan device(s) "
87
+ f"GPU acceleration active: {device_count} device(s) "
129
88
  "detected by llama-server"
130
89
  )
131
90
  return
@@ -236,6 +195,8 @@ def _launch_llama_subprocess(
236
195
  snapshot_files: dict,
237
196
  use_gpu: bool,
238
197
  telemetry: LlamaTelemetry,
198
+ backend: str,
199
+ ctx_size: int,
239
200
  supports_embeddings: bool = False,
240
201
  supports_reranking: bool = False,
241
202
  ) -> subprocess.Popen:
@@ -246,6 +207,7 @@ def _launch_llama_subprocess(
246
207
  snapshot_files: Dictionary of model files to load
247
208
  use_gpu: Whether to use GPU acceleration
248
209
  telemetry: Telemetry object for tracking performance metrics
210
+ backend: Backend to use (e.g., 'vulkan', 'rocm')
249
211
  supports_embeddings: Whether the model supports embeddings
250
212
  supports_reranking: Whether the model supports reranking
251
213
 
@@ -254,10 +216,16 @@ def _launch_llama_subprocess(
254
216
  """
255
217
 
256
218
  # Get the current executable path (handles both Windows and Ubuntu structures)
257
- exe_path = get_llama_server_exe_path()
219
+ exe_path = get_llama_server_exe_path(backend)
258
220
 
259
221
  # Build the base command
260
- base_command = [exe_path, "-m", snapshot_files["variant"]]
222
+ base_command = [
223
+ exe_path,
224
+ "-m",
225
+ snapshot_files["variant"],
226
+ "--ctx-size",
227
+ str(ctx_size),
228
+ ]
261
229
  if "mmproj" in snapshot_files:
262
230
  base_command.extend(["--mmproj", snapshot_files["mmproj"]])
263
231
  if not use_gpu:
@@ -288,6 +256,15 @@ def _launch_llama_subprocess(
288
256
 
289
257
  # Set up environment with library path for Linux
290
258
  env = os.environ.copy()
259
+
260
+ # Load environment variables from .env file in the executable directory
261
+ exe_dir = os.path.dirname(exe_path)
262
+ env_file_path = os.path.join(exe_dir, ".env")
263
+ if os.path.exists(env_file_path):
264
+ load_dotenv(env_file_path, override=True)
265
+ env.update(os.environ)
266
+ logging.debug(f"Loaded environment variables from {env_file_path}")
267
+
291
268
  if platform.system().lower() == "linux":
292
269
  lib_dir = os.path.dirname(exe_path) # Same directory as the executable
293
270
  current_ld_path = env.get("LD_LIBRARY_PATH", "")
@@ -320,18 +297,17 @@ def _launch_llama_subprocess(
320
297
  return process
321
298
 
322
299
 
323
- def server_load(model_config: PullConfig, telemetry: LlamaTelemetry):
300
+ def server_load(
301
+ model_config: PullConfig, telemetry: LlamaTelemetry, backend: str, ctx_size: int
302
+ ):
324
303
  # Install and/or update llama.cpp if needed
325
304
  try:
326
- install_llamacpp()
305
+ install_llamacpp(backend)
327
306
  except NotImplementedError as e:
328
307
  raise HTTPException(
329
308
  status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e)
330
309
  )
331
310
 
332
- # Get platform-specific paths at runtime
333
- llama_server_exe_path = get_llama_server_exe_path()
334
-
335
311
  # Download the gguf to the hugging face cache
336
312
  snapshot_files = download_gguf(model_config.checkpoint, model_config.mmproj)
337
313
  logging.debug(f"GGUF file paths: {snapshot_files}")
@@ -342,14 +318,13 @@ def server_load(model_config: PullConfig, telemetry: LlamaTelemetry):
342
318
  supports_embeddings = "embeddings" in model_info.get("labels", [])
343
319
  supports_reranking = "reranking" in model_info.get("labels", [])
344
320
 
345
- # Start the llama-serve.exe process
346
- logging.debug(f"Using llama_server for GGUF model: {llama_server_exe_path}")
347
-
348
321
  # Attempt loading on GPU first
349
322
  llama_server_process = _launch_llama_subprocess(
350
323
  snapshot_files,
351
324
  use_gpu=True,
352
325
  telemetry=telemetry,
326
+ backend=backend,
327
+ ctx_size=ctx_size,
353
328
  supports_embeddings=supports_embeddings,
354
329
  supports_reranking=supports_reranking,
355
330
  )
@@ -374,6 +349,8 @@ def server_load(model_config: PullConfig, telemetry: LlamaTelemetry):
374
349
  snapshot_files,
375
350
  use_gpu=False,
376
351
  telemetry=telemetry,
352
+ backend=backend,
353
+ ctx_size=ctx_size,
377
354
  supports_embeddings=supports_embeddings,
378
355
  supports_reranking=supports_reranking,
379
356
  )
@@ -1,5 +1,4 @@
1
1
  import sys
2
- import argparse
3
2
  import asyncio
4
3
  import statistics
5
4
  import time
@@ -48,6 +47,11 @@ from openai.types.responses import (
48
47
  )
49
48
 
50
49
  import lemonade.api as lemonade_api
50
+ import lemonade.tools.server.llamacpp as llamacpp
51
+ from lemonade.tools.server.tool_calls import extract_tool_calls, get_tool_call_pattern
52
+ from lemonade.tools.server.webapp import get_webapp_html
53
+ from lemonade.tools.server.utils.port import lifespan
54
+
51
55
  from lemonade_server.model_manager import ModelManager
52
56
  from lemonade_server.pydantic_models import (
53
57
  DEFAULT_MAX_NEW_TOKENS,
@@ -60,18 +64,17 @@ from lemonade_server.pydantic_models import (
60
64
  PullConfig,
61
65
  DeleteConfig,
62
66
  )
63
- from lemonade.tools.management_tools import ManagementTool
64
- import lemonade.tools.server.llamacpp as llamacpp
65
- from lemonade.tools.server.tool_calls import extract_tool_calls, get_tool_call_pattern
66
- from lemonade.tools.server.webapp import get_webapp_html
67
- from lemonade.tools.server.utils.port import lifespan
68
67
 
69
68
  # Only import tray on Windows
70
69
  if platform.system() == "Windows":
70
+ # pylint: disable=ungrouped-imports
71
71
  from lemonade.tools.server.tray import LemonadeTray, OutputDuplicator
72
72
 
73
+
73
74
  DEFAULT_PORT = 8000
74
75
  DEFAULT_LOG_LEVEL = "info"
76
+ DEFAULT_LLAMACPP_BACKEND = "vulkan"
77
+ DEFAULT_CTX_SIZE = 4096
75
78
 
76
79
 
77
80
  class ServerModel(Model):
@@ -126,7 +129,7 @@ class StopOnEvent:
126
129
  return self.stop_event.is_set()
127
130
 
128
131
 
129
- class Server(ManagementTool):
132
+ class Server:
130
133
  """
131
134
  Open a web server that apps can use to communicate with the LLM.
132
135
 
@@ -144,11 +147,25 @@ class Server(ManagementTool):
144
147
  - /api/v1/models: list all available models.
145
148
  """
146
149
 
147
- unique_name = "serve"
148
-
149
- def __init__(self):
150
+ def __init__(
151
+ self,
152
+ port: int = DEFAULT_PORT,
153
+ log_level: str = DEFAULT_LOG_LEVEL,
154
+ ctx_size: int = DEFAULT_CTX_SIZE,
155
+ tray: bool = False,
156
+ log_file: str = None,
157
+ llamacpp_backend: str = DEFAULT_LLAMACPP_BACKEND,
158
+ ):
150
159
  super().__init__()
151
160
 
161
+ # Save args as members
162
+ self.port = port
163
+ self.log_level = log_level
164
+ self.ctx_size = ctx_size
165
+ self.tray = tray
166
+ self.log_file = log_file
167
+ self.llamacpp_backend = llamacpp_backend
168
+
152
169
  # Initialize FastAPI app
153
170
  self.app = FastAPI(lifespan=lifespan)
154
171
 
@@ -186,9 +203,6 @@ class Server(ManagementTool):
186
203
  self.output_tokens = None
187
204
  self.decode_token_times = None
188
205
 
189
- # Input truncation settings
190
- self.truncate_inputs = False
191
-
192
206
  # Store debug logging state
193
207
  self.debug_logging_enabled = logging.getLogger().isEnabledFor(logging.DEBUG)
194
208
 
@@ -241,66 +255,18 @@ class Server(ManagementTool):
241
255
  self.app.post(f"{prefix}/reranking")(self.reranking)
242
256
  self.app.post(f"{prefix}/rerank")(self.reranking)
243
257
 
244
- @staticmethod
245
- def parser(add_help: bool = True) -> argparse.ArgumentParser:
246
- parser = __class__.helpful_parser(
247
- short_description="Launch an industry-standard LLM server",
248
- add_help=add_help,
249
- )
250
-
251
- # Only add the tray option on Windows
252
- if platform.system() == "Windows":
253
- parser.add_argument(
254
- "--tray",
255
- action="store_true",
256
- help="Run the server in system tray mode",
257
- )
258
-
259
- parser.add_argument(
260
- "--port",
261
- required=False,
262
- type=int,
263
- default=DEFAULT_PORT,
264
- help=f"Port number to run the server on (default: {DEFAULT_PORT})",
265
- )
266
- parser.add_argument(
267
- "--log-level",
268
- required=False,
269
- type=str,
270
- default=DEFAULT_LOG_LEVEL,
271
- choices=["critical", "error", "warning", "info", "debug", "trace"],
272
- help=f"Logging level (default: {DEFAULT_LOG_LEVEL})",
273
- )
274
-
275
- parser.add_argument(
276
- "--log-file",
277
- required=False,
278
- type=str,
279
- help="Path to the log file",
280
- )
281
-
282
- return parser
283
-
284
258
  def _setup_server_common(
285
259
  self,
286
- port: int,
287
- truncate_inputs: bool = False,
288
- log_level: str = DEFAULT_LOG_LEVEL,
289
260
  tray: bool = False,
290
- log_file: str = None,
291
261
  threaded_mode: bool = False,
292
262
  ):
293
263
  """
294
264
  Common setup logic shared between run() and run_in_thread().
295
265
 
296
266
  Args:
297
- port: Port number for the server
298
- truncate_inputs: Whether to truncate inputs if they exceed max length
299
- log_level: Logging level to configure
267
+ tray: Whether to run the server in tray mode
300
268
  threaded_mode: Whether this is being set up for threaded execution
301
269
  """
302
- # Store truncation settings
303
- self.truncate_inputs = truncate_inputs
304
270
 
305
271
  # Define TRACE level
306
272
  logging.TRACE = 9 # Lower than DEBUG which is 10
@@ -318,18 +284,20 @@ class Server(ManagementTool):
318
284
  logging.getLogger("uvicorn.error").setLevel(logging.WARNING)
319
285
  else:
320
286
  # Configure logging to match uvicorn's format
321
- logging_level = getattr(logging, log_level.upper())
287
+ logging_level = getattr(logging, self.log_level.upper())
322
288
 
323
289
  # Set up file handler for logging to lemonade.log
324
290
  uvicorn_formatter = uvicorn.logging.DefaultFormatter(
325
291
  fmt="%(levelprefix)s %(message)s",
326
292
  use_colors=True,
327
293
  )
328
- if not log_file:
329
- log_file = tempfile.NamedTemporaryFile(
294
+ if not self.log_file:
295
+ self.log_file = tempfile.NamedTemporaryFile(
330
296
  prefix="lemonade_", suffix=".log", delete=False
331
297
  ).name
332
- file_handler = logging.FileHandler(log_file, mode="a", encoding="utf-8")
298
+ file_handler = logging.FileHandler(
299
+ self.log_file, mode="a", encoding="utf-8"
300
+ )
333
301
  file_handler.setLevel(logging_level)
334
302
  file_handler.setFormatter(uvicorn_formatter)
335
303
 
@@ -349,12 +317,12 @@ class Server(ManagementTool):
349
317
  self.debug_logging_enabled = logging.getLogger().isEnabledFor(logging.DEBUG)
350
318
  if tray:
351
319
  # Save original stdout/stderr
352
- sys.stdout = OutputDuplicator(log_file, sys.stdout)
353
- sys.stderr = OutputDuplicator(log_file, sys.stderr)
320
+ sys.stdout = OutputDuplicator(self.log_file, sys.stdout)
321
+ sys.stderr = OutputDuplicator(self.log_file, sys.stderr)
354
322
 
355
323
  # Open lemonade server in tray mode
356
324
  # lambda function used for deferred instantiation and thread safety
357
- LemonadeTray(log_file, port, lambda: Server()).run()
325
+ LemonadeTray(self.log_file, self.port, lambda: self).run()
358
326
  sys.exit(0)
359
327
 
360
328
  if self.debug_logging_enabled:
@@ -363,47 +331,26 @@ class Server(ManagementTool):
363
331
 
364
332
  # Let the app know what port it's running on, so
365
333
  # that the lifespan can access it
366
- self.app.port = port
334
+ self.app.port = self.port
367
335
 
368
- def run(
369
- self,
370
- # ManagementTool has a required cache_dir arg, but
371
- # we always use the default cache directory
372
- _=None,
373
- port: int = DEFAULT_PORT,
374
- log_level: str = DEFAULT_LOG_LEVEL,
375
- truncate_inputs: bool = False,
376
- tray: bool = False,
377
- log_file: str = None,
378
- ):
336
+ def run(self):
379
337
  # Common setup
380
338
  self._setup_server_common(
381
- port=port,
382
- truncate_inputs=truncate_inputs,
383
- log_level=log_level,
384
339
  threaded_mode=False,
385
- tray=tray,
386
- log_file=log_file,
340
+ tray=self.tray,
387
341
  )
388
342
 
389
- uvicorn.run(self.app, host="localhost", port=port, log_level=log_level)
343
+ uvicorn.run(
344
+ self.app, host="localhost", port=self.port, log_level=self.log_level
345
+ )
390
346
 
391
- def run_in_thread(
392
- self,
393
- port: int = DEFAULT_PORT,
394
- host: str = "localhost",
395
- log_level: str = "warning",
396
- truncate_inputs: bool = False,
397
- ):
347
+ def run_in_thread(self, host: str = "localhost"):
398
348
  """
399
349
  Set up the server for running in a thread.
400
350
  Returns a uvicorn server instance that can be controlled externally.
401
351
  """
402
352
  # Common setup
403
353
  self._setup_server_common(
404
- port=port,
405
- truncate_inputs=truncate_inputs,
406
- log_level=log_level,
407
354
  threaded_mode=True,
408
355
  tray=False,
409
356
  )
@@ -418,8 +365,8 @@ class Server(ManagementTool):
418
365
  config = Config(
419
366
  app=self.app,
420
367
  host=host,
421
- port=port,
422
- log_level=log_level,
368
+ port=self.port,
369
+ log_level=self.log_level,
423
370
  log_config=None,
424
371
  )
425
372
 
@@ -1099,29 +1046,21 @@ class Server(ManagementTool):
1099
1046
  )
1100
1047
  self.input_tokens = len(input_ids[0])
1101
1048
 
1102
- if (
1103
- self.llm_loaded.max_prompt_length
1104
- and self.input_tokens > self.llm_loaded.max_prompt_length
1105
- ):
1106
- if self.truncate_inputs:
1107
- # Truncate input ids
1108
- truncate_amount = self.input_tokens - self.llm_loaded.max_prompt_length
1109
- input_ids = input_ids[: self.llm_loaded.max_prompt_length]
1110
-
1111
- # Update token count
1112
- self.input_tokens = len(input_ids)
1113
-
1114
- # Show warning message
1115
- truncation_message = (
1116
- f"Input exceeded {self.llm_loaded.max_prompt_length} tokens. "
1117
- f"Truncated {truncate_amount} tokens."
1118
- )
1119
- logging.warning(truncation_message)
1120
- else:
1121
- raise RuntimeError(
1122
- f"Prompt tokens ({self.input_tokens}) cannot be greater "
1123
- f"than the model's max prompt length ({self.llm_loaded.max_prompt_length})"
1124
- )
1049
+ # For non-llamacpp recipes, truncate inputs to ctx_size if needed
1050
+ if self.llm_loaded.recipe != "llamacpp" and self.input_tokens > self.ctx_size:
1051
+ # Truncate input ids
1052
+ truncate_amount = self.input_tokens - self.ctx_size
1053
+ input_ids = input_ids[: self.ctx_size]
1054
+
1055
+ # Update token count
1056
+ self.input_tokens = len(input_ids)
1057
+
1058
+ # Show warning message
1059
+ truncation_message = (
1060
+ f"Input exceeded {self.ctx_size} tokens. "
1061
+ f"Truncated {truncate_amount} tokens from the beginning."
1062
+ )
1063
+ logging.warning(truncation_message)
1125
1064
 
1126
1065
  # Log the input tokens early to avoid this not showing due to potential crashes
1127
1066
  logging.debug(f"Input Tokens: {self.input_tokens}")
@@ -1317,7 +1256,7 @@ class Server(ManagementTool):
1317
1256
  self.tokenizer = None
1318
1257
  self.model = None
1319
1258
 
1320
- default_message = f"model {model_reference} not found"
1259
+ default_message = "see stack trace and error message below"
1321
1260
  if message:
1322
1261
  detail = message
1323
1262
  else:
@@ -1438,6 +1377,8 @@ class Server(ManagementTool):
1438
1377
  self.llama_server_process = llamacpp.server_load(
1439
1378
  model_config=config_to_use,
1440
1379
  telemetry=self.llama_telemetry,
1380
+ backend=self.llamacpp_backend,
1381
+ ctx_size=self.ctx_size,
1441
1382
  )
1442
1383
 
1443
1384
  else: