lemonade-sdk 8.1.0__py3-none-any.whl → 8.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lemonade-sdk might be problematic. Click here for more details.

@@ -10,21 +10,138 @@ import requests
10
10
  import lemonade.common.printing as printing
11
11
  from lemonade.tools.adapter import PassthroughTokenizer, ModelAdapter
12
12
 
13
- LLAMA_VERSION = "b5902"
13
+ from lemonade.common.system_info import get_system_info
14
14
 
15
+ from dotenv import set_key, load_dotenv
15
16
 
16
- def get_llama_folder_path():
17
+ LLAMA_VERSION_VULKAN = "b6097"
18
+ LLAMA_VERSION_ROCM = "b1021"
19
+
20
+
21
+ def identify_rocm_arch_from_name(device_name: str) -> str | None:
22
+ """
23
+ Identify the appropriate ROCm target architecture based on the device name
24
+ """
25
+ device_name_lower = device_name.lower()
26
+ if "radeon" not in device_name_lower:
27
+ return None
28
+
29
+ # Check iGPUs
30
+ # STX Halo iGPUs (gfx1151 architecture)
31
+ # Radeon 8050S Graphics / Radeon 8060S Graphics
32
+ target_arch = None
33
+ if any(halo_igpu in device_name_lower.lower() for halo_igpu in ["8050s", "8060s"]):
34
+ return "gfx1151"
35
+
36
+ # Check dGPUs
37
+ # RDNA4 GPUs (gfx120X architecture)
38
+ # AMD Radeon AI PRO R9700, AMD Radeon RX 9070 XT, AMD Radeon RX 9070 GRE,
39
+ # AMD Radeon RX 9070, AMD Radeon RX 9060 XT
40
+ if any(
41
+ rdna4_gpu in device_name_lower.lower()
42
+ for rdna4_gpu in ["r9700", "9060", "9070"]
43
+ ):
44
+ return "gfx120X"
45
+
46
+ # RDNA3 GPUs (gfx110X architecture)
47
+ # AMD Radeon PRO V710, AMD Radeon PRO W7900 Dual Slot, AMD Radeon PRO W7900,
48
+ # AMD Radeon PRO W7800 48GB, AMD Radeon PRO W7800, AMD Radeon PRO W7700,
49
+ # AMD Radeon RX 7900 XTX, AMD Radeon RX 7900 XT, AMD Radeon RX 7900 GRE,
50
+ # AMD Radeon RX 7800 XT, AMD Radeon RX 7700 XT
51
+ elif any(
52
+ rdna3_gpu in device_name_lower.lower()
53
+ for rdna3_gpu in ["7700", "7800", "7900", "v710"]
54
+ ):
55
+ return "gfx110X"
56
+
57
+ return None
58
+
59
+
60
+ def identify_rocm_arch() -> str:
61
+ """
62
+ Identify the appropriate ROCm target architecture based on the device info
63
+ Returns tuple of (architecture, gpu_type) where gpu_type is 'igpu' or 'dgpu'
64
+ """
65
+
66
+ # Check for integrated and discrete AMD GPUs
67
+ system_info = get_system_info()
68
+ amd_igpu = system_info.get_amd_igpu_device()
69
+ amd_dgpu = system_info.get_amd_dgpu_devices()
70
+ target_arch = None
71
+ for gpu in [amd_igpu] + amd_dgpu:
72
+ if gpu.get("available") and gpu.get("name"):
73
+ target_arch = identify_rocm_arch_from_name(gpu["name"].lower())
74
+ if target_arch:
75
+ break
76
+
77
+ return target_arch
78
+
79
+
80
+ def identify_hip_id() -> str:
81
+ """
82
+ Identify the HIP ID
83
+ """
84
+ # Get HIP devices
85
+ hip_devices = get_hip_devices()
86
+ logging.debug(f"HIP devices found: {hip_devices}")
87
+ if len(hip_devices) == 0:
88
+ raise ValueError("No HIP devices found when identifying HIP ID")
89
+
90
+ # Identify HIP devices that are compatible with our ROCm builds
91
+ rocm_devices = []
92
+ for device in hip_devices:
93
+ device_id, device_name = device
94
+ if identify_rocm_arch_from_name(device_name):
95
+ rocm_devices.append([device_id, device_name])
96
+ logging.debug(f"ROCm devices found: {rocm_devices}")
97
+
98
+ # If no ROCm devices are found, use the last HIP device
99
+ # This might be needed in some scenarios where HIP reports generic device names
100
+ # Example: "AMD Radeon Graphics" for STX Halo iGPU on Ubuntu 24.04
101
+ if len(rocm_devices) == 0:
102
+ rocm_devices = [hip_devices[-1]]
103
+ logging.warning(
104
+ "No ROCm devices found when identifying HIP ID. "
105
+ f"Falling back to the following device: {rocm_devices[0]}"
106
+ )
107
+ elif len(rocm_devices) > 1:
108
+ logging.warning(
109
+ f"Multiple ROCm devices found when identifying HIP ID: {rocm_devices}"
110
+ "The last device will be used."
111
+ )
112
+
113
+ # Select the last device
114
+ device_selected = rocm_devices[-1]
115
+ logging.debug(f"Selected ROCm device: {device_selected}")
116
+
117
+ # Return the device ID
118
+ return device_selected[0]
119
+
120
+
121
+ def get_llama_version(backend: str) -> str:
122
+ """
123
+ Select the appropriate llama.cpp version based on the backend
124
+ """
125
+ if backend == "rocm":
126
+ return LLAMA_VERSION_ROCM
127
+ elif backend == "vulkan":
128
+ return LLAMA_VERSION_VULKAN
129
+ else:
130
+ raise ValueError(f"Unsupported backend: {backend}")
131
+
132
+
133
+ def get_llama_folder_path(backend: str):
17
134
  """
18
135
  Get path for llama.cpp platform-specific executables folder
19
136
  """
20
- return os.path.join(os.path.dirname(sys.executable), "llamacpp")
137
+ return os.path.join(os.path.dirname(sys.executable), backend, "llama_server")
21
138
 
22
139
 
23
- def get_llama_exe_path(exe_name):
140
+ def get_llama_exe_path(exe_name: str, backend: str):
24
141
  """
25
142
  Get path to platform-specific llama-server executable
26
143
  """
27
- base_dir = get_llama_folder_path()
144
+ base_dir = get_llama_folder_path(backend)
28
145
  if platform.system().lower() == "windows":
29
146
  return os.path.join(base_dir, f"{exe_name}.exe")
30
147
  else: # Linux/Ubuntu
@@ -37,33 +154,33 @@ def get_llama_exe_path(exe_name):
37
154
  return os.path.join(base_dir, exe_name)
38
155
 
39
156
 
40
- def get_llama_server_exe_path():
157
+ def get_llama_server_exe_path(backend: str):
41
158
  """
42
159
  Get path to platform-specific llama-server executable
43
160
  """
44
- return get_llama_exe_path("llama-server")
161
+ return get_llama_exe_path("llama-server", backend)
45
162
 
46
163
 
47
- def get_llama_cli_exe_path():
164
+ def get_llama_cli_exe_path(backend: str):
48
165
  """
49
166
  Get path to platform-specific llama-cli executable
50
167
  """
51
- return get_llama_exe_path("llama-cli")
168
+ return get_llama_exe_path("llama-cli", backend)
52
169
 
53
170
 
54
- def get_version_txt_path():
171
+ def get_version_txt_path(backend: str):
55
172
  """
56
173
  Get path to text file that contains version information
57
174
  """
58
- return os.path.join(get_llama_folder_path(), "version.txt")
175
+ return os.path.join(get_llama_folder_path(backend), "version.txt")
59
176
 
60
177
 
61
- def get_llama_installed_version():
178
+ def get_llama_installed_version(backend: str):
62
179
  """
63
180
  Gets version of installed llama.cpp
64
181
  Returns None if llama.cpp is not installed
65
182
  """
66
- version_txt_path = get_version_txt_path()
183
+ version_txt_path = get_version_txt_path(backend)
67
184
  if os.path.exists(version_txt_path):
68
185
  with open(version_txt_path, "r", encoding="utf-8") as f:
69
186
  llama_installed_version = f.read()
@@ -71,24 +188,48 @@ def get_llama_installed_version():
71
188
  return None
72
189
 
73
190
 
74
- def get_binary_url_and_filename(version):
191
+ def get_binary_url_and_filename(backend: str, target_arch: str = None):
75
192
  """
76
- Get the appropriate llama.cpp binary URL and filename based on platform
193
+ Get the appropriate binary URL and filename based on platform and backend
194
+
195
+ Args:
196
+ backend: Backend to use
77
197
  """
78
198
  system = platform.system().lower()
79
199
 
80
- if system == "windows":
81
- filename = f"llama-{version}-bin-win-vulkan-x64.zip"
82
- elif system == "linux":
83
- filename = f"llama-{version}-bin-ubuntu-vulkan-x64.zip"
200
+ if backend == "rocm":
201
+
202
+ # ROCm support from lemonade-sdk/llamacpp-rocm
203
+ repo = "lemonade-sdk/llamacpp-rocm"
204
+ version = LLAMA_VERSION_ROCM
205
+ if system == "windows":
206
+ filename = f"llama-{version}-windows-rocm-{target_arch}-x64.zip"
207
+ elif system == "linux":
208
+ filename = f"llama-{version}-ubuntu-rocm-{target_arch}-x64.zip"
209
+ else:
210
+ raise NotImplementedError(
211
+ f"Platform {system} not supported for ROCm llamacpp. Supported: Windows, Ubuntu Linux"
212
+ )
213
+
214
+ elif backend == "vulkan":
215
+ # Original Vulkan support from ggml-org/llama.cpp
216
+ repo = "ggml-org/llama.cpp"
217
+ version = LLAMA_VERSION_VULKAN
218
+ if system == "windows":
219
+ filename = f"llama-{version}-bin-win-vulkan-x64.zip"
220
+ elif system == "linux":
221
+ filename = f"llama-{version}-bin-ubuntu-vulkan-x64.zip"
222
+ else:
223
+ raise NotImplementedError(
224
+ f"Platform {system} not supported for Vulkan llamacpp. Supported: Windows, Ubuntu Linux"
225
+ )
84
226
  else:
227
+ supported_backends = ["vulkan", "rocm"]
85
228
  raise NotImplementedError(
86
- f"Platform {system} not supported for llamacpp. Supported: Windows, Ubuntu Linux"
229
+ f"Unsupported backend: {backend}. Supported backends: {supported_backends}"
87
230
  )
88
231
 
89
- url = (
90
- f"https://github.com/ggml-org/llama.cpp/releases/download/{version}/{filename}"
91
- )
232
+ url = f"https://github.com/{repo}/releases/download/{version}/{filename}"
92
233
  return url, filename
93
234
 
94
235
 
@@ -122,7 +263,7 @@ def validate_platform_support():
122
263
  )
123
264
 
124
265
 
125
- def install_llamacpp():
266
+ def install_llamacpp(backend):
126
267
  """
127
268
  Installs or upgrades llama.cpp binaries if needed
128
269
  """
@@ -130,56 +271,110 @@ def install_llamacpp():
130
271
  # Exception will be thrown if platform is not supported
131
272
  validate_platform_support()
132
273
 
133
- # Installation location for llama.cpp
134
- llama_folder_path = get_llama_folder_path()
274
+ version = get_llama_version(backend)
275
+
276
+ # Get platform-specific paths at runtime
277
+ llama_server_exe_dir = get_llama_folder_path(backend)
278
+ llama_server_exe_path = get_llama_server_exe_path(backend)
135
279
 
136
280
  # Check whether the llamacpp install needs an upgrade
137
- if os.path.exists(llama_folder_path):
138
- if get_llama_installed_version() != LLAMA_VERSION:
281
+ version_txt_path = os.path.join(llama_server_exe_dir, "version.txt")
282
+ backend_txt_path = os.path.join(llama_server_exe_dir, "backend.txt")
283
+
284
+ logging.info(f"Using backend: {backend}")
285
+
286
+ if os.path.exists(version_txt_path) and os.path.exists(backend_txt_path):
287
+ with open(version_txt_path, "r", encoding="utf-8") as f:
288
+ llamacpp_installed_version = f.read().strip()
289
+ with open(backend_txt_path, "r", encoding="utf-8") as f:
290
+ llamacpp_installed_backend = f.read().strip()
291
+
292
+ if (
293
+ llamacpp_installed_version != version
294
+ or llamacpp_installed_backend != backend
295
+ ):
139
296
  # Remove the existing install, which will trigger a new install
140
297
  # in the next code block
141
- shutil.rmtree(llama_folder_path)
298
+ shutil.rmtree(llama_server_exe_dir)
299
+ elif os.path.exists(version_txt_path):
300
+ # Old installation without backend tracking - remove to upgrade
301
+ shutil.rmtree(llama_server_exe_dir)
142
302
 
143
303
  # Download llama.cpp server if it isn't already available
144
- if not os.path.exists(llama_folder_path):
145
- # Download llama.cpp server zip
146
- llama_zip_url, filename = get_binary_url_and_filename(LLAMA_VERSION)
147
- llama_zip_path = os.path.join(os.path.dirname(sys.executable), filename)
148
- logging.info(f"Downloading llama.cpp server from {llama_zip_url}")
304
+ if not os.path.exists(llama_server_exe_path):
305
+
306
+ # Create the directory
307
+ os.makedirs(llama_server_exe_dir, exist_ok=True)
308
+
309
+ # Identify the target architecture (only needed for ROCm)
310
+ target_arch = None
311
+ if backend == "rocm":
312
+ # Identify the target architecture
313
+ target_arch = identify_rocm_arch()
314
+ if not target_arch:
315
+ system = platform.system().lower()
316
+ if system == "linux":
317
+ hint = (
318
+ "Hint: If you think your device is supported, "
319
+ "running `sudo update-pciids` may help identify your hardware."
320
+ )
321
+ else:
322
+ hint = ""
323
+ raise ValueError(
324
+ "ROCm backend selected but no compatible ROCm target architecture found. "
325
+ "See https://github.com/lemonade-sdk/lemonade?tab=readme-ov-file#supported-configurations "
326
+ f"for supported configurations. {hint}"
327
+ )
328
+
329
+ # Direct download for Vulkan/ROCm
330
+ llama_archive_url, filename = get_binary_url_and_filename(backend, target_arch)
331
+ llama_archive_path = os.path.join(llama_server_exe_dir, filename)
332
+ logging.info(f"Downloading llama.cpp server from {llama_archive_url}")
149
333
 
150
- with requests.get(llama_zip_url, stream=True) as r:
334
+ with requests.get(llama_archive_url, stream=True) as r:
151
335
  r.raise_for_status()
152
- with open(llama_zip_path, "wb") as f:
336
+ with open(llama_archive_path, "wb") as f:
153
337
  for chunk in r.iter_content(chunk_size=8192):
154
338
  f.write(chunk)
155
339
 
156
- # Extract zip
157
- logging.info(f"Extracting {llama_zip_path} to {llama_folder_path}")
158
- with zipfile.ZipFile(llama_zip_path, "r") as zip_ref:
159
- zip_ref.extractall(llama_folder_path)
340
+ logging.info(f"Extracting {filename} to {llama_server_exe_dir}")
341
+ if filename.endswith(".zip"):
342
+ with zipfile.ZipFile(llama_archive_path, "r") as zip_ref:
343
+ zip_ref.extractall(llama_server_exe_dir)
344
+ else:
345
+ raise NotImplementedError(f"Unsupported archive format: {filename}")
346
+
347
+ # Identify and set HIP ID
348
+ if backend == "rocm":
349
+ hip_id = identify_hip_id()
350
+ env_file_path = os.path.join(llama_server_exe_dir, ".env")
351
+ set_key(env_file_path, "HIP_VISIBLE_DEVICES", str(hip_id))
160
352
 
161
353
  # Make executable on Linux - need to update paths after extraction
162
354
  if platform.system().lower() == "linux":
163
355
  # Re-get the paths since extraction might have changed the directory structure
164
- for updated_exe_path in [
165
- get_llama_server_exe_path(),
166
- get_llama_cli_exe_path(),
167
- ]:
168
- if os.path.exists(updated_exe_path):
169
- os.chmod(updated_exe_path, 0o755)
170
- logging.info(f"Set executable permissions for {updated_exe_path}")
356
+ exe_paths = [
357
+ (get_llama_server_exe_path(backend), "llama-server"),
358
+ (get_llama_cli_exe_path(backend), "llama-cli"),
359
+ ]
360
+
361
+ for exe_path, exe_name in exe_paths:
362
+ if os.path.exists(exe_path):
363
+ os.chmod(exe_path, 0o755)
364
+ logging.info(f"Set executable permissions for {exe_path}")
171
365
  else:
172
366
  logging.warning(
173
- f"Could not find llama.cpp executable at {updated_exe_path}"
367
+ f"Could not find {exe_name} executable at {exe_path}"
174
368
  )
175
369
 
176
- # Save version.txt
177
- with open(get_version_txt_path(), "w", encoding="utf-8") as vf:
178
- vf.write(LLAMA_VERSION)
370
+ # Save version and backend info
371
+ with open(version_txt_path, "w", encoding="utf-8") as vf:
372
+ vf.write(version)
373
+ with open(backend_txt_path, "w", encoding="utf-8") as bf:
374
+ bf.write(backend)
179
375
 
180
- # Delete zip file
181
- os.remove(llama_zip_path)
182
- logging.info("Cleaned up zip file")
376
+ # Delete the archive file
377
+ os.remove(llama_archive_path)
183
378
 
184
379
 
185
380
  def parse_checkpoint(checkpoint: str) -> tuple[str, str | None]:
@@ -525,6 +720,14 @@ class LlamaCppAdapter(ModelAdapter):
525
720
  try:
526
721
  # Set up environment with library path for Linux
527
722
  env = os.environ.copy()
723
+
724
+ # Load environment variables from .env file in the executable directory
725
+ exe_dir = os.path.dirname(self.executable)
726
+ env_file_path = os.path.join(exe_dir, ".env")
727
+ if os.path.exists(env_file_path):
728
+ load_dotenv(env_file_path, override=True)
729
+ env.update(os.environ)
730
+
528
731
  if self.lib_dir and os.name != "nt": # Not Windows
529
732
  current_ld_path = env.get("LD_LIBRARY_PATH", "")
530
733
  if current_ld_path:
@@ -610,3 +813,68 @@ class LlamaCppAdapter(ModelAdapter):
610
813
  error_msg = f"Failed to run llama.cpp command: {str(e)}\n"
611
814
  error_msg += f"Command: {' '.join(cmd)}"
612
815
  raise Exception(error_msg)
816
+
817
+
818
+ def get_hip_devices():
819
+ """Get list of HIP devices with their IDs and names."""
820
+ import ctypes
821
+ import sys
822
+ import os
823
+ import glob
824
+ from ctypes import c_int, POINTER
825
+ from ctypes.util import find_library
826
+
827
+ # Get llama.cpp path
828
+ rocm_path = get_llama_folder_path("rocm")
829
+
830
+ # Load HIP library
831
+ hip_library_pattern = (
832
+ "amdhip64*.dll" if sys.platform.startswith("win") else "libamdhip64*.so"
833
+ )
834
+ search_pattern = os.path.join(rocm_path, hip_library_pattern)
835
+ matching_files = glob.glob(search_pattern)
836
+ if not matching_files:
837
+ raise RuntimeError(
838
+ f"Could not find HIP runtime library matching pattern: {search_pattern}"
839
+ )
840
+ try:
841
+ libhip = ctypes.CDLL(matching_files[0])
842
+ except OSError:
843
+ raise RuntimeError(f"Could not load HIP runtime library from {path}")
844
+
845
+ # Setup function signatures
846
+ hipError_t = c_int
847
+ hipDeviceProp_t = ctypes.c_char * 2048
848
+ libhip.hipGetDeviceCount.restype = hipError_t
849
+ libhip.hipGetDeviceCount.argtypes = [POINTER(c_int)]
850
+ libhip.hipGetDeviceProperties.restype = hipError_t
851
+ libhip.hipGetDeviceProperties.argtypes = [POINTER(hipDeviceProp_t), c_int]
852
+ libhip.hipGetErrorString.restype = ctypes.c_char_p
853
+ libhip.hipGetErrorString.argtypes = [hipError_t]
854
+
855
+ # Get device count
856
+ device_count = c_int()
857
+ err = libhip.hipGetDeviceCount(ctypes.byref(device_count))
858
+ if err != 0:
859
+ logging.error(
860
+ "hipGetDeviceCount failed:", libhip.hipGetErrorString(err).decode()
861
+ )
862
+ return []
863
+
864
+ # Get device properties
865
+ devices = []
866
+ for i in range(device_count.value):
867
+ prop = hipDeviceProp_t()
868
+ err = libhip.hipGetDeviceProperties(ctypes.byref(prop), i)
869
+ if err != 0:
870
+ logging.error(
871
+ f"hipGetDeviceProperties failed for device {i}:",
872
+ libhip.hipGetErrorString(err).decode(),
873
+ )
874
+ continue
875
+
876
+ # Extract device name from HIP device properties
877
+ device_name = ctypes.string_at(prop, 256).decode("utf-8").rstrip("\x00")
878
+ devices.append([i, device_name])
879
+
880
+ return devices
@@ -109,7 +109,7 @@ class Cache(ManagementTool):
109
109
  # pylint: disable=pointless-statement,f-string-without-interpolation
110
110
  f"""
111
111
  A set of functions for managing the lemonade build cache. The default
112
- cache location is {lemonade_cache.DEFAULT_CACHE_DIR}, and can also be
112
+ cache location is {lemonade_cache.DEFAULT_CACHE_DIR}, and can also be
113
113
  selected with
114
114
  the global --cache-dir option or the LEMONADE_CACHE_DIR environment variable.
115
115
 
@@ -633,9 +633,9 @@ class OgaLoad(FirstTool):
633
633
  model_generate.generate_hybrid_model(
634
634
  input_model=input_model_path,
635
635
  output_dir=output_model_path,
636
- # script_option="jit_npu",
637
- # mode="bf16",
638
- # dml_only=False,
636
+ script_option="jit_npu",
637
+ mode="bf16",
638
+ dml_only=False,
639
639
  )
640
640
  except Exception as e:
641
641
  raise RuntimeError(
@@ -1,5 +1,4 @@
1
1
  import os
2
- import sys
3
2
  import logging
4
3
  import time
5
4
  import subprocess
@@ -9,6 +8,7 @@ import platform
9
8
 
10
9
  import requests
11
10
  from tabulate import tabulate
11
+ from dotenv import load_dotenv
12
12
  from fastapi import HTTPException, status
13
13
  from fastapi.responses import StreamingResponse
14
14
 
@@ -29,8 +29,6 @@ from lemonade.tools.llamacpp.utils import (
29
29
  download_gguf,
30
30
  )
31
31
 
32
- LLAMA_VERSION = "b5787"
33
-
34
32
 
35
33
  def llamacpp_address(port: int) -> str:
36
34
  """
@@ -45,45 +43,6 @@ def llamacpp_address(port: int) -> str:
45
43
  return f"http://127.0.0.1:{port}/v1"
46
44
 
47
45
 
48
- def get_llama_server_paths():
49
- """
50
- Get platform-specific paths for llama server directory and executable
51
- """
52
- base_dir = os.path.join(os.path.dirname(sys.executable), "llama_server")
53
-
54
- if platform.system().lower() == "windows":
55
- return base_dir, os.path.join(base_dir, "llama-server.exe")
56
- else: # Linux/Ubuntu
57
- # Check if executable exists in build/bin subdirectory (Current Ubuntu structure)
58
- build_bin_path = os.path.join(base_dir, "build", "bin", "llama-server")
59
- if os.path.exists(build_bin_path):
60
- return base_dir, build_bin_path
61
- else:
62
- # Fallback to root directory
63
- return base_dir, os.path.join(base_dir, "llama-server")
64
-
65
-
66
- def get_binary_url_and_filename(version):
67
- """
68
- Get the appropriate binary URL and filename based on platform
69
- """
70
- system = platform.system().lower()
71
-
72
- if system == "windows":
73
- filename = f"llama-{version}-bin-win-vulkan-x64.zip"
74
- elif system == "linux":
75
- filename = f"llama-{version}-bin-ubuntu-vulkan-x64.zip"
76
- else:
77
- raise NotImplementedError(
78
- f"Platform {system} not supported for llamacpp. Supported: Windows, Ubuntu Linux"
79
- )
80
-
81
- url = (
82
- f"https://github.com/ggml-org/llama.cpp/releases/download/{version}/{filename}"
83
- )
84
- return url, filename
85
-
86
-
87
46
  class LlamaTelemetry:
88
47
  """
89
48
  Manages telemetry data collection and display for llama server.
@@ -125,7 +84,7 @@ class LlamaTelemetry:
125
84
  device_count = int(vulkan_match.group(1))
126
85
  if device_count > 0:
127
86
  logging.info(
128
- f"GPU acceleration active: {device_count} Vulkan device(s) "
87
+ f"GPU acceleration active: {device_count} device(s) "
129
88
  "detected by llama-server"
130
89
  )
131
90
  return
@@ -236,6 +195,8 @@ def _launch_llama_subprocess(
236
195
  snapshot_files: dict,
237
196
  use_gpu: bool,
238
197
  telemetry: LlamaTelemetry,
198
+ backend: str,
199
+ ctx_size: int,
239
200
  supports_embeddings: bool = False,
240
201
  supports_reranking: bool = False,
241
202
  ) -> subprocess.Popen:
@@ -246,6 +207,7 @@ def _launch_llama_subprocess(
246
207
  snapshot_files: Dictionary of model files to load
247
208
  use_gpu: Whether to use GPU acceleration
248
209
  telemetry: Telemetry object for tracking performance metrics
210
+ backend: Backend to use (e.g., 'vulkan', 'rocm')
249
211
  supports_embeddings: Whether the model supports embeddings
250
212
  supports_reranking: Whether the model supports reranking
251
213
 
@@ -254,10 +216,16 @@ def _launch_llama_subprocess(
254
216
  """
255
217
 
256
218
  # Get the current executable path (handles both Windows and Ubuntu structures)
257
- exe_path = get_llama_server_exe_path()
219
+ exe_path = get_llama_server_exe_path(backend)
258
220
 
259
221
  # Build the base command
260
- base_command = [exe_path, "-m", snapshot_files["variant"]]
222
+ base_command = [
223
+ exe_path,
224
+ "-m",
225
+ snapshot_files["variant"],
226
+ "--ctx-size",
227
+ str(ctx_size),
228
+ ]
261
229
  if "mmproj" in snapshot_files:
262
230
  base_command.extend(["--mmproj", snapshot_files["mmproj"]])
263
231
  if not use_gpu:
@@ -288,6 +256,15 @@ def _launch_llama_subprocess(
288
256
 
289
257
  # Set up environment with library path for Linux
290
258
  env = os.environ.copy()
259
+
260
+ # Load environment variables from .env file in the executable directory
261
+ exe_dir = os.path.dirname(exe_path)
262
+ env_file_path = os.path.join(exe_dir, ".env")
263
+ if os.path.exists(env_file_path):
264
+ load_dotenv(env_file_path, override=True)
265
+ env.update(os.environ)
266
+ logging.debug(f"Loaded environment variables from {env_file_path}")
267
+
291
268
  if platform.system().lower() == "linux":
292
269
  lib_dir = os.path.dirname(exe_path) # Same directory as the executable
293
270
  current_ld_path = env.get("LD_LIBRARY_PATH", "")
@@ -320,18 +297,17 @@ def _launch_llama_subprocess(
320
297
  return process
321
298
 
322
299
 
323
- def server_load(model_config: PullConfig, telemetry: LlamaTelemetry):
300
+ def server_load(
301
+ model_config: PullConfig, telemetry: LlamaTelemetry, backend: str, ctx_size: int
302
+ ):
324
303
  # Install and/or update llama.cpp if needed
325
304
  try:
326
- install_llamacpp()
305
+ install_llamacpp(backend)
327
306
  except NotImplementedError as e:
328
307
  raise HTTPException(
329
308
  status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e)
330
309
  )
331
310
 
332
- # Get platform-specific paths at runtime
333
- llama_server_exe_path = get_llama_server_exe_path()
334
-
335
311
  # Download the gguf to the hugging face cache
336
312
  snapshot_files = download_gguf(model_config.checkpoint, model_config.mmproj)
337
313
  logging.debug(f"GGUF file paths: {snapshot_files}")
@@ -342,14 +318,13 @@ def server_load(model_config: PullConfig, telemetry: LlamaTelemetry):
342
318
  supports_embeddings = "embeddings" in model_info.get("labels", [])
343
319
  supports_reranking = "reranking" in model_info.get("labels", [])
344
320
 
345
- # Start the llama-serve.exe process
346
- logging.debug(f"Using llama_server for GGUF model: {llama_server_exe_path}")
347
-
348
321
  # Attempt loading on GPU first
349
322
  llama_server_process = _launch_llama_subprocess(
350
323
  snapshot_files,
351
324
  use_gpu=True,
352
325
  telemetry=telemetry,
326
+ backend=backend,
327
+ ctx_size=ctx_size,
353
328
  supports_embeddings=supports_embeddings,
354
329
  supports_reranking=supports_reranking,
355
330
  )
@@ -374,6 +349,8 @@ def server_load(model_config: PullConfig, telemetry: LlamaTelemetry):
374
349
  snapshot_files,
375
350
  use_gpu=False,
376
351
  telemetry=telemetry,
352
+ backend=backend,
353
+ ctx_size=ctx_size,
377
354
  supports_embeddings=supports_embeddings,
378
355
  supports_reranking=supports_reranking,
379
356
  )