lemonade-sdk 7.0.4__py3-none-any.whl → 8.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lemonade-sdk might be problematic. Click here for more details.
- lemonade/api.py +3 -3
- lemonade/cli.py +11 -17
- lemonade/common/build.py +0 -47
- lemonade/common/network.py +50 -0
- lemonade/common/status.py +2 -21
- lemonade/common/system_info.py +19 -4
- lemonade/profilers/memory_tracker.py +3 -1
- lemonade/tools/accuracy.py +3 -4
- lemonade/tools/adapter.py +1 -2
- lemonade/tools/{huggingface_bench.py → huggingface/bench.py} +2 -87
- lemonade/tools/huggingface/load.py +235 -0
- lemonade/tools/{huggingface_load.py → huggingface/utils.py} +87 -255
- lemonade/tools/humaneval.py +9 -3
- lemonade/tools/{llamacpp_bench.py → llamacpp/bench.py} +1 -1
- lemonade/tools/{llamacpp.py → llamacpp/load.py} +18 -2
- lemonade/tools/mmlu.py +7 -15
- lemonade/tools/{ort_genai/oga.py → oga/load.py} +31 -422
- lemonade/tools/oga/utils.py +423 -0
- lemonade/tools/perplexity.py +4 -3
- lemonade/tools/prompt.py +2 -1
- lemonade/tools/quark/quark_load.py +2 -1
- lemonade/tools/quark/quark_quantize.py +5 -5
- lemonade/tools/report/table.py +3 -3
- lemonade/tools/server/llamacpp.py +188 -45
- lemonade/tools/server/serve.py +184 -146
- lemonade/tools/server/static/favicon.ico +0 -0
- lemonade/tools/server/static/styles.css +568 -0
- lemonade/tools/server/static/webapp.html +439 -0
- lemonade/tools/server/tray.py +458 -0
- lemonade/tools/server/{port_utils.py → utils/port.py} +22 -3
- lemonade/tools/server/utils/system_tray.py +395 -0
- lemonade/tools/server/{instructions.py → webapp.py} +4 -10
- lemonade/version.py +1 -1
- lemonade_install/install.py +46 -28
- lemonade_sdk-8.0.1.dist-info/METADATA +179 -0
- lemonade_sdk-8.0.1.dist-info/RECORD +70 -0
- lemonade_server/cli.py +182 -27
- lemonade_server/model_manager.py +192 -20
- lemonade_server/pydantic_models.py +9 -4
- lemonade_server/server_models.json +5 -3
- lemonade/common/analyze_model.py +0 -26
- lemonade/common/labels.py +0 -61
- lemonade/common/onnx_helpers.py +0 -176
- lemonade/common/plugins.py +0 -10
- lemonade/common/tensor_helpers.py +0 -83
- lemonade/tools/server/static/instructions.html +0 -262
- lemonade_sdk-7.0.4.dist-info/METADATA +0 -113
- lemonade_sdk-7.0.4.dist-info/RECORD +0 -69
- /lemonade/tools/{ort_genai → oga}/__init__.py +0 -0
- /lemonade/tools/{ort_genai/oga_bench.py → oga/bench.py} +0 -0
- /lemonade/tools/server/{thread_utils.py → utils/thread.py} +0 -0
- {lemonade_sdk-7.0.4.dist-info → lemonade_sdk-8.0.1.dist-info}/WHEEL +0 -0
- {lemonade_sdk-7.0.4.dist-info → lemonade_sdk-8.0.1.dist-info}/entry_points.txt +0 -0
- {lemonade_sdk-7.0.4.dist-info → lemonade_sdk-8.0.1.dist-info}/licenses/LICENSE +0 -0
- {lemonade_sdk-7.0.4.dist-info → lemonade_sdk-8.0.1.dist-info}/licenses/NOTICE.md +0 -0
- {lemonade_sdk-7.0.4.dist-info → lemonade_sdk-8.0.1.dist-info}/top_level.txt +0 -0
|
@@ -7,11 +7,12 @@ import lemonade.common.status as status
|
|
|
7
7
|
from lemonade.tools import FirstTool
|
|
8
8
|
from lemonade.tools.adapter import PassthroughTokenizer, ModelAdapter
|
|
9
9
|
from lemonade.cache import Keys
|
|
10
|
-
from lemonade.tools.huggingface_load import get_base_model
|
|
11
10
|
|
|
12
11
|
|
|
13
12
|
class LlamaCppAdapter(ModelAdapter):
|
|
14
|
-
def __init__(
|
|
13
|
+
def __init__(
|
|
14
|
+
self, model, output_tokens, context_size, threads, executable, lib_dir=None
|
|
15
|
+
):
|
|
15
16
|
super().__init__()
|
|
16
17
|
|
|
17
18
|
self.model = os.path.normpath(model)
|
|
@@ -19,6 +20,7 @@ class LlamaCppAdapter(ModelAdapter):
|
|
|
19
20
|
self.context_size = context_size
|
|
20
21
|
self.threads = threads
|
|
21
22
|
self.executable = os.path.normpath(executable)
|
|
23
|
+
self.lib_dir = lib_dir
|
|
22
24
|
|
|
23
25
|
def generate(
|
|
24
26
|
self,
|
|
@@ -78,6 +80,15 @@ class LlamaCppAdapter(ModelAdapter):
|
|
|
78
80
|
cmd = [str(m) for m in cmd]
|
|
79
81
|
|
|
80
82
|
try:
|
|
83
|
+
# Set up environment with library path for Linux
|
|
84
|
+
env = os.environ.copy()
|
|
85
|
+
if self.lib_dir and os.name != "nt": # Not Windows
|
|
86
|
+
current_ld_path = env.get("LD_LIBRARY_PATH", "")
|
|
87
|
+
if current_ld_path:
|
|
88
|
+
env["LD_LIBRARY_PATH"] = f"{self.lib_dir}:{current_ld_path}"
|
|
89
|
+
else:
|
|
90
|
+
env["LD_LIBRARY_PATH"] = self.lib_dir
|
|
91
|
+
|
|
81
92
|
process = subprocess.Popen(
|
|
82
93
|
cmd,
|
|
83
94
|
stdout=subprocess.PIPE,
|
|
@@ -85,6 +96,7 @@ class LlamaCppAdapter(ModelAdapter):
|
|
|
85
96
|
universal_newlines=True,
|
|
86
97
|
encoding="utf-8",
|
|
87
98
|
errors="replace",
|
|
99
|
+
env=env,
|
|
88
100
|
)
|
|
89
101
|
|
|
90
102
|
raw_output, stderr = process.communicate(timeout=600)
|
|
@@ -208,11 +220,14 @@ class LoadLlamaCpp(FirstTool):
|
|
|
208
220
|
output_tokens: int = 512,
|
|
209
221
|
model_binary: Optional[str] = None,
|
|
210
222
|
executable: str = None,
|
|
223
|
+
lib_dir: Optional[str] = None,
|
|
211
224
|
) -> State:
|
|
212
225
|
"""
|
|
213
226
|
Load a llama.cpp model
|
|
214
227
|
"""
|
|
215
228
|
|
|
229
|
+
from lemonade.common.network import get_base_model
|
|
230
|
+
|
|
216
231
|
if executable is None:
|
|
217
232
|
raise Exception(f"{self.__class__.unique_name} requires an executable path")
|
|
218
233
|
|
|
@@ -241,6 +256,7 @@ class LoadLlamaCpp(FirstTool):
|
|
|
241
256
|
context_size=context_size,
|
|
242
257
|
threads=threads,
|
|
243
258
|
executable=executable,
|
|
259
|
+
lib_dir=lib_dir,
|
|
244
260
|
)
|
|
245
261
|
state.tokenizer = PassthroughTokenizer()
|
|
246
262
|
|
lemonade/tools/mmlu.py
CHANGED
|
@@ -4,9 +4,6 @@ import tarfile
|
|
|
4
4
|
from pathlib import Path
|
|
5
5
|
from typing import List, Optional
|
|
6
6
|
import subprocess
|
|
7
|
-
import numpy as np
|
|
8
|
-
import pandas as pd
|
|
9
|
-
import requests
|
|
10
7
|
from lemonade.state import State
|
|
11
8
|
from lemonade.tools import Tool
|
|
12
9
|
import lemonade.common.printing as printing
|
|
@@ -84,6 +81,9 @@ class AccuracyMMLU(Tool):
|
|
|
84
81
|
tests: List[str] = None,
|
|
85
82
|
) -> State:
|
|
86
83
|
|
|
84
|
+
import numpy as np
|
|
85
|
+
import pandas as pd
|
|
86
|
+
|
|
87
87
|
if data_dir:
|
|
88
88
|
data_dir_to_use = data_dir
|
|
89
89
|
else:
|
|
@@ -224,18 +224,6 @@ class AccuracyMMLU(Tool):
|
|
|
224
224
|
return state
|
|
225
225
|
|
|
226
226
|
|
|
227
|
-
def _list_tests(data_dir):
|
|
228
|
-
"""Lists all available tests based on the files in the test data directory."""
|
|
229
|
-
test_files = [
|
|
230
|
-
f for f in os.listdir(os.path.join(data_dir, "test")) if f.endswith("_test.csv")
|
|
231
|
-
]
|
|
232
|
-
print(
|
|
233
|
-
"Available tests:",
|
|
234
|
-
*[f.replace("_test.csv", "") for f in sorted(test_files)],
|
|
235
|
-
sep="\n",
|
|
236
|
-
)
|
|
237
|
-
|
|
238
|
-
|
|
239
227
|
def _format_subject(subject):
|
|
240
228
|
"""Formats a subject string by replacing underscores with spaces."""
|
|
241
229
|
return " ".join(subject.split("_"))
|
|
@@ -243,6 +231,8 @@ def _format_subject(subject):
|
|
|
243
231
|
|
|
244
232
|
def _safe_read_csv(path):
|
|
245
233
|
"""Safely reads a CSV file and returns a DataFrame."""
|
|
234
|
+
import pandas as pd
|
|
235
|
+
|
|
246
236
|
try:
|
|
247
237
|
return pd.read_csv(path, header=None)
|
|
248
238
|
except FileNotFoundError:
|
|
@@ -292,6 +282,8 @@ def download_and_extract_dataset(data_cache_dir: str, dataset_url: str):
|
|
|
292
282
|
Download the dataset from the given URL and extract it into the target directory.
|
|
293
283
|
"""
|
|
294
284
|
|
|
285
|
+
import requests
|
|
286
|
+
|
|
295
287
|
# Create the directory if it does not exist
|
|
296
288
|
Path(data_cache_dir).mkdir(parents=True, exist_ok=True)
|
|
297
289
|
|
|
@@ -10,28 +10,16 @@
|
|
|
10
10
|
|
|
11
11
|
import argparse
|
|
12
12
|
import os
|
|
13
|
-
import time
|
|
14
13
|
import json
|
|
15
14
|
import shutil
|
|
16
|
-
import logging
|
|
17
15
|
from fnmatch import fnmatch
|
|
18
|
-
from queue import Queue
|
|
19
16
|
import subprocess
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
import onnxruntime_genai as og
|
|
23
|
-
import onnxruntime_genai.models.builder as model_builder
|
|
24
|
-
from transformers import AutoTokenizer
|
|
17
|
+
|
|
18
|
+
|
|
25
19
|
from lemonade.state import State
|
|
26
20
|
from lemonade.tools import FirstTool
|
|
27
21
|
import lemonade.common.status as status
|
|
28
22
|
import lemonade.common.printing as printing
|
|
29
|
-
from lemonade.tools.huggingface_load import get_base_model, is_offline
|
|
30
|
-
from lemonade.tools.adapter import (
|
|
31
|
-
ModelAdapter,
|
|
32
|
-
TokenizerAdapter,
|
|
33
|
-
PassthroughTokenizerResult,
|
|
34
|
-
)
|
|
35
23
|
from lemonade.cache import Keys
|
|
36
24
|
from lemonade_install.install import (
|
|
37
25
|
get_ryzen_ai_version_info,
|
|
@@ -57,414 +45,16 @@ execution_providers = {
|
|
|
57
45
|
}
|
|
58
46
|
|
|
59
47
|
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
def __call__(self, prompt: str, return_tensors="np"):
|
|
72
|
-
tokens = self.tokenizer.encode(prompt)
|
|
73
|
-
return PassthroughTokenizerResult(tokens)
|
|
74
|
-
|
|
75
|
-
# pylint: disable=unused-argument
|
|
76
|
-
def decode(self, response, skip_special_tokens=True) -> str:
|
|
77
|
-
return self.tokenizer.decode(response)
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
class OrtGenaiStreamer:
|
|
81
|
-
def __init__(self, tokenizer: OrtGenaiTokenizer, timeout=None):
|
|
82
|
-
self.tokenizer = tokenizer
|
|
83
|
-
self.text_queue = Queue()
|
|
84
|
-
self.stop_signal = None
|
|
85
|
-
self.timeout = timeout
|
|
86
|
-
|
|
87
|
-
def add_text(self, text: str):
|
|
88
|
-
self.text_queue.put(text, timeout=self.timeout)
|
|
89
|
-
|
|
90
|
-
def done(self):
|
|
91
|
-
self.text_queue.put(self.stop_signal, timeout=self.timeout)
|
|
92
|
-
|
|
93
|
-
def __iter__(self):
|
|
94
|
-
return self
|
|
95
|
-
|
|
96
|
-
def __next__(self):
|
|
97
|
-
value = self.text_queue.get(timeout=self.timeout)
|
|
98
|
-
if value == self.stop_signal:
|
|
99
|
-
raise StopIteration()
|
|
100
|
-
else:
|
|
101
|
-
return value
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
class OrtGenaiModel(ModelAdapter):
|
|
105
|
-
|
|
106
|
-
def __init__(self, input_folder):
|
|
107
|
-
super().__init__()
|
|
108
|
-
self.model = og.Model(input_folder)
|
|
109
|
-
self.type = "ort-genai"
|
|
110
|
-
self.config = self.load_config(input_folder)
|
|
111
|
-
|
|
112
|
-
def load_config(self, input_folder):
|
|
113
|
-
rai_config_path = os.path.join(input_folder, "rai_config.json")
|
|
114
|
-
if os.path.exists(rai_config_path):
|
|
115
|
-
with open(rai_config_path, "r", encoding="utf-8") as f:
|
|
116
|
-
max_prompt_length = json.load(f)["max_prompt_length"]["1.4.1"]
|
|
117
|
-
else:
|
|
118
|
-
max_prompt_length = None
|
|
119
|
-
|
|
120
|
-
config_path = os.path.join(input_folder, "genai_config.json")
|
|
121
|
-
if os.path.exists(config_path):
|
|
122
|
-
with open(config_path, "r", encoding="utf-8") as f:
|
|
123
|
-
config_dict = json.load(f)
|
|
124
|
-
if max_prompt_length:
|
|
125
|
-
config_dict["max_prompt_length"] = max_prompt_length
|
|
126
|
-
return config_dict
|
|
127
|
-
return None
|
|
128
|
-
|
|
129
|
-
def generate(
|
|
130
|
-
self,
|
|
131
|
-
input_ids,
|
|
132
|
-
max_new_tokens=512,
|
|
133
|
-
min_new_tokens=0,
|
|
134
|
-
do_sample=True,
|
|
135
|
-
top_k=50,
|
|
136
|
-
top_p=1.0,
|
|
137
|
-
temperature=0.7,
|
|
138
|
-
streamer: OrtGenaiStreamer = None,
|
|
139
|
-
pad_token_id=None,
|
|
140
|
-
stopping_criteria=None,
|
|
141
|
-
max_length=None,
|
|
142
|
-
random_seed=1,
|
|
143
|
-
):
|
|
144
|
-
params = og.GeneratorParams(self.model)
|
|
145
|
-
|
|
146
|
-
prompt_length = len(input_ids)
|
|
147
|
-
max_prompt_length = self.config.get("max_prompt_length")
|
|
148
|
-
if max_prompt_length and prompt_length > max_prompt_length:
|
|
149
|
-
raise ValueError(
|
|
150
|
-
f"This prompt (length {prompt_length}) exceeds the model's "
|
|
151
|
-
f"maximum allowed prompt length ({max_prompt_length})."
|
|
152
|
-
)
|
|
153
|
-
|
|
154
|
-
# There is a breaking API change in OGA 0.6.0
|
|
155
|
-
# Determine whether we should use the old or new APIs
|
|
156
|
-
# This also supports 0.6.0.dev0, which evaluates to less than 0.6.0 in Version
|
|
157
|
-
use_oga_post_6_api = (
|
|
158
|
-
Version(og.__version__) >= Version("0.6.0") or "0.6.0" in og.__version__
|
|
159
|
-
)
|
|
160
|
-
use_oga_pre_6_api = not use_oga_post_6_api
|
|
161
|
-
|
|
162
|
-
if pad_token_id:
|
|
163
|
-
params.pad_token_id = pad_token_id
|
|
164
|
-
|
|
165
|
-
# Handle max_length and max_new_tokens
|
|
166
|
-
if max_length and max_new_tokens:
|
|
167
|
-
logging.warning(
|
|
168
|
-
"Both max_length and max_new_tokens were provided. "
|
|
169
|
-
"max_length will take precedence. "
|
|
170
|
-
"When setting max_length, please explicitly set max_new_tokens to None."
|
|
171
|
-
)
|
|
172
|
-
max_length_to_use = None
|
|
173
|
-
if max_length:
|
|
174
|
-
max_length_to_use = max_length
|
|
175
|
-
elif max_new_tokens:
|
|
176
|
-
max_length_to_use = prompt_length + max_new_tokens
|
|
177
|
-
|
|
178
|
-
min_length = prompt_length + min_new_tokens
|
|
179
|
-
|
|
180
|
-
if use_oga_pre_6_api:
|
|
181
|
-
params.input_ids = input_ids
|
|
182
|
-
|
|
183
|
-
if random_seed is None:
|
|
184
|
-
random_seed = -1 # In og.Generator, -1 = seed with random device
|
|
185
|
-
|
|
186
|
-
if self.config and "search" in self.config:
|
|
187
|
-
search_config = self.config["search"]
|
|
188
|
-
params.set_search_options(
|
|
189
|
-
do_sample=search_config.get("do_sample", do_sample),
|
|
190
|
-
top_k=search_config.get("top_k", top_k),
|
|
191
|
-
top_p=search_config.get("top_p", top_p),
|
|
192
|
-
temperature=search_config.get("temperature", temperature),
|
|
193
|
-
max_length=max_length_to_use,
|
|
194
|
-
min_length=min_length,
|
|
195
|
-
early_stopping=search_config.get("early_stopping", False),
|
|
196
|
-
length_penalty=search_config.get("length_penalty", 1.0),
|
|
197
|
-
num_beams=search_config.get("num_beams", 1),
|
|
198
|
-
num_return_sequences=search_config.get("num_return_sequences", 1),
|
|
199
|
-
repetition_penalty=search_config.get("repetition_penalty", 1.0),
|
|
200
|
-
past_present_share_buffer=search_config.get(
|
|
201
|
-
"past_present_share_buffer", True
|
|
202
|
-
),
|
|
203
|
-
random_seed=random_seed,
|
|
204
|
-
# Not currently supported by OGA
|
|
205
|
-
# diversity_penalty=search_config.get('diversity_penalty', 0.0),
|
|
206
|
-
# no_repeat_ngram_size=search_config.get('no_repeat_ngram_size', 0),
|
|
207
|
-
)
|
|
208
|
-
else:
|
|
209
|
-
params.set_search_options(
|
|
210
|
-
do_sample=do_sample,
|
|
211
|
-
top_k=top_k,
|
|
212
|
-
top_p=top_p,
|
|
213
|
-
temperature=temperature,
|
|
214
|
-
max_length=max_length_to_use,
|
|
215
|
-
min_length=min_length,
|
|
216
|
-
random_seed=random_seed,
|
|
217
|
-
)
|
|
218
|
-
params.try_graph_capture_with_max_batch_size(1)
|
|
219
|
-
|
|
220
|
-
generator = og.Generator(self.model, params)
|
|
221
|
-
|
|
222
|
-
if streamer is None:
|
|
223
|
-
prompt_start_time = time.perf_counter()
|
|
224
|
-
if use_oga_post_6_api:
|
|
225
|
-
generator.append_tokens(input_ids)
|
|
226
|
-
if use_oga_pre_6_api:
|
|
227
|
-
generator.compute_logits()
|
|
228
|
-
generator.generate_next_token()
|
|
229
|
-
prompt_end_time = time.perf_counter()
|
|
230
|
-
|
|
231
|
-
self.time_to_first_token = prompt_end_time - prompt_start_time
|
|
232
|
-
|
|
233
|
-
if max_new_tokens > 1:
|
|
234
|
-
|
|
235
|
-
token_gen_times = []
|
|
236
|
-
while not generator.is_done():
|
|
237
|
-
token_gen_start_time = time.perf_counter()
|
|
238
|
-
if use_oga_pre_6_api:
|
|
239
|
-
generator.compute_logits()
|
|
240
|
-
generator.generate_next_token()
|
|
241
|
-
token_gen_end_time = time.perf_counter()
|
|
242
|
-
|
|
243
|
-
token_gen_times.append(token_gen_end_time - token_gen_start_time)
|
|
244
|
-
|
|
245
|
-
if token_gen_times:
|
|
246
|
-
# List will be empty if we generated 1 or 0 tokens, and we don't
|
|
247
|
-
# want a divide-by-zero error in those cases
|
|
248
|
-
avg_token_gen_latency_s = sum(token_gen_times) / len(
|
|
249
|
-
token_gen_times
|
|
250
|
-
)
|
|
251
|
-
self.tokens_per_second = 1 / avg_token_gen_latency_s
|
|
252
|
-
|
|
253
|
-
return [generator.get_sequence(0)]
|
|
254
|
-
else:
|
|
255
|
-
if use_oga_post_6_api:
|
|
256
|
-
generator.append_tokens(input_ids)
|
|
257
|
-
tokenizer_stream = streamer.tokenizer.tokenizer.create_stream()
|
|
258
|
-
|
|
259
|
-
stop_early = False
|
|
260
|
-
|
|
261
|
-
while not generator.is_done() and not stop_early:
|
|
262
|
-
if use_oga_pre_6_api:
|
|
263
|
-
generator.compute_logits()
|
|
264
|
-
generator.generate_next_token()
|
|
265
|
-
|
|
266
|
-
new_token = generator.get_next_tokens()[0]
|
|
267
|
-
new_text = tokenizer_stream.decode(new_token)
|
|
268
|
-
|
|
269
|
-
streamer.add_text(new_text)
|
|
270
|
-
|
|
271
|
-
if stopping_criteria is not None:
|
|
272
|
-
if stopping_criteria[0].stop_event.is_set():
|
|
273
|
-
stop_early = True
|
|
274
|
-
|
|
275
|
-
streamer.done()
|
|
276
|
-
|
|
277
|
-
def _model_call(self, input_ids):
|
|
278
|
-
"""
|
|
279
|
-
Run the model on input_ids and get logits.
|
|
280
|
-
|
|
281
|
-
This method directly accesses model logits rather than using the full generate pipeline for
|
|
282
|
-
several important reasons:
|
|
283
|
-
1. Purpose: We need raw logits from a single forward pass, while generate() is optimized for
|
|
284
|
-
producing multiple tokens through iterative inference
|
|
285
|
-
2. Efficiency: Direct access is more efficient for logprob calculations with no
|
|
286
|
-
sampling overhead
|
|
287
|
-
3. Precision: Logprob calculations require exact control over input-to-output mapping
|
|
288
|
-
4. Consistency: Similar approach used in both HF and OGA implementations
|
|
289
|
-
|
|
290
|
-
Args:
|
|
291
|
-
input_ids: Input token IDs
|
|
292
|
-
|
|
293
|
-
Returns:
|
|
294
|
-
Logits for each token in the sequence
|
|
295
|
-
"""
|
|
296
|
-
import torch
|
|
297
|
-
|
|
298
|
-
# Setup generator params
|
|
299
|
-
params = og.GeneratorParams(self.model)
|
|
300
|
-
|
|
301
|
-
# Configure for a simple forward pass
|
|
302
|
-
params.set_search_options(
|
|
303
|
-
do_sample=False,
|
|
304
|
-
temperature=0.0,
|
|
305
|
-
max_length=len(input_ids),
|
|
306
|
-
)
|
|
307
|
-
|
|
308
|
-
# Initialize generator
|
|
309
|
-
generator = og.Generator(self.model, params)
|
|
310
|
-
|
|
311
|
-
# Feed tokens to model based on API version
|
|
312
|
-
generator.append_tokens(input_ids)
|
|
313
|
-
|
|
314
|
-
# Extract logits - this returns a list of logits tensors
|
|
315
|
-
logits = generator.get_output("logits")
|
|
316
|
-
|
|
317
|
-
# Convert to torch tensor for easier processing
|
|
318
|
-
return torch.tensor(logits[0])
|
|
319
|
-
|
|
320
|
-
def _select_cont_toks(self, logits, context_len, continuation_tokens):
|
|
321
|
-
"""
|
|
322
|
-
Select and process logits for continuation tokens.
|
|
323
|
-
|
|
324
|
-
Args:
|
|
325
|
-
logits: Full sequence logits
|
|
326
|
-
context_len: Length of context tokens
|
|
327
|
-
continuation_tokens: List or tensor of continuation token IDs
|
|
328
|
-
|
|
329
|
-
Returns:
|
|
330
|
-
Log probabilities for continuation tokens
|
|
331
|
-
"""
|
|
332
|
-
import torch
|
|
333
|
-
|
|
334
|
-
# Extract relevant logits for continuation prediction (shift by one)
|
|
335
|
-
cont_logits = logits[
|
|
336
|
-
context_len - 1 : context_len - 1 + len(continuation_tokens)
|
|
337
|
-
]
|
|
338
|
-
|
|
339
|
-
# Convert to torch tensors if needed
|
|
340
|
-
if not isinstance(continuation_tokens, torch.Tensor):
|
|
341
|
-
continuation_tokens = torch.tensor(continuation_tokens, dtype=torch.long)
|
|
342
|
-
|
|
343
|
-
# Apply log softmax to get log probabilities
|
|
344
|
-
log_probs = torch.log_softmax(cont_logits, dim=-1)
|
|
345
|
-
|
|
346
|
-
# Get log probs for the specific continuation tokens
|
|
347
|
-
token_log_probs = torch.gather(
|
|
348
|
-
log_probs, 1, continuation_tokens.unsqueeze(-1)
|
|
349
|
-
).squeeze(-1)
|
|
350
|
-
|
|
351
|
-
return token_log_probs
|
|
352
|
-
|
|
353
|
-
def compute_logprobs(
|
|
354
|
-
self, text, tokenizer, prompt_length=None, logprobs=None, echo=False
|
|
355
|
-
):
|
|
356
|
-
"""
|
|
357
|
-
Compute log probabilities for all tokens in the given text.
|
|
358
|
-
|
|
359
|
-
Args:
|
|
360
|
-
text: The full text to analyze (e.g., prompt + completion)
|
|
361
|
-
prompt_length: Number of tokens in the prompt. If provided and echo=False,
|
|
362
|
-
only completion tokens after this position will be returned.
|
|
363
|
-
logprobs: If not None, return log probabilities. Value indicates how many top
|
|
364
|
-
alternatives to return. If True but not an integer, defaults to 5 alternatives.
|
|
365
|
-
echo: If True, include logprobs for prompt tokens. If False, only return logprobs
|
|
366
|
-
for completion tokens.
|
|
367
|
-
|
|
368
|
-
Returns:
|
|
369
|
-
- text_offset: Character offsets for each token in the text
|
|
370
|
-
- token_logprobs: Log probability for each token
|
|
371
|
-
- tokens: The actual tokens used
|
|
372
|
-
- top_logprobs: Top alternative log probabilities for each position
|
|
373
|
-
"""
|
|
374
|
-
import torch
|
|
375
|
-
|
|
376
|
-
if tokenizer is None:
|
|
377
|
-
raise ValueError("Tokenizer is required for logprob calculation")
|
|
378
|
-
|
|
379
|
-
# Encode the full text
|
|
380
|
-
tokens = tokenizer(text).input_ids # pylint: disable=E1102
|
|
381
|
-
|
|
382
|
-
# Track character offsets for each token
|
|
383
|
-
text_offset = []
|
|
384
|
-
start_idx = 0
|
|
385
|
-
|
|
386
|
-
token_strings = []
|
|
387
|
-
for token_id in tokens:
|
|
388
|
-
token_str = tokenizer.decode([token_id])
|
|
389
|
-
token_strings.append(token_str)
|
|
390
|
-
|
|
391
|
-
# Calculate character offsets for tokens - handles cases where tokens
|
|
392
|
-
# may not directly match in the original text due to encoding differences,
|
|
393
|
-
# special characters, or tokenization artifacts
|
|
394
|
-
try:
|
|
395
|
-
pos = text[start_idx:].find(token_str)
|
|
396
|
-
if pos != -1:
|
|
397
|
-
text_offset.append(start_idx + pos)
|
|
398
|
-
start_idx += pos + len(token_str)
|
|
399
|
-
else:
|
|
400
|
-
text_offset.append(start_idx)
|
|
401
|
-
except (TypeError, ValueError, UnicodeError):
|
|
402
|
-
# Fallback to current position when matching fails due to encoding issues
|
|
403
|
-
text_offset.append(start_idx)
|
|
404
|
-
|
|
405
|
-
# Get logits from model
|
|
406
|
-
logits = self._model_call(tokens)
|
|
407
|
-
|
|
408
|
-
# Calculate log probabilities for each token
|
|
409
|
-
all_log_probs = torch.log_softmax(logits, dim=-1)
|
|
410
|
-
|
|
411
|
-
# The first token doesn't have a conditional probability
|
|
412
|
-
# For tokens after the first, get the predicted probability
|
|
413
|
-
token_log_probs = []
|
|
414
|
-
top_logprobs_list = []
|
|
415
|
-
|
|
416
|
-
# For each position, get the actual token probability and top alternatives
|
|
417
|
-
for i in range(len(tokens)):
|
|
418
|
-
# Get previous token position logits
|
|
419
|
-
if i > 0: # First token has no preceding context
|
|
420
|
-
prev_logits = all_log_probs[i - 1]
|
|
421
|
-
curr_token_id = tokens[i]
|
|
422
|
-
# Get probability of the actual token that appeared
|
|
423
|
-
token_logprob = prev_logits[curr_token_id].item()
|
|
424
|
-
token_log_probs.append(token_logprob)
|
|
425
|
-
|
|
426
|
-
# Get top-k alternatives if requested
|
|
427
|
-
if logprobs is not None:
|
|
428
|
-
num_alternatives = logprobs if isinstance(logprobs, int) else 5
|
|
429
|
-
topk_values, topk_indices = torch.topk(
|
|
430
|
-
prev_logits, min(num_alternatives, prev_logits.size(-1))
|
|
431
|
-
)
|
|
432
|
-
|
|
433
|
-
# Create dictionary of token: logprob
|
|
434
|
-
position_logprobs = {}
|
|
435
|
-
for val, idx in zip(topk_values.tolist(), topk_indices.tolist()):
|
|
436
|
-
token_str = tokenizer.decode([idx])
|
|
437
|
-
position_logprobs[token_str] = val
|
|
438
|
-
|
|
439
|
-
top_logprobs_list.append(position_logprobs)
|
|
440
|
-
else:
|
|
441
|
-
# For the first token, we don't have a conditional probability
|
|
442
|
-
token_log_probs.append(None)
|
|
443
|
-
top_logprobs_list.append({})
|
|
444
|
-
|
|
445
|
-
# If we don't want to echo prompt tokens, filter them out
|
|
446
|
-
if not echo and prompt_length is not None:
|
|
447
|
-
# Ensure prompt_length is within bounds
|
|
448
|
-
prompt_length = min(prompt_length, len(tokens))
|
|
449
|
-
|
|
450
|
-
# Filter results to only include completion tokens
|
|
451
|
-
if prompt_length < len(tokens):
|
|
452
|
-
filtered_text_offset = text_offset[prompt_length:]
|
|
453
|
-
filtered_token_logprobs = token_log_probs[prompt_length:]
|
|
454
|
-
filtered_tokens = token_strings[prompt_length:]
|
|
455
|
-
filtered_top_logprobs = top_logprobs_list[prompt_length:]
|
|
456
|
-
|
|
457
|
-
return (
|
|
458
|
-
filtered_text_offset,
|
|
459
|
-
filtered_token_logprobs,
|
|
460
|
-
filtered_tokens,
|
|
461
|
-
filtered_top_logprobs,
|
|
462
|
-
)
|
|
463
|
-
else:
|
|
464
|
-
# No completion tokens
|
|
465
|
-
return [], [], [], []
|
|
466
|
-
|
|
467
|
-
return text_offset, token_log_probs, token_strings, top_logprobs_list
|
|
48
|
+
def import_error_heler(e: Exception):
|
|
49
|
+
"""
|
|
50
|
+
Print a helpful message in the event of an import error
|
|
51
|
+
"""
|
|
52
|
+
raise ImportError(
|
|
53
|
+
f"{e}\n Please install lemonade-sdk with "
|
|
54
|
+
"one of the llm-oga extras, for example:\n"
|
|
55
|
+
"pip install lemonade-sdk[llm-oga-cpu]\n"
|
|
56
|
+
"See https://lemonade_server.ai/install_options.html for details"
|
|
57
|
+
)
|
|
468
58
|
|
|
469
59
|
|
|
470
60
|
class OgaLoad(FirstTool):
|
|
@@ -624,6 +214,8 @@ class OgaLoad(FirstTool):
|
|
|
624
214
|
files that have locally been quantized/converted to OGA format and any other
|
|
625
215
|
models that have been manually added by the user.
|
|
626
216
|
"""
|
|
217
|
+
from huggingface_hub import snapshot_download
|
|
218
|
+
|
|
627
219
|
if subfolder is None:
|
|
628
220
|
subfolder = f"{execution_providers[device]}-{dtype}"
|
|
629
221
|
subfolder += (
|
|
@@ -749,6 +341,12 @@ class OgaLoad(FirstTool):
|
|
|
749
341
|
Uses OGA model builder to quantize safetensors format model and convert to ONNX
|
|
750
342
|
format. The model files are saved to the full_model_path folder.
|
|
751
343
|
"""
|
|
344
|
+
|
|
345
|
+
try:
|
|
346
|
+
import onnxruntime_genai.models.builder as model_builder
|
|
347
|
+
except ImportError as e:
|
|
348
|
+
import_error_heler(e)
|
|
349
|
+
|
|
752
350
|
printing.log_info(f"Building {checkpoint} for {device} using {dtype}")
|
|
753
351
|
extra_options = {}
|
|
754
352
|
if int4_block_size is not None:
|
|
@@ -837,6 +435,14 @@ class OgaLoad(FirstTool):
|
|
|
837
435
|
Loads the OGA model from local folder and then loads the tokenizer.
|
|
838
436
|
Will auto-detect if we're offline.
|
|
839
437
|
"""
|
|
438
|
+
|
|
439
|
+
try:
|
|
440
|
+
from transformers import AutoTokenizer
|
|
441
|
+
from lemonade.tools.oga.utils import OrtGenaiModel, OrtGenaiTokenizer
|
|
442
|
+
from lemonade.common.network import is_offline
|
|
443
|
+
except ImportError as e:
|
|
444
|
+
import_error_heler(e)
|
|
445
|
+
|
|
840
446
|
try:
|
|
841
447
|
state.model = OrtGenaiModel(full_model_path)
|
|
842
448
|
except Exception as e:
|
|
@@ -945,6 +551,9 @@ class OgaLoad(FirstTool):
|
|
|
945
551
|
trust_remote_code=False,
|
|
946
552
|
subfolder: str = None,
|
|
947
553
|
) -> State:
|
|
554
|
+
from huggingface_hub import snapshot_download
|
|
555
|
+
from lemonade.common.network import get_base_model, is_offline
|
|
556
|
+
|
|
948
557
|
# Auto-detect offline status
|
|
949
558
|
offline = is_offline()
|
|
950
559
|
if offline:
|