lemonade-sdk 7.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lemonade-sdk might be problematic. Click here for more details.
- lemonade/__init__.py +5 -0
- lemonade/api.py +125 -0
- lemonade/cache.py +85 -0
- lemonade/cli.py +135 -0
- lemonade/common/__init__.py +0 -0
- lemonade/common/analyze_model.py +26 -0
- lemonade/common/build.py +223 -0
- lemonade/common/cli_helpers.py +139 -0
- lemonade/common/exceptions.py +98 -0
- lemonade/common/filesystem.py +368 -0
- lemonade/common/labels.py +61 -0
- lemonade/common/onnx_helpers.py +176 -0
- lemonade/common/plugins.py +10 -0
- lemonade/common/printing.py +110 -0
- lemonade/common/status.py +490 -0
- lemonade/common/system_info.py +390 -0
- lemonade/common/tensor_helpers.py +83 -0
- lemonade/common/test_helpers.py +28 -0
- lemonade/profilers/__init__.py +1 -0
- lemonade/profilers/memory_tracker.py +257 -0
- lemonade/profilers/profiler.py +55 -0
- lemonade/sequence.py +363 -0
- lemonade/state.py +159 -0
- lemonade/tools/__init__.py +1 -0
- lemonade/tools/adapter.py +104 -0
- lemonade/tools/bench.py +284 -0
- lemonade/tools/huggingface_bench.py +267 -0
- lemonade/tools/huggingface_load.py +520 -0
- lemonade/tools/humaneval.py +258 -0
- lemonade/tools/llamacpp.py +261 -0
- lemonade/tools/llamacpp_bench.py +154 -0
- lemonade/tools/management_tools.py +273 -0
- lemonade/tools/mmlu.py +327 -0
- lemonade/tools/ort_genai/__init__.py +0 -0
- lemonade/tools/ort_genai/oga.py +1129 -0
- lemonade/tools/ort_genai/oga_bench.py +142 -0
- lemonade/tools/perplexity.py +146 -0
- lemonade/tools/prompt.py +228 -0
- lemonade/tools/quark/__init__.py +0 -0
- lemonade/tools/quark/quark_load.py +172 -0
- lemonade/tools/quark/quark_quantize.py +439 -0
- lemonade/tools/report/__init__.py +0 -0
- lemonade/tools/report/llm_report.py +203 -0
- lemonade/tools/report/table.py +739 -0
- lemonade/tools/server/__init__.py +0 -0
- lemonade/tools/server/serve.py +1354 -0
- lemonade/tools/server/tool_calls.py +146 -0
- lemonade/tools/tool.py +374 -0
- lemonade/version.py +1 -0
- lemonade_install/__init__.py +1 -0
- lemonade_install/install.py +774 -0
- lemonade_sdk-7.0.0.dist-info/METADATA +116 -0
- lemonade_sdk-7.0.0.dist-info/RECORD +61 -0
- lemonade_sdk-7.0.0.dist-info/WHEEL +5 -0
- lemonade_sdk-7.0.0.dist-info/entry_points.txt +4 -0
- lemonade_sdk-7.0.0.dist-info/licenses/LICENSE +201 -0
- lemonade_sdk-7.0.0.dist-info/licenses/NOTICE.md +21 -0
- lemonade_sdk-7.0.0.dist-info/top_level.txt +3 -0
- lemonade_server/cli.py +260 -0
- lemonade_server/model_manager.py +98 -0
- lemonade_server/server_models.json +142 -0
|
@@ -0,0 +1,520 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
from typing import Dict, Optional
|
|
3
|
+
import json
|
|
4
|
+
import socket
|
|
5
|
+
import transformers
|
|
6
|
+
import torch
|
|
7
|
+
from huggingface_hub import model_info
|
|
8
|
+
from lemonade.state import State
|
|
9
|
+
import lemonade.common.status as status
|
|
10
|
+
import lemonade.common.printing as printing
|
|
11
|
+
from lemonade.tools import FirstTool
|
|
12
|
+
from lemonade.tools.adapter import ModelAdapter, TokenizerAdapter
|
|
13
|
+
from lemonade.cache import Keys
|
|
14
|
+
|
|
15
|
+
# Command line interfaces for tools will use string inputs for data
|
|
16
|
+
# types, however the internal tool logic will need to know the actual
|
|
17
|
+
# torch type
|
|
18
|
+
str_to_dtype = {
|
|
19
|
+
"float32": torch.float32,
|
|
20
|
+
"float16": torch.float16,
|
|
21
|
+
"bfloat16": torch.bfloat16,
|
|
22
|
+
"int8_static": torch.int8,
|
|
23
|
+
"int8_dynamic": torch.int8,
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def make_example_inputs(state: State) -> Dict:
|
|
28
|
+
"""
|
|
29
|
+
Create a dictionary of LLM inputs that can be passed as an argument
|
|
30
|
+
into quantization, ONNX export, etc.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
tokenizer = state.tokenizer
|
|
34
|
+
inputs_ids = tokenizer("Hello there", return_tensors="pt").input_ids
|
|
35
|
+
return {"input_ids": inputs_ids}
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class HuggingfaceTokenizerAdapter(TokenizerAdapter):
|
|
39
|
+
def __init__(self, tokenizer: transformers.AutoTokenizer, device: str):
|
|
40
|
+
super().__init__(tokenizer)
|
|
41
|
+
self.tokenizer = tokenizer
|
|
42
|
+
self.device = device
|
|
43
|
+
|
|
44
|
+
def __call__(self, prompt, **kwargs):
|
|
45
|
+
tokens = self.tokenizer(prompt, **kwargs)
|
|
46
|
+
if self.device:
|
|
47
|
+
return tokens.to(self.device)
|
|
48
|
+
else:
|
|
49
|
+
return tokens
|
|
50
|
+
|
|
51
|
+
def decode(self, response, **kwargs):
|
|
52
|
+
return self.tokenizer.decode(response, **kwargs)
|
|
53
|
+
|
|
54
|
+
def batch_decode(self, tokens, **kwargs):
|
|
55
|
+
return self.tokenizer.batch_decode(tokens, **kwargs)
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def eos_token_id(self):
|
|
59
|
+
return self.tokenizer.eos_token_id
|
|
60
|
+
|
|
61
|
+
def save_pretrained(self, model_dir, **kwargs):
|
|
62
|
+
return self.tokenizer.save_pretrained(model_dir, **kwargs)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def is_offline():
|
|
66
|
+
"""
|
|
67
|
+
Check if the system is offline by attempting to connect to huggingface.co.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
bool: True if the system is offline (cannot connect to huggingface.co),
|
|
71
|
+
False otherwise.
|
|
72
|
+
"""
|
|
73
|
+
try:
|
|
74
|
+
socket.gethostbyname("huggingface.co")
|
|
75
|
+
return False
|
|
76
|
+
except socket.gaierror:
|
|
77
|
+
return True
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def get_base_model(checkpoint: str) -> Optional[str]:
|
|
81
|
+
"""
|
|
82
|
+
Get the base model information for a given checkpoint from the Hugging Face Hub.
|
|
83
|
+
Will auto-detect if we're offline and skip the network call in that case.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
checkpoint: The model checkpoint to query
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
The base model name if found, or None if not found or error occurs
|
|
90
|
+
"""
|
|
91
|
+
# Skip network call in offline mode
|
|
92
|
+
if is_offline():
|
|
93
|
+
return None
|
|
94
|
+
|
|
95
|
+
try:
|
|
96
|
+
info = model_info(checkpoint)
|
|
97
|
+
if info.cardData and "base_model" in info.cardData:
|
|
98
|
+
if info.cardData["base_model"] is not None:
|
|
99
|
+
# This is a derived model
|
|
100
|
+
return info.cardData["base_model"]
|
|
101
|
+
else:
|
|
102
|
+
# This is itself a base model
|
|
103
|
+
return [checkpoint]
|
|
104
|
+
except Exception: # pylint: disable=broad-except
|
|
105
|
+
pass
|
|
106
|
+
return None
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class HuggingfaceLoad(FirstTool):
|
|
110
|
+
"""
|
|
111
|
+
Load an LLM as a torch.nn.Module using the Hugging Face transformers
|
|
112
|
+
from_pretrained() API.
|
|
113
|
+
|
|
114
|
+
Expected input: a checkpoint to load
|
|
115
|
+
|
|
116
|
+
Output state produced:
|
|
117
|
+
- state.model: instance of torch.nn.Module that implements an LLM.
|
|
118
|
+
- state.inputs: tokenized example inputs to the model, in the form of a
|
|
119
|
+
dictionary of kwargs.
|
|
120
|
+
- state.tokenizer: instance of Hugging Face PretrainedTokenizer.
|
|
121
|
+
- state.dtype: data type of the model.
|
|
122
|
+
- state.checkpoint: pretrained checkpoint used to load the model.
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
unique_name = "huggingface-load"
|
|
126
|
+
|
|
127
|
+
def __init__(self):
|
|
128
|
+
super().__init__(monitor_message="Loading Huggingface checkpoint")
|
|
129
|
+
|
|
130
|
+
self.status_stats = [Keys.DTYPE]
|
|
131
|
+
|
|
132
|
+
@staticmethod
|
|
133
|
+
def parser(add_help: bool = True) -> argparse.ArgumentParser:
|
|
134
|
+
parser = __class__.helpful_parser(
|
|
135
|
+
short_description="Load an LLM in PyTorch using huggingface transformers",
|
|
136
|
+
add_help=add_help,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
default_dtype = "float32"
|
|
140
|
+
parser.add_argument(
|
|
141
|
+
"--dtype",
|
|
142
|
+
"-d",
|
|
143
|
+
required=False,
|
|
144
|
+
default=default_dtype,
|
|
145
|
+
help=f"Data type to load the model in (default: {default_dtype}).",
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
choices = ["cpu", "cuda"]
|
|
149
|
+
for cuda in range(15):
|
|
150
|
+
choices.append(f"cuda:{cuda}")
|
|
151
|
+
parser.add_argument(
|
|
152
|
+
"--device",
|
|
153
|
+
required=False,
|
|
154
|
+
default=None,
|
|
155
|
+
choices=choices,
|
|
156
|
+
help="Move the model and inputs to a device using the .to() method "
|
|
157
|
+
"(default: don't call the .to() method)",
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
parser.add_argument(
|
|
161
|
+
"--load-kwargs",
|
|
162
|
+
required=False,
|
|
163
|
+
default="{}",
|
|
164
|
+
type=json.loads,
|
|
165
|
+
help="Arbitrary kwargs, in json format, that will be passed as "
|
|
166
|
+
"from_pretrained(**kwargs). "
|
|
167
|
+
r"Example: --load-kwargs='{\"trust_remote_code\": true} would result in "
|
|
168
|
+
"from_pretrained(trust_remote_code=True)",
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
parser.add_argument(
|
|
172
|
+
"--channels-last",
|
|
173
|
+
default=True,
|
|
174
|
+
type=bool,
|
|
175
|
+
help="Whether to format the model in memory using "
|
|
176
|
+
"channels-last (default: True)",
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
return parser
|
|
180
|
+
|
|
181
|
+
def parse(self, state: State, args, known_only=True) -> argparse.Namespace:
|
|
182
|
+
|
|
183
|
+
parsed_args = super().parse(state, args, known_only)
|
|
184
|
+
|
|
185
|
+
# Save stats about the user's input (do this prior to decoding)
|
|
186
|
+
state.save_stat(Keys.CHECKPOINT, parsed_args.input)
|
|
187
|
+
state.save_stat(Keys.DTYPE, parsed_args.dtype)
|
|
188
|
+
|
|
189
|
+
# Decode dtype arg into a torch value
|
|
190
|
+
parsed_args.dtype = str_to_dtype[parsed_args.dtype]
|
|
191
|
+
|
|
192
|
+
return parsed_args
|
|
193
|
+
|
|
194
|
+
def run(
|
|
195
|
+
self,
|
|
196
|
+
state: State,
|
|
197
|
+
input: str = "",
|
|
198
|
+
dtype: torch.dtype = torch.float32,
|
|
199
|
+
device: Optional[str] = None,
|
|
200
|
+
load_kwargs: Optional[Dict] = None,
|
|
201
|
+
channels_last: bool = True,
|
|
202
|
+
) -> State:
|
|
203
|
+
# Auto-detect offline status
|
|
204
|
+
offline = is_offline()
|
|
205
|
+
if offline:
|
|
206
|
+
printing.log_warning(
|
|
207
|
+
"Network connectivity to huggingface.co not detected. Running in offline mode."
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
checkpoint = input
|
|
211
|
+
|
|
212
|
+
if load_kwargs is None:
|
|
213
|
+
load_kwargs_to_use = {}
|
|
214
|
+
else:
|
|
215
|
+
load_kwargs_to_use = load_kwargs
|
|
216
|
+
|
|
217
|
+
# Add local_files_only to kwargs in offline mode
|
|
218
|
+
if offline:
|
|
219
|
+
load_kwargs_to_use["local_files_only"] = True
|
|
220
|
+
|
|
221
|
+
if vars(state).get(Keys.MODEL):
|
|
222
|
+
raise ValueError("HuggingfaceLoad must be the first tool in the sequence")
|
|
223
|
+
|
|
224
|
+
try:
|
|
225
|
+
model = transformers.AutoModelForCausalLM.from_pretrained(
|
|
226
|
+
checkpoint,
|
|
227
|
+
torch_dtype=dtype,
|
|
228
|
+
low_cpu_mem_usage=True,
|
|
229
|
+
**load_kwargs_to_use,
|
|
230
|
+
)
|
|
231
|
+
except Exception as e:
|
|
232
|
+
if offline and "Can't load config for" in str(e):
|
|
233
|
+
raise ValueError(
|
|
234
|
+
f"Cannot load model {checkpoint} in offline mode. "
|
|
235
|
+
f"The model files may not be available locally. Original error: {str(e)}"
|
|
236
|
+
)
|
|
237
|
+
raise
|
|
238
|
+
|
|
239
|
+
# Only call the model.to() method if an argument to this function
|
|
240
|
+
# provides a reason to do so
|
|
241
|
+
to_args = {}
|
|
242
|
+
if channels_last:
|
|
243
|
+
to_args["memory_format"] = torch.channels_last
|
|
244
|
+
if device:
|
|
245
|
+
to_args["device"] = device
|
|
246
|
+
if to_args:
|
|
247
|
+
model.to(**to_args)
|
|
248
|
+
|
|
249
|
+
model = model.eval()
|
|
250
|
+
|
|
251
|
+
try:
|
|
252
|
+
tokenizer_kwargs = {
|
|
253
|
+
"use_fast": False,
|
|
254
|
+
"model_max_length": 4096,
|
|
255
|
+
"padding_side": "left",
|
|
256
|
+
}
|
|
257
|
+
if offline:
|
|
258
|
+
tokenizer_kwargs["local_files_only"] = True
|
|
259
|
+
|
|
260
|
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
|
261
|
+
checkpoint, **tokenizer_kwargs
|
|
262
|
+
)
|
|
263
|
+
except ValueError as e:
|
|
264
|
+
# Sometimes those specific tokenizer flags are not supported, in which
|
|
265
|
+
# case we try to just load a simple tokenizer
|
|
266
|
+
tokenizer_kwargs = {}
|
|
267
|
+
if offline:
|
|
268
|
+
tokenizer_kwargs["local_files_only"] = True
|
|
269
|
+
|
|
270
|
+
try:
|
|
271
|
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
|
272
|
+
checkpoint, **tokenizer_kwargs
|
|
273
|
+
)
|
|
274
|
+
except Exception as e:
|
|
275
|
+
if offline and "Can't load tokenizer for" in str(e):
|
|
276
|
+
raise ValueError(
|
|
277
|
+
f"Cannot load tokenizer for {checkpoint} in offline mode. "
|
|
278
|
+
f"The tokenizer files may not be available locally. "
|
|
279
|
+
f"Original error: {str(e)}"
|
|
280
|
+
)
|
|
281
|
+
raise
|
|
282
|
+
|
|
283
|
+
# Pass the model and inputs into state
|
|
284
|
+
state.model = HuggingfaceAdapter(model, dtype, device, tokenizer)
|
|
285
|
+
|
|
286
|
+
state.tokenizer = HuggingfaceTokenizerAdapter(tokenizer, device)
|
|
287
|
+
state.dtype = dtype
|
|
288
|
+
state.checkpoint = checkpoint
|
|
289
|
+
state.device = device
|
|
290
|
+
|
|
291
|
+
# Save stats about the model
|
|
292
|
+
state.save_stat(Keys.CHECKPOINT, checkpoint)
|
|
293
|
+
state.save_stat(Keys.DTYPE, str(dtype).split(".")[1])
|
|
294
|
+
state.save_stat(Keys.DEVICE, device)
|
|
295
|
+
|
|
296
|
+
# Get base model information
|
|
297
|
+
base_model = get_base_model(checkpoint)
|
|
298
|
+
if base_model is not None:
|
|
299
|
+
state.save_stat("base_model", base_model)
|
|
300
|
+
|
|
301
|
+
# Create a UniqueInvocationInfo and ModelInfo so that we can display status
|
|
302
|
+
# at the end of the sequence
|
|
303
|
+
status.add_to_state(state=state, name=input, model=model)
|
|
304
|
+
|
|
305
|
+
return state
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
class HuggingfaceAdapter(ModelAdapter):
|
|
309
|
+
"""
|
|
310
|
+
Wrapper class for Huggingface LLMs that handle generation arguments
|
|
311
|
+
from callers to match HF specification.
|
|
312
|
+
|
|
313
|
+
repetition_penalty: helps the LLM avoid repeating the same short
|
|
314
|
+
phrase in the response over and over.
|
|
315
|
+
temperature: helps the LLM stay focused on the prompt.
|
|
316
|
+
do_sample: apply the temperature.
|
|
317
|
+
"""
|
|
318
|
+
|
|
319
|
+
def __init__(self, model, dtype=torch.float32, device="cpu", tokenizer=None):
|
|
320
|
+
super().__init__()
|
|
321
|
+
self.model = model
|
|
322
|
+
self.dtype = dtype
|
|
323
|
+
self.device = device
|
|
324
|
+
self.tokenizer = tokenizer
|
|
325
|
+
|
|
326
|
+
def generate(
|
|
327
|
+
self,
|
|
328
|
+
input_ids,
|
|
329
|
+
**kwargs,
|
|
330
|
+
):
|
|
331
|
+
|
|
332
|
+
# Move input_ids to the same device as the model
|
|
333
|
+
input_ids = input_ids.to(self.device)
|
|
334
|
+
|
|
335
|
+
# Fix temperature handling to avoid errors:
|
|
336
|
+
# If temperature is 0.0, force do_sample=False (greedy decoding)
|
|
337
|
+
if kwargs.get("temperature") == 0.0:
|
|
338
|
+
kwargs["do_sample"] = False
|
|
339
|
+
|
|
340
|
+
# If do_sample is False and temperature is 0.0, remove temperature
|
|
341
|
+
# to avoid the warning from HuggingFace.
|
|
342
|
+
# Note: This is the same approach taken by LM Eval Harness for handling temperature.
|
|
343
|
+
generation_kwargs = {
|
|
344
|
+
"max_new_tokens": kwargs.get("max_new_tokens", 512),
|
|
345
|
+
"do_sample": kwargs.get("do_sample", True),
|
|
346
|
+
**kwargs,
|
|
347
|
+
}
|
|
348
|
+
|
|
349
|
+
with torch.no_grad(), torch.inference_mode():
|
|
350
|
+
outputs = self.model.generate(input_ids=input_ids, **generation_kwargs)
|
|
351
|
+
|
|
352
|
+
return outputs
|
|
353
|
+
|
|
354
|
+
def _model_call(self, input_tensor):
|
|
355
|
+
"""Forward pass through the model to get logits
|
|
356
|
+
|
|
357
|
+
This method directly calls the model forward pass rather than using model.generate() for
|
|
358
|
+
several important reasons:
|
|
359
|
+
1. Purpose: We need raw logits from a single forward pass, while generate() is for producing
|
|
360
|
+
multiple tokens through iterative inference
|
|
361
|
+
2. Efficiency: Direct calls are more efficient for logprob calculations with no sampling
|
|
362
|
+
overhead
|
|
363
|
+
3. Precision: Logprob calculations require exact control over input-to-output mapping
|
|
364
|
+
4. Consistency: Similar approach used in both HF and OGA implementations
|
|
365
|
+
|
|
366
|
+
Args:
|
|
367
|
+
input_tensor: Input token IDs tensor
|
|
368
|
+
|
|
369
|
+
Returns:
|
|
370
|
+
Logits tensor from model forward pass
|
|
371
|
+
"""
|
|
372
|
+
with torch.no_grad(), torch.inference_mode():
|
|
373
|
+
outputs = self.model(input_tensor)
|
|
374
|
+
return outputs.logits
|
|
375
|
+
|
|
376
|
+
def _select_cont_toks(self, logits, context_len, cont_toks):
|
|
377
|
+
"""
|
|
378
|
+
Select logits corresponding to continuation tokens and gather their probabilities
|
|
379
|
+
|
|
380
|
+
Args:
|
|
381
|
+
logits: Model output logits
|
|
382
|
+
context_len: Length of input context
|
|
383
|
+
cont_toks: List of continuation token IDs
|
|
384
|
+
|
|
385
|
+
Returns:
|
|
386
|
+
Tensor of log probabilities for continuation tokens
|
|
387
|
+
"""
|
|
388
|
+
# Get the continuation logits (discard context logits)
|
|
389
|
+
cont_logits = logits[context_len - 1 : context_len - 1 + len(cont_toks)]
|
|
390
|
+
|
|
391
|
+
# Convert cont_toks to tensor if needed
|
|
392
|
+
if not isinstance(cont_toks, torch.Tensor):
|
|
393
|
+
cont_toks = torch.tensor(cont_toks, dtype=torch.long, device=logits.device)
|
|
394
|
+
|
|
395
|
+
# Gather log probs at the corresponding token indices
|
|
396
|
+
log_probs = torch.log_softmax(cont_logits, dim=-1)
|
|
397
|
+
token_log_probs = torch.gather(log_probs, 1, cont_toks.unsqueeze(-1)).squeeze(
|
|
398
|
+
-1
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
return token_log_probs
|
|
402
|
+
|
|
403
|
+
def compute_logprobs(
|
|
404
|
+
self, text, tokenizer, prompt_length=None, logprobs=None, echo=False
|
|
405
|
+
):
|
|
406
|
+
"""
|
|
407
|
+
Compute log probabilities for all tokens in the given text.
|
|
408
|
+
|
|
409
|
+
Args:
|
|
410
|
+
text: The full text to analyze (e.g., prompt + completion)
|
|
411
|
+
prompt_length: Number of tokens in the prompt. If provided and echo=False,
|
|
412
|
+
only completion tokens after this position will be returned.
|
|
413
|
+
logprobs: If not None, return log probabilities. Value indicates how many top
|
|
414
|
+
alternatives to return. If True but not an integer, defaults to 5 alternatives.
|
|
415
|
+
echo: If True, include logprobs for prompt tokens. If False, only return logprobs
|
|
416
|
+
for completion tokens.
|
|
417
|
+
|
|
418
|
+
Returns:
|
|
419
|
+
- text_offset: Character offsets for each token in the text
|
|
420
|
+
- token_logprobs: Log probability for each token
|
|
421
|
+
- tokens: The actual tokens used
|
|
422
|
+
- top_logprobs: Top alternative log probabilities for each position
|
|
423
|
+
"""
|
|
424
|
+
if tokenizer is None:
|
|
425
|
+
raise ValueError("Tokenizer is required for logprob calculation")
|
|
426
|
+
|
|
427
|
+
# Encode the full text
|
|
428
|
+
tokens = tokenizer(text).input_ids
|
|
429
|
+
|
|
430
|
+
# Track character offsets for each token
|
|
431
|
+
text_offset = []
|
|
432
|
+
start_idx = 0
|
|
433
|
+
|
|
434
|
+
token_strings = []
|
|
435
|
+
for token_id in tokens:
|
|
436
|
+
token_str = tokenizer.decode([token_id])
|
|
437
|
+
token_strings.append(token_str)
|
|
438
|
+
|
|
439
|
+
# Calculate character offsets for tokens - handles cases where tokens
|
|
440
|
+
# may not directly match in the original text due to encoding differences,
|
|
441
|
+
# special characters, or tokenization artifacts
|
|
442
|
+
try:
|
|
443
|
+
pos = text[start_idx:].find(token_str)
|
|
444
|
+
if pos != -1:
|
|
445
|
+
text_offset.append(start_idx + pos)
|
|
446
|
+
start_idx += pos + len(token_str)
|
|
447
|
+
else:
|
|
448
|
+
text_offset.append(start_idx)
|
|
449
|
+
except (TypeError, ValueError, UnicodeError):
|
|
450
|
+
# Fallback to current position when matching fails due to encoding issues
|
|
451
|
+
text_offset.append(start_idx)
|
|
452
|
+
|
|
453
|
+
# Convert to tensor and get model output
|
|
454
|
+
input_tensor = torch.tensor([tokens], dtype=torch.long, device=self.device)
|
|
455
|
+
logits = self._model_call(input_tensor)[0]
|
|
456
|
+
|
|
457
|
+
# Calculate log probabilities for each token
|
|
458
|
+
all_log_probs = torch.log_softmax(logits, dim=-1)
|
|
459
|
+
|
|
460
|
+
# The first token doesn't have a conditional probability
|
|
461
|
+
# For tokens after the first, get the predicted probability
|
|
462
|
+
token_log_probs = []
|
|
463
|
+
top_logprobs_list = []
|
|
464
|
+
|
|
465
|
+
# For each position, get the actual token probability and top alternatives
|
|
466
|
+
for i in range(len(tokens)):
|
|
467
|
+
# Get previous token position logits
|
|
468
|
+
if i > 0: # First token has no preceding context
|
|
469
|
+
prev_logits = all_log_probs[i - 1]
|
|
470
|
+
curr_token_id = tokens[i]
|
|
471
|
+
# Get probability of the actual token that appeared
|
|
472
|
+
token_logprob = prev_logits[curr_token_id].item()
|
|
473
|
+
token_log_probs.append(token_logprob)
|
|
474
|
+
|
|
475
|
+
# Get top-k alternatives if requested
|
|
476
|
+
if logprobs is not None:
|
|
477
|
+
num_alternatives = logprobs if isinstance(logprobs, int) else 5
|
|
478
|
+
topk_values, topk_indices = torch.topk(
|
|
479
|
+
prev_logits, min(num_alternatives, prev_logits.size(-1))
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
# Create dictionary of token: logprob
|
|
483
|
+
position_logprobs = {}
|
|
484
|
+
for val, idx in zip(topk_values.tolist(), topk_indices.tolist()):
|
|
485
|
+
token_str = tokenizer.decode([idx])
|
|
486
|
+
position_logprobs[token_str] = val
|
|
487
|
+
|
|
488
|
+
top_logprobs_list.append(position_logprobs)
|
|
489
|
+
else:
|
|
490
|
+
# For the first token, we don't have a conditional probability
|
|
491
|
+
token_log_probs.append(None)
|
|
492
|
+
top_logprobs_list.append({})
|
|
493
|
+
|
|
494
|
+
# If we don't want to echo prompt tokens, filter them out
|
|
495
|
+
if not echo and prompt_length is not None:
|
|
496
|
+
# Ensure prompt_length is within bounds
|
|
497
|
+
prompt_length = min(prompt_length, len(tokens))
|
|
498
|
+
|
|
499
|
+
# Filter results to only include completion tokens
|
|
500
|
+
if prompt_length < len(tokens):
|
|
501
|
+
filtered_text_offset = text_offset[prompt_length:]
|
|
502
|
+
filtered_token_logprobs = token_log_probs[prompt_length:]
|
|
503
|
+
filtered_tokens = token_strings[prompt_length:]
|
|
504
|
+
filtered_top_logprobs = top_logprobs_list[prompt_length:]
|
|
505
|
+
|
|
506
|
+
return (
|
|
507
|
+
filtered_text_offset,
|
|
508
|
+
filtered_token_logprobs,
|
|
509
|
+
filtered_tokens,
|
|
510
|
+
filtered_top_logprobs,
|
|
511
|
+
)
|
|
512
|
+
else:
|
|
513
|
+
# No completion tokens
|
|
514
|
+
return [], [], [], []
|
|
515
|
+
|
|
516
|
+
return text_offset, token_log_probs, token_strings, top_logprobs_list
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
# This file was originally licensed under Apache 2.0. It has been modified.
|
|
520
|
+
# Modifications Copyright (c) 2025 AMD
|