@elizaos/training 2.0.0-alpha.11 → 2.0.0-alpha.26
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/package.json +3 -3
- package/python/scripts/benchmark_should_respond.py +190 -0
- package/python/scripts/debug_inference.py +62 -0
- package/python/scripts/optimize_prompt_grpo.py +269 -0
- package/python/scripts/test_generation.py +29 -0
- package/python/scripts/train_grpo.py +360 -0
- package/python/scripts/train_jsonl.py +223 -0
- package/research-output/training-runs/training-run-1771276293257.json +38 -0
- package/research-output/training-runs/training-run-1771276389280.json +38 -0
- package/research-output/training-runs/training-run-1771276502776.json +38 -0
- package/research-output/training-runs/training-run-1771277340748.json +38 -0
- package/scripts/generate_should_respond.ts +267 -0
- package/scripts/generate_should_respond_dataset.ts +162 -0
- package/scripts/rank_trajectories.ts +207 -0
- package/scripts/run_rlaif_loop.ts +78 -0
- package/scripts/run_task_benchmark.ts +247 -0
- package/src/benchmark/BenchmarkDataViewer.ts +1 -1
- package/src/benchmark/BenchmarkValidator.ts +157 -159
- package/src/benchmark/TaskRunner.ts +94 -0
- package/src/benchmark/index.ts +3 -1
- package/src/dependencies.ts +36 -0
- package/data/.gitkeep +0 -0
- package/data/degen/.gitkeep +0 -2
- package/data/trader/.gitkeep +0 -2
package/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Shaw Walters and elizaOS Contributors
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@elizaos/training",
|
|
3
|
-
"version": "2.0.0-alpha.
|
|
3
|
+
"version": "2.0.0-alpha.26",
|
|
4
4
|
"description": "ElizaOS RL training pipeline with benchmarking and model publishing support",
|
|
5
5
|
"main": "./src/index.ts",
|
|
6
6
|
"types": "./src/index.ts",
|
|
@@ -48,11 +48,11 @@
|
|
|
48
48
|
"ethers": "^6.16.0",
|
|
49
49
|
"uuid": "^11.1.0"
|
|
50
50
|
},
|
|
51
|
-
"peerDependencies": {},
|
|
52
51
|
"devDependencies": {
|
|
53
52
|
"@types/node": "^24.10.0",
|
|
54
53
|
"@types/uuid": "^10.0.0",
|
|
55
54
|
"bun-types": "^1.3.2",
|
|
56
55
|
"typescript": "^5.9.3"
|
|
57
|
-
}
|
|
56
|
+
},
|
|
57
|
+
"gitHead": "91dceb1d2e9762af27353dbc764e40e1a0599508"
|
|
58
58
|
}
|
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
|
|
2
|
+
import argparse
|
|
3
|
+
import json
|
|
4
|
+
import re
|
|
5
|
+
import matplotlib
|
|
6
|
+
matplotlib.use('Agg')
|
|
7
|
+
import matplotlib.pyplot as plt
|
|
8
|
+
import mlx.core as mx
|
|
9
|
+
from mlx_lm import load, generate
|
|
10
|
+
|
|
11
|
+
def extract_action(text):
|
|
12
|
+
match = re.search(r"<action>(.*?)</action>", text, re.DOTALL)
|
|
13
|
+
if match:
|
|
14
|
+
return match.group(1).strip().upper()
|
|
15
|
+
return "NONE"
|
|
16
|
+
|
|
17
|
+
def is_valid_format(text):
|
|
18
|
+
return "<response>" in text and "</response>" in text and "<action>" in text and "</action>" in text
|
|
19
|
+
|
|
20
|
+
def run_inference(model_path, adapter_path, prompts, ground_truths, label="Model", verbose=False, args=None):
|
|
21
|
+
print(f"Loading {label} from {model_path} (adapter: {adapter_path})...")
|
|
22
|
+
model, tokenizer = load(model_path, adapter_path=adapter_path)
|
|
23
|
+
|
|
24
|
+
results = []
|
|
25
|
+
correct_count = 0
|
|
26
|
+
valid_format_count = 0
|
|
27
|
+
|
|
28
|
+
for i, prompt in enumerate(prompts):
|
|
29
|
+
print(f"Generating {i+1}/{len(prompts)}...")
|
|
30
|
+
|
|
31
|
+
# Apply chat template to match training
|
|
32
|
+
system_prompt = "You are a helpful assistant."
|
|
33
|
+
if hasattr(args, "system_prompt_file") and args.system_prompt_file:
|
|
34
|
+
with open(args.system_prompt_file, 'r') as f:
|
|
35
|
+
system_prompt = f.read().strip()
|
|
36
|
+
|
|
37
|
+
messages = [
|
|
38
|
+
{"role": "system", "content": system_prompt},
|
|
39
|
+
{"role": "user", "content": prompt}
|
|
40
|
+
]
|
|
41
|
+
|
|
42
|
+
try:
|
|
43
|
+
# apply_chat_template returns a string if tokenize=False
|
|
44
|
+
# We want the text prompt that the model sees
|
|
45
|
+
full_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
46
|
+
|
|
47
|
+
text = generate(model, tokenizer, prompt=full_prompt, max_tokens=100, verbose=verbose)
|
|
48
|
+
except TypeError:
|
|
49
|
+
# Fallback if arguments differ in version
|
|
50
|
+
full_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
51
|
+
text = generate(model, tokenizer, prompt=full_prompt, max_tokens=100, verbose=verbose)
|
|
52
|
+
|
|
53
|
+
pred_action = extract_action(text)
|
|
54
|
+
true_action = extract_action(ground_truths[i])
|
|
55
|
+
|
|
56
|
+
is_correct = (pred_action == true_action)
|
|
57
|
+
is_valid = is_valid_format(text)
|
|
58
|
+
|
|
59
|
+
if is_correct: correct_count += 1
|
|
60
|
+
if is_valid: valid_format_count += 1
|
|
61
|
+
|
|
62
|
+
results.append({
|
|
63
|
+
"prompt": prompt,
|
|
64
|
+
"generated": text,
|
|
65
|
+
"predicted_action": pred_action,
|
|
66
|
+
"ground_truth_action": true_action,
|
|
67
|
+
"is_correct": is_correct,
|
|
68
|
+
"is_valid_format": is_valid
|
|
69
|
+
})
|
|
70
|
+
|
|
71
|
+
accuracy = correct_count / len(prompts)
|
|
72
|
+
format_compliance = valid_format_count / len(prompts)
|
|
73
|
+
|
|
74
|
+
return {
|
|
75
|
+
"model": label,
|
|
76
|
+
"accuracy": accuracy,
|
|
77
|
+
"format_compliance": format_compliance,
|
|
78
|
+
"details": results
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
def main():
|
|
82
|
+
parser = argparse.ArgumentParser()
|
|
83
|
+
parser.add_argument("--base-model", type=str, default="mlx-community/Qwen2.5-1.5B-Instruct-4bit")
|
|
84
|
+
parser.add_argument("--sft-adapter", type=str, required=True)
|
|
85
|
+
parser.add_argument("--rl-adapter", type=str, required=False)
|
|
86
|
+
parser.add_argument("--pure-rl-adapter", type=str, required=False, help="Path to RL-only adapter")
|
|
87
|
+
parser.add_argument("--data", type=str, required=True)
|
|
88
|
+
parser.add_argument("--count", type=int, default=10)
|
|
89
|
+
parser.add_argument("--verbose", action="store_true", help="Enable verbose generation output")
|
|
90
|
+
parser.add_argument("--system-prompt-file", type=str, help="Path to custom system prompt file")
|
|
91
|
+
args = parser.parse_args()
|
|
92
|
+
|
|
93
|
+
# Load prompts and ground truths
|
|
94
|
+
prompts = []
|
|
95
|
+
ground_truths = []
|
|
96
|
+
|
|
97
|
+
print("Loading data...")
|
|
98
|
+
with open(args.data, 'r') as f:
|
|
99
|
+
for line in f:
|
|
100
|
+
if not line.strip(): continue
|
|
101
|
+
try:
|
|
102
|
+
item = json.loads(line)
|
|
103
|
+
if "messages" in item:
|
|
104
|
+
# User message is prompt, Assistant message is ground truth
|
|
105
|
+
user_msg = next((m["content"] for m in reversed(item["messages"]) if m["role"] == "user"), None)
|
|
106
|
+
asst_msg = next((m["content"] for m in reversed(item["messages"]) if m["role"] == "assistant"), None)
|
|
107
|
+
|
|
108
|
+
if user_msg and asst_msg:
|
|
109
|
+
prompts.append(user_msg)
|
|
110
|
+
ground_truths.append(asst_msg)
|
|
111
|
+
except:
|
|
112
|
+
pass
|
|
113
|
+
|
|
114
|
+
# Select subset
|
|
115
|
+
subset_indices = range(min(args.count, len(prompts)))
|
|
116
|
+
selected_prompts = [prompts[i] for i in subset_indices]
|
|
117
|
+
selected_truths = [ground_truths[i] for i in subset_indices]
|
|
118
|
+
|
|
119
|
+
if not selected_prompts:
|
|
120
|
+
print("No prompts found.")
|
|
121
|
+
return
|
|
122
|
+
|
|
123
|
+
print(f"--- Running Benchmark on {len(selected_prompts)} samples ---")
|
|
124
|
+
|
|
125
|
+
metrics = []
|
|
126
|
+
|
|
127
|
+
# 1. Base Model
|
|
128
|
+
base_metrics = run_inference(args.base_model, None, selected_prompts, selected_truths, "Base Model", verbose=args.verbose, args=args)
|
|
129
|
+
metrics.append(base_metrics)
|
|
130
|
+
|
|
131
|
+
# 2. SFT Model
|
|
132
|
+
sft_metrics = run_inference(args.base_model, args.sft_adapter, selected_prompts, selected_truths, "SFT Model", verbose=args.verbose, args=args)
|
|
133
|
+
metrics.append(sft_metrics)
|
|
134
|
+
|
|
135
|
+
# 3. Pure RL Model (Base -> RL)
|
|
136
|
+
if args.pure_rl_adapter:
|
|
137
|
+
try:
|
|
138
|
+
pure_rl_metrics = run_inference(args.base_model, args.pure_rl_adapter, selected_prompts, selected_truths, "Pure RL Model", verbose=args.verbose, args=args)
|
|
139
|
+
metrics.append(pure_rl_metrics)
|
|
140
|
+
except Exception as e:
|
|
141
|
+
print(f"Failed to run Pure RL benchmark: {e}")
|
|
142
|
+
|
|
143
|
+
# 4. SFT+RL Model (SFT -> RL)
|
|
144
|
+
if args.rl_adapter:
|
|
145
|
+
rl_metrics = run_inference(args.base_model, args.rl_adapter, selected_prompts, selected_truths, "SFT+RL Model", verbose=args.verbose, args=args)
|
|
146
|
+
metrics.append(rl_metrics)
|
|
147
|
+
|
|
148
|
+
# Save Results
|
|
149
|
+
with open("benchmark_results.json", "w") as f:
|
|
150
|
+
json.dump(metrics, f, indent=2)
|
|
151
|
+
print("\nSaved detailed results to benchmark_results.json")
|
|
152
|
+
|
|
153
|
+
# Print Summary
|
|
154
|
+
print("\n=== SUMMARY ===")
|
|
155
|
+
print(f"{'Model':<15} | {'Accuracy':<10} | {'Format':<10}")
|
|
156
|
+
print("-" * 40)
|
|
157
|
+
for m in metrics:
|
|
158
|
+
print(f"{m['model']:<15} | {m['accuracy']:.2%} | {m['format_compliance']:.2%}")
|
|
159
|
+
|
|
160
|
+
# Plotting
|
|
161
|
+
try:
|
|
162
|
+
models = [m['model'] for m in metrics]
|
|
163
|
+
accuracies = [m['accuracy'] for m in metrics]
|
|
164
|
+
formats = [m['format_compliance'] for m in metrics]
|
|
165
|
+
|
|
166
|
+
x = range(len(models))
|
|
167
|
+
width = 0.35
|
|
168
|
+
|
|
169
|
+
fig, ax = plt.subplots(figsize=(10, 6))
|
|
170
|
+
rects1 = ax.bar([i - width/2 for i in x], accuracies, width, label='Action Accuracy')
|
|
171
|
+
rects2 = ax.bar([i + width/2 for i in x], formats, width, label='Format Compliance')
|
|
172
|
+
|
|
173
|
+
ax.set_ylabel('Score')
|
|
174
|
+
ax.set_title('ShouldRespond Benchmark Results')
|
|
175
|
+
ax.set_xticks(x)
|
|
176
|
+
ax.set_xticklabels(models)
|
|
177
|
+
ax.legend()
|
|
178
|
+
|
|
179
|
+
ax.bar_label(rects1, padding=3, fmt='%.2f')
|
|
180
|
+
ax.bar_label(rects2, padding=3, fmt='%.2f')
|
|
181
|
+
|
|
182
|
+
plt.tight_layout()
|
|
183
|
+
plt.savefig("benchmark_chart.png")
|
|
184
|
+
print("Saved chart to benchmark_chart.png")
|
|
185
|
+
|
|
186
|
+
except Exception as e:
|
|
187
|
+
print(f"Failed to create chart: {e}")
|
|
188
|
+
|
|
189
|
+
if __name__ == "__main__":
|
|
190
|
+
main()
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import mlx.core as mx
|
|
2
|
+
from mlx_lm import load, generate
|
|
3
|
+
|
|
4
|
+
model_name = "mlx-community/Qwen2.5-1.5B-Instruct-4bit"
|
|
5
|
+
print(f"Loading {model_name}...")
|
|
6
|
+
model, tokenizer = load(model_name)
|
|
7
|
+
|
|
8
|
+
# Original Prompt
|
|
9
|
+
prompt_orig = """<task>Decide on behalf of Eliza whether they should respond to the message, ignore it or stop the conversation.</task>
|
|
10
|
+
|
|
11
|
+
<providers>
|
|
12
|
+
[RECENT_MESSAGES]
|
|
13
|
+
User: Hey @Eliza, what's up?
|
|
14
|
+
|
|
15
|
+
</providers>
|
|
16
|
+
|
|
17
|
+
<instructions>Decide if Eliza should respond to or interact with the conversation.
|
|
18
|
+
|
|
19
|
+
IMPORTANT RULES FOR RESPONDING:
|
|
20
|
+
- If YOUR name (Eliza) is directly mentioned -> RESPOND
|
|
21
|
+
- If someone uses a DIFFERENT name (not Eliza) -> IGNORE (they're talking to someone else)
|
|
22
|
+
- If you're actively participating in a conversation and the message continues that thread -> RESPOND
|
|
23
|
+
- If someone tells you to stop or be quiet -> STOP
|
|
24
|
+
- Otherwise -> IGNORE
|
|
25
|
+
|
|
26
|
+
The key distinction is:
|
|
27
|
+
- "Talking TO Eliza" (your name mentioned, replies to you, continuing your conversation) -> RESPOND
|
|
28
|
+
- "Talking ABOUT Eliza" or to someone else -> IGNORE
|
|
29
|
+
</instructions>
|
|
30
|
+
|
|
31
|
+
<output>
|
|
32
|
+
Do NOT include any thinking, reasoning, or <think> sections in your response.
|
|
33
|
+
Go directly to the XML response format without any preamble or explanation.
|
|
34
|
+
|
|
35
|
+
Respond using XML format like this:
|
|
36
|
+
<response>
|
|
37
|
+
<name>Eliza</name>
|
|
38
|
+
<reasoning>Your reasoning here</reasoning>
|
|
39
|
+
<action>RESPOND | IGNORE | STOP</action>
|
|
40
|
+
</response>
|
|
41
|
+
|
|
42
|
+
IMPORTANT: Your response must ONLY contain the <response></response> XML block above. Do not include any text, thinking, or reasoning before or after this XML block. Start your response immediately with <response> and end with </response>.
|
|
43
|
+
</output>"""
|
|
44
|
+
|
|
45
|
+
prompts = [
|
|
46
|
+
("Original", prompt_orig),
|
|
47
|
+
("No Providers Block", prompt_orig.replace("<providers>\n[RECENT_MESSAGES]", "Messages:\n").replace("\n\n</providers>", "")),
|
|
48
|
+
("Simplified", """User: Hey @Eliza, what's up?
|
|
49
|
+
Instructions: You are Eliza. Provide an XML response deciding to RESPOND, IGNORE, or STOP.
|
|
50
|
+
Criteria: Respond if addressed directly (@Eliza).
|
|
51
|
+
Format: <response><name>Eliza</name><reasoning>...</reasoning><action>...</action></response>""")
|
|
52
|
+
]
|
|
53
|
+
|
|
54
|
+
for name, p in prompts:
|
|
55
|
+
print(f"\n--- Testing {name} ---")
|
|
56
|
+
messages = [
|
|
57
|
+
{"role": "system", "content": "You are a helpful assistant."},
|
|
58
|
+
{"role": "user", "content": p}
|
|
59
|
+
]
|
|
60
|
+
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
61
|
+
output = generate(model, tokenizer, prompt=text, max_tokens=100, verbose=False)
|
|
62
|
+
print(output)
|
|
@@ -0,0 +1,269 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import json
|
|
3
|
+
import re
|
|
4
|
+
import random
|
|
5
|
+
import os
|
|
6
|
+
import mlx.core as mx
|
|
7
|
+
from mlx_lm import load, generate
|
|
8
|
+
from dotenv import load_dotenv
|
|
9
|
+
|
|
10
|
+
# Load environment variables
|
|
11
|
+
load_dotenv()
|
|
12
|
+
|
|
13
|
+
# -------------------------------------------------------------------------
|
|
14
|
+
# Reward Function (Copied from train_grpo.py)
|
|
15
|
+
# -------------------------------------------------------------------------
|
|
16
|
+
ACTION_RE = re.compile(r"<action>\s*(.*?)\s*</action>", re.DOTALL)
|
|
17
|
+
|
|
18
|
+
def compute_rewards(prompts, completions):
|
|
19
|
+
"""
|
|
20
|
+
Reward function for shouldRespond task.
|
|
21
|
+
Returns list of scores.
|
|
22
|
+
"""
|
|
23
|
+
rewards = []
|
|
24
|
+
|
|
25
|
+
for prompt, text in zip(prompts, completions):
|
|
26
|
+
score = 0.0
|
|
27
|
+
|
|
28
|
+
# Parse action
|
|
29
|
+
action_match = ACTION_RE.search(text)
|
|
30
|
+
action = action_match.group(1).strip().upper() if action_match else "NONE"
|
|
31
|
+
|
|
32
|
+
# Heuristics
|
|
33
|
+
last_user_msg = prompt.split("User:")[-1] if "User:" in prompt else prompt
|
|
34
|
+
is_direct_mention = ("@Eliza" in last_user_msg or "Eliza" in last_user_msg)
|
|
35
|
+
is_stop = any(w in last_user_msg.lower() for w in ["stop", "shut up", "quiet", "be quiet"])
|
|
36
|
+
is_continuation = "Eliza:" in prompt
|
|
37
|
+
is_ambiguous = any(w in last_user_msg.lower() for w in ["anyone", "anybody", "help", "assist", "question", "somebody"])
|
|
38
|
+
should_respond = is_direct_mention or is_continuation or is_ambiguous
|
|
39
|
+
|
|
40
|
+
# Scoring
|
|
41
|
+
if is_stop:
|
|
42
|
+
if action == "STOP": score += 1.0
|
|
43
|
+
elif action == "IGNORE": score += 0.3
|
|
44
|
+
else: score -= 0.3
|
|
45
|
+
elif should_respond:
|
|
46
|
+
if action == "RESPOND": score += 1.0
|
|
47
|
+
else: score -= 0.3
|
|
48
|
+
else: # should ignore
|
|
49
|
+
if action == "IGNORE": score += 1.0
|
|
50
|
+
else: score -= 0.3
|
|
51
|
+
|
|
52
|
+
# Format bonus
|
|
53
|
+
if "<response>" in text and "</response>" in text:
|
|
54
|
+
score += 0.2
|
|
55
|
+
if action == "NONE":
|
|
56
|
+
score -= 0.5
|
|
57
|
+
|
|
58
|
+
rewards.append(score)
|
|
59
|
+
|
|
60
|
+
return rewards
|
|
61
|
+
|
|
62
|
+
# -------------------------------------------------------------------------
|
|
63
|
+
# Meta-Optimizer
|
|
64
|
+
# -------------------------------------------------------------------------
|
|
65
|
+
META_PROMPT_TEMPLATE = """You are an expert Prompt Engineer optimizing a system prompt for an AI agent.
|
|
66
|
+
|
|
67
|
+
Current System Prompt:
|
|
68
|
+
"{current_prompt}"
|
|
69
|
+
|
|
70
|
+
Task:
|
|
71
|
+
The AI needs to decide whether to REPLY, IGNORE, or STOP based on a chat history.
|
|
72
|
+
It MUST output valid XML: <response><name>Eliza</name><reason>...</reason><action>...</action></response>.
|
|
73
|
+
|
|
74
|
+
Here are some examples of the AI's performance with the current prompt:
|
|
75
|
+
|
|
76
|
+
[SUCCESSFUL EXAMPLE]
|
|
77
|
+
Input: {good_input}
|
|
78
|
+
Output: {good_output}
|
|
79
|
+
Reward: {good_score} (This was good!)
|
|
80
|
+
|
|
81
|
+
[FAILED EXAMPLE]
|
|
82
|
+
Input: {bad_input}
|
|
83
|
+
Output: {bad_output}
|
|
84
|
+
Reward: {bad_score} (This was bad!)
|
|
85
|
+
|
|
86
|
+
[INSTRUCTIONS]
|
|
87
|
+
Analyze why the successful example worked and the failed one didn't.
|
|
88
|
+
Rewrite the System Prompt to fix the failure case while maintaining the success case.
|
|
89
|
+
The new prompt should be concise, clear, and emphasize the XML format and the decision logic.
|
|
90
|
+
Return ONLY the new system prompt text. Do not add explanations.
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
def call_external_model(provider, model_name, prompt):
|
|
94
|
+
"""
|
|
95
|
+
Call external API (Groq or Anthropic) to optimize the prompt.
|
|
96
|
+
"""
|
|
97
|
+
try:
|
|
98
|
+
if provider == "groq":
|
|
99
|
+
from groq import Groq
|
|
100
|
+
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
|
|
101
|
+
chat_completion = client.chat.completions.create(
|
|
102
|
+
messages=[
|
|
103
|
+
{"role": "system", "content": "You are an expert prompt engineer."},
|
|
104
|
+
{"role": "user", "content": prompt}
|
|
105
|
+
],
|
|
106
|
+
model=model_name,
|
|
107
|
+
temperature=0.7,
|
|
108
|
+
max_tokens=1024,
|
|
109
|
+
)
|
|
110
|
+
return chat_completion.choices[0].message.content
|
|
111
|
+
|
|
112
|
+
elif provider == "anthropic":
|
|
113
|
+
import anthropic
|
|
114
|
+
client = anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY"))
|
|
115
|
+
message = client.messages.create(
|
|
116
|
+
model=model_name,
|
|
117
|
+
max_tokens=1024,
|
|
118
|
+
messages=[
|
|
119
|
+
{"role": "user", "content": prompt}
|
|
120
|
+
]
|
|
121
|
+
)
|
|
122
|
+
return message.content[0].text
|
|
123
|
+
|
|
124
|
+
elif provider == "openai":
|
|
125
|
+
from openai import OpenAI
|
|
126
|
+
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
|
127
|
+
completion = client.chat.completions.create(
|
|
128
|
+
model=model_name,
|
|
129
|
+
messages=[
|
|
130
|
+
{"role": "system", "content": "You are an expert prompt engineer."},
|
|
131
|
+
{"role": "user", "content": prompt}
|
|
132
|
+
]
|
|
133
|
+
)
|
|
134
|
+
return completion.choices[0].message.content
|
|
135
|
+
|
|
136
|
+
else:
|
|
137
|
+
raise ValueError(f"Unknown provider: {provider}")
|
|
138
|
+
|
|
139
|
+
except Exception as e:
|
|
140
|
+
print(f"Error calling {provider}: {e}")
|
|
141
|
+
return None
|
|
142
|
+
|
|
143
|
+
def optimize_prompt(provider, model_name, current_prompt, good_example, bad_example):
|
|
144
|
+
meta_prompt = META_PROMPT_TEMPLATE.format(
|
|
145
|
+
current_prompt=current_prompt,
|
|
146
|
+
good_input=good_example['prompt'],
|
|
147
|
+
good_output=good_example['output'],
|
|
148
|
+
good_score=good_example['score'],
|
|
149
|
+
bad_input=bad_example['prompt'],
|
|
150
|
+
bad_output=bad_example['output'],
|
|
151
|
+
bad_score=bad_example['score']
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
new_prompt = call_external_model(provider, model_name, meta_prompt)
|
|
155
|
+
|
|
156
|
+
if new_prompt:
|
|
157
|
+
# Simple cleanup: remove quotes if the model added them
|
|
158
|
+
new_prompt = new_prompt.strip()
|
|
159
|
+
if new_prompt.startswith('"') and new_prompt.endswith('"'):
|
|
160
|
+
new_prompt = new_prompt[1:-1]
|
|
161
|
+
return new_prompt
|
|
162
|
+
else:
|
|
163
|
+
return current_prompt
|
|
164
|
+
|
|
165
|
+
# -------------------------------------------------------------------------
|
|
166
|
+
# Main Loop
|
|
167
|
+
# -------------------------------------------------------------------------
|
|
168
|
+
def main():
|
|
169
|
+
parser = argparse.ArgumentParser()
|
|
170
|
+
parser.add_argument("--model", type=str, default="mlx-community/Qwen2.5-1.5B-Instruct-4bit")
|
|
171
|
+
parser.add_argument("--data", type=str, required=True)
|
|
172
|
+
parser.add_argument("--iter", type=int, default=10)
|
|
173
|
+
parser.add_argument("--batch-size", type=int, default=4)
|
|
174
|
+
parser.add_argument("--save-path", type=str, default="optimized_prompt.txt")
|
|
175
|
+
parser.add_argument("--optimizer-provider", type=str, default="groq", choices=["groq", "anthropic", "openai"])
|
|
176
|
+
parser.add_argument("--optimizer-model", type=str, default="llama-3.3-70b-versatile")
|
|
177
|
+
args = parser.parse_args()
|
|
178
|
+
|
|
179
|
+
print(f"Loading Evaluator Model: {args.model}...")
|
|
180
|
+
model, tokenizer = load(args.model)
|
|
181
|
+
|
|
182
|
+
print(f"Using Meta-Optimizer: {args.optimizer_provider} ({args.optimizer_model})")
|
|
183
|
+
|
|
184
|
+
# Load Data
|
|
185
|
+
prompts = []
|
|
186
|
+
with open(args.data, 'r') as f:
|
|
187
|
+
for line in f:
|
|
188
|
+
if line.strip():
|
|
189
|
+
item = json.loads(line)
|
|
190
|
+
user_msg = next((m["content"] for m in reversed(item["messages"]) if m["role"] == "user"), None)
|
|
191
|
+
if user_msg: prompts.append(user_msg)
|
|
192
|
+
|
|
193
|
+
random.shuffle(prompts)
|
|
194
|
+
|
|
195
|
+
# Initialize Prompt
|
|
196
|
+
# Initialize Prompt
|
|
197
|
+
current_system_prompt = (
|
|
198
|
+
"You are an AI assistant named Eliza. "
|
|
199
|
+
"Decide whether to REPLY, IGNORE, or STOP. "
|
|
200
|
+
"You MUST respond in valid XML format: "
|
|
201
|
+
"<response><name>Eliza</name><reason>...</reason><action>REPLY/IGNORE/STOP</action></response>."
|
|
202
|
+
)
|
|
203
|
+
best_overall_avg_score = -float('inf')
|
|
204
|
+
|
|
205
|
+
print(f"Starting Training-Free GRPO for {args.iter} iterations...")
|
|
206
|
+
|
|
207
|
+
for i in range(args.iter):
|
|
208
|
+
print(f"\n=== Iteration {i+1}/{args.iter} ===")
|
|
209
|
+
|
|
210
|
+
# 1. Sample Batch
|
|
211
|
+
batch_prompts = random.sample(prompts, args.batch_size)
|
|
212
|
+
|
|
213
|
+
batch_results = []
|
|
214
|
+
|
|
215
|
+
# 2. Evaluate Current Prompt on Batch
|
|
216
|
+
print(f"Evaluating current prompt on {len(batch_prompts)} examples...")
|
|
217
|
+
current_scores = []
|
|
218
|
+
|
|
219
|
+
for p in batch_prompts:
|
|
220
|
+
# Construct full prompt with system instruction
|
|
221
|
+
messages = [
|
|
222
|
+
{"role": "system", "content": current_system_prompt},
|
|
223
|
+
{"role": "user", "content": p}
|
|
224
|
+
]
|
|
225
|
+
full_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
226
|
+
output = generate(model, tokenizer, prompt=full_text, max_tokens=100, verbose=False)
|
|
227
|
+
|
|
228
|
+
# Score
|
|
229
|
+
score = compute_rewards([p], [output])[0]
|
|
230
|
+
current_scores.append(score)
|
|
231
|
+
|
|
232
|
+
batch_results.append({
|
|
233
|
+
"prompt": p,
|
|
234
|
+
"output": output,
|
|
235
|
+
"score": score
|
|
236
|
+
})
|
|
237
|
+
|
|
238
|
+
avg_score = sum(current_scores) / len(current_scores)
|
|
239
|
+
print(f"Avg Score: {avg_score:.2f}")
|
|
240
|
+
|
|
241
|
+
if avg_score > best_overall_avg_score:
|
|
242
|
+
best_overall_avg_score = avg_score
|
|
243
|
+
print(f"New Best Score! Saving prompt to {args.save_path}...")
|
|
244
|
+
with open(args.save_path, "w") as f:
|
|
245
|
+
f.write(current_system_prompt)
|
|
246
|
+
|
|
247
|
+
# 3. Select Good and Bad Examples
|
|
248
|
+
# Sort by score
|
|
249
|
+
batch_results.sort(key=lambda x: x['score'], reverse=True)
|
|
250
|
+
good = batch_results[0]
|
|
251
|
+
bad = batch_results[-1]
|
|
252
|
+
|
|
253
|
+
if good['score'] > bad['score']:
|
|
254
|
+
# Variance exists, we can learn
|
|
255
|
+
print(f"Optimizing: Good ({good['score']}) vs Bad ({bad['score']})")
|
|
256
|
+
print(f"Good Output: {good['output'][:50]}...")
|
|
257
|
+
print(f"Bad Output: {bad['output'][:50]}...")
|
|
258
|
+
|
|
259
|
+
new_prompt = optimize_prompt(args.optimizer_provider, args.optimizer_model, current_system_prompt, good, bad)
|
|
260
|
+
print(f"\nNew Prompt Proposed:\n{new_prompt}\n")
|
|
261
|
+
current_system_prompt = new_prompt
|
|
262
|
+
else:
|
|
263
|
+
print("No variance in batch (all good or all bad). Skipping optimization step.")
|
|
264
|
+
|
|
265
|
+
print("\noptimization Complete.")
|
|
266
|
+
print(f"Final Prompt saved to {args.save_path}")
|
|
267
|
+
|
|
268
|
+
if __name__ == "__main__":
|
|
269
|
+
main()
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
|
|
2
|
+
import mlx.core as mx
|
|
3
|
+
from mlx_lm import load, generate
|
|
4
|
+
import argparse
|
|
5
|
+
|
|
6
|
+
def main():
|
|
7
|
+
parser = argparse.ArgumentParser()
|
|
8
|
+
parser.add_argument("--model", type=str, default="mlx-community/Qwen2.5-1.5B-Instruct-4bit")
|
|
9
|
+
parser.add_argument("--adapter-path", type=str, default="trained_models/should_respond_sft/adapters")
|
|
10
|
+
parser.add_argument("--temp", type=float, default=1.0)
|
|
11
|
+
args = parser.parse_args()
|
|
12
|
+
|
|
13
|
+
print(f"Loading {args.model} with {args.adapter_path}")
|
|
14
|
+
model, tokenizer = load(args.model, adapter_path=args.adapter_path)
|
|
15
|
+
|
|
16
|
+
prompt = "<task>Decide on behalf of Eliza whether they should respond to the message, ignore it or stop the conversation.</task>\n\n<providers>\n[RECENT_MESSAGES]\nUser: I heard Eliza is helping\n</providers>\n\n<instructions>Decide if Eliza should respond to or interact with the conversation.\n\nIMPORTANT RULES FOR RESPONDING:\n- If YOUR name (Eliza) is directly mentioned → RESPOND\n- If someone uses a DIFFERENT name (not Eliza) → IGNORE (they're talking to someone else)\n- If you're actively participating in a conversation and the message continues that thread → RESPOND\n- If someone tells you to stop or be quiet → STOP\n- Otherwise → IGNORE\n\nThe key distinction is:\n- \"Talking TO Eliza\" (your name mentioned, replies to you, continuing your conversation) → RESPOND\n- \"Talking ABOUT Eliza\" or to someone else → IGNORE\n</instructions>\n\n<output>\nDo NOT include any thinking, reasoning, or <think> sections in your response.\nGo directly to the XML response format without any preamble or explanation.\n\nRespond using XML format like this:\n<response>\n <name>Eliza</name>\n <reasoning>Your reasoning here</reasoning>\n <action>RESPOND | IGNORE | STOP</action>\n</response>\n\nIMPORTANT: Your response must ONLY contain the <response></response> XML block above. Do not include any text, thinking, or reasoning before or after this XML block. Start your response immediately with <response> and end with </response>.\n</output>"
|
|
17
|
+
|
|
18
|
+
print("\n--- Gen 1 (temp={}) ---".format(args.temp))
|
|
19
|
+
from mlx_lm.sample_utils import make_sampler
|
|
20
|
+
sampler = make_sampler(temp=args.temp)
|
|
21
|
+
|
|
22
|
+
print(generate(model, tokenizer, prompt=prompt, max_tokens=50, verbose=True, sampler=sampler))
|
|
23
|
+
|
|
24
|
+
print("\n--- Gen 2 (temp={}) ---".format(args.temp))
|
|
25
|
+
sampler2 = make_sampler(temp=args.temp)
|
|
26
|
+
print(generate(model, tokenizer, prompt=prompt, max_tokens=50, verbose=True, sampler=sampler2))
|
|
27
|
+
|
|
28
|
+
if __name__ == "__main__":
|
|
29
|
+
main()
|