lemonade-sdk 8.1.11__py3-none-any.whl → 8.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lemonade-sdk might be problematic. Click here for more details.

Files changed (32) hide show
  1. lemonade/cache.py +6 -1
  2. lemonade/common/status.py +4 -4
  3. lemonade/common/system_info.py +0 -26
  4. lemonade/tools/bench.py +22 -1
  5. lemonade/tools/flm/utils.py +70 -22
  6. lemonade/tools/llamacpp/bench.py +111 -23
  7. lemonade/tools/llamacpp/load.py +30 -2
  8. lemonade/tools/llamacpp/utils.py +234 -15
  9. lemonade/tools/oga/bench.py +0 -26
  10. lemonade/tools/oga/load.py +38 -142
  11. lemonade/tools/oga/migration.py +403 -0
  12. lemonade/tools/report/table.py +6 -0
  13. lemonade/tools/server/flm.py +2 -6
  14. lemonade/tools/server/llamacpp.py +20 -1
  15. lemonade/tools/server/serve.py +335 -17
  16. lemonade/tools/server/static/js/models.js +416 -18
  17. lemonade/tools/server/static/js/shared.js +44 -6
  18. lemonade/tools/server/static/logs.html +29 -19
  19. lemonade/tools/server/static/styles.css +204 -0
  20. lemonade/tools/server/static/webapp.html +32 -0
  21. lemonade/version.py +1 -1
  22. lemonade_install/install.py +33 -579
  23. {lemonade_sdk-8.1.11.dist-info → lemonade_sdk-8.2.0.dist-info}/METADATA +5 -3
  24. {lemonade_sdk-8.1.11.dist-info → lemonade_sdk-8.2.0.dist-info}/RECORD +32 -31
  25. lemonade_server/cli.py +10 -0
  26. lemonade_server/model_manager.py +172 -11
  27. lemonade_server/server_models.json +102 -66
  28. {lemonade_sdk-8.1.11.dist-info → lemonade_sdk-8.2.0.dist-info}/WHEEL +0 -0
  29. {lemonade_sdk-8.1.11.dist-info → lemonade_sdk-8.2.0.dist-info}/entry_points.txt +0 -0
  30. {lemonade_sdk-8.1.11.dist-info → lemonade_sdk-8.2.0.dist-info}/licenses/LICENSE +0 -0
  31. {lemonade_sdk-8.1.11.dist-info → lemonade_sdk-8.2.0.dist-info}/licenses/NOTICE.md +0 -0
  32. {lemonade_sdk-8.1.11.dist-info → lemonade_sdk-8.2.0.dist-info}/top_level.txt +0 -0
@@ -7,11 +7,10 @@ import zipfile
7
7
  from typing import Optional
8
8
  import subprocess
9
9
  import requests
10
+ import lemonade.common.build as build
10
11
  import lemonade.common.printing as printing
11
12
  from lemonade.tools.adapter import PassthroughTokenizer, ModelAdapter
12
-
13
13
  from lemonade.common.system_info import get_system_info
14
-
15
14
  from dotenv import set_key, load_dotenv
16
15
 
17
16
  LLAMA_VERSION_VULKAN = "b6510"
@@ -175,6 +174,13 @@ def get_llama_cli_exe_path(backend: str):
175
174
  return get_llama_exe_path("llama-cli", backend)
176
175
 
177
176
 
177
+ def get_llama_bench_exe_path(backend: str):
178
+ """
179
+ Get path to platform-specific llama-bench executable
180
+ """
181
+ return get_llama_exe_path("llama-bench", backend)
182
+
183
+
178
184
  def get_version_txt_path(backend: str):
179
185
  """
180
186
  Get path to text file that contains version information
@@ -370,7 +376,7 @@ def install_llamacpp(backend):
370
376
  import stat
371
377
 
372
378
  # Find and make executable files executable
373
- for root, dirs, files in os.walk(llama_server_exe_dir):
379
+ for root, _, files in os.walk(llama_server_exe_dir):
374
380
  for file in files:
375
381
  file_path = os.path.join(root, file)
376
382
  # Make files in bin/ directories executable
@@ -406,6 +412,7 @@ def install_llamacpp(backend):
406
412
  exe_paths = [
407
413
  (get_llama_server_exe_path(backend), "llama-server"),
408
414
  (get_llama_cli_exe_path(backend), "llama-cli"),
415
+ (get_llama_bench_exe_path(backend), "llama-bench"),
409
416
  ]
410
417
 
411
418
  for exe_path, exe_name in exe_paths:
@@ -647,15 +654,91 @@ def identify_gguf_models(
647
654
  return core_files, sharded_files
648
655
 
649
656
 
650
- def download_gguf(config_checkpoint, config_mmproj=None, do_not_upgrade=False) -> dict:
657
+ def resolve_local_gguf_model(
658
+ checkpoint: str, variant: str, config_mmproj: str = None
659
+ ) -> dict | None:
651
660
  """
652
- Downloads the GGUF file for the given model configuration.
661
+ Attempts to resolve a GGUF model from the local HuggingFace cache.
662
+ """
663
+ from huggingface_hub.constants import HF_HUB_CACHE
664
+
665
+ # Convert checkpoint to cache directory format
666
+ if checkpoint.startswith("models--"):
667
+ model_cache_dir = os.path.join(HF_HUB_CACHE, checkpoint)
668
+ else:
669
+ # This is a HuggingFace repo - convert to cache directory format
670
+ repo_cache_name = checkpoint.replace("/", "--")
671
+ model_cache_dir = os.path.join(HF_HUB_CACHE, f"models--{repo_cache_name}")
672
+
673
+ # Check if the cache directory exists
674
+ if not os.path.exists(model_cache_dir):
675
+ return None
676
+
677
+ gguf_file_found = None
678
+
679
+ # If variant is specified, look for that specific file
680
+ if variant:
681
+ search_term = variant if variant.endswith(".gguf") else f"{variant}.gguf"
682
+
683
+ for root, _, files in os.walk(model_cache_dir):
684
+ if search_term in files:
685
+ gguf_file_found = os.path.join(root, search_term)
686
+ break
687
+
688
+ # If no variant or variant not found, find any .gguf file (excluding mmproj)
689
+ if not gguf_file_found:
690
+ for root, _, files in os.walk(model_cache_dir):
691
+ gguf_files = [
692
+ f for f in files if f.endswith(".gguf") and "mmproj" not in f.lower()
693
+ ]
694
+ if gguf_files:
695
+ gguf_file_found = os.path.join(root, gguf_files[0])
696
+ break
697
+
698
+ # If no GGUF file found, model is not in cache
699
+ if not gguf_file_found:
700
+ return None
701
+
702
+ # Build result dictionary
703
+ result = {"variant": gguf_file_found}
704
+
705
+ # Search for mmproj file if provided
706
+ if config_mmproj:
707
+ for root, _, files in os.walk(model_cache_dir):
708
+ if config_mmproj in files:
709
+ result["mmproj"] = os.path.join(root, config_mmproj)
710
+ break
711
+
712
+ logging.info(f"Resolved local GGUF model: {result}")
713
+ return result
653
714
 
654
- For sharded models, if the variant points to a folder (e.g. Q4_0), all files in that folder
655
- will be downloaded but only the first file will be returned for loading.
715
+
716
+ def download_gguf(
717
+ config_checkpoint: str, config_mmproj=None, do_not_upgrade: bool = False
718
+ ) -> dict:
656
719
  """
720
+ Downloads the GGUF file for the given model configuration from HuggingFace.
721
+
722
+ This function downloads models from the internet. It does NOT check the local cache first.
723
+ Callers should use resolve_local_gguf_model() if they want to check for existing models first.
724
+
725
+ Args:
726
+ config_checkpoint: Checkpoint identifier (file path or HF repo with variant)
727
+ config_mmproj: Optional mmproj file to also download
728
+ do_not_upgrade: If True, use local cache only without attempting to download updates
657
729
 
658
- # This code handles all cases by constructing the appropriate filename or pattern
730
+ Returns:
731
+ Dictionary with "variant" (and optionally "mmproj") file paths
732
+ """
733
+ # Handle direct file path case - if the checkpoint is an actual file on disk
734
+ if os.path.exists(config_checkpoint):
735
+ result = {"variant": config_checkpoint}
736
+ if config_mmproj:
737
+ result["mmproj"] = config_mmproj
738
+ return result
739
+
740
+ # Parse checkpoint to extract base and variant
741
+ # Checkpoint format: repo_name:variant (e.g., "unsloth/Qwen3-0.6B-GGUF:Q4_0")
659
742
  checkpoint, variant = parse_checkpoint(config_checkpoint)
660
743
 
661
744
  # Identify the GGUF model files in the repository that match the variant
@@ -699,8 +782,10 @@ class LlamaCppAdapter(ModelAdapter):
699
782
  context_size,
700
783
  threads,
701
784
  executable,
785
+ bench_executable,
702
786
  reasoning=False,
703
787
  lib_dir=None,
788
+ state=None,
704
789
  ):
705
790
  super().__init__()
706
791
 
@@ -712,8 +797,10 @@ class LlamaCppAdapter(ModelAdapter):
712
797
  self.context_size = context_size
713
798
  self.threads = threads
714
799
  self.executable = os.path.normpath(executable)
800
+ self.bench_executable = os.path.normpath(bench_executable)
715
801
  self.reasoning = reasoning
716
802
  self.lib_dir = lib_dir
803
+ self.state = state
717
804
 
718
805
  def generate(
719
806
  self,
@@ -754,32 +841,54 @@ class LlamaCppAdapter(ModelAdapter):
754
841
  self.executable,
755
842
  "-m",
756
843
  self.model,
757
- "--ctx-size",
844
+ "--ctx-size", # size of the prompt context, 0 = loaded from model
758
845
  str(self.context_size),
759
- "-n",
846
+ "-n", # number of tokens to predict, -1 = infinity, =2 - until context filled
760
847
  str(n_predict),
761
- "-t",
848
+ "-t", # number of threads to use during generation
762
849
  str(self.threads),
763
850
  "-p",
764
851
  prompt,
852
+ "-b", # logical maximum batch size
853
+ "1",
854
+ "-ub", # physical maximum batch size
855
+ "1",
765
856
  "--temp",
766
857
  str(temperature),
767
858
  "--top-p",
768
859
  str(top_p),
769
860
  "--top-k",
770
861
  str(top_k),
771
- "-e",
772
- "-no-cnv",
773
- "--reasoning-format",
862
+ "-e", # process escape sequences
863
+ "--no-conversation", # disable conversation mode
864
+ "--reasoning-format", # leaves thoughts unparsed in message content
774
865
  "none",
775
866
  ]
776
867
 
868
+ # If prompt exceeds 500 characters, then use a file
869
+ if len(prompt) < 500:
870
+ cmd += ["-p", prompt]
871
+ else:
872
+ # Create prompt file in cache directory
873
+ prompt_file = os.path.join(
874
+ build.output_dir(self.state.cache_dir, self.state.build_name),
875
+ "prompt.txt",
876
+ )
877
+ with open(prompt_file, "w", encoding="utf-8") as file:
878
+ file.write(prompt)
879
+ cmd += ["-f", prompt_file]
880
+
777
881
  # Configure GPU layers: 99 for GPU, 0 for CPU-only
778
882
  ngl_value = "99" if self.device == "igpu" else "0"
779
883
  cmd = cmd + ["-ngl", ngl_value]
780
884
 
781
885
  cmd = [str(m) for m in cmd]
782
886
 
887
+ # save llama-cli command
888
+ self.state.llama_cli_cmd = getattr(self.state, "llama_cli_cmd", []) + [
889
+ " ".join(cmd)
890
+ ]
891
+
783
892
  try:
784
893
  # Set up environment with library path for Linux
785
894
  env = os.environ.copy()
@@ -809,6 +918,15 @@ class LlamaCppAdapter(ModelAdapter):
809
918
  )
810
919
 
811
920
  raw_output, stderr = process.communicate(timeout=600)
921
+
922
+ # save llama-cli command output with performance info to state
923
+ # (can be viewed in state.yaml file in cache)
924
+ self.state.llama_cli_stderr = getattr(
925
+ self.state, "llama_cli_stderr", []
926
+ ) + [
927
+ [line for line in stderr.splitlines() if line.startswith("llama_perf_")]
928
+ ]
929
+
812
930
  if process.returncode != 0:
813
931
  error_msg = f"llama.cpp failed with return code {process.returncode}.\n"
814
932
  error_msg += f"Command: {' '.join(cmd)}\n"
@@ -873,7 +991,108 @@ class LlamaCppAdapter(ModelAdapter):
873
991
  return [output_text]
874
992
 
875
993
  except Exception as e:
876
- error_msg = f"Failed to run llama.cpp command: {str(e)}\n"
994
+ error_msg = f"Failed to run llama-cli.exe command: {str(e)}\n"
995
+ error_msg += f"Command: {' '.join(cmd)}"
996
+ raise Exception(error_msg)
997
+
998
+ def benchmark(self, prompts, iterations, output_tokens):
999
+ """
1000
+ Runs the llama-bench.exe tool to measure TTFT and TPS
1001
+ """
1002
+ cmd = [
1003
+ self.bench_executable,
1004
+ "-m",
1005
+ self.model,
1006
+ "-r",
1007
+ iterations,
1008
+ "-p",
1009
+ ",".join([str(p) for p in prompts]),
1010
+ "-n",
1011
+ output_tokens,
1012
+ "-t",
1013
+ self.threads if self.threads > 0 else 16,
1014
+ "-b",
1015
+ 1,
1016
+ "-ub",
1017
+ 1,
1018
+ ]
1019
+ cmd = [str(m) for m in cmd]
1020
+
1021
+ # save llama-bench command
1022
+ self.state.llama_bench_cmd = " ".join(cmd)
1023
+
1024
+ try:
1025
+ # Set up environment with library path for Linux
1026
+ env = os.environ.copy()
1027
+
1028
+ # Load environment variables from .env file in the executable directory
1029
+ exe_dir = os.path.dirname(self.executable)
1030
+ env_file_path = os.path.join(exe_dir, ".env")
1031
+ if os.path.exists(env_file_path):
1032
+ load_dotenv(env_file_path, override=True)
1033
+ env.update(os.environ)
1034
+
1035
+ if self.lib_dir and os.name != "nt": # Not Windows
1036
+ current_ld_path = env.get("LD_LIBRARY_PATH", "")
1037
+ if current_ld_path:
1038
+ env["LD_LIBRARY_PATH"] = f"{self.lib_dir}:{current_ld_path}"
1039
+ else:
1040
+ env["LD_LIBRARY_PATH"] = self.lib_dir
1041
+
1042
+ process = subprocess.Popen(
1043
+ cmd,
1044
+ stdout=subprocess.PIPE,
1045
+ stderr=subprocess.PIPE,
1046
+ universal_newlines=True,
1047
+ encoding="utf-8",
1048
+ errors="replace",
1049
+ env=env,
1050
+ )
1051
+
1052
+ raw_output, stderr = process.communicate(timeout=600)
1053
+
1054
+ # save llama-bench command output with performance info to state
1055
+ # (can be viewed in state.yaml file in cache)
1056
+ self.state.llama_bench_standard_output = raw_output.splitlines()
1057
+
1058
+ if process.returncode != 0:
1059
+ error_msg = (
1060
+ f"llama-bench.exe failed with return code {process.returncode}.\n"
1061
+ )
1062
+ error_msg += f"Command: {' '.join(cmd)}\n"
1063
+ error_msg += f"Error output:\n{stderr}\n"
1064
+ error_msg += f"Standard output:\n{raw_output}"
1065
+ raise Exception(error_msg)
1066
+
1067
+ if raw_output is None:
1068
+ raise Exception("No output received from llama-bench.exe process")
1069
+
1070
+ # Parse information from llama-bench.exe output
1071
+ prompt_lengths = []
1072
+ pp_tps = []
1073
+ pp_tps_sd = []
1074
+ tg_tps = None
1075
+ tg_tps_sd = None
1076
+
1077
+ for line in self.state.llama_bench_standard_output:
1078
+ # Parse TPS information
1079
+ for p in prompts:
1080
+ if f"pp{p:d}" in line:
1081
+ parts = line.split("|")
1082
+ timings = parts[-2].strip().split(" ")
1083
+ prompt_lengths.append(p)
1084
+ pp_tps.append(float(timings[0]))
1085
+ pp_tps_sd.append(float(timings[-1]))
1086
+ if f"tg{output_tokens:d}" in line:
1087
+ parts = line.split("|")
1088
+ timings = parts[-2].strip().split(" ")
1089
+ tg_tps = float(timings[0])
1090
+ tg_tps_sd = float(timings[-1])
1091
+
1092
+ return prompt_lengths, pp_tps, pp_tps_sd, tg_tps, tg_tps_sd
1093
+
1094
+ except Exception as e:
1095
+ error_msg = f"Failed to run llama-bench.exe command: {str(e)}\n"
877
1096
  error_msg += f"Command: {' '.join(cmd)}"
878
1097
  raise Exception(error_msg)
879
1098
 
@@ -2,7 +2,6 @@ import argparse
2
2
  import statistics
3
3
  from statistics import StatisticsError
4
4
  from lemonade.state import State
5
- from lemonade.cache import Keys
6
5
  from lemonade.tools.adapter import ModelAdapter, TokenizerAdapter
7
6
  from lemonade.tools.bench import Bench
8
7
 
@@ -20,16 +19,6 @@ class OgaBench(Bench):
20
19
 
21
20
  unique_name = "oga-bench"
22
21
 
23
- def __init__(self):
24
- super().__init__()
25
-
26
- # Additional statistics generated by this bench tool
27
- self.status_stats.insert(
28
- self.status_stats.index(Keys.TOKEN_GENERATION_TOKENS_PER_SECOND) + 1,
29
- Keys.STD_DEV_TOKENS_PER_SECOND,
30
- )
31
- self.std_dev_token_generation_tokens_per_second_list = []
32
-
33
22
  @staticmethod
34
23
  def parser(add_help: bool = True) -> argparse.ArgumentParser:
35
24
  parser = __class__.helpful_parser(
@@ -121,21 +110,6 @@ class OgaBench(Bench):
121
110
  # Less than 2 measurements
122
111
  self.std_dev_token_generation_tokens_per_second_list.append(None)
123
112
 
124
- def save_stats(self, state):
125
- super().save_stats(state)
126
-
127
- # Save additional statistics
128
- if not all(
129
- element is None
130
- for element in self.std_dev_token_generation_tokens_per_second_list
131
- ):
132
- state.save_stat(
133
- Keys.STD_DEV_TOKENS_PER_SECOND,
134
- self.get_item_or_list(
135
- self.std_dev_token_generation_tokens_per_second_list
136
- ),
137
- )
138
-
139
113
 
140
114
  # This file was originally licensed under Apache 2.0. It has been modified.
141
115
  # Modifications Copyright (c) 2025 AMD
@@ -4,7 +4,6 @@
4
4
 
5
5
  import argparse
6
6
  import subprocess
7
- import sys
8
7
  import os
9
8
  import json
10
9
  import webbrowser
@@ -38,6 +37,17 @@ execution_providers = {
38
37
  }
39
38
 
40
39
 
40
+ def find_onnx_files_recursively(directory):
41
+ """
42
+ Recursively search for ONNX files in a directory and its subdirectories.
43
+ """
44
+ for _, _, files in os.walk(directory):
45
+ for file in files:
46
+ if file.endswith(".onnx"):
47
+ return True
48
+ return False
49
+
50
+
41
51
  def _get_npu_driver_version():
42
52
  """
43
53
  Get the NPU driver version using PowerShell directly.
@@ -321,6 +331,7 @@ class OgaLoad(FirstTool):
321
331
 
322
332
  @staticmethod
323
333
  def _setup_model_dependencies(full_model_path, device, ryzenai_version, oga_path):
334
+ # pylint: disable=unused-argument
324
335
  """
325
336
  Sets up model dependencies for hybrid and NPU inference by:
326
337
  1. Configuring the custom_ops_library path in genai_config.json.
@@ -328,116 +339,35 @@ class OgaLoad(FirstTool):
328
339
  3. Check NPU driver version if required for device and ryzenai_version.
329
340
  """
330
341
 
331
- env_path = sys.prefix
342
+ # For RyzenAI 1.6.0, check NPU driver version for NPU and hybrid devices
343
+ if device in ["npu", "hybrid"]:
344
+ required_driver_version = REQUIRED_NPU_DRIVER_VERSION
332
345
 
333
- if "1.4.0" in ryzenai_version:
334
- if device == "npu":
335
- custom_ops_path = os.path.join(
336
- oga_path, "libs", "onnxruntime_vitis_ai_custom_ops.dll"
337
- )
338
- else:
339
- custom_ops_path = os.path.join(oga_path, "libs", "onnx_custom_ops.dll")
340
- else:
341
- # For 1.5.0+, check NPU driver version for NPU and hybrid devices
342
- if device in ["npu", "hybrid"]:
343
- required_driver_version = REQUIRED_NPU_DRIVER_VERSION
344
-
345
- current_driver_version = _get_npu_driver_version()
346
-
347
- if not current_driver_version:
348
- printing.log_warning(
349
- f"NPU driver not found. {device.upper()} inference requires NPU driver "
350
- f"version {required_driver_version}.\n"
351
- "Please download and install the NPU Driver from:\n"
352
- f"{NPU_DRIVER_DOWNLOAD_URL}\n"
353
- "NPU functionality may not work properly."
354
- )
355
- _open_driver_install_page()
356
-
357
- elif not _compare_driver_versions(
358
- current_driver_version, required_driver_version
359
- ):
360
- printing.log_warning(
361
- f"Incorrect NPU driver version detected: {current_driver_version}\n"
362
- f"{device.upper()} inference with RyzenAI 1.5.0 requires driver "
363
- f"version {required_driver_version} or higher.\n"
364
- "Please download and install the correct NPU Driver from:\n"
365
- f"{NPU_DRIVER_DOWNLOAD_URL}\n"
366
- "NPU functionality may not work properly."
367
- )
368
- _open_driver_install_page()
369
-
370
- if device == "npu":
371
- # For 1.5.0, custom ops are in the conda environment's onnxruntime package
372
- custom_ops_path = os.path.join(
373
- env_path,
374
- "Lib",
375
- "site-packages",
376
- "onnxruntime",
377
- "capi",
378
- "onnxruntime_vitis_ai_custom_ops.dll",
379
- )
380
- dll_source_path = os.path.join(
381
- env_path, "Lib", "site-packages", "onnxruntime", "capi"
382
- )
383
- required_dlls = ["dyn_dispatch_core.dll", "xaiengine.dll"]
384
- else:
385
- custom_ops_path = os.path.join(
386
- env_path,
387
- "Lib",
388
- "site-packages",
389
- "onnxruntime_genai",
390
- "onnx_custom_ops.dll",
391
- )
392
- dll_source_path = os.path.join(
393
- env_path, "Lib", "site-packages", "onnxruntime_genai"
394
- )
395
- required_dlls = ["libutf8_validity.dll", "abseil_dll.dll"]
396
-
397
- # Validate that all required DLLs exist in the source directory
398
- missing_dlls = []
399
- if not os.path.exists(custom_ops_path):
400
- missing_dlls.append(custom_ops_path)
401
-
402
- for dll_name in required_dlls:
403
- dll_source = os.path.join(dll_source_path, dll_name)
404
- if not os.path.exists(dll_source):
405
- missing_dlls.append(dll_source)
406
-
407
- if missing_dlls:
408
- dll_list = "\n - ".join(missing_dlls)
409
- raise RuntimeError(
410
- f"Required DLLs not found for {device} inference:\n - {dll_list}\n"
411
- f"Please ensure your RyzenAI installation is complete and supports {device}."
346
+ current_driver_version = _get_npu_driver_version()
347
+ rai_version, _ = _get_ryzenai_version_info(device)
348
+
349
+ if not current_driver_version:
350
+ printing.log_warning(
351
+ f"NPU driver not found. {device.upper()} inference requires NPU driver "
352
+ f"version {required_driver_version}.\n"
353
+ "Please download and install the NPU Driver from:\n"
354
+ f"{NPU_DRIVER_DOWNLOAD_URL}\n"
355
+ "NPU functionality may not work properly."
412
356
  )
357
+ _open_driver_install_page()
413
358
 
414
- # Add the DLL source directory to PATH
415
- current_path = os.environ.get("PATH", "")
416
- if dll_source_path not in current_path:
417
- os.environ["PATH"] = dll_source_path + os.pathsep + current_path
418
-
419
- # Update the model config with custom_ops_library path
420
- config_path = os.path.join(full_model_path, "genai_config.json")
421
- if os.path.exists(config_path):
422
- with open(config_path, "r", encoding="utf-8") as f:
423
- config = json.load(f)
424
-
425
- if (
426
- "model" in config
427
- and "decoder" in config["model"]
428
- and "session_options" in config["model"]["decoder"]
359
+ elif not _compare_driver_versions(
360
+ current_driver_version, required_driver_version
429
361
  ):
430
- config["model"]["decoder"]["session_options"][
431
- "custom_ops_library"
432
- ] = custom_ops_path
433
-
434
- with open(config_path, "w", encoding="utf-8") as f:
435
- json.dump(config, f, indent=4)
436
-
437
- else:
438
- printing.log_info(
439
- f"Model's `genai_config.json` not found in {full_model_path}"
440
- )
362
+ printing.log_warning(
363
+ f"Incorrect NPU driver version detected: {current_driver_version}\n"
364
+ f"{device.upper()} inference with RyzenAI {rai_version} requires driver "
365
+ f"version {required_driver_version} or higher.\n"
366
+ "Please download and install the correct NPU Driver from:\n"
367
+ f"{NPU_DRIVER_DOWNLOAD_URL}\n"
368
+ "NPU functionality may not work properly."
369
+ )
370
+ _open_driver_install_page()
441
371
 
442
372
  @staticmethod
443
373
  def _is_preoptimized_model(input_model_path):
@@ -502,34 +432,6 @@ class OgaLoad(FirstTool):
502
432
 
503
433
  return full_model_path
504
434
 
505
- @staticmethod
506
- def _setup_npu_environment(ryzenai_version, oga_path):
507
- """
508
- Sets up environment for NPU flow of ONNX model and returns saved state to be restored
509
- later in cleanup.
510
- """
511
- if "1.5.0" in ryzenai_version:
512
- # For PyPI installation (1.5.0+), no environment setup needed
513
- return None
514
- elif "1.4.0" in ryzenai_version:
515
- # Legacy lemonade-install approach for 1.4.0
516
- if not os.path.exists(os.path.join(oga_path, "libs", "onnxruntime.dll")):
517
- raise RuntimeError(
518
- f"Cannot find libs/onnxruntime.dll in lib folder: {oga_path}"
519
- )
520
-
521
- # Save current state so they can be restored after inference.
522
- saved_state = {"cwd": os.getcwd(), "path": os.environ["PATH"]}
523
-
524
- # Setup NPU environment (cwd and path will be restored later)
525
- os.chdir(oga_path)
526
- os.environ["PATH"] = (
527
- os.path.join(oga_path, "libs") + os.pathsep + os.environ["PATH"]
528
- )
529
- return saved_state
530
- else:
531
- raise ValueError(f"Unsupported RyzenAI version: {ryzenai_version}")
532
-
533
435
  @staticmethod
534
436
  def _load_model_and_setup_state(
535
437
  state, full_model_path, checkpoint, trust_remote_code
@@ -702,8 +604,7 @@ class OgaLoad(FirstTool):
702
604
  state.save_stat(Keys.CHECKPOINT, checkpoint)
703
605
  state.save_stat(Keys.LOCAL_MODEL_FOLDER, full_model_path)
704
606
  # See if there is a file ending in ".onnx" in this folder
705
- dir = os.listdir(input)
706
- has_onnx_file = any([filename.endswith(".onnx") for filename in dir])
607
+ has_onnx_file = find_onnx_files_recursively(input)
707
608
  if not has_onnx_file:
708
609
  raise ValueError(
709
610
  f"The folder {input} does not contain an ONNX model file."
@@ -852,15 +753,10 @@ class OgaLoad(FirstTool):
852
753
 
853
754
  try:
854
755
  if device == "npu":
855
- saved_env_state = self._setup_npu_environment(
856
- ryzenai_version, oga_path
857
- )
858
756
  # Set USE_AIE_RoPE based on model type
859
757
  os.environ["USE_AIE_RoPE"] = (
860
758
  "0" if "phi-" in checkpoint.lower() else "1"
861
759
  )
862
- elif device == "hybrid":
863
- saved_env_state = None
864
760
 
865
761
  self._load_model_and_setup_state(
866
762
  state, full_model_path, checkpoint, trust_remote_code