lemonade-sdk 7.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lemonade-sdk might be problematic. Click here for more details.

Files changed (61) hide show
  1. lemonade/__init__.py +5 -0
  2. lemonade/api.py +125 -0
  3. lemonade/cache.py +85 -0
  4. lemonade/cli.py +135 -0
  5. lemonade/common/__init__.py +0 -0
  6. lemonade/common/analyze_model.py +26 -0
  7. lemonade/common/build.py +223 -0
  8. lemonade/common/cli_helpers.py +139 -0
  9. lemonade/common/exceptions.py +98 -0
  10. lemonade/common/filesystem.py +368 -0
  11. lemonade/common/labels.py +61 -0
  12. lemonade/common/onnx_helpers.py +176 -0
  13. lemonade/common/plugins.py +10 -0
  14. lemonade/common/printing.py +110 -0
  15. lemonade/common/status.py +490 -0
  16. lemonade/common/system_info.py +390 -0
  17. lemonade/common/tensor_helpers.py +83 -0
  18. lemonade/common/test_helpers.py +28 -0
  19. lemonade/profilers/__init__.py +1 -0
  20. lemonade/profilers/memory_tracker.py +257 -0
  21. lemonade/profilers/profiler.py +55 -0
  22. lemonade/sequence.py +363 -0
  23. lemonade/state.py +159 -0
  24. lemonade/tools/__init__.py +1 -0
  25. lemonade/tools/adapter.py +104 -0
  26. lemonade/tools/bench.py +284 -0
  27. lemonade/tools/huggingface_bench.py +267 -0
  28. lemonade/tools/huggingface_load.py +520 -0
  29. lemonade/tools/humaneval.py +258 -0
  30. lemonade/tools/llamacpp.py +261 -0
  31. lemonade/tools/llamacpp_bench.py +154 -0
  32. lemonade/tools/management_tools.py +273 -0
  33. lemonade/tools/mmlu.py +327 -0
  34. lemonade/tools/ort_genai/__init__.py +0 -0
  35. lemonade/tools/ort_genai/oga.py +1129 -0
  36. lemonade/tools/ort_genai/oga_bench.py +142 -0
  37. lemonade/tools/perplexity.py +146 -0
  38. lemonade/tools/prompt.py +228 -0
  39. lemonade/tools/quark/__init__.py +0 -0
  40. lemonade/tools/quark/quark_load.py +172 -0
  41. lemonade/tools/quark/quark_quantize.py +439 -0
  42. lemonade/tools/report/__init__.py +0 -0
  43. lemonade/tools/report/llm_report.py +203 -0
  44. lemonade/tools/report/table.py +739 -0
  45. lemonade/tools/server/__init__.py +0 -0
  46. lemonade/tools/server/serve.py +1354 -0
  47. lemonade/tools/server/tool_calls.py +146 -0
  48. lemonade/tools/tool.py +374 -0
  49. lemonade/version.py +1 -0
  50. lemonade_install/__init__.py +1 -0
  51. lemonade_install/install.py +774 -0
  52. lemonade_sdk-7.0.0.dist-info/METADATA +116 -0
  53. lemonade_sdk-7.0.0.dist-info/RECORD +61 -0
  54. lemonade_sdk-7.0.0.dist-info/WHEEL +5 -0
  55. lemonade_sdk-7.0.0.dist-info/entry_points.txt +4 -0
  56. lemonade_sdk-7.0.0.dist-info/licenses/LICENSE +201 -0
  57. lemonade_sdk-7.0.0.dist-info/licenses/NOTICE.md +21 -0
  58. lemonade_sdk-7.0.0.dist-info/top_level.txt +3 -0
  59. lemonade_server/cli.py +260 -0
  60. lemonade_server/model_manager.py +98 -0
  61. lemonade_server/server_models.json +142 -0
@@ -0,0 +1,273 @@
1
+ import argparse
2
+ import abc
3
+ from typing import List
4
+ import lemonade.common.filesystem as fs
5
+ import lemonade.common.exceptions as exp
6
+ import lemonade.common.printing as printing
7
+ from lemonade.tools.tool import ToolParser
8
+ from lemonade.version import __version__ as lemonade_version
9
+ from lemonade.common.system_info import get_system_info_dict
10
+ from lemonade.common.build import output_dir
11
+ import lemonade.cache as lemonade_cache
12
+
13
+
14
+ class ManagementTool(abc.ABC):
15
+ """
16
+ Intended for management functions, such as managing the cache
17
+ or printing the version number.
18
+ """
19
+
20
+ unique_name: str
21
+
22
+ @classmethod
23
+ def helpful_parser(cls, short_description: str, **kwargs):
24
+ epilog = (
25
+ f"`{cls.unique_name}` is a Management Tool. It is intended to be invoked by itself "
26
+ "(i.e., not as part of a sequence), to accomplish a utility function. "
27
+ )
28
+
29
+ return ToolParser(
30
+ prog=f"lemonade {cls.unique_name}",
31
+ short_description=short_description,
32
+ description=cls.__doc__,
33
+ epilog=epilog,
34
+ **kwargs,
35
+ )
36
+
37
+ @staticmethod
38
+ @abc.abstractmethod
39
+ def parser() -> argparse.ArgumentParser:
40
+ """
41
+ Static method that returns an ArgumentParser that defines the command
42
+ line interface for this Tool.
43
+ """
44
+
45
+ # pylint: disable=unused-argument
46
+ def parse(self, args, known_only=True) -> argparse.Namespace:
47
+ """
48
+ Run the parser and return a Namespace of keyword arguments that the user
49
+ passed to the Tool via the command line.
50
+
51
+ Tools should extend this function only if they require specific parsing
52
+ logic.
53
+
54
+ Args:
55
+ args: command line arguments passed from the CLI.
56
+ known_only: this argument allows the CLI framework to
57
+ incrementally parse complex commands.
58
+ """
59
+
60
+ if known_only:
61
+ parsed_args = self.__class__.parser().parse_args(args)
62
+ else:
63
+ parsed_args, _ = self.__class__.parser().parse_known_args(args)
64
+
65
+ return parsed_args
66
+
67
+ @abc.abstractmethod
68
+ def run(self, cache_dir: str):
69
+ """
70
+ Execute the functionality of the Tool.
71
+ """
72
+
73
+ def parse_and_run(self, cache_dir: str, args, known_only=True):
74
+ """
75
+ Helper function to parse CLI arguments into the args expected
76
+ by run(), and then forward them into the run() method.
77
+ """
78
+
79
+ parsed_args = self.parse(args, known_only)
80
+ self.run(cache_dir, **parsed_args.__dict__)
81
+
82
+
83
+ class Version(ManagementTool):
84
+ """
85
+ Simply prints the version number of the lemonade installation.
86
+ """
87
+
88
+ unique_name = "version"
89
+
90
+ @staticmethod
91
+ def parser(add_help: bool = True) -> argparse.ArgumentParser:
92
+ parser = __class__.helpful_parser(
93
+ short_description="Print the lemonade version number",
94
+ add_help=add_help,
95
+ )
96
+
97
+ return parser
98
+
99
+ def run(self, _):
100
+ print(lemonade_version)
101
+
102
+
103
+ class Cache(ManagementTool):
104
+ # pylint: disable=pointless-statement,f-string-without-interpolation
105
+ f"""
106
+ A set of functions for managing the lemonade build cache. The default
107
+ cache location is {lemonade_cache.DEFAULT_CACHE_DIR}, and can also be
108
+ selected with
109
+ the global --cache-dir option or the LEMONADE_CACHE_DIR environment variable.
110
+
111
+ Users must set either "--all" or "--build-names" to let the tool
112
+ know what builds to operate on.
113
+
114
+ Users must also set one of the available actions (e.g., list, stats, etc.).
115
+
116
+ That action will be applied to all selected builds.
117
+ """
118
+
119
+ unique_name = "cache"
120
+
121
+ @staticmethod
122
+ def parser(add_help: bool = True) -> argparse.ArgumentParser:
123
+ # NOTE: `--cache-dir` is set as a global input to the lemonade CLI and
124
+ # passed directly to the `run()` method
125
+
126
+ parser = __class__.helpful_parser(
127
+ short_description="Manage the build cache " f"",
128
+ add_help=add_help,
129
+ )
130
+
131
+ build_selection_group = parser.add_mutually_exclusive_group(required=True)
132
+
133
+ build_selection_group.add_argument(
134
+ "-b",
135
+ "--build-names",
136
+ nargs="+",
137
+ help="Name of the specific build(s) to be operated upon, within the cache directory",
138
+ )
139
+
140
+ build_selection_group.add_argument(
141
+ "-a",
142
+ "--all",
143
+ dest="all_builds",
144
+ help="Operate on all the builds in the cache",
145
+ action="store_true",
146
+ )
147
+
148
+ action_group = parser.add_mutually_exclusive_group(required=True)
149
+
150
+ action_group.add_argument(
151
+ "-l",
152
+ "--list",
153
+ dest="list_builds",
154
+ action="store_true",
155
+ help="List all of the builds in the cache",
156
+ )
157
+
158
+ action_group.add_argument(
159
+ "-s",
160
+ "--stats",
161
+ action="store_true",
162
+ help="Print the collected stats for the selected build(s)",
163
+ )
164
+
165
+ action_group.add_argument(
166
+ "--delete",
167
+ action="store_true",
168
+ help="Permanently delete the selected build(s)",
169
+ )
170
+
171
+ action_group.add_argument(
172
+ "--clean",
173
+ action="store_true",
174
+ help="Remove the build artifacts from the selected build(s)",
175
+ )
176
+
177
+ return parser
178
+
179
+ def run(
180
+ self,
181
+ cache_dir: str,
182
+ all_builds: bool = False,
183
+ build_names: List[str] = None,
184
+ list_builds: bool = False,
185
+ stats: bool = False,
186
+ delete: bool = False,
187
+ clean: bool = False,
188
+ ):
189
+ fs.check_cache_dir(cache_dir)
190
+
191
+ if all_builds and build_names:
192
+ raise ValueError(
193
+ "all_builds and build_names are mutually exclusive, "
194
+ "but both are used in this call."
195
+ )
196
+ elif all_builds:
197
+ builds = fs.get_available_builds(cache_dir)
198
+ elif build_names:
199
+ builds = build_names
200
+ else:
201
+ raise ValueError(
202
+ "Either all_builds or build_names must be set, "
203
+ "but this call sets neither."
204
+ )
205
+
206
+ # Print a nice heading
207
+ printing.log_info(f"Operating on cache directory {cache_dir}")
208
+
209
+ if not builds:
210
+ printing.log_warning("No builds found.")
211
+
212
+ for build in builds:
213
+ build_path = output_dir(cache_dir, build_name=build)
214
+ if fs.is_build_dir(cache_dir, build):
215
+ # Run actions on the build
216
+ # These actions are intended to be mutually exclusive, so we
217
+ # use an if-elif block in order from least to most destructive
218
+ if list_builds:
219
+ print(build)
220
+ elif stats:
221
+ fs.print_yaml_file(fs.Stats(cache_dir, build).file, "stats")
222
+ elif clean:
223
+ fs.clean_output_dir(cache_dir, build)
224
+ printing.log_info(f"Removed the build artifacts from: {build}")
225
+
226
+ elif delete:
227
+ fs.rmdir(build_path)
228
+ printing.log_info(f"Deleted build: {build}")
229
+ else:
230
+ raise exp.CacheError(
231
+ f"No build found with name: {build}. "
232
+ "Try running `lemonade cache --list` to see the builds in your build cache."
233
+ )
234
+
235
+ print()
236
+
237
+
238
+ class SystemInfo(ManagementTool):
239
+ """
240
+ Prints system information for the lemonade installation.
241
+ """
242
+
243
+ unique_name = "system-info"
244
+
245
+ @staticmethod
246
+ def parser(add_help: bool = True) -> argparse.ArgumentParser:
247
+ parser = __class__.helpful_parser(
248
+ short_description="Print system information",
249
+ add_help=add_help,
250
+ )
251
+
252
+ return parser
253
+
254
+ @staticmethod
255
+ def pretty_print(my_dict: dict, level=0):
256
+ for k, v in my_dict.items():
257
+ if isinstance(v, dict):
258
+ print(" " * level + f"{k}:")
259
+ SystemInfo.pretty_print(v, level + 1)
260
+ elif isinstance(v, list):
261
+ print(" " * level + f"{k}:")
262
+ for item in v:
263
+ print(" " * (level + 1) + f"{item}")
264
+ else:
265
+ print(" " * level + f"{k}: {v}")
266
+
267
+ def run(self, _):
268
+ system_info_dict = get_system_info_dict()
269
+ self.pretty_print(system_info_dict)
270
+
271
+
272
+ # This file was originally licensed under Apache 2.0. It has been modified.
273
+ # Modifications Copyright (c) 2025 AMD
lemonade/tools/mmlu.py ADDED
@@ -0,0 +1,327 @@
1
+ import argparse
2
+ import os
3
+ import tarfile
4
+ from pathlib import Path
5
+ from typing import List, Optional
6
+ import subprocess
7
+ import numpy as np
8
+ import pandas as pd
9
+ import requests
10
+ from lemonade.state import State
11
+ from lemonade.tools import Tool
12
+ import lemonade.common.printing as printing
13
+ import lemonade.common.build as build
14
+ import lemonade.common.filesystem as fs
15
+
16
+ # Constants
17
+ choices = ["A", "B", "C", "D"]
18
+ dataset_url = "https://people.eecs.berkeley.edu/~hendrycks/data.tar"
19
+
20
+
21
+ def min_handle_none(*args: int):
22
+ """
23
+ Returns the minimum of the arguments. If one of the arguments is none,
24
+ it doesn't count towards the min.
25
+ """
26
+
27
+ filter_out_none = (value for value in args if value is not None)
28
+ return min(filter_out_none)
29
+
30
+
31
+ class AccuracyMMLU(Tool):
32
+ """
33
+ See docs/lemonade/mmlu_accuracy.md for more details
34
+ """
35
+
36
+ unique_name = "accuracy-mmlu"
37
+
38
+ def __init__(self):
39
+ super().__init__(monitor_message="Measuring accuracy with MMLU")
40
+ self.status_stats = []
41
+
42
+ @staticmethod
43
+ def parser(add_help: bool = True) -> argparse.ArgumentParser:
44
+ parser = __class__.helpful_parser(
45
+ short_description="Measure accuracy with Massive Multitask "
46
+ "Language Understanding (MMLU)",
47
+ add_help=add_help,
48
+ )
49
+
50
+ parser.add_argument(
51
+ "--ntrain",
52
+ type=int,
53
+ default=5,
54
+ help="Number of training examples to use. Default set to 5 for `5 Shot`",
55
+ )
56
+ parser.add_argument(
57
+ "--max-evals",
58
+ type=int,
59
+ default=None,
60
+ help="Maximum evaluations to run per test",
61
+ )
62
+ parser.add_argument(
63
+ "--data-dir",
64
+ type=str,
65
+ required=False,
66
+ help="Directory containing test and dev data (default: lemonade cache).",
67
+ )
68
+ parser.add_argument(
69
+ "--tests",
70
+ nargs="+",
71
+ help=(
72
+ "Specific tests to run. For a single quick test, we suggest 'management'."
73
+ + "Default: run all tests."
74
+ ),
75
+ )
76
+ return parser
77
+
78
+ def run(
79
+ self,
80
+ state: State,
81
+ ntrain: int = 5,
82
+ max_evals: int = None,
83
+ data_dir: Optional[str] = None,
84
+ tests: List[str] = None,
85
+ ) -> State:
86
+
87
+ if data_dir:
88
+ data_dir_to_use = data_dir
89
+ else:
90
+ data_dir_to_use = os.path.join(state.cache_dir, "data", "mmlu")
91
+
92
+ # Setup MMLU dataset
93
+ dataset_dir = download_and_extract_dataset(data_dir_to_use, dataset_url)
94
+
95
+ model_results_dir = os.path.join(
96
+ build.output_dir(state.cache_dir, state.build_name), "mmlu"
97
+ )
98
+ os.makedirs(model_results_dir, exist_ok=True)
99
+
100
+ tests_to_run = [
101
+ f.replace("_test.csv", "")
102
+ for f in sorted(os.listdir(os.path.join(dataset_dir, "test")))
103
+ if f.endswith("_test.csv")
104
+ ]
105
+ if tests is not None:
106
+ unsupported_tests = set(tests) - set(tests_to_run)
107
+ if unsupported_tests:
108
+ raise ValueError(
109
+ f"Invalid test names provided: {', '.join(unsupported_tests)}. "
110
+ f"Valid tests are: {', '.join(tests_to_run)}"
111
+ )
112
+ tests_to_run = [test for test in tests if test in tests_to_run]
113
+
114
+ tokenizer = state.tokenizer
115
+ model = state.model
116
+
117
+ # Update Tool progress monitor
118
+ self.set_percent_progress(0.0)
119
+ number_of_questions = float(
120
+ sum(
121
+ [
122
+ min_handle_none(
123
+ len(
124
+ _safe_read_csv(
125
+ os.path.join(dataset_dir, "test", f"{subject}_test.csv")
126
+ )
127
+ ),
128
+ max_evals,
129
+ )
130
+ for subject in tests_to_run
131
+ ]
132
+ )
133
+ )
134
+
135
+ questions_completed = 0
136
+
137
+ summary_data = []
138
+ for subject in tests_to_run:
139
+ dev_df = _safe_read_csv(
140
+ os.path.join(dataset_dir, "dev", f"{subject}_dev.csv")
141
+ )[:ntrain]
142
+ test_df = _safe_read_csv(
143
+ os.path.join(dataset_dir, "test", f"{subject}_test.csv")
144
+ )
145
+
146
+ # Evaluate the model on the test data for a given subject
147
+ detailed_results = []
148
+
149
+ for i in range(min_handle_none(test_df.shape[0], max_evals)):
150
+ prompt = _gen_prompt(dev_df, subject, ntrain) + _format_example(
151
+ test_df, i, include_answer=False
152
+ )
153
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
154
+
155
+ response_text = _generate_response(tokenizer, model, input_ids)
156
+ try:
157
+ pred_label = response_text[-1].upper()
158
+ # Handle models generating empty outputs
159
+ except IndexError:
160
+ pred_label = "-"
161
+
162
+ label = test_df.iloc[i, -1].strip().upper()
163
+ detailed_results.append(
164
+ {
165
+ "Question": test_df.iloc[i, 0],
166
+ "Prompt": prompt,
167
+ "Correct Answer": label,
168
+ "Generated Answer": pred_label,
169
+ "Correct": pred_label == label,
170
+ }
171
+ )
172
+
173
+ # Update progress monitor
174
+ questions_completed = questions_completed + 1
175
+ percent_completed = questions_completed / number_of_questions * 100
176
+ self.set_percent_progress(percent_completed)
177
+
178
+ acc = np.mean([res["Correct"] for res in detailed_results])
179
+
180
+ subject_results_df = pd.DataFrame(detailed_results)
181
+ subject_csv_path = os.path.join(
182
+ model_results_dir, f"{subject}_detailed_results.csv"
183
+ )
184
+ subject_results_df.to_csv(subject_csv_path, index=False)
185
+
186
+ # Update summary_data with total questions and correct answers
187
+ correct_answers_count = sum(
188
+ result["Correct"] for result in detailed_results
189
+ )
190
+
191
+ summary_data.append(
192
+ {
193
+ "Subject": subject,
194
+ "Accuracy": acc,
195
+ "Total Questions": len(test_df),
196
+ "Evaluated Questions": (
197
+ max_evals
198
+ if max_evals is not None and max_evals < len(test_df)
199
+ else len(test_df)
200
+ ),
201
+ "Correct Answers": correct_answers_count,
202
+ }
203
+ )
204
+
205
+ # Save accuracy results to stats file
206
+ # And display in the CLI
207
+ stat_name = f"mmlu_{subject}_accuracy"
208
+ stat_units_name = f"{stat_name}_units"
209
+ state.save_stat(stat_name, float(acc) * 100)
210
+ state.save_stat(stat_units_name, "%")
211
+ self.status_stats.append(stat_name)
212
+
213
+ # Calculate average of mmlu accuracy and display in the CLI
214
+ acc_avg = np.mean([accuracy_data["Accuracy"] for accuracy_data in summary_data])
215
+ state.save_stat(fs.Keys.AVERAGE_MMLU_ACCURACY, float(acc_avg) * 100)
216
+ state.save_stat(f"{fs.Keys.AVERAGE_MMLU_ACCURACY}_units", "%")
217
+ self.status_stats.append(fs.Keys.AVERAGE_MMLU_ACCURACY)
218
+
219
+ # Save accuracy results to CSV file
220
+ summary_df = pd.DataFrame(summary_data)
221
+ summary_df.to_csv(
222
+ os.path.join(model_results_dir, "summary_results.csv"), index=False
223
+ )
224
+ return state
225
+
226
+
227
+ def _list_tests(data_dir):
228
+ """Lists all available tests based on the files in the test data directory."""
229
+ test_files = [
230
+ f for f in os.listdir(os.path.join(data_dir, "test")) if f.endswith("_test.csv")
231
+ ]
232
+ print(
233
+ "Available tests:",
234
+ *[f.replace("_test.csv", "") for f in sorted(test_files)],
235
+ sep="\n",
236
+ )
237
+
238
+
239
+ def _format_subject(subject):
240
+ """Formats a subject string by replacing underscores with spaces."""
241
+ return " ".join(subject.split("_"))
242
+
243
+
244
+ def _safe_read_csv(path):
245
+ """Safely reads a CSV file and returns a DataFrame."""
246
+ try:
247
+ return pd.read_csv(path, header=None)
248
+ except FileNotFoundError:
249
+ printing.log_error(f"Error: File not found - {path}")
250
+ except Exception as e: # pylint: disable=broad-except
251
+ printing.log_error(f"An error occurred while reading {path}: {e}")
252
+
253
+
254
+ def _format_example(df, idx, include_answer=True):
255
+ """Formats an example from the dataframe into a prompt string."""
256
+ prompt = df.iloc[idx, 0]
257
+ for j in range(1, df.shape[1] - 1):
258
+ prompt += f"\n{choices[j-1]}. {df.iloc[idx, j]}"
259
+ prompt += "\nAnswer_:"
260
+ if include_answer:
261
+ prompt += f" {df.iloc[idx, -1]}\n\n"
262
+ return prompt
263
+
264
+
265
+ def _gen_prompt(train_df, subject, k=-1):
266
+ """Generates a prompt string from multiple examples."""
267
+ prompt = (
268
+ "The following are multiple choice questions (with answers) about "
269
+ + f"{_format_subject(subject)}.\n\n"
270
+ )
271
+ for i in range(min(k, train_df.shape[0]) if k != -1 else train_df.shape[0]):
272
+ prompt += _format_example(train_df, i)
273
+ return prompt
274
+
275
+
276
+ def _generate_response(tokenizer, model, input_ids):
277
+ """Generates a model response for the given input IDs."""
278
+ try:
279
+ response = model.generate(input_ids, max_new_tokens=1)
280
+ return tokenizer.decode(response[0], skip_special_tokens=True).strip()
281
+ except subprocess.CalledProcessError as e:
282
+ printing.log_warning(
283
+ f"Subprocess failed with command: {e} and error message: {e.stderr}"
284
+ )
285
+ except Exception as e: # pylint: disable=broad-except
286
+ printing.log_warning(f"Error during model generation: {e}")
287
+ return "" # Return an empty string on failure
288
+
289
+
290
+ def download_and_extract_dataset(data_cache_dir: str, dataset_url: str):
291
+ """
292
+ Download the dataset from the given URL and extract it into the target directory.
293
+ """
294
+
295
+ # Create the directory if it does not exist
296
+ Path(data_cache_dir).mkdir(parents=True, exist_ok=True)
297
+
298
+ # Check if the data already exists to avoid re-downloading
299
+ if not os.listdir(data_cache_dir): # Checks if the directory is empty
300
+ printing.log_info(f"Downloading dataset to {data_cache_dir}")
301
+
302
+ # Download the dataset
303
+ response = requests.get(dataset_url, stream=True)
304
+ if response.status_code == 200:
305
+ tar_path = os.path.join(data_cache_dir, "data.tar")
306
+ with open(tar_path, "wb") as f:
307
+ f.write(response.raw.read())
308
+
309
+ printing.log_info("Extracting dataset...")
310
+ # Extract the tar file
311
+ with tarfile.open(tar_path) as tar:
312
+ tar.extractall(path=data_cache_dir)
313
+ os.remove(tar_path)
314
+ printing.log_info("Dataset ready.")
315
+ else:
316
+ printing.log_info("Failed to download the dataset.")
317
+ else:
318
+ printing.log_info(
319
+ f"Dataset already exists in {data_cache_dir}, skipping download."
320
+ )
321
+
322
+ # MMLU data is stored in data.tar/data
323
+ return os.path.join(data_cache_dir, "data")
324
+
325
+
326
+ # This file was originally licensed under Apache 2.0. It has been modified.
327
+ # Modifications Copyright (c) 2025 AMD
File without changes