lemonade-sdk 7.0.3__py3-none-any.whl → 8.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lemonade-sdk might be problematic. Click here for more details.
- lemonade/api.py +3 -3
- lemonade/cli.py +11 -17
- lemonade/common/build.py +0 -47
- lemonade/common/network.py +50 -0
- lemonade/common/status.py +2 -21
- lemonade/common/system_info.py +19 -4
- lemonade/profilers/memory_tracker.py +3 -1
- lemonade/tools/accuracy.py +3 -4
- lemonade/tools/adapter.py +1 -2
- lemonade/tools/{huggingface_bench.py → huggingface/bench.py} +2 -87
- lemonade/tools/huggingface/load.py +235 -0
- lemonade/tools/{huggingface_load.py → huggingface/utils.py} +87 -255
- lemonade/tools/humaneval.py +9 -3
- lemonade/tools/{llamacpp_bench.py → llamacpp/bench.py} +1 -1
- lemonade/tools/{llamacpp.py → llamacpp/load.py} +18 -2
- lemonade/tools/mmlu.py +7 -15
- lemonade/tools/{ort_genai/oga.py → oga/load.py} +31 -422
- lemonade/tools/oga/utils.py +423 -0
- lemonade/tools/perplexity.py +4 -3
- lemonade/tools/prompt.py +2 -1
- lemonade/tools/quark/quark_load.py +2 -1
- lemonade/tools/quark/quark_quantize.py +5 -5
- lemonade/tools/report/table.py +3 -3
- lemonade/tools/server/llamacpp.py +159 -34
- lemonade/tools/server/serve.py +169 -147
- lemonade/tools/server/static/favicon.ico +0 -0
- lemonade/tools/server/static/styles.css +568 -0
- lemonade/tools/server/static/webapp.html +439 -0
- lemonade/tools/server/tray.py +458 -0
- lemonade/tools/server/{port_utils.py → utils/port.py} +22 -3
- lemonade/tools/server/utils/system_tray.py +395 -0
- lemonade/tools/server/{instructions.py → webapp.py} +4 -10
- lemonade/version.py +1 -1
- lemonade_install/install.py +46 -28
- {lemonade_sdk-7.0.3.dist-info → lemonade_sdk-8.0.0.dist-info}/METADATA +84 -22
- lemonade_sdk-8.0.0.dist-info/RECORD +70 -0
- lemonade_server/cli.py +182 -27
- lemonade_server/model_manager.py +192 -20
- lemonade_server/pydantic_models.py +9 -4
- lemonade_server/server_models.json +5 -3
- lemonade/common/analyze_model.py +0 -26
- lemonade/common/labels.py +0 -61
- lemonade/common/onnx_helpers.py +0 -176
- lemonade/common/plugins.py +0 -10
- lemonade/common/tensor_helpers.py +0 -83
- lemonade/tools/server/static/instructions.html +0 -262
- lemonade_sdk-7.0.3.dist-info/RECORD +0 -69
- /lemonade/tools/{ort_genai → oga}/__init__.py +0 -0
- /lemonade/tools/{ort_genai/oga_bench.py → oga/bench.py} +0 -0
- /lemonade/tools/server/{thread_utils.py → utils/thread.py} +0 -0
- {lemonade_sdk-7.0.3.dist-info → lemonade_sdk-8.0.0.dist-info}/WHEEL +0 -0
- {lemonade_sdk-7.0.3.dist-info → lemonade_sdk-8.0.0.dist-info}/entry_points.txt +0 -0
- {lemonade_sdk-7.0.3.dist-info → lemonade_sdk-8.0.0.dist-info}/licenses/LICENSE +0 -0
- {lemonade_sdk-7.0.3.dist-info → lemonade_sdk-8.0.0.dist-info}/licenses/NOTICE.md +0 -0
- {lemonade_sdk-7.0.3.dist-info → lemonade_sdk-8.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,423 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import time
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
from queue import Queue
|
|
6
|
+
from packaging.version import Version
|
|
7
|
+
import onnxruntime_genai as og
|
|
8
|
+
from transformers import AutoTokenizer
|
|
9
|
+
from lemonade.tools.adapter import (
|
|
10
|
+
ModelAdapter,
|
|
11
|
+
TokenizerAdapter,
|
|
12
|
+
PassthroughTokenizerResult,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class OrtGenaiTokenizer(TokenizerAdapter):
|
|
17
|
+
def __init__(self, model: og.Model, hf_tokenizer: AutoTokenizer):
|
|
18
|
+
super().__init__(hf_tokenizer)
|
|
19
|
+
# Initialize OGA tokenizer
|
|
20
|
+
self.tokenizer = og.Tokenizer(model)
|
|
21
|
+
|
|
22
|
+
# Placeholder value since some code will try to query it
|
|
23
|
+
# If we actually need this to return a proper value, then
|
|
24
|
+
# og.GeneratorParams.eos_token_id has it
|
|
25
|
+
self.eos_token_id = None
|
|
26
|
+
|
|
27
|
+
def __call__(self, prompt: str, return_tensors="np"):
|
|
28
|
+
tokens = self.tokenizer.encode(prompt)
|
|
29
|
+
return PassthroughTokenizerResult(tokens)
|
|
30
|
+
|
|
31
|
+
# pylint: disable=unused-argument
|
|
32
|
+
def decode(self, response, skip_special_tokens=True) -> str:
|
|
33
|
+
return self.tokenizer.decode(response)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class OrtGenaiStreamer:
|
|
37
|
+
def __init__(self, tokenizer: OrtGenaiTokenizer, timeout=None):
|
|
38
|
+
self.tokenizer = tokenizer
|
|
39
|
+
self.text_queue = Queue()
|
|
40
|
+
self.stop_signal = None
|
|
41
|
+
self.timeout = timeout
|
|
42
|
+
|
|
43
|
+
def add_text(self, text: str):
|
|
44
|
+
self.text_queue.put(text, timeout=self.timeout)
|
|
45
|
+
|
|
46
|
+
def done(self):
|
|
47
|
+
self.text_queue.put(self.stop_signal, timeout=self.timeout)
|
|
48
|
+
|
|
49
|
+
def __iter__(self):
|
|
50
|
+
return self
|
|
51
|
+
|
|
52
|
+
def __next__(self):
|
|
53
|
+
value = self.text_queue.get(timeout=self.timeout)
|
|
54
|
+
if value == self.stop_signal:
|
|
55
|
+
raise StopIteration()
|
|
56
|
+
else:
|
|
57
|
+
return value
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class OrtGenaiModel(ModelAdapter):
|
|
61
|
+
|
|
62
|
+
def __init__(self, input_folder):
|
|
63
|
+
super().__init__()
|
|
64
|
+
self.model = og.Model(input_folder)
|
|
65
|
+
self.type = "ort-genai"
|
|
66
|
+
self.config = self.load_config(input_folder)
|
|
67
|
+
|
|
68
|
+
def load_config(self, input_folder):
|
|
69
|
+
rai_config_path = os.path.join(input_folder, "rai_config.json")
|
|
70
|
+
if os.path.exists(rai_config_path):
|
|
71
|
+
with open(rai_config_path, "r", encoding="utf-8") as f:
|
|
72
|
+
max_prompt_length = json.load(f)["max_prompt_length"]["1.4.1"]
|
|
73
|
+
else:
|
|
74
|
+
max_prompt_length = None
|
|
75
|
+
|
|
76
|
+
config_path = os.path.join(input_folder, "genai_config.json")
|
|
77
|
+
if os.path.exists(config_path):
|
|
78
|
+
with open(config_path, "r", encoding="utf-8") as f:
|
|
79
|
+
config_dict = json.load(f)
|
|
80
|
+
if max_prompt_length:
|
|
81
|
+
config_dict["max_prompt_length"] = max_prompt_length
|
|
82
|
+
return config_dict
|
|
83
|
+
return None
|
|
84
|
+
|
|
85
|
+
def generate(
|
|
86
|
+
self,
|
|
87
|
+
input_ids,
|
|
88
|
+
max_new_tokens=512,
|
|
89
|
+
min_new_tokens=0,
|
|
90
|
+
do_sample=True,
|
|
91
|
+
top_k=50,
|
|
92
|
+
top_p=1.0,
|
|
93
|
+
temperature=0.7,
|
|
94
|
+
streamer: OrtGenaiStreamer = None,
|
|
95
|
+
pad_token_id=None,
|
|
96
|
+
stopping_criteria=None,
|
|
97
|
+
max_length=None,
|
|
98
|
+
random_seed=1,
|
|
99
|
+
):
|
|
100
|
+
params = og.GeneratorParams(self.model)
|
|
101
|
+
|
|
102
|
+
prompt_length = len(input_ids)
|
|
103
|
+
max_prompt_length = self.config.get("max_prompt_length")
|
|
104
|
+
if max_prompt_length and prompt_length > max_prompt_length:
|
|
105
|
+
raise ValueError(
|
|
106
|
+
f"This prompt (length {prompt_length}) exceeds the model's "
|
|
107
|
+
f"maximum allowed prompt length ({max_prompt_length})."
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
# There is a breaking API change in OGA 0.6.0
|
|
111
|
+
# Determine whether we should use the old or new APIs
|
|
112
|
+
# This also supports 0.6.0.dev0, which evaluates to less than 0.6.0 in Version
|
|
113
|
+
use_oga_post_6_api = (
|
|
114
|
+
Version(og.__version__) >= Version("0.6.0") or "0.6.0" in og.__version__
|
|
115
|
+
)
|
|
116
|
+
use_oga_pre_6_api = not use_oga_post_6_api
|
|
117
|
+
|
|
118
|
+
if pad_token_id:
|
|
119
|
+
params.pad_token_id = pad_token_id
|
|
120
|
+
|
|
121
|
+
# Handle max_length and max_new_tokens
|
|
122
|
+
if max_length and max_new_tokens:
|
|
123
|
+
logging.warning(
|
|
124
|
+
"Both max_length and max_new_tokens were provided. "
|
|
125
|
+
"max_length will take precedence. "
|
|
126
|
+
"When setting max_length, please explicitly set max_new_tokens to None."
|
|
127
|
+
)
|
|
128
|
+
max_length_to_use = None
|
|
129
|
+
if max_length:
|
|
130
|
+
max_length_to_use = max_length
|
|
131
|
+
elif max_new_tokens:
|
|
132
|
+
max_length_to_use = prompt_length + max_new_tokens
|
|
133
|
+
|
|
134
|
+
min_length = prompt_length + min_new_tokens
|
|
135
|
+
|
|
136
|
+
if use_oga_pre_6_api:
|
|
137
|
+
params.input_ids = input_ids
|
|
138
|
+
|
|
139
|
+
if random_seed is None:
|
|
140
|
+
random_seed = -1 # In og.Generator, -1 = seed with random device
|
|
141
|
+
|
|
142
|
+
if self.config and "search" in self.config:
|
|
143
|
+
search_config = self.config["search"]
|
|
144
|
+
params.set_search_options(
|
|
145
|
+
do_sample=search_config.get("do_sample", do_sample),
|
|
146
|
+
top_k=search_config.get("top_k", top_k),
|
|
147
|
+
top_p=search_config.get("top_p", top_p),
|
|
148
|
+
temperature=search_config.get("temperature", temperature),
|
|
149
|
+
max_length=max_length_to_use,
|
|
150
|
+
min_length=min_length,
|
|
151
|
+
early_stopping=search_config.get("early_stopping", False),
|
|
152
|
+
length_penalty=search_config.get("length_penalty", 1.0),
|
|
153
|
+
num_beams=search_config.get("num_beams", 1),
|
|
154
|
+
num_return_sequences=search_config.get("num_return_sequences", 1),
|
|
155
|
+
repetition_penalty=search_config.get("repetition_penalty", 1.0),
|
|
156
|
+
past_present_share_buffer=search_config.get(
|
|
157
|
+
"past_present_share_buffer", True
|
|
158
|
+
),
|
|
159
|
+
random_seed=random_seed,
|
|
160
|
+
# Not currently supported by OGA
|
|
161
|
+
# diversity_penalty=search_config.get('diversity_penalty', 0.0),
|
|
162
|
+
# no_repeat_ngram_size=search_config.get('no_repeat_ngram_size', 0),
|
|
163
|
+
)
|
|
164
|
+
else:
|
|
165
|
+
params.set_search_options(
|
|
166
|
+
do_sample=do_sample,
|
|
167
|
+
top_k=top_k,
|
|
168
|
+
top_p=top_p,
|
|
169
|
+
temperature=temperature,
|
|
170
|
+
max_length=max_length_to_use,
|
|
171
|
+
min_length=min_length,
|
|
172
|
+
random_seed=random_seed,
|
|
173
|
+
)
|
|
174
|
+
params.try_graph_capture_with_max_batch_size(1)
|
|
175
|
+
|
|
176
|
+
generator = og.Generator(self.model, params)
|
|
177
|
+
|
|
178
|
+
if streamer is None:
|
|
179
|
+
prompt_start_time = time.perf_counter()
|
|
180
|
+
if use_oga_post_6_api:
|
|
181
|
+
generator.append_tokens(input_ids)
|
|
182
|
+
if use_oga_pre_6_api:
|
|
183
|
+
generator.compute_logits()
|
|
184
|
+
generator.generate_next_token()
|
|
185
|
+
prompt_end_time = time.perf_counter()
|
|
186
|
+
|
|
187
|
+
self.time_to_first_token = prompt_end_time - prompt_start_time
|
|
188
|
+
|
|
189
|
+
if max_new_tokens > 1:
|
|
190
|
+
|
|
191
|
+
token_gen_times = []
|
|
192
|
+
while not generator.is_done():
|
|
193
|
+
token_gen_start_time = time.perf_counter()
|
|
194
|
+
if use_oga_pre_6_api:
|
|
195
|
+
generator.compute_logits()
|
|
196
|
+
generator.generate_next_token()
|
|
197
|
+
token_gen_end_time = time.perf_counter()
|
|
198
|
+
|
|
199
|
+
token_gen_times.append(token_gen_end_time - token_gen_start_time)
|
|
200
|
+
|
|
201
|
+
if token_gen_times:
|
|
202
|
+
# List will be empty if we generated 1 or 0 tokens, and we don't
|
|
203
|
+
# want a divide-by-zero error in those cases
|
|
204
|
+
avg_token_gen_latency_s = sum(token_gen_times) / len(
|
|
205
|
+
token_gen_times
|
|
206
|
+
)
|
|
207
|
+
self.tokens_per_second = 1 / avg_token_gen_latency_s
|
|
208
|
+
|
|
209
|
+
return [generator.get_sequence(0)]
|
|
210
|
+
else:
|
|
211
|
+
if use_oga_post_6_api:
|
|
212
|
+
generator.append_tokens(input_ids)
|
|
213
|
+
tokenizer_stream = streamer.tokenizer.tokenizer.create_stream()
|
|
214
|
+
|
|
215
|
+
stop_early = False
|
|
216
|
+
|
|
217
|
+
while not generator.is_done() and not stop_early:
|
|
218
|
+
if use_oga_pre_6_api:
|
|
219
|
+
generator.compute_logits()
|
|
220
|
+
generator.generate_next_token()
|
|
221
|
+
|
|
222
|
+
new_token = generator.get_next_tokens()[0]
|
|
223
|
+
new_text = tokenizer_stream.decode(new_token)
|
|
224
|
+
|
|
225
|
+
streamer.add_text(new_text)
|
|
226
|
+
|
|
227
|
+
if stopping_criteria is not None:
|
|
228
|
+
if stopping_criteria[0].stop_event.is_set():
|
|
229
|
+
stop_early = True
|
|
230
|
+
|
|
231
|
+
streamer.done()
|
|
232
|
+
|
|
233
|
+
def _model_call(self, input_ids):
|
|
234
|
+
"""
|
|
235
|
+
Run the model on input_ids and get logits.
|
|
236
|
+
|
|
237
|
+
This method directly accesses model logits rather than using the full generate pipeline for
|
|
238
|
+
several important reasons:
|
|
239
|
+
1. Purpose: We need raw logits from a single forward pass, while generate() is optimized for
|
|
240
|
+
producing multiple tokens through iterative inference
|
|
241
|
+
2. Efficiency: Direct access is more efficient for logprob calculations with no
|
|
242
|
+
sampling overhead
|
|
243
|
+
3. Precision: Logprob calculations require exact control over input-to-output mapping
|
|
244
|
+
4. Consistency: Similar approach used in both HF and OGA implementations
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
input_ids: Input token IDs
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
Logits for each token in the sequence
|
|
251
|
+
"""
|
|
252
|
+
import torch
|
|
253
|
+
|
|
254
|
+
# Setup generator params
|
|
255
|
+
params = og.GeneratorParams(self.model)
|
|
256
|
+
|
|
257
|
+
# Configure for a simple forward pass
|
|
258
|
+
params.set_search_options(
|
|
259
|
+
do_sample=False,
|
|
260
|
+
temperature=0.0,
|
|
261
|
+
max_length=len(input_ids),
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
# Initialize generator
|
|
265
|
+
generator = og.Generator(self.model, params)
|
|
266
|
+
|
|
267
|
+
# Feed tokens to model based on API version
|
|
268
|
+
generator.append_tokens(input_ids)
|
|
269
|
+
|
|
270
|
+
# Extract logits - this returns a list of logits tensors
|
|
271
|
+
logits = generator.get_output("logits")
|
|
272
|
+
|
|
273
|
+
# Convert to torch tensor for easier processing
|
|
274
|
+
return torch.tensor(logits[0])
|
|
275
|
+
|
|
276
|
+
def _select_cont_toks(self, logits, context_len, continuation_tokens):
|
|
277
|
+
"""
|
|
278
|
+
Select and process logits for continuation tokens.
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
logits: Full sequence logits
|
|
282
|
+
context_len: Length of context tokens
|
|
283
|
+
continuation_tokens: List or tensor of continuation token IDs
|
|
284
|
+
|
|
285
|
+
Returns:
|
|
286
|
+
Log probabilities for continuation tokens
|
|
287
|
+
"""
|
|
288
|
+
import torch
|
|
289
|
+
|
|
290
|
+
# Extract relevant logits for continuation prediction (shift by one)
|
|
291
|
+
cont_logits = logits[
|
|
292
|
+
context_len - 1 : context_len - 1 + len(continuation_tokens)
|
|
293
|
+
]
|
|
294
|
+
|
|
295
|
+
# Convert to torch tensors if needed
|
|
296
|
+
if not isinstance(continuation_tokens, torch.Tensor):
|
|
297
|
+
continuation_tokens = torch.tensor(continuation_tokens, dtype=torch.long)
|
|
298
|
+
|
|
299
|
+
# Apply log softmax to get log probabilities
|
|
300
|
+
log_probs = torch.log_softmax(cont_logits, dim=-1)
|
|
301
|
+
|
|
302
|
+
# Get log probs for the specific continuation tokens
|
|
303
|
+
token_log_probs = torch.gather(
|
|
304
|
+
log_probs, 1, continuation_tokens.unsqueeze(-1)
|
|
305
|
+
).squeeze(-1)
|
|
306
|
+
|
|
307
|
+
return token_log_probs
|
|
308
|
+
|
|
309
|
+
def compute_logprobs(
|
|
310
|
+
self, text, tokenizer, prompt_length=None, logprobs=None, echo=False
|
|
311
|
+
):
|
|
312
|
+
"""
|
|
313
|
+
Compute log probabilities for all tokens in the given text.
|
|
314
|
+
|
|
315
|
+
Args:
|
|
316
|
+
text: The full text to analyze (e.g., prompt + completion)
|
|
317
|
+
prompt_length: Number of tokens in the prompt. If provided and echo=False,
|
|
318
|
+
only completion tokens after this position will be returned.
|
|
319
|
+
logprobs: If not None, return log probabilities. Value indicates how many top
|
|
320
|
+
alternatives to return. If True but not an integer, defaults to 5 alternatives.
|
|
321
|
+
echo: If True, include logprobs for prompt tokens. If False, only return logprobs
|
|
322
|
+
for completion tokens.
|
|
323
|
+
|
|
324
|
+
Returns:
|
|
325
|
+
- text_offset: Character offsets for each token in the text
|
|
326
|
+
- token_logprobs: Log probability for each token
|
|
327
|
+
- tokens: The actual tokens used
|
|
328
|
+
- top_logprobs: Top alternative log probabilities for each position
|
|
329
|
+
"""
|
|
330
|
+
import torch
|
|
331
|
+
|
|
332
|
+
if tokenizer is None:
|
|
333
|
+
raise ValueError("Tokenizer is required for logprob calculation")
|
|
334
|
+
|
|
335
|
+
# Encode the full text
|
|
336
|
+
tokens = tokenizer(text).input_ids # pylint: disable=E1102
|
|
337
|
+
|
|
338
|
+
# Track character offsets for each token
|
|
339
|
+
text_offset = []
|
|
340
|
+
start_idx = 0
|
|
341
|
+
|
|
342
|
+
token_strings = []
|
|
343
|
+
for token_id in tokens:
|
|
344
|
+
token_str = tokenizer.decode([token_id])
|
|
345
|
+
token_strings.append(token_str)
|
|
346
|
+
|
|
347
|
+
# Calculate character offsets for tokens - handles cases where tokens
|
|
348
|
+
# may not directly match in the original text due to encoding differences,
|
|
349
|
+
# special characters, or tokenization artifacts
|
|
350
|
+
try:
|
|
351
|
+
pos = text[start_idx:].find(token_str)
|
|
352
|
+
if pos != -1:
|
|
353
|
+
text_offset.append(start_idx + pos)
|
|
354
|
+
start_idx += pos + len(token_str)
|
|
355
|
+
else:
|
|
356
|
+
text_offset.append(start_idx)
|
|
357
|
+
except (TypeError, ValueError, UnicodeError):
|
|
358
|
+
# Fallback to current position when matching fails due to encoding issues
|
|
359
|
+
text_offset.append(start_idx)
|
|
360
|
+
|
|
361
|
+
# Get logits from model
|
|
362
|
+
logits = self._model_call(tokens)
|
|
363
|
+
|
|
364
|
+
# Calculate log probabilities for each token
|
|
365
|
+
all_log_probs = torch.log_softmax(logits, dim=-1)
|
|
366
|
+
|
|
367
|
+
# The first token doesn't have a conditional probability
|
|
368
|
+
# For tokens after the first, get the predicted probability
|
|
369
|
+
token_log_probs = []
|
|
370
|
+
top_logprobs_list = []
|
|
371
|
+
|
|
372
|
+
# For each position, get the actual token probability and top alternatives
|
|
373
|
+
for i in range(len(tokens)):
|
|
374
|
+
# Get previous token position logits
|
|
375
|
+
if i > 0: # First token has no preceding context
|
|
376
|
+
prev_logits = all_log_probs[i - 1]
|
|
377
|
+
curr_token_id = tokens[i]
|
|
378
|
+
# Get probability of the actual token that appeared
|
|
379
|
+
token_logprob = prev_logits[curr_token_id].item()
|
|
380
|
+
token_log_probs.append(token_logprob)
|
|
381
|
+
|
|
382
|
+
# Get top-k alternatives if requested
|
|
383
|
+
if logprobs is not None:
|
|
384
|
+
num_alternatives = logprobs if isinstance(logprobs, int) else 5
|
|
385
|
+
topk_values, topk_indices = torch.topk(
|
|
386
|
+
prev_logits, min(num_alternatives, prev_logits.size(-1))
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
# Create dictionary of token: logprob
|
|
390
|
+
position_logprobs = {}
|
|
391
|
+
for val, idx in zip(topk_values.tolist(), topk_indices.tolist()):
|
|
392
|
+
token_str = tokenizer.decode([idx])
|
|
393
|
+
position_logprobs[token_str] = val
|
|
394
|
+
|
|
395
|
+
top_logprobs_list.append(position_logprobs)
|
|
396
|
+
else:
|
|
397
|
+
# For the first token, we don't have a conditional probability
|
|
398
|
+
token_log_probs.append(None)
|
|
399
|
+
top_logprobs_list.append({})
|
|
400
|
+
|
|
401
|
+
# If we don't want to echo prompt tokens, filter them out
|
|
402
|
+
if not echo and prompt_length is not None:
|
|
403
|
+
# Ensure prompt_length is within bounds
|
|
404
|
+
prompt_length = min(prompt_length, len(tokens))
|
|
405
|
+
|
|
406
|
+
# Filter results to only include completion tokens
|
|
407
|
+
if prompt_length < len(tokens):
|
|
408
|
+
filtered_text_offset = text_offset[prompt_length:]
|
|
409
|
+
filtered_token_logprobs = token_log_probs[prompt_length:]
|
|
410
|
+
filtered_tokens = token_strings[prompt_length:]
|
|
411
|
+
filtered_top_logprobs = top_logprobs_list[prompt_length:]
|
|
412
|
+
|
|
413
|
+
return (
|
|
414
|
+
filtered_text_offset,
|
|
415
|
+
filtered_token_logprobs,
|
|
416
|
+
filtered_tokens,
|
|
417
|
+
filtered_top_logprobs,
|
|
418
|
+
)
|
|
419
|
+
else:
|
|
420
|
+
# No completion tokens
|
|
421
|
+
return [], [], [], []
|
|
422
|
+
|
|
423
|
+
return text_offset, token_log_probs, token_strings, top_logprobs_list
|
lemonade/tools/perplexity.py
CHANGED
|
@@ -1,8 +1,5 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import argparse
|
|
3
|
-
import pandas as pd
|
|
4
|
-
import torch
|
|
5
|
-
from datasets import load_dataset
|
|
6
3
|
from lemonade.state import State
|
|
7
4
|
from lemonade.tools import Tool
|
|
8
5
|
import lemonade.common.printing as printing
|
|
@@ -41,6 +38,10 @@ class AccuracyPerplexity(Tool):
|
|
|
41
38
|
state: State,
|
|
42
39
|
) -> State:
|
|
43
40
|
|
|
41
|
+
import pandas as pd
|
|
42
|
+
import torch
|
|
43
|
+
from datasets import load_dataset
|
|
44
|
+
|
|
44
45
|
try:
|
|
45
46
|
printing.log_info("Downloading dataset ...")
|
|
46
47
|
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
|
lemonade/tools/prompt.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import argparse
|
|
2
2
|
import os
|
|
3
|
-
import matplotlib.pyplot as plt
|
|
4
3
|
import lemonade.common.build as build
|
|
5
4
|
import lemonade.common.printing as printing
|
|
6
5
|
from lemonade.state import State
|
|
@@ -154,6 +153,8 @@ class LLMPrompt(Tool):
|
|
|
154
153
|
random_seed: int = DEFAULT_RANDOM_SEED,
|
|
155
154
|
) -> State:
|
|
156
155
|
|
|
156
|
+
import matplotlib.pyplot as plt
|
|
157
|
+
|
|
157
158
|
model: ModelAdapter = state.model
|
|
158
159
|
tokenizer: TokenizerAdapter = state.tokenizer
|
|
159
160
|
|
|
@@ -2,7 +2,6 @@ import argparse
|
|
|
2
2
|
import os
|
|
3
3
|
import sys
|
|
4
4
|
|
|
5
|
-
import torch
|
|
6
5
|
from lemonade.state import State
|
|
7
6
|
from lemonade.tools import Tool
|
|
8
7
|
import lemonade.common.printing as printing
|
|
@@ -101,6 +100,8 @@ class QuarkLoad(Tool):
|
|
|
101
100
|
Exception: If an error occurs during the QuarkLoad process.
|
|
102
101
|
"""
|
|
103
102
|
|
|
103
|
+
import torch
|
|
104
|
+
|
|
104
105
|
try:
|
|
105
106
|
if os.path.isdir(DEFAULT_QUARK_DIR):
|
|
106
107
|
quark_llm_path = os.path.join(
|
|
@@ -2,9 +2,6 @@ import argparse
|
|
|
2
2
|
import os
|
|
3
3
|
import sys
|
|
4
4
|
from pathlib import Path
|
|
5
|
-
|
|
6
|
-
import torch
|
|
7
|
-
from transformers import AutoProcessor
|
|
8
5
|
from lemonade.state import State
|
|
9
6
|
from lemonade.tools import Tool
|
|
10
7
|
import lemonade.common.printing as printing
|
|
@@ -319,8 +316,8 @@ class QuarkQuantize(Tool):
|
|
|
319
316
|
- Optionally exporting, compiling, and evaluating the model.
|
|
320
317
|
"""
|
|
321
318
|
|
|
322
|
-
|
|
323
|
-
|
|
319
|
+
import torch
|
|
320
|
+
from transformers import AutoProcessor
|
|
324
321
|
|
|
325
322
|
# Importing quark utils after adding to sys.path
|
|
326
323
|
from llm_utils.data_preparation import get_calib_dataloader
|
|
@@ -328,6 +325,9 @@ class QuarkQuantize(Tool):
|
|
|
328
325
|
from llm_ptq.configuration_preparation import get_config, get_export_config
|
|
329
326
|
from quark.torch import ModelQuantizer, ModelExporter, save_params
|
|
330
327
|
|
|
328
|
+
model = state.model.model
|
|
329
|
+
tokenizer = state.tokenizer
|
|
330
|
+
|
|
331
331
|
# 1. Load Model
|
|
332
332
|
printing.log_info("Loading model ...")
|
|
333
333
|
model_type = get_model_type(model)
|
lemonade/tools/report/table.py
CHANGED
|
@@ -7,10 +7,10 @@ from tabulate import tabulate
|
|
|
7
7
|
import lemonade.common.build as build
|
|
8
8
|
import lemonade.common.filesystem as fs
|
|
9
9
|
from lemonade.cache import Keys
|
|
10
|
-
from lemonade.tools.
|
|
11
|
-
from lemonade.tools.
|
|
10
|
+
from lemonade.tools.huggingface.bench import HuggingfaceBench
|
|
11
|
+
from lemonade.tools.llamacpp.bench import LlamaCppBench
|
|
12
12
|
from lemonade.tools.mmlu import AccuracyMMLU
|
|
13
|
-
from lemonade.tools.
|
|
13
|
+
from lemonade.tools.oga.bench import OgaBench
|
|
14
14
|
|
|
15
15
|
# List of python packages for which to log the version
|
|
16
16
|
PYTHON_PACKAGES = ["onnxruntime", "transformers", "lemonade-sdk", "voe"]
|