lemonade-sdk 8.0.4__py3-none-any.whl → 8.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lemonade-sdk might be problematic. Click here for more details.

@@ -0,0 +1,612 @@
1
+ import logging
2
+ import os
3
+ import platform
4
+ import shutil
5
+ import sys
6
+ import zipfile
7
+ from typing import Optional
8
+ import subprocess
9
+ import requests
10
+ import lemonade.common.printing as printing
11
+ from lemonade.tools.adapter import PassthroughTokenizer, ModelAdapter
12
+
13
+ LLAMA_VERSION = "b5902"
14
+
15
+
16
+ def get_llama_folder_path():
17
+ """
18
+ Get path for llama.cpp platform-specific executables folder
19
+ """
20
+ return os.path.join(os.path.dirname(sys.executable), "llamacpp")
21
+
22
+
23
+ def get_llama_exe_path(exe_name):
24
+ """
25
+ Get path to platform-specific llama-server executable
26
+ """
27
+ base_dir = get_llama_folder_path()
28
+ if platform.system().lower() == "windows":
29
+ return os.path.join(base_dir, f"{exe_name}.exe")
30
+ else: # Linux/Ubuntu
31
+ # Check if executable exists in build/bin subdirectory (Current Ubuntu structure)
32
+ build_bin_path = os.path.join(base_dir, "build", "bin", exe_name)
33
+ if os.path.exists(build_bin_path):
34
+ return build_bin_path
35
+ else:
36
+ # Fallback to root directory
37
+ return os.path.join(base_dir, exe_name)
38
+
39
+
40
+ def get_llama_server_exe_path():
41
+ """
42
+ Get path to platform-specific llama-server executable
43
+ """
44
+ return get_llama_exe_path("llama-server")
45
+
46
+
47
+ def get_llama_cli_exe_path():
48
+ """
49
+ Get path to platform-specific llama-cli executable
50
+ """
51
+ return get_llama_exe_path("llama-cli")
52
+
53
+
54
+ def get_version_txt_path():
55
+ """
56
+ Get path to text file that contains version information
57
+ """
58
+ return os.path.join(get_llama_folder_path(), "version.txt")
59
+
60
+
61
+ def get_llama_installed_version():
62
+ """
63
+ Gets version of installed llama.cpp
64
+ Returns None if llama.cpp is not installed
65
+ """
66
+ version_txt_path = get_version_txt_path()
67
+ if os.path.exists(version_txt_path):
68
+ with open(version_txt_path, "r", encoding="utf-8") as f:
69
+ llama_installed_version = f.read()
70
+ return llama_installed_version
71
+ return None
72
+
73
+
74
+ def get_binary_url_and_filename(version):
75
+ """
76
+ Get the appropriate llama.cpp binary URL and filename based on platform
77
+ """
78
+ system = platform.system().lower()
79
+
80
+ if system == "windows":
81
+ filename = f"llama-{version}-bin-win-vulkan-x64.zip"
82
+ elif system == "linux":
83
+ filename = f"llama-{version}-bin-ubuntu-vulkan-x64.zip"
84
+ else:
85
+ raise NotImplementedError(
86
+ f"Platform {system} not supported for llamacpp. Supported: Windows, Ubuntu Linux"
87
+ )
88
+
89
+ url = (
90
+ f"https://github.com/ggml-org/llama.cpp/releases/download/{version}/{filename}"
91
+ )
92
+ return url, filename
93
+
94
+
95
+ def validate_platform_support():
96
+ """
97
+ Validate platform support before attempting download
98
+ """
99
+ system = platform.system().lower()
100
+
101
+ if system not in ["windows", "linux"]:
102
+ raise NotImplementedError(
103
+ f"Platform {system} not supported for llamacpp. "
104
+ "Supported: Windows, Ubuntu Linux"
105
+ )
106
+
107
+ if system == "linux":
108
+ # Check if we're actually on Ubuntu/compatible distro and log a warning if not
109
+ try:
110
+ with open("/etc/os-release", "r", encoding="utf-8") as f:
111
+ os_info = f.read().lower()
112
+ if "ubuntu" not in os_info and "debian" not in os_info:
113
+ logging.warning(
114
+ "llamacpp binaries are built for Ubuntu. "
115
+ "Compatibility with other Linux distributions is not guaranteed."
116
+ )
117
+ except (FileNotFoundError, PermissionError, OSError) as e:
118
+ logging.warning(
119
+ "Could not determine Linux distribution (%s). "
120
+ "llamacpp binaries are built for Ubuntu.",
121
+ str(e),
122
+ )
123
+
124
+
125
+ def install_llamacpp():
126
+ """
127
+ Installs or upgrades llama.cpp binaries if needed
128
+ """
129
+
130
+ # Exception will be thrown if platform is not supported
131
+ validate_platform_support()
132
+
133
+ # Installation location for llama.cpp
134
+ llama_folder_path = get_llama_folder_path()
135
+
136
+ # Check whether the llamacpp install needs an upgrade
137
+ if os.path.exists(llama_folder_path):
138
+ if get_llama_installed_version() != LLAMA_VERSION:
139
+ # Remove the existing install, which will trigger a new install
140
+ # in the next code block
141
+ shutil.rmtree(llama_folder_path)
142
+
143
+ # Download llama.cpp server if it isn't already available
144
+ if not os.path.exists(llama_folder_path):
145
+ # Download llama.cpp server zip
146
+ llama_zip_url, filename = get_binary_url_and_filename(LLAMA_VERSION)
147
+ llama_zip_path = os.path.join(os.path.dirname(sys.executable), filename)
148
+ logging.info(f"Downloading llama.cpp server from {llama_zip_url}")
149
+
150
+ with requests.get(llama_zip_url, stream=True) as r:
151
+ r.raise_for_status()
152
+ with open(llama_zip_path, "wb") as f:
153
+ for chunk in r.iter_content(chunk_size=8192):
154
+ f.write(chunk)
155
+
156
+ # Extract zip
157
+ logging.info(f"Extracting {llama_zip_path} to {llama_folder_path}")
158
+ with zipfile.ZipFile(llama_zip_path, "r") as zip_ref:
159
+ zip_ref.extractall(llama_folder_path)
160
+
161
+ # Make executable on Linux - need to update paths after extraction
162
+ if platform.system().lower() == "linux":
163
+ # Re-get the paths since extraction might have changed the directory structure
164
+ for updated_exe_path in [
165
+ get_llama_server_exe_path(),
166
+ get_llama_cli_exe_path(),
167
+ ]:
168
+ if os.path.exists(updated_exe_path):
169
+ os.chmod(updated_exe_path, 0o755)
170
+ logging.info(f"Set executable permissions for {updated_exe_path}")
171
+ else:
172
+ logging.warning(
173
+ f"Could not find llama.cpp executable at {updated_exe_path}"
174
+ )
175
+
176
+ # Save version.txt
177
+ with open(get_version_txt_path(), "w", encoding="utf-8") as vf:
178
+ vf.write(LLAMA_VERSION)
179
+
180
+ # Delete zip file
181
+ os.remove(llama_zip_path)
182
+ logging.info("Cleaned up zip file")
183
+
184
+
185
+ def parse_checkpoint(checkpoint: str) -> tuple[str, str | None]:
186
+ """
187
+ Parse a checkpoint string that may contain a variant separated by a colon.
188
+
189
+ For GGUF models, the format is "repository:variant" (e.g., "unsloth/Qwen3-0.6B-GGUF:Q4_0").
190
+ For other models, there is no variant.
191
+
192
+ Args:
193
+ checkpoint: The checkpoint string, potentially with variant
194
+
195
+ Returns:
196
+ tuple: (base_checkpoint, variant) where variant is None if no colon is present
197
+ """
198
+ if ":" in checkpoint:
199
+ base_checkpoint, variant = checkpoint.split(":", 1)
200
+ return base_checkpoint, variant
201
+ return checkpoint, None
202
+
203
+
204
+ def get_local_checkpoint_path(base_checkpoint, variant):
205
+ """
206
+ Returns the absolute path to a .gguf checkpoint file in the local HuggingFace hub.
207
+ Also returns just .gguf filename.
208
+
209
+ Checkpoint is one of the following types:
210
+ 1. Full filename: exact file to download
211
+ 2. Quantization variant: find a single file ending with the variant name (case insensitive)
212
+ 3. Folder name with subfolder that matches the variant name (case insensitive)
213
+
214
+ """
215
+ full_model_path = None
216
+ model_to_use = None
217
+ try:
218
+ from huggingface_hub import snapshot_download
219
+
220
+ snapshot_path = snapshot_download(
221
+ repo_id=base_checkpoint,
222
+ local_files_only=True,
223
+ )
224
+
225
+ full_model_path = None
226
+ model_to_use = None
227
+
228
+ if os.path.isdir(snapshot_path) and os.listdir(snapshot_path):
229
+
230
+ snapshot_files = [filename for filename in os.listdir(snapshot_path)]
231
+
232
+ if variant.endswith(".gguf"):
233
+ # Variant is an exact file
234
+ model_to_use = variant
235
+ if variant in snapshot_files:
236
+ full_model_path = os.path.join(snapshot_path, variant)
237
+ else:
238
+ raise ValueError(
239
+ f"The variant {variant} is not available locally in {snapshot_path}."
240
+ )
241
+
242
+ else:
243
+ # Variant is a quantization
244
+ end_with_variant = [
245
+ file
246
+ for file in snapshot_files
247
+ if file.lower().endswith(f"{variant}.gguf".lower())
248
+ ]
249
+ if len(end_with_variant) == 1:
250
+ model_to_use = end_with_variant[0]
251
+ full_model_path = os.path.join(snapshot_path, model_to_use)
252
+ elif len(end_with_variant) > 1:
253
+ raise ValueError(
254
+ f"Multiple .gguf files found for variant {variant}, "
255
+ f"but only one is allowed."
256
+ )
257
+ else:
258
+ # Check whether the variant corresponds to a folder with
259
+ # sharded files (case insensitive)
260
+ quantization_folder = [
261
+ folder
262
+ for folder in snapshot_files
263
+ if folder.lower() == variant.lower()
264
+ and os.path.exists(os.path.join(snapshot_path, folder))
265
+ and os.path.isdir(os.path.join(snapshot_path, folder))
266
+ ]
267
+ if len(quantization_folder) == 1:
268
+ quantization_folder = os.path.join(
269
+ snapshot_path, quantization_folder[0]
270
+ )
271
+ sharded_files = [
272
+ f
273
+ for f in os.listdir(quantization_folder)
274
+ if f.endswith(".gguf")
275
+ ]
276
+ if not sharded_files:
277
+ raise ValueError(
278
+ f"No .gguf files found for variant {variant}."
279
+ )
280
+ else:
281
+ model_to_use = sharded_files[0]
282
+ full_model_path = os.path.join(
283
+ quantization_folder, model_to_use
284
+ )
285
+ elif len(quantization_folder) > 1:
286
+ raise ValueError(
287
+ f"Multiple checkpoint folder names match the variant {variant}."
288
+ )
289
+ else:
290
+ raise ValueError(f"No .gguf files found for variant {variant}.")
291
+ else:
292
+ raise ValueError(
293
+ f"The checkpoint {base_checkpoint} is not a local checkpoint."
294
+ )
295
+
296
+ except Exception as e: # pylint: disable=broad-exception-caught
297
+ # Log any errors but continue with the original path
298
+ printing.log_info(f"Error checking Hugging Face cache: {e}")
299
+
300
+ return full_model_path, model_to_use
301
+
302
+
303
+ def identify_gguf_models(
304
+ checkpoint: str, variant: str, mmproj: str
305
+ ) -> tuple[dict, list[str]]:
306
+ """
307
+ Identifies the GGUF model files in the repository that match the variant.
308
+ """
309
+
310
+ hint = """
311
+ The CHECKPOINT:VARIANT scheme is used to specify model files in Hugging Face repositories.
312
+
313
+ The VARIANT format can be one of several types:
314
+ 1. Full filename: exact file to download
315
+ 2. None/empty: gets the first .gguf file in the repository (excludes mmproj files)
316
+ 3. Quantization variant: find a single file ending with the variant name (case insensitive)
317
+ 4. Folder name: downloads all .gguf files in the folder that matches the variant name (case insensitive)
318
+
319
+ Examples:
320
+ - "unsloth/Qwen3-8B-GGUF:qwen3.gguf" -> downloads "qwen3.gguf"
321
+ - "unsloth/Qwen3-30B-A3B-GGUF" -> downloads "Qwen3-30B-A3B-GGUF.gguf"
322
+ - "unsloth/Qwen3-8B-GGUF:Q4_1" -> downloads "Qwen3-8B-GGUF-Q4_1.gguf"
323
+ - "unsloth/Qwen3-30B-A3B-GGUF:Q4_0" -> downloads all files in "Q4_0/" folder
324
+ """
325
+
326
+ from huggingface_hub import list_repo_files
327
+
328
+ repo_files = list_repo_files(checkpoint)
329
+ sharded_files = []
330
+
331
+ # (case 1) If variant ends in .gguf, use it directly
332
+ if variant and variant.endswith(".gguf"):
333
+ variant_name = variant
334
+ if variant_name not in repo_files:
335
+ raise ValueError(
336
+ f"File {variant} not found in Hugging Face repository {checkpoint}. {hint}"
337
+ )
338
+ # (case 2) If no variant is provided, get the first .gguf file in the repository
339
+ elif variant is None:
340
+ all_variants = [
341
+ f for f in repo_files if f.endswith(".gguf") and "mmproj" not in f
342
+ ]
343
+ if len(all_variants) == 0:
344
+ raise ValueError(
345
+ f"No .gguf files found in Hugging Face repository {checkpoint}. {hint}"
346
+ )
347
+ variant_name = all_variants[0]
348
+ else:
349
+ # (case 3) Find a single file ending with the variant name (case insensitive)
350
+ end_with_variant = [
351
+ f
352
+ for f in repo_files
353
+ if f.lower().endswith(f"{variant}.gguf".lower())
354
+ and "mmproj" not in f.lower()
355
+ ]
356
+ if len(end_with_variant) == 1:
357
+ variant_name = end_with_variant[0]
358
+ elif len(end_with_variant) > 1:
359
+ raise ValueError(
360
+ f"Multiple .gguf files found for variant {variant}, but only one is allowed. {hint}"
361
+ )
362
+ # (case 4) Check whether the variant corresponds to a folder with
363
+ # sharded files (case insensitive)
364
+ else:
365
+ sharded_files = [
366
+ f
367
+ for f in repo_files
368
+ if f.endswith(".gguf") and f.lower().startswith(f"{variant}/".lower())
369
+ ]
370
+
371
+ if not sharded_files:
372
+ raise ValueError(f"No .gguf files found for variant {variant}. {hint}")
373
+
374
+ # Sort to ensure consistent ordering
375
+ sharded_files.sort()
376
+
377
+ # Use first file as primary (this is how llamacpp handles it)
378
+ variant_name = sharded_files[0]
379
+
380
+ core_files = {"variant": variant_name}
381
+
382
+ # If there is a mmproj file, add it to the patterns
383
+ if mmproj:
384
+ if mmproj not in repo_files:
385
+ raise ValueError(
386
+ f"The provided mmproj file {mmproj} was not found in {checkpoint}."
387
+ )
388
+ core_files["mmproj"] = mmproj
389
+
390
+ return core_files, sharded_files
391
+
392
+
393
+ def download_gguf(config_checkpoint, config_mmproj=None) -> dict:
394
+ """
395
+ Downloads the GGUF file for the given model configuration.
396
+
397
+ For sharded models, if the variant points to a folder (e.g. Q4_0), all files in that folder
398
+ will be downloaded but only the first file will be returned for loading.
399
+ """
400
+
401
+ # This code handles all cases by constructing the appropriate filename or pattern
402
+ checkpoint, variant = parse_checkpoint(config_checkpoint)
403
+
404
+ # Identify the GGUF model files in the repository that match the variant
405
+ core_files, sharded_files = identify_gguf_models(checkpoint, variant, config_mmproj)
406
+
407
+ # Download the files
408
+ from huggingface_hub import snapshot_download
409
+
410
+ snapshot_folder = snapshot_download(
411
+ repo_id=checkpoint,
412
+ allow_patterns=list(core_files.values()) + sharded_files,
413
+ )
414
+
415
+ # Ensure we downloaded all expected files
416
+ for file in list(core_files.values()) + sharded_files:
417
+ expected_path = os.path.join(snapshot_folder, file)
418
+ if not os.path.exists(expected_path):
419
+ raise ValueError(
420
+ f"Hugging Face snapshot download for {config_checkpoint} "
421
+ f"expected file {file} not found at {expected_path}"
422
+ )
423
+
424
+ # Return a dict of the full path of the core GGUF files
425
+ return {
426
+ file_name: os.path.join(snapshot_folder, file_path)
427
+ for file_name, file_path in core_files.items()
428
+ }
429
+
430
+
431
+ class LlamaCppTokenizerAdapter(PassthroughTokenizer):
432
+ pass
433
+
434
+
435
+ class LlamaCppAdapter(ModelAdapter):
436
+ def __init__(
437
+ self,
438
+ model,
439
+ device,
440
+ output_tokens,
441
+ context_size,
442
+ threads,
443
+ executable,
444
+ reasoning=False,
445
+ lib_dir=None,
446
+ ):
447
+ super().__init__()
448
+
449
+ self.model = os.path.normpath(model)
450
+ self.device = device
451
+ self.output_tokens = (
452
+ output_tokens # default value of max tokens to generate from a prompt
453
+ )
454
+ self.context_size = context_size
455
+ self.threads = threads
456
+ self.executable = os.path.normpath(executable)
457
+ self.reasoning = reasoning
458
+ self.lib_dir = lib_dir
459
+
460
+ def generate(
461
+ self,
462
+ input_ids: str,
463
+ max_new_tokens: Optional[int] = None,
464
+ temperature: float = 0.8,
465
+ top_p: float = 0.95,
466
+ top_k: int = 40,
467
+ return_raw: bool = False,
468
+ **kwargs, # pylint: disable=unused-argument
469
+ ):
470
+ """
471
+ Pass a text prompt into the llamacpp inference CLI.
472
+
473
+ The input_ids arg here should receive the original text that
474
+ would normally be encoded by a tokenizer.
475
+
476
+ Args:
477
+ input_ids: The input text prompt
478
+ max_new_tokens: Maximum number of tokens to generate
479
+ temperature: Temperature for sampling (0.0 = greedy)
480
+ top_p: Top-p sampling threshold
481
+ top_k: Top-k sampling threshold
482
+ return_raw: If True, returns the complete raw output including timing info
483
+ **kwargs: Additional arguments (ignored)
484
+
485
+ Returns:
486
+ List containing a single string with the generated text, or raw output if
487
+ return_raw=True
488
+ """
489
+
490
+ prompt = input_ids
491
+ if self.reasoning:
492
+ prompt += "<think>"
493
+ n_predict = max_new_tokens if max_new_tokens is not None else self.output_tokens
494
+
495
+ cmd = [
496
+ self.executable,
497
+ "-m",
498
+ self.model,
499
+ "--ctx-size",
500
+ str(self.context_size),
501
+ "-n",
502
+ str(n_predict),
503
+ "-t",
504
+ str(self.threads),
505
+ "-p",
506
+ prompt,
507
+ "--temp",
508
+ str(temperature),
509
+ "--top-p",
510
+ str(top_p),
511
+ "--top-k",
512
+ str(top_k),
513
+ "-e",
514
+ "-no-cnv",
515
+ "--reasoning-format",
516
+ "none",
517
+ ]
518
+
519
+ # Configure GPU layers: 99 for GPU, 0 for CPU-only
520
+ ngl_value = "99" if self.device == "igpu" else "0"
521
+ cmd = cmd + ["-ngl", ngl_value]
522
+
523
+ cmd = [str(m) for m in cmd]
524
+
525
+ try:
526
+ # Set up environment with library path for Linux
527
+ env = os.environ.copy()
528
+ if self.lib_dir and os.name != "nt": # Not Windows
529
+ current_ld_path = env.get("LD_LIBRARY_PATH", "")
530
+ if current_ld_path:
531
+ env["LD_LIBRARY_PATH"] = f"{self.lib_dir}:{current_ld_path}"
532
+ else:
533
+ env["LD_LIBRARY_PATH"] = self.lib_dir
534
+
535
+ process = subprocess.Popen(
536
+ cmd,
537
+ stdout=subprocess.PIPE,
538
+ stderr=subprocess.PIPE,
539
+ universal_newlines=True,
540
+ encoding="utf-8",
541
+ errors="replace",
542
+ env=env,
543
+ )
544
+
545
+ raw_output, stderr = process.communicate(timeout=600)
546
+ if process.returncode != 0:
547
+ error_msg = f"llama.cpp failed with return code {process.returncode}.\n"
548
+ error_msg += f"Command: {' '.join(cmd)}\n"
549
+ error_msg += f"Error output:\n{stderr}\n"
550
+ error_msg += f"Standard output:\n{raw_output}"
551
+ raise Exception(error_msg)
552
+
553
+ if raw_output is None:
554
+ raise Exception("No output received from llama.cpp process")
555
+
556
+ # Parse information from llama.cpp output
557
+ for line in stderr.splitlines():
558
+ # Parse timing and token information
559
+ #
560
+ # Prompt processing time and length (tokens)
561
+ # Sample: llama_perf_context_print: prompt eval time = 35.26 ms /
562
+ # 3 tokens ( 11.75 ms per token, 85.09 tokens per second)
563
+ #
564
+ if "llama_perf_context_print: prompt eval time =" in line:
565
+ parts = line.split("=")[1].split()
566
+ time_to_first_token_ms = float(parts[0])
567
+ self.time_to_first_token = time_to_first_token_ms / 1000
568
+ self.prompt_tokens = int(parts[3])
569
+ #
570
+ # Response processing time and length (tokens)
571
+ # Sample: llama_perf_context_print: eval time = 1991.14 ms /
572
+ # 63 runs ( 31.61 ms per token, 31.64 tokens per second)
573
+ #
574
+ if "llama_perf_context_print: eval time =" in line:
575
+ parts = line.split("=")[1].split()
576
+ self.response_tokens = int(parts[3])
577
+ response_time_ms = float(parts[0])
578
+ self.tokens_per_second = (
579
+ 1000 * self.response_tokens / response_time_ms
580
+ if response_time_ms > 0
581
+ else 0
582
+ )
583
+
584
+ if return_raw:
585
+ return [raw_output, stderr]
586
+
587
+ # Find where the prompt ends and the generated text begins
588
+ prompt_found = False
589
+ output_text = ""
590
+ prompt_first_line = prompt.split("\n")[0]
591
+ for line in raw_output.splitlines():
592
+ if prompt_first_line in line:
593
+ prompt_found = True
594
+ if prompt_found:
595
+ line = line.replace("</s> [end of text]", "")
596
+ output_text = output_text + line
597
+
598
+ if not prompt_found:
599
+ raise Exception(
600
+ f"Could not find prompt '{prompt_first_line}' in llama.cpp output. "
601
+ "This usually means the model failed to process the prompt correctly.\n"
602
+ f"Raw output:\n{raw_output}\n"
603
+ f"Stderr:\n{stderr}"
604
+ )
605
+
606
+ # Return list containing the generated text
607
+ return [output_text]
608
+
609
+ except Exception as e:
610
+ error_msg = f"Failed to run llama.cpp command: {str(e)}\n"
611
+ error_msg += f"Command: {' '.join(cmd)}"
612
+ raise Exception(error_msg)