lemonade-sdk 8.0.5__py3-none-any.whl → 8.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lemonade-sdk might be problematic. Click here for more details.

@@ -3,26 +3,25 @@
3
3
  # pylint: disable=no-member
4
4
 
5
5
  import argparse
6
+ import subprocess
7
+ import sys
6
8
  import os
7
9
  import json
8
- import shutil
10
+ import webbrowser
9
11
  from fnmatch import fnmatch
10
- import subprocess
11
-
12
12
 
13
13
  from lemonade.state import State
14
14
  from lemonade.tools import FirstTool
15
+ from lemonade.cache import Keys
15
16
  import lemonade.common.status as status
16
17
  import lemonade.common.printing as printing
17
- from lemonade.cache import Keys
18
18
  from lemonade_install.install import (
19
- get_ryzen_ai_version_info,
20
- get_oga_npu_dir,
21
- get_oga_hybrid_dir,
19
+ _get_ryzenai_version_info,
22
20
  SUPPORTED_RYZEN_AI_SERIES,
21
+ NPU_DRIVER_DOWNLOAD_URL,
22
+ REQUIRED_NPU_DRIVER_VERSION,
23
23
  )
24
24
 
25
-
26
25
  # ONNX Runtime GenAI models will be cached in this subfolder of the lemonade cache folder
27
26
  oga_models_path = "oga_models"
28
27
 
@@ -39,6 +38,42 @@ execution_providers = {
39
38
  }
40
39
 
41
40
 
41
+ def _get_npu_driver_version():
42
+ """
43
+ Get the NPU driver version using PowerShell directly.
44
+ Returns the driver version string or None if not found.
45
+ """
46
+ try:
47
+
48
+ # Use PowerShell directly to avoid wmi issues in embedded Python environments
49
+ powershell_cmd = [
50
+ "powershell",
51
+ "-NoProfile",
52
+ "-ExecutionPolicy",
53
+ "Bypass",
54
+ "-Command",
55
+ (
56
+ "Get-WmiObject -Class Win32_PnPSignedDriver | "
57
+ 'Where-Object { $_.DeviceName -like "*NPU Compute Accelerator Device*" } | '
58
+ "Select-Object -ExpandProperty DriverVersion"
59
+ ),
60
+ ]
61
+
62
+ result = subprocess.run(
63
+ powershell_cmd, capture_output=True, text=True, check=True, timeout=30
64
+ )
65
+
66
+ driver_version = result.stdout.strip()
67
+
68
+ if driver_version and driver_version != "":
69
+ return driver_version
70
+ else:
71
+ return None
72
+
73
+ except Exception: # pylint: disable=broad-except
74
+ return None
75
+
76
+
42
77
  def import_error_heler(e: Exception):
43
78
  """
44
79
  Print a helpful message in the event of an import error
@@ -47,10 +82,26 @@ def import_error_heler(e: Exception):
47
82
  f"{e}\n Please install lemonade-sdk with "
48
83
  "one of the oga extras, for example:\n"
49
84
  "pip install lemonade-sdk[dev,oga-cpu]\n"
50
- "See https://lemonade_server.ai/install_options.html for details"
85
+ "See https://lemonade-server.ai/install_options.html for details"
51
86
  )
52
87
 
53
88
 
89
+ def _open_driver_install_page():
90
+ """
91
+ Opens the driver installation page in the user's default web browser.
92
+ """
93
+ try:
94
+ driver_page_url = "https://lemonade-server.ai/driver_install.html"
95
+ printing.log_info(f"Opening driver installation guide: {driver_page_url}")
96
+ webbrowser.open(driver_page_url)
97
+ except Exception as e: # pylint: disable=broad-except
98
+ printing.log_info(f"Could not open browser automatically: {e}")
99
+ printing.log_info(
100
+ "Please visit https://lemonade-server.ai/driver_install.html "
101
+ "for driver installation instructions."
102
+ )
103
+
104
+
54
105
  class OgaLoad(FirstTool):
55
106
  """
56
107
  Tool that loads an LLM in OnnxRuntime-GenAI for use with CPU or DirectML execution providers.
@@ -208,7 +259,7 @@ class OgaLoad(FirstTool):
208
259
  files that have locally been quantized/converted to OGA format and any other
209
260
  models that have been manually added by the user.
210
261
  """
211
- from huggingface_hub import snapshot_download
262
+ from lemonade.common.network import custom_snapshot_download
212
263
 
213
264
  if subfolder is None:
214
265
  subfolder = f"{execution_providers[device]}-{dtype}"
@@ -232,8 +283,8 @@ class OgaLoad(FirstTool):
232
283
  # If not found in lemonade cache, check in Hugging Face cache
233
284
  if not model_exists_locally:
234
285
  try:
235
- snapshot_path = snapshot_download(
236
- repo_id=checkpoint,
286
+ snapshot_path = custom_snapshot_download(
287
+ checkpoint,
237
288
  local_files_only=True,
238
289
  )
239
290
 
@@ -258,25 +309,101 @@ class OgaLoad(FirstTool):
258
309
  return full_model_path, model_exists_locally
259
310
 
260
311
  @staticmethod
261
- def _update_hybrid_custom_ops_library_path(full_model_path):
312
+ def _setup_model_dependencies(full_model_path, device, ryzenai_version, oga_path):
262
313
  """
263
- Modifies the genai_config.json file in the hybrid model folder to set the custom_ops_library
264
- path to the location of the onnx_custom_ops.dll in the current environment.
265
- This is needed for hybrid inference.
314
+ Sets up model dependencies for hybrid and NPU inference by:
315
+ 1. Configuring the custom_ops_library path in genai_config.json.
316
+ 2. Adding DLL source directories to PATH for dependent DLL discovery.
317
+ 3. Check NPU driver version if required for device and ryzenai_version.
266
318
  """
267
- oga_path, version = get_oga_hybrid_dir()
268
-
269
- if "1.3.0" in version:
270
- custom_ops_path = os.path.join(
271
- oga_path,
272
- "onnx_utils",
273
- "bin",
274
- "onnx_custom_ops.dll",
275
- )
319
+
320
+ env_path = sys.prefix
321
+
322
+ if "1.4.0" in ryzenai_version:
323
+ if device == "npu":
324
+ custom_ops_path = os.path.join(
325
+ oga_path, "libs", "onnxruntime_vitis_ai_custom_ops.dll"
326
+ )
327
+ else:
328
+ custom_ops_path = os.path.join(oga_path, "libs", "onnx_custom_ops.dll")
276
329
  else:
277
- custom_ops_path = os.path.join(oga_path, "libs", "onnx_custom_ops.dll")
330
+ # For 1.5.0+, check NPU driver version for NPU and hybrid devices
331
+ if device in ["npu", "hybrid"]:
332
+ required_driver_version = REQUIRED_NPU_DRIVER_VERSION
333
+
334
+ current_driver_version = _get_npu_driver_version()
335
+
336
+ if not current_driver_version:
337
+ printing.log_warning(
338
+ f"NPU driver not found. {device.upper()} inference requires NPU driver "
339
+ f"version {required_driver_version}.\n"
340
+ "Please download and install the NPU Driver from:\n"
341
+ f"{NPU_DRIVER_DOWNLOAD_URL}\n"
342
+ "NPU functionality may not work properly."
343
+ )
344
+ _open_driver_install_page()
345
+
346
+ elif current_driver_version != required_driver_version:
347
+ printing.log_warning(
348
+ f"Incorrect NPU driver version detected: {current_driver_version}\n"
349
+ f"{device.upper()} inference with RyzenAI 1.5.0 requires driver "
350
+ f"version {required_driver_version}.\n"
351
+ "Please download and install the correct NPU Driver from:\n"
352
+ f"{NPU_DRIVER_DOWNLOAD_URL}\n"
353
+ "NPU functionality may not work properly."
354
+ )
355
+ _open_driver_install_page()
356
+
357
+ if device == "npu":
358
+ # For 1.5.0, custom ops are in the conda environment's onnxruntime package
359
+ custom_ops_path = os.path.join(
360
+ env_path,
361
+ "Lib",
362
+ "site-packages",
363
+ "onnxruntime",
364
+ "capi",
365
+ "onnxruntime_vitis_ai_custom_ops.dll",
366
+ )
367
+ dll_source_path = os.path.join(
368
+ env_path, "Lib", "site-packages", "onnxruntime", "capi"
369
+ )
370
+ required_dlls = ["dyn_dispatch_core.dll", "xaiengine.dll"]
371
+ else:
372
+ custom_ops_path = os.path.join(
373
+ env_path,
374
+ "Lib",
375
+ "site-packages",
376
+ "onnxruntime_genai",
377
+ "onnx_custom_ops.dll",
378
+ )
379
+ dll_source_path = os.path.join(
380
+ env_path, "Lib", "site-packages", "onnxruntime_genai"
381
+ )
382
+ required_dlls = ["libutf8_validity.dll", "abseil_dll.dll"]
383
+
384
+ # Validate that all required DLLs exist in the source directory
385
+ missing_dlls = []
386
+ if not os.path.exists(custom_ops_path):
387
+ missing_dlls.append(custom_ops_path)
388
+
389
+ for dll_name in required_dlls:
390
+ dll_source = os.path.join(dll_source_path, dll_name)
391
+ if not os.path.exists(dll_source):
392
+ missing_dlls.append(dll_source)
393
+
394
+ if missing_dlls:
395
+ dll_list = "\n - ".join(missing_dlls)
396
+ raise RuntimeError(
397
+ f"Required DLLs not found for {device} inference:\n - {dll_list}\n"
398
+ f"Please ensure your RyzenAI installation is complete and supports {device}."
399
+ )
400
+
401
+ # Add the DLL source directory to PATH
402
+ current_path = os.environ.get("PATH", "")
403
+ if dll_source_path not in current_path:
404
+ os.environ["PATH"] = dll_source_path + os.pathsep + current_path
278
405
 
279
- # Insert the custom_ops_path into the model config file
406
+ # Update the model config with custom_ops_library path
280
407
  config_path = os.path.join(full_model_path, "genai_config.json")
281
408
  if os.path.exists(config_path):
282
409
  with open(config_path, "r", encoding="utf-8") as f:
@@ -363,63 +490,32 @@ class OgaLoad(FirstTool):
363
490
  return full_model_path
364
491
 
365
492
  @staticmethod
366
- def _setup_npu_environment():
493
+ def _setup_npu_environment(ryzenai_version, oga_path):
367
494
  """
368
495
  Sets up environment for NPU flow of ONNX model and returns saved state to be restored
369
496
  later in cleanup.
370
497
  """
371
- oga_path, version = get_oga_npu_dir()
372
-
373
- if not os.path.exists(os.path.join(oga_path, "libs", "onnxruntime.dll")):
374
- raise RuntimeError(
375
- f"Cannot find libs/onnxruntime.dll in lib folder: {oga_path}"
376
- )
377
-
378
- # Save current state so they can be restored after inference.
379
- saved_state = {"cwd": os.getcwd(), "path": os.environ["PATH"]}
380
-
381
- # Setup NPU environment (cwd and path will be restored later)
382
- os.chdir(oga_path)
383
- os.environ["PATH"] = (
384
- os.path.join(oga_path, "libs") + os.pathsep + os.environ["PATH"]
385
- )
386
- if "1.3.0" in version:
387
- os.environ["DD_ROOT"] = ".\\bins"
388
- os.environ["DEVICE"] = "stx"
389
- os.environ["XLNX_ENABLE_CACHE"] = "0"
498
+ if "1.5.0" in ryzenai_version:
499
+ # For PyPI installation (1.5.0+), no environment setup needed
500
+ return None
501
+ elif "1.4.0" in ryzenai_version:
502
+ # Legacy lemonade-install approach for 1.4.0
503
+ if not os.path.exists(os.path.join(oga_path, "libs", "onnxruntime.dll")):
504
+ raise RuntimeError(
505
+ f"Cannot find libs/onnxruntime.dll in lib folder: {oga_path}"
506
+ )
390
507
 
391
- return saved_state
508
+ # Save current state so they can be restored after inference.
509
+ saved_state = {"cwd": os.getcwd(), "path": os.environ["PATH"]}
392
510
 
393
- @staticmethod
394
- def _setup_hybrid_environment():
395
- """
396
- Sets up the environment for the Hybrid flow and returns saved state to be restored later
397
- in cleanup.
398
- """
399
- # Determine the Ryzen AI OGA version and hybrid artifacts path
400
- oga_path, version = get_oga_hybrid_dir()
401
-
402
- if "1.3.0" in version:
403
- dst_dll = os.path.join(
404
- oga_path,
405
- "onnx_utils",
406
- "bin",
407
- "DirectML.dll",
511
+ # Setup NPU environment (cwd and path will be restored later)
512
+ os.chdir(oga_path)
513
+ os.environ["PATH"] = (
514
+ os.path.join(oga_path, "libs") + os.pathsep + os.environ["PATH"]
408
515
  )
409
- if not os.path.isfile(dst_dll):
410
- # Artifacts 1.3.0 has DirectML.dll in different subfolder, so copy it to the
411
- # correct place. This should not be needed in later RAI release artifacts.
412
- src_dll = os.path.join(
413
- oga_path,
414
- "onnxruntime_genai",
415
- "lib",
416
- "DirectML.dll",
417
- )
418
- os.makedirs(os.path.dirname(dst_dll), exist_ok=True)
419
- shutil.copy2(src_dll, dst_dll)
420
-
421
- saved_state = None
422
- return saved_state
516
+ return saved_state
517
+ else:
518
+ raise ValueError(f"Unsupported RyzenAI version: {ryzenai_version}")
423
519
 
424
520
  @staticmethod
425
521
  def _load_model_and_setup_state(
@@ -431,7 +527,6 @@ class OgaLoad(FirstTool):
431
527
  """
432
528
 
433
529
  try:
434
- from transformers import AutoTokenizer
435
530
  from lemonade.tools.oga.utils import OrtGenaiModel, OrtGenaiTokenizer
436
531
  from lemonade.common.network import is_offline
437
532
  except ImportError as e:
@@ -456,6 +551,11 @@ class OgaLoad(FirstTool):
456
551
  # Auto-detect offline mode
457
552
  offline = is_offline()
458
553
 
554
+ try:
555
+ from transformers import AutoTokenizer
556
+ except ImportError as e:
557
+ import_error_heler(e)
558
+
459
559
  try:
460
560
  # Always try to use local files first
461
561
  local_files_only = True
@@ -495,42 +595,52 @@ class OgaLoad(FirstTool):
495
595
  os.chdir(saved_state["cwd"])
496
596
  os.environ["PATH"] = saved_state["path"]
497
597
 
498
- def _generate_model_for_hybrid_or_npu(
499
- self, output_model_path, device, input_model_path
500
- ):
598
+ def _generate_model_for_oga(self, output_model_path, device, input_model_path):
501
599
  """
502
- Uses a subprocess to run the 'model_generate' command for hybrid or npu devices.
600
+ Uses the model_generate tool to generate the model for OGA hybrid or npu targets.
503
601
  """
602
+ try:
603
+ import model_generate
604
+ except ImportError as e:
605
+ raise ImportError(
606
+ f"{e}\nYou are trying to use a developer tool that may not be "
607
+ "installed. Please install the required package using:\n"
608
+ "pip install -e .[dev,oga-ryzenai] \
609
+ --extra-index-url https://pypi.amd.com/simple"
610
+ )
504
611
 
505
612
  # Determine the appropriate flag based on the device type
506
613
  if device == "hybrid":
507
- device_flag = "--hybrid"
614
+ device_flag = "hybrid"
508
615
  elif device == "npu":
509
- device_flag = "--npu"
616
+ device_flag = "npu"
510
617
  else:
511
618
  raise ValueError(f"Unsupported device type for model generation: {device}")
512
619
 
513
- command = [
514
- "model_generate",
515
- device_flag,
516
- output_model_path, # Output model directory
517
- input_model_path, # Input model directory
518
- ]
620
+ printing.log_info(
621
+ f"Generating model for device: {device_flag}, \
622
+ input: {input_model_path}, output: {output_model_path}"
623
+ )
519
624
 
520
- printing.log_info(f"Running command: {' '.join(command)}")
521
625
  try:
522
- with open(self.logfile_path, "w", encoding="utf-8") as log_file:
523
- subprocess.run(
524
- command, check=True, text=True, stdout=log_file, stderr=log_file
626
+ if device_flag == "npu":
627
+ model_generate.generate_npu_model(
628
+ input_model=input_model_path,
629
+ output_dir=output_model_path,
630
+ packed_const=False,
525
631
  )
526
- except FileNotFoundError as e:
527
- error_message = (
528
- "The 'model_generate' package is missing from your system. "
529
- "Ensure all required packages are installed. "
530
- "To install it, run the following command:\n\n"
531
- " lemonade-install --ryzenai <target> --build-model\n"
532
- )
533
- raise RuntimeError(error_message) from e
632
+ else: # hybrid
633
+ model_generate.generate_hybrid_model(
634
+ input_model=input_model_path,
635
+ output_dir=output_model_path,
636
+ # script_option="jit_npu",
637
+ # mode="bf16",
638
+ # dml_only=False,
639
+ )
640
+ except Exception as e:
641
+ raise RuntimeError(
642
+ f"Failed to generate model for {device_flag} device. Error: {e}"
643
+ ) from e
534
644
 
535
645
  def run(
536
646
  self,
@@ -545,8 +655,11 @@ class OgaLoad(FirstTool):
545
655
  trust_remote_code=False,
546
656
  subfolder: str = None,
547
657
  ) -> State:
548
- from huggingface_hub import snapshot_download
549
- from lemonade.common.network import get_base_model, is_offline
658
+ from lemonade.common.network import (
659
+ custom_snapshot_download,
660
+ get_base_model,
661
+ is_offline,
662
+ )
550
663
 
551
664
  # Auto-detect offline status
552
665
  offline = is_offline()
@@ -562,7 +675,8 @@ class OgaLoad(FirstTool):
562
675
  state.save_stat(Keys.DTYPE, dtype)
563
676
  state.save_stat(Keys.DEVICE, device)
564
677
  if device in ["hybrid", "npu"]:
565
- ryzen_ai_version_info = get_ryzen_ai_version_info()
678
+ ryzenai_version, _ = _get_ryzenai_version_info(device)
679
+ ryzen_ai_version_info = {"version": ryzenai_version}
566
680
  state.save_stat(Keys.RYZEN_AI_VERSION_INFO, ryzen_ai_version_info)
567
681
 
568
682
  # Check if input is a local folder
@@ -627,8 +741,8 @@ class OgaLoad(FirstTool):
627
741
  "The (device, dtype, checkpoint) combination is not supported: "
628
742
  f"({device}, {dtype}, {checkpoint})"
629
743
  )
630
- input_model_path = snapshot_download(
631
- repo_id=checkpoint,
744
+ input_model_path = custom_snapshot_download(
745
+ checkpoint,
632
746
  ignore_patterns=["*.md", "*.txt"],
633
747
  local_files_only=offline,
634
748
  )
@@ -661,7 +775,7 @@ class OgaLoad(FirstTool):
661
775
  else:
662
776
  # If ONNX but not modified yet for Hybrid or NPU,
663
777
  # needs further optimization
664
- self._generate_model_for_hybrid_or_npu(
778
+ self._generate_model_for_oga(
665
779
  full_model_path,
666
780
  device,
667
781
  input_model_path,
@@ -673,7 +787,7 @@ class OgaLoad(FirstTool):
673
787
  config = json.load(f)
674
788
  if "quantization_config" in config:
675
789
  # If quantized, use subprocess to generate the model
676
- self._generate_model_for_hybrid_or_npu(
790
+ self._generate_model_for_oga(
677
791
  full_model_path, device, input_model_path
678
792
  )
679
793
  else:
@@ -708,18 +822,31 @@ class OgaLoad(FirstTool):
708
822
 
709
823
  # Load model if download-only argument is not set
710
824
  if not download_only:
825
+ # Get version information for NPU/Hybrid devices
826
+ if device in ["hybrid", "npu"]:
827
+ ryzenai_version, oga_path = _get_ryzenai_version_info(device)
828
+ else:
829
+ ryzenai_version, oga_path = None, None
711
830
 
712
831
  saved_env_state = None
832
+
833
+ # Setup model dependencies for NPU/Hybrid devices
834
+ if device in ["hybrid", "npu"]:
835
+ self._setup_model_dependencies(
836
+ full_model_path, device, ryzenai_version, oga_path
837
+ )
838
+
713
839
  try:
714
840
  if device == "npu":
715
- saved_env_state = self._setup_npu_environment()
841
+ saved_env_state = self._setup_npu_environment(
842
+ ryzenai_version, oga_path
843
+ )
716
844
  # Set USE_AIE_RoPE based on model type
717
845
  os.environ["USE_AIE_RoPE"] = (
718
846
  "0" if "phi-" in checkpoint.lower() else "1"
719
847
  )
720
848
  elif device == "hybrid":
721
- saved_env_state = self._setup_hybrid_environment()
722
- self._update_hybrid_custom_ops_library_path(full_model_path)
849
+ saved_env_state = None
723
850
 
724
851
  self._load_model_and_setup_state(
725
852
  state, full_model_path, checkpoint, trust_remote_code
@@ -11,6 +11,7 @@ from lemonade.tools.adapter import (
11
11
  TokenizerAdapter,
12
12
  PassthroughTokenizerResult,
13
13
  )
14
+ from lemonade_install.install import _get_ryzenai_version_info
14
15
 
15
16
 
16
17
  class OrtGenaiTokenizer(TokenizerAdapter):
@@ -67,18 +68,29 @@ class OrtGenaiModel(ModelAdapter):
67
68
 
68
69
  def load_config(self, input_folder):
69
70
  rai_config_path = os.path.join(input_folder, "rai_config.json")
70
- if os.path.exists(rai_config_path):
71
- with open(rai_config_path, "r", encoding="utf-8") as f:
72
- max_prompt_length = json.load(f)["max_prompt_length"]["1.4.1"]
73
- else:
74
- max_prompt_length = None
71
+ max_prompt_length = None
72
+
73
+ try:
74
+ detected_version, _ = _get_ryzenai_version_info()
75
+
76
+ if os.path.exists(rai_config_path):
77
+ with open(rai_config_path, "r", encoding="utf-8") as f:
78
+ rai_config = json.load(f)
79
+ if (
80
+ "max_prompt_length" in rai_config
81
+ and detected_version in rai_config["max_prompt_length"]
82
+ ):
83
+ max_prompt_length = rai_config["max_prompt_length"][
84
+ detected_version
85
+ ]
86
+ except: # pylint: disable=bare-except
87
+ pass
75
88
 
76
89
  config_path = os.path.join(input_folder, "genai_config.json")
77
90
  if os.path.exists(config_path):
78
91
  with open(config_path, "r", encoding="utf-8") as f:
79
92
  config_dict = json.load(f)
80
- if max_prompt_length:
81
- config_dict["max_prompt_length"] = max_prompt_length
93
+ config_dict["max_prompt_length"] = max_prompt_length
82
94
  return config_dict
83
95
  return None
84
96
 
@@ -99,13 +111,16 @@ class OrtGenaiModel(ModelAdapter):
99
111
  ):
100
112
  params = og.GeneratorParams(self.model)
101
113
 
114
+ # OGA models return a list of tokens (older versions) or 1d numpy array (newer versions)
102
115
  prompt_length = len(input_ids)
116
+
103
117
  max_prompt_length = self.config.get("max_prompt_length")
104
118
  if max_prompt_length and prompt_length > max_prompt_length:
105
119
  raise ValueError(
106
120
  f"This prompt (length {prompt_length}) exceeds the model's "
107
121
  f"maximum allowed prompt length ({max_prompt_length})."
108
122
  )
123
+ self.prompt_tokens = prompt_length
109
124
 
110
125
  # There is a breaking API change in OGA 0.6.0
111
126
  # Determine whether we should use the old or new APIs
@@ -206,18 +221,21 @@ class OrtGenaiModel(ModelAdapter):
206
221
  )
207
222
  self.tokens_per_second = 1 / avg_token_gen_latency_s
208
223
 
209
- return [generator.get_sequence(0)]
224
+ response = generator.get_sequence(0)
225
+ self.response_tokens = len(response) - self.prompt_tokens
226
+ return [response]
210
227
  else:
211
228
  if use_oga_post_6_api:
212
229
  generator.append_tokens(input_ids)
213
230
  tokenizer_stream = streamer.tokenizer.tokenizer.create_stream()
214
-
231
+ self.response_tokens = 0
215
232
  stop_early = False
216
233
 
217
234
  while not generator.is_done() and not stop_early:
218
235
  if use_oga_pre_6_api:
219
236
  generator.compute_logits()
220
237
  generator.generate_next_token()
238
+ self.response_tokens += 1
221
239
 
222
240
  new_token = generator.get_next_tokens()[0]
223
241
  new_text = tokenizer_stream.decode(new_token)
lemonade/tools/prompt.py CHANGED
@@ -161,7 +161,11 @@ class LLMPrompt(Tool):
161
161
  # If template flag is set, then wrap prompt in template
162
162
  if template:
163
163
  # Embed prompt in model's chat template
164
- if tokenizer.chat_template:
164
+ if not hasattr(tokenizer, "prompt_template"):
165
+ printing.log_warning(
166
+ "Templates for this model type are not yet implemented."
167
+ )
168
+ elif tokenizer.chat_template:
165
169
  # Use the model's built-in chat template if available
166
170
  messages_dict = [{"role": "user", "content": prompt}]
167
171
  prompt = tokenizer.apply_chat_template(
@@ -175,25 +179,10 @@ class LLMPrompt(Tool):
175
179
  state.save_stat(Keys.PROMPT_TEMPLATE, "Default")
176
180
 
177
181
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids
178
- if isinstance(input_ids, (list, str)):
179
- # OGA models return a list of tokens (older versions)
180
- # Our llama.cpp adapter returns a string
181
- len_tokens_in = len(input_ids)
182
- elif hasattr(input_ids, "shape"):
183
- # HF models return a 2-D tensor
184
- # OGA models with newer versions may return numpy arrays
185
- if len(input_ids.shape) == 1:
186
- # 1-D array from newer OGA versions
187
- len_tokens_in = len(input_ids)
188
- else:
189
- # 2-D tensor from HF models
190
- len_tokens_in = input_ids.shape[1]
191
- else:
192
- # Fallback: try to get length directly
193
- len_tokens_in = len(input_ids)
194
182
 
195
183
  len_tokens_out = []
196
184
  response_texts = []
185
+ prompt_tokens = None # will be determined in generate function
197
186
  for trial in range(n_trials):
198
187
  if n_trials > 1:
199
188
  self.set_percent_progress(100.0 * trial / n_trials)
@@ -222,19 +211,22 @@ class LLMPrompt(Tool):
222
211
 
223
212
  response_array = response if isinstance(response, str) else response[0]
224
213
 
225
- # Separate the prompt from the response
226
- len_tokens_out.append(len(response_array) - len_tokens_in)
214
+ prompt_tokens = model.prompt_tokens
215
+ len_tokens_out.append(model.response_tokens)
227
216
 
228
- input_token = 0
217
+ # Remove the input from the response
218
+ # (up to the point they diverge, which they should not)
219
+ counter = 0
220
+ len_input_ids = len(input_ids_array)
229
221
  while (
230
- input_token < len_tokens_in
231
- and input_ids_array[input_token] == response_array[input_token]
222
+ counter < len_input_ids
223
+ and input_ids_array[counter] == response_array[counter]
232
224
  ):
233
- input_token += 1
225
+ counter += 1
234
226
 
235
227
  # Only decode the actual response (not the prompt)
236
228
  response_text = tokenizer.decode(
237
- response_array[input_token:], skip_special_tokens=True
229
+ response_array[counter:], skip_special_tokens=True
238
230
  ).strip()
239
231
  response_texts.append(response_text)
240
232
 
@@ -259,7 +251,7 @@ class LLMPrompt(Tool):
259
251
  plt.savefig(figure_path)
260
252
  state.save_stat(Keys.RESPONSE_LENGTHS_HISTOGRAM, figure_path)
261
253
 
262
- state.save_stat(Keys.PROMPT_TOKENS, len_tokens_in)
254
+ state.save_stat(Keys.PROMPT_TOKENS, prompt_tokens)
263
255
  state.save_stat(Keys.PROMPT, prompt)
264
256
  state.save_stat(Keys.RESPONSE_TOKENS, len_tokens_out)
265
257
  state.save_stat(Keys.RESPONSE, sanitize_text(response_texts))