lemonade-sdk 9.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lemonade/__init__.py +5 -0
- lemonade/api.py +180 -0
- lemonade/cache.py +92 -0
- lemonade/cli.py +173 -0
- lemonade/common/__init__.py +0 -0
- lemonade/common/build.py +176 -0
- lemonade/common/cli_helpers.py +139 -0
- lemonade/common/exceptions.py +98 -0
- lemonade/common/filesystem.py +368 -0
- lemonade/common/inference_engines.py +408 -0
- lemonade/common/network.py +93 -0
- lemonade/common/printing.py +110 -0
- lemonade/common/status.py +471 -0
- lemonade/common/system_info.py +1411 -0
- lemonade/common/test_helpers.py +28 -0
- lemonade/profilers/__init__.py +1 -0
- lemonade/profilers/agt_power.py +437 -0
- lemonade/profilers/hwinfo_power.py +429 -0
- lemonade/profilers/memory_tracker.py +259 -0
- lemonade/profilers/profiler.py +58 -0
- lemonade/sequence.py +363 -0
- lemonade/state.py +159 -0
- lemonade/tools/__init__.py +1 -0
- lemonade/tools/accuracy.py +432 -0
- lemonade/tools/adapter.py +114 -0
- lemonade/tools/bench.py +302 -0
- lemonade/tools/flm/__init__.py +1 -0
- lemonade/tools/flm/utils.py +305 -0
- lemonade/tools/huggingface/bench.py +187 -0
- lemonade/tools/huggingface/load.py +235 -0
- lemonade/tools/huggingface/utils.py +359 -0
- lemonade/tools/humaneval.py +264 -0
- lemonade/tools/llamacpp/bench.py +255 -0
- lemonade/tools/llamacpp/load.py +222 -0
- lemonade/tools/llamacpp/utils.py +1260 -0
- lemonade/tools/management_tools.py +319 -0
- lemonade/tools/mmlu.py +319 -0
- lemonade/tools/oga/__init__.py +0 -0
- lemonade/tools/oga/bench.py +120 -0
- lemonade/tools/oga/load.py +804 -0
- lemonade/tools/oga/migration.py +403 -0
- lemonade/tools/oga/utils.py +462 -0
- lemonade/tools/perplexity.py +147 -0
- lemonade/tools/prompt.py +263 -0
- lemonade/tools/report/__init__.py +0 -0
- lemonade/tools/report/llm_report.py +203 -0
- lemonade/tools/report/table.py +899 -0
- lemonade/tools/server/__init__.py +0 -0
- lemonade/tools/server/flm.py +133 -0
- lemonade/tools/server/llamacpp.py +320 -0
- lemonade/tools/server/serve.py +2123 -0
- lemonade/tools/server/static/favicon.ico +0 -0
- lemonade/tools/server/static/index.html +279 -0
- lemonade/tools/server/static/js/chat.js +1059 -0
- lemonade/tools/server/static/js/model-settings.js +183 -0
- lemonade/tools/server/static/js/models.js +1395 -0
- lemonade/tools/server/static/js/shared.js +556 -0
- lemonade/tools/server/static/logs.html +191 -0
- lemonade/tools/server/static/styles.css +2654 -0
- lemonade/tools/server/static/webapp.html +321 -0
- lemonade/tools/server/tool_calls.py +153 -0
- lemonade/tools/server/tray.py +664 -0
- lemonade/tools/server/utils/macos_tray.py +226 -0
- lemonade/tools/server/utils/port.py +77 -0
- lemonade/tools/server/utils/thread.py +85 -0
- lemonade/tools/server/utils/windows_tray.py +408 -0
- lemonade/tools/server/webapp.py +34 -0
- lemonade/tools/server/wrapped_server.py +559 -0
- lemonade/tools/tool.py +374 -0
- lemonade/version.py +1 -0
- lemonade_install/__init__.py +1 -0
- lemonade_install/install.py +239 -0
- lemonade_sdk-9.1.1.dist-info/METADATA +276 -0
- lemonade_sdk-9.1.1.dist-info/RECORD +84 -0
- lemonade_sdk-9.1.1.dist-info/WHEEL +5 -0
- lemonade_sdk-9.1.1.dist-info/entry_points.txt +5 -0
- lemonade_sdk-9.1.1.dist-info/licenses/LICENSE +201 -0
- lemonade_sdk-9.1.1.dist-info/licenses/NOTICE.md +47 -0
- lemonade_sdk-9.1.1.dist-info/top_level.txt +3 -0
- lemonade_server/cli.py +805 -0
- lemonade_server/model_manager.py +758 -0
- lemonade_server/pydantic_models.py +159 -0
- lemonade_server/server_models.json +643 -0
- lemonade_server/settings.py +39 -0
|
@@ -0,0 +1,462 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import time
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
from queue import Queue
|
|
6
|
+
from packaging.version import Version
|
|
7
|
+
import onnxruntime_genai as og
|
|
8
|
+
from transformers import AutoTokenizer
|
|
9
|
+
from lemonade.tools.adapter import (
|
|
10
|
+
ModelAdapter,
|
|
11
|
+
TokenizerAdapter,
|
|
12
|
+
PassthroughTokenizerResult,
|
|
13
|
+
)
|
|
14
|
+
from lemonade_install.install import _get_ryzenai_version_info
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class OrtGenaiTokenizer(TokenizerAdapter):
|
|
18
|
+
def __init__(self, model: og.Model, hf_tokenizer: AutoTokenizer):
|
|
19
|
+
super().__init__(hf_tokenizer)
|
|
20
|
+
# Initialize OGA tokenizer
|
|
21
|
+
self.tokenizer = og.Tokenizer(model)
|
|
22
|
+
|
|
23
|
+
# Placeholder value since some code will try to query it
|
|
24
|
+
# If we actually need this to return a proper value, then
|
|
25
|
+
# og.GeneratorParams.eos_token_id has it
|
|
26
|
+
self.eos_token_id = None
|
|
27
|
+
|
|
28
|
+
def __call__(self, prompt: str, return_tensors="np"):
|
|
29
|
+
tokens = self.tokenizer.encode(prompt)
|
|
30
|
+
return PassthroughTokenizerResult(tokens)
|
|
31
|
+
|
|
32
|
+
# pylint: disable=unused-argument
|
|
33
|
+
def decode(self, response, skip_special_tokens=True) -> str:
|
|
34
|
+
return self.tokenizer.decode(response)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class OrtGenaiStreamer:
|
|
38
|
+
def __init__(self, tokenizer: OrtGenaiTokenizer, timeout=None):
|
|
39
|
+
self.tokenizer = tokenizer
|
|
40
|
+
self.text_queue = Queue()
|
|
41
|
+
self.stop_signal = None
|
|
42
|
+
self.timeout = timeout
|
|
43
|
+
|
|
44
|
+
def add_text(self, text: str):
|
|
45
|
+
self.text_queue.put(text, timeout=self.timeout)
|
|
46
|
+
|
|
47
|
+
def done(self):
|
|
48
|
+
self.text_queue.put(self.stop_signal, timeout=self.timeout)
|
|
49
|
+
|
|
50
|
+
def __iter__(self):
|
|
51
|
+
return self
|
|
52
|
+
|
|
53
|
+
def __next__(self):
|
|
54
|
+
value = self.text_queue.get(timeout=self.timeout)
|
|
55
|
+
if value == self.stop_signal:
|
|
56
|
+
raise StopIteration()
|
|
57
|
+
else:
|
|
58
|
+
return value
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class OrtGenaiModel(ModelAdapter):
|
|
62
|
+
|
|
63
|
+
def __init__(self, input_folder):
|
|
64
|
+
super().__init__()
|
|
65
|
+
self.model = og.Model(input_folder)
|
|
66
|
+
self.type = "ort-genai"
|
|
67
|
+
self.config = self.load_config(input_folder)
|
|
68
|
+
|
|
69
|
+
def load_config(self, input_folder):
|
|
70
|
+
rai_config_path = os.path.join(input_folder, "rai_config.json")
|
|
71
|
+
max_prompt_length = None
|
|
72
|
+
|
|
73
|
+
try:
|
|
74
|
+
detected_version, _ = _get_ryzenai_version_info()
|
|
75
|
+
|
|
76
|
+
if os.path.exists(rai_config_path):
|
|
77
|
+
with open(rai_config_path, "r", encoding="utf-8") as f:
|
|
78
|
+
rai_config = json.load(f)
|
|
79
|
+
if (
|
|
80
|
+
"max_prompt_length" in rai_config
|
|
81
|
+
and detected_version in rai_config["max_prompt_length"]
|
|
82
|
+
):
|
|
83
|
+
max_prompt_length = rai_config["max_prompt_length"][
|
|
84
|
+
detected_version
|
|
85
|
+
]
|
|
86
|
+
except: # pylint: disable=bare-except
|
|
87
|
+
pass
|
|
88
|
+
|
|
89
|
+
config_path = os.path.join(input_folder, "genai_config.json")
|
|
90
|
+
if os.path.exists(config_path):
|
|
91
|
+
with open(config_path, "r", encoding="utf-8") as f:
|
|
92
|
+
config_dict = json.load(f)
|
|
93
|
+
config_dict["max_prompt_length"] = max_prompt_length
|
|
94
|
+
return config_dict
|
|
95
|
+
return None
|
|
96
|
+
|
|
97
|
+
def generate(
|
|
98
|
+
self,
|
|
99
|
+
input_ids,
|
|
100
|
+
max_new_tokens=512,
|
|
101
|
+
min_new_tokens=0,
|
|
102
|
+
do_sample=True,
|
|
103
|
+
top_k=None,
|
|
104
|
+
top_p=None,
|
|
105
|
+
temperature=None,
|
|
106
|
+
repeat_penalty=None,
|
|
107
|
+
streamer: OrtGenaiStreamer = None,
|
|
108
|
+
pad_token_id=None,
|
|
109
|
+
stopping_criteria=None,
|
|
110
|
+
max_length=None,
|
|
111
|
+
random_seed=1,
|
|
112
|
+
):
|
|
113
|
+
params = og.GeneratorParams(self.model)
|
|
114
|
+
|
|
115
|
+
# OGA models return a list of tokens (older versions) or 1d numpy array (newer versions)
|
|
116
|
+
prompt_length = len(input_ids)
|
|
117
|
+
|
|
118
|
+
max_prompt_length = self.config.get("max_prompt_length")
|
|
119
|
+
if max_prompt_length and prompt_length > max_prompt_length:
|
|
120
|
+
raise ValueError(
|
|
121
|
+
f"This prompt (length {prompt_length}) exceeds the model's "
|
|
122
|
+
f"maximum allowed prompt length ({max_prompt_length})."
|
|
123
|
+
)
|
|
124
|
+
self.prompt_tokens = prompt_length
|
|
125
|
+
|
|
126
|
+
# There is a breaking API change in OGA 0.6.0
|
|
127
|
+
# Determine whether we should use the old or new APIs
|
|
128
|
+
# This also supports 0.6.0.dev0, which evaluates to less than 0.6.0 in Version
|
|
129
|
+
use_oga_post_6_api = (
|
|
130
|
+
Version(og.__version__) >= Version("0.6.0") or "0.6.0" in og.__version__
|
|
131
|
+
)
|
|
132
|
+
use_oga_pre_6_api = not use_oga_post_6_api
|
|
133
|
+
|
|
134
|
+
if pad_token_id:
|
|
135
|
+
params.pad_token_id = pad_token_id
|
|
136
|
+
|
|
137
|
+
# Handle max_length and max_new_tokens
|
|
138
|
+
if max_length and max_new_tokens:
|
|
139
|
+
logging.warning(
|
|
140
|
+
"Both max_length and max_new_tokens were provided. "
|
|
141
|
+
"max_length will take precedence. "
|
|
142
|
+
"When setting max_length, please explicitly set max_new_tokens to None."
|
|
143
|
+
)
|
|
144
|
+
max_length_to_use = None
|
|
145
|
+
if max_length:
|
|
146
|
+
max_length_to_use = max_length
|
|
147
|
+
elif max_new_tokens:
|
|
148
|
+
max_length_to_use = prompt_length + max_new_tokens
|
|
149
|
+
|
|
150
|
+
min_length = prompt_length + min_new_tokens
|
|
151
|
+
|
|
152
|
+
if use_oga_pre_6_api:
|
|
153
|
+
params.input_ids = input_ids
|
|
154
|
+
|
|
155
|
+
if random_seed is None:
|
|
156
|
+
random_seed = -1 # In og.Generator, -1 = seed with random device
|
|
157
|
+
|
|
158
|
+
# Get search config if available, otherwise use empty dict
|
|
159
|
+
# Thanks to the empty dict, if the model doesn't have a built-in search
|
|
160
|
+
# config, the .get() calls will all just use the default values
|
|
161
|
+
search_config = {}
|
|
162
|
+
if self.config and "search" in self.config:
|
|
163
|
+
search_config = self.config["search"]
|
|
164
|
+
|
|
165
|
+
# Apply parameter hierarchy: user provided > search config > defaults
|
|
166
|
+
default_top_k = 50
|
|
167
|
+
default_top_p = 1.0
|
|
168
|
+
default_temperature = 0.7
|
|
169
|
+
default_repetition_penalty = 1.0
|
|
170
|
+
|
|
171
|
+
top_k_to_use = (
|
|
172
|
+
top_k if top_k is not None else search_config.get("top_k", default_top_k)
|
|
173
|
+
)
|
|
174
|
+
top_p_to_use = (
|
|
175
|
+
top_p if top_p is not None else search_config.get("top_p", default_top_p)
|
|
176
|
+
)
|
|
177
|
+
temperature_to_use = (
|
|
178
|
+
temperature
|
|
179
|
+
if temperature is not None
|
|
180
|
+
else search_config.get("temperature", default_temperature)
|
|
181
|
+
)
|
|
182
|
+
# Map the llamacpp name, `repeat_penalty`, to the OGA name, `repetition_penalty`
|
|
183
|
+
repetition_penalty_to_use = (
|
|
184
|
+
repeat_penalty
|
|
185
|
+
if repeat_penalty is not None
|
|
186
|
+
else search_config.get("repetition_penalty", default_repetition_penalty)
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
# Set search options once with all parameters
|
|
190
|
+
params.set_search_options(
|
|
191
|
+
do_sample=search_config.get("do_sample", do_sample),
|
|
192
|
+
top_k=top_k_to_use,
|
|
193
|
+
top_p=top_p_to_use,
|
|
194
|
+
temperature=temperature_to_use,
|
|
195
|
+
repetition_penalty=repetition_penalty_to_use,
|
|
196
|
+
max_length=max_length_to_use,
|
|
197
|
+
min_length=min_length,
|
|
198
|
+
early_stopping=search_config.get("early_stopping", False),
|
|
199
|
+
length_penalty=search_config.get("length_penalty", 1.0),
|
|
200
|
+
num_beams=search_config.get("num_beams", 1),
|
|
201
|
+
num_return_sequences=search_config.get("num_return_sequences", 1),
|
|
202
|
+
past_present_share_buffer=search_config.get(
|
|
203
|
+
"past_present_share_buffer", True
|
|
204
|
+
),
|
|
205
|
+
random_seed=random_seed,
|
|
206
|
+
# Not currently supported by OGA
|
|
207
|
+
# diversity_penalty=search_config.get('diversity_penalty', 0.0),
|
|
208
|
+
# no_repeat_ngram_size=search_config.get('no_repeat_ngram_size', 0),
|
|
209
|
+
)
|
|
210
|
+
params.try_graph_capture_with_max_batch_size(1)
|
|
211
|
+
|
|
212
|
+
generator = og.Generator(self.model, params)
|
|
213
|
+
|
|
214
|
+
if streamer is None:
|
|
215
|
+
prompt_start_time = time.perf_counter()
|
|
216
|
+
if use_oga_post_6_api:
|
|
217
|
+
generator.append_tokens(input_ids)
|
|
218
|
+
if use_oga_pre_6_api:
|
|
219
|
+
generator.compute_logits()
|
|
220
|
+
generator.generate_next_token()
|
|
221
|
+
prompt_end_time = time.perf_counter()
|
|
222
|
+
|
|
223
|
+
self.time_to_first_token = prompt_end_time - prompt_start_time
|
|
224
|
+
|
|
225
|
+
if max_new_tokens > 1:
|
|
226
|
+
|
|
227
|
+
token_gen_times = []
|
|
228
|
+
while not generator.is_done():
|
|
229
|
+
token_gen_start_time = time.perf_counter()
|
|
230
|
+
if use_oga_pre_6_api:
|
|
231
|
+
generator.compute_logits()
|
|
232
|
+
generator.generate_next_token()
|
|
233
|
+
token_gen_end_time = time.perf_counter()
|
|
234
|
+
|
|
235
|
+
token_gen_times.append(token_gen_end_time - token_gen_start_time)
|
|
236
|
+
|
|
237
|
+
if token_gen_times:
|
|
238
|
+
# List will be empty if we generated 1 or 0 tokens, and we don't
|
|
239
|
+
# want a divide-by-zero error in those cases
|
|
240
|
+
avg_token_gen_latency_s = sum(token_gen_times) / len(
|
|
241
|
+
token_gen_times
|
|
242
|
+
)
|
|
243
|
+
self.tokens_per_second = 1 / avg_token_gen_latency_s
|
|
244
|
+
|
|
245
|
+
response = generator.get_sequence(0)
|
|
246
|
+
self.response_tokens = len(response) - self.prompt_tokens
|
|
247
|
+
return [response]
|
|
248
|
+
else:
|
|
249
|
+
if use_oga_post_6_api:
|
|
250
|
+
generator.append_tokens(input_ids)
|
|
251
|
+
tokenizer_stream = streamer.tokenizer.tokenizer.create_stream()
|
|
252
|
+
self.response_tokens = 0
|
|
253
|
+
stop_early = False
|
|
254
|
+
|
|
255
|
+
while not generator.is_done() and not stop_early:
|
|
256
|
+
if use_oga_pre_6_api:
|
|
257
|
+
generator.compute_logits()
|
|
258
|
+
generator.generate_next_token()
|
|
259
|
+
self.response_tokens += 1
|
|
260
|
+
|
|
261
|
+
new_token = generator.get_next_tokens()[0]
|
|
262
|
+
new_text = tokenizer_stream.decode(new_token)
|
|
263
|
+
|
|
264
|
+
streamer.add_text(new_text)
|
|
265
|
+
|
|
266
|
+
if stopping_criteria is not None:
|
|
267
|
+
if stopping_criteria[0].stop_event.is_set():
|
|
268
|
+
stop_early = True
|
|
269
|
+
|
|
270
|
+
streamer.done()
|
|
271
|
+
|
|
272
|
+
def _model_call(self, input_ids):
|
|
273
|
+
"""
|
|
274
|
+
Run the model on input_ids and get logits.
|
|
275
|
+
|
|
276
|
+
This method directly accesses model logits rather than using the full generate pipeline for
|
|
277
|
+
several important reasons:
|
|
278
|
+
1. Purpose: We need raw logits from a single forward pass, while generate() is optimized for
|
|
279
|
+
producing multiple tokens through iterative inference
|
|
280
|
+
2. Efficiency: Direct access is more efficient for logprob calculations with no
|
|
281
|
+
sampling overhead
|
|
282
|
+
3. Precision: Logprob calculations require exact control over input-to-output mapping
|
|
283
|
+
4. Consistency: Similar approach used in both HF and OGA implementations
|
|
284
|
+
|
|
285
|
+
Args:
|
|
286
|
+
input_ids: Input token IDs
|
|
287
|
+
|
|
288
|
+
Returns:
|
|
289
|
+
Logits for each token in the sequence
|
|
290
|
+
"""
|
|
291
|
+
import torch
|
|
292
|
+
|
|
293
|
+
# Setup generator params
|
|
294
|
+
params = og.GeneratorParams(self.model)
|
|
295
|
+
|
|
296
|
+
# Configure for a simple forward pass
|
|
297
|
+
params.set_search_options(
|
|
298
|
+
do_sample=False,
|
|
299
|
+
temperature=0.0,
|
|
300
|
+
max_length=len(input_ids),
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
# Initialize generator
|
|
304
|
+
generator = og.Generator(self.model, params)
|
|
305
|
+
|
|
306
|
+
# Feed tokens to model based on API version
|
|
307
|
+
generator.append_tokens(input_ids)
|
|
308
|
+
|
|
309
|
+
# Extract logits - this returns a list of logits tensors
|
|
310
|
+
logits = generator.get_output("logits")
|
|
311
|
+
|
|
312
|
+
# Convert to torch tensor for easier processing
|
|
313
|
+
return torch.tensor(logits[0])
|
|
314
|
+
|
|
315
|
+
def _select_cont_toks(self, logits, context_len, continuation_tokens):
|
|
316
|
+
"""
|
|
317
|
+
Select and process logits for continuation tokens.
|
|
318
|
+
|
|
319
|
+
Args:
|
|
320
|
+
logits: Full sequence logits
|
|
321
|
+
context_len: Length of context tokens
|
|
322
|
+
continuation_tokens: List or tensor of continuation token IDs
|
|
323
|
+
|
|
324
|
+
Returns:
|
|
325
|
+
Log probabilities for continuation tokens
|
|
326
|
+
"""
|
|
327
|
+
import torch
|
|
328
|
+
|
|
329
|
+
# Extract relevant logits for continuation prediction (shift by one)
|
|
330
|
+
cont_logits = logits[
|
|
331
|
+
context_len - 1 : context_len - 1 + len(continuation_tokens)
|
|
332
|
+
]
|
|
333
|
+
|
|
334
|
+
# Convert to torch tensors if needed
|
|
335
|
+
if not isinstance(continuation_tokens, torch.Tensor):
|
|
336
|
+
continuation_tokens = torch.tensor(continuation_tokens, dtype=torch.long)
|
|
337
|
+
|
|
338
|
+
# Apply log softmax to get log probabilities
|
|
339
|
+
log_probs = torch.log_softmax(cont_logits, dim=-1)
|
|
340
|
+
|
|
341
|
+
# Get log probs for the specific continuation tokens
|
|
342
|
+
token_log_probs = torch.gather(
|
|
343
|
+
log_probs, 1, continuation_tokens.unsqueeze(-1)
|
|
344
|
+
).squeeze(-1)
|
|
345
|
+
|
|
346
|
+
return token_log_probs
|
|
347
|
+
|
|
348
|
+
def compute_logprobs(
|
|
349
|
+
self, text, tokenizer, prompt_length=None, logprobs=None, echo=False
|
|
350
|
+
):
|
|
351
|
+
"""
|
|
352
|
+
Compute log probabilities for all tokens in the given text.
|
|
353
|
+
|
|
354
|
+
Args:
|
|
355
|
+
text: The full text to analyze (e.g., prompt + completion)
|
|
356
|
+
prompt_length: Number of tokens in the prompt. If provided and echo=False,
|
|
357
|
+
only completion tokens after this position will be returned.
|
|
358
|
+
logprobs: If not None, return log probabilities. Value indicates how many top
|
|
359
|
+
alternatives to return. If True but not an integer, defaults to 5 alternatives.
|
|
360
|
+
echo: If True, include logprobs for prompt tokens. If False, only return logprobs
|
|
361
|
+
for completion tokens.
|
|
362
|
+
|
|
363
|
+
Returns:
|
|
364
|
+
- text_offset: Character offsets for each token in the text
|
|
365
|
+
- token_logprobs: Log probability for each token
|
|
366
|
+
- tokens: The actual tokens used
|
|
367
|
+
- top_logprobs: Top alternative log probabilities for each position
|
|
368
|
+
"""
|
|
369
|
+
import torch
|
|
370
|
+
|
|
371
|
+
if tokenizer is None:
|
|
372
|
+
raise ValueError("Tokenizer is required for logprob calculation")
|
|
373
|
+
|
|
374
|
+
# Encode the full text
|
|
375
|
+
tokens = tokenizer(text).input_ids # pylint: disable=E1102
|
|
376
|
+
|
|
377
|
+
# Track character offsets for each token
|
|
378
|
+
text_offset = []
|
|
379
|
+
start_idx = 0
|
|
380
|
+
|
|
381
|
+
token_strings = []
|
|
382
|
+
for token_id in tokens:
|
|
383
|
+
token_str = tokenizer.decode([token_id])
|
|
384
|
+
token_strings.append(token_str)
|
|
385
|
+
|
|
386
|
+
# Calculate character offsets for tokens - handles cases where tokens
|
|
387
|
+
# may not directly match in the original text due to encoding differences,
|
|
388
|
+
# special characters, or tokenization artifacts
|
|
389
|
+
try:
|
|
390
|
+
pos = text[start_idx:].find(token_str)
|
|
391
|
+
if pos != -1:
|
|
392
|
+
text_offset.append(start_idx + pos)
|
|
393
|
+
start_idx += pos + len(token_str)
|
|
394
|
+
else:
|
|
395
|
+
text_offset.append(start_idx)
|
|
396
|
+
except (TypeError, ValueError, UnicodeError):
|
|
397
|
+
# Fallback to current position when matching fails due to encoding issues
|
|
398
|
+
text_offset.append(start_idx)
|
|
399
|
+
|
|
400
|
+
# Get logits from model
|
|
401
|
+
logits = self._model_call(tokens)
|
|
402
|
+
|
|
403
|
+
# Calculate log probabilities for each token
|
|
404
|
+
all_log_probs = torch.log_softmax(logits, dim=-1)
|
|
405
|
+
|
|
406
|
+
# The first token doesn't have a conditional probability
|
|
407
|
+
# For tokens after the first, get the predicted probability
|
|
408
|
+
token_log_probs = []
|
|
409
|
+
top_logprobs_list = []
|
|
410
|
+
|
|
411
|
+
# For each position, get the actual token probability and top alternatives
|
|
412
|
+
for i in range(len(tokens)):
|
|
413
|
+
# Get previous token position logits
|
|
414
|
+
if i > 0: # First token has no preceding context
|
|
415
|
+
prev_logits = all_log_probs[i - 1]
|
|
416
|
+
curr_token_id = tokens[i]
|
|
417
|
+
# Get probability of the actual token that appeared
|
|
418
|
+
token_logprob = prev_logits[curr_token_id].item()
|
|
419
|
+
token_log_probs.append(token_logprob)
|
|
420
|
+
|
|
421
|
+
# Get top-k alternatives if requested
|
|
422
|
+
if logprobs is not None:
|
|
423
|
+
num_alternatives = logprobs if isinstance(logprobs, int) else 5
|
|
424
|
+
topk_values, topk_indices = torch.topk(
|
|
425
|
+
prev_logits, min(num_alternatives, prev_logits.size(-1))
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
# Create dictionary of token: logprob
|
|
429
|
+
position_logprobs = {}
|
|
430
|
+
for val, idx in zip(topk_values.tolist(), topk_indices.tolist()):
|
|
431
|
+
token_str = tokenizer.decode([idx])
|
|
432
|
+
position_logprobs[token_str] = val
|
|
433
|
+
|
|
434
|
+
top_logprobs_list.append(position_logprobs)
|
|
435
|
+
else:
|
|
436
|
+
# For the first token, we don't have a conditional probability
|
|
437
|
+
token_log_probs.append(None)
|
|
438
|
+
top_logprobs_list.append({})
|
|
439
|
+
|
|
440
|
+
# If we don't want to echo prompt tokens, filter them out
|
|
441
|
+
if not echo and prompt_length is not None:
|
|
442
|
+
# Ensure prompt_length is within bounds
|
|
443
|
+
prompt_length = min(prompt_length, len(tokens))
|
|
444
|
+
|
|
445
|
+
# Filter results to only include completion tokens
|
|
446
|
+
if prompt_length < len(tokens):
|
|
447
|
+
filtered_text_offset = text_offset[prompt_length:]
|
|
448
|
+
filtered_token_logprobs = token_log_probs[prompt_length:]
|
|
449
|
+
filtered_tokens = token_strings[prompt_length:]
|
|
450
|
+
filtered_top_logprobs = top_logprobs_list[prompt_length:]
|
|
451
|
+
|
|
452
|
+
return (
|
|
453
|
+
filtered_text_offset,
|
|
454
|
+
filtered_token_logprobs,
|
|
455
|
+
filtered_tokens,
|
|
456
|
+
filtered_top_logprobs,
|
|
457
|
+
)
|
|
458
|
+
else:
|
|
459
|
+
# No completion tokens
|
|
460
|
+
return [], [], [], []
|
|
461
|
+
|
|
462
|
+
return text_offset, token_log_probs, token_strings, top_logprobs_list
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import argparse
|
|
3
|
+
from lemonade.state import State
|
|
4
|
+
from lemonade.tools import Tool
|
|
5
|
+
import lemonade.common.printing as printing
|
|
6
|
+
import lemonade.common.build as build
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class AccuracyPerplexity(Tool):
|
|
10
|
+
"""
|
|
11
|
+
Measure perplexity of an LLM using the Wikitext-2 dataset.
|
|
12
|
+
|
|
13
|
+
Required input state:
|
|
14
|
+
- state.model: instance that provides a __call__() method that returns
|
|
15
|
+
output.logits and supports model.config.max_position_embeddings
|
|
16
|
+
- state.tokenizer: instance of Hugging Face PretrainedTokenizer
|
|
17
|
+
|
|
18
|
+
Output state produced: None
|
|
19
|
+
|
|
20
|
+
See docs/dev_cli/perplexity.md for more details.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
unique_name = "accuracy-perplexity"
|
|
24
|
+
|
|
25
|
+
def __init__(self):
|
|
26
|
+
super().__init__(monitor_message="Measuring perplexity")
|
|
27
|
+
|
|
28
|
+
@staticmethod
|
|
29
|
+
def parser(add_help: bool = True) -> argparse.ArgumentParser:
|
|
30
|
+
parser = __class__.helpful_parser(
|
|
31
|
+
short_description="Measure perplexity score",
|
|
32
|
+
add_help=add_help,
|
|
33
|
+
)
|
|
34
|
+
return parser
|
|
35
|
+
|
|
36
|
+
def run(
|
|
37
|
+
self,
|
|
38
|
+
state: State,
|
|
39
|
+
) -> State:
|
|
40
|
+
|
|
41
|
+
import pandas as pd
|
|
42
|
+
import torch
|
|
43
|
+
from datasets import load_dataset
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
printing.log_info("Downloading dataset ...")
|
|
47
|
+
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
|
|
48
|
+
except Exception as e: # pylint: disable=broad-except
|
|
49
|
+
printing.log_error(f"Error during dataset load: {e}")
|
|
50
|
+
raise e
|
|
51
|
+
|
|
52
|
+
tokenizer = state.tokenizer
|
|
53
|
+
model = state.model
|
|
54
|
+
# Tokenize the entire test dataset text, joining entries with double new lines
|
|
55
|
+
encodings = tokenizer("\n\n".join(dataset["text"]), return_tensors="pt")
|
|
56
|
+
|
|
57
|
+
# Retrieve the maximum input length that the model can handle
|
|
58
|
+
try:
|
|
59
|
+
max_length = model.config.max_position_embeddings
|
|
60
|
+
except AttributeError:
|
|
61
|
+
# Some LLMs do not have the config.max_position_embeddings attribute
|
|
62
|
+
# However, most LLMs support at least 2048 context length, so this
|
|
63
|
+
# try-except will allow a few more LLMs to work
|
|
64
|
+
max_length = 2048
|
|
65
|
+
# Set stride to half of the maximum input length for overlapping window processing
|
|
66
|
+
# Refer to docs/dev_cli/perplexity.md for more information on sliding window
|
|
67
|
+
stride = max_length // 2
|
|
68
|
+
# Determine the total sequence length of the tokenized input
|
|
69
|
+
seq_len = encodings.input_ids.size(1)
|
|
70
|
+
|
|
71
|
+
negative_log_likelihoods = []
|
|
72
|
+
summary_data = []
|
|
73
|
+
prev_end_location = 0
|
|
74
|
+
|
|
75
|
+
model_results_dir = os.path.join(
|
|
76
|
+
build.output_dir(state.cache_dir, state.build_name), "perplexity"
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
for begin_location in range(0, seq_len, stride):
|
|
80
|
+
end_location = min(begin_location + max_length, seq_len)
|
|
81
|
+
target_len = end_location - prev_end_location
|
|
82
|
+
input_ids = encodings.input_ids[:, begin_location:end_location]
|
|
83
|
+
target_ids = input_ids.clone()
|
|
84
|
+
target_ids[:, :-target_len] = -100
|
|
85
|
+
|
|
86
|
+
# Forward pass the model to get logits
|
|
87
|
+
with torch.no_grad():
|
|
88
|
+
try:
|
|
89
|
+
outputs = model(input_ids, labels=target_ids)
|
|
90
|
+
logits = outputs.logits
|
|
91
|
+
except Exception as e: # pylint: disable=broad-except
|
|
92
|
+
printing.log_error(
|
|
93
|
+
f"Error during model forward pass execution: {e}"
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# Compute loss manually for visualization
|
|
97
|
+
shift_logits = logits[..., :-1, :].contiguous()
|
|
98
|
+
shift_labels = target_ids[..., 1:].contiguous()
|
|
99
|
+
effective_token_count = (target_ids != -100).sum().item()
|
|
100
|
+
negative_log_likelihoods.append(
|
|
101
|
+
(outputs.loss.item(), effective_token_count)
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# Decode predicted and actual next words for the last token position
|
|
105
|
+
predictions = torch.argmax(shift_logits, dim=-1)
|
|
106
|
+
predicted_tokens = predictions[:, -1]
|
|
107
|
+
actual_tokens = shift_labels[:, -1]
|
|
108
|
+
|
|
109
|
+
predicted_words = tokenizer.batch_decode(
|
|
110
|
+
predicted_tokens, skip_special_tokens=True
|
|
111
|
+
)
|
|
112
|
+
actual_words = tokenizer.batch_decode(
|
|
113
|
+
actual_tokens, skip_special_tokens=True
|
|
114
|
+
)
|
|
115
|
+
context = tokenizer.decode(input_ids[0, :])
|
|
116
|
+
|
|
117
|
+
summary_data.append(
|
|
118
|
+
{
|
|
119
|
+
"Context": context[-stride:],
|
|
120
|
+
"Predicted next word": predicted_words,
|
|
121
|
+
"Actual next word": actual_words,
|
|
122
|
+
"Loss for this window": outputs.loss.item(),
|
|
123
|
+
}
|
|
124
|
+
)
|
|
125
|
+
prev_end_location = end_location
|
|
126
|
+
|
|
127
|
+
# Total loss calculation considering the number of tokens for each segment
|
|
128
|
+
total_loss = sum(loss * count for loss, count in negative_log_likelihoods)
|
|
129
|
+
total_tokens = sum(count for _, count in negative_log_likelihoods)
|
|
130
|
+
|
|
131
|
+
# Calculate average negative_log_likelihood and perplexity
|
|
132
|
+
average_negative_log_likelihood = total_loss / total_tokens
|
|
133
|
+
perplexity = torch.exp(torch.tensor(average_negative_log_likelihood))
|
|
134
|
+
|
|
135
|
+
# Save accuracy results to stats file
|
|
136
|
+
state.save_stat("perplexity_score", float(perplexity.item()))
|
|
137
|
+
|
|
138
|
+
# Save accuracy results to CSV file
|
|
139
|
+
summary_df = pd.DataFrame(summary_data)
|
|
140
|
+
summary_df.to_csv(
|
|
141
|
+
os.path.join(model_results_dir, "summary_results.csv"), index=False
|
|
142
|
+
)
|
|
143
|
+
return state
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
# This file was originally licensed under Apache 2.0. It has been modified.
|
|
147
|
+
# Modifications Copyright (c) 2025 AMD
|