lemonade-sdk 9.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. lemonade/__init__.py +5 -0
  2. lemonade/api.py +180 -0
  3. lemonade/cache.py +92 -0
  4. lemonade/cli.py +173 -0
  5. lemonade/common/__init__.py +0 -0
  6. lemonade/common/build.py +176 -0
  7. lemonade/common/cli_helpers.py +139 -0
  8. lemonade/common/exceptions.py +98 -0
  9. lemonade/common/filesystem.py +368 -0
  10. lemonade/common/inference_engines.py +408 -0
  11. lemonade/common/network.py +93 -0
  12. lemonade/common/printing.py +110 -0
  13. lemonade/common/status.py +471 -0
  14. lemonade/common/system_info.py +1411 -0
  15. lemonade/common/test_helpers.py +28 -0
  16. lemonade/profilers/__init__.py +1 -0
  17. lemonade/profilers/agt_power.py +437 -0
  18. lemonade/profilers/hwinfo_power.py +429 -0
  19. lemonade/profilers/memory_tracker.py +259 -0
  20. lemonade/profilers/profiler.py +58 -0
  21. lemonade/sequence.py +363 -0
  22. lemonade/state.py +159 -0
  23. lemonade/tools/__init__.py +1 -0
  24. lemonade/tools/accuracy.py +432 -0
  25. lemonade/tools/adapter.py +114 -0
  26. lemonade/tools/bench.py +302 -0
  27. lemonade/tools/flm/__init__.py +1 -0
  28. lemonade/tools/flm/utils.py +305 -0
  29. lemonade/tools/huggingface/bench.py +187 -0
  30. lemonade/tools/huggingface/load.py +235 -0
  31. lemonade/tools/huggingface/utils.py +359 -0
  32. lemonade/tools/humaneval.py +264 -0
  33. lemonade/tools/llamacpp/bench.py +255 -0
  34. lemonade/tools/llamacpp/load.py +222 -0
  35. lemonade/tools/llamacpp/utils.py +1260 -0
  36. lemonade/tools/management_tools.py +319 -0
  37. lemonade/tools/mmlu.py +319 -0
  38. lemonade/tools/oga/__init__.py +0 -0
  39. lemonade/tools/oga/bench.py +120 -0
  40. lemonade/tools/oga/load.py +804 -0
  41. lemonade/tools/oga/migration.py +403 -0
  42. lemonade/tools/oga/utils.py +462 -0
  43. lemonade/tools/perplexity.py +147 -0
  44. lemonade/tools/prompt.py +263 -0
  45. lemonade/tools/report/__init__.py +0 -0
  46. lemonade/tools/report/llm_report.py +203 -0
  47. lemonade/tools/report/table.py +899 -0
  48. lemonade/tools/server/__init__.py +0 -0
  49. lemonade/tools/server/flm.py +133 -0
  50. lemonade/tools/server/llamacpp.py +320 -0
  51. lemonade/tools/server/serve.py +2123 -0
  52. lemonade/tools/server/static/favicon.ico +0 -0
  53. lemonade/tools/server/static/index.html +279 -0
  54. lemonade/tools/server/static/js/chat.js +1059 -0
  55. lemonade/tools/server/static/js/model-settings.js +183 -0
  56. lemonade/tools/server/static/js/models.js +1395 -0
  57. lemonade/tools/server/static/js/shared.js +556 -0
  58. lemonade/tools/server/static/logs.html +191 -0
  59. lemonade/tools/server/static/styles.css +2654 -0
  60. lemonade/tools/server/static/webapp.html +321 -0
  61. lemonade/tools/server/tool_calls.py +153 -0
  62. lemonade/tools/server/tray.py +664 -0
  63. lemonade/tools/server/utils/macos_tray.py +226 -0
  64. lemonade/tools/server/utils/port.py +77 -0
  65. lemonade/tools/server/utils/thread.py +85 -0
  66. lemonade/tools/server/utils/windows_tray.py +408 -0
  67. lemonade/tools/server/webapp.py +34 -0
  68. lemonade/tools/server/wrapped_server.py +559 -0
  69. lemonade/tools/tool.py +374 -0
  70. lemonade/version.py +1 -0
  71. lemonade_install/__init__.py +1 -0
  72. lemonade_install/install.py +239 -0
  73. lemonade_sdk-9.1.1.dist-info/METADATA +276 -0
  74. lemonade_sdk-9.1.1.dist-info/RECORD +84 -0
  75. lemonade_sdk-9.1.1.dist-info/WHEEL +5 -0
  76. lemonade_sdk-9.1.1.dist-info/entry_points.txt +5 -0
  77. lemonade_sdk-9.1.1.dist-info/licenses/LICENSE +201 -0
  78. lemonade_sdk-9.1.1.dist-info/licenses/NOTICE.md +47 -0
  79. lemonade_sdk-9.1.1.dist-info/top_level.txt +3 -0
  80. lemonade_server/cli.py +805 -0
  81. lemonade_server/model_manager.py +758 -0
  82. lemonade_server/pydantic_models.py +159 -0
  83. lemonade_server/server_models.json +643 -0
  84. lemonade_server/settings.py +39 -0
@@ -0,0 +1,319 @@
1
+ import argparse
2
+ import abc
3
+ import json
4
+ from typing import List
5
+ import lemonade.common.filesystem as fs
6
+ import lemonade.common.exceptions as exp
7
+ import lemonade.common.printing as printing
8
+ from lemonade.tools.tool import ToolParser
9
+ from lemonade.version import __version__ as lemonade_version
10
+ from lemonade.common.system_info import (
11
+ get_system_info_dict,
12
+ get_device_info_dict,
13
+ get_system_info,
14
+ )
15
+ from lemonade.common.build import output_dir
16
+ import lemonade.cache as lemonade_cache
17
+
18
+
19
+ class ManagementTool(abc.ABC):
20
+ """
21
+ Intended for management functions, such as managing the cache
22
+ or printing the version number.
23
+ """
24
+
25
+ unique_name: str
26
+
27
+ @classmethod
28
+ def helpful_parser(cls, short_description: str, **kwargs):
29
+ epilog = (
30
+ f"`{cls.unique_name}` is a Management Tool. It is intended to be invoked by itself "
31
+ "(i.e., not as part of a sequence), to accomplish a utility function. "
32
+ )
33
+
34
+ return ToolParser(
35
+ prog=f"lemonade {cls.unique_name}",
36
+ short_description=short_description,
37
+ description=cls.__doc__,
38
+ epilog=epilog,
39
+ **kwargs,
40
+ )
41
+
42
+ @staticmethod
43
+ @abc.abstractmethod
44
+ def parser() -> argparse.ArgumentParser:
45
+ """
46
+ Static method that returns an ArgumentParser that defines the command
47
+ line interface for this Tool.
48
+ """
49
+
50
+ # pylint: disable=unused-argument
51
+ def parse(self, args, known_only=True) -> argparse.Namespace:
52
+ """
53
+ Run the parser and return a Namespace of keyword arguments that the user
54
+ passed to the Tool via the command line.
55
+
56
+ Tools should extend this function only if they require specific parsing
57
+ logic.
58
+
59
+ Args:
60
+ args: command line arguments passed from the CLI.
61
+ known_only: this argument allows the CLI framework to
62
+ incrementally parse complex commands.
63
+ """
64
+
65
+ if known_only:
66
+ parsed_args = self.__class__.parser().parse_args(args)
67
+ else:
68
+ parsed_args, _ = self.__class__.parser().parse_known_args(args)
69
+
70
+ return parsed_args
71
+
72
+ @abc.abstractmethod
73
+ def run(self, cache_dir: str):
74
+ """
75
+ Execute the functionality of the Tool.
76
+ """
77
+
78
+ def parse_and_run(self, cache_dir: str, args, known_only=True):
79
+ """
80
+ Helper function to parse CLI arguments into the args expected
81
+ by run(), and then forward them into the run() method.
82
+ """
83
+
84
+ parsed_args = self.parse(args, known_only)
85
+ self.run(cache_dir, **parsed_args.__dict__)
86
+
87
+
88
+ class Version(ManagementTool):
89
+ """
90
+ Simply prints the version number of the lemonade installation.
91
+ """
92
+
93
+ unique_name = "version"
94
+
95
+ @staticmethod
96
+ def parser(add_help: bool = True) -> argparse.ArgumentParser:
97
+ parser = __class__.helpful_parser(
98
+ short_description="Print the lemonade version number",
99
+ add_help=add_help,
100
+ )
101
+
102
+ return parser
103
+
104
+ def run(self, _):
105
+ print(lemonade_version)
106
+
107
+
108
+ class Cache(ManagementTool):
109
+ # pylint: disable=pointless-statement,f-string-without-interpolation
110
+ f"""
111
+ A set of functions for managing the lemonade build cache. The default
112
+ cache location is {lemonade_cache.DEFAULT_CACHE_DIR}, and can also be
113
+ selected with
114
+ the global --cache-dir option or the LEMONADE_CACHE_DIR environment variable.
115
+
116
+ Users must set either "--all" or "--build-names" to let the tool
117
+ know what builds to operate on.
118
+
119
+ Users must also set one of the available actions (e.g., list, stats, etc.).
120
+
121
+ That action will be applied to all selected builds.
122
+ """
123
+
124
+ unique_name = "cache"
125
+
126
+ @staticmethod
127
+ def parser(add_help: bool = True) -> argparse.ArgumentParser:
128
+ # NOTE: `--cache-dir` is set as a global input to the lemonade CLI and
129
+ # passed directly to the `run()` method
130
+
131
+ parser = __class__.helpful_parser(
132
+ short_description="Manage the build cache " f"",
133
+ add_help=add_help,
134
+ )
135
+
136
+ build_selection_group = parser.add_mutually_exclusive_group(required=True)
137
+
138
+ build_selection_group.add_argument(
139
+ "-b",
140
+ "--build-names",
141
+ nargs="+",
142
+ help="Name of the specific build(s) to be operated upon, within the cache directory",
143
+ )
144
+
145
+ build_selection_group.add_argument(
146
+ "-a",
147
+ "--all",
148
+ dest="all_builds",
149
+ help="Operate on all the builds in the cache",
150
+ action="store_true",
151
+ )
152
+
153
+ action_group = parser.add_mutually_exclusive_group(required=True)
154
+
155
+ action_group.add_argument(
156
+ "-l",
157
+ "--list",
158
+ dest="list_builds",
159
+ action="store_true",
160
+ help="List all of the builds in the cache",
161
+ )
162
+
163
+ action_group.add_argument(
164
+ "-s",
165
+ "--stats",
166
+ action="store_true",
167
+ help="Print the collected stats for the selected build(s)",
168
+ )
169
+
170
+ action_group.add_argument(
171
+ "--delete",
172
+ action="store_true",
173
+ help="Permanently delete the selected build(s)",
174
+ )
175
+
176
+ action_group.add_argument(
177
+ "--clean",
178
+ action="store_true",
179
+ help="Remove the build artifacts from the selected build(s)",
180
+ )
181
+
182
+ return parser
183
+
184
+ def run(
185
+ self,
186
+ cache_dir: str,
187
+ all_builds: bool = False,
188
+ build_names: List[str] = None,
189
+ list_builds: bool = False,
190
+ stats: bool = False,
191
+ delete: bool = False,
192
+ clean: bool = False,
193
+ ):
194
+ fs.check_cache_dir(cache_dir)
195
+
196
+ if all_builds and build_names:
197
+ raise ValueError(
198
+ "all_builds and build_names are mutually exclusive, "
199
+ "but both are used in this call."
200
+ )
201
+ elif all_builds:
202
+ builds = fs.get_available_builds(cache_dir)
203
+ elif build_names:
204
+ builds = build_names
205
+ else:
206
+ raise ValueError(
207
+ "Either all_builds or build_names must be set, "
208
+ "but this call sets neither."
209
+ )
210
+
211
+ # Print a nice heading
212
+ printing.log_info(f"Operating on cache directory {cache_dir}")
213
+
214
+ if not builds:
215
+ printing.log_warning("No builds found.")
216
+
217
+ for build in builds:
218
+ build_path = output_dir(cache_dir, build_name=build)
219
+ if fs.is_build_dir(cache_dir, build):
220
+ # Run actions on the build
221
+ # These actions are intended to be mutually exclusive, so we
222
+ # use an if-elif block in order from least to most destructive
223
+ if list_builds:
224
+ print(build)
225
+ elif stats:
226
+ fs.print_yaml_file(fs.Stats(cache_dir, build).file, "stats")
227
+ elif clean:
228
+ fs.clean_output_dir(cache_dir, build)
229
+ printing.log_info(f"Removed the build artifacts from: {build}")
230
+
231
+ elif delete:
232
+ fs.rmdir(build_path)
233
+ printing.log_info(f"Deleted build: {build}")
234
+ else:
235
+ raise exp.CacheError(
236
+ f"No build found with name: {build}. "
237
+ "Try running `lemonade cache --list` to see the builds in your build cache."
238
+ )
239
+
240
+ print()
241
+
242
+
243
+ class SystemInfo(ManagementTool):
244
+ """
245
+ Prints system information for the lemonade installation.
246
+ """
247
+
248
+ unique_name = "system-info"
249
+
250
+ @staticmethod
251
+ def parser(add_help: bool = True) -> argparse.ArgumentParser:
252
+ parser = __class__.helpful_parser(
253
+ short_description="Print system and device information",
254
+ add_help=add_help,
255
+ )
256
+
257
+ parser.add_argument(
258
+ "--format", choices=["table", "json"], default="table", help="Output format"
259
+ )
260
+
261
+ parser.add_argument(
262
+ "--verbose",
263
+ action="store_true",
264
+ help="Show detailed system information",
265
+ )
266
+
267
+ return parser
268
+
269
+ @staticmethod
270
+ def pretty_print(my_dict: dict, level=0):
271
+ for k, v in my_dict.items():
272
+ if k == "available" and v is True:
273
+ continue
274
+
275
+ if isinstance(v, dict):
276
+ # Special handling for device availability
277
+ if v.get("available") is False:
278
+ error_msg = v.get("error", "Not available")
279
+ print(" " * level + f"{k}: {error_msg}")
280
+ else:
281
+ print(" " * level + f"{k}:")
282
+ SystemInfo.pretty_print(v, level + 1)
283
+ elif isinstance(v, list):
284
+ print(" " * level + f"{k}:")
285
+ for item in v:
286
+ if isinstance(item, dict):
287
+ SystemInfo.pretty_print(item, level + 1)
288
+ print()
289
+ else:
290
+ print(" " * (level + 1) + f"{item}")
291
+ else:
292
+ print(" " * level + f"{k}: {v}")
293
+
294
+ def run(self, _, format="table", verbose=False):
295
+ # Get basic system info
296
+ system_info_dict = get_system_info_dict()
297
+
298
+ # Always include devices
299
+ system_info_dict["Devices"] = get_device_info_dict()
300
+
301
+ # Filter out verbose-only information if not in verbose mode
302
+ if not verbose:
303
+ essential_keys = ["OS Version", "Processor", "Physical Memory", "Devices"]
304
+ system_info_dict = {
305
+ k: v for k, v in system_info_dict.items() if k in essential_keys
306
+ }
307
+ else:
308
+ # In verbose mode, add Python packages at the end
309
+ system_info = get_system_info()
310
+ system_info_dict["Python Packages"] = system_info.get_python_packages()
311
+
312
+ if format == "json":
313
+ print(json.dumps(system_info_dict, indent=2))
314
+ else:
315
+ self.pretty_print(system_info_dict)
316
+
317
+
318
+ # This file was originally licensed under Apache 2.0. It has been modified.
319
+ # Modifications Copyright (c) 2025 AMD
lemonade/tools/mmlu.py ADDED
@@ -0,0 +1,319 @@
1
+ import argparse
2
+ import os
3
+ import tarfile
4
+ from pathlib import Path
5
+ from typing import List, Optional
6
+ import subprocess
7
+ from lemonade.state import State
8
+ from lemonade.tools import Tool
9
+ import lemonade.common.printing as printing
10
+ import lemonade.common.build as build
11
+ import lemonade.common.filesystem as fs
12
+
13
+ # Constants
14
+ choices = ["A", "B", "C", "D"]
15
+ dataset_url = "https://people.eecs.berkeley.edu/~hendrycks/data.tar"
16
+
17
+
18
+ def min_handle_none(*args: int):
19
+ """
20
+ Returns the minimum of the arguments. If one of the arguments is none,
21
+ it doesn't count towards the min.
22
+ """
23
+
24
+ filter_out_none = (value for value in args if value is not None)
25
+ return min(filter_out_none)
26
+
27
+
28
+ class AccuracyMMLU(Tool):
29
+ """
30
+ See docs/dev_cli/mmlu_accuracy.md for more details
31
+ """
32
+
33
+ unique_name = "accuracy-mmlu"
34
+
35
+ def __init__(self):
36
+ super().__init__(monitor_message="Measuring accuracy with MMLU")
37
+ self.status_stats = []
38
+
39
+ @staticmethod
40
+ def parser(add_help: bool = True) -> argparse.ArgumentParser:
41
+ parser = __class__.helpful_parser(
42
+ short_description="Measure accuracy with Massive Multitask "
43
+ "Language Understanding (MMLU)",
44
+ add_help=add_help,
45
+ )
46
+
47
+ parser.add_argument(
48
+ "--ntrain",
49
+ type=int,
50
+ default=5,
51
+ help="Number of training examples to use. Default set to 5 for `5 Shot`",
52
+ )
53
+ parser.add_argument(
54
+ "--max-evals",
55
+ type=int,
56
+ default=None,
57
+ help="Maximum evaluations to run per test",
58
+ )
59
+ parser.add_argument(
60
+ "--data-dir",
61
+ type=str,
62
+ required=False,
63
+ help="Directory containing test and dev data (default: lemonade cache).",
64
+ )
65
+ parser.add_argument(
66
+ "--tests",
67
+ nargs="+",
68
+ help=(
69
+ "Specific tests to run. For a single quick test, we suggest 'management'."
70
+ + "Default: run all tests."
71
+ ),
72
+ )
73
+ return parser
74
+
75
+ def run(
76
+ self,
77
+ state: State,
78
+ ntrain: int = 5,
79
+ max_evals: int = None,
80
+ data_dir: Optional[str] = None,
81
+ tests: List[str] = None,
82
+ ) -> State:
83
+
84
+ import numpy as np
85
+ import pandas as pd
86
+
87
+ if data_dir:
88
+ data_dir_to_use = data_dir
89
+ else:
90
+ data_dir_to_use = os.path.join(state.cache_dir, "data", "mmlu")
91
+
92
+ # Setup MMLU dataset
93
+ dataset_dir = download_and_extract_dataset(data_dir_to_use, dataset_url)
94
+
95
+ model_results_dir = os.path.join(
96
+ build.output_dir(state.cache_dir, state.build_name), "mmlu"
97
+ )
98
+ os.makedirs(model_results_dir, exist_ok=True)
99
+
100
+ tests_to_run = [
101
+ f.replace("_test.csv", "")
102
+ for f in sorted(os.listdir(os.path.join(dataset_dir, "test")))
103
+ if f.endswith("_test.csv")
104
+ ]
105
+ if tests is not None:
106
+ unsupported_tests = set(tests) - set(tests_to_run)
107
+ if unsupported_tests:
108
+ raise ValueError(
109
+ f"Invalid test names provided: {', '.join(unsupported_tests)}. "
110
+ f"Valid tests are: {', '.join(tests_to_run)}"
111
+ )
112
+ tests_to_run = [test for test in tests if test in tests_to_run]
113
+
114
+ tokenizer = state.tokenizer
115
+ model = state.model
116
+
117
+ # Update Tool progress monitor
118
+ self.set_percent_progress(0.0)
119
+ number_of_questions = float(
120
+ sum(
121
+ [
122
+ min_handle_none(
123
+ len(
124
+ _safe_read_csv(
125
+ os.path.join(dataset_dir, "test", f"{subject}_test.csv")
126
+ )
127
+ ),
128
+ max_evals,
129
+ )
130
+ for subject in tests_to_run
131
+ ]
132
+ )
133
+ )
134
+
135
+ questions_completed = 0
136
+
137
+ summary_data = []
138
+ for subject in tests_to_run:
139
+ dev_df = _safe_read_csv(
140
+ os.path.join(dataset_dir, "dev", f"{subject}_dev.csv")
141
+ )[:ntrain]
142
+ test_df = _safe_read_csv(
143
+ os.path.join(dataset_dir, "test", f"{subject}_test.csv")
144
+ )
145
+
146
+ # Evaluate the model on the test data for a given subject
147
+ detailed_results = []
148
+
149
+ for i in range(min_handle_none(test_df.shape[0], max_evals)):
150
+ prompt = _gen_prompt(dev_df, subject, ntrain) + _format_example(
151
+ test_df, i, include_answer=False
152
+ )
153
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
154
+
155
+ response_text = _generate_response(tokenizer, model, input_ids)
156
+ try:
157
+ pred_label = response_text[-1].upper()
158
+ # Handle models generating empty outputs
159
+ except IndexError:
160
+ pred_label = "-"
161
+
162
+ label = test_df.iloc[i, -1].strip().upper()
163
+ detailed_results.append(
164
+ {
165
+ "Question": test_df.iloc[i, 0],
166
+ "Prompt": prompt,
167
+ "Correct Answer": label,
168
+ "Generated Answer": pred_label,
169
+ "Correct": pred_label == label,
170
+ }
171
+ )
172
+
173
+ # Update progress monitor
174
+ questions_completed = questions_completed + 1
175
+ percent_completed = questions_completed / number_of_questions * 100
176
+ self.set_percent_progress(percent_completed)
177
+
178
+ acc = np.mean([res["Correct"] for res in detailed_results])
179
+
180
+ subject_results_df = pd.DataFrame(detailed_results)
181
+ subject_csv_path = os.path.join(
182
+ model_results_dir, f"{subject}_detailed_results.csv"
183
+ )
184
+ subject_results_df.to_csv(subject_csv_path, index=False)
185
+
186
+ # Update summary_data with total questions and correct answers
187
+ correct_answers_count = sum(
188
+ result["Correct"] for result in detailed_results
189
+ )
190
+
191
+ summary_data.append(
192
+ {
193
+ "Subject": subject,
194
+ "Accuracy": acc,
195
+ "Total Questions": len(test_df),
196
+ "Evaluated Questions": (
197
+ max_evals
198
+ if max_evals is not None and max_evals < len(test_df)
199
+ else len(test_df)
200
+ ),
201
+ "Correct Answers": correct_answers_count,
202
+ }
203
+ )
204
+
205
+ # Save accuracy results to stats file
206
+ # And display in the CLI
207
+ stat_name = f"mmlu_{subject}_accuracy"
208
+ stat_units_name = f"{stat_name}_units"
209
+ state.save_stat(stat_name, float(acc) * 100)
210
+ state.save_stat(stat_units_name, "%")
211
+ self.status_stats.append(stat_name)
212
+
213
+ # Calculate average of mmlu accuracy and display in the CLI
214
+ acc_avg = np.mean([accuracy_data["Accuracy"] for accuracy_data in summary_data])
215
+ state.save_stat(fs.Keys.AVERAGE_MMLU_ACCURACY, float(acc_avg) * 100)
216
+ state.save_stat(f"{fs.Keys.AVERAGE_MMLU_ACCURACY}_units", "%")
217
+ self.status_stats.append(fs.Keys.AVERAGE_MMLU_ACCURACY)
218
+
219
+ # Save accuracy results to CSV file
220
+ summary_df = pd.DataFrame(summary_data)
221
+ summary_df.to_csv(
222
+ os.path.join(model_results_dir, "summary_results.csv"), index=False
223
+ )
224
+ return state
225
+
226
+
227
+ def _format_subject(subject):
228
+ """Formats a subject string by replacing underscores with spaces."""
229
+ return " ".join(subject.split("_"))
230
+
231
+
232
+ def _safe_read_csv(path):
233
+ """Safely reads a CSV file and returns a DataFrame."""
234
+ import pandas as pd
235
+
236
+ try:
237
+ return pd.read_csv(path, header=None)
238
+ except FileNotFoundError:
239
+ printing.log_error(f"Error: File not found - {path}")
240
+ except Exception as e: # pylint: disable=broad-except
241
+ printing.log_error(f"An error occurred while reading {path}: {e}")
242
+
243
+
244
+ def _format_example(df, idx, include_answer=True):
245
+ """Formats an example from the dataframe into a prompt string."""
246
+ prompt = df.iloc[idx, 0]
247
+ for j in range(1, df.shape[1] - 1):
248
+ prompt += f"\n{choices[j-1]}. {df.iloc[idx, j]}"
249
+ prompt += "\nAnswer_:"
250
+ if include_answer:
251
+ prompt += f" {df.iloc[idx, -1]}\n\n"
252
+ return prompt
253
+
254
+
255
+ def _gen_prompt(train_df, subject, k=-1):
256
+ """Generates a prompt string from multiple examples."""
257
+ prompt = (
258
+ "The following are multiple choice questions (with answers) about "
259
+ + f"{_format_subject(subject)}.\n\n"
260
+ )
261
+ for i in range(min(k, train_df.shape[0]) if k != -1 else train_df.shape[0]):
262
+ prompt += _format_example(train_df, i)
263
+ return prompt
264
+
265
+
266
+ def _generate_response(tokenizer, model, input_ids):
267
+ """Generates a model response for the given input IDs."""
268
+ try:
269
+ response = model.generate(input_ids, max_new_tokens=1)
270
+ return tokenizer.decode(response[0], skip_special_tokens=True).strip()
271
+ except subprocess.CalledProcessError as e:
272
+ printing.log_warning(
273
+ f"Subprocess failed with command: {e} and error message: {e.stderr}"
274
+ )
275
+ except Exception as e: # pylint: disable=broad-except
276
+ printing.log_warning(f"Error during model generation: {e}")
277
+ return "" # Return an empty string on failure
278
+
279
+
280
+ def download_and_extract_dataset(data_cache_dir: str, dataset_url: str):
281
+ """
282
+ Download the dataset from the given URL and extract it into the target directory.
283
+ """
284
+
285
+ import requests
286
+
287
+ # Create the directory if it does not exist
288
+ Path(data_cache_dir).mkdir(parents=True, exist_ok=True)
289
+
290
+ # Check if the data already exists to avoid re-downloading
291
+ if not os.listdir(data_cache_dir): # Checks if the directory is empty
292
+ printing.log_info(f"Downloading dataset to {data_cache_dir}")
293
+
294
+ # Download the dataset
295
+ response = requests.get(dataset_url, stream=True)
296
+ if response.status_code == 200:
297
+ tar_path = os.path.join(data_cache_dir, "data.tar")
298
+ with open(tar_path, "wb") as f:
299
+ f.write(response.raw.read())
300
+
301
+ printing.log_info("Extracting dataset...")
302
+ # Extract the tar file
303
+ with tarfile.open(tar_path) as tar:
304
+ tar.extractall(path=data_cache_dir)
305
+ os.remove(tar_path)
306
+ printing.log_info("Dataset ready.")
307
+ else:
308
+ printing.log_info("Failed to download the dataset.")
309
+ else:
310
+ printing.log_info(
311
+ f"Dataset already exists in {data_cache_dir}, skipping download."
312
+ )
313
+
314
+ # MMLU data is stored in data.tar/data
315
+ return os.path.join(data_cache_dir, "data")
316
+
317
+
318
+ # This file was originally licensed under Apache 2.0. It has been modified.
319
+ # Modifications Copyright (c) 2025 AMD
File without changes