@elizaos/training 2.0.0-alpha.11 → 2.0.0-alpha.26

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Shaw Walters and elizaOS Contributors
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@elizaos/training",
3
- "version": "2.0.0-alpha.11",
3
+ "version": "2.0.0-alpha.26",
4
4
  "description": "ElizaOS RL training pipeline with benchmarking and model publishing support",
5
5
  "main": "./src/index.ts",
6
6
  "types": "./src/index.ts",
@@ -48,11 +48,11 @@
48
48
  "ethers": "^6.16.0",
49
49
  "uuid": "^11.1.0"
50
50
  },
51
- "peerDependencies": {},
52
51
  "devDependencies": {
53
52
  "@types/node": "^24.10.0",
54
53
  "@types/uuid": "^10.0.0",
55
54
  "bun-types": "^1.3.2",
56
55
  "typescript": "^5.9.3"
57
- }
56
+ },
57
+ "gitHead": "91dceb1d2e9762af27353dbc764e40e1a0599508"
58
58
  }
@@ -0,0 +1,190 @@
1
+
2
+ import argparse
3
+ import json
4
+ import re
5
+ import matplotlib
6
+ matplotlib.use('Agg')
7
+ import matplotlib.pyplot as plt
8
+ import mlx.core as mx
9
+ from mlx_lm import load, generate
10
+
11
+ def extract_action(text):
12
+ match = re.search(r"<action>(.*?)</action>", text, re.DOTALL)
13
+ if match:
14
+ return match.group(1).strip().upper()
15
+ return "NONE"
16
+
17
+ def is_valid_format(text):
18
+ return "<response>" in text and "</response>" in text and "<action>" in text and "</action>" in text
19
+
20
+ def run_inference(model_path, adapter_path, prompts, ground_truths, label="Model", verbose=False, args=None):
21
+ print(f"Loading {label} from {model_path} (adapter: {adapter_path})...")
22
+ model, tokenizer = load(model_path, adapter_path=adapter_path)
23
+
24
+ results = []
25
+ correct_count = 0
26
+ valid_format_count = 0
27
+
28
+ for i, prompt in enumerate(prompts):
29
+ print(f"Generating {i+1}/{len(prompts)}...")
30
+
31
+ # Apply chat template to match training
32
+ system_prompt = "You are a helpful assistant."
33
+ if hasattr(args, "system_prompt_file") and args.system_prompt_file:
34
+ with open(args.system_prompt_file, 'r') as f:
35
+ system_prompt = f.read().strip()
36
+
37
+ messages = [
38
+ {"role": "system", "content": system_prompt},
39
+ {"role": "user", "content": prompt}
40
+ ]
41
+
42
+ try:
43
+ # apply_chat_template returns a string if tokenize=False
44
+ # We want the text prompt that the model sees
45
+ full_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
46
+
47
+ text = generate(model, tokenizer, prompt=full_prompt, max_tokens=100, verbose=verbose)
48
+ except TypeError:
49
+ # Fallback if arguments differ in version
50
+ full_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
51
+ text = generate(model, tokenizer, prompt=full_prompt, max_tokens=100, verbose=verbose)
52
+
53
+ pred_action = extract_action(text)
54
+ true_action = extract_action(ground_truths[i])
55
+
56
+ is_correct = (pred_action == true_action)
57
+ is_valid = is_valid_format(text)
58
+
59
+ if is_correct: correct_count += 1
60
+ if is_valid: valid_format_count += 1
61
+
62
+ results.append({
63
+ "prompt": prompt,
64
+ "generated": text,
65
+ "predicted_action": pred_action,
66
+ "ground_truth_action": true_action,
67
+ "is_correct": is_correct,
68
+ "is_valid_format": is_valid
69
+ })
70
+
71
+ accuracy = correct_count / len(prompts)
72
+ format_compliance = valid_format_count / len(prompts)
73
+
74
+ return {
75
+ "model": label,
76
+ "accuracy": accuracy,
77
+ "format_compliance": format_compliance,
78
+ "details": results
79
+ }
80
+
81
+ def main():
82
+ parser = argparse.ArgumentParser()
83
+ parser.add_argument("--base-model", type=str, default="mlx-community/Qwen2.5-1.5B-Instruct-4bit")
84
+ parser.add_argument("--sft-adapter", type=str, required=True)
85
+ parser.add_argument("--rl-adapter", type=str, required=False)
86
+ parser.add_argument("--pure-rl-adapter", type=str, required=False, help="Path to RL-only adapter")
87
+ parser.add_argument("--data", type=str, required=True)
88
+ parser.add_argument("--count", type=int, default=10)
89
+ parser.add_argument("--verbose", action="store_true", help="Enable verbose generation output")
90
+ parser.add_argument("--system-prompt-file", type=str, help="Path to custom system prompt file")
91
+ args = parser.parse_args()
92
+
93
+ # Load prompts and ground truths
94
+ prompts = []
95
+ ground_truths = []
96
+
97
+ print("Loading data...")
98
+ with open(args.data, 'r') as f:
99
+ for line in f:
100
+ if not line.strip(): continue
101
+ try:
102
+ item = json.loads(line)
103
+ if "messages" in item:
104
+ # User message is prompt, Assistant message is ground truth
105
+ user_msg = next((m["content"] for m in reversed(item["messages"]) if m["role"] == "user"), None)
106
+ asst_msg = next((m["content"] for m in reversed(item["messages"]) if m["role"] == "assistant"), None)
107
+
108
+ if user_msg and asst_msg:
109
+ prompts.append(user_msg)
110
+ ground_truths.append(asst_msg)
111
+ except:
112
+ pass
113
+
114
+ # Select subset
115
+ subset_indices = range(min(args.count, len(prompts)))
116
+ selected_prompts = [prompts[i] for i in subset_indices]
117
+ selected_truths = [ground_truths[i] for i in subset_indices]
118
+
119
+ if not selected_prompts:
120
+ print("No prompts found.")
121
+ return
122
+
123
+ print(f"--- Running Benchmark on {len(selected_prompts)} samples ---")
124
+
125
+ metrics = []
126
+
127
+ # 1. Base Model
128
+ base_metrics = run_inference(args.base_model, None, selected_prompts, selected_truths, "Base Model", verbose=args.verbose, args=args)
129
+ metrics.append(base_metrics)
130
+
131
+ # 2. SFT Model
132
+ sft_metrics = run_inference(args.base_model, args.sft_adapter, selected_prompts, selected_truths, "SFT Model", verbose=args.verbose, args=args)
133
+ metrics.append(sft_metrics)
134
+
135
+ # 3. Pure RL Model (Base -> RL)
136
+ if args.pure_rl_adapter:
137
+ try:
138
+ pure_rl_metrics = run_inference(args.base_model, args.pure_rl_adapter, selected_prompts, selected_truths, "Pure RL Model", verbose=args.verbose, args=args)
139
+ metrics.append(pure_rl_metrics)
140
+ except Exception as e:
141
+ print(f"Failed to run Pure RL benchmark: {e}")
142
+
143
+ # 4. SFT+RL Model (SFT -> RL)
144
+ if args.rl_adapter:
145
+ rl_metrics = run_inference(args.base_model, args.rl_adapter, selected_prompts, selected_truths, "SFT+RL Model", verbose=args.verbose, args=args)
146
+ metrics.append(rl_metrics)
147
+
148
+ # Save Results
149
+ with open("benchmark_results.json", "w") as f:
150
+ json.dump(metrics, f, indent=2)
151
+ print("\nSaved detailed results to benchmark_results.json")
152
+
153
+ # Print Summary
154
+ print("\n=== SUMMARY ===")
155
+ print(f"{'Model':<15} | {'Accuracy':<10} | {'Format':<10}")
156
+ print("-" * 40)
157
+ for m in metrics:
158
+ print(f"{m['model']:<15} | {m['accuracy']:.2%} | {m['format_compliance']:.2%}")
159
+
160
+ # Plotting
161
+ try:
162
+ models = [m['model'] for m in metrics]
163
+ accuracies = [m['accuracy'] for m in metrics]
164
+ formats = [m['format_compliance'] for m in metrics]
165
+
166
+ x = range(len(models))
167
+ width = 0.35
168
+
169
+ fig, ax = plt.subplots(figsize=(10, 6))
170
+ rects1 = ax.bar([i - width/2 for i in x], accuracies, width, label='Action Accuracy')
171
+ rects2 = ax.bar([i + width/2 for i in x], formats, width, label='Format Compliance')
172
+
173
+ ax.set_ylabel('Score')
174
+ ax.set_title('ShouldRespond Benchmark Results')
175
+ ax.set_xticks(x)
176
+ ax.set_xticklabels(models)
177
+ ax.legend()
178
+
179
+ ax.bar_label(rects1, padding=3, fmt='%.2f')
180
+ ax.bar_label(rects2, padding=3, fmt='%.2f')
181
+
182
+ plt.tight_layout()
183
+ plt.savefig("benchmark_chart.png")
184
+ print("Saved chart to benchmark_chart.png")
185
+
186
+ except Exception as e:
187
+ print(f"Failed to create chart: {e}")
188
+
189
+ if __name__ == "__main__":
190
+ main()
@@ -0,0 +1,62 @@
1
+ import mlx.core as mx
2
+ from mlx_lm import load, generate
3
+
4
+ model_name = "mlx-community/Qwen2.5-1.5B-Instruct-4bit"
5
+ print(f"Loading {model_name}...")
6
+ model, tokenizer = load(model_name)
7
+
8
+ # Original Prompt
9
+ prompt_orig = """<task>Decide on behalf of Eliza whether they should respond to the message, ignore it or stop the conversation.</task>
10
+
11
+ <providers>
12
+ [RECENT_MESSAGES]
13
+ User: Hey @Eliza, what's up?
14
+
15
+ </providers>
16
+
17
+ <instructions>Decide if Eliza should respond to or interact with the conversation.
18
+
19
+ IMPORTANT RULES FOR RESPONDING:
20
+ - If YOUR name (Eliza) is directly mentioned -> RESPOND
21
+ - If someone uses a DIFFERENT name (not Eliza) -> IGNORE (they're talking to someone else)
22
+ - If you're actively participating in a conversation and the message continues that thread -> RESPOND
23
+ - If someone tells you to stop or be quiet -> STOP
24
+ - Otherwise -> IGNORE
25
+
26
+ The key distinction is:
27
+ - "Talking TO Eliza" (your name mentioned, replies to you, continuing your conversation) -> RESPOND
28
+ - "Talking ABOUT Eliza" or to someone else -> IGNORE
29
+ </instructions>
30
+
31
+ <output>
32
+ Do NOT include any thinking, reasoning, or <think> sections in your response.
33
+ Go directly to the XML response format without any preamble or explanation.
34
+
35
+ Respond using XML format like this:
36
+ <response>
37
+ <name>Eliza</name>
38
+ <reasoning>Your reasoning here</reasoning>
39
+ <action>RESPOND | IGNORE | STOP</action>
40
+ </response>
41
+
42
+ IMPORTANT: Your response must ONLY contain the <response></response> XML block above. Do not include any text, thinking, or reasoning before or after this XML block. Start your response immediately with <response> and end with </response>.
43
+ </output>"""
44
+
45
+ prompts = [
46
+ ("Original", prompt_orig),
47
+ ("No Providers Block", prompt_orig.replace("<providers>\n[RECENT_MESSAGES]", "Messages:\n").replace("\n\n</providers>", "")),
48
+ ("Simplified", """User: Hey @Eliza, what's up?
49
+ Instructions: You are Eliza. Provide an XML response deciding to RESPOND, IGNORE, or STOP.
50
+ Criteria: Respond if addressed directly (@Eliza).
51
+ Format: <response><name>Eliza</name><reasoning>...</reasoning><action>...</action></response>""")
52
+ ]
53
+
54
+ for name, p in prompts:
55
+ print(f"\n--- Testing {name} ---")
56
+ messages = [
57
+ {"role": "system", "content": "You are a helpful assistant."},
58
+ {"role": "user", "content": p}
59
+ ]
60
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
61
+ output = generate(model, tokenizer, prompt=text, max_tokens=100, verbose=False)
62
+ print(output)
@@ -0,0 +1,269 @@
1
+ import argparse
2
+ import json
3
+ import re
4
+ import random
5
+ import os
6
+ import mlx.core as mx
7
+ from mlx_lm import load, generate
8
+ from dotenv import load_dotenv
9
+
10
+ # Load environment variables
11
+ load_dotenv()
12
+
13
+ # -------------------------------------------------------------------------
14
+ # Reward Function (Copied from train_grpo.py)
15
+ # -------------------------------------------------------------------------
16
+ ACTION_RE = re.compile(r"<action>\s*(.*?)\s*</action>", re.DOTALL)
17
+
18
+ def compute_rewards(prompts, completions):
19
+ """
20
+ Reward function for shouldRespond task.
21
+ Returns list of scores.
22
+ """
23
+ rewards = []
24
+
25
+ for prompt, text in zip(prompts, completions):
26
+ score = 0.0
27
+
28
+ # Parse action
29
+ action_match = ACTION_RE.search(text)
30
+ action = action_match.group(1).strip().upper() if action_match else "NONE"
31
+
32
+ # Heuristics
33
+ last_user_msg = prompt.split("User:")[-1] if "User:" in prompt else prompt
34
+ is_direct_mention = ("@Eliza" in last_user_msg or "Eliza" in last_user_msg)
35
+ is_stop = any(w in last_user_msg.lower() for w in ["stop", "shut up", "quiet", "be quiet"])
36
+ is_continuation = "Eliza:" in prompt
37
+ is_ambiguous = any(w in last_user_msg.lower() for w in ["anyone", "anybody", "help", "assist", "question", "somebody"])
38
+ should_respond = is_direct_mention or is_continuation or is_ambiguous
39
+
40
+ # Scoring
41
+ if is_stop:
42
+ if action == "STOP": score += 1.0
43
+ elif action == "IGNORE": score += 0.3
44
+ else: score -= 0.3
45
+ elif should_respond:
46
+ if action == "RESPOND": score += 1.0
47
+ else: score -= 0.3
48
+ else: # should ignore
49
+ if action == "IGNORE": score += 1.0
50
+ else: score -= 0.3
51
+
52
+ # Format bonus
53
+ if "<response>" in text and "</response>" in text:
54
+ score += 0.2
55
+ if action == "NONE":
56
+ score -= 0.5
57
+
58
+ rewards.append(score)
59
+
60
+ return rewards
61
+
62
+ # -------------------------------------------------------------------------
63
+ # Meta-Optimizer
64
+ # -------------------------------------------------------------------------
65
+ META_PROMPT_TEMPLATE = """You are an expert Prompt Engineer optimizing a system prompt for an AI agent.
66
+
67
+ Current System Prompt:
68
+ "{current_prompt}"
69
+
70
+ Task:
71
+ The AI needs to decide whether to REPLY, IGNORE, or STOP based on a chat history.
72
+ It MUST output valid XML: <response><name>Eliza</name><reason>...</reason><action>...</action></response>.
73
+
74
+ Here are some examples of the AI's performance with the current prompt:
75
+
76
+ [SUCCESSFUL EXAMPLE]
77
+ Input: {good_input}
78
+ Output: {good_output}
79
+ Reward: {good_score} (This was good!)
80
+
81
+ [FAILED EXAMPLE]
82
+ Input: {bad_input}
83
+ Output: {bad_output}
84
+ Reward: {bad_score} (This was bad!)
85
+
86
+ [INSTRUCTIONS]
87
+ Analyze why the successful example worked and the failed one didn't.
88
+ Rewrite the System Prompt to fix the failure case while maintaining the success case.
89
+ The new prompt should be concise, clear, and emphasize the XML format and the decision logic.
90
+ Return ONLY the new system prompt text. Do not add explanations.
91
+ """
92
+
93
+ def call_external_model(provider, model_name, prompt):
94
+ """
95
+ Call external API (Groq or Anthropic) to optimize the prompt.
96
+ """
97
+ try:
98
+ if provider == "groq":
99
+ from groq import Groq
100
+ client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
101
+ chat_completion = client.chat.completions.create(
102
+ messages=[
103
+ {"role": "system", "content": "You are an expert prompt engineer."},
104
+ {"role": "user", "content": prompt}
105
+ ],
106
+ model=model_name,
107
+ temperature=0.7,
108
+ max_tokens=1024,
109
+ )
110
+ return chat_completion.choices[0].message.content
111
+
112
+ elif provider == "anthropic":
113
+ import anthropic
114
+ client = anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY"))
115
+ message = client.messages.create(
116
+ model=model_name,
117
+ max_tokens=1024,
118
+ messages=[
119
+ {"role": "user", "content": prompt}
120
+ ]
121
+ )
122
+ return message.content[0].text
123
+
124
+ elif provider == "openai":
125
+ from openai import OpenAI
126
+ client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
127
+ completion = client.chat.completions.create(
128
+ model=model_name,
129
+ messages=[
130
+ {"role": "system", "content": "You are an expert prompt engineer."},
131
+ {"role": "user", "content": prompt}
132
+ ]
133
+ )
134
+ return completion.choices[0].message.content
135
+
136
+ else:
137
+ raise ValueError(f"Unknown provider: {provider}")
138
+
139
+ except Exception as e:
140
+ print(f"Error calling {provider}: {e}")
141
+ return None
142
+
143
+ def optimize_prompt(provider, model_name, current_prompt, good_example, bad_example):
144
+ meta_prompt = META_PROMPT_TEMPLATE.format(
145
+ current_prompt=current_prompt,
146
+ good_input=good_example['prompt'],
147
+ good_output=good_example['output'],
148
+ good_score=good_example['score'],
149
+ bad_input=bad_example['prompt'],
150
+ bad_output=bad_example['output'],
151
+ bad_score=bad_example['score']
152
+ )
153
+
154
+ new_prompt = call_external_model(provider, model_name, meta_prompt)
155
+
156
+ if new_prompt:
157
+ # Simple cleanup: remove quotes if the model added them
158
+ new_prompt = new_prompt.strip()
159
+ if new_prompt.startswith('"') and new_prompt.endswith('"'):
160
+ new_prompt = new_prompt[1:-1]
161
+ return new_prompt
162
+ else:
163
+ return current_prompt
164
+
165
+ # -------------------------------------------------------------------------
166
+ # Main Loop
167
+ # -------------------------------------------------------------------------
168
+ def main():
169
+ parser = argparse.ArgumentParser()
170
+ parser.add_argument("--model", type=str, default="mlx-community/Qwen2.5-1.5B-Instruct-4bit")
171
+ parser.add_argument("--data", type=str, required=True)
172
+ parser.add_argument("--iter", type=int, default=10)
173
+ parser.add_argument("--batch-size", type=int, default=4)
174
+ parser.add_argument("--save-path", type=str, default="optimized_prompt.txt")
175
+ parser.add_argument("--optimizer-provider", type=str, default="groq", choices=["groq", "anthropic", "openai"])
176
+ parser.add_argument("--optimizer-model", type=str, default="llama-3.3-70b-versatile")
177
+ args = parser.parse_args()
178
+
179
+ print(f"Loading Evaluator Model: {args.model}...")
180
+ model, tokenizer = load(args.model)
181
+
182
+ print(f"Using Meta-Optimizer: {args.optimizer_provider} ({args.optimizer_model})")
183
+
184
+ # Load Data
185
+ prompts = []
186
+ with open(args.data, 'r') as f:
187
+ for line in f:
188
+ if line.strip():
189
+ item = json.loads(line)
190
+ user_msg = next((m["content"] for m in reversed(item["messages"]) if m["role"] == "user"), None)
191
+ if user_msg: prompts.append(user_msg)
192
+
193
+ random.shuffle(prompts)
194
+
195
+ # Initialize Prompt
196
+ # Initialize Prompt
197
+ current_system_prompt = (
198
+ "You are an AI assistant named Eliza. "
199
+ "Decide whether to REPLY, IGNORE, or STOP. "
200
+ "You MUST respond in valid XML format: "
201
+ "<response><name>Eliza</name><reason>...</reason><action>REPLY/IGNORE/STOP</action></response>."
202
+ )
203
+ best_overall_avg_score = -float('inf')
204
+
205
+ print(f"Starting Training-Free GRPO for {args.iter} iterations...")
206
+
207
+ for i in range(args.iter):
208
+ print(f"\n=== Iteration {i+1}/{args.iter} ===")
209
+
210
+ # 1. Sample Batch
211
+ batch_prompts = random.sample(prompts, args.batch_size)
212
+
213
+ batch_results = []
214
+
215
+ # 2. Evaluate Current Prompt on Batch
216
+ print(f"Evaluating current prompt on {len(batch_prompts)} examples...")
217
+ current_scores = []
218
+
219
+ for p in batch_prompts:
220
+ # Construct full prompt with system instruction
221
+ messages = [
222
+ {"role": "system", "content": current_system_prompt},
223
+ {"role": "user", "content": p}
224
+ ]
225
+ full_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
226
+ output = generate(model, tokenizer, prompt=full_text, max_tokens=100, verbose=False)
227
+
228
+ # Score
229
+ score = compute_rewards([p], [output])[0]
230
+ current_scores.append(score)
231
+
232
+ batch_results.append({
233
+ "prompt": p,
234
+ "output": output,
235
+ "score": score
236
+ })
237
+
238
+ avg_score = sum(current_scores) / len(current_scores)
239
+ print(f"Avg Score: {avg_score:.2f}")
240
+
241
+ if avg_score > best_overall_avg_score:
242
+ best_overall_avg_score = avg_score
243
+ print(f"New Best Score! Saving prompt to {args.save_path}...")
244
+ with open(args.save_path, "w") as f:
245
+ f.write(current_system_prompt)
246
+
247
+ # 3. Select Good and Bad Examples
248
+ # Sort by score
249
+ batch_results.sort(key=lambda x: x['score'], reverse=True)
250
+ good = batch_results[0]
251
+ bad = batch_results[-1]
252
+
253
+ if good['score'] > bad['score']:
254
+ # Variance exists, we can learn
255
+ print(f"Optimizing: Good ({good['score']}) vs Bad ({bad['score']})")
256
+ print(f"Good Output: {good['output'][:50]}...")
257
+ print(f"Bad Output: {bad['output'][:50]}...")
258
+
259
+ new_prompt = optimize_prompt(args.optimizer_provider, args.optimizer_model, current_system_prompt, good, bad)
260
+ print(f"\nNew Prompt Proposed:\n{new_prompt}\n")
261
+ current_system_prompt = new_prompt
262
+ else:
263
+ print("No variance in batch (all good or all bad). Skipping optimization step.")
264
+
265
+ print("\noptimization Complete.")
266
+ print(f"Final Prompt saved to {args.save_path}")
267
+
268
+ if __name__ == "__main__":
269
+ main()
@@ -0,0 +1,29 @@
1
+
2
+ import mlx.core as mx
3
+ from mlx_lm import load, generate
4
+ import argparse
5
+
6
+ def main():
7
+ parser = argparse.ArgumentParser()
8
+ parser.add_argument("--model", type=str, default="mlx-community/Qwen2.5-1.5B-Instruct-4bit")
9
+ parser.add_argument("--adapter-path", type=str, default="trained_models/should_respond_sft/adapters")
10
+ parser.add_argument("--temp", type=float, default=1.0)
11
+ args = parser.parse_args()
12
+
13
+ print(f"Loading {args.model} with {args.adapter_path}")
14
+ model, tokenizer = load(args.model, adapter_path=args.adapter_path)
15
+
16
+ prompt = "<task>Decide on behalf of Eliza whether they should respond to the message, ignore it or stop the conversation.</task>\n\n<providers>\n[RECENT_MESSAGES]\nUser: I heard Eliza is helping\n</providers>\n\n<instructions>Decide if Eliza should respond to or interact with the conversation.\n\nIMPORTANT RULES FOR RESPONDING:\n- If YOUR name (Eliza) is directly mentioned → RESPOND\n- If someone uses a DIFFERENT name (not Eliza) → IGNORE (they're talking to someone else)\n- If you're actively participating in a conversation and the message continues that thread → RESPOND\n- If someone tells you to stop or be quiet → STOP\n- Otherwise → IGNORE\n\nThe key distinction is:\n- \"Talking TO Eliza\" (your name mentioned, replies to you, continuing your conversation) → RESPOND\n- \"Talking ABOUT Eliza\" or to someone else → IGNORE\n</instructions>\n\n<output>\nDo NOT include any thinking, reasoning, or <think> sections in your response.\nGo directly to the XML response format without any preamble or explanation.\n\nRespond using XML format like this:\n<response>\n <name>Eliza</name>\n <reasoning>Your reasoning here</reasoning>\n <action>RESPOND | IGNORE | STOP</action>\n</response>\n\nIMPORTANT: Your response must ONLY contain the <response></response> XML block above. Do not include any text, thinking, or reasoning before or after this XML block. Start your response immediately with <response> and end with </response>.\n</output>"
17
+
18
+ print("\n--- Gen 1 (temp={}) ---".format(args.temp))
19
+ from mlx_lm.sample_utils import make_sampler
20
+ sampler = make_sampler(temp=args.temp)
21
+
22
+ print(generate(model, tokenizer, prompt=prompt, max_tokens=50, verbose=True, sampler=sampler))
23
+
24
+ print("\n--- Gen 2 (temp={}) ---".format(args.temp))
25
+ sampler2 = make_sampler(temp=args.temp)
26
+ print(generate(model, tokenizer, prompt=prompt, max_tokens=50, verbose=True, sampler=sampler2))
27
+
28
+ if __name__ == "__main__":
29
+ main()