lemonade-sdk 7.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lemonade-sdk might be problematic. Click here for more details.
- lemonade/__init__.py +5 -0
- lemonade/api.py +125 -0
- lemonade/cache.py +85 -0
- lemonade/cli.py +135 -0
- lemonade/common/__init__.py +0 -0
- lemonade/common/analyze_model.py +26 -0
- lemonade/common/build.py +223 -0
- lemonade/common/cli_helpers.py +139 -0
- lemonade/common/exceptions.py +98 -0
- lemonade/common/filesystem.py +368 -0
- lemonade/common/labels.py +61 -0
- lemonade/common/onnx_helpers.py +176 -0
- lemonade/common/plugins.py +10 -0
- lemonade/common/printing.py +110 -0
- lemonade/common/status.py +490 -0
- lemonade/common/system_info.py +390 -0
- lemonade/common/tensor_helpers.py +83 -0
- lemonade/common/test_helpers.py +28 -0
- lemonade/profilers/__init__.py +1 -0
- lemonade/profilers/memory_tracker.py +257 -0
- lemonade/profilers/profiler.py +55 -0
- lemonade/sequence.py +363 -0
- lemonade/state.py +159 -0
- lemonade/tools/__init__.py +1 -0
- lemonade/tools/adapter.py +104 -0
- lemonade/tools/bench.py +284 -0
- lemonade/tools/huggingface_bench.py +267 -0
- lemonade/tools/huggingface_load.py +520 -0
- lemonade/tools/humaneval.py +258 -0
- lemonade/tools/llamacpp.py +261 -0
- lemonade/tools/llamacpp_bench.py +154 -0
- lemonade/tools/management_tools.py +273 -0
- lemonade/tools/mmlu.py +327 -0
- lemonade/tools/ort_genai/__init__.py +0 -0
- lemonade/tools/ort_genai/oga.py +1129 -0
- lemonade/tools/ort_genai/oga_bench.py +142 -0
- lemonade/tools/perplexity.py +146 -0
- lemonade/tools/prompt.py +228 -0
- lemonade/tools/quark/__init__.py +0 -0
- lemonade/tools/quark/quark_load.py +172 -0
- lemonade/tools/quark/quark_quantize.py +439 -0
- lemonade/tools/report/__init__.py +0 -0
- lemonade/tools/report/llm_report.py +203 -0
- lemonade/tools/report/table.py +739 -0
- lemonade/tools/server/__init__.py +0 -0
- lemonade/tools/server/serve.py +1354 -0
- lemonade/tools/server/tool_calls.py +146 -0
- lemonade/tools/tool.py +374 -0
- lemonade/version.py +1 -0
- lemonade_install/__init__.py +1 -0
- lemonade_install/install.py +774 -0
- lemonade_sdk-7.0.0.dist-info/METADATA +116 -0
- lemonade_sdk-7.0.0.dist-info/RECORD +61 -0
- lemonade_sdk-7.0.0.dist-info/WHEEL +5 -0
- lemonade_sdk-7.0.0.dist-info/entry_points.txt +4 -0
- lemonade_sdk-7.0.0.dist-info/licenses/LICENSE +201 -0
- lemonade_sdk-7.0.0.dist-info/licenses/NOTICE.md +21 -0
- lemonade_sdk-7.0.0.dist-info/top_level.txt +3 -0
- lemonade_server/cli.py +260 -0
- lemonade_server/model_manager.py +98 -0
- lemonade_server/server_models.json +142 -0
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import abc
|
|
3
|
+
from typing import List
|
|
4
|
+
import lemonade.common.filesystem as fs
|
|
5
|
+
import lemonade.common.exceptions as exp
|
|
6
|
+
import lemonade.common.printing as printing
|
|
7
|
+
from lemonade.tools.tool import ToolParser
|
|
8
|
+
from lemonade.version import __version__ as lemonade_version
|
|
9
|
+
from lemonade.common.system_info import get_system_info_dict
|
|
10
|
+
from lemonade.common.build import output_dir
|
|
11
|
+
import lemonade.cache as lemonade_cache
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ManagementTool(abc.ABC):
|
|
15
|
+
"""
|
|
16
|
+
Intended for management functions, such as managing the cache
|
|
17
|
+
or printing the version number.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
unique_name: str
|
|
21
|
+
|
|
22
|
+
@classmethod
|
|
23
|
+
def helpful_parser(cls, short_description: str, **kwargs):
|
|
24
|
+
epilog = (
|
|
25
|
+
f"`{cls.unique_name}` is a Management Tool. It is intended to be invoked by itself "
|
|
26
|
+
"(i.e., not as part of a sequence), to accomplish a utility function. "
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
return ToolParser(
|
|
30
|
+
prog=f"lemonade {cls.unique_name}",
|
|
31
|
+
short_description=short_description,
|
|
32
|
+
description=cls.__doc__,
|
|
33
|
+
epilog=epilog,
|
|
34
|
+
**kwargs,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
@staticmethod
|
|
38
|
+
@abc.abstractmethod
|
|
39
|
+
def parser() -> argparse.ArgumentParser:
|
|
40
|
+
"""
|
|
41
|
+
Static method that returns an ArgumentParser that defines the command
|
|
42
|
+
line interface for this Tool.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
# pylint: disable=unused-argument
|
|
46
|
+
def parse(self, args, known_only=True) -> argparse.Namespace:
|
|
47
|
+
"""
|
|
48
|
+
Run the parser and return a Namespace of keyword arguments that the user
|
|
49
|
+
passed to the Tool via the command line.
|
|
50
|
+
|
|
51
|
+
Tools should extend this function only if they require specific parsing
|
|
52
|
+
logic.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
args: command line arguments passed from the CLI.
|
|
56
|
+
known_only: this argument allows the CLI framework to
|
|
57
|
+
incrementally parse complex commands.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
if known_only:
|
|
61
|
+
parsed_args = self.__class__.parser().parse_args(args)
|
|
62
|
+
else:
|
|
63
|
+
parsed_args, _ = self.__class__.parser().parse_known_args(args)
|
|
64
|
+
|
|
65
|
+
return parsed_args
|
|
66
|
+
|
|
67
|
+
@abc.abstractmethod
|
|
68
|
+
def run(self, cache_dir: str):
|
|
69
|
+
"""
|
|
70
|
+
Execute the functionality of the Tool.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
def parse_and_run(self, cache_dir: str, args, known_only=True):
|
|
74
|
+
"""
|
|
75
|
+
Helper function to parse CLI arguments into the args expected
|
|
76
|
+
by run(), and then forward them into the run() method.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
parsed_args = self.parse(args, known_only)
|
|
80
|
+
self.run(cache_dir, **parsed_args.__dict__)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class Version(ManagementTool):
|
|
84
|
+
"""
|
|
85
|
+
Simply prints the version number of the lemonade installation.
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
unique_name = "version"
|
|
89
|
+
|
|
90
|
+
@staticmethod
|
|
91
|
+
def parser(add_help: bool = True) -> argparse.ArgumentParser:
|
|
92
|
+
parser = __class__.helpful_parser(
|
|
93
|
+
short_description="Print the lemonade version number",
|
|
94
|
+
add_help=add_help,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
return parser
|
|
98
|
+
|
|
99
|
+
def run(self, _):
|
|
100
|
+
print(lemonade_version)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class Cache(ManagementTool):
|
|
104
|
+
# pylint: disable=pointless-statement,f-string-without-interpolation
|
|
105
|
+
f"""
|
|
106
|
+
A set of functions for managing the lemonade build cache. The default
|
|
107
|
+
cache location is {lemonade_cache.DEFAULT_CACHE_DIR}, and can also be
|
|
108
|
+
selected with
|
|
109
|
+
the global --cache-dir option or the LEMONADE_CACHE_DIR environment variable.
|
|
110
|
+
|
|
111
|
+
Users must set either "--all" or "--build-names" to let the tool
|
|
112
|
+
know what builds to operate on.
|
|
113
|
+
|
|
114
|
+
Users must also set one of the available actions (e.g., list, stats, etc.).
|
|
115
|
+
|
|
116
|
+
That action will be applied to all selected builds.
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
unique_name = "cache"
|
|
120
|
+
|
|
121
|
+
@staticmethod
|
|
122
|
+
def parser(add_help: bool = True) -> argparse.ArgumentParser:
|
|
123
|
+
# NOTE: `--cache-dir` is set as a global input to the lemonade CLI and
|
|
124
|
+
# passed directly to the `run()` method
|
|
125
|
+
|
|
126
|
+
parser = __class__.helpful_parser(
|
|
127
|
+
short_description="Manage the build cache " f"",
|
|
128
|
+
add_help=add_help,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
build_selection_group = parser.add_mutually_exclusive_group(required=True)
|
|
132
|
+
|
|
133
|
+
build_selection_group.add_argument(
|
|
134
|
+
"-b",
|
|
135
|
+
"--build-names",
|
|
136
|
+
nargs="+",
|
|
137
|
+
help="Name of the specific build(s) to be operated upon, within the cache directory",
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
build_selection_group.add_argument(
|
|
141
|
+
"-a",
|
|
142
|
+
"--all",
|
|
143
|
+
dest="all_builds",
|
|
144
|
+
help="Operate on all the builds in the cache",
|
|
145
|
+
action="store_true",
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
action_group = parser.add_mutually_exclusive_group(required=True)
|
|
149
|
+
|
|
150
|
+
action_group.add_argument(
|
|
151
|
+
"-l",
|
|
152
|
+
"--list",
|
|
153
|
+
dest="list_builds",
|
|
154
|
+
action="store_true",
|
|
155
|
+
help="List all of the builds in the cache",
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
action_group.add_argument(
|
|
159
|
+
"-s",
|
|
160
|
+
"--stats",
|
|
161
|
+
action="store_true",
|
|
162
|
+
help="Print the collected stats for the selected build(s)",
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
action_group.add_argument(
|
|
166
|
+
"--delete",
|
|
167
|
+
action="store_true",
|
|
168
|
+
help="Permanently delete the selected build(s)",
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
action_group.add_argument(
|
|
172
|
+
"--clean",
|
|
173
|
+
action="store_true",
|
|
174
|
+
help="Remove the build artifacts from the selected build(s)",
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
return parser
|
|
178
|
+
|
|
179
|
+
def run(
|
|
180
|
+
self,
|
|
181
|
+
cache_dir: str,
|
|
182
|
+
all_builds: bool = False,
|
|
183
|
+
build_names: List[str] = None,
|
|
184
|
+
list_builds: bool = False,
|
|
185
|
+
stats: bool = False,
|
|
186
|
+
delete: bool = False,
|
|
187
|
+
clean: bool = False,
|
|
188
|
+
):
|
|
189
|
+
fs.check_cache_dir(cache_dir)
|
|
190
|
+
|
|
191
|
+
if all_builds and build_names:
|
|
192
|
+
raise ValueError(
|
|
193
|
+
"all_builds and build_names are mutually exclusive, "
|
|
194
|
+
"but both are used in this call."
|
|
195
|
+
)
|
|
196
|
+
elif all_builds:
|
|
197
|
+
builds = fs.get_available_builds(cache_dir)
|
|
198
|
+
elif build_names:
|
|
199
|
+
builds = build_names
|
|
200
|
+
else:
|
|
201
|
+
raise ValueError(
|
|
202
|
+
"Either all_builds or build_names must be set, "
|
|
203
|
+
"but this call sets neither."
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
# Print a nice heading
|
|
207
|
+
printing.log_info(f"Operating on cache directory {cache_dir}")
|
|
208
|
+
|
|
209
|
+
if not builds:
|
|
210
|
+
printing.log_warning("No builds found.")
|
|
211
|
+
|
|
212
|
+
for build in builds:
|
|
213
|
+
build_path = output_dir(cache_dir, build_name=build)
|
|
214
|
+
if fs.is_build_dir(cache_dir, build):
|
|
215
|
+
# Run actions on the build
|
|
216
|
+
# These actions are intended to be mutually exclusive, so we
|
|
217
|
+
# use an if-elif block in order from least to most destructive
|
|
218
|
+
if list_builds:
|
|
219
|
+
print(build)
|
|
220
|
+
elif stats:
|
|
221
|
+
fs.print_yaml_file(fs.Stats(cache_dir, build).file, "stats")
|
|
222
|
+
elif clean:
|
|
223
|
+
fs.clean_output_dir(cache_dir, build)
|
|
224
|
+
printing.log_info(f"Removed the build artifacts from: {build}")
|
|
225
|
+
|
|
226
|
+
elif delete:
|
|
227
|
+
fs.rmdir(build_path)
|
|
228
|
+
printing.log_info(f"Deleted build: {build}")
|
|
229
|
+
else:
|
|
230
|
+
raise exp.CacheError(
|
|
231
|
+
f"No build found with name: {build}. "
|
|
232
|
+
"Try running `lemonade cache --list` to see the builds in your build cache."
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
print()
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
class SystemInfo(ManagementTool):
|
|
239
|
+
"""
|
|
240
|
+
Prints system information for the lemonade installation.
|
|
241
|
+
"""
|
|
242
|
+
|
|
243
|
+
unique_name = "system-info"
|
|
244
|
+
|
|
245
|
+
@staticmethod
|
|
246
|
+
def parser(add_help: bool = True) -> argparse.ArgumentParser:
|
|
247
|
+
parser = __class__.helpful_parser(
|
|
248
|
+
short_description="Print system information",
|
|
249
|
+
add_help=add_help,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
return parser
|
|
253
|
+
|
|
254
|
+
@staticmethod
|
|
255
|
+
def pretty_print(my_dict: dict, level=0):
|
|
256
|
+
for k, v in my_dict.items():
|
|
257
|
+
if isinstance(v, dict):
|
|
258
|
+
print(" " * level + f"{k}:")
|
|
259
|
+
SystemInfo.pretty_print(v, level + 1)
|
|
260
|
+
elif isinstance(v, list):
|
|
261
|
+
print(" " * level + f"{k}:")
|
|
262
|
+
for item in v:
|
|
263
|
+
print(" " * (level + 1) + f"{item}")
|
|
264
|
+
else:
|
|
265
|
+
print(" " * level + f"{k}: {v}")
|
|
266
|
+
|
|
267
|
+
def run(self, _):
|
|
268
|
+
system_info_dict = get_system_info_dict()
|
|
269
|
+
self.pretty_print(system_info_dict)
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
# This file was originally licensed under Apache 2.0. It has been modified.
|
|
273
|
+
# Modifications Copyright (c) 2025 AMD
|
lemonade/tools/mmlu.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import os
|
|
3
|
+
import tarfile
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import List, Optional
|
|
6
|
+
import subprocess
|
|
7
|
+
import numpy as np
|
|
8
|
+
import pandas as pd
|
|
9
|
+
import requests
|
|
10
|
+
from lemonade.state import State
|
|
11
|
+
from lemonade.tools import Tool
|
|
12
|
+
import lemonade.common.printing as printing
|
|
13
|
+
import lemonade.common.build as build
|
|
14
|
+
import lemonade.common.filesystem as fs
|
|
15
|
+
|
|
16
|
+
# Constants
|
|
17
|
+
choices = ["A", "B", "C", "D"]
|
|
18
|
+
dataset_url = "https://people.eecs.berkeley.edu/~hendrycks/data.tar"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def min_handle_none(*args: int):
|
|
22
|
+
"""
|
|
23
|
+
Returns the minimum of the arguments. If one of the arguments is none,
|
|
24
|
+
it doesn't count towards the min.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
filter_out_none = (value for value in args if value is not None)
|
|
28
|
+
return min(filter_out_none)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class AccuracyMMLU(Tool):
|
|
32
|
+
"""
|
|
33
|
+
See docs/lemonade/mmlu_accuracy.md for more details
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
unique_name = "accuracy-mmlu"
|
|
37
|
+
|
|
38
|
+
def __init__(self):
|
|
39
|
+
super().__init__(monitor_message="Measuring accuracy with MMLU")
|
|
40
|
+
self.status_stats = []
|
|
41
|
+
|
|
42
|
+
@staticmethod
|
|
43
|
+
def parser(add_help: bool = True) -> argparse.ArgumentParser:
|
|
44
|
+
parser = __class__.helpful_parser(
|
|
45
|
+
short_description="Measure accuracy with Massive Multitask "
|
|
46
|
+
"Language Understanding (MMLU)",
|
|
47
|
+
add_help=add_help,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
parser.add_argument(
|
|
51
|
+
"--ntrain",
|
|
52
|
+
type=int,
|
|
53
|
+
default=5,
|
|
54
|
+
help="Number of training examples to use. Default set to 5 for `5 Shot`",
|
|
55
|
+
)
|
|
56
|
+
parser.add_argument(
|
|
57
|
+
"--max-evals",
|
|
58
|
+
type=int,
|
|
59
|
+
default=None,
|
|
60
|
+
help="Maximum evaluations to run per test",
|
|
61
|
+
)
|
|
62
|
+
parser.add_argument(
|
|
63
|
+
"--data-dir",
|
|
64
|
+
type=str,
|
|
65
|
+
required=False,
|
|
66
|
+
help="Directory containing test and dev data (default: lemonade cache).",
|
|
67
|
+
)
|
|
68
|
+
parser.add_argument(
|
|
69
|
+
"--tests",
|
|
70
|
+
nargs="+",
|
|
71
|
+
help=(
|
|
72
|
+
"Specific tests to run. For a single quick test, we suggest 'management'."
|
|
73
|
+
+ "Default: run all tests."
|
|
74
|
+
),
|
|
75
|
+
)
|
|
76
|
+
return parser
|
|
77
|
+
|
|
78
|
+
def run(
|
|
79
|
+
self,
|
|
80
|
+
state: State,
|
|
81
|
+
ntrain: int = 5,
|
|
82
|
+
max_evals: int = None,
|
|
83
|
+
data_dir: Optional[str] = None,
|
|
84
|
+
tests: List[str] = None,
|
|
85
|
+
) -> State:
|
|
86
|
+
|
|
87
|
+
if data_dir:
|
|
88
|
+
data_dir_to_use = data_dir
|
|
89
|
+
else:
|
|
90
|
+
data_dir_to_use = os.path.join(state.cache_dir, "data", "mmlu")
|
|
91
|
+
|
|
92
|
+
# Setup MMLU dataset
|
|
93
|
+
dataset_dir = download_and_extract_dataset(data_dir_to_use, dataset_url)
|
|
94
|
+
|
|
95
|
+
model_results_dir = os.path.join(
|
|
96
|
+
build.output_dir(state.cache_dir, state.build_name), "mmlu"
|
|
97
|
+
)
|
|
98
|
+
os.makedirs(model_results_dir, exist_ok=True)
|
|
99
|
+
|
|
100
|
+
tests_to_run = [
|
|
101
|
+
f.replace("_test.csv", "")
|
|
102
|
+
for f in sorted(os.listdir(os.path.join(dataset_dir, "test")))
|
|
103
|
+
if f.endswith("_test.csv")
|
|
104
|
+
]
|
|
105
|
+
if tests is not None:
|
|
106
|
+
unsupported_tests = set(tests) - set(tests_to_run)
|
|
107
|
+
if unsupported_tests:
|
|
108
|
+
raise ValueError(
|
|
109
|
+
f"Invalid test names provided: {', '.join(unsupported_tests)}. "
|
|
110
|
+
f"Valid tests are: {', '.join(tests_to_run)}"
|
|
111
|
+
)
|
|
112
|
+
tests_to_run = [test for test in tests if test in tests_to_run]
|
|
113
|
+
|
|
114
|
+
tokenizer = state.tokenizer
|
|
115
|
+
model = state.model
|
|
116
|
+
|
|
117
|
+
# Update Tool progress monitor
|
|
118
|
+
self.set_percent_progress(0.0)
|
|
119
|
+
number_of_questions = float(
|
|
120
|
+
sum(
|
|
121
|
+
[
|
|
122
|
+
min_handle_none(
|
|
123
|
+
len(
|
|
124
|
+
_safe_read_csv(
|
|
125
|
+
os.path.join(dataset_dir, "test", f"{subject}_test.csv")
|
|
126
|
+
)
|
|
127
|
+
),
|
|
128
|
+
max_evals,
|
|
129
|
+
)
|
|
130
|
+
for subject in tests_to_run
|
|
131
|
+
]
|
|
132
|
+
)
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
questions_completed = 0
|
|
136
|
+
|
|
137
|
+
summary_data = []
|
|
138
|
+
for subject in tests_to_run:
|
|
139
|
+
dev_df = _safe_read_csv(
|
|
140
|
+
os.path.join(dataset_dir, "dev", f"{subject}_dev.csv")
|
|
141
|
+
)[:ntrain]
|
|
142
|
+
test_df = _safe_read_csv(
|
|
143
|
+
os.path.join(dataset_dir, "test", f"{subject}_test.csv")
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
# Evaluate the model on the test data for a given subject
|
|
147
|
+
detailed_results = []
|
|
148
|
+
|
|
149
|
+
for i in range(min_handle_none(test_df.shape[0], max_evals)):
|
|
150
|
+
prompt = _gen_prompt(dev_df, subject, ntrain) + _format_example(
|
|
151
|
+
test_df, i, include_answer=False
|
|
152
|
+
)
|
|
153
|
+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
|
154
|
+
|
|
155
|
+
response_text = _generate_response(tokenizer, model, input_ids)
|
|
156
|
+
try:
|
|
157
|
+
pred_label = response_text[-1].upper()
|
|
158
|
+
# Handle models generating empty outputs
|
|
159
|
+
except IndexError:
|
|
160
|
+
pred_label = "-"
|
|
161
|
+
|
|
162
|
+
label = test_df.iloc[i, -1].strip().upper()
|
|
163
|
+
detailed_results.append(
|
|
164
|
+
{
|
|
165
|
+
"Question": test_df.iloc[i, 0],
|
|
166
|
+
"Prompt": prompt,
|
|
167
|
+
"Correct Answer": label,
|
|
168
|
+
"Generated Answer": pred_label,
|
|
169
|
+
"Correct": pred_label == label,
|
|
170
|
+
}
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
# Update progress monitor
|
|
174
|
+
questions_completed = questions_completed + 1
|
|
175
|
+
percent_completed = questions_completed / number_of_questions * 100
|
|
176
|
+
self.set_percent_progress(percent_completed)
|
|
177
|
+
|
|
178
|
+
acc = np.mean([res["Correct"] for res in detailed_results])
|
|
179
|
+
|
|
180
|
+
subject_results_df = pd.DataFrame(detailed_results)
|
|
181
|
+
subject_csv_path = os.path.join(
|
|
182
|
+
model_results_dir, f"{subject}_detailed_results.csv"
|
|
183
|
+
)
|
|
184
|
+
subject_results_df.to_csv(subject_csv_path, index=False)
|
|
185
|
+
|
|
186
|
+
# Update summary_data with total questions and correct answers
|
|
187
|
+
correct_answers_count = sum(
|
|
188
|
+
result["Correct"] for result in detailed_results
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
summary_data.append(
|
|
192
|
+
{
|
|
193
|
+
"Subject": subject,
|
|
194
|
+
"Accuracy": acc,
|
|
195
|
+
"Total Questions": len(test_df),
|
|
196
|
+
"Evaluated Questions": (
|
|
197
|
+
max_evals
|
|
198
|
+
if max_evals is not None and max_evals < len(test_df)
|
|
199
|
+
else len(test_df)
|
|
200
|
+
),
|
|
201
|
+
"Correct Answers": correct_answers_count,
|
|
202
|
+
}
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
# Save accuracy results to stats file
|
|
206
|
+
# And display in the CLI
|
|
207
|
+
stat_name = f"mmlu_{subject}_accuracy"
|
|
208
|
+
stat_units_name = f"{stat_name}_units"
|
|
209
|
+
state.save_stat(stat_name, float(acc) * 100)
|
|
210
|
+
state.save_stat(stat_units_name, "%")
|
|
211
|
+
self.status_stats.append(stat_name)
|
|
212
|
+
|
|
213
|
+
# Calculate average of mmlu accuracy and display in the CLI
|
|
214
|
+
acc_avg = np.mean([accuracy_data["Accuracy"] for accuracy_data in summary_data])
|
|
215
|
+
state.save_stat(fs.Keys.AVERAGE_MMLU_ACCURACY, float(acc_avg) * 100)
|
|
216
|
+
state.save_stat(f"{fs.Keys.AVERAGE_MMLU_ACCURACY}_units", "%")
|
|
217
|
+
self.status_stats.append(fs.Keys.AVERAGE_MMLU_ACCURACY)
|
|
218
|
+
|
|
219
|
+
# Save accuracy results to CSV file
|
|
220
|
+
summary_df = pd.DataFrame(summary_data)
|
|
221
|
+
summary_df.to_csv(
|
|
222
|
+
os.path.join(model_results_dir, "summary_results.csv"), index=False
|
|
223
|
+
)
|
|
224
|
+
return state
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def _list_tests(data_dir):
|
|
228
|
+
"""Lists all available tests based on the files in the test data directory."""
|
|
229
|
+
test_files = [
|
|
230
|
+
f for f in os.listdir(os.path.join(data_dir, "test")) if f.endswith("_test.csv")
|
|
231
|
+
]
|
|
232
|
+
print(
|
|
233
|
+
"Available tests:",
|
|
234
|
+
*[f.replace("_test.csv", "") for f in sorted(test_files)],
|
|
235
|
+
sep="\n",
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def _format_subject(subject):
|
|
240
|
+
"""Formats a subject string by replacing underscores with spaces."""
|
|
241
|
+
return " ".join(subject.split("_"))
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def _safe_read_csv(path):
|
|
245
|
+
"""Safely reads a CSV file and returns a DataFrame."""
|
|
246
|
+
try:
|
|
247
|
+
return pd.read_csv(path, header=None)
|
|
248
|
+
except FileNotFoundError:
|
|
249
|
+
printing.log_error(f"Error: File not found - {path}")
|
|
250
|
+
except Exception as e: # pylint: disable=broad-except
|
|
251
|
+
printing.log_error(f"An error occurred while reading {path}: {e}")
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def _format_example(df, idx, include_answer=True):
|
|
255
|
+
"""Formats an example from the dataframe into a prompt string."""
|
|
256
|
+
prompt = df.iloc[idx, 0]
|
|
257
|
+
for j in range(1, df.shape[1] - 1):
|
|
258
|
+
prompt += f"\n{choices[j-1]}. {df.iloc[idx, j]}"
|
|
259
|
+
prompt += "\nAnswer_:"
|
|
260
|
+
if include_answer:
|
|
261
|
+
prompt += f" {df.iloc[idx, -1]}\n\n"
|
|
262
|
+
return prompt
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def _gen_prompt(train_df, subject, k=-1):
|
|
266
|
+
"""Generates a prompt string from multiple examples."""
|
|
267
|
+
prompt = (
|
|
268
|
+
"The following are multiple choice questions (with answers) about "
|
|
269
|
+
+ f"{_format_subject(subject)}.\n\n"
|
|
270
|
+
)
|
|
271
|
+
for i in range(min(k, train_df.shape[0]) if k != -1 else train_df.shape[0]):
|
|
272
|
+
prompt += _format_example(train_df, i)
|
|
273
|
+
return prompt
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def _generate_response(tokenizer, model, input_ids):
|
|
277
|
+
"""Generates a model response for the given input IDs."""
|
|
278
|
+
try:
|
|
279
|
+
response = model.generate(input_ids, max_new_tokens=1)
|
|
280
|
+
return tokenizer.decode(response[0], skip_special_tokens=True).strip()
|
|
281
|
+
except subprocess.CalledProcessError as e:
|
|
282
|
+
printing.log_warning(
|
|
283
|
+
f"Subprocess failed with command: {e} and error message: {e.stderr}"
|
|
284
|
+
)
|
|
285
|
+
except Exception as e: # pylint: disable=broad-except
|
|
286
|
+
printing.log_warning(f"Error during model generation: {e}")
|
|
287
|
+
return "" # Return an empty string on failure
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def download_and_extract_dataset(data_cache_dir: str, dataset_url: str):
|
|
291
|
+
"""
|
|
292
|
+
Download the dataset from the given URL and extract it into the target directory.
|
|
293
|
+
"""
|
|
294
|
+
|
|
295
|
+
# Create the directory if it does not exist
|
|
296
|
+
Path(data_cache_dir).mkdir(parents=True, exist_ok=True)
|
|
297
|
+
|
|
298
|
+
# Check if the data already exists to avoid re-downloading
|
|
299
|
+
if not os.listdir(data_cache_dir): # Checks if the directory is empty
|
|
300
|
+
printing.log_info(f"Downloading dataset to {data_cache_dir}")
|
|
301
|
+
|
|
302
|
+
# Download the dataset
|
|
303
|
+
response = requests.get(dataset_url, stream=True)
|
|
304
|
+
if response.status_code == 200:
|
|
305
|
+
tar_path = os.path.join(data_cache_dir, "data.tar")
|
|
306
|
+
with open(tar_path, "wb") as f:
|
|
307
|
+
f.write(response.raw.read())
|
|
308
|
+
|
|
309
|
+
printing.log_info("Extracting dataset...")
|
|
310
|
+
# Extract the tar file
|
|
311
|
+
with tarfile.open(tar_path) as tar:
|
|
312
|
+
tar.extractall(path=data_cache_dir)
|
|
313
|
+
os.remove(tar_path)
|
|
314
|
+
printing.log_info("Dataset ready.")
|
|
315
|
+
else:
|
|
316
|
+
printing.log_info("Failed to download the dataset.")
|
|
317
|
+
else:
|
|
318
|
+
printing.log_info(
|
|
319
|
+
f"Dataset already exists in {data_cache_dir}, skipping download."
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
# MMLU data is stored in data.tar/data
|
|
323
|
+
return os.path.join(data_cache_dir, "data")
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
# This file was originally licensed under Apache 2.0. It has been modified.
|
|
327
|
+
# Modifications Copyright (c) 2025 AMD
|
|
File without changes
|