lemonade-sdk 7.0.1__py3-none-any.whl → 7.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lemonade-sdk might be problematic. Click here for more details.
- lemonade/cli.py +2 -0
- lemonade/tools/accuracy.py +335 -0
- lemonade/tools/huggingface_load.py +6 -0
- lemonade/tools/ort_genai/oga.py +6 -4
- lemonade/tools/prompt.py +28 -1
- lemonade/tools/server/instructions.py +8 -265
- lemonade/tools/server/llamacpp.py +45 -19
- lemonade/tools/server/port_utils.py +57 -0
- lemonade/tools/server/serve.py +96 -44
- lemonade/tools/server/static/instructions.html +262 -0
- lemonade/tools/server/thread_utils.py +87 -0
- lemonade/version.py +1 -1
- {lemonade_sdk-7.0.1.dist-info → lemonade_sdk-7.0.3.dist-info}/METADATA +1 -1
- {lemonade_sdk-7.0.1.dist-info → lemonade_sdk-7.0.3.dist-info}/RECORD +22 -18
- lemonade_server/model_manager.py +45 -12
- {lemonade/tools/server → lemonade_server}/pydantic_models.py +2 -0
- lemonade_server/server_models.json +25 -4
- {lemonade_sdk-7.0.1.dist-info → lemonade_sdk-7.0.3.dist-info}/WHEEL +0 -0
- {lemonade_sdk-7.0.1.dist-info → lemonade_sdk-7.0.3.dist-info}/entry_points.txt +0 -0
- {lemonade_sdk-7.0.1.dist-info → lemonade_sdk-7.0.3.dist-info}/licenses/LICENSE +0 -0
- {lemonade_sdk-7.0.1.dist-info → lemonade_sdk-7.0.3.dist-info}/licenses/NOTICE.md +0 -0
- {lemonade_sdk-7.0.1.dist-info → lemonade_sdk-7.0.3.dist-info}/top_level.txt +0 -0
lemonade/cli.py
CHANGED
|
@@ -19,6 +19,7 @@ import lemonade.cache as cache
|
|
|
19
19
|
from lemonade.tools.mmlu import AccuracyMMLU
|
|
20
20
|
from lemonade.tools.humaneval import AccuracyHumaneval
|
|
21
21
|
from lemonade.tools.perplexity import AccuracyPerplexity
|
|
22
|
+
from lemonade.tools.accuracy import LMEvalHarness
|
|
22
23
|
from lemonade.tools.prompt import LLMPrompt
|
|
23
24
|
from lemonade.tools.quark.quark_load import QuarkLoad
|
|
24
25
|
from lemonade.tools.quark.quark_quantize import QuarkQuantize
|
|
@@ -36,6 +37,7 @@ def main():
|
|
|
36
37
|
AccuracyMMLU,
|
|
37
38
|
AccuracyHumaneval,
|
|
38
39
|
AccuracyPerplexity,
|
|
40
|
+
LMEvalHarness,
|
|
39
41
|
LLMPrompt,
|
|
40
42
|
HuggingfaceBench,
|
|
41
43
|
OgaBench,
|
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
import socket
|
|
5
|
+
import subprocess
|
|
6
|
+
import sys
|
|
7
|
+
import time
|
|
8
|
+
from typing import Optional
|
|
9
|
+
|
|
10
|
+
import requests
|
|
11
|
+
|
|
12
|
+
from lemonade.state import State
|
|
13
|
+
from lemonade.tools import Tool
|
|
14
|
+
import lemonade.common.printing as printing
|
|
15
|
+
import lemonade.common.build as build
|
|
16
|
+
|
|
17
|
+
from lemonade.tools.server.thread_utils import ServerRunner
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def is_port_in_use(port, host="localhost"):
|
|
21
|
+
"""
|
|
22
|
+
Check if a port is in use
|
|
23
|
+
"""
|
|
24
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
25
|
+
return s.connect_ex((host, port)) == 0
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class LMEvalHarness(Tool):
|
|
29
|
+
"""
|
|
30
|
+
Tool for evaluating LLMs using lm-eval-harness on industry standard benchmarks
|
|
31
|
+
like MMLU, GSM8k, and more. See docs/lemonade/lm_eval.md for more details.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
unique_name = "lm-eval-harness"
|
|
35
|
+
|
|
36
|
+
def __init__(self):
|
|
37
|
+
super().__init__(
|
|
38
|
+
monitor_message="Evaluate model accuracy using ElutherAI's lm-eval-harness"
|
|
39
|
+
)
|
|
40
|
+
self.status_stats = []
|
|
41
|
+
self.server_runner = None
|
|
42
|
+
|
|
43
|
+
@staticmethod
|
|
44
|
+
def parser(add_help: bool = True) -> argparse.ArgumentParser:
|
|
45
|
+
parser = __class__.helpful_parser(
|
|
46
|
+
short_description="Evaluate model using lm-eval-harness",
|
|
47
|
+
add_help=add_help,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
parser.add_argument(
|
|
51
|
+
"--task",
|
|
52
|
+
type=str,
|
|
53
|
+
required=True,
|
|
54
|
+
help="Task(s) to evaluate on (e.g., gsm8k, mmlu)",
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
parser.add_argument(
|
|
58
|
+
"--server-port", type=int, default=8000, help="Port to use for the server"
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
parser.add_argument(
|
|
62
|
+
"--num-fewshot",
|
|
63
|
+
type=int,
|
|
64
|
+
default=0,
|
|
65
|
+
help="Number of examples in few-shot prompts",
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
parser.add_argument(
|
|
69
|
+
"--limit",
|
|
70
|
+
type=int,
|
|
71
|
+
default=None,
|
|
72
|
+
help="Limit the number of examples per task",
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
parser.add_argument(
|
|
76
|
+
"--log-samples",
|
|
77
|
+
action="store_true",
|
|
78
|
+
help="Log samples for each task to log file",
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
parser.add_argument(
|
|
82
|
+
"--output-path",
|
|
83
|
+
type=str,
|
|
84
|
+
default=None,
|
|
85
|
+
help="Path to save evaluation results",
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
return parser
|
|
89
|
+
|
|
90
|
+
def _process_results(self, results_dir, state):
|
|
91
|
+
"""Process evaluation results and save to state stats"""
|
|
92
|
+
if not os.path.exists(results_dir) or not os.path.isdir(results_dir):
|
|
93
|
+
printing.log_warning(f"Results directory not found at {results_dir}")
|
|
94
|
+
return
|
|
95
|
+
|
|
96
|
+
model_dirs = [
|
|
97
|
+
d
|
|
98
|
+
for d in os.listdir(results_dir)
|
|
99
|
+
if os.path.isdir(os.path.join(results_dir, d))
|
|
100
|
+
]
|
|
101
|
+
|
|
102
|
+
if not model_dirs:
|
|
103
|
+
printing.log_warning(f"No model directories found in {results_dir}")
|
|
104
|
+
return
|
|
105
|
+
|
|
106
|
+
model_dir = os.path.join(results_dir, model_dirs[0])
|
|
107
|
+
printing.log_info(f"Found model directory: {model_dir}")
|
|
108
|
+
|
|
109
|
+
# Find the results JSON file with timestamp
|
|
110
|
+
results_files = [
|
|
111
|
+
f
|
|
112
|
+
for f in os.listdir(model_dir)
|
|
113
|
+
if f.startswith("results_") and f.endswith(".json")
|
|
114
|
+
]
|
|
115
|
+
|
|
116
|
+
if not results_files:
|
|
117
|
+
printing.log_warning(f"No results files found in {model_dir}")
|
|
118
|
+
return
|
|
119
|
+
|
|
120
|
+
# Sort by timestamp
|
|
121
|
+
results_files.sort(reverse=True)
|
|
122
|
+
results_file_path = os.path.join(model_dir, results_files[0])
|
|
123
|
+
printing.log_info(f"Processing results from {results_file_path}")
|
|
124
|
+
|
|
125
|
+
# Read and process results
|
|
126
|
+
try:
|
|
127
|
+
with open(results_file_path, "r", encoding="utf-8") as f:
|
|
128
|
+
results = json.load(f)
|
|
129
|
+
|
|
130
|
+
# Extract and display metrics
|
|
131
|
+
if "results" in results:
|
|
132
|
+
for task_name, metrics in results["results"].items():
|
|
133
|
+
printing.log_info(f"Results for {task_name}:")
|
|
134
|
+
|
|
135
|
+
for metric, value in metrics.items():
|
|
136
|
+
if isinstance(value, (int, float)) and not metric.startswith(
|
|
137
|
+
"alias"
|
|
138
|
+
):
|
|
139
|
+
# Format metric name for stats
|
|
140
|
+
clean_metric = metric.replace(",", "_")
|
|
141
|
+
stat_name = f"lm_eval_{task_name}_{clean_metric}"
|
|
142
|
+
|
|
143
|
+
# Save to state stats as percentage
|
|
144
|
+
state.save_stat(stat_name, float(value) * 100)
|
|
145
|
+
state.save_stat(f"{stat_name}_units", "%")
|
|
146
|
+
self.status_stats.append(stat_name)
|
|
147
|
+
|
|
148
|
+
printing.log_info(
|
|
149
|
+
f" {metric}: {value:.4f} ({value*100:.2f}%)"
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
# Save summary metrics if available
|
|
153
|
+
avg_metrics = {}
|
|
154
|
+
if "higher_is_better" in results:
|
|
155
|
+
for metric_type in results["higher_is_better"].values():
|
|
156
|
+
for metric in metric_type.keys():
|
|
157
|
+
if metric not in avg_metrics:
|
|
158
|
+
avg_metrics[metric] = []
|
|
159
|
+
|
|
160
|
+
for task_metrics in results["results"].values():
|
|
161
|
+
for metric, value in task_metrics.items():
|
|
162
|
+
if isinstance(value, (int, float)) and not metric.startswith(
|
|
163
|
+
"alias"
|
|
164
|
+
):
|
|
165
|
+
base_metric = metric.split(",")[0]
|
|
166
|
+
if base_metric in avg_metrics:
|
|
167
|
+
avg_metrics[base_metric].append(value)
|
|
168
|
+
|
|
169
|
+
# Calculate and save averages
|
|
170
|
+
for metric, values in avg_metrics.items():
|
|
171
|
+
if values:
|
|
172
|
+
avg_value = sum(values) / len(values)
|
|
173
|
+
stat_name = f"lm_eval_average_{metric}"
|
|
174
|
+
state.save_stat(stat_name, float(avg_value) * 100)
|
|
175
|
+
state.save_stat(f"{stat_name}_units", "%")
|
|
176
|
+
self.status_stats.append(stat_name)
|
|
177
|
+
printing.log_info(
|
|
178
|
+
f"Average {metric}: {avg_value:.4f} ({avg_value*100:.2f}%)"
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
except (IOError, json.JSONDecodeError) as e:
|
|
182
|
+
printing.log_error(f"Error processing results: {e}")
|
|
183
|
+
|
|
184
|
+
def run(
|
|
185
|
+
self,
|
|
186
|
+
state: State,
|
|
187
|
+
task: str,
|
|
188
|
+
server_port: int = 8000,
|
|
189
|
+
server_host: str = "localhost",
|
|
190
|
+
num_fewshot: int = 0,
|
|
191
|
+
limit: Optional[int] = None,
|
|
192
|
+
log_samples: bool = False,
|
|
193
|
+
output_path: Optional[str] = None,
|
|
194
|
+
) -> State:
|
|
195
|
+
|
|
196
|
+
model = state.model
|
|
197
|
+
tokenizer = state.tokenizer
|
|
198
|
+
|
|
199
|
+
if model is None or tokenizer is None:
|
|
200
|
+
raise ValueError(
|
|
201
|
+
"Model and tokenizer must be loaded in state before running lm-eval-harness"
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
# Set up output path
|
|
205
|
+
if output_path is None:
|
|
206
|
+
output_path = os.path.join(
|
|
207
|
+
build.output_dir(state.cache_dir, state.build_name), "lm_eval_results"
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
os.makedirs(output_path, exist_ok=True)
|
|
211
|
+
|
|
212
|
+
# Check if port is already in use
|
|
213
|
+
if is_port_in_use(server_port, server_host):
|
|
214
|
+
error_msg = (
|
|
215
|
+
f"Port {server_port} is already in use. "
|
|
216
|
+
"Please close all applications using this port and try again."
|
|
217
|
+
)
|
|
218
|
+
printing.log_error(error_msg)
|
|
219
|
+
raise RuntimeError(error_msg)
|
|
220
|
+
|
|
221
|
+
# Retroactively determine recipe based on model type to select correct iterator
|
|
222
|
+
# The model is already loaded in server, so we only need recipe for iterator selection
|
|
223
|
+
checkpoint = getattr(state, "checkpoint", "unknown")
|
|
224
|
+
if "OrtGenaiModel" in str(type(model)):
|
|
225
|
+
recipe = "oga-"
|
|
226
|
+
else:
|
|
227
|
+
recipe = "unknown"
|
|
228
|
+
|
|
229
|
+
# Start the server thread
|
|
230
|
+
self.server_runner = ServerRunner(
|
|
231
|
+
model=model,
|
|
232
|
+
tokenizer=tokenizer,
|
|
233
|
+
checkpoint=checkpoint,
|
|
234
|
+
recipe=recipe,
|
|
235
|
+
host=server_host,
|
|
236
|
+
port=server_port,
|
|
237
|
+
)
|
|
238
|
+
self.server_runner.start()
|
|
239
|
+
|
|
240
|
+
# Wait for server initialization
|
|
241
|
+
printing.log_info("Waiting for server initialization...")
|
|
242
|
+
|
|
243
|
+
# Wait for server to start and be responsive
|
|
244
|
+
server_url = f"http://{server_host}:{server_port}"
|
|
245
|
+
max_retries = 30
|
|
246
|
+
retry_delay = 1
|
|
247
|
+
|
|
248
|
+
printing.log_info(f"Checking if server is available at {server_url}...")
|
|
249
|
+
for i in range(max_retries):
|
|
250
|
+
try:
|
|
251
|
+
response = requests.get(f"{server_url}/api/v0/health", timeout=2)
|
|
252
|
+
if response.status_code == 200:
|
|
253
|
+
printing.log_info(f"Server is ready after {i+1} attempts")
|
|
254
|
+
break
|
|
255
|
+
except requests.exceptions.RequestException:
|
|
256
|
+
if i < max_retries - 1:
|
|
257
|
+
time.sleep(retry_delay)
|
|
258
|
+
else:
|
|
259
|
+
printing.log_error(
|
|
260
|
+
f"Server did not start after {max_retries} attempts"
|
|
261
|
+
)
|
|
262
|
+
raise RuntimeError("Failed to start the server")
|
|
263
|
+
|
|
264
|
+
# Build API URL
|
|
265
|
+
results_file = os.path.join(output_path, f"{task}_results")
|
|
266
|
+
|
|
267
|
+
printing.log_info(f"Running lm-eval-harness on {task}...")
|
|
268
|
+
|
|
269
|
+
# Build lm-eval-harness command
|
|
270
|
+
cmd = [
|
|
271
|
+
"lm_eval",
|
|
272
|
+
"--model",
|
|
273
|
+
"local-completions",
|
|
274
|
+
"--tasks",
|
|
275
|
+
task,
|
|
276
|
+
"--model_args",
|
|
277
|
+
(
|
|
278
|
+
f"model={checkpoint},"
|
|
279
|
+
f"base_url={server_url}/api/v0/completions,"
|
|
280
|
+
f"num_concurrent=1,"
|
|
281
|
+
f"max_retries=5,"
|
|
282
|
+
f"retry_timeout=10,"
|
|
283
|
+
f"tokenized_requests=False"
|
|
284
|
+
),
|
|
285
|
+
"--num_fewshot",
|
|
286
|
+
str(num_fewshot),
|
|
287
|
+
"--output_path",
|
|
288
|
+
results_file,
|
|
289
|
+
]
|
|
290
|
+
|
|
291
|
+
if limit is not None:
|
|
292
|
+
cmd.extend(["--limit", str(limit)])
|
|
293
|
+
|
|
294
|
+
if log_samples:
|
|
295
|
+
cmd.extend(["--log_samples"])
|
|
296
|
+
|
|
297
|
+
try:
|
|
298
|
+
# On Windows, set UTF-8 mode to handle Unicode output
|
|
299
|
+
env = os.environ.copy()
|
|
300
|
+
if sys.platform == "win32":
|
|
301
|
+
env["PYTHONIOENCODING"] = "utf-8"
|
|
302
|
+
|
|
303
|
+
# Execute lm-eval-harness command
|
|
304
|
+
result = subprocess.run(
|
|
305
|
+
cmd, check=True, text=True, capture_output=True, env=env
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
# Log relevant output and skip any parts that might cause encoding issues
|
|
309
|
+
try:
|
|
310
|
+
printing.log_info(result.stdout)
|
|
311
|
+
except UnicodeEncodeError:
|
|
312
|
+
printing.log_info(
|
|
313
|
+
"Results obtained successfully but couldn't display due to encoding issues"
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
# Process results from the correct location
|
|
317
|
+
results_dir = os.path.join(output_path, f"{task}_results")
|
|
318
|
+
self._process_results(results_dir, state)
|
|
319
|
+
|
|
320
|
+
except subprocess.CalledProcessError as e:
|
|
321
|
+
printing.log_error(f"Error running lm-eval-harness: {e}")
|
|
322
|
+
printing.log_error(f"stderr: {e.stderr}")
|
|
323
|
+
except (IOError, ValueError, requests.RequestException) as e:
|
|
324
|
+
printing.log_error(f"Error: {e}")
|
|
325
|
+
finally:
|
|
326
|
+
# Shut down server
|
|
327
|
+
if self.server_runner and self.server_runner.is_alive():
|
|
328
|
+
printing.log_info("Shutting down server runner...")
|
|
329
|
+
self.server_runner.shutdown()
|
|
330
|
+
|
|
331
|
+
# Make sure we don't have any lingering references to state's model/tokenizer
|
|
332
|
+
# that could prevent garbage collection
|
|
333
|
+
self.server_runner = None
|
|
334
|
+
|
|
335
|
+
return state
|
|
@@ -326,6 +326,7 @@ class HuggingfaceAdapter(ModelAdapter):
|
|
|
326
326
|
def generate(
|
|
327
327
|
self,
|
|
328
328
|
input_ids,
|
|
329
|
+
random_seed=1,
|
|
329
330
|
**kwargs,
|
|
330
331
|
):
|
|
331
332
|
|
|
@@ -346,6 +347,11 @@ class HuggingfaceAdapter(ModelAdapter):
|
|
|
346
347
|
**kwargs,
|
|
347
348
|
}
|
|
348
349
|
|
|
350
|
+
if random_seed is None:
|
|
351
|
+
torch.random.seed()
|
|
352
|
+
else:
|
|
353
|
+
torch.random.manual_seed(random_seed)
|
|
354
|
+
|
|
349
355
|
with torch.no_grad(), torch.inference_mode():
|
|
350
356
|
outputs = self.model.generate(input_ids=input_ids, **generation_kwargs)
|
|
351
357
|
|
lemonade/tools/ort_genai/oga.py
CHANGED
|
@@ -139,6 +139,7 @@ class OrtGenaiModel(ModelAdapter):
|
|
|
139
139
|
pad_token_id=None,
|
|
140
140
|
stopping_criteria=None,
|
|
141
141
|
max_length=None,
|
|
142
|
+
random_seed=1,
|
|
142
143
|
):
|
|
143
144
|
params = og.GeneratorParams(self.model)
|
|
144
145
|
|
|
@@ -179,6 +180,9 @@ class OrtGenaiModel(ModelAdapter):
|
|
|
179
180
|
if use_oga_pre_6_api:
|
|
180
181
|
params.input_ids = input_ids
|
|
181
182
|
|
|
183
|
+
if random_seed is None:
|
|
184
|
+
random_seed = -1 # In og.Generator, -1 = seed with random device
|
|
185
|
+
|
|
182
186
|
if self.config and "search" in self.config:
|
|
183
187
|
search_config = self.config["search"]
|
|
184
188
|
params.set_search_options(
|
|
@@ -196,10 +200,7 @@ class OrtGenaiModel(ModelAdapter):
|
|
|
196
200
|
past_present_share_buffer=search_config.get(
|
|
197
201
|
"past_present_share_buffer", True
|
|
198
202
|
),
|
|
199
|
-
|
|
200
|
-
# by default, random_seed=-1 causes different laptops to give
|
|
201
|
-
# different results
|
|
202
|
-
random_seed=1,
|
|
203
|
+
random_seed=random_seed,
|
|
203
204
|
# Not currently supported by OGA
|
|
204
205
|
# diversity_penalty=search_config.get('diversity_penalty', 0.0),
|
|
205
206
|
# no_repeat_ngram_size=search_config.get('no_repeat_ngram_size', 0),
|
|
@@ -212,6 +213,7 @@ class OrtGenaiModel(ModelAdapter):
|
|
|
212
213
|
temperature=temperature,
|
|
213
214
|
max_length=max_length_to_use,
|
|
214
215
|
min_length=min_length,
|
|
216
|
+
random_seed=random_seed,
|
|
215
217
|
)
|
|
216
218
|
params.try_graph_capture_with_max_batch_size(1)
|
|
217
219
|
|
lemonade/tools/prompt.py
CHANGED
|
@@ -15,6 +15,7 @@ DEFAULT_GENERATE_PARAMS = {
|
|
|
15
15
|
"temperature": 0.7,
|
|
16
16
|
}
|
|
17
17
|
|
|
18
|
+
DEFAULT_RANDOM_SEED = 1
|
|
18
19
|
DEFAULT_MAX_NEW_TOKENS = 512
|
|
19
20
|
DEFAULT_N_TRIALS = 1
|
|
20
21
|
|
|
@@ -108,6 +109,19 @@ class LLMPrompt(Tool):
|
|
|
108
109
|
f"(useful for testing, default is {DEFAULT_N_TRIALS})",
|
|
109
110
|
)
|
|
110
111
|
|
|
112
|
+
parser.add_argument(
|
|
113
|
+
"--random-seed",
|
|
114
|
+
"-r",
|
|
115
|
+
default=str(DEFAULT_RANDOM_SEED),
|
|
116
|
+
help="Positive integer seed for random number generator used in "
|
|
117
|
+
"sampling tokens "
|
|
118
|
+
f"(default is {DEFAULT_RANDOM_SEED}). If the number of trials is "
|
|
119
|
+
"greater than one, then the seed is incremented by one for each "
|
|
120
|
+
"trial. Set to `None` for random, non-repeatable results. This "
|
|
121
|
+
"random seed behavior only applies to models loaded with "
|
|
122
|
+
"`oga-load` or `huggingface-load`.",
|
|
123
|
+
)
|
|
124
|
+
|
|
111
125
|
return parser
|
|
112
126
|
|
|
113
127
|
def parse(self, state: State, args, known_only=True) -> argparse.Namespace:
|
|
@@ -123,6 +137,11 @@ class LLMPrompt(Tool):
|
|
|
123
137
|
with open(parsed_args.prompt, "r", encoding="utf-8") as f:
|
|
124
138
|
parsed_args.prompt = f.read()
|
|
125
139
|
|
|
140
|
+
if parsed_args.random_seed == "None":
|
|
141
|
+
parsed_args.random_seed = None
|
|
142
|
+
else:
|
|
143
|
+
parsed_args.random_seed = int(parsed_args.random_seed)
|
|
144
|
+
|
|
126
145
|
return parsed_args
|
|
127
146
|
|
|
128
147
|
def run(
|
|
@@ -132,6 +151,7 @@ class LLMPrompt(Tool):
|
|
|
132
151
|
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
|
|
133
152
|
n_trials: int = DEFAULT_N_TRIALS,
|
|
134
153
|
template: bool = False,
|
|
154
|
+
random_seed: int = DEFAULT_RANDOM_SEED,
|
|
135
155
|
) -> State:
|
|
136
156
|
|
|
137
157
|
model: ModelAdapter = state.model
|
|
@@ -170,9 +190,16 @@ class LLMPrompt(Tool):
|
|
|
170
190
|
|
|
171
191
|
# Get the response from the LLM, which may include the prompt in it
|
|
172
192
|
response = model.generate(
|
|
173
|
-
input_ids,
|
|
193
|
+
input_ids,
|
|
194
|
+
max_new_tokens=max_new_tokens,
|
|
195
|
+
random_seed=random_seed,
|
|
196
|
+
**DEFAULT_GENERATE_PARAMS,
|
|
174
197
|
)
|
|
175
198
|
|
|
199
|
+
# Increment random seed if not none
|
|
200
|
+
if random_seed is not None:
|
|
201
|
+
random_seed += 1
|
|
202
|
+
|
|
176
203
|
# Flatten the input and response
|
|
177
204
|
input_ids_array = (
|
|
178
205
|
input_ids if isinstance(input_ids, (list, str)) else input_ids[0]
|