lemonade-sdk 9.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. lemonade/__init__.py +5 -0
  2. lemonade/api.py +180 -0
  3. lemonade/cache.py +92 -0
  4. lemonade/cli.py +173 -0
  5. lemonade/common/__init__.py +0 -0
  6. lemonade/common/build.py +176 -0
  7. lemonade/common/cli_helpers.py +139 -0
  8. lemonade/common/exceptions.py +98 -0
  9. lemonade/common/filesystem.py +368 -0
  10. lemonade/common/inference_engines.py +408 -0
  11. lemonade/common/network.py +93 -0
  12. lemonade/common/printing.py +110 -0
  13. lemonade/common/status.py +471 -0
  14. lemonade/common/system_info.py +1411 -0
  15. lemonade/common/test_helpers.py +28 -0
  16. lemonade/profilers/__init__.py +1 -0
  17. lemonade/profilers/agt_power.py +437 -0
  18. lemonade/profilers/hwinfo_power.py +429 -0
  19. lemonade/profilers/memory_tracker.py +259 -0
  20. lemonade/profilers/profiler.py +58 -0
  21. lemonade/sequence.py +363 -0
  22. lemonade/state.py +159 -0
  23. lemonade/tools/__init__.py +1 -0
  24. lemonade/tools/accuracy.py +432 -0
  25. lemonade/tools/adapter.py +114 -0
  26. lemonade/tools/bench.py +302 -0
  27. lemonade/tools/flm/__init__.py +1 -0
  28. lemonade/tools/flm/utils.py +305 -0
  29. lemonade/tools/huggingface/bench.py +187 -0
  30. lemonade/tools/huggingface/load.py +235 -0
  31. lemonade/tools/huggingface/utils.py +359 -0
  32. lemonade/tools/humaneval.py +264 -0
  33. lemonade/tools/llamacpp/bench.py +255 -0
  34. lemonade/tools/llamacpp/load.py +222 -0
  35. lemonade/tools/llamacpp/utils.py +1260 -0
  36. lemonade/tools/management_tools.py +319 -0
  37. lemonade/tools/mmlu.py +319 -0
  38. lemonade/tools/oga/__init__.py +0 -0
  39. lemonade/tools/oga/bench.py +120 -0
  40. lemonade/tools/oga/load.py +804 -0
  41. lemonade/tools/oga/migration.py +403 -0
  42. lemonade/tools/oga/utils.py +462 -0
  43. lemonade/tools/perplexity.py +147 -0
  44. lemonade/tools/prompt.py +263 -0
  45. lemonade/tools/report/__init__.py +0 -0
  46. lemonade/tools/report/llm_report.py +203 -0
  47. lemonade/tools/report/table.py +899 -0
  48. lemonade/tools/server/__init__.py +0 -0
  49. lemonade/tools/server/flm.py +133 -0
  50. lemonade/tools/server/llamacpp.py +320 -0
  51. lemonade/tools/server/serve.py +2123 -0
  52. lemonade/tools/server/static/favicon.ico +0 -0
  53. lemonade/tools/server/static/index.html +279 -0
  54. lemonade/tools/server/static/js/chat.js +1059 -0
  55. lemonade/tools/server/static/js/model-settings.js +183 -0
  56. lemonade/tools/server/static/js/models.js +1395 -0
  57. lemonade/tools/server/static/js/shared.js +556 -0
  58. lemonade/tools/server/static/logs.html +191 -0
  59. lemonade/tools/server/static/styles.css +2654 -0
  60. lemonade/tools/server/static/webapp.html +321 -0
  61. lemonade/tools/server/tool_calls.py +153 -0
  62. lemonade/tools/server/tray.py +664 -0
  63. lemonade/tools/server/utils/macos_tray.py +226 -0
  64. lemonade/tools/server/utils/port.py +77 -0
  65. lemonade/tools/server/utils/thread.py +85 -0
  66. lemonade/tools/server/utils/windows_tray.py +408 -0
  67. lemonade/tools/server/webapp.py +34 -0
  68. lemonade/tools/server/wrapped_server.py +559 -0
  69. lemonade/tools/tool.py +374 -0
  70. lemonade/version.py +1 -0
  71. lemonade_install/__init__.py +1 -0
  72. lemonade_install/install.py +239 -0
  73. lemonade_sdk-9.1.1.dist-info/METADATA +276 -0
  74. lemonade_sdk-9.1.1.dist-info/RECORD +84 -0
  75. lemonade_sdk-9.1.1.dist-info/WHEEL +5 -0
  76. lemonade_sdk-9.1.1.dist-info/entry_points.txt +5 -0
  77. lemonade_sdk-9.1.1.dist-info/licenses/LICENSE +201 -0
  78. lemonade_sdk-9.1.1.dist-info/licenses/NOTICE.md +47 -0
  79. lemonade_sdk-9.1.1.dist-info/top_level.txt +3 -0
  80. lemonade_server/cli.py +805 -0
  81. lemonade_server/model_manager.py +758 -0
  82. lemonade_server/pydantic_models.py +159 -0
  83. lemonade_server/server_models.json +643 -0
  84. lemonade_server/settings.py +39 -0
@@ -0,0 +1,1260 @@
1
+ import logging
2
+ import os
3
+ import platform
4
+ import shutil
5
+ import sys
6
+ import threading
7
+ import time
8
+ import zipfile
9
+ from typing import Optional
10
+ import psutil
11
+ import subprocess
12
+ import requests
13
+ import lemonade.common.build as build
14
+ import lemonade.common.printing as printing
15
+ from lemonade.tools.adapter import PassthroughTokenizer, ModelAdapter
16
+ from lemonade.common.system_info import get_system_info
17
+ from dotenv import set_key, load_dotenv
18
+
19
+ LLAMA_VERSION_VULKAN = "b6510"
20
+ LLAMA_VERSION_ROCM = "b1066"
21
+ LLAMA_VERSION_METAL = "b6510"
22
+ LLAMA_VERSION_CPU = "b6510"
23
+
24
+
25
+ def identify_rocm_arch_from_name(device_name: str) -> str | None:
26
+ """
27
+ Identify the appropriate ROCm target architecture based on the device name
28
+ """
29
+ device_name_lower = device_name.lower()
30
+ if "radeon" not in device_name_lower:
31
+ return None
32
+
33
+ # Check iGPUs
34
+ # STX Halo iGPUs (gfx1151 architecture)
35
+ # Radeon 8050S Graphics / Radeon 8060S Graphics
36
+ target_arch = None
37
+ if any(halo_igpu in device_name_lower.lower() for halo_igpu in ["8050s", "8060s"]):
38
+ return "gfx1151"
39
+
40
+ # Check dGPUs
41
+ # RDNA4 GPUs (gfx120X architecture)
42
+ # AMD Radeon AI PRO R9700, AMD Radeon RX 9070 XT, AMD Radeon RX 9070 GRE,
43
+ # AMD Radeon RX 9070, AMD Radeon RX 9060 XT
44
+ if any(
45
+ rdna4_gpu in device_name_lower.lower()
46
+ for rdna4_gpu in ["r9700", "9060", "9070"]
47
+ ):
48
+ return "gfx120X"
49
+
50
+ # RDNA3 GPUs (gfx110X architecture)
51
+ # AMD Radeon PRO V710, AMD Radeon PRO W7900 Dual Slot, AMD Radeon PRO W7900,
52
+ # AMD Radeon PRO W7800 48GB, AMD Radeon PRO W7800, AMD Radeon PRO W7700,
53
+ # AMD Radeon RX 7900 XTX, AMD Radeon RX 7900 XT, AMD Radeon RX 7900 GRE,
54
+ # AMD Radeon RX 7800 XT, AMD Radeon RX 7700 XT
55
+ elif any(
56
+ rdna3_gpu in device_name_lower.lower()
57
+ for rdna3_gpu in ["7700", "7800", "7900", "v710"]
58
+ ):
59
+ return "gfx110X"
60
+
61
+ return None
62
+
63
+
64
+ def identify_rocm_arch() -> str:
65
+ """
66
+ Identify the appropriate ROCm target architecture based on the device info
67
+ Returns tuple of (architecture, gpu_type) where gpu_type is 'igpu' or 'dgpu'
68
+ """
69
+
70
+ # Check for integrated and discrete AMD GPUs
71
+ system_info = get_system_info()
72
+ amd_igpu = system_info.get_amd_igpu_device()
73
+ amd_dgpu = system_info.get_amd_dgpu_devices()
74
+ target_arch = None
75
+ for gpu in [amd_igpu] + amd_dgpu:
76
+ if gpu.get("available") and gpu.get("name"):
77
+ target_arch = identify_rocm_arch_from_name(gpu["name"].lower())
78
+ if target_arch:
79
+ break
80
+
81
+ return target_arch
82
+
83
+
84
+ def identify_hip_id() -> str:
85
+ """
86
+ Identify the HIP ID
87
+ """
88
+ # Get HIP devices
89
+ hip_devices = get_hip_devices()
90
+ logging.debug(f"HIP devices found: {hip_devices}")
91
+ if len(hip_devices) == 0:
92
+ raise ValueError("No HIP devices found when identifying HIP ID")
93
+
94
+ # Identify HIP devices that are compatible with our ROCm builds
95
+ rocm_devices = []
96
+ for device in hip_devices:
97
+ device_id, device_name = device
98
+ if identify_rocm_arch_from_name(device_name):
99
+ rocm_devices.append([device_id, device_name])
100
+ logging.debug(f"ROCm devices found: {rocm_devices}")
101
+
102
+ # If no ROCm devices are found, use the last HIP device
103
+ # This might be needed in some scenarios where HIP reports generic device names
104
+ # Example: "AMD Radeon Graphics" for STX Halo iGPU on Ubuntu 24.04
105
+ if len(rocm_devices) == 0:
106
+ rocm_devices = [hip_devices[-1]]
107
+ logging.warning(
108
+ "No ROCm devices found when identifying HIP ID. "
109
+ f"Falling back to the following device: {rocm_devices[0]}"
110
+ )
111
+ elif len(rocm_devices) > 1:
112
+ logging.warning(
113
+ f"Multiple ROCm devices found when identifying HIP ID: {rocm_devices}"
114
+ "The last device will be used."
115
+ )
116
+
117
+ # Select the last device
118
+ device_selected = rocm_devices[-1]
119
+ logging.debug(f"Selected ROCm device: {device_selected}")
120
+
121
+ # Return the device ID
122
+ return device_selected[0]
123
+
124
+
125
+ def get_llama_version(backend: str) -> str:
126
+ """
127
+ Select the appropriate llama.cpp version based on the backend
128
+ """
129
+ if backend == "rocm":
130
+ return LLAMA_VERSION_ROCM
131
+ elif backend == "vulkan":
132
+ return LLAMA_VERSION_VULKAN
133
+ elif backend == "metal":
134
+ return LLAMA_VERSION_METAL
135
+ elif backend == "cpu":
136
+ return LLAMA_VERSION_CPU
137
+ else:
138
+ raise ValueError(
139
+ f"Unsupported backend: {backend}. Supported: vulkan, rocm, metal, cpu"
140
+ )
141
+
142
+
143
+ def get_llama_folder_path(backend: str):
144
+ """
145
+ Get path for llama.cpp platform-specific executables folder.
146
+ Uses sys.prefix to get the environment root (works for both venv and conda):
147
+ - Conda: sys.executable is at env/python.exe, sys.prefix is env/
148
+ - Venv: sys.executable is at .venv/Scripts/python.exe, sys.prefix is .venv/
149
+ """
150
+ return os.path.join(sys.prefix, backend, "llama_server")
151
+
152
+
153
+ def get_llama_exe_path(exe_name: str, backend: str):
154
+ """
155
+ Get path to platform-specific llama-server executable
156
+ """
157
+ base_dir = get_llama_folder_path(backend)
158
+ system = platform.system().lower()
159
+
160
+ if system == "windows":
161
+ return os.path.join(base_dir, f"{exe_name}.exe")
162
+ else: # Darwin/Linux/Ubuntu
163
+ # Check if executable exists in build/bin subdirectory
164
+ build_bin_path = os.path.join(base_dir, "build", "bin", exe_name)
165
+ if os.path.exists(build_bin_path):
166
+ return build_bin_path
167
+ else:
168
+ # Fallback to root directory
169
+ return os.path.join(base_dir, exe_name)
170
+
171
+
172
+ def get_llama_server_exe_path(backend: str):
173
+ """
174
+ Get path to platform-specific llama-server executable
175
+ """
176
+ return get_llama_exe_path("llama-server", backend)
177
+
178
+
179
+ def get_llama_cli_exe_path(backend: str):
180
+ """
181
+ Get path to platform-specific llama-cli executable
182
+ """
183
+ return get_llama_exe_path("llama-cli", backend)
184
+
185
+
186
+ def get_llama_bench_exe_path(backend: str):
187
+ """
188
+ Get path to platform-specific llama-bench executable
189
+ """
190
+ return get_llama_exe_path("llama-bench", backend)
191
+
192
+
193
+ def get_version_txt_path(backend: str):
194
+ """
195
+ Get path to text file that contains version information
196
+ """
197
+ return os.path.join(get_llama_folder_path(backend), "version.txt")
198
+
199
+
200
+ def get_llama_installed_version(backend: str):
201
+ """
202
+ Gets version of installed llama.cpp
203
+ Returns None if llama.cpp is not installed
204
+ """
205
+ version_txt_path = get_version_txt_path(backend)
206
+ if os.path.exists(version_txt_path):
207
+ with open(version_txt_path, "r", encoding="utf-8") as f:
208
+ llama_installed_version = f.read()
209
+ return llama_installed_version
210
+ return None
211
+
212
+
213
+ def get_binary_url_and_filename(backend: str, target_arch: str = None):
214
+ """
215
+ Get the appropriate binary URL and filename based on platform and backend
216
+
217
+ Args:
218
+ backend: Backend to use
219
+ """
220
+ system = platform.system().lower()
221
+
222
+ if backend == "rocm":
223
+
224
+ # ROCm support from lemonade-sdk/llamacpp-rocm
225
+ repo = "lemonade-sdk/llamacpp-rocm"
226
+ version = LLAMA_VERSION_ROCM
227
+ if system == "windows":
228
+ filename = f"llama-{version}-windows-rocm-{target_arch}-x64.zip"
229
+ elif system == "linux":
230
+ filename = f"llama-{version}-ubuntu-rocm-{target_arch}-x64.zip"
231
+ else:
232
+ raise NotImplementedError(
233
+ f"Platform {system} not supported for ROCm llamacpp. Supported: Windows, Ubuntu Linux"
234
+ )
235
+
236
+ elif backend == "vulkan":
237
+ # Original Vulkan support from ggml-org/llama.cpp
238
+ repo = "ggml-org/llama.cpp"
239
+ version = LLAMA_VERSION_VULKAN
240
+ if system == "windows":
241
+ filename = f"llama-{version}-bin-win-vulkan-x64.zip"
242
+ elif system == "linux":
243
+ filename = f"llama-{version}-bin-ubuntu-vulkan-x64.zip"
244
+ else:
245
+ raise NotImplementedError(
246
+ f"Platform {system} not supported for Vulkan llamacpp. Supported: Windows, Ubuntu Linux"
247
+ )
248
+
249
+ elif backend == "cpu":
250
+ # Original CPU support from ggml-org/llama.cpp
251
+ repo = "ggml-org/llama.cpp"
252
+ version = LLAMA_VERSION_CPU
253
+ if system == "windows":
254
+ filename = f"llama-{version}-bin-win-cpu-x64.zip"
255
+ elif system == "linux":
256
+ filename = f"llama-{version}-bin-ubuntu-x64.zip"
257
+ else:
258
+ raise NotImplementedError(
259
+ f"Platform {system} not supported for CPU llamacpp. Supported: Windows, Ubuntu Linux"
260
+ )
261
+
262
+ elif backend == "metal":
263
+ # Metal support for macOS Apple Silicon from ggml-org/llama.cpp
264
+ repo = "ggml-org/llama.cpp"
265
+ version = LLAMA_VERSION_METAL
266
+ if system == "darwin":
267
+ if platform.machine().lower() in ["arm64", "aarch64"]:
268
+ filename = f"llama-{version}-bin-macos-arm64.zip"
269
+ else:
270
+ raise NotImplementedError(
271
+ "Metal backend only supports Apple Silicon (ARM64) processors"
272
+ )
273
+ else:
274
+ raise NotImplementedError(
275
+ f"Platform {system} not supported for Metal llamacpp. Metal is only supported on macOS"
276
+ )
277
+ else:
278
+ supported_backends = ["vulkan", "rocm", "metal", "cpu"]
279
+ raise NotImplementedError(
280
+ f"Unsupported backend: {backend}. Supported backends: {supported_backends}"
281
+ )
282
+
283
+ url = f"https://github.com/{repo}/releases/download/{version}/{filename}"
284
+ return url, filename
285
+
286
+
287
+ def validate_platform_support():
288
+ """
289
+ Validate platform support before attempting download
290
+ """
291
+ system = platform.system().lower()
292
+
293
+ if system not in ["windows", "linux", "darwin"]:
294
+ raise NotImplementedError(
295
+ f"Platform {system} not supported for llamacpp. "
296
+ "Supported: Windows, Ubuntu Linux, macOS"
297
+ )
298
+
299
+ if system == "linux":
300
+ # Check if we're actually on Ubuntu/compatible distro and log a warning if not
301
+ try:
302
+ with open("/etc/os-release", "r", encoding="utf-8") as f:
303
+ os_info = f.read().lower()
304
+ if "ubuntu" not in os_info and "debian" not in os_info:
305
+ logging.warning(
306
+ "llamacpp binaries are built for Ubuntu. "
307
+ "Compatibility with other Linux distributions is not guaranteed."
308
+ )
309
+ except (FileNotFoundError, PermissionError, OSError) as e:
310
+ logging.warning(
311
+ "Could not determine Linux distribution (%s). "
312
+ "llamacpp binaries are built for Ubuntu.",
313
+ str(e),
314
+ )
315
+
316
+
317
+ def install_llamacpp(backend):
318
+ """
319
+ Installs or upgrades llama.cpp binaries if needed
320
+ """
321
+
322
+ # Exception will be thrown if platform is not supported
323
+ validate_platform_support()
324
+
325
+ version = get_llama_version(backend)
326
+
327
+ # Get platform-specific paths at runtime
328
+ llama_server_exe_dir = get_llama_folder_path(backend)
329
+ llama_server_exe_path = get_llama_server_exe_path(backend)
330
+
331
+ # Check whether the llamacpp install needs an upgrade
332
+ version_txt_path = os.path.join(llama_server_exe_dir, "version.txt")
333
+ backend_txt_path = os.path.join(llama_server_exe_dir, "backend.txt")
334
+
335
+ logging.info(f"Using backend: {backend}")
336
+
337
+ if os.path.exists(version_txt_path) and os.path.exists(backend_txt_path):
338
+ with open(version_txt_path, "r", encoding="utf-8") as f:
339
+ llamacpp_installed_version = f.read().strip()
340
+ with open(backend_txt_path, "r", encoding="utf-8") as f:
341
+ llamacpp_installed_backend = f.read().strip()
342
+
343
+ if (
344
+ llamacpp_installed_version != version
345
+ or llamacpp_installed_backend != backend
346
+ ):
347
+ # Remove the existing install, which will trigger a new install
348
+ # in the next code block
349
+ shutil.rmtree(llama_server_exe_dir)
350
+ elif os.path.exists(version_txt_path):
351
+ # Old installation without backend tracking - remove to upgrade
352
+ shutil.rmtree(llama_server_exe_dir)
353
+
354
+ # Download llama.cpp server if it isn't already available
355
+ if not os.path.exists(llama_server_exe_path):
356
+
357
+ # Create the directory
358
+ os.makedirs(llama_server_exe_dir, exist_ok=True)
359
+
360
+ # Identify the target architecture (only needed for ROCm)
361
+ target_arch = None
362
+ if backend == "rocm":
363
+ # Identify the target architecture
364
+ target_arch = identify_rocm_arch()
365
+ if not target_arch:
366
+ system = platform.system().lower()
367
+ if system == "linux":
368
+ hint = (
369
+ "Hint: If you think your device is supported, "
370
+ "running `sudo update-pciids` may help identify your hardware."
371
+ )
372
+ else:
373
+ hint = ""
374
+ raise ValueError(
375
+ "ROCm backend selected but no compatible ROCm target architecture found. "
376
+ "See https://github.com/lemonade-sdk/lemonade?tab=readme-ov-file#supported-configurations "
377
+ f"for supported configurations. {hint}"
378
+ )
379
+
380
+ # Direct download for Vulkan/ROCm
381
+ llama_archive_url, filename = get_binary_url_and_filename(backend, target_arch)
382
+ llama_archive_path = os.path.join(llama_server_exe_dir, filename)
383
+ logging.info(f"Downloading llama.cpp server from {llama_archive_url}")
384
+
385
+ with requests.get(llama_archive_url, stream=True) as r:
386
+ r.raise_for_status()
387
+ with open(llama_archive_path, "wb") as f:
388
+ for chunk in r.iter_content(chunk_size=8192):
389
+ f.write(chunk)
390
+
391
+ logging.info(f"Extracting {filename} to {llama_server_exe_dir}")
392
+ if filename.endswith(".zip"):
393
+ with zipfile.ZipFile(llama_archive_path, "r") as zip_ref:
394
+ zip_ref.extractall(llama_server_exe_dir)
395
+
396
+ # On Unix-like systems (macOS/Linux), make executables executable
397
+ if platform.system().lower() in ["darwin", "linux"]:
398
+ import stat
399
+
400
+ # Find and make executable files executable
401
+ for root, _, files in os.walk(llama_server_exe_dir):
402
+ for file in files:
403
+ file_path = os.path.join(root, file)
404
+ # Make files in bin/ directories executable
405
+ if "bin" in root.split(os.sep) or file in [
406
+ "llama-server",
407
+ "llama-simple",
408
+ ]:
409
+ try:
410
+ current_permissions = os.stat(file_path).st_mode
411
+ os.chmod(file_path, current_permissions | stat.S_IEXEC)
412
+ logging.debug(f"Made {file_path} executable")
413
+ except Exception as e:
414
+ raise RuntimeError(
415
+ f"Failed to make {file_path} executable. This will prevent "
416
+ f"llama-server from starting. Error: {e}"
417
+ )
418
+ else:
419
+ raise NotImplementedError(f"Unsupported archive format: {filename}")
420
+
421
+ # Identify and set HIP ID
422
+ if backend == "rocm":
423
+ try:
424
+ hip_id = identify_hip_id()
425
+ except Exception as e: # pylint: disable=broad-exception-caught
426
+ hip_id = 0
427
+ logging.warning(f"Error identifying HIP ID: {e}. Falling back to 0.")
428
+ env_file_path = os.path.join(llama_server_exe_dir, ".env")
429
+ set_key(env_file_path, "HIP_VISIBLE_DEVICES", str(hip_id))
430
+
431
+ # Make executable on Linux - need to update paths after extraction
432
+ if platform.system().lower() == "linux":
433
+ # Re-get the paths since extraction might have changed the directory structure
434
+ exe_paths = [
435
+ (get_llama_server_exe_path(backend), "llama-server"),
436
+ (get_llama_cli_exe_path(backend), "llama-cli"),
437
+ (get_llama_bench_exe_path(backend), "llama-bench"),
438
+ ]
439
+
440
+ for exe_path, exe_name in exe_paths:
441
+ if os.path.exists(exe_path):
442
+ os.chmod(exe_path, 0o755)
443
+ logging.info(f"Set executable permissions for {exe_path}")
444
+ else:
445
+ logging.warning(
446
+ f"Could not find {exe_name} executable at {exe_path}"
447
+ )
448
+
449
+ # Save version and backend info
450
+ with open(version_txt_path, "w", encoding="utf-8") as vf:
451
+ vf.write(version)
452
+ with open(backend_txt_path, "w", encoding="utf-8") as bf:
453
+ bf.write(backend)
454
+
455
+ # Delete the archive file
456
+ os.remove(llama_archive_path)
457
+
458
+
459
+ def parse_checkpoint(checkpoint: str) -> tuple[str, str | None]:
460
+ """
461
+ Parse a checkpoint string that may contain a variant separated by a colon.
462
+
463
+ For GGUF models, the format is "repository:variant" (e.g., "unsloth/Qwen3-0.6B-GGUF:Q4_0").
464
+ For other models, there is no variant.
465
+
466
+ Args:
467
+ checkpoint: The checkpoint string, potentially with variant
468
+
469
+ Returns:
470
+ tuple: (base_checkpoint, variant) where variant is None if no colon is present
471
+ """
472
+ if ":" in checkpoint:
473
+ base_checkpoint, variant = checkpoint.split(":", 1)
474
+ return base_checkpoint, variant
475
+ return checkpoint, None
476
+
477
+
478
+ def get_local_checkpoint_path(base_checkpoint, variant):
479
+ """
480
+ Returns the absolute path to a .gguf checkpoint file in the local HuggingFace hub.
481
+ Also returns just .gguf filename.
482
+
483
+ Checkpoint is one of the following types:
484
+ 1. Full filename: exact file to download
485
+ 2. Quantization variant: find a single file ending with the variant name (case insensitive)
486
+ 3. Folder name with subfolder that matches the variant name (case insensitive)
487
+
488
+ """
489
+ full_model_path = None
490
+ model_to_use = None
491
+ try:
492
+ from lemonade.common.network import custom_snapshot_download
493
+
494
+ snapshot_path = custom_snapshot_download(
495
+ base_checkpoint,
496
+ local_files_only=True,
497
+ )
498
+
499
+ full_model_path = None
500
+ model_to_use = None
501
+
502
+ if os.path.isdir(snapshot_path) and os.listdir(snapshot_path):
503
+
504
+ snapshot_files = [filename for filename in os.listdir(snapshot_path)]
505
+
506
+ if variant.endswith(".gguf"):
507
+ # Variant is an exact file
508
+ model_to_use = variant
509
+ if variant in snapshot_files:
510
+ full_model_path = os.path.join(snapshot_path, variant)
511
+ else:
512
+ raise ValueError(
513
+ f"The variant {variant} is not available locally in {snapshot_path}."
514
+ )
515
+
516
+ else:
517
+ # Variant is a quantization
518
+ end_with_variant = [
519
+ file
520
+ for file in snapshot_files
521
+ if file.lower().endswith(f"{variant}.gguf".lower())
522
+ ]
523
+ if len(end_with_variant) == 1:
524
+ model_to_use = end_with_variant[0]
525
+ full_model_path = os.path.join(snapshot_path, model_to_use)
526
+ elif len(end_with_variant) > 1:
527
+ raise ValueError(
528
+ f"Multiple .gguf files found for variant {variant}, "
529
+ f"but only one is allowed."
530
+ )
531
+ else:
532
+ # Check whether the variant corresponds to a folder with
533
+ # sharded files (case insensitive)
534
+ quantization_folder = [
535
+ folder
536
+ for folder in snapshot_files
537
+ if folder.lower() == variant.lower()
538
+ and os.path.exists(os.path.join(snapshot_path, folder))
539
+ and os.path.isdir(os.path.join(snapshot_path, folder))
540
+ ]
541
+ if len(quantization_folder) == 1:
542
+ quantization_folder = os.path.join(
543
+ snapshot_path, quantization_folder[0]
544
+ )
545
+ sharded_files = [
546
+ f
547
+ for f in os.listdir(quantization_folder)
548
+ if f.endswith(".gguf")
549
+ ]
550
+ if not sharded_files:
551
+ raise ValueError(
552
+ f"No .gguf files found for variant {variant}."
553
+ )
554
+ else:
555
+ model_to_use = sharded_files[0]
556
+ full_model_path = os.path.join(
557
+ quantization_folder, model_to_use
558
+ )
559
+ elif len(quantization_folder) > 1:
560
+ raise ValueError(
561
+ f"Multiple checkpoint folder names match the variant {variant}."
562
+ )
563
+ else:
564
+ raise ValueError(f"No .gguf files found for variant {variant}.")
565
+ else:
566
+ raise ValueError(
567
+ f"The checkpoint {base_checkpoint} is not a local checkpoint."
568
+ )
569
+
570
+ except Exception as e: # pylint: disable=broad-exception-caught
571
+ # Log any errors but continue with the original path
572
+ printing.log_info(f"Error checking Hugging Face cache: {e}")
573
+
574
+ return full_model_path, model_to_use
575
+
576
+
577
+ def identify_gguf_models(
578
+ checkpoint: str, variant: Optional[str], mmproj: str
579
+ ) -> tuple[dict, list[str]]:
580
+ """
581
+ Identifies the GGUF model files in the repository that match the variant.
582
+ """
583
+
584
+ hint = """
585
+ The CHECKPOINT:VARIANT scheme is used to specify model files in Hugging Face repositories.
586
+
587
+ The VARIANT format can be one of several types:
588
+ 0. wildcard (*): download all .gguf files in the repo
589
+ 1. Full filename: exact file to download
590
+ 2. None/empty: gets the first .gguf file in the repository (excludes mmproj files)
591
+ 3. Quantization variant: find a single file ending with the variant name (case insensitive)
592
+ 4. Folder name: downloads all .gguf files in the folder that matches the variant name (case insensitive)
593
+
594
+ Examples:
595
+ - "ggml-org/gpt-oss-120b-GGUF:*" -> downloads all .gguf files in repo
596
+ - "unsloth/Qwen3-8B-GGUF:qwen3.gguf" -> downloads "qwen3.gguf"
597
+ - "unsloth/Qwen3-30B-A3B-GGUF" -> downloads "Qwen3-30B-A3B-GGUF.gguf"
598
+ - "unsloth/Qwen3-8B-GGUF:Q4_1" -> downloads "Qwen3-8B-GGUF-Q4_1.gguf"
599
+ - "unsloth/Qwen3-30B-A3B-GGUF:Q4_0" -> downloads all files in "Q4_0/" folder
600
+ """
601
+
602
+ from huggingface_hub import list_repo_files
603
+
604
+ repo_files = list_repo_files(checkpoint)
605
+ sharded_files = []
606
+
607
+ # (case 0) Wildcard, download everything
608
+ if variant and variant == "*":
609
+ sharded_files = [f for f in repo_files if f.endswith(".gguf")]
610
+
611
+ # Sort to ensure consistent ordering
612
+ sharded_files.sort()
613
+
614
+ # Use first file as primary (this is how llamacpp handles it)
615
+ variant_name = sharded_files[0]
616
+
617
+ # (case 1) If variant ends in .gguf, use it directly
618
+ elif variant and variant.endswith(".gguf"):
619
+ variant_name = variant
620
+ if variant_name not in repo_files:
621
+ raise ValueError(
622
+ f"File {variant} not found in Hugging Face repository {checkpoint}. {hint}"
623
+ )
624
+ # (case 2) If no variant is provided, get the first .gguf file in the repository
625
+ elif variant is None:
626
+ all_variants = [
627
+ f for f in repo_files if f.endswith(".gguf") and "mmproj" not in f
628
+ ]
629
+ if len(all_variants) == 0:
630
+ raise ValueError(
631
+ f"No .gguf files found in Hugging Face repository {checkpoint}. {hint}"
632
+ )
633
+ variant_name = all_variants[0]
634
+ else:
635
+ # (case 3) Find a single file ending with the variant name (case insensitive)
636
+ end_with_variant = [
637
+ f
638
+ for f in repo_files
639
+ if f.lower().endswith(f"{variant}.gguf".lower())
640
+ and "mmproj" not in f.lower()
641
+ ]
642
+ if len(end_with_variant) == 1:
643
+ variant_name = end_with_variant[0]
644
+ elif len(end_with_variant) > 1:
645
+ raise ValueError(
646
+ f"Multiple .gguf files found for variant {variant}, but only one is allowed. {hint}"
647
+ )
648
+ # (case 4) Check whether the variant corresponds to a folder with
649
+ # sharded files (case insensitive)
650
+ else:
651
+ sharded_files = [
652
+ f
653
+ for f in repo_files
654
+ if f.endswith(".gguf") and f.lower().startswith(f"{variant}/".lower())
655
+ ]
656
+
657
+ if not sharded_files:
658
+ raise ValueError(f"No .gguf files found for variant {variant}. {hint}")
659
+
660
+ # Sort to ensure consistent ordering
661
+ sharded_files.sort()
662
+
663
+ # Use first file as primary (this is how llamacpp handles it)
664
+ variant_name = sharded_files[0]
665
+
666
+ core_files = {"variant": variant_name}
667
+
668
+ # If there is a mmproj file, add it to the patterns
669
+ if mmproj:
670
+ if mmproj not in repo_files:
671
+ raise ValueError(
672
+ f"The provided mmproj file {mmproj} was not found in {checkpoint}."
673
+ )
674
+ core_files["mmproj"] = mmproj
675
+
676
+ return core_files, sharded_files
677
+
678
+
679
+ def resolve_local_gguf_model(
680
+ checkpoint: str, variant: str, config_mmproj: str = None
681
+ ) -> dict | None:
682
+ """
683
+ Attempts to resolve a GGUF model from the local HuggingFace cache.
684
+ """
685
+ from huggingface_hub.constants import HF_HUB_CACHE
686
+
687
+ # Convert checkpoint to cache directory format
688
+ if checkpoint.startswith("models--"):
689
+ model_cache_dir = os.path.join(HF_HUB_CACHE, checkpoint)
690
+ else:
691
+ # This is a HuggingFace repo - convert to cache directory format
692
+ repo_cache_name = checkpoint.replace("/", "--")
693
+ model_cache_dir = os.path.join(HF_HUB_CACHE, f"models--{repo_cache_name}")
694
+
695
+ # Check if the cache directory exists
696
+ if not os.path.exists(model_cache_dir):
697
+ return None
698
+
699
+ gguf_file_found = None
700
+
701
+ # If variant is specified, look for that specific file
702
+ if variant:
703
+ search_term = variant if variant.endswith(".gguf") else f"{variant}.gguf"
704
+
705
+ for root, _, files in os.walk(model_cache_dir):
706
+ if search_term in files:
707
+ gguf_file_found = os.path.join(root, search_term)
708
+ break
709
+
710
+ # If no variant or variant not found, find any .gguf file (excluding mmproj)
711
+ if not gguf_file_found:
712
+ for root, _, files in os.walk(model_cache_dir):
713
+ gguf_files = [
714
+ f for f in files if f.endswith(".gguf") and "mmproj" not in f.lower()
715
+ ]
716
+ if gguf_files:
717
+ gguf_file_found = os.path.join(root, gguf_files[0])
718
+ break
719
+
720
+ # If no GGUF file found, model is not in cache
721
+ if not gguf_file_found:
722
+ return None
723
+
724
+ # Build result dictionary
725
+ result = {"variant": gguf_file_found}
726
+
727
+ # Search for mmproj file if provided
728
+ if config_mmproj:
729
+ for root, _, files in os.walk(model_cache_dir):
730
+ if config_mmproj in files:
731
+ result["mmproj"] = os.path.join(root, config_mmproj)
732
+ break
733
+
734
+ logging.info(f"Resolved local GGUF model: {result}")
735
+ return result
736
+
737
+
738
+ def download_gguf(
739
+ config_checkpoint: str, config_mmproj=None, do_not_upgrade: bool = False
740
+ ) -> dict:
741
+ """
742
+ Downloads the GGUF file for the given model configuration from HuggingFace.
743
+
744
+ This function downloads models from the internet. It does NOT check the local cache first.
745
+ Callers should use resolve_local_gguf_model() if they want to check for existing models first.
746
+
747
+ Args:
748
+ config_checkpoint: Checkpoint identifier (file path or HF repo with variant)
749
+ config_mmproj: Optional mmproj file to also download
750
+ do_not_upgrade: If True, use local cache only without attempting to download updates
751
+
752
+ Returns:
753
+ Dictionary with "variant" (and optionally "mmproj") file paths
754
+ """
755
+ # Handle direct file path case - if the checkpoint is an actual file on disk
756
+ if os.path.exists(config_checkpoint):
757
+ result = {"variant": config_checkpoint}
758
+ if config_mmproj:
759
+ result["mmproj"] = config_mmproj
760
+ return result
761
+
762
+ # Parse checkpoint to extract base and variant
763
+ # Checkpoint format: repo_name:variant (e.g., "unsloth/Qwen3-0.6B-GGUF:Q4_0")
764
+ checkpoint, variant = parse_checkpoint(config_checkpoint)
765
+
766
+ # Identify the GGUF model files in the repository that match the variant
767
+ core_files, sharded_files = identify_gguf_models(checkpoint, variant, config_mmproj)
768
+
769
+ # Download the files
770
+ from lemonade.common.network import custom_snapshot_download
771
+
772
+ snapshot_folder = custom_snapshot_download(
773
+ checkpoint,
774
+ allow_patterns=list(core_files.values()) + sharded_files,
775
+ do_not_upgrade=do_not_upgrade,
776
+ )
777
+
778
+ # Ensure we downloaded all expected files
779
+ for file in list(core_files.values()) + sharded_files:
780
+ expected_path = os.path.join(snapshot_folder, file)
781
+ if not os.path.exists(expected_path):
782
+ raise ValueError(
783
+ f"Hugging Face snapshot download for {config_checkpoint} "
784
+ f"expected file {file} not found at {expected_path}"
785
+ )
786
+
787
+ # Return a dict of the full path of the core GGUF files
788
+ return {
789
+ file_name: os.path.join(snapshot_folder, file_path)
790
+ for file_name, file_path in core_files.items()
791
+ }
792
+
793
+
794
+ # Function to read a stream (stdout or stderr) into a list
795
+ def stream_reader(stream, output_list):
796
+ for line in iter(stream.readline, b""):
797
+ decoded_line = line.decode().rstrip()
798
+ output_list.append(decoded_line)
799
+ stream.close()
800
+
801
+
802
+ def monitor_process_memory(pid, memory_data, interval=0.5):
803
+ """Monitor memory usage of a process in a separate thread."""
804
+
805
+ try:
806
+ is_windows = platform.system() == "Windows"
807
+ if is_windows:
808
+ # We can only collect peak_wset in Windows
809
+ process = psutil.Process(pid)
810
+ while process.is_running():
811
+ try:
812
+ mem_info = process.memory_info()
813
+ peak_wset = mem_info.peak_wset
814
+ if peak_wset is not None:
815
+ memory_data["peak_wset"] = peak_wset
816
+ except psutil.NoSuchProcess:
817
+ break
818
+ time.sleep(interval)
819
+ except Exception as e:
820
+ print(f"Error monitoring process: {e}")
821
+
822
+ return memory_data
823
+
824
+
825
+ class LlamaCppTokenizerAdapter(PassthroughTokenizer):
826
+ pass
827
+
828
+
829
+ class LlamaCppAdapter(ModelAdapter):
830
+ def __init__(
831
+ self,
832
+ model,
833
+ device,
834
+ output_tokens,
835
+ context_size,
836
+ threads,
837
+ executable,
838
+ bench_executable,
839
+ reasoning=False,
840
+ lib_dir=None,
841
+ state=None,
842
+ ):
843
+ super().__init__()
844
+
845
+ self.model = os.path.normpath(model)
846
+ self.device = device
847
+ self.output_tokens = (
848
+ output_tokens # default value of max tokens to generate from a prompt
849
+ )
850
+ self.context_size = context_size
851
+ self.threads = threads
852
+ self.executable = os.path.normpath(executable)
853
+ self.bench_executable = os.path.normpath(bench_executable)
854
+ self.reasoning = reasoning
855
+ self.lib_dir = lib_dir
856
+ self.state = state
857
+
858
+ def generate(
859
+ self,
860
+ input_ids: str,
861
+ max_new_tokens: Optional[int] = None,
862
+ temperature: float = 0.8,
863
+ top_p: float = 0.95,
864
+ top_k: int = 40,
865
+ return_raw: bool = False,
866
+ save_max_memory_used: bool = False,
867
+ **kwargs, # pylint: disable=unused-argument
868
+ ):
869
+ """
870
+ Pass a text prompt into the llamacpp inference CLI.
871
+
872
+ The input_ids arg here should receive the original text that
873
+ would normally be encoded by a tokenizer.
874
+
875
+ Args:
876
+ input_ids: The input text prompt
877
+ max_new_tokens: Maximum number of tokens to generate
878
+ temperature: Temperature for sampling (0.0 = greedy)
879
+ top_p: Top-p sampling threshold
880
+ top_k: Top-k sampling threshold
881
+ return_raw: If True, returns the complete raw output including timing info
882
+ **kwargs: Additional arguments (ignored)
883
+
884
+ Returns:
885
+ List containing a single string with the generated text, or raw output if
886
+ return_raw=True
887
+ """
888
+
889
+ prompt = input_ids
890
+ if self.reasoning:
891
+ prompt += "<think>"
892
+ n_predict = max_new_tokens if max_new_tokens is not None else self.output_tokens
893
+
894
+ cmd = [
895
+ self.executable,
896
+ "-m",
897
+ self.model,
898
+ "--ctx-size", # size of the prompt context, 0 = loaded from model
899
+ str(self.context_size),
900
+ "-n", # number of tokens to predict, -1 = infinity, =2 - until context filled
901
+ str(n_predict),
902
+ "-t", # number of threads to use during generation
903
+ str(self.threads),
904
+ "-p",
905
+ prompt,
906
+ "-b", # logical maximum batch size
907
+ "1",
908
+ "-ub", # physical maximum batch size
909
+ "1",
910
+ "--temp",
911
+ str(temperature),
912
+ "--top-p",
913
+ str(top_p),
914
+ "--top-k",
915
+ str(top_k),
916
+ "-e", # process escape sequences
917
+ "--no-conversation", # disable conversation mode
918
+ "--reasoning-format", # leaves thoughts unparsed in message content
919
+ "none",
920
+ ]
921
+
922
+ # If prompt exceeds 500 characters, then use a file
923
+ if len(prompt) < 500:
924
+ cmd += ["-p", prompt]
925
+ else:
926
+ # Create prompt file in cache directory
927
+ prompt_file = os.path.join(
928
+ build.output_dir(self.state.cache_dir, self.state.build_name),
929
+ "prompt.txt",
930
+ )
931
+ with open(prompt_file, "w", encoding="utf-8") as file:
932
+ file.write(prompt)
933
+ cmd += ["-f", prompt_file]
934
+
935
+ # Configure GPU layers: 99 for GPU, 0 for CPU-only
936
+ ngl_value = "99" if self.device == "igpu" else "0"
937
+ cmd = cmd + ["-ngl", ngl_value]
938
+
939
+ cmd = [str(m) for m in cmd]
940
+
941
+ # save llama-cli command
942
+ self.state.llama_cli_cmd = getattr(self.state, "llama_cli_cmd", []) + [
943
+ " ".join(cmd)
944
+ ]
945
+
946
+ try:
947
+ # Set up environment with library path for Linux
948
+ env = os.environ.copy()
949
+
950
+ # Load environment variables from .env file in the executable directory
951
+ exe_dir = os.path.dirname(self.executable)
952
+ env_file_path = os.path.join(exe_dir, ".env")
953
+ if os.path.exists(env_file_path):
954
+ load_dotenv(env_file_path, override=True)
955
+ env.update(os.environ)
956
+
957
+ if self.lib_dir and os.name != "nt": # Not Windows
958
+ current_ld_path = env.get("LD_LIBRARY_PATH", "")
959
+ if current_ld_path:
960
+ env["LD_LIBRARY_PATH"] = f"{self.lib_dir}:{current_ld_path}"
961
+ else:
962
+ env["LD_LIBRARY_PATH"] = self.lib_dir
963
+
964
+ process = subprocess.Popen(
965
+ cmd,
966
+ stdout=subprocess.PIPE,
967
+ stderr=subprocess.PIPE,
968
+ universal_newlines=True,
969
+ encoding="utf-8",
970
+ errors="replace",
971
+ env=env,
972
+ )
973
+
974
+ # Start memory monitoring in a separate thread
975
+ if save_max_memory_used:
976
+ memory_data = {}
977
+ monitor_thread = threading.Thread(
978
+ target=monitor_process_memory,
979
+ args=(process.pid, memory_data),
980
+ daemon=True,
981
+ )
982
+ monitor_thread.start()
983
+
984
+ # Communicate with the subprocess
985
+ stdout, stderr = process.communicate(timeout=600)
986
+
987
+ # save llama-cli command output with performance info to state
988
+ # (can be viewed in state.yaml file in cache)
989
+ self.state.llama_cli_stderr = getattr(
990
+ self.state, "llama_cli_stderr", []
991
+ ) + [
992
+ [line for line in stderr.splitlines() if line.startswith("llama_perf_")]
993
+ ]
994
+
995
+ if process.returncode != 0:
996
+ error_msg = f"llama.cpp failed with return code {process.returncode}.\n"
997
+ error_msg += f"Command: {' '.join(cmd)}\n"
998
+ error_msg += f"Error output:\n{stderr}\n"
999
+ error_msg += f"Standard output:\n{stdout}"
1000
+ raise Exception(error_msg)
1001
+
1002
+ if stdout is None:
1003
+ raise Exception("No output received from llama.cpp process")
1004
+
1005
+ # Parse information from llama.cpp output
1006
+ for line in stderr.splitlines():
1007
+ # Parse timing and token information
1008
+ #
1009
+ # Prompt processing time and length (tokens)
1010
+ # Sample: llama_perf_context_print: prompt eval time = 35.26 ms /
1011
+ # 3 tokens ( 11.75 ms per token, 85.09 tokens per second)
1012
+ #
1013
+ if "llama_perf_context_print: prompt eval time =" in line:
1014
+ parts = line.split("=")[1].split()
1015
+ time_to_first_token_ms = float(parts[0])
1016
+ self.time_to_first_token = time_to_first_token_ms / 1000
1017
+ self.prompt_tokens = int(parts[3])
1018
+ #
1019
+ # Response processing time and length (tokens)
1020
+ # Sample: llama_perf_context_print: eval time = 1991.14 ms /
1021
+ # 63 runs ( 31.61 ms per token, 31.64 tokens per second)
1022
+ #
1023
+ if "llama_perf_context_print: eval time =" in line:
1024
+ parts = line.split("=")[1].split()
1025
+ self.response_tokens = int(parts[3]) + 1 # include first token
1026
+ response_time_ms = float(parts[0])
1027
+ self.tokens_per_second = (
1028
+ 1000 * self.response_tokens / response_time_ms
1029
+ if response_time_ms > 0
1030
+ else 0
1031
+ )
1032
+
1033
+ # Wait for monitor thread to finish and write peak_wset
1034
+ if save_max_memory_used:
1035
+ monitor_thread.join(timeout=2)
1036
+ self.peak_wset = memory_data.get("peak_wset", None)
1037
+
1038
+ if return_raw:
1039
+ return [stdout, stderr]
1040
+
1041
+ # Find where the prompt ends and the generated text begins
1042
+ prompt_found = False
1043
+ output_text = ""
1044
+ prompt_first_line = prompt.split("\n")[0]
1045
+ for line in stdout.splitlines():
1046
+ if prompt_first_line in line:
1047
+ prompt_found = True
1048
+ if prompt_found:
1049
+ line = line.replace("</s> [end of text]", "")
1050
+ output_text = output_text + line
1051
+
1052
+ if not prompt_found:
1053
+ raise Exception(
1054
+ f"Could not find prompt '{prompt_first_line}' in llama.cpp output. "
1055
+ "This usually means the model failed to process the prompt correctly.\n"
1056
+ f"Raw output:\n{stdout}\n"
1057
+ f"Stderr:\n{stderr}"
1058
+ )
1059
+
1060
+ # Return list containing the generated text
1061
+ return [output_text]
1062
+
1063
+ except Exception as e:
1064
+ error_msg = f"Failed to run llama-cli.exe command: {str(e)}\n"
1065
+ error_msg += f"Command: {' '.join(cmd)}"
1066
+ raise Exception(error_msg)
1067
+
1068
+ def benchmark(self, prompt, iterations, output_tokens):
1069
+ """
1070
+ Runs the llama-bench.exe tool to measure TTFT and TPS
1071
+ """
1072
+ cmd = [
1073
+ self.bench_executable,
1074
+ "-m",
1075
+ self.model,
1076
+ "-r",
1077
+ iterations,
1078
+ "-p",
1079
+ str(prompt),
1080
+ "-n",
1081
+ output_tokens,
1082
+ "-t",
1083
+ self.threads if self.threads > 0 else 16,
1084
+ "-b",
1085
+ 1,
1086
+ "-ub",
1087
+ 1,
1088
+ ]
1089
+ ngl_value = "99" if self.device == "igpu" else "0"
1090
+ cmd = cmd + ["-ngl", ngl_value]
1091
+ cmd = [str(m) for m in cmd]
1092
+
1093
+ # save llama-bench command
1094
+ self.state.llama_bench_cmd = " ".join(cmd)
1095
+
1096
+ try:
1097
+ # Set up environment with library path for Linux
1098
+ env = os.environ.copy()
1099
+
1100
+ # Load environment variables from .env file in the executable directory
1101
+ exe_dir = os.path.dirname(self.executable)
1102
+ env_file_path = os.path.join(exe_dir, ".env")
1103
+ if os.path.exists(env_file_path):
1104
+ load_dotenv(env_file_path, override=True)
1105
+ env.update(os.environ)
1106
+
1107
+ if self.lib_dir and os.name != "nt": # Not Windows
1108
+ current_ld_path = env.get("LD_LIBRARY_PATH", "")
1109
+ if current_ld_path:
1110
+ env["LD_LIBRARY_PATH"] = f"{self.lib_dir}:{current_ld_path}"
1111
+ else:
1112
+ env["LD_LIBRARY_PATH"] = self.lib_dir
1113
+
1114
+ process = subprocess.Popen(
1115
+ cmd,
1116
+ stdout=subprocess.PIPE,
1117
+ stderr=subprocess.PIPE,
1118
+ universal_newlines=True,
1119
+ encoding="utf-8",
1120
+ errors="replace",
1121
+ env=env,
1122
+ )
1123
+
1124
+ # Start memory monitoring in a separate thread
1125
+ save_max_memory_used = platform.system() == "Windows"
1126
+ if save_max_memory_used:
1127
+ memory_data = {}
1128
+ monitor_thread = threading.Thread(
1129
+ target=monitor_process_memory,
1130
+ args=(process.pid, memory_data),
1131
+ daemon=True,
1132
+ )
1133
+ monitor_thread.start()
1134
+
1135
+ # Communicate with the subprocess
1136
+ stdout, stderr = process.communicate(timeout=600)
1137
+
1138
+ # save llama-bench command output with performance info to state
1139
+ # (can be viewed in state.yaml file in cache)
1140
+ self.state.llama_bench_standard_output = stdout.splitlines()
1141
+
1142
+ if process.returncode != 0:
1143
+ error_msg = (
1144
+ f"llama-bench.exe failed with return code {process.returncode}.\n"
1145
+ )
1146
+ error_msg += f"Command: {' '.join(cmd)}\n"
1147
+ error_msg += f"Error output:\n{stderr}\n"
1148
+ error_msg += f"Standard output:\n{stdout}"
1149
+ raise Exception(error_msg)
1150
+
1151
+ if stdout is None:
1152
+ error_msg = "No output received from llama-bench.exe process\n"
1153
+ error_msg += f"Error output:\n{stderr}\n"
1154
+ error_msg += f"Standard output:\n{stdout}"
1155
+ raise Exception(error_msg)
1156
+
1157
+ # Parse information from llama-bench.exe output
1158
+ prompt_length = None
1159
+ pp_tps = None
1160
+ pp_tps_sd = None
1161
+ tg_tps = None
1162
+ tg_tps_sd = None
1163
+
1164
+ for line in stdout.splitlines():
1165
+ # Parse TPS information
1166
+ if f"pp{prompt:d}" in line:
1167
+ parts = line.split("|")
1168
+ timings = parts[-2].strip().split(" ")
1169
+ prompt_length = prompt
1170
+ pp_tps = float(timings[0])
1171
+ pp_tps_sd = float(timings[-1])
1172
+ if f"tg{output_tokens:d}" in line:
1173
+ parts = line.split("|")
1174
+ timings = parts[-2].strip().split(" ")
1175
+ tg_tps = float(timings[0])
1176
+ tg_tps_sd = float(timings[-1])
1177
+
1178
+ except Exception as e:
1179
+ error_msg = f"Failed to run llama-bench.exe command: {str(e)}\n"
1180
+ error_msg += f"Command: {' '.join(cmd)}"
1181
+ raise Exception(error_msg)
1182
+
1183
+ # Determine max memory used
1184
+ if save_max_memory_used:
1185
+ # Wait for monitor thread to finish
1186
+ monitor_thread.join(timeout=2)
1187
+
1188
+ # Track memory usage concurrently
1189
+ peak_wset = memory_data.get("peak_wset", None)
1190
+ else:
1191
+ peak_wset = None
1192
+
1193
+ return prompt_length, pp_tps, pp_tps_sd, tg_tps, tg_tps_sd, peak_wset
1194
+
1195
+
1196
+ def get_hip_devices():
1197
+ """Get list of HIP devices with their IDs and names."""
1198
+ import ctypes
1199
+ import sys
1200
+ import os
1201
+ import glob
1202
+ from ctypes import c_int, POINTER
1203
+ from ctypes.util import find_library
1204
+
1205
+ # Get llama.cpp path
1206
+ rocm_path = get_llama_folder_path("rocm")
1207
+
1208
+ # Load HIP library
1209
+ hip_library_pattern = (
1210
+ "amdhip64*.dll" if sys.platform.startswith("win") else "libamdhip64*.so"
1211
+ )
1212
+ search_pattern = os.path.join(rocm_path, hip_library_pattern)
1213
+ matching_files = glob.glob(search_pattern)
1214
+ if not matching_files:
1215
+ raise RuntimeError(
1216
+ f"Could not find HIP runtime library matching pattern: {search_pattern}"
1217
+ )
1218
+ try:
1219
+ libhip = ctypes.CDLL(matching_files[0])
1220
+ except OSError:
1221
+ raise RuntimeError(
1222
+ f"Could not load HIP runtime library from {matching_files[0]}"
1223
+ )
1224
+
1225
+ # Setup function signatures
1226
+ hipError_t = c_int
1227
+ hipDeviceProp_t = ctypes.c_char * 2048
1228
+ libhip.hipGetDeviceCount.restype = hipError_t
1229
+ libhip.hipGetDeviceCount.argtypes = [POINTER(c_int)]
1230
+ libhip.hipGetDeviceProperties.restype = hipError_t
1231
+ libhip.hipGetDeviceProperties.argtypes = [POINTER(hipDeviceProp_t), c_int]
1232
+ libhip.hipGetErrorString.restype = ctypes.c_char_p
1233
+ libhip.hipGetErrorString.argtypes = [hipError_t]
1234
+
1235
+ # Get device count
1236
+ device_count = c_int()
1237
+ err = libhip.hipGetDeviceCount(ctypes.byref(device_count))
1238
+ if err != 0:
1239
+ logging.error(
1240
+ "hipGetDeviceCount failed:", libhip.hipGetErrorString(err).decode()
1241
+ )
1242
+ return []
1243
+
1244
+ # Get device properties
1245
+ devices = []
1246
+ for i in range(device_count.value):
1247
+ prop = hipDeviceProp_t()
1248
+ err = libhip.hipGetDeviceProperties(ctypes.byref(prop), i)
1249
+ if err != 0:
1250
+ logging.error(
1251
+ f"hipGetDeviceProperties failed for device {i}:",
1252
+ libhip.hipGetErrorString(err).decode(),
1253
+ )
1254
+ continue
1255
+
1256
+ # Extract device name from HIP device properties
1257
+ device_name = ctypes.string_at(prop, 256).decode("utf-8").rstrip("\x00")
1258
+ devices.append([i, device_name])
1259
+
1260
+ return devices