lemonade-sdk 7.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lemonade-sdk might be problematic. Click here for more details.
- lemonade/__init__.py +5 -0
- lemonade/api.py +125 -0
- lemonade/cache.py +85 -0
- lemonade/cli.py +135 -0
- lemonade/common/__init__.py +0 -0
- lemonade/common/analyze_model.py +26 -0
- lemonade/common/build.py +223 -0
- lemonade/common/cli_helpers.py +139 -0
- lemonade/common/exceptions.py +98 -0
- lemonade/common/filesystem.py +368 -0
- lemonade/common/labels.py +61 -0
- lemonade/common/onnx_helpers.py +176 -0
- lemonade/common/plugins.py +10 -0
- lemonade/common/printing.py +110 -0
- lemonade/common/status.py +490 -0
- lemonade/common/system_info.py +390 -0
- lemonade/common/tensor_helpers.py +83 -0
- lemonade/common/test_helpers.py +28 -0
- lemonade/profilers/__init__.py +1 -0
- lemonade/profilers/memory_tracker.py +257 -0
- lemonade/profilers/profiler.py +55 -0
- lemonade/sequence.py +363 -0
- lemonade/state.py +159 -0
- lemonade/tools/__init__.py +1 -0
- lemonade/tools/adapter.py +104 -0
- lemonade/tools/bench.py +284 -0
- lemonade/tools/huggingface_bench.py +267 -0
- lemonade/tools/huggingface_load.py +520 -0
- lemonade/tools/humaneval.py +258 -0
- lemonade/tools/llamacpp.py +261 -0
- lemonade/tools/llamacpp_bench.py +154 -0
- lemonade/tools/management_tools.py +273 -0
- lemonade/tools/mmlu.py +327 -0
- lemonade/tools/ort_genai/__init__.py +0 -0
- lemonade/tools/ort_genai/oga.py +1129 -0
- lemonade/tools/ort_genai/oga_bench.py +142 -0
- lemonade/tools/perplexity.py +146 -0
- lemonade/tools/prompt.py +228 -0
- lemonade/tools/quark/__init__.py +0 -0
- lemonade/tools/quark/quark_load.py +172 -0
- lemonade/tools/quark/quark_quantize.py +439 -0
- lemonade/tools/report/__init__.py +0 -0
- lemonade/tools/report/llm_report.py +203 -0
- lemonade/tools/report/table.py +739 -0
- lemonade/tools/server/__init__.py +0 -0
- lemonade/tools/server/serve.py +1354 -0
- lemonade/tools/server/tool_calls.py +146 -0
- lemonade/tools/tool.py +374 -0
- lemonade/version.py +1 -0
- lemonade_install/__init__.py +1 -0
- lemonade_install/install.py +774 -0
- lemonade_sdk-7.0.0.dist-info/METADATA +116 -0
- lemonade_sdk-7.0.0.dist-info/RECORD +61 -0
- lemonade_sdk-7.0.0.dist-info/WHEEL +5 -0
- lemonade_sdk-7.0.0.dist-info/entry_points.txt +4 -0
- lemonade_sdk-7.0.0.dist-info/licenses/LICENSE +201 -0
- lemonade_sdk-7.0.0.dist-info/licenses/NOTICE.md +21 -0
- lemonade_sdk-7.0.0.dist-info/top_level.txt +3 -0
- lemonade_server/cli.py +260 -0
- lemonade_server/model_manager.py +98 -0
- lemonade_server/server_models.json +142 -0
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import statistics
|
|
3
|
+
from statistics import StatisticsError
|
|
4
|
+
from lemonade.state import State
|
|
5
|
+
from lemonade.cache import Keys
|
|
6
|
+
from lemonade.tools.adapter import ModelAdapter, TokenizerAdapter
|
|
7
|
+
from lemonade.tools.bench import Bench
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class OgaBench(Bench):
|
|
11
|
+
"""
|
|
12
|
+
Benchmark any model that adheres to the ModelAdapter interface.
|
|
13
|
+
|
|
14
|
+
Required input state:
|
|
15
|
+
- MODEL: model instance to benchmark.
|
|
16
|
+
- TOKENIZER: tokenizer instance used to generate inputs for the model.
|
|
17
|
+
|
|
18
|
+
Output state produced: None
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
unique_name = "oga-bench"
|
|
22
|
+
|
|
23
|
+
def __init__(self):
|
|
24
|
+
super().__init__()
|
|
25
|
+
|
|
26
|
+
# Additional statistics generated by this bench tool
|
|
27
|
+
self.status_stats.insert(
|
|
28
|
+
self.status_stats.index(Keys.TOKEN_GENERATION_TOKENS_PER_SECOND) + 1,
|
|
29
|
+
Keys.STD_DEV_TOKENS_PER_SECOND,
|
|
30
|
+
)
|
|
31
|
+
self.std_dev_token_generation_tokens_per_second_list = []
|
|
32
|
+
|
|
33
|
+
@staticmethod
|
|
34
|
+
def parser(add_help: bool = True) -> argparse.ArgumentParser:
|
|
35
|
+
parser = __class__.helpful_parser(
|
|
36
|
+
short_description="Benchmark an LLM in onnxruntime-genai (OGA)",
|
|
37
|
+
add_help=add_help,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
parser = Bench.parser(parser)
|
|
41
|
+
|
|
42
|
+
return parser
|
|
43
|
+
|
|
44
|
+
def get_prompt_str(self, state, token_length):
|
|
45
|
+
"""
|
|
46
|
+
Returns a string with the prescribed token length.
|
|
47
|
+
"""
|
|
48
|
+
tokenizer: TokenizerAdapter = state.tokenizer
|
|
49
|
+
test_prompt = "word " * (token_length - 1)
|
|
50
|
+
input_ids = tokenizer(test_prompt, return_tensors="pt").input_ids
|
|
51
|
+
test_token_length = len(input_ids)
|
|
52
|
+
delta = test_token_length - token_length
|
|
53
|
+
if delta == 0:
|
|
54
|
+
return test_prompt
|
|
55
|
+
return "word " * max(token_length - 1 - delta, 0)
|
|
56
|
+
|
|
57
|
+
def run_prompt(
|
|
58
|
+
self,
|
|
59
|
+
state: State,
|
|
60
|
+
report_progress_fn,
|
|
61
|
+
prompt: str,
|
|
62
|
+
iterations: int,
|
|
63
|
+
warmup_iterations: int,
|
|
64
|
+
output_tokens: int,
|
|
65
|
+
) -> State:
|
|
66
|
+
|
|
67
|
+
model: ModelAdapter = state.model
|
|
68
|
+
tokenizer: TokenizerAdapter = state.tokenizer
|
|
69
|
+
|
|
70
|
+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
|
71
|
+
self.input_ids_len_list.append(len(input_ids))
|
|
72
|
+
per_iteration_time_to_first_token = []
|
|
73
|
+
per_iteration_tokens_per_second = []
|
|
74
|
+
|
|
75
|
+
# Don't capture time for warmup
|
|
76
|
+
for count in range(warmup_iterations):
|
|
77
|
+
outputs = model.generate(input_ids, max_new_tokens=output_tokens)
|
|
78
|
+
self.tokens_out_len_list.append(len(outputs[0]) - len(input_ids))
|
|
79
|
+
report_progress_fn((count + 1) / (warmup_iterations + iterations))
|
|
80
|
+
|
|
81
|
+
for count in range(iterations):
|
|
82
|
+
outputs = model.generate(
|
|
83
|
+
input_ids,
|
|
84
|
+
max_new_tokens=output_tokens,
|
|
85
|
+
min_new_tokens=output_tokens,
|
|
86
|
+
)
|
|
87
|
+
report_progress_fn(
|
|
88
|
+
(warmup_iterations + count + 1) / (warmup_iterations + iterations)
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
token_len = len(outputs[0]) - len(input_ids)
|
|
92
|
+
self.tokens_out_len_list.append(token_len)
|
|
93
|
+
|
|
94
|
+
# Only count an iteration if it produced enough tokens
|
|
95
|
+
if token_len >= output_tokens:
|
|
96
|
+
per_iteration_time_to_first_token.append(model.time_to_first_token)
|
|
97
|
+
per_iteration_tokens_per_second.append(model.tokens_per_second)
|
|
98
|
+
|
|
99
|
+
if not per_iteration_time_to_first_token or not per_iteration_tokens_per_second:
|
|
100
|
+
raise Bench.not_enough_tokens(output_tokens)
|
|
101
|
+
|
|
102
|
+
mean_time_to_first_token = statistics.mean(per_iteration_time_to_first_token)
|
|
103
|
+
self.mean_time_to_first_token_list.append(mean_time_to_first_token)
|
|
104
|
+
self.prefill_tokens_per_second_list.append(
|
|
105
|
+
len(input_ids) / mean_time_to_first_token
|
|
106
|
+
)
|
|
107
|
+
self.token_generation_tokens_per_second_list.append(
|
|
108
|
+
statistics.mean(per_iteration_tokens_per_second)
|
|
109
|
+
)
|
|
110
|
+
try:
|
|
111
|
+
self.std_dev_time_to_first_token_list.append(
|
|
112
|
+
statistics.stdev(per_iteration_time_to_first_token)
|
|
113
|
+
)
|
|
114
|
+
except StatisticsError:
|
|
115
|
+
# Less than 2 measurements
|
|
116
|
+
self.std_dev_time_to_first_token_list.append(None)
|
|
117
|
+
try:
|
|
118
|
+
self.std_dev_token_generation_tokens_per_second_list.append(
|
|
119
|
+
statistics.stdev(per_iteration_tokens_per_second)
|
|
120
|
+
)
|
|
121
|
+
except StatisticsError:
|
|
122
|
+
# Less than 2 measurements
|
|
123
|
+
self.std_dev_token_generation_tokens_per_second_list.append(None)
|
|
124
|
+
|
|
125
|
+
def save_stats(self, state):
|
|
126
|
+
super().save_stats(state)
|
|
127
|
+
|
|
128
|
+
# Save additional statistics
|
|
129
|
+
if not all(
|
|
130
|
+
element is None
|
|
131
|
+
for element in self.std_dev_token_generation_tokens_per_second_list
|
|
132
|
+
):
|
|
133
|
+
state.save_stat(
|
|
134
|
+
Keys.STD_DEV_TOKENS_PER_SECOND,
|
|
135
|
+
self.get_item_or_list(
|
|
136
|
+
self.std_dev_token_generation_tokens_per_second_list
|
|
137
|
+
),
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
# This file was originally licensed under Apache 2.0. It has been modified.
|
|
142
|
+
# Modifications Copyright (c) 2025 AMD
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import argparse
|
|
3
|
+
import pandas as pd
|
|
4
|
+
import torch
|
|
5
|
+
from datasets import load_dataset
|
|
6
|
+
from lemonade.state import State
|
|
7
|
+
from lemonade.tools import Tool
|
|
8
|
+
import lemonade.common.printing as printing
|
|
9
|
+
import lemonade.common.build as build
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class AccuracyPerplexity(Tool):
|
|
13
|
+
"""
|
|
14
|
+
Measure perplexity of an LLM using the Wikitext-2 dataset.
|
|
15
|
+
|
|
16
|
+
Required input state:
|
|
17
|
+
- state.model: instance that provides a __call__() method that returns
|
|
18
|
+
output.logits and supports model.config.max_position_embeddings
|
|
19
|
+
- state.tokenizer: instance of Hugging Face PretrainedTokenizer
|
|
20
|
+
|
|
21
|
+
Output state produced: None
|
|
22
|
+
|
|
23
|
+
See docs/lemonade/perplexity.md for more details.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
unique_name = "accuracy-perplexity"
|
|
27
|
+
|
|
28
|
+
def __init__(self):
|
|
29
|
+
super().__init__(monitor_message="Measuring perplexity")
|
|
30
|
+
|
|
31
|
+
@staticmethod
|
|
32
|
+
def parser(add_help: bool = True) -> argparse.ArgumentParser:
|
|
33
|
+
parser = __class__.helpful_parser(
|
|
34
|
+
short_description="Measure perplexity score",
|
|
35
|
+
add_help=add_help,
|
|
36
|
+
)
|
|
37
|
+
return parser
|
|
38
|
+
|
|
39
|
+
def run(
|
|
40
|
+
self,
|
|
41
|
+
state: State,
|
|
42
|
+
) -> State:
|
|
43
|
+
|
|
44
|
+
try:
|
|
45
|
+
printing.log_info("Downloading dataset ...")
|
|
46
|
+
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
|
|
47
|
+
except Exception as e: # pylint: disable=broad-except
|
|
48
|
+
printing.log_error(f"Error during dataset load: {e}")
|
|
49
|
+
raise e
|
|
50
|
+
|
|
51
|
+
tokenizer = state.tokenizer
|
|
52
|
+
model = state.model
|
|
53
|
+
# Tokenize the entire test dataset text, joining entries with double new lines
|
|
54
|
+
encodings = tokenizer("\n\n".join(dataset["text"]), return_tensors="pt")
|
|
55
|
+
|
|
56
|
+
# Retrieve the maximum input length that the model can handle
|
|
57
|
+
try:
|
|
58
|
+
max_length = model.config.max_position_embeddings
|
|
59
|
+
except AttributeError:
|
|
60
|
+
# Some LLMs do not have the config.max_position_embeddings attribute
|
|
61
|
+
# However, most LLMs support at least 2048 context length, so this
|
|
62
|
+
# try-except will allow a few more LLMs to work
|
|
63
|
+
max_length = 2048
|
|
64
|
+
# Set stride to half of the maximum input length for overlapping window processing
|
|
65
|
+
# Refer to docs/perplexity.md for more information on sliding window
|
|
66
|
+
stride = max_length // 2
|
|
67
|
+
# Determine the total sequence length of the tokenized input
|
|
68
|
+
seq_len = encodings.input_ids.size(1)
|
|
69
|
+
|
|
70
|
+
negative_log_likelihoods = []
|
|
71
|
+
summary_data = []
|
|
72
|
+
prev_end_location = 0
|
|
73
|
+
|
|
74
|
+
model_results_dir = os.path.join(
|
|
75
|
+
build.output_dir(state.cache_dir, state.build_name), "perplexity"
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
for begin_location in range(0, seq_len, stride):
|
|
79
|
+
end_location = min(begin_location + max_length, seq_len)
|
|
80
|
+
target_len = end_location - prev_end_location
|
|
81
|
+
input_ids = encodings.input_ids[:, begin_location:end_location]
|
|
82
|
+
target_ids = input_ids.clone()
|
|
83
|
+
target_ids[:, :-target_len] = -100
|
|
84
|
+
|
|
85
|
+
# Forward pass the model to get logits
|
|
86
|
+
with torch.no_grad():
|
|
87
|
+
try:
|
|
88
|
+
outputs = model(input_ids, labels=target_ids)
|
|
89
|
+
logits = outputs.logits
|
|
90
|
+
except Exception as e: # pylint: disable=broad-except
|
|
91
|
+
printing.log_error(
|
|
92
|
+
f"Error during model forward pass execution: {e}"
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
# Compute loss manually for visualization
|
|
96
|
+
shift_logits = logits[..., :-1, :].contiguous()
|
|
97
|
+
shift_labels = target_ids[..., 1:].contiguous()
|
|
98
|
+
effective_token_count = (target_ids != -100).sum().item()
|
|
99
|
+
negative_log_likelihoods.append(
|
|
100
|
+
(outputs.loss.item(), effective_token_count)
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# Decode predicted and actual next words for the last token position
|
|
104
|
+
predictions = torch.argmax(shift_logits, dim=-1)
|
|
105
|
+
predicted_tokens = predictions[:, -1]
|
|
106
|
+
actual_tokens = shift_labels[:, -1]
|
|
107
|
+
|
|
108
|
+
predicted_words = tokenizer.batch_decode(
|
|
109
|
+
predicted_tokens, skip_special_tokens=True
|
|
110
|
+
)
|
|
111
|
+
actual_words = tokenizer.batch_decode(
|
|
112
|
+
actual_tokens, skip_special_tokens=True
|
|
113
|
+
)
|
|
114
|
+
context = tokenizer.decode(input_ids[0, :])
|
|
115
|
+
|
|
116
|
+
summary_data.append(
|
|
117
|
+
{
|
|
118
|
+
"Context": context[-stride:],
|
|
119
|
+
"Predicted next word": predicted_words,
|
|
120
|
+
"Actual next word": actual_words,
|
|
121
|
+
"Loss for this window": outputs.loss.item(),
|
|
122
|
+
}
|
|
123
|
+
)
|
|
124
|
+
prev_end_location = end_location
|
|
125
|
+
|
|
126
|
+
# Total loss calculation considering the number of tokens for each segment
|
|
127
|
+
total_loss = sum(loss * count for loss, count in negative_log_likelihoods)
|
|
128
|
+
total_tokens = sum(count for _, count in negative_log_likelihoods)
|
|
129
|
+
|
|
130
|
+
# Calculate average negative_log_likelihood and perplexity
|
|
131
|
+
average_negative_log_likelihood = total_loss / total_tokens
|
|
132
|
+
perplexity = torch.exp(torch.tensor(average_negative_log_likelihood))
|
|
133
|
+
|
|
134
|
+
# Save accuracy results to stats file
|
|
135
|
+
state.save_stat("perplexity_score", float(perplexity.item()))
|
|
136
|
+
|
|
137
|
+
# Save accuracy results to CSV file
|
|
138
|
+
summary_df = pd.DataFrame(summary_data)
|
|
139
|
+
summary_df.to_csv(
|
|
140
|
+
os.path.join(model_results_dir, "summary_results.csv"), index=False
|
|
141
|
+
)
|
|
142
|
+
return state
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
# This file was originally licensed under Apache 2.0. It has been modified.
|
|
146
|
+
# Modifications Copyright (c) 2025 AMD
|
lemonade/tools/prompt.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import os
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
import lemonade.common.build as build
|
|
5
|
+
import lemonade.common.printing as printing
|
|
6
|
+
from lemonade.state import State
|
|
7
|
+
from lemonade.tools import Tool
|
|
8
|
+
from lemonade.tools.adapter import ModelAdapter, TokenizerAdapter
|
|
9
|
+
from lemonade.cache import Keys
|
|
10
|
+
|
|
11
|
+
DEFAULT_GENERATE_PARAMS = {
|
|
12
|
+
"do_sample": True,
|
|
13
|
+
"top_k": 50,
|
|
14
|
+
"top_p": 0.95,
|
|
15
|
+
"temperature": 0.7,
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
DEFAULT_MAX_NEW_TOKENS = 512
|
|
19
|
+
DEFAULT_N_TRIALS = 1
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def sanitize_string(input_string):
|
|
23
|
+
return input_string.encode("charmap", "ignore").decode("charmap")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def sanitize_text(text):
|
|
27
|
+
if isinstance(text, str):
|
|
28
|
+
return sanitize_string(text)
|
|
29
|
+
elif isinstance(text, list):
|
|
30
|
+
return [sanitize_string(item) for item in text]
|
|
31
|
+
else:
|
|
32
|
+
raise TypeError("Input must be a string or a list of strings.")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def positive_int(x):
|
|
36
|
+
"""Conversion function for argparse"""
|
|
37
|
+
i = int(x)
|
|
38
|
+
if i < 1:
|
|
39
|
+
raise ValueError("Non-positive values are not allowed")
|
|
40
|
+
return i
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class LLMPrompt(Tool):
|
|
44
|
+
"""
|
|
45
|
+
Send a prompt to an LLM instance and print the response to the screen.
|
|
46
|
+
|
|
47
|
+
Required input state:
|
|
48
|
+
- state.model: LLM instance that supports the generate() method.
|
|
49
|
+
- state.tokenizer: LLM tokenizer instance that supports the __call__() (ie, encode)
|
|
50
|
+
and decode() methods.
|
|
51
|
+
|
|
52
|
+
Output state produced:
|
|
53
|
+
- "response": text response from the LLM.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
unique_name = "llm-prompt"
|
|
57
|
+
|
|
58
|
+
def __init__(self):
|
|
59
|
+
super().__init__(monitor_message="Prompting LLM")
|
|
60
|
+
|
|
61
|
+
self.status_stats = [
|
|
62
|
+
Keys.PROMPT_TOKENS,
|
|
63
|
+
Keys.PROMPT,
|
|
64
|
+
Keys.PROMPT_TEMPLATE,
|
|
65
|
+
Keys.RESPONSE_TOKENS,
|
|
66
|
+
Keys.RESPONSE,
|
|
67
|
+
Keys.RESPONSE_LENGTHS_HISTOGRAM,
|
|
68
|
+
]
|
|
69
|
+
|
|
70
|
+
@staticmethod
|
|
71
|
+
def parser(add_help: bool = True) -> argparse.ArgumentParser:
|
|
72
|
+
parser = __class__.helpful_parser(
|
|
73
|
+
short_description="Prompt an LLM and print the result",
|
|
74
|
+
add_help=add_help,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
parser.add_argument(
|
|
78
|
+
"--prompt",
|
|
79
|
+
"-p",
|
|
80
|
+
help="Input prompt to the LLM. Two formats are supported: "
|
|
81
|
+
"1) str: use a user-provided prompt string, and "
|
|
82
|
+
"2) path/to/prompt.txt: load the prompt from a .txt file.",
|
|
83
|
+
required=True,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
parser.add_argument(
|
|
87
|
+
"--template",
|
|
88
|
+
"-t",
|
|
89
|
+
action="store_true",
|
|
90
|
+
help="Insert the prompt into the model's chat template before processing.",
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
parser.add_argument(
|
|
94
|
+
"--max-new-tokens",
|
|
95
|
+
"-m",
|
|
96
|
+
default=DEFAULT_MAX_NEW_TOKENS,
|
|
97
|
+
type=int,
|
|
98
|
+
help=f"Maximum number of new tokens in the response "
|
|
99
|
+
f"(default is {DEFAULT_MAX_NEW_TOKENS})",
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
parser.add_argument(
|
|
103
|
+
"--n-trials",
|
|
104
|
+
"-n",
|
|
105
|
+
default=DEFAULT_N_TRIALS,
|
|
106
|
+
type=positive_int,
|
|
107
|
+
help=f"Number of responses the LLM will generate for the prompt "
|
|
108
|
+
f"(useful for testing, default is {DEFAULT_N_TRIALS})",
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
return parser
|
|
112
|
+
|
|
113
|
+
def parse(self, state: State, args, known_only=True) -> argparse.Namespace:
|
|
114
|
+
"""
|
|
115
|
+
Helper function to parse CLI arguments into the args expected
|
|
116
|
+
by run()
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
parsed_args = super().parse(state, args, known_only)
|
|
120
|
+
|
|
121
|
+
# Decode prompt arg into a string prompt
|
|
122
|
+
if parsed_args.prompt.endswith(".txt") and os.path.exists(parsed_args.prompt):
|
|
123
|
+
with open(parsed_args.prompt, "r", encoding="utf-8") as f:
|
|
124
|
+
parsed_args.prompt = f.read()
|
|
125
|
+
|
|
126
|
+
return parsed_args
|
|
127
|
+
|
|
128
|
+
def run(
|
|
129
|
+
self,
|
|
130
|
+
state: State,
|
|
131
|
+
prompt: str = "Hello",
|
|
132
|
+
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
|
|
133
|
+
n_trials: int = DEFAULT_N_TRIALS,
|
|
134
|
+
template: bool = False,
|
|
135
|
+
) -> State:
|
|
136
|
+
|
|
137
|
+
model: ModelAdapter = state.model
|
|
138
|
+
tokenizer: TokenizerAdapter = state.tokenizer
|
|
139
|
+
|
|
140
|
+
# If template flag is set, then wrap prompt in template
|
|
141
|
+
if template:
|
|
142
|
+
# Embed prompt in model's chat template
|
|
143
|
+
if tokenizer.chat_template:
|
|
144
|
+
# Use the model's built-in chat template if available
|
|
145
|
+
messages_dict = [{"role": "user", "content": prompt}]
|
|
146
|
+
prompt = tokenizer.apply_chat_template(
|
|
147
|
+
messages_dict, tokenize=False, add_generation_prompt=True
|
|
148
|
+
)
|
|
149
|
+
state.save_stat(Keys.PROMPT_TEMPLATE, "Model-specific")
|
|
150
|
+
else:
|
|
151
|
+
# Fallback to a standardized template
|
|
152
|
+
printing.log_info("No chat template found. Using default template.")
|
|
153
|
+
prompt = f"<|user|>\n{prompt} <|end|>\n<|assistant|>"
|
|
154
|
+
state.save_stat(Keys.PROMPT_TEMPLATE, "Default")
|
|
155
|
+
|
|
156
|
+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
|
157
|
+
if isinstance(input_ids, (list, str)):
|
|
158
|
+
# OGA models return a list of tokens
|
|
159
|
+
# Our llama.cpp adapter returns a string
|
|
160
|
+
len_tokens_in = len(input_ids)
|
|
161
|
+
else:
|
|
162
|
+
# HF models return a 2-D tensor
|
|
163
|
+
len_tokens_in = input_ids.shape[1]
|
|
164
|
+
|
|
165
|
+
len_tokens_out = []
|
|
166
|
+
response_texts = []
|
|
167
|
+
for trial in range(n_trials):
|
|
168
|
+
if n_trials > 1:
|
|
169
|
+
self.set_percent_progress(100.0 * trial / n_trials)
|
|
170
|
+
|
|
171
|
+
# Get the response from the LLM, which may include the prompt in it
|
|
172
|
+
response = model.generate(
|
|
173
|
+
input_ids, max_new_tokens=max_new_tokens, **DEFAULT_GENERATE_PARAMS
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
# Flatten the input and response
|
|
177
|
+
input_ids_array = (
|
|
178
|
+
input_ids if isinstance(input_ids, (list, str)) else input_ids[0]
|
|
179
|
+
)
|
|
180
|
+
response_array = response if isinstance(response, str) else response[0]
|
|
181
|
+
|
|
182
|
+
# Separate the prompt from the response
|
|
183
|
+
len_tokens_out.append(len(response_array) - len_tokens_in)
|
|
184
|
+
|
|
185
|
+
input_token = 0
|
|
186
|
+
while (
|
|
187
|
+
input_token < len_tokens_in
|
|
188
|
+
and input_ids_array[input_token] == response_array[input_token]
|
|
189
|
+
):
|
|
190
|
+
input_token += 1
|
|
191
|
+
|
|
192
|
+
# Only decode the actual response (not the prompt)
|
|
193
|
+
response_text = tokenizer.decode(
|
|
194
|
+
response_array[input_token:], skip_special_tokens=True
|
|
195
|
+
).strip()
|
|
196
|
+
response_texts.append(response_text)
|
|
197
|
+
|
|
198
|
+
state.response = response_texts
|
|
199
|
+
|
|
200
|
+
if n_trials == 1:
|
|
201
|
+
len_tokens_out = len_tokens_out[0]
|
|
202
|
+
response_texts = response_texts[0]
|
|
203
|
+
else:
|
|
204
|
+
self.set_percent_progress(None)
|
|
205
|
+
|
|
206
|
+
# Plot data
|
|
207
|
+
plt.figure()
|
|
208
|
+
plt.hist(len_tokens_out, bins=20)
|
|
209
|
+
plt.xlabel("Response Length (tokens)")
|
|
210
|
+
plt.ylabel("Frequency")
|
|
211
|
+
plt.title(f"Histogram of Response Lengths\n{state.build_name}")
|
|
212
|
+
figure_path = os.path.join(
|
|
213
|
+
build.output_dir(state.cache_dir, state.build_name),
|
|
214
|
+
"response_lengths.png",
|
|
215
|
+
)
|
|
216
|
+
plt.savefig(figure_path)
|
|
217
|
+
state.save_stat(Keys.RESPONSE_LENGTHS_HISTOGRAM, figure_path)
|
|
218
|
+
|
|
219
|
+
state.save_stat(Keys.PROMPT_TOKENS, len_tokens_in)
|
|
220
|
+
state.save_stat(Keys.PROMPT, prompt)
|
|
221
|
+
state.save_stat(Keys.RESPONSE_TOKENS, len_tokens_out)
|
|
222
|
+
state.save_stat(Keys.RESPONSE, sanitize_text(response_texts))
|
|
223
|
+
|
|
224
|
+
return state
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
# This file was originally licensed under Apache 2.0. It has been modified.
|
|
228
|
+
# Modifications Copyright (c) 2025 AMD
|
|
File without changes
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import os
|
|
3
|
+
import sys
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from lemonade.state import State
|
|
7
|
+
from lemonade.tools import Tool
|
|
8
|
+
import lemonade.common.printing as printing
|
|
9
|
+
import lemonade.common.build as build
|
|
10
|
+
from lemonade_install.install import DEFAULT_QUARK_DIR
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class QuarkLoad(Tool):
|
|
14
|
+
"""
|
|
15
|
+
Load a model Quantized and exported using Quark.
|
|
16
|
+
Required Input State:
|
|
17
|
+
- state.model: Pretrained model instance to be quantized.
|
|
18
|
+
- state.tokenizer: Tokenizer instance from Hugging Face.
|
|
19
|
+
Output:
|
|
20
|
+
- state of the loaded model
|
|
21
|
+
|
|
22
|
+
See docs/quark.md for more details.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
unique_name = "quark-load"
|
|
26
|
+
|
|
27
|
+
def __init__(self):
|
|
28
|
+
super().__init__(monitor_message="Load Quark Quantized model")
|
|
29
|
+
|
|
30
|
+
@staticmethod
|
|
31
|
+
def parser(add_help: bool = True) -> argparse.ArgumentParser:
|
|
32
|
+
parser = __class__.helpful_parser(
|
|
33
|
+
short_description="Load a quantized model using Quark",
|
|
34
|
+
add_help=add_help,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
parser.add_argument(
|
|
38
|
+
"--quant-scheme",
|
|
39
|
+
type=str,
|
|
40
|
+
required=True,
|
|
41
|
+
default=None,
|
|
42
|
+
help="Supported quantization schemes in Quark",
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
parser.add_argument(
|
|
46
|
+
"--quant-algo",
|
|
47
|
+
type=str,
|
|
48
|
+
required=True,
|
|
49
|
+
default=None,
|
|
50
|
+
choices=["awq", "gptq", "autosmoothquant", None],
|
|
51
|
+
help="Supported quantization algorithms in Quark",
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
parser.add_argument(
|
|
55
|
+
"--torch-compile", action="store_true", help="Model torch compile"
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
parser.add_argument(
|
|
59
|
+
"--safetensors-model-reload",
|
|
60
|
+
action="store_true",
|
|
61
|
+
help="Safetensors model reload",
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
parser.add_argument(
|
|
65
|
+
"--safetensors-model-dir",
|
|
66
|
+
default=None,
|
|
67
|
+
help="Directory of safetensors model",
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
parser.add_argument(
|
|
71
|
+
"--params-load", action="store_true", help="Model parameters load"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
parser.add_argument("--json-path", help="Specify the path of saved json file")
|
|
75
|
+
|
|
76
|
+
parser.add_argument(
|
|
77
|
+
"--safetensors-path",
|
|
78
|
+
default=None,
|
|
79
|
+
help="Specify the path of saved safetensors file",
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
return parser
|
|
83
|
+
|
|
84
|
+
def run(
|
|
85
|
+
self,
|
|
86
|
+
state: State,
|
|
87
|
+
quant_scheme: str,
|
|
88
|
+
quant_algo: str,
|
|
89
|
+
torch_compile: bool = False,
|
|
90
|
+
safetensors_model_reload: bool = False,
|
|
91
|
+
safetensors_model_dir: str = None,
|
|
92
|
+
params_load: bool = False,
|
|
93
|
+
json_path: str = None,
|
|
94
|
+
safetensors_path: str = None,
|
|
95
|
+
) -> State:
|
|
96
|
+
"""
|
|
97
|
+
Executes the QuarkLoad process.
|
|
98
|
+
Returns:
|
|
99
|
+
State: The updated state after loading the model.
|
|
100
|
+
Raises:
|
|
101
|
+
Exception: If an error occurs during the QuarkLoad process.
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
try:
|
|
105
|
+
if os.path.isdir(DEFAULT_QUARK_DIR):
|
|
106
|
+
quark_llm_path = os.path.join(
|
|
107
|
+
DEFAULT_QUARK_DIR, "examples", "torch", "language_modeling"
|
|
108
|
+
)
|
|
109
|
+
sys.path.insert(0, quark_llm_path)
|
|
110
|
+
else:
|
|
111
|
+
raise FileNotFoundError(
|
|
112
|
+
f"The directory {DEFAULT_QUARK_DIR} does not exist. \
|
|
113
|
+
Please check your installation."
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# Default load path specific to recipe
|
|
117
|
+
# This will NOT work
|
|
118
|
+
# The default path is now uniquely craeated with timestamp
|
|
119
|
+
# Default load path will not work. Need to pass explicit load path
|
|
120
|
+
model_export_path = os.path.join(
|
|
121
|
+
build.output_dir(state.cache_dir, state.build_name),
|
|
122
|
+
"exported_model",
|
|
123
|
+
quant_scheme,
|
|
124
|
+
quant_algo,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Set default paths only if current values are None
|
|
128
|
+
if safetensors_model_dir is None:
|
|
129
|
+
safetensors_model_dir = model_export_path
|
|
130
|
+
if safetensors_path is None:
|
|
131
|
+
safetensors_path = os.path.join(model_export_path, "model.safetensors")
|
|
132
|
+
printing.log_info("Loading model ...")
|
|
133
|
+
if not params_load and not safetensors_model_reload:
|
|
134
|
+
raise ValueError(
|
|
135
|
+
" Specify load format: 'params_load' or 'safetensors_model_reload'."
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
# Reload quantized model if specified
|
|
139
|
+
from quark.torch import load_params, import_model_info
|
|
140
|
+
|
|
141
|
+
if params_load:
|
|
142
|
+
printing.log_info(
|
|
143
|
+
"Restoring quantized model from JSON/safetensors files"
|
|
144
|
+
)
|
|
145
|
+
model = load_params(
|
|
146
|
+
model,
|
|
147
|
+
json_path=json_path,
|
|
148
|
+
safetensors_path=safetensors_path,
|
|
149
|
+
)
|
|
150
|
+
elif safetensors_model_reload:
|
|
151
|
+
printing.log_info(
|
|
152
|
+
"Restoring quantized model from quark_safetensors files"
|
|
153
|
+
)
|
|
154
|
+
model = import_model_info(model, model_info_dir=safetensors_model_dir)
|
|
155
|
+
|
|
156
|
+
if torch_compile:
|
|
157
|
+
printing.log_info("torch.compile...")
|
|
158
|
+
model = torch.compile(model)
|
|
159
|
+
|
|
160
|
+
state.model = model
|
|
161
|
+
state.dtype = model.dtype
|
|
162
|
+
|
|
163
|
+
printing.log_info("Quark Load process completed.")
|
|
164
|
+
|
|
165
|
+
except Exception as e:
|
|
166
|
+
printing.log_error(f"An error occurred during the QuarkLoad process: {e}")
|
|
167
|
+
raise
|
|
168
|
+
return state
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
# This file was originally licensed under Apache 2.0. It has been modified.
|
|
172
|
+
# Modifications Copyright (c) 2025 AMD
|