lemonade-sdk 9.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lemonade/__init__.py +5 -0
- lemonade/api.py +180 -0
- lemonade/cache.py +92 -0
- lemonade/cli.py +173 -0
- lemonade/common/__init__.py +0 -0
- lemonade/common/build.py +176 -0
- lemonade/common/cli_helpers.py +139 -0
- lemonade/common/exceptions.py +98 -0
- lemonade/common/filesystem.py +368 -0
- lemonade/common/inference_engines.py +408 -0
- lemonade/common/network.py +93 -0
- lemonade/common/printing.py +110 -0
- lemonade/common/status.py +471 -0
- lemonade/common/system_info.py +1411 -0
- lemonade/common/test_helpers.py +28 -0
- lemonade/profilers/__init__.py +1 -0
- lemonade/profilers/agt_power.py +437 -0
- lemonade/profilers/hwinfo_power.py +429 -0
- lemonade/profilers/memory_tracker.py +259 -0
- lemonade/profilers/profiler.py +58 -0
- lemonade/sequence.py +363 -0
- lemonade/state.py +159 -0
- lemonade/tools/__init__.py +1 -0
- lemonade/tools/accuracy.py +432 -0
- lemonade/tools/adapter.py +114 -0
- lemonade/tools/bench.py +302 -0
- lemonade/tools/flm/__init__.py +1 -0
- lemonade/tools/flm/utils.py +305 -0
- lemonade/tools/huggingface/bench.py +187 -0
- lemonade/tools/huggingface/load.py +235 -0
- lemonade/tools/huggingface/utils.py +359 -0
- lemonade/tools/humaneval.py +264 -0
- lemonade/tools/llamacpp/bench.py +255 -0
- lemonade/tools/llamacpp/load.py +222 -0
- lemonade/tools/llamacpp/utils.py +1260 -0
- lemonade/tools/management_tools.py +319 -0
- lemonade/tools/mmlu.py +319 -0
- lemonade/tools/oga/__init__.py +0 -0
- lemonade/tools/oga/bench.py +120 -0
- lemonade/tools/oga/load.py +804 -0
- lemonade/tools/oga/migration.py +403 -0
- lemonade/tools/oga/utils.py +462 -0
- lemonade/tools/perplexity.py +147 -0
- lemonade/tools/prompt.py +263 -0
- lemonade/tools/report/__init__.py +0 -0
- lemonade/tools/report/llm_report.py +203 -0
- lemonade/tools/report/table.py +899 -0
- lemonade/tools/server/__init__.py +0 -0
- lemonade/tools/server/flm.py +133 -0
- lemonade/tools/server/llamacpp.py +320 -0
- lemonade/tools/server/serve.py +2123 -0
- lemonade/tools/server/static/favicon.ico +0 -0
- lemonade/tools/server/static/index.html +279 -0
- lemonade/tools/server/static/js/chat.js +1059 -0
- lemonade/tools/server/static/js/model-settings.js +183 -0
- lemonade/tools/server/static/js/models.js +1395 -0
- lemonade/tools/server/static/js/shared.js +556 -0
- lemonade/tools/server/static/logs.html +191 -0
- lemonade/tools/server/static/styles.css +2654 -0
- lemonade/tools/server/static/webapp.html +321 -0
- lemonade/tools/server/tool_calls.py +153 -0
- lemonade/tools/server/tray.py +664 -0
- lemonade/tools/server/utils/macos_tray.py +226 -0
- lemonade/tools/server/utils/port.py +77 -0
- lemonade/tools/server/utils/thread.py +85 -0
- lemonade/tools/server/utils/windows_tray.py +408 -0
- lemonade/tools/server/webapp.py +34 -0
- lemonade/tools/server/wrapped_server.py +559 -0
- lemonade/tools/tool.py +374 -0
- lemonade/version.py +1 -0
- lemonade_install/__init__.py +1 -0
- lemonade_install/install.py +239 -0
- lemonade_sdk-9.1.1.dist-info/METADATA +276 -0
- lemonade_sdk-9.1.1.dist-info/RECORD +84 -0
- lemonade_sdk-9.1.1.dist-info/WHEEL +5 -0
- lemonade_sdk-9.1.1.dist-info/entry_points.txt +5 -0
- lemonade_sdk-9.1.1.dist-info/licenses/LICENSE +201 -0
- lemonade_sdk-9.1.1.dist-info/licenses/NOTICE.md +47 -0
- lemonade_sdk-9.1.1.dist-info/top_level.txt +3 -0
- lemonade_server/cli.py +805 -0
- lemonade_server/model_manager.py +758 -0
- lemonade_server/pydantic_models.py +159 -0
- lemonade_server/server_models.json +643 -0
- lemonade_server/settings.py +39 -0
|
@@ -0,0 +1,432 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
import socket
|
|
5
|
+
import subprocess
|
|
6
|
+
import sys
|
|
7
|
+
import time
|
|
8
|
+
from typing import Optional
|
|
9
|
+
|
|
10
|
+
from lemonade.state import State
|
|
11
|
+
from lemonade.tools import Tool
|
|
12
|
+
import lemonade.common.printing as printing
|
|
13
|
+
import lemonade.common.build as build
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def is_port_in_use(port, host="localhost"):
|
|
17
|
+
"""
|
|
18
|
+
Check if a port is in use
|
|
19
|
+
"""
|
|
20
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
21
|
+
return s.connect_ex((host, port)) == 0
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class LMEvalHarness(Tool):
|
|
25
|
+
"""
|
|
26
|
+
Tool for evaluating LLMs using lm-eval-harness on industry standard benchmarks
|
|
27
|
+
like MMLU, GSM8k, and more. See docs/lemonade/lm_eval.md for more details.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
unique_name = "lm-eval-harness"
|
|
31
|
+
|
|
32
|
+
def __init__(self):
|
|
33
|
+
super().__init__(
|
|
34
|
+
monitor_message="Evaluate model accuracy using ElutherAI's lm-eval-harness"
|
|
35
|
+
)
|
|
36
|
+
self.status_stats = []
|
|
37
|
+
self.server_runner = None
|
|
38
|
+
|
|
39
|
+
@staticmethod
|
|
40
|
+
def parser(add_help: bool = True) -> argparse.ArgumentParser:
|
|
41
|
+
parser = __class__.helpful_parser(
|
|
42
|
+
short_description="Evaluate model using lm-eval-harness",
|
|
43
|
+
add_help=add_help,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
parser.add_argument(
|
|
47
|
+
"--task",
|
|
48
|
+
type=str,
|
|
49
|
+
required=True,
|
|
50
|
+
help="Task(s) to evaluate on (e.g., gsm8k, mmlu)",
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
parser.add_argument(
|
|
54
|
+
"--server-port", type=int, default=8000, help="Port to use for the server"
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
parser.add_argument(
|
|
58
|
+
"--num-fewshot",
|
|
59
|
+
type=int,
|
|
60
|
+
default=0,
|
|
61
|
+
help="Number of examples in few-shot prompts",
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
parser.add_argument(
|
|
65
|
+
"--limit",
|
|
66
|
+
type=int,
|
|
67
|
+
default=None,
|
|
68
|
+
help="Limit the number of examples per task",
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
parser.add_argument(
|
|
72
|
+
"--log-samples",
|
|
73
|
+
action="store_true",
|
|
74
|
+
help="Log samples for each task to log file",
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
parser.add_argument(
|
|
78
|
+
"--output-path",
|
|
79
|
+
type=str,
|
|
80
|
+
default=None,
|
|
81
|
+
help="Path to save evaluation results",
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
return parser
|
|
85
|
+
|
|
86
|
+
def _scale_metric(self, metric_name, value):
|
|
87
|
+
"""
|
|
88
|
+
Scale metric value appropriately based on type and range
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
metric_name: Name of the metric (e.g., "acc,none", "ppl")
|
|
92
|
+
value: Numeric value of the metric
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
tuple: (scaled_value, units, display_string)
|
|
96
|
+
"""
|
|
97
|
+
fraction_metrics = {
|
|
98
|
+
"acc",
|
|
99
|
+
"accuracy",
|
|
100
|
+
"f1",
|
|
101
|
+
"exact_match",
|
|
102
|
+
"em",
|
|
103
|
+
"win_rate",
|
|
104
|
+
"recall",
|
|
105
|
+
"precision",
|
|
106
|
+
"rouge",
|
|
107
|
+
"bleu",
|
|
108
|
+
"meteor",
|
|
109
|
+
"bertscore",
|
|
110
|
+
"match",
|
|
111
|
+
"correct",
|
|
112
|
+
"pass",
|
|
113
|
+
"success_rate",
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
metric_base = metric_name.split(",")[0].lower()
|
|
117
|
+
is_fraction = any(
|
|
118
|
+
frac_metric in metric_base for frac_metric in fraction_metrics
|
|
119
|
+
)
|
|
120
|
+
is_in_unit_range = 0 <= value <= 1
|
|
121
|
+
|
|
122
|
+
if is_fraction and is_in_unit_range:
|
|
123
|
+
scaled_value = float(value) * 100
|
|
124
|
+
units = "%"
|
|
125
|
+
display_str = f"{value:.4f} ({scaled_value:.2f}%)"
|
|
126
|
+
else:
|
|
127
|
+
scaled_value = float(value)
|
|
128
|
+
units = "raw"
|
|
129
|
+
display_str = f"{value:.4f}"
|
|
130
|
+
|
|
131
|
+
return scaled_value, units, display_str
|
|
132
|
+
|
|
133
|
+
def _process_results(self, results_path, state):
|
|
134
|
+
"""
|
|
135
|
+
Process evaluation results and save to state stats
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
results_path: Can be either a direct JSON file path or a directory path
|
|
139
|
+
state: State object to save metrics to
|
|
140
|
+
"""
|
|
141
|
+
results_file_path = None
|
|
142
|
+
|
|
143
|
+
# Determine if this is a file or directory and find the JSON file
|
|
144
|
+
if os.path.isfile(results_path) and results_path.endswith(".json"):
|
|
145
|
+
# Direct JSON file path (modern format)
|
|
146
|
+
results_file_path = results_path
|
|
147
|
+
elif os.path.isdir(results_path):
|
|
148
|
+
# Look for model subdirectories
|
|
149
|
+
model_dirs = [
|
|
150
|
+
d
|
|
151
|
+
for d in os.listdir(results_path)
|
|
152
|
+
if os.path.isdir(os.path.join(results_path, d))
|
|
153
|
+
]
|
|
154
|
+
|
|
155
|
+
if model_dirs:
|
|
156
|
+
# Format: results_dir/model_name/results_*.json
|
|
157
|
+
model_dir = os.path.join(results_path, model_dirs[0])
|
|
158
|
+
printing.log_info(f"Found model directory: {model_dir}")
|
|
159
|
+
|
|
160
|
+
results_files = [
|
|
161
|
+
f
|
|
162
|
+
for f in os.listdir(model_dir)
|
|
163
|
+
if f.startswith("results_") and f.endswith(".json")
|
|
164
|
+
]
|
|
165
|
+
|
|
166
|
+
if results_files:
|
|
167
|
+
results_files.sort(reverse=True)
|
|
168
|
+
results_file_path = os.path.join(model_dir, results_files[0])
|
|
169
|
+
else:
|
|
170
|
+
printing.log_warning(f"No results files found in {model_dir}")
|
|
171
|
+
return
|
|
172
|
+
else:
|
|
173
|
+
printing.log_warning(f"No model directories found in {results_path}")
|
|
174
|
+
return
|
|
175
|
+
else:
|
|
176
|
+
# Handle case where lm-eval adds timestamp to expected filename
|
|
177
|
+
results_dir = os.path.dirname(results_path)
|
|
178
|
+
if os.path.exists(results_dir):
|
|
179
|
+
json_files = [f for f in os.listdir(results_dir) if f.endswith(".json")]
|
|
180
|
+
if json_files:
|
|
181
|
+
results_file_path = os.path.join(results_dir, json_files[0])
|
|
182
|
+
printing.log_info(f"Found results file: {results_file_path}")
|
|
183
|
+
else:
|
|
184
|
+
printing.log_warning(f"No JSON results file found in {results_dir}")
|
|
185
|
+
return
|
|
186
|
+
else:
|
|
187
|
+
printing.log_warning(f"Results path not found at {results_path}")
|
|
188
|
+
return
|
|
189
|
+
|
|
190
|
+
if not results_file_path or not os.path.exists(results_file_path):
|
|
191
|
+
printing.log_warning(f"Results file not found at {results_file_path}")
|
|
192
|
+
return
|
|
193
|
+
|
|
194
|
+
printing.log_info(f"Processing results from {results_file_path}")
|
|
195
|
+
|
|
196
|
+
try:
|
|
197
|
+
with open(results_file_path, "r", encoding="utf-8") as f:
|
|
198
|
+
results = json.load(f)
|
|
199
|
+
|
|
200
|
+
# Extract and display metrics
|
|
201
|
+
if "results" in results:
|
|
202
|
+
for task_name, metrics in results["results"].items():
|
|
203
|
+
printing.log_info(f"Results for {task_name}:")
|
|
204
|
+
|
|
205
|
+
for metric, value in metrics.items():
|
|
206
|
+
if isinstance(value, (int, float)) and not metric.startswith(
|
|
207
|
+
"alias"
|
|
208
|
+
):
|
|
209
|
+
# Format metric name for stats - remove ,none suffix
|
|
210
|
+
clean_metric = metric.split(",")[0] # Remove ,none suffix
|
|
211
|
+
stat_name = f"lm_eval_{task_name}_{clean_metric}"
|
|
212
|
+
|
|
213
|
+
# Scale metric appropriately
|
|
214
|
+
scaled_value, units, value_str = self._scale_metric(
|
|
215
|
+
metric, value
|
|
216
|
+
)
|
|
217
|
+
display_str = f" {metric}: {value_str}"
|
|
218
|
+
|
|
219
|
+
state.save_stat(stat_name, scaled_value)
|
|
220
|
+
state.save_stat(f"{stat_name}_units", units)
|
|
221
|
+
self.status_stats.append(stat_name)
|
|
222
|
+
|
|
223
|
+
printing.log_info(display_str)
|
|
224
|
+
|
|
225
|
+
# Save summary metrics if available
|
|
226
|
+
avg_metrics = {}
|
|
227
|
+
if "higher_is_better" in results:
|
|
228
|
+
for metric_type in results["higher_is_better"].values():
|
|
229
|
+
for metric in metric_type.keys():
|
|
230
|
+
if metric not in avg_metrics:
|
|
231
|
+
avg_metrics[metric] = []
|
|
232
|
+
|
|
233
|
+
for task_metrics in results["results"].values():
|
|
234
|
+
for metric, value in task_metrics.items():
|
|
235
|
+
if isinstance(value, (int, float)) and not metric.startswith(
|
|
236
|
+
"alias"
|
|
237
|
+
):
|
|
238
|
+
base_metric = metric.split(",")[0]
|
|
239
|
+
if base_metric in avg_metrics:
|
|
240
|
+
avg_metrics[base_metric].append(value)
|
|
241
|
+
|
|
242
|
+
# Calculate and save averages
|
|
243
|
+
for metric, values in avg_metrics.items():
|
|
244
|
+
if values:
|
|
245
|
+
avg_value = sum(values) / len(values)
|
|
246
|
+
stat_name = f"lm_eval_average_{metric}"
|
|
247
|
+
|
|
248
|
+
# Apply same scaling logic as individual metrics
|
|
249
|
+
scaled_avg, units, value_str = self._scale_metric(
|
|
250
|
+
metric, avg_value
|
|
251
|
+
)
|
|
252
|
+
display_str = f"Average {metric}: {value_str}"
|
|
253
|
+
|
|
254
|
+
state.save_stat(stat_name, scaled_avg)
|
|
255
|
+
state.save_stat(f"{stat_name}_units", units)
|
|
256
|
+
self.status_stats.append(stat_name)
|
|
257
|
+
printing.log_info(display_str)
|
|
258
|
+
|
|
259
|
+
except (IOError, json.JSONDecodeError) as e:
|
|
260
|
+
printing.log_error(f"Error processing results: {e}")
|
|
261
|
+
|
|
262
|
+
def run(
|
|
263
|
+
self,
|
|
264
|
+
state: State,
|
|
265
|
+
task: str,
|
|
266
|
+
server_port: int = 8000,
|
|
267
|
+
server_host: str = "localhost",
|
|
268
|
+
num_fewshot: int = 0,
|
|
269
|
+
limit: Optional[int] = None,
|
|
270
|
+
log_samples: bool = False,
|
|
271
|
+
output_path: Optional[str] = None,
|
|
272
|
+
) -> State:
|
|
273
|
+
|
|
274
|
+
# Check if lm-eval is available
|
|
275
|
+
try:
|
|
276
|
+
# pylint: disable=unused-import
|
|
277
|
+
import lm_eval
|
|
278
|
+
except ImportError:
|
|
279
|
+
error_msg = (
|
|
280
|
+
"lm-eval-harness is required but not installed. "
|
|
281
|
+
"Please install it using one of the following commands:\n"
|
|
282
|
+
" pip install lemonade-sdk[dev]\n"
|
|
283
|
+
" pip install -e .[dev]\n"
|
|
284
|
+
)
|
|
285
|
+
printing.log_error(error_msg)
|
|
286
|
+
raise ImportError(error_msg)
|
|
287
|
+
|
|
288
|
+
import requests
|
|
289
|
+
from lemonade.tools.server.utils.thread import ServerRunner
|
|
290
|
+
|
|
291
|
+
model = state.model
|
|
292
|
+
tokenizer = state.tokenizer
|
|
293
|
+
|
|
294
|
+
if model is None or tokenizer is None:
|
|
295
|
+
raise ValueError(
|
|
296
|
+
"Model and tokenizer must be loaded in state before running lm-eval-harness"
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
# Set up output path
|
|
300
|
+
if output_path is None:
|
|
301
|
+
output_path = os.path.join(
|
|
302
|
+
build.output_dir(state.cache_dir, state.build_name), "lm_eval_results"
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
os.makedirs(output_path, exist_ok=True)
|
|
306
|
+
|
|
307
|
+
# Check if port is already in use
|
|
308
|
+
if is_port_in_use(server_port, server_host):
|
|
309
|
+
error_msg = (
|
|
310
|
+
f"Port {server_port} is already in use. "
|
|
311
|
+
"Please close all applications using this port and try again."
|
|
312
|
+
)
|
|
313
|
+
printing.log_error(error_msg)
|
|
314
|
+
raise RuntimeError(error_msg)
|
|
315
|
+
|
|
316
|
+
# Retroactively determine recipe based on model type to select correct iterator
|
|
317
|
+
# The model is already loaded in server, so we only need recipe for iterator selection
|
|
318
|
+
checkpoint = getattr(state, "checkpoint", "unknown")
|
|
319
|
+
if "OrtGenaiModel" in str(type(model)):
|
|
320
|
+
recipe = "oga-"
|
|
321
|
+
else:
|
|
322
|
+
recipe = "unknown"
|
|
323
|
+
|
|
324
|
+
# Start the server thread
|
|
325
|
+
self.server_runner = ServerRunner(
|
|
326
|
+
model=model,
|
|
327
|
+
tokenizer=tokenizer,
|
|
328
|
+
checkpoint=checkpoint,
|
|
329
|
+
recipe=recipe,
|
|
330
|
+
host=server_host,
|
|
331
|
+
port=server_port,
|
|
332
|
+
)
|
|
333
|
+
self.server_runner.start()
|
|
334
|
+
|
|
335
|
+
# Wait for server initialization
|
|
336
|
+
printing.log_info("Waiting for server initialization...")
|
|
337
|
+
|
|
338
|
+
# Wait for server to start and be responsive
|
|
339
|
+
server_url = f"http://{server_host}:{server_port}"
|
|
340
|
+
max_retries = 30
|
|
341
|
+
retry_delay = 1
|
|
342
|
+
|
|
343
|
+
printing.log_info(f"Checking if server is available at {server_url}...")
|
|
344
|
+
for i in range(max_retries):
|
|
345
|
+
try:
|
|
346
|
+
response = requests.get(f"{server_url}/api/v0/health", timeout=2)
|
|
347
|
+
if response.status_code == 200:
|
|
348
|
+
printing.log_info(f"Server is ready after {i+1} attempts")
|
|
349
|
+
break
|
|
350
|
+
except requests.exceptions.RequestException:
|
|
351
|
+
if i < max_retries - 1:
|
|
352
|
+
time.sleep(retry_delay)
|
|
353
|
+
else:
|
|
354
|
+
printing.log_error(
|
|
355
|
+
f"Server did not start after {max_retries} attempts"
|
|
356
|
+
)
|
|
357
|
+
raise RuntimeError("Failed to start the server")
|
|
358
|
+
|
|
359
|
+
# Build API URL
|
|
360
|
+
results_file = os.path.join(output_path, f"{task}_results.json")
|
|
361
|
+
|
|
362
|
+
printing.log_info(f"Running lm-eval-harness on {task}...")
|
|
363
|
+
|
|
364
|
+
# Build lm-eval-harness command
|
|
365
|
+
# Use sys.executable -m to ensure cross-platform compatibility (Windows)
|
|
366
|
+
cmd = [
|
|
367
|
+
sys.executable,
|
|
368
|
+
"-m",
|
|
369
|
+
"lm_eval",
|
|
370
|
+
"--model",
|
|
371
|
+
"local-completions",
|
|
372
|
+
"--tasks",
|
|
373
|
+
task,
|
|
374
|
+
"--model_args",
|
|
375
|
+
(
|
|
376
|
+
f"model={checkpoint},"
|
|
377
|
+
f"base_url={server_url}/api/v0/completions,"
|
|
378
|
+
f"num_concurrent=1,"
|
|
379
|
+
f"max_retries=5,"
|
|
380
|
+
f"retry_timeout=10,"
|
|
381
|
+
f"tokenized_requests=False"
|
|
382
|
+
),
|
|
383
|
+
"--num_fewshot",
|
|
384
|
+
str(num_fewshot),
|
|
385
|
+
"--output_path",
|
|
386
|
+
results_file,
|
|
387
|
+
]
|
|
388
|
+
|
|
389
|
+
if limit is not None:
|
|
390
|
+
cmd.extend(["--limit", str(limit)])
|
|
391
|
+
|
|
392
|
+
if log_samples:
|
|
393
|
+
cmd.extend(["--log_samples"])
|
|
394
|
+
|
|
395
|
+
try:
|
|
396
|
+
# On Windows, set UTF-8 mode to handle Unicode output
|
|
397
|
+
env = os.environ.copy()
|
|
398
|
+
if sys.platform == "win32":
|
|
399
|
+
env["PYTHONIOENCODING"] = "utf-8"
|
|
400
|
+
|
|
401
|
+
# Execute lm-eval-harness command
|
|
402
|
+
result = subprocess.run(
|
|
403
|
+
cmd, check=True, text=True, capture_output=True, env=env
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
# Log relevant output and skip any parts that might cause encoding issues
|
|
407
|
+
try:
|
|
408
|
+
printing.log_info(result.stdout)
|
|
409
|
+
except UnicodeEncodeError:
|
|
410
|
+
printing.log_info(
|
|
411
|
+
"Results obtained successfully but couldn't display due to encoding issues"
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
# Process results from the JSON file
|
|
415
|
+
self._process_results(results_file, state)
|
|
416
|
+
|
|
417
|
+
except subprocess.CalledProcessError as e:
|
|
418
|
+
printing.log_error(f"Error running lm-eval-harness: {e}")
|
|
419
|
+
printing.log_error(f"stderr: {e.stderr}")
|
|
420
|
+
except (IOError, ValueError, requests.RequestException) as e:
|
|
421
|
+
printing.log_error(f"Error: {e}")
|
|
422
|
+
finally:
|
|
423
|
+
# Shut down server
|
|
424
|
+
if self.server_runner and self.server_runner.is_alive():
|
|
425
|
+
printing.log_info("Shutting down server runner...")
|
|
426
|
+
self.server_runner.shutdown()
|
|
427
|
+
|
|
428
|
+
# Make sure we don't have any lingering references to state's model/tokenizer
|
|
429
|
+
# that could prevent garbage collection
|
|
430
|
+
self.server_runner = None
|
|
431
|
+
|
|
432
|
+
return state
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class ModelAdapter(abc.ABC):
|
|
5
|
+
"""
|
|
6
|
+
Base class for adapting an LLM to work with lemonade's standardized tools
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
def __init__(self):
|
|
10
|
+
"""
|
|
11
|
+
Self-benchmarking ModelAdapters can store their results in the
|
|
12
|
+
tokens_per_second and time_to_first_token members.
|
|
13
|
+
ModelAdapters that run generate in a different process can store the
|
|
14
|
+
peak memory used (bytes) by that process in the peak_wset member.
|
|
15
|
+
"""
|
|
16
|
+
self.tokens_per_second = None
|
|
17
|
+
self.time_to_first_token = None
|
|
18
|
+
self.prompt_tokens = None
|
|
19
|
+
self.response_tokens = None
|
|
20
|
+
self.peak_wset = None
|
|
21
|
+
|
|
22
|
+
self.type = "generic"
|
|
23
|
+
|
|
24
|
+
@abc.abstractmethod
|
|
25
|
+
def generate(self, input_ids, max_new_tokens=512):
|
|
26
|
+
"""
|
|
27
|
+
Generate is the primary method required by lemonade's accuracy tools
|
|
28
|
+
|
|
29
|
+
We try to keep the signature here minimal to allow for maximum compatibility
|
|
30
|
+
with recipe components, which themselves may not support a lot of arguments.
|
|
31
|
+
|
|
32
|
+
The generate method should store prompt and response lengths (in tokens)
|
|
33
|
+
in the prompt_tokens and response_tokens members. If a different process is used,
|
|
34
|
+
the generate method can also store the peak memory used by that process in the
|
|
35
|
+
peak_wset member.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class TokenizerAdapter(abc.ABC):
|
|
40
|
+
"""
|
|
41
|
+
Base class for adapting an LLM's tokenizer to work with lemonade's standard tools
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(self, tokenizer=None):
|
|
45
|
+
self.auto_tokenizer = tokenizer
|
|
46
|
+
|
|
47
|
+
@abc.abstractmethod
|
|
48
|
+
def __call__(self, prompt: str):
|
|
49
|
+
"""
|
|
50
|
+
Args:
|
|
51
|
+
prompt: text that should be encoded and passed to the LLM as input_ids
|
|
52
|
+
|
|
53
|
+
Returns: input_ids
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
@abc.abstractmethod
|
|
57
|
+
def decode(self, response) -> str:
|
|
58
|
+
"""
|
|
59
|
+
Args:
|
|
60
|
+
response: tokens from the LLM that should be decoded into text
|
|
61
|
+
|
|
62
|
+
Returns: text response of the LLM
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def apply_chat_template(self, *args, **kwargs):
|
|
66
|
+
"""
|
|
67
|
+
Convert messages into a single tokenizable string
|
|
68
|
+
"""
|
|
69
|
+
return self.auto_tokenizer.apply_chat_template(*args, **kwargs)
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def chat_template(self):
|
|
73
|
+
return self.auto_tokenizer.chat_template
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def eos_token(self):
|
|
77
|
+
return self.auto_tokenizer.eos_token
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class PassthroughTokenizerResult:
|
|
81
|
+
"""
|
|
82
|
+
Data structure for holding a tokenizer result where the input_ids
|
|
83
|
+
are packaged in a non-standard way, but we still want to adhere to
|
|
84
|
+
standard interfaces (e.g., result.input_ids).
|
|
85
|
+
|
|
86
|
+
For example: CLI-based tools that have their own internal tokenizer that
|
|
87
|
+
isn't exposed to the user. In this case we can pass the prompt through as
|
|
88
|
+
a string.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
def __init__(self, prompt):
|
|
92
|
+
self.input_ids = prompt
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class PassthroughTokenizer(TokenizerAdapter):
|
|
96
|
+
"""
|
|
97
|
+
Tokenizer adapter that forwards the prompt to input_ids as text,
|
|
98
|
+
and then forwards a text LLM response through decode() as text.
|
|
99
|
+
|
|
100
|
+
Useful for CLI-based tools that have their own internal tokenizer that
|
|
101
|
+
isn't exposed to the user.
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
# pylint: disable=unused-argument
|
|
105
|
+
def __call__(self, prompt: str, **kwargs):
|
|
106
|
+
return PassthroughTokenizerResult(prompt)
|
|
107
|
+
|
|
108
|
+
# pylint: disable=unused-argument
|
|
109
|
+
def decode(self, response: str, **kwargs):
|
|
110
|
+
return response
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
# This file was originally licensed under Apache 2.0. It has been modified.
|
|
114
|
+
# Modifications Copyright (c) 2025 AMD
|