lemonade-sdk 7.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lemonade-sdk might be problematic. Click here for more details.

Files changed (61) hide show
  1. lemonade/__init__.py +5 -0
  2. lemonade/api.py +125 -0
  3. lemonade/cache.py +85 -0
  4. lemonade/cli.py +135 -0
  5. lemonade/common/__init__.py +0 -0
  6. lemonade/common/analyze_model.py +26 -0
  7. lemonade/common/build.py +223 -0
  8. lemonade/common/cli_helpers.py +139 -0
  9. lemonade/common/exceptions.py +98 -0
  10. lemonade/common/filesystem.py +368 -0
  11. lemonade/common/labels.py +61 -0
  12. lemonade/common/onnx_helpers.py +176 -0
  13. lemonade/common/plugins.py +10 -0
  14. lemonade/common/printing.py +110 -0
  15. lemonade/common/status.py +490 -0
  16. lemonade/common/system_info.py +390 -0
  17. lemonade/common/tensor_helpers.py +83 -0
  18. lemonade/common/test_helpers.py +28 -0
  19. lemonade/profilers/__init__.py +1 -0
  20. lemonade/profilers/memory_tracker.py +257 -0
  21. lemonade/profilers/profiler.py +55 -0
  22. lemonade/sequence.py +363 -0
  23. lemonade/state.py +159 -0
  24. lemonade/tools/__init__.py +1 -0
  25. lemonade/tools/adapter.py +104 -0
  26. lemonade/tools/bench.py +284 -0
  27. lemonade/tools/huggingface_bench.py +267 -0
  28. lemonade/tools/huggingface_load.py +520 -0
  29. lemonade/tools/humaneval.py +258 -0
  30. lemonade/tools/llamacpp.py +261 -0
  31. lemonade/tools/llamacpp_bench.py +154 -0
  32. lemonade/tools/management_tools.py +273 -0
  33. lemonade/tools/mmlu.py +327 -0
  34. lemonade/tools/ort_genai/__init__.py +0 -0
  35. lemonade/tools/ort_genai/oga.py +1129 -0
  36. lemonade/tools/ort_genai/oga_bench.py +142 -0
  37. lemonade/tools/perplexity.py +146 -0
  38. lemonade/tools/prompt.py +228 -0
  39. lemonade/tools/quark/__init__.py +0 -0
  40. lemonade/tools/quark/quark_load.py +172 -0
  41. lemonade/tools/quark/quark_quantize.py +439 -0
  42. lemonade/tools/report/__init__.py +0 -0
  43. lemonade/tools/report/llm_report.py +203 -0
  44. lemonade/tools/report/table.py +739 -0
  45. lemonade/tools/server/__init__.py +0 -0
  46. lemonade/tools/server/serve.py +1354 -0
  47. lemonade/tools/server/tool_calls.py +146 -0
  48. lemonade/tools/tool.py +374 -0
  49. lemonade/version.py +1 -0
  50. lemonade_install/__init__.py +1 -0
  51. lemonade_install/install.py +774 -0
  52. lemonade_sdk-7.0.0.dist-info/METADATA +116 -0
  53. lemonade_sdk-7.0.0.dist-info/RECORD +61 -0
  54. lemonade_sdk-7.0.0.dist-info/WHEEL +5 -0
  55. lemonade_sdk-7.0.0.dist-info/entry_points.txt +4 -0
  56. lemonade_sdk-7.0.0.dist-info/licenses/LICENSE +201 -0
  57. lemonade_sdk-7.0.0.dist-info/licenses/NOTICE.md +21 -0
  58. lemonade_sdk-7.0.0.dist-info/top_level.txt +3 -0
  59. lemonade_server/cli.py +260 -0
  60. lemonade_server/model_manager.py +98 -0
  61. lemonade_server/server_models.json +142 -0
@@ -0,0 +1,142 @@
1
+ import argparse
2
+ import statistics
3
+ from statistics import StatisticsError
4
+ from lemonade.state import State
5
+ from lemonade.cache import Keys
6
+ from lemonade.tools.adapter import ModelAdapter, TokenizerAdapter
7
+ from lemonade.tools.bench import Bench
8
+
9
+
10
+ class OgaBench(Bench):
11
+ """
12
+ Benchmark any model that adheres to the ModelAdapter interface.
13
+
14
+ Required input state:
15
+ - MODEL: model instance to benchmark.
16
+ - TOKENIZER: tokenizer instance used to generate inputs for the model.
17
+
18
+ Output state produced: None
19
+ """
20
+
21
+ unique_name = "oga-bench"
22
+
23
+ def __init__(self):
24
+ super().__init__()
25
+
26
+ # Additional statistics generated by this bench tool
27
+ self.status_stats.insert(
28
+ self.status_stats.index(Keys.TOKEN_GENERATION_TOKENS_PER_SECOND) + 1,
29
+ Keys.STD_DEV_TOKENS_PER_SECOND,
30
+ )
31
+ self.std_dev_token_generation_tokens_per_second_list = []
32
+
33
+ @staticmethod
34
+ def parser(add_help: bool = True) -> argparse.ArgumentParser:
35
+ parser = __class__.helpful_parser(
36
+ short_description="Benchmark an LLM in onnxruntime-genai (OGA)",
37
+ add_help=add_help,
38
+ )
39
+
40
+ parser = Bench.parser(parser)
41
+
42
+ return parser
43
+
44
+ def get_prompt_str(self, state, token_length):
45
+ """
46
+ Returns a string with the prescribed token length.
47
+ """
48
+ tokenizer: TokenizerAdapter = state.tokenizer
49
+ test_prompt = "word " * (token_length - 1)
50
+ input_ids = tokenizer(test_prompt, return_tensors="pt").input_ids
51
+ test_token_length = len(input_ids)
52
+ delta = test_token_length - token_length
53
+ if delta == 0:
54
+ return test_prompt
55
+ return "word " * max(token_length - 1 - delta, 0)
56
+
57
+ def run_prompt(
58
+ self,
59
+ state: State,
60
+ report_progress_fn,
61
+ prompt: str,
62
+ iterations: int,
63
+ warmup_iterations: int,
64
+ output_tokens: int,
65
+ ) -> State:
66
+
67
+ model: ModelAdapter = state.model
68
+ tokenizer: TokenizerAdapter = state.tokenizer
69
+
70
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
71
+ self.input_ids_len_list.append(len(input_ids))
72
+ per_iteration_time_to_first_token = []
73
+ per_iteration_tokens_per_second = []
74
+
75
+ # Don't capture time for warmup
76
+ for count in range(warmup_iterations):
77
+ outputs = model.generate(input_ids, max_new_tokens=output_tokens)
78
+ self.tokens_out_len_list.append(len(outputs[0]) - len(input_ids))
79
+ report_progress_fn((count + 1) / (warmup_iterations + iterations))
80
+
81
+ for count in range(iterations):
82
+ outputs = model.generate(
83
+ input_ids,
84
+ max_new_tokens=output_tokens,
85
+ min_new_tokens=output_tokens,
86
+ )
87
+ report_progress_fn(
88
+ (warmup_iterations + count + 1) / (warmup_iterations + iterations)
89
+ )
90
+
91
+ token_len = len(outputs[0]) - len(input_ids)
92
+ self.tokens_out_len_list.append(token_len)
93
+
94
+ # Only count an iteration if it produced enough tokens
95
+ if token_len >= output_tokens:
96
+ per_iteration_time_to_first_token.append(model.time_to_first_token)
97
+ per_iteration_tokens_per_second.append(model.tokens_per_second)
98
+
99
+ if not per_iteration_time_to_first_token or not per_iteration_tokens_per_second:
100
+ raise Bench.not_enough_tokens(output_tokens)
101
+
102
+ mean_time_to_first_token = statistics.mean(per_iteration_time_to_first_token)
103
+ self.mean_time_to_first_token_list.append(mean_time_to_first_token)
104
+ self.prefill_tokens_per_second_list.append(
105
+ len(input_ids) / mean_time_to_first_token
106
+ )
107
+ self.token_generation_tokens_per_second_list.append(
108
+ statistics.mean(per_iteration_tokens_per_second)
109
+ )
110
+ try:
111
+ self.std_dev_time_to_first_token_list.append(
112
+ statistics.stdev(per_iteration_time_to_first_token)
113
+ )
114
+ except StatisticsError:
115
+ # Less than 2 measurements
116
+ self.std_dev_time_to_first_token_list.append(None)
117
+ try:
118
+ self.std_dev_token_generation_tokens_per_second_list.append(
119
+ statistics.stdev(per_iteration_tokens_per_second)
120
+ )
121
+ except StatisticsError:
122
+ # Less than 2 measurements
123
+ self.std_dev_token_generation_tokens_per_second_list.append(None)
124
+
125
+ def save_stats(self, state):
126
+ super().save_stats(state)
127
+
128
+ # Save additional statistics
129
+ if not all(
130
+ element is None
131
+ for element in self.std_dev_token_generation_tokens_per_second_list
132
+ ):
133
+ state.save_stat(
134
+ Keys.STD_DEV_TOKENS_PER_SECOND,
135
+ self.get_item_or_list(
136
+ self.std_dev_token_generation_tokens_per_second_list
137
+ ),
138
+ )
139
+
140
+
141
+ # This file was originally licensed under Apache 2.0. It has been modified.
142
+ # Modifications Copyright (c) 2025 AMD
@@ -0,0 +1,146 @@
1
+ import os
2
+ import argparse
3
+ import pandas as pd
4
+ import torch
5
+ from datasets import load_dataset
6
+ from lemonade.state import State
7
+ from lemonade.tools import Tool
8
+ import lemonade.common.printing as printing
9
+ import lemonade.common.build as build
10
+
11
+
12
+ class AccuracyPerplexity(Tool):
13
+ """
14
+ Measure perplexity of an LLM using the Wikitext-2 dataset.
15
+
16
+ Required input state:
17
+ - state.model: instance that provides a __call__() method that returns
18
+ output.logits and supports model.config.max_position_embeddings
19
+ - state.tokenizer: instance of Hugging Face PretrainedTokenizer
20
+
21
+ Output state produced: None
22
+
23
+ See docs/lemonade/perplexity.md for more details.
24
+ """
25
+
26
+ unique_name = "accuracy-perplexity"
27
+
28
+ def __init__(self):
29
+ super().__init__(monitor_message="Measuring perplexity")
30
+
31
+ @staticmethod
32
+ def parser(add_help: bool = True) -> argparse.ArgumentParser:
33
+ parser = __class__.helpful_parser(
34
+ short_description="Measure perplexity score",
35
+ add_help=add_help,
36
+ )
37
+ return parser
38
+
39
+ def run(
40
+ self,
41
+ state: State,
42
+ ) -> State:
43
+
44
+ try:
45
+ printing.log_info("Downloading dataset ...")
46
+ dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
47
+ except Exception as e: # pylint: disable=broad-except
48
+ printing.log_error(f"Error during dataset load: {e}")
49
+ raise e
50
+
51
+ tokenizer = state.tokenizer
52
+ model = state.model
53
+ # Tokenize the entire test dataset text, joining entries with double new lines
54
+ encodings = tokenizer("\n\n".join(dataset["text"]), return_tensors="pt")
55
+
56
+ # Retrieve the maximum input length that the model can handle
57
+ try:
58
+ max_length = model.config.max_position_embeddings
59
+ except AttributeError:
60
+ # Some LLMs do not have the config.max_position_embeddings attribute
61
+ # However, most LLMs support at least 2048 context length, so this
62
+ # try-except will allow a few more LLMs to work
63
+ max_length = 2048
64
+ # Set stride to half of the maximum input length for overlapping window processing
65
+ # Refer to docs/perplexity.md for more information on sliding window
66
+ stride = max_length // 2
67
+ # Determine the total sequence length of the tokenized input
68
+ seq_len = encodings.input_ids.size(1)
69
+
70
+ negative_log_likelihoods = []
71
+ summary_data = []
72
+ prev_end_location = 0
73
+
74
+ model_results_dir = os.path.join(
75
+ build.output_dir(state.cache_dir, state.build_name), "perplexity"
76
+ )
77
+
78
+ for begin_location in range(0, seq_len, stride):
79
+ end_location = min(begin_location + max_length, seq_len)
80
+ target_len = end_location - prev_end_location
81
+ input_ids = encodings.input_ids[:, begin_location:end_location]
82
+ target_ids = input_ids.clone()
83
+ target_ids[:, :-target_len] = -100
84
+
85
+ # Forward pass the model to get logits
86
+ with torch.no_grad():
87
+ try:
88
+ outputs = model(input_ids, labels=target_ids)
89
+ logits = outputs.logits
90
+ except Exception as e: # pylint: disable=broad-except
91
+ printing.log_error(
92
+ f"Error during model forward pass execution: {e}"
93
+ )
94
+
95
+ # Compute loss manually for visualization
96
+ shift_logits = logits[..., :-1, :].contiguous()
97
+ shift_labels = target_ids[..., 1:].contiguous()
98
+ effective_token_count = (target_ids != -100).sum().item()
99
+ negative_log_likelihoods.append(
100
+ (outputs.loss.item(), effective_token_count)
101
+ )
102
+
103
+ # Decode predicted and actual next words for the last token position
104
+ predictions = torch.argmax(shift_logits, dim=-1)
105
+ predicted_tokens = predictions[:, -1]
106
+ actual_tokens = shift_labels[:, -1]
107
+
108
+ predicted_words = tokenizer.batch_decode(
109
+ predicted_tokens, skip_special_tokens=True
110
+ )
111
+ actual_words = tokenizer.batch_decode(
112
+ actual_tokens, skip_special_tokens=True
113
+ )
114
+ context = tokenizer.decode(input_ids[0, :])
115
+
116
+ summary_data.append(
117
+ {
118
+ "Context": context[-stride:],
119
+ "Predicted next word": predicted_words,
120
+ "Actual next word": actual_words,
121
+ "Loss for this window": outputs.loss.item(),
122
+ }
123
+ )
124
+ prev_end_location = end_location
125
+
126
+ # Total loss calculation considering the number of tokens for each segment
127
+ total_loss = sum(loss * count for loss, count in negative_log_likelihoods)
128
+ total_tokens = sum(count for _, count in negative_log_likelihoods)
129
+
130
+ # Calculate average negative_log_likelihood and perplexity
131
+ average_negative_log_likelihood = total_loss / total_tokens
132
+ perplexity = torch.exp(torch.tensor(average_negative_log_likelihood))
133
+
134
+ # Save accuracy results to stats file
135
+ state.save_stat("perplexity_score", float(perplexity.item()))
136
+
137
+ # Save accuracy results to CSV file
138
+ summary_df = pd.DataFrame(summary_data)
139
+ summary_df.to_csv(
140
+ os.path.join(model_results_dir, "summary_results.csv"), index=False
141
+ )
142
+ return state
143
+
144
+
145
+ # This file was originally licensed under Apache 2.0. It has been modified.
146
+ # Modifications Copyright (c) 2025 AMD
@@ -0,0 +1,228 @@
1
+ import argparse
2
+ import os
3
+ import matplotlib.pyplot as plt
4
+ import lemonade.common.build as build
5
+ import lemonade.common.printing as printing
6
+ from lemonade.state import State
7
+ from lemonade.tools import Tool
8
+ from lemonade.tools.adapter import ModelAdapter, TokenizerAdapter
9
+ from lemonade.cache import Keys
10
+
11
+ DEFAULT_GENERATE_PARAMS = {
12
+ "do_sample": True,
13
+ "top_k": 50,
14
+ "top_p": 0.95,
15
+ "temperature": 0.7,
16
+ }
17
+
18
+ DEFAULT_MAX_NEW_TOKENS = 512
19
+ DEFAULT_N_TRIALS = 1
20
+
21
+
22
+ def sanitize_string(input_string):
23
+ return input_string.encode("charmap", "ignore").decode("charmap")
24
+
25
+
26
+ def sanitize_text(text):
27
+ if isinstance(text, str):
28
+ return sanitize_string(text)
29
+ elif isinstance(text, list):
30
+ return [sanitize_string(item) for item in text]
31
+ else:
32
+ raise TypeError("Input must be a string or a list of strings.")
33
+
34
+
35
+ def positive_int(x):
36
+ """Conversion function for argparse"""
37
+ i = int(x)
38
+ if i < 1:
39
+ raise ValueError("Non-positive values are not allowed")
40
+ return i
41
+
42
+
43
+ class LLMPrompt(Tool):
44
+ """
45
+ Send a prompt to an LLM instance and print the response to the screen.
46
+
47
+ Required input state:
48
+ - state.model: LLM instance that supports the generate() method.
49
+ - state.tokenizer: LLM tokenizer instance that supports the __call__() (ie, encode)
50
+ and decode() methods.
51
+
52
+ Output state produced:
53
+ - "response": text response from the LLM.
54
+ """
55
+
56
+ unique_name = "llm-prompt"
57
+
58
+ def __init__(self):
59
+ super().__init__(monitor_message="Prompting LLM")
60
+
61
+ self.status_stats = [
62
+ Keys.PROMPT_TOKENS,
63
+ Keys.PROMPT,
64
+ Keys.PROMPT_TEMPLATE,
65
+ Keys.RESPONSE_TOKENS,
66
+ Keys.RESPONSE,
67
+ Keys.RESPONSE_LENGTHS_HISTOGRAM,
68
+ ]
69
+
70
+ @staticmethod
71
+ def parser(add_help: bool = True) -> argparse.ArgumentParser:
72
+ parser = __class__.helpful_parser(
73
+ short_description="Prompt an LLM and print the result",
74
+ add_help=add_help,
75
+ )
76
+
77
+ parser.add_argument(
78
+ "--prompt",
79
+ "-p",
80
+ help="Input prompt to the LLM. Two formats are supported: "
81
+ "1) str: use a user-provided prompt string, and "
82
+ "2) path/to/prompt.txt: load the prompt from a .txt file.",
83
+ required=True,
84
+ )
85
+
86
+ parser.add_argument(
87
+ "--template",
88
+ "-t",
89
+ action="store_true",
90
+ help="Insert the prompt into the model's chat template before processing.",
91
+ )
92
+
93
+ parser.add_argument(
94
+ "--max-new-tokens",
95
+ "-m",
96
+ default=DEFAULT_MAX_NEW_TOKENS,
97
+ type=int,
98
+ help=f"Maximum number of new tokens in the response "
99
+ f"(default is {DEFAULT_MAX_NEW_TOKENS})",
100
+ )
101
+
102
+ parser.add_argument(
103
+ "--n-trials",
104
+ "-n",
105
+ default=DEFAULT_N_TRIALS,
106
+ type=positive_int,
107
+ help=f"Number of responses the LLM will generate for the prompt "
108
+ f"(useful for testing, default is {DEFAULT_N_TRIALS})",
109
+ )
110
+
111
+ return parser
112
+
113
+ def parse(self, state: State, args, known_only=True) -> argparse.Namespace:
114
+ """
115
+ Helper function to parse CLI arguments into the args expected
116
+ by run()
117
+ """
118
+
119
+ parsed_args = super().parse(state, args, known_only)
120
+
121
+ # Decode prompt arg into a string prompt
122
+ if parsed_args.prompt.endswith(".txt") and os.path.exists(parsed_args.prompt):
123
+ with open(parsed_args.prompt, "r", encoding="utf-8") as f:
124
+ parsed_args.prompt = f.read()
125
+
126
+ return parsed_args
127
+
128
+ def run(
129
+ self,
130
+ state: State,
131
+ prompt: str = "Hello",
132
+ max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
133
+ n_trials: int = DEFAULT_N_TRIALS,
134
+ template: bool = False,
135
+ ) -> State:
136
+
137
+ model: ModelAdapter = state.model
138
+ tokenizer: TokenizerAdapter = state.tokenizer
139
+
140
+ # If template flag is set, then wrap prompt in template
141
+ if template:
142
+ # Embed prompt in model's chat template
143
+ if tokenizer.chat_template:
144
+ # Use the model's built-in chat template if available
145
+ messages_dict = [{"role": "user", "content": prompt}]
146
+ prompt = tokenizer.apply_chat_template(
147
+ messages_dict, tokenize=False, add_generation_prompt=True
148
+ )
149
+ state.save_stat(Keys.PROMPT_TEMPLATE, "Model-specific")
150
+ else:
151
+ # Fallback to a standardized template
152
+ printing.log_info("No chat template found. Using default template.")
153
+ prompt = f"<|user|>\n{prompt} <|end|>\n<|assistant|>"
154
+ state.save_stat(Keys.PROMPT_TEMPLATE, "Default")
155
+
156
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
157
+ if isinstance(input_ids, (list, str)):
158
+ # OGA models return a list of tokens
159
+ # Our llama.cpp adapter returns a string
160
+ len_tokens_in = len(input_ids)
161
+ else:
162
+ # HF models return a 2-D tensor
163
+ len_tokens_in = input_ids.shape[1]
164
+
165
+ len_tokens_out = []
166
+ response_texts = []
167
+ for trial in range(n_trials):
168
+ if n_trials > 1:
169
+ self.set_percent_progress(100.0 * trial / n_trials)
170
+
171
+ # Get the response from the LLM, which may include the prompt in it
172
+ response = model.generate(
173
+ input_ids, max_new_tokens=max_new_tokens, **DEFAULT_GENERATE_PARAMS
174
+ )
175
+
176
+ # Flatten the input and response
177
+ input_ids_array = (
178
+ input_ids if isinstance(input_ids, (list, str)) else input_ids[0]
179
+ )
180
+ response_array = response if isinstance(response, str) else response[0]
181
+
182
+ # Separate the prompt from the response
183
+ len_tokens_out.append(len(response_array) - len_tokens_in)
184
+
185
+ input_token = 0
186
+ while (
187
+ input_token < len_tokens_in
188
+ and input_ids_array[input_token] == response_array[input_token]
189
+ ):
190
+ input_token += 1
191
+
192
+ # Only decode the actual response (not the prompt)
193
+ response_text = tokenizer.decode(
194
+ response_array[input_token:], skip_special_tokens=True
195
+ ).strip()
196
+ response_texts.append(response_text)
197
+
198
+ state.response = response_texts
199
+
200
+ if n_trials == 1:
201
+ len_tokens_out = len_tokens_out[0]
202
+ response_texts = response_texts[0]
203
+ else:
204
+ self.set_percent_progress(None)
205
+
206
+ # Plot data
207
+ plt.figure()
208
+ plt.hist(len_tokens_out, bins=20)
209
+ plt.xlabel("Response Length (tokens)")
210
+ plt.ylabel("Frequency")
211
+ plt.title(f"Histogram of Response Lengths\n{state.build_name}")
212
+ figure_path = os.path.join(
213
+ build.output_dir(state.cache_dir, state.build_name),
214
+ "response_lengths.png",
215
+ )
216
+ plt.savefig(figure_path)
217
+ state.save_stat(Keys.RESPONSE_LENGTHS_HISTOGRAM, figure_path)
218
+
219
+ state.save_stat(Keys.PROMPT_TOKENS, len_tokens_in)
220
+ state.save_stat(Keys.PROMPT, prompt)
221
+ state.save_stat(Keys.RESPONSE_TOKENS, len_tokens_out)
222
+ state.save_stat(Keys.RESPONSE, sanitize_text(response_texts))
223
+
224
+ return state
225
+
226
+
227
+ # This file was originally licensed under Apache 2.0. It has been modified.
228
+ # Modifications Copyright (c) 2025 AMD
File without changes
@@ -0,0 +1,172 @@
1
+ import argparse
2
+ import os
3
+ import sys
4
+
5
+ import torch
6
+ from lemonade.state import State
7
+ from lemonade.tools import Tool
8
+ import lemonade.common.printing as printing
9
+ import lemonade.common.build as build
10
+ from lemonade_install.install import DEFAULT_QUARK_DIR
11
+
12
+
13
+ class QuarkLoad(Tool):
14
+ """
15
+ Load a model Quantized and exported using Quark.
16
+ Required Input State:
17
+ - state.model: Pretrained model instance to be quantized.
18
+ - state.tokenizer: Tokenizer instance from Hugging Face.
19
+ Output:
20
+ - state of the loaded model
21
+
22
+ See docs/quark.md for more details.
23
+ """
24
+
25
+ unique_name = "quark-load"
26
+
27
+ def __init__(self):
28
+ super().__init__(monitor_message="Load Quark Quantized model")
29
+
30
+ @staticmethod
31
+ def parser(add_help: bool = True) -> argparse.ArgumentParser:
32
+ parser = __class__.helpful_parser(
33
+ short_description="Load a quantized model using Quark",
34
+ add_help=add_help,
35
+ )
36
+
37
+ parser.add_argument(
38
+ "--quant-scheme",
39
+ type=str,
40
+ required=True,
41
+ default=None,
42
+ help="Supported quantization schemes in Quark",
43
+ )
44
+
45
+ parser.add_argument(
46
+ "--quant-algo",
47
+ type=str,
48
+ required=True,
49
+ default=None,
50
+ choices=["awq", "gptq", "autosmoothquant", None],
51
+ help="Supported quantization algorithms in Quark",
52
+ )
53
+
54
+ parser.add_argument(
55
+ "--torch-compile", action="store_true", help="Model torch compile"
56
+ )
57
+
58
+ parser.add_argument(
59
+ "--safetensors-model-reload",
60
+ action="store_true",
61
+ help="Safetensors model reload",
62
+ )
63
+
64
+ parser.add_argument(
65
+ "--safetensors-model-dir",
66
+ default=None,
67
+ help="Directory of safetensors model",
68
+ )
69
+
70
+ parser.add_argument(
71
+ "--params-load", action="store_true", help="Model parameters load"
72
+ )
73
+
74
+ parser.add_argument("--json-path", help="Specify the path of saved json file")
75
+
76
+ parser.add_argument(
77
+ "--safetensors-path",
78
+ default=None,
79
+ help="Specify the path of saved safetensors file",
80
+ )
81
+
82
+ return parser
83
+
84
+ def run(
85
+ self,
86
+ state: State,
87
+ quant_scheme: str,
88
+ quant_algo: str,
89
+ torch_compile: bool = False,
90
+ safetensors_model_reload: bool = False,
91
+ safetensors_model_dir: str = None,
92
+ params_load: bool = False,
93
+ json_path: str = None,
94
+ safetensors_path: str = None,
95
+ ) -> State:
96
+ """
97
+ Executes the QuarkLoad process.
98
+ Returns:
99
+ State: The updated state after loading the model.
100
+ Raises:
101
+ Exception: If an error occurs during the QuarkLoad process.
102
+ """
103
+
104
+ try:
105
+ if os.path.isdir(DEFAULT_QUARK_DIR):
106
+ quark_llm_path = os.path.join(
107
+ DEFAULT_QUARK_DIR, "examples", "torch", "language_modeling"
108
+ )
109
+ sys.path.insert(0, quark_llm_path)
110
+ else:
111
+ raise FileNotFoundError(
112
+ f"The directory {DEFAULT_QUARK_DIR} does not exist. \
113
+ Please check your installation."
114
+ )
115
+
116
+ # Default load path specific to recipe
117
+ # This will NOT work
118
+ # The default path is now uniquely craeated with timestamp
119
+ # Default load path will not work. Need to pass explicit load path
120
+ model_export_path = os.path.join(
121
+ build.output_dir(state.cache_dir, state.build_name),
122
+ "exported_model",
123
+ quant_scheme,
124
+ quant_algo,
125
+ )
126
+
127
+ # Set default paths only if current values are None
128
+ if safetensors_model_dir is None:
129
+ safetensors_model_dir = model_export_path
130
+ if safetensors_path is None:
131
+ safetensors_path = os.path.join(model_export_path, "model.safetensors")
132
+ printing.log_info("Loading model ...")
133
+ if not params_load and not safetensors_model_reload:
134
+ raise ValueError(
135
+ " Specify load format: 'params_load' or 'safetensors_model_reload'."
136
+ )
137
+
138
+ # Reload quantized model if specified
139
+ from quark.torch import load_params, import_model_info
140
+
141
+ if params_load:
142
+ printing.log_info(
143
+ "Restoring quantized model from JSON/safetensors files"
144
+ )
145
+ model = load_params(
146
+ model,
147
+ json_path=json_path,
148
+ safetensors_path=safetensors_path,
149
+ )
150
+ elif safetensors_model_reload:
151
+ printing.log_info(
152
+ "Restoring quantized model from quark_safetensors files"
153
+ )
154
+ model = import_model_info(model, model_info_dir=safetensors_model_dir)
155
+
156
+ if torch_compile:
157
+ printing.log_info("torch.compile...")
158
+ model = torch.compile(model)
159
+
160
+ state.model = model
161
+ state.dtype = model.dtype
162
+
163
+ printing.log_info("Quark Load process completed.")
164
+
165
+ except Exception as e:
166
+ printing.log_error(f"An error occurred during the QuarkLoad process: {e}")
167
+ raise
168
+ return state
169
+
170
+
171
+ # This file was originally licensed under Apache 2.0. It has been modified.
172
+ # Modifications Copyright (c) 2025 AMD