lemonade-sdk 7.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lemonade-sdk might be problematic. Click here for more details.

Files changed (61) hide show
  1. lemonade/__init__.py +5 -0
  2. lemonade/api.py +125 -0
  3. lemonade/cache.py +85 -0
  4. lemonade/cli.py +135 -0
  5. lemonade/common/__init__.py +0 -0
  6. lemonade/common/analyze_model.py +26 -0
  7. lemonade/common/build.py +223 -0
  8. lemonade/common/cli_helpers.py +139 -0
  9. lemonade/common/exceptions.py +98 -0
  10. lemonade/common/filesystem.py +368 -0
  11. lemonade/common/labels.py +61 -0
  12. lemonade/common/onnx_helpers.py +176 -0
  13. lemonade/common/plugins.py +10 -0
  14. lemonade/common/printing.py +110 -0
  15. lemonade/common/status.py +490 -0
  16. lemonade/common/system_info.py +390 -0
  17. lemonade/common/tensor_helpers.py +83 -0
  18. lemonade/common/test_helpers.py +28 -0
  19. lemonade/profilers/__init__.py +1 -0
  20. lemonade/profilers/memory_tracker.py +257 -0
  21. lemonade/profilers/profiler.py +55 -0
  22. lemonade/sequence.py +363 -0
  23. lemonade/state.py +159 -0
  24. lemonade/tools/__init__.py +1 -0
  25. lemonade/tools/adapter.py +104 -0
  26. lemonade/tools/bench.py +284 -0
  27. lemonade/tools/huggingface_bench.py +267 -0
  28. lemonade/tools/huggingface_load.py +520 -0
  29. lemonade/tools/humaneval.py +258 -0
  30. lemonade/tools/llamacpp.py +261 -0
  31. lemonade/tools/llamacpp_bench.py +154 -0
  32. lemonade/tools/management_tools.py +273 -0
  33. lemonade/tools/mmlu.py +327 -0
  34. lemonade/tools/ort_genai/__init__.py +0 -0
  35. lemonade/tools/ort_genai/oga.py +1129 -0
  36. lemonade/tools/ort_genai/oga_bench.py +142 -0
  37. lemonade/tools/perplexity.py +146 -0
  38. lemonade/tools/prompt.py +228 -0
  39. lemonade/tools/quark/__init__.py +0 -0
  40. lemonade/tools/quark/quark_load.py +172 -0
  41. lemonade/tools/quark/quark_quantize.py +439 -0
  42. lemonade/tools/report/__init__.py +0 -0
  43. lemonade/tools/report/llm_report.py +203 -0
  44. lemonade/tools/report/table.py +739 -0
  45. lemonade/tools/server/__init__.py +0 -0
  46. lemonade/tools/server/serve.py +1354 -0
  47. lemonade/tools/server/tool_calls.py +146 -0
  48. lemonade/tools/tool.py +374 -0
  49. lemonade/version.py +1 -0
  50. lemonade_install/__init__.py +1 -0
  51. lemonade_install/install.py +774 -0
  52. lemonade_sdk-7.0.0.dist-info/METADATA +116 -0
  53. lemonade_sdk-7.0.0.dist-info/RECORD +61 -0
  54. lemonade_sdk-7.0.0.dist-info/WHEEL +5 -0
  55. lemonade_sdk-7.0.0.dist-info/entry_points.txt +4 -0
  56. lemonade_sdk-7.0.0.dist-info/licenses/LICENSE +201 -0
  57. lemonade_sdk-7.0.0.dist-info/licenses/NOTICE.md +21 -0
  58. lemonade_sdk-7.0.0.dist-info/top_level.txt +3 -0
  59. lemonade_server/cli.py +260 -0
  60. lemonade_server/model_manager.py +98 -0
  61. lemonade_server/server_models.json +142 -0
@@ -0,0 +1,1129 @@
1
+ # onnxruntime_genai is not lint-friendly yet and PyLint can't
2
+ # find any of the class methods
3
+ # pylint: disable=no-member
4
+ #
5
+ # Model builder constraints:
6
+ # 11/10/24 Need transformers <4.45.0 OR onnxruntime-genai 0.5.0 (which must be built from source)
7
+ # (transformers v4.45 changes the format of the tokenizer.json file which will be supported in
8
+ # onnxruntime-genai 0.5)
9
+ #
10
+
11
+ import argparse
12
+ import os
13
+ import time
14
+ import json
15
+ import shutil
16
+ import logging
17
+ from fnmatch import fnmatch
18
+ from queue import Queue
19
+ import subprocess
20
+ from packaging.version import Version
21
+ from huggingface_hub import snapshot_download
22
+ import onnxruntime_genai as og
23
+ import onnxruntime_genai.models.builder as model_builder
24
+ from transformers import AutoTokenizer
25
+ from lemonade.state import State
26
+ from lemonade.tools import FirstTool
27
+ import lemonade.common.status as status
28
+ import lemonade.common.printing as printing
29
+ from lemonade.tools.huggingface_load import get_base_model, is_offline
30
+ from lemonade.tools.adapter import (
31
+ ModelAdapter,
32
+ TokenizerAdapter,
33
+ PassthroughTokenizerResult,
34
+ )
35
+ from lemonade.cache import Keys
36
+ from lemonade_install.install import (
37
+ get_ryzen_ai_version_info,
38
+ get_oga_npu_dir,
39
+ get_oga_hybrid_dir,
40
+ SUPPORTED_RYZEN_AI_SERIES,
41
+ )
42
+
43
+
44
+ # ONNX Runtime GenAI models will be cached in this subfolder of the lemonade cache folder
45
+ oga_models_path = "oga_models"
46
+
47
+ # ONNX Runtime GenAI model builder tool uses this subfolder of the lemonade cache as its cache
48
+ oga_model_builder_cache_path = "model_builder"
49
+
50
+ # Mapping from processor to execution provider, used in pathnames and by model_builder
51
+ execution_providers = {
52
+ "cpu": "cpu",
53
+ "npu": "npu",
54
+ "igpu": "dml",
55
+ "hybrid": "hybrid",
56
+ "cuda": "cuda",
57
+ }
58
+
59
+
60
+ class OrtGenaiTokenizer(TokenizerAdapter):
61
+ def __init__(self, model: og.Model, hf_tokenizer: AutoTokenizer):
62
+ super().__init__(hf_tokenizer)
63
+ # Initialize OGA tokenizer
64
+ self.tokenizer = og.Tokenizer(model)
65
+
66
+ # Placeholder value since some code will try to query it
67
+ # If we actually need this to return a proper value, then
68
+ # og.GeneratorParams.eos_token_id has it
69
+ self.eos_token_id = None
70
+
71
+ def __call__(self, prompt: str, return_tensors="np"):
72
+ tokens = self.tokenizer.encode(prompt)
73
+ return PassthroughTokenizerResult(tokens)
74
+
75
+ # pylint: disable=unused-argument
76
+ def decode(self, response, skip_special_tokens=True) -> str:
77
+ return self.tokenizer.decode(response)
78
+
79
+
80
+ class OrtGenaiStreamer:
81
+ def __init__(self, tokenizer: OrtGenaiTokenizer, timeout=None):
82
+ self.tokenizer = tokenizer
83
+ self.text_queue = Queue()
84
+ self.stop_signal = None
85
+ self.timeout = timeout
86
+
87
+ def add_text(self, text: str):
88
+ self.text_queue.put(text, timeout=self.timeout)
89
+
90
+ def done(self):
91
+ self.text_queue.put(self.stop_signal, timeout=self.timeout)
92
+
93
+ def __iter__(self):
94
+ return self
95
+
96
+ def __next__(self):
97
+ value = self.text_queue.get(timeout=self.timeout)
98
+ if value == self.stop_signal:
99
+ raise StopIteration()
100
+ else:
101
+ return value
102
+
103
+
104
+ class OrtGenaiModel(ModelAdapter):
105
+
106
+ def __init__(self, input_folder):
107
+ super().__init__()
108
+ self.model = og.Model(input_folder)
109
+ self.type = "ort-genai"
110
+ self.config = self.load_config(input_folder)
111
+
112
+ def load_config(self, input_folder):
113
+ rai_config_path = os.path.join(input_folder, "rai_config.json")
114
+ if os.path.exists(rai_config_path):
115
+ with open(rai_config_path, "r", encoding="utf-8") as f:
116
+ max_prompt_length = json.load(f)["max_prompt_length"]["1.4.1"]
117
+ else:
118
+ max_prompt_length = None
119
+
120
+ config_path = os.path.join(input_folder, "genai_config.json")
121
+ if os.path.exists(config_path):
122
+ with open(config_path, "r", encoding="utf-8") as f:
123
+ config_dict = json.load(f)
124
+ if max_prompt_length:
125
+ config_dict["max_prompt_length"] = max_prompt_length
126
+ return config_dict
127
+ return None
128
+
129
+ def generate(
130
+ self,
131
+ input_ids,
132
+ max_new_tokens=512,
133
+ min_new_tokens=0,
134
+ do_sample=True,
135
+ top_k=50,
136
+ top_p=1.0,
137
+ temperature=0.7,
138
+ streamer: OrtGenaiStreamer = None,
139
+ pad_token_id=None,
140
+ stopping_criteria=None,
141
+ max_length=None,
142
+ ):
143
+ params = og.GeneratorParams(self.model)
144
+
145
+ prompt_length = len(input_ids)
146
+ max_prompt_length = self.config.get("max_prompt_length")
147
+ if max_prompt_length and prompt_length > max_prompt_length:
148
+ raise ValueError(
149
+ f"This prompt (length {prompt_length}) exceeds the model's "
150
+ f"maximum allowed prompt length ({max_prompt_length})."
151
+ )
152
+
153
+ # There is a breaking API change in OGA 0.6.0
154
+ # Determine whether we should use the old or new APIs
155
+ # This also supports 0.6.0.dev0, which evaluates to less than 0.6.0 in Version
156
+ use_oga_post_6_api = (
157
+ Version(og.__version__) >= Version("0.6.0") or "0.6.0" in og.__version__
158
+ )
159
+ use_oga_pre_6_api = not use_oga_post_6_api
160
+
161
+ if pad_token_id:
162
+ params.pad_token_id = pad_token_id
163
+
164
+ # Handle max_length and max_new_tokens
165
+ if max_length and max_new_tokens:
166
+ logging.warning(
167
+ "Both max_length and max_new_tokens were provided. "
168
+ "max_length will take precedence. "
169
+ "When setting max_length, please explicitly set max_new_tokens to None."
170
+ )
171
+ max_length_to_use = None
172
+ if max_length:
173
+ max_length_to_use = max_length
174
+ elif max_new_tokens:
175
+ max_length_to_use = prompt_length + max_new_tokens
176
+
177
+ min_length = prompt_length + min_new_tokens
178
+
179
+ if use_oga_pre_6_api:
180
+ params.input_ids = input_ids
181
+
182
+ if self.config and "search" in self.config:
183
+ search_config = self.config["search"]
184
+ params.set_search_options(
185
+ do_sample=search_config.get("do_sample", do_sample),
186
+ top_k=search_config.get("top_k", top_k),
187
+ top_p=search_config.get("top_p", top_p),
188
+ temperature=search_config.get("temperature", temperature),
189
+ max_length=max_length_to_use,
190
+ min_length=min_length,
191
+ early_stopping=search_config.get("early_stopping", False),
192
+ length_penalty=search_config.get("length_penalty", 1.0),
193
+ num_beams=search_config.get("num_beams", 1),
194
+ num_return_sequences=search_config.get("num_return_sequences", 1),
195
+ repetition_penalty=search_config.get("repetition_penalty", 1.0),
196
+ past_present_share_buffer=search_config.get(
197
+ "past_present_share_buffer", True
198
+ ),
199
+ # Make sure that results do not vary across laptops
200
+ # by default, random_seed=-1 causes different laptops to give
201
+ # different results
202
+ random_seed=1,
203
+ # Not currently supported by OGA
204
+ # diversity_penalty=search_config.get('diversity_penalty', 0.0),
205
+ # no_repeat_ngram_size=search_config.get('no_repeat_ngram_size', 0),
206
+ )
207
+ else:
208
+ params.set_search_options(
209
+ do_sample=do_sample,
210
+ top_k=top_k,
211
+ top_p=top_p,
212
+ temperature=temperature,
213
+ max_length=max_length_to_use,
214
+ min_length=min_length,
215
+ )
216
+ params.try_graph_capture_with_max_batch_size(1)
217
+
218
+ generator = og.Generator(self.model, params)
219
+
220
+ if streamer is None:
221
+ prompt_start_time = time.perf_counter()
222
+ if use_oga_post_6_api:
223
+ generator.append_tokens(input_ids)
224
+ if use_oga_pre_6_api:
225
+ generator.compute_logits()
226
+ generator.generate_next_token()
227
+ prompt_end_time = time.perf_counter()
228
+
229
+ self.time_to_first_token = prompt_end_time - prompt_start_time
230
+
231
+ if max_new_tokens > 1:
232
+
233
+ token_gen_times = []
234
+ while not generator.is_done():
235
+ token_gen_start_time = time.perf_counter()
236
+ if use_oga_pre_6_api:
237
+ generator.compute_logits()
238
+ generator.generate_next_token()
239
+ token_gen_end_time = time.perf_counter()
240
+
241
+ token_gen_times.append(token_gen_end_time - token_gen_start_time)
242
+
243
+ if token_gen_times:
244
+ # List will be empty if we generated 1 or 0 tokens, and we don't
245
+ # want a divide-by-zero error in those cases
246
+ avg_token_gen_latency_s = sum(token_gen_times) / len(
247
+ token_gen_times
248
+ )
249
+ self.tokens_per_second = 1 / avg_token_gen_latency_s
250
+
251
+ return [generator.get_sequence(0)]
252
+ else:
253
+ if use_oga_post_6_api:
254
+ generator.append_tokens(input_ids)
255
+ tokenizer_stream = streamer.tokenizer.tokenizer.create_stream()
256
+
257
+ stop_early = False
258
+
259
+ while not generator.is_done() and not stop_early:
260
+ if use_oga_pre_6_api:
261
+ generator.compute_logits()
262
+ generator.generate_next_token()
263
+
264
+ new_token = generator.get_next_tokens()[0]
265
+ new_text = tokenizer_stream.decode(new_token)
266
+
267
+ streamer.add_text(new_text)
268
+
269
+ if stopping_criteria is not None:
270
+ if stopping_criteria[0].stop_event.is_set():
271
+ stop_early = True
272
+
273
+ streamer.done()
274
+
275
+ def _model_call(self, input_ids):
276
+ """
277
+ Run the model on input_ids and get logits.
278
+
279
+ This method directly accesses model logits rather than using the full generate pipeline for
280
+ several important reasons:
281
+ 1. Purpose: We need raw logits from a single forward pass, while generate() is optimized for
282
+ producing multiple tokens through iterative inference
283
+ 2. Efficiency: Direct access is more efficient for logprob calculations with no
284
+ sampling overhead
285
+ 3. Precision: Logprob calculations require exact control over input-to-output mapping
286
+ 4. Consistency: Similar approach used in both HF and OGA implementations
287
+
288
+ Args:
289
+ input_ids: Input token IDs
290
+
291
+ Returns:
292
+ Logits for each token in the sequence
293
+ """
294
+ import torch
295
+
296
+ # Setup generator params
297
+ params = og.GeneratorParams(self.model)
298
+
299
+ # Configure for a simple forward pass
300
+ params.set_search_options(
301
+ do_sample=False,
302
+ temperature=0.0,
303
+ max_length=len(input_ids),
304
+ )
305
+
306
+ # Initialize generator
307
+ generator = og.Generator(self.model, params)
308
+
309
+ # Feed tokens to model based on API version
310
+ generator.append_tokens(input_ids)
311
+
312
+ # Extract logits - this returns a list of logits tensors
313
+ logits = generator.get_output("logits")
314
+
315
+ # Convert to torch tensor for easier processing
316
+ return torch.tensor(logits[0])
317
+
318
+ def _select_cont_toks(self, logits, context_len, continuation_tokens):
319
+ """
320
+ Select and process logits for continuation tokens.
321
+
322
+ Args:
323
+ logits: Full sequence logits
324
+ context_len: Length of context tokens
325
+ continuation_tokens: List or tensor of continuation token IDs
326
+
327
+ Returns:
328
+ Log probabilities for continuation tokens
329
+ """
330
+ import torch
331
+
332
+ # Extract relevant logits for continuation prediction (shift by one)
333
+ cont_logits = logits[
334
+ context_len - 1 : context_len - 1 + len(continuation_tokens)
335
+ ]
336
+
337
+ # Convert to torch tensors if needed
338
+ if not isinstance(continuation_tokens, torch.Tensor):
339
+ continuation_tokens = torch.tensor(continuation_tokens, dtype=torch.long)
340
+
341
+ # Apply log softmax to get log probabilities
342
+ log_probs = torch.log_softmax(cont_logits, dim=-1)
343
+
344
+ # Get log probs for the specific continuation tokens
345
+ token_log_probs = torch.gather(
346
+ log_probs, 1, continuation_tokens.unsqueeze(-1)
347
+ ).squeeze(-1)
348
+
349
+ return token_log_probs
350
+
351
+ def compute_logprobs(
352
+ self, text, tokenizer, prompt_length=None, logprobs=None, echo=False
353
+ ):
354
+ """
355
+ Compute log probabilities for all tokens in the given text.
356
+
357
+ Args:
358
+ text: The full text to analyze (e.g., prompt + completion)
359
+ prompt_length: Number of tokens in the prompt. If provided and echo=False,
360
+ only completion tokens after this position will be returned.
361
+ logprobs: If not None, return log probabilities. Value indicates how many top
362
+ alternatives to return. If True but not an integer, defaults to 5 alternatives.
363
+ echo: If True, include logprobs for prompt tokens. If False, only return logprobs
364
+ for completion tokens.
365
+
366
+ Returns:
367
+ - text_offset: Character offsets for each token in the text
368
+ - token_logprobs: Log probability for each token
369
+ - tokens: The actual tokens used
370
+ - top_logprobs: Top alternative log probabilities for each position
371
+ """
372
+ import torch
373
+
374
+ if tokenizer is None:
375
+ raise ValueError("Tokenizer is required for logprob calculation")
376
+
377
+ # Encode the full text
378
+ tokens = tokenizer(text).input_ids # pylint: disable=E1102
379
+
380
+ # Track character offsets for each token
381
+ text_offset = []
382
+ start_idx = 0
383
+
384
+ token_strings = []
385
+ for token_id in tokens:
386
+ token_str = tokenizer.decode([token_id])
387
+ token_strings.append(token_str)
388
+
389
+ # Calculate character offsets for tokens - handles cases where tokens
390
+ # may not directly match in the original text due to encoding differences,
391
+ # special characters, or tokenization artifacts
392
+ try:
393
+ pos = text[start_idx:].find(token_str)
394
+ if pos != -1:
395
+ text_offset.append(start_idx + pos)
396
+ start_idx += pos + len(token_str)
397
+ else:
398
+ text_offset.append(start_idx)
399
+ except (TypeError, ValueError, UnicodeError):
400
+ # Fallback to current position when matching fails due to encoding issues
401
+ text_offset.append(start_idx)
402
+
403
+ # Get logits from model
404
+ logits = self._model_call(tokens)
405
+
406
+ # Calculate log probabilities for each token
407
+ all_log_probs = torch.log_softmax(logits, dim=-1)
408
+
409
+ # The first token doesn't have a conditional probability
410
+ # For tokens after the first, get the predicted probability
411
+ token_log_probs = []
412
+ top_logprobs_list = []
413
+
414
+ # For each position, get the actual token probability and top alternatives
415
+ for i in range(len(tokens)):
416
+ # Get previous token position logits
417
+ if i > 0: # First token has no preceding context
418
+ prev_logits = all_log_probs[i - 1]
419
+ curr_token_id = tokens[i]
420
+ # Get probability of the actual token that appeared
421
+ token_logprob = prev_logits[curr_token_id].item()
422
+ token_log_probs.append(token_logprob)
423
+
424
+ # Get top-k alternatives if requested
425
+ if logprobs is not None:
426
+ num_alternatives = logprobs if isinstance(logprobs, int) else 5
427
+ topk_values, topk_indices = torch.topk(
428
+ prev_logits, min(num_alternatives, prev_logits.size(-1))
429
+ )
430
+
431
+ # Create dictionary of token: logprob
432
+ position_logprobs = {}
433
+ for val, idx in zip(topk_values.tolist(), topk_indices.tolist()):
434
+ token_str = tokenizer.decode([idx])
435
+ position_logprobs[token_str] = val
436
+
437
+ top_logprobs_list.append(position_logprobs)
438
+ else:
439
+ # For the first token, we don't have a conditional probability
440
+ token_log_probs.append(None)
441
+ top_logprobs_list.append({})
442
+
443
+ # If we don't want to echo prompt tokens, filter them out
444
+ if not echo and prompt_length is not None:
445
+ # Ensure prompt_length is within bounds
446
+ prompt_length = min(prompt_length, len(tokens))
447
+
448
+ # Filter results to only include completion tokens
449
+ if prompt_length < len(tokens):
450
+ filtered_text_offset = text_offset[prompt_length:]
451
+ filtered_token_logprobs = token_log_probs[prompt_length:]
452
+ filtered_tokens = token_strings[prompt_length:]
453
+ filtered_top_logprobs = top_logprobs_list[prompt_length:]
454
+
455
+ return (
456
+ filtered_text_offset,
457
+ filtered_token_logprobs,
458
+ filtered_tokens,
459
+ filtered_top_logprobs,
460
+ )
461
+ else:
462
+ # No completion tokens
463
+ return [], [], [], []
464
+
465
+ return text_offset, token_log_probs, token_strings, top_logprobs_list
466
+
467
+
468
+ class OgaLoad(FirstTool):
469
+ """
470
+ Tool that loads an LLM in OnnxRuntime-GenAI for use with CPU or DirectML execution providers.
471
+
472
+ Input: path to a checkpoint.
473
+ Supported choices for cpu and igpu from HF model repository:
474
+ LLM models on Huggingface supported by model_builder. See documentation
475
+ (https://github.com/lemonade-sdk/lemonade/blob/main/docs/ort_genai_igpu.md)
476
+ for supported models.
477
+ Supported choices for npu from HF model repository:
478
+ Models on Hugging Face that follow the "amd/**-onnx-ryzen-strix" pattern
479
+ Local models for cpu, igpu, or npu:
480
+ The specified checkpoint is converted to a local path, via mapping to lower case
481
+ and replacing '/' with '_'. If this model already exists in the 'models' folder
482
+ of the lemonade cache and if it has a subfolder <device>-<dtype>, then this model
483
+ will be used. If the --force flag is used and the model is built with model_builder,
484
+ then it will be rebuilt.
485
+
486
+
487
+
488
+ Output:
489
+ state.model: handle to a Huggingface-style LLM loaded on DirectML device
490
+ state.tokenizer = Huggingface-style LLM tokenizer instance
491
+ state.dtype = data type of the model on DirectML device
492
+ state.checkpoint = name of the checkpoint used to load state.model
493
+
494
+ Note: This tool expects the onnxruntime-genai-directml library to be pre-installed.
495
+ If that library is not installed, this tool will not load.
496
+ """
497
+
498
+ unique_name = "oga-load"
499
+
500
+ def __init__(self):
501
+ super().__init__(monitor_message="Loading OnnxRuntime-GenAI model")
502
+
503
+ self.status_stats = [
504
+ Keys.DTYPE,
505
+ Keys.DEVICE,
506
+ Keys.LOCAL_MODEL_FOLDER,
507
+ ]
508
+
509
+ @staticmethod
510
+ def parser(add_help: bool = True) -> argparse.ArgumentParser:
511
+ parser = __class__.helpful_parser(
512
+ short_description="Load model in onnxruntime-genai (OGA)",
513
+ add_help=add_help,
514
+ )
515
+
516
+ parser.add_argument(
517
+ "-ip",
518
+ "--input_path",
519
+ default="",
520
+ help="the local huggingface model in your disk",
521
+ )
522
+
523
+ parser.add_argument(
524
+ "-d",
525
+ "--device",
526
+ choices=["igpu", "npu", "cpu", "hybrid", "cuda"],
527
+ default="igpu",
528
+ help="Which device to load the model on to (default: igpu)",
529
+ )
530
+
531
+ parser.add_argument(
532
+ "--dtype",
533
+ choices=["int4", "fp16", "fp32"],
534
+ required=True,
535
+ help="Data type to load the model in",
536
+ )
537
+
538
+ parser.add_argument(
539
+ "--int4-block-size",
540
+ default=None,
541
+ help="Specify the block_size for int4 quantization.",
542
+ choices=[16, 32, 64, 128, 256],
543
+ type=int,
544
+ )
545
+
546
+ parser.add_argument(
547
+ "--force",
548
+ action="store_true",
549
+ help="Forces downloading of Hugging-Face model again (if changed). Additionally for"
550
+ " cpu and igpu devices only, forces model_builder to run again on the HF model"
551
+ " (changed or not).",
552
+ )
553
+
554
+ parser.add_argument(
555
+ "--download-only",
556
+ action="store_true",
557
+ help="Download the model if needed, but don't load it",
558
+ )
559
+
560
+ parser.add_argument(
561
+ "--trust-remote-code",
562
+ action="store_true",
563
+ help="Set this flag to use models whose code is on the Hugging Face hub rather "
564
+ "than natively in the OnnxRuntime Gen AI libraries. Please review the model code "
565
+ "in advance as this is a security risk.",
566
+ )
567
+
568
+ parser.add_argument(
569
+ "--subfolder",
570
+ default=None,
571
+ help="Subfolder where model is located <LEMONADE CACHE>/oga_models/<MODELNAME>"
572
+ "/<SUBFOLDER>, default is <EP for device>-<dtype>. The EPs are: "
573
+ f'{", ".join([value + " for " + key for key, value in execution_providers.items()])}.',
574
+ )
575
+
576
+ return parser
577
+
578
+ @staticmethod
579
+ def _validate_model_configuration(device, dtype, checkpoint):
580
+ """
581
+ Validate if the device, dtype, platform and checkpoint combination are consistent with
582
+ HuggingFace checkpoint naming conventions and specifically for AMD models for NPU
583
+ and hybrid flows.
584
+
585
+ Returns True if device, dtype, and model are consistent.
586
+ """
587
+
588
+ hf_supported_models = {
589
+ "cpu": {"int4": "*/*", "fp32": "*/*"},
590
+ "igpu": {"int4": "*/*", "fp16": "*/*"},
591
+ "npu": {"int4": "*/*"},
592
+ "hybrid": {"int4": "*/*"},
593
+ "cuda": {"int4": "*/*", "fp16": "*/*"},
594
+ }
595
+
596
+ hf_supported = (
597
+ device in hf_supported_models
598
+ and dtype in hf_supported_models[device]
599
+ and fnmatch(checkpoint, hf_supported_models[device][dtype])
600
+ )
601
+ return hf_supported
602
+
603
+ @staticmethod
604
+ def _setup_model_paths(
605
+ state, checkpoint, device, dtype, subfolder, int4_block_size
606
+ ):
607
+ """
608
+ Determines and returns the following model path information for models produced by OGA
609
+ model builder:
610
+
611
+ full_model_path - Full path to where the OGA model files are stored.
612
+ oga_models_subfolder - The subfolder of the oga_models folder where the model files
613
+ are stored. (<full_model_path> = <oga_models>/<oga_models_subfolder>)
614
+ This subfolder is usually
615
+ <checkpoint_string>/<device>-<dtype>[-block-<int4_block_size]>
616
+ but the if the argument subfolder is not None it will override the latter part
617
+ of this path.
618
+ model_exists_locally - True if full_model_path is a folder that contains files
619
+
620
+ Note: Model files already in ONNX format on Hugging Face will be stored in the
621
+ Hugging Face cache, not this folder. The <oga_models> folder contains model
622
+ files that have locally been quantized/converted to OGA format and any other
623
+ models that have been manually added by the user.
624
+ """
625
+ if subfolder is None:
626
+ subfolder = f"{execution_providers[device]}-{dtype}"
627
+ subfolder += (
628
+ f"-block-{int4_block_size}"
629
+ if dtype == "int4" and int4_block_size is not None
630
+ else ""
631
+ )
632
+
633
+ # First, check in the lemonade oga_models cache
634
+ oga_models_subfolder = os.path.join(
635
+ checkpoint.replace("/", "_").lower(), subfolder
636
+ )
637
+ full_model_path = os.path.join(
638
+ state.cache_dir, oga_models_path, oga_models_subfolder
639
+ )
640
+ model_exists_locally = os.path.isdir(full_model_path) and os.listdir(
641
+ full_model_path
642
+ )
643
+
644
+ # If not found in lemonade cache, check in Hugging Face cache
645
+ if not model_exists_locally:
646
+ try:
647
+ snapshot_path = snapshot_download(
648
+ repo_id=checkpoint,
649
+ local_files_only=True,
650
+ )
651
+
652
+ # Check if the snapshot contains ONNX files
653
+ if os.path.isdir(snapshot_path) and os.listdir(snapshot_path):
654
+ is_onnx_model = any(
655
+ filename.endswith(".onnx")
656
+ for filename in os.listdir(snapshot_path)
657
+ )
658
+
659
+ if is_onnx_model:
660
+ # If the model is in HF cache and has ONNX files, use it
661
+ full_model_path = snapshot_path
662
+ model_exists_locally = True
663
+ printing.log_info(
664
+ f"Found ONNX model in Hugging Face cache: {full_model_path}"
665
+ )
666
+ except Exception as e: # pylint: disable=broad-exception-caught
667
+ # Log any errors but continue with the original path
668
+ printing.log_info(f"Error checking Hugging Face cache: {e}")
669
+
670
+ return full_model_path, model_exists_locally
671
+
672
+ @staticmethod
673
+ def _update_hybrid_custom_ops_library_path(full_model_path):
674
+ """
675
+ Modifies the genai_config.json file in the hybrid model folder to set the custom_ops_library
676
+ path to the location of the onnx_custom_ops.dll in the current environment.
677
+ This is needed for hybrid inference.
678
+ """
679
+ oga_path, version = get_oga_hybrid_dir()
680
+
681
+ if "1.3.0" in version:
682
+ custom_ops_path = os.path.join(
683
+ oga_path,
684
+ "onnx_utils",
685
+ "bin",
686
+ "onnx_custom_ops.dll",
687
+ )
688
+ else:
689
+ custom_ops_path = os.path.join(oga_path, "libs", "onnx_custom_ops.dll")
690
+
691
+ # Insert the custom_ops_path into the model config file
692
+ config_path = os.path.join(full_model_path, "genai_config.json")
693
+ if os.path.exists(config_path):
694
+ with open(config_path, "r", encoding="utf-8") as f:
695
+ config = json.load(f)
696
+
697
+ if (
698
+ "model" in config
699
+ and "decoder" in config["model"]
700
+ and "session_options" in config["model"]["decoder"]
701
+ ):
702
+ config["model"]["decoder"]["session_options"][
703
+ "custom_ops_library"
704
+ ] = custom_ops_path
705
+
706
+ with open(config_path, "w", encoding="utf-8") as f:
707
+ json.dump(config, f, indent=4)
708
+
709
+ else:
710
+ printing.log_info(
711
+ f"Model's `genai_config.json` not found in {full_model_path}"
712
+ )
713
+
714
+ @staticmethod
715
+ def _is_preoptimized_model(input_model_path):
716
+ """
717
+ Checks if the 'custom_ops_library' field exists in the genai_config.json file
718
+ to determine if this is a pre-optimized model for hybrid as well
719
+ as NPU only.
720
+
721
+ Args:
722
+ input_model_path (str): Path to the input model directory.
723
+
724
+ Returns:
725
+ bool: True if 'custom_ops_library' exists, False otherwise.
726
+ """
727
+ config_path = os.path.join(input_model_path, "genai_config.json")
728
+ if not os.path.exists(config_path):
729
+ printing.log_info(f"Model's `genai_config.json` not found in {config_path}")
730
+ return False
731
+
732
+ with open(config_path, "r", encoding="utf-8") as f:
733
+ config = json.load(f)
734
+ if (
735
+ "model" in config
736
+ and "decoder" in config["model"]
737
+ and "session_options" in config["model"]["decoder"]
738
+ ):
739
+ return "custom_ops_library" in config["model"]["decoder"]["session_options"]
740
+ return False
741
+
742
+ @staticmethod
743
+ def _download_and_build_safetensors_model(
744
+ checkpoint, device, dtype, full_model_path, int4_block_size, input_path, state
745
+ ):
746
+ """
747
+ Uses OGA model builder to quantize safetensors format model and convert to ONNX
748
+ format. The model files are saved to the full_model_path folder.
749
+ """
750
+ printing.log_info(f"Building {checkpoint} for {device} using {dtype}")
751
+ extra_options = {}
752
+ if int4_block_size is not None:
753
+ extra_options["int4-block-size"] = int4_block_size
754
+ try:
755
+ model_builder.create_model(
756
+ checkpoint,
757
+ input_path,
758
+ full_model_path,
759
+ dtype,
760
+ execution_providers[device],
761
+ os.path.join(state.cache_dir, oga_model_builder_cache_path),
762
+ **extra_options,
763
+ )
764
+ except NotImplementedError as e:
765
+ raise NotImplementedError("[Model builder] " + str(e)) from e
766
+ except OSError as e:
767
+ raise ValueError("[Model builder] " + str(e)) from e
768
+
769
+ return full_model_path
770
+
771
+ @staticmethod
772
+ def _setup_npu_environment():
773
+ """
774
+ Sets up environment for NPU flow of ONNX model and returns saved state to be restored
775
+ later in cleanup.
776
+ """
777
+ oga_path, version = get_oga_npu_dir()
778
+
779
+ if not os.path.exists(os.path.join(oga_path, "libs", "onnxruntime.dll")):
780
+ raise RuntimeError(
781
+ f"Cannot find libs/onnxruntime.dll in lib folder: {oga_path}"
782
+ )
783
+
784
+ # Save current state so they can be restored after inference.
785
+ saved_state = {"cwd": os.getcwd(), "path": os.environ["PATH"]}
786
+
787
+ # Setup NPU environment (cwd and path will be restored later)
788
+ os.chdir(oga_path)
789
+ os.environ["PATH"] = (
790
+ os.path.join(oga_path, "libs") + os.pathsep + os.environ["PATH"]
791
+ )
792
+ if "1.3.0" in version:
793
+ os.environ["DD_ROOT"] = ".\\bins"
794
+ os.environ["DEVICE"] = "stx"
795
+ os.environ["XLNX_ENABLE_CACHE"] = "0"
796
+
797
+ return saved_state
798
+
799
+ @staticmethod
800
+ def _setup_hybrid_environment():
801
+ """
802
+ Sets up the environment for the Hybrid flow and returns saved state to be restored later
803
+ in cleanup.
804
+ """
805
+ # Determine the Ryzen AI OGA version and hybrid artifacts path
806
+ oga_path, version = get_oga_hybrid_dir()
807
+
808
+ if "1.3.0" in version:
809
+ dst_dll = os.path.join(
810
+ oga_path,
811
+ "onnx_utils",
812
+ "bin",
813
+ "DirectML.dll",
814
+ )
815
+ if not os.path.isfile(dst_dll):
816
+ # Artifacts 1.3.0 has DirectML.dll in different subfolder, so copy it to the
817
+ # correct place. This should not be needed in later RAI release artifacts.
818
+ src_dll = os.path.join(
819
+ oga_path,
820
+ "onnxruntime_genai",
821
+ "lib",
822
+ "DirectML.dll",
823
+ )
824
+ os.makedirs(os.path.dirname(dst_dll), exist_ok=True)
825
+ shutil.copy2(src_dll, dst_dll)
826
+
827
+ saved_state = None
828
+ return saved_state
829
+
830
+ @staticmethod
831
+ def _load_model_and_setup_state(
832
+ state, full_model_path, checkpoint, trust_remote_code
833
+ ):
834
+ """
835
+ Loads the OGA model from local folder and then loads the tokenizer.
836
+ Will auto-detect if we're offline.
837
+ """
838
+ try:
839
+ state.model = OrtGenaiModel(full_model_path)
840
+ except Exception as e:
841
+ if "invalid unordered_map<K, T>" in str(e):
842
+ raise ValueError(
843
+ "Error initializing model: Invalid configuration detected.\n"
844
+ "Please check the following:\n"
845
+ f"1. Please check your model's config file in {full_model_path} "
846
+ "and ensure custom_ops_library points to the valid "
847
+ "onnx_custom_ops.dll path.\n"
848
+ "2. Make sure the NPU driver is loaded.\n"
849
+ "3. Make sure hybrid has been installed on a Ryzen AI "
850
+ f"{'or '.join(SUPPORTED_RYZEN_AI_SERIES)}-series processor."
851
+ ) from e
852
+ raise
853
+
854
+ # Auto-detect offline mode
855
+ offline = is_offline()
856
+
857
+ try:
858
+ # Always try to use local files first
859
+ local_files_only = True
860
+
861
+ hf_tokenizer = AutoTokenizer.from_pretrained(
862
+ full_model_path,
863
+ local_files_only=local_files_only,
864
+ trust_remote_code=trust_remote_code,
865
+ )
866
+ except ValueError as e:
867
+ if "trust_remote_code" in str(e):
868
+ raise ValueError(
869
+ "This model requires you to execute code from the repo. Please review it "
870
+ "and if you trust it, then use the `--trust-remote-code` flag with oga-load."
871
+ )
872
+
873
+ if offline and "Can't load tokenizer for" in str(e):
874
+ raise ValueError(
875
+ f"Cannot load tokenizer for {checkpoint} in offline mode. "
876
+ f"The tokenizer files may not be available locally in {full_model_path}."
877
+ )
878
+ raise
879
+
880
+ state.tokenizer = OrtGenaiTokenizer(
881
+ state.model.model,
882
+ hf_tokenizer,
883
+ )
884
+
885
+ status.add_to_state(state=state, name=checkpoint, model=checkpoint)
886
+
887
+ @staticmethod
888
+ def _cleanup_environment(saved_state):
889
+ """
890
+ Restores environment to its original state after inference is complete.
891
+ """
892
+ if saved_state:
893
+ os.chdir(saved_state["cwd"])
894
+ os.environ["PATH"] = saved_state["path"]
895
+
896
+ def _generate_model_for_hybrid_or_npu(
897
+ self, output_model_path, device, input_model_path
898
+ ):
899
+ """
900
+ Uses a subprocess to run the 'model_generate' command for hybrid or npu devices.
901
+ """
902
+
903
+ # Determine the appropriate flag based on the device type
904
+ if device == "hybrid":
905
+ device_flag = "--hybrid"
906
+ elif device == "npu":
907
+ device_flag = "--npu"
908
+ else:
909
+ raise ValueError(f"Unsupported device type for model generation: {device}")
910
+
911
+ command = [
912
+ "model_generate",
913
+ device_flag,
914
+ output_model_path, # Output model directory
915
+ input_model_path, # Input model directory
916
+ ]
917
+
918
+ printing.log_info(f"Running command: {' '.join(command)}")
919
+ try:
920
+ with open(self.logfile_path, "w", encoding="utf-8") as log_file:
921
+ subprocess.run(
922
+ command, check=True, text=True, stdout=log_file, stderr=log_file
923
+ )
924
+ except FileNotFoundError as e:
925
+ error_message = (
926
+ "The 'model_generate' package is missing from your system. "
927
+ "Ensure all required packages are installed. "
928
+ "To install it, run the following command:\n\n"
929
+ " lemonade-install --ryzenai <target> --build-model\n"
930
+ )
931
+ raise RuntimeError(error_message) from e
932
+
933
+ def run(
934
+ self,
935
+ state: State,
936
+ input: str,
937
+ input_path: str = "",
938
+ device: str = "igpu",
939
+ dtype: str = "int4",
940
+ int4_block_size: int = None,
941
+ force: bool = False,
942
+ download_only: bool = False,
943
+ trust_remote_code=False,
944
+ subfolder: str = None,
945
+ ) -> State:
946
+ # Auto-detect offline status
947
+ offline = is_offline()
948
+ if offline:
949
+ printing.log_warning(
950
+ "Network connectivity to huggingface.co not detected. Running in offline mode."
951
+ )
952
+
953
+ state.device = device
954
+ state.dtype = dtype
955
+
956
+ # Log initial stats
957
+ state.save_stat(Keys.DTYPE, dtype)
958
+ state.save_stat(Keys.DEVICE, device)
959
+ if device in ["hybrid", "npu"]:
960
+ ryzen_ai_version_info = get_ryzen_ai_version_info()
961
+ state.save_stat(Keys.RYZEN_AI_VERSION_INFO, ryzen_ai_version_info)
962
+
963
+ # Check if input is a local folder
964
+ if os.path.isdir(input):
965
+ # input is a local folder
966
+ full_model_path = os.path.abspath(input)
967
+ checkpoint = "local_model"
968
+ state.checkpoint = checkpoint
969
+ state.save_stat(Keys.CHECKPOINT, checkpoint)
970
+ state.save_stat(Keys.LOCAL_MODEL_FOLDER, full_model_path)
971
+ # See if there is a file ending in ".onnx" in this folder
972
+ dir = os.listdir(input)
973
+ has_onnx_file = any([filename.endswith(".onnx") for filename in dir])
974
+ if not has_onnx_file:
975
+ raise ValueError(
976
+ f"The folder {input} does not contain an ONNX model file."
977
+ )
978
+ if force:
979
+ raise ValueError(
980
+ "Your input (-i, --input) points to a local folder, which is not "
981
+ "compatible with the force argument."
982
+ )
983
+
984
+ else:
985
+ # input is a model checkpoint
986
+ checkpoint = input
987
+ state.checkpoint = checkpoint
988
+ state.save_stat(Keys.CHECKPOINT, checkpoint)
989
+
990
+ # Get base model information
991
+ if not offline:
992
+ base_model = get_base_model(checkpoint)
993
+ if base_model is not None:
994
+ state.save_stat("base_model", base_model)
995
+
996
+ # Setup paths
997
+ full_model_path, model_exists_locally = self._setup_model_paths(
998
+ state, checkpoint, device, dtype, subfolder, int4_block_size
999
+ )
1000
+
1001
+ # If in offline mode, we can only use locally available models
1002
+ if offline and not model_exists_locally:
1003
+ raise ValueError(
1004
+ f"Model {checkpoint} is not available locally for {device} with {dtype}. "
1005
+ f"Cannot download in offline mode. Check {full_model_path}"
1006
+ )
1007
+
1008
+ # Handle download/build if needed
1009
+ if (not model_exists_locally) or force:
1010
+ if offline:
1011
+ raise ValueError(
1012
+ f"Cannot download or build model {checkpoint} in offline mode"
1013
+ )
1014
+
1015
+ # Validate configuration
1016
+ hf_supported = self._validate_model_configuration(
1017
+ device, dtype, checkpoint
1018
+ )
1019
+
1020
+ if not hf_supported:
1021
+ raise ValueError(
1022
+ "The (device, dtype, checkpoint) combination is not supported: "
1023
+ f"({device}, {dtype}, {checkpoint})"
1024
+ )
1025
+ input_model_path = snapshot_download(
1026
+ repo_id=checkpoint,
1027
+ ignore_patterns=["*.md", "*.txt"],
1028
+ local_files_only=offline,
1029
+ )
1030
+ # Check if model is ONNX or safetensors
1031
+ is_onnx_model = any(
1032
+ [
1033
+ filename.endswith(".onnx")
1034
+ for filename in os.listdir(input_model_path)
1035
+ ]
1036
+ )
1037
+ is_preoptimized_onnx = is_onnx_model and self._is_preoptimized_model(
1038
+ input_model_path
1039
+ )
1040
+ is_safetensors_model = any(
1041
+ [
1042
+ filename.endswith(".safetensors")
1043
+ for filename in os.listdir(input_model_path)
1044
+ ]
1045
+ )
1046
+ if not (is_onnx_model or is_safetensors_model):
1047
+ raise ValueError(
1048
+ f"The model {checkpoint} is not supported. "
1049
+ "It does not contain ONNX or safetensors files."
1050
+ )
1051
+ if device in ["npu", "hybrid"]:
1052
+ if is_onnx_model:
1053
+ if is_preoptimized_onnx:
1054
+ # Use HuggingFace cache path as it is
1055
+ full_model_path = input_model_path
1056
+ else:
1057
+ # If ONNX but not modified yet for Hybrid or NPU,
1058
+ # needs further optimization
1059
+ self._generate_model_for_hybrid_or_npu(
1060
+ full_model_path,
1061
+ device,
1062
+ input_model_path,
1063
+ )
1064
+ elif is_safetensors_model:
1065
+ config_path = os.path.join(input_model_path, "config.json")
1066
+ if os.path.exists(config_path):
1067
+ with open(config_path, "r", encoding="utf-8") as f:
1068
+ config = json.load(f)
1069
+ if "quantization_config" in config:
1070
+ # If quantized, use subprocess to generate the model
1071
+ self._generate_model_for_hybrid_or_npu(
1072
+ full_model_path, device, input_model_path
1073
+ )
1074
+ else:
1075
+ raise ValueError(
1076
+ f"The safetensors model {checkpoint} is not quantized. "
1077
+ "Only quantized safetensors models are supported"
1078
+ " on npu or hybrid targets."
1079
+ )
1080
+ else:
1081
+ raise ValueError(
1082
+ f"config.json not found for safetensors model: {checkpoint}"
1083
+ )
1084
+ else:
1085
+ raise ValueError(
1086
+ f"Unsupported model type for checkpoint: {checkpoint}"
1087
+ )
1088
+ else:
1089
+ if is_onnx_model:
1090
+ # Use HuggingFace cache path as it is
1091
+ full_model_path = input_model_path
1092
+ else:
1093
+ self._download_and_build_safetensors_model(
1094
+ checkpoint,
1095
+ device,
1096
+ dtype,
1097
+ full_model_path,
1098
+ int4_block_size,
1099
+ input_path,
1100
+ state,
1101
+ )
1102
+ state.save_stat(Keys.LOCAL_MODEL_FOLDER, full_model_path)
1103
+
1104
+ # Load model if download-only argument is not set
1105
+ if not download_only:
1106
+
1107
+ saved_env_state = None
1108
+ try:
1109
+ if device == "npu":
1110
+ saved_env_state = self._setup_npu_environment()
1111
+ # Set USE_AIE_RoPE based on model type
1112
+ os.environ["USE_AIE_RoPE"] = (
1113
+ "0" if "phi-" in checkpoint.lower() else "1"
1114
+ )
1115
+ elif device == "hybrid":
1116
+ saved_env_state = self._setup_hybrid_environment()
1117
+ self._update_hybrid_custom_ops_library_path(full_model_path)
1118
+
1119
+ self._load_model_and_setup_state(
1120
+ state, full_model_path, checkpoint, trust_remote_code
1121
+ )
1122
+ finally:
1123
+ self._cleanup_environment(saved_env_state)
1124
+
1125
+ return state
1126
+
1127
+
1128
+ # This file was originally licensed under Apache 2.0. It has been modified.
1129
+ # Modifications Copyright (c) 2025 AMD