@elizaos/training 2.0.0-alpha.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (224) hide show
  1. package/Dockerfile +75 -0
  2. package/LICENSE +21 -0
  3. package/Makefile +374 -0
  4. package/README.md +346 -0
  5. package/config/rubrics.json +137 -0
  6. package/docker-compose.test.yml +57 -0
  7. package/package.json +57 -0
  8. package/python/config/babylon_atropos.yaml +90 -0
  9. package/python/config/profiles/12gb.json +11 -0
  10. package/python/config/profiles/16gb.json +10 -0
  11. package/python/config/profiles/24gb.json +10 -0
  12. package/python/config/profiles/48gb.json +10 -0
  13. package/python/config/profiles/cpu.json +11 -0
  14. package/python/config/profiles/l40-2gpu-safe.json +20 -0
  15. package/python/config/profiles/l40-2gpu.json +22 -0
  16. package/python/config/profiles/l40-4gpu.json +21 -0
  17. package/python/config/profiles/l40.json +17 -0
  18. package/python/config/tinker_training.yaml +143 -0
  19. package/python/curriculum_state.json +165 -0
  20. package/python/env.template +86 -0
  21. package/python/env.training.template +46 -0
  22. package/python/pyproject.toml +41 -0
  23. package/python/requirements-ci.txt +31 -0
  24. package/python/requirements.txt +87 -0
  25. package/python/scripts/__init__.py +4 -0
  26. package/python/scripts/benchmark_should_respond.py +190 -0
  27. package/python/scripts/debug_inference.py +62 -0
  28. package/python/scripts/import_json_trajectories.py +412 -0
  29. package/python/scripts/local-finetune/README.md +63 -0
  30. package/python/scripts/local-finetune/ingest_and_score.py +139 -0
  31. package/python/scripts/local-finetune/merge_model.py +32 -0
  32. package/python/scripts/local-finetune/test_adapter.py +91 -0
  33. package/python/scripts/local-finetune/train_from_csv.py +132 -0
  34. package/python/scripts/merge_trajectories.py +318 -0
  35. package/python/scripts/optimize_prompt_grpo.py +269 -0
  36. package/python/scripts/run_ab_test.py +143 -0
  37. package/python/scripts/run_full_pipeline.py +544 -0
  38. package/python/scripts/run_tinker_training.py +192 -0
  39. package/python/scripts/run_training.py +914 -0
  40. package/python/scripts/test_generation.py +29 -0
  41. package/python/scripts/test_judge.py +155 -0
  42. package/python/scripts/test_pipeline.py +356 -0
  43. package/python/scripts/test_trained_model.py +380 -0
  44. package/python/scripts/train_grpo.py +360 -0
  45. package/python/scripts/train_jsonl.py +223 -0
  46. package/python/scripts/train_local.py +528 -0
  47. package/python/setup.py +20 -0
  48. package/python/src/__init__.py +190 -0
  49. package/python/src/data_bridge/__init__.py +24 -0
  50. package/python/src/data_bridge/converter.py +435 -0
  51. package/python/src/data_bridge/reader.py +393 -0
  52. package/python/src/models.py +283 -0
  53. package/python/src/training/__init__.py +605 -0
  54. package/python/src/training/ab_testing.py +404 -0
  55. package/python/src/training/action_executor.py +621 -0
  56. package/python/src/training/archetype_trainer.py +347 -0
  57. package/python/src/training/atropos_trainer.py +980 -0
  58. package/python/src/training/babylon_env.py +1254 -0
  59. package/python/src/training/error_recovery.py +647 -0
  60. package/python/src/training/evaluation.py +856 -0
  61. package/python/src/training/fast_simulator.py +880 -0
  62. package/python/src/training/format_validator.py +584 -0
  63. package/python/src/training/hybrid_env.py +522 -0
  64. package/python/src/training/kl_controller.py +628 -0
  65. package/python/src/training/multi_prompt_dataset.py +883 -0
  66. package/python/src/training/multi_turn.py +656 -0
  67. package/python/src/training/online_env.py +1084 -0
  68. package/python/src/training/quality_scorer.py +391 -0
  69. package/python/src/training/quality_utils.py +633 -0
  70. package/python/src/training/rewards.py +1344 -0
  71. package/python/src/training/rlaif_env.py +17 -0
  72. package/python/src/training/rollout_generator.py +502 -0
  73. package/python/src/training/rubric_loader.py +198 -0
  74. package/python/src/training/scenario_pool.py +1072 -0
  75. package/python/src/training/schemas.py +481 -0
  76. package/python/src/training/service_manager.py +552 -0
  77. package/python/src/training/simulation_bridge.py +535 -0
  78. package/python/src/training/tick_reward_attribution.py +399 -0
  79. package/python/src/training/tinker_client.py +575 -0
  80. package/python/src/training/tinker_trainer.py +646 -0
  81. package/python/src/training/tokenization_utils.py +402 -0
  82. package/python/tests/e2e/__init__.py +13 -0
  83. package/python/tests/e2e/conftest.py +258 -0
  84. package/python/tests/e2e/test_full_pipeline.py +643 -0
  85. package/python/tests/e2e/test_online_training_e2e.py +365 -0
  86. package/python/tests/integration/__init__.py +12 -0
  87. package/python/tests/integration/conftest.py +383 -0
  88. package/python/tests/integration/test_db_integration.py +649 -0
  89. package/python/tests/integration/test_json_mode_integration.py +554 -0
  90. package/python/tests/test_action_executor.py +594 -0
  91. package/python/tests/test_archetype_scoring.py +1027 -0
  92. package/python/tests/test_atropos_integration.py +360 -0
  93. package/python/tests/test_evaluation.py +727 -0
  94. package/python/tests/test_format_validator.py +486 -0
  95. package/python/tests/test_kl_controller.py +432 -0
  96. package/python/tests/test_lr_scheduler.py +579 -0
  97. package/python/tests/test_multi_turn.py +590 -0
  98. package/python/tests/test_online_env.py +519 -0
  99. package/python/tests/test_quality_scorer.py +474 -0
  100. package/python/tests/test_scenario_pool.py +735 -0
  101. package/python/tests/test_service_manager.py +585 -0
  102. package/python/tests/test_simulation_rollout.py +581 -0
  103. package/python/tests/test_tokenization_utils.py +501 -0
  104. package/python/tests/test_training_orchestrator.py +497 -0
  105. package/python/tests/test_training_output_structure.py +661 -0
  106. package/research-output/training-runs/training-run-1770772042899.json +26 -0
  107. package/research-output/training-runs/training-run-1770930079670.json +32 -0
  108. package/research-output/training-runs/training-run-1770930143700.json +44 -0
  109. package/research-output/training-runs/training-run-1770930183638.json +38 -0
  110. package/research-output/training-runs/training-run-1770930442049.json +38 -0
  111. package/research-output/training-runs/training-run-1770930793243.json +38 -0
  112. package/research-output/training-runs/training-run-1771276293257.json +38 -0
  113. package/research-output/training-runs/training-run-1771276389280.json +38 -0
  114. package/research-output/training-runs/training-run-1771276502776.json +38 -0
  115. package/research-output/training-runs/training-run-1771277340748.json +38 -0
  116. package/research-output/training-runs/training-run-1773013658993.json +38 -0
  117. package/research-output/training-runs/training-run-1773013861014.json +38 -0
  118. package/research-output/training-runs/training-run-1773014215983.json +38 -0
  119. package/scripts/assess-training-data.ts +422 -0
  120. package/scripts/e2e-training-test.ts +550 -0
  121. package/scripts/export-rubrics.ts +64 -0
  122. package/scripts/generate-research-report.ts +1523 -0
  123. package/scripts/generate_dataset.sh +173 -0
  124. package/scripts/generate_should_respond.ts +267 -0
  125. package/scripts/generate_should_respond_dataset.ts +162 -0
  126. package/scripts/json-mode-benchmark.ts +399 -0
  127. package/scripts/rank_trajectories.ts +207 -0
  128. package/scripts/real-archetype-benchmark.ts +210 -0
  129. package/scripts/run-baseline-comparison.ts +116 -0
  130. package/scripts/run-full-pipeline.ts +272 -0
  131. package/scripts/run_rlaif_loop.ts +78 -0
  132. package/scripts/run_task_benchmark.ts +247 -0
  133. package/scripts/runpod_setup.sh +137 -0
  134. package/scripts/runpod_validate.sh +147 -0
  135. package/scripts/test-model-in-game.ts +955 -0
  136. package/scripts/test-scoring.ts +73 -0
  137. package/scripts/test-trained-model.ts +209 -0
  138. package/scripts/train-and-test.ts +824 -0
  139. package/scripts/verify-final.ts +118 -0
  140. package/src/adapter.ts +516 -0
  141. package/src/archetypes/ArchetypeConfigService.ts +626 -0
  142. package/src/archetypes/derive-archetype.ts +249 -0
  143. package/src/archetypes/index.ts +22 -0
  144. package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
  145. package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
  146. package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
  147. package/src/benchmark/BenchmarkDataViewer.ts +324 -0
  148. package/src/benchmark/BenchmarkHistoryService.ts +221 -0
  149. package/src/benchmark/BenchmarkRunner.ts +685 -0
  150. package/src/benchmark/BenchmarkValidator.ts +204 -0
  151. package/src/benchmark/FastEvalRunner.ts +225 -0
  152. package/src/benchmark/MetricsValidator.ts +165 -0
  153. package/src/benchmark/MetricsVisualizer.ts +909 -0
  154. package/src/benchmark/ModelBenchmarkService.ts +611 -0
  155. package/src/benchmark/ModelRegistry.ts +158 -0
  156. package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
  157. package/src/benchmark/SimulationA2AInterface.ts +1169 -0
  158. package/src/benchmark/SimulationEngine.ts +832 -0
  159. package/src/benchmark/TaskRunner.ts +94 -0
  160. package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
  161. package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
  162. package/src/benchmark/index.ts +91 -0
  163. package/src/benchmark/parseSimulationMetrics.ts +124 -0
  164. package/src/benchmark/simulation-types.ts +78 -0
  165. package/src/dependencies.ts +475 -0
  166. package/src/generation/TrajectoryGenerator.ts +387 -0
  167. package/src/generation/index.ts +12 -0
  168. package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
  169. package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
  170. package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
  171. package/src/huggingface/index.ts +27 -0
  172. package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
  173. package/src/index.ts +102 -0
  174. package/src/init-training.ts +53 -0
  175. package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
  176. package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
  177. package/src/metrics/index.ts +8 -0
  178. package/src/metrics/types.ts +200 -0
  179. package/src/rubrics/__tests__/index.test.ts +184 -0
  180. package/src/rubrics/ass-kisser.ts +85 -0
  181. package/src/rubrics/degen.ts +80 -0
  182. package/src/rubrics/goody-twoshoes.ts +84 -0
  183. package/src/rubrics/index.ts +236 -0
  184. package/src/rubrics/information-trader.ts +84 -0
  185. package/src/rubrics/infosec.ts +101 -0
  186. package/src/rubrics/liar.ts +104 -0
  187. package/src/rubrics/perps-trader.ts +87 -0
  188. package/src/rubrics/researcher.ts +81 -0
  189. package/src/rubrics/scammer.ts +82 -0
  190. package/src/rubrics/social-butterfly.ts +73 -0
  191. package/src/rubrics/super-predictor.ts +97 -0
  192. package/src/rubrics/trader.ts +67 -0
  193. package/src/scoring/ArchetypeScoringService.ts +486 -0
  194. package/src/scoring/JudgePromptBuilder.ts +556 -0
  195. package/src/scoring/LLMJudgeCache.ts +401 -0
  196. package/src/scoring/index.ts +9 -0
  197. package/src/training/AutomationPipeline.ts +916 -0
  198. package/src/training/BenchmarkService.ts +518 -0
  199. package/src/training/ConfigValidator.ts +220 -0
  200. package/src/training/MarketOutcomesTracker.ts +187 -0
  201. package/src/training/ModelDeployer.ts +186 -0
  202. package/src/training/ModelFetcher.ts +76 -0
  203. package/src/training/ModelSelectionService.ts +341 -0
  204. package/src/training/ModelUsageVerifier.ts +160 -0
  205. package/src/training/MultiModelOrchestrator.ts +580 -0
  206. package/src/training/RLModelConfig.ts +407 -0
  207. package/src/training/RewardBackpropagationService.ts +149 -0
  208. package/src/training/RulerScoringService.ts +666 -0
  209. package/src/training/TrainingMonitor.ts +166 -0
  210. package/src/training/TrajectoryRecorder.ts +399 -0
  211. package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
  212. package/src/training/index.ts +100 -0
  213. package/src/training/logRLConfig.ts +34 -0
  214. package/src/training/pipeline.ts +129 -0
  215. package/src/training/storage/ModelStorageService.ts +279 -0
  216. package/src/training/storage/TrainingDataArchiver.ts +197 -0
  217. package/src/training/storage/index.ts +17 -0
  218. package/src/training/types.ts +207 -0
  219. package/src/training/window-utils.ts +138 -0
  220. package/src/utils/index.ts +101 -0
  221. package/src/utils/logger.ts +59 -0
  222. package/src/utils/snowflake.ts +17 -0
  223. package/src/utils/synthetic-detector.ts +111 -0
  224. package/tsconfig.json +20 -0
@@ -0,0 +1,360 @@
1
+ """
2
+ GRPO (Group Relative Policy Optimization) Training for ShouldRespond.
3
+
4
+ Implements proper GRPO with:
5
+ 1. KL divergence penalty against a frozen reference model
6
+ 2. PPO-style policy ratio clipping
7
+ 3. Gradient norm clipping
8
+ 4. Per-token log-prob computation
9
+ 5. Early stopping on KL divergence blow-up
10
+ """
11
+
12
+ import argparse
13
+ import json
14
+ import math
15
+ import re
16
+
17
+ import mlx.core as mx
18
+ import mlx.nn as nn
19
+ import mlx.optimizers as optim
20
+ from mlx.utils import tree_flatten, tree_map
21
+ from mlx_lm import load, generate
22
+ from mlx_lm.sample_utils import make_sampler
23
+ from mlx_lm.tuner.lora import LoRALinear
24
+
25
+ # -------------------------------------------------------------------------
26
+ # Reward Function
27
+ # -------------------------------------------------------------------------
28
+ ACTION_RE = re.compile(r"<action>\s*(.*?)\s*</action>", re.DOTALL)
29
+
30
+ def compute_rewards(prompts, completions):
31
+ """
32
+ Reward function for shouldRespond task.
33
+ Returns mx.array of shape (N,) with rewards in [-1.0, +1.2].
34
+
35
+ Graduated scoring:
36
+ +1.0 correct action
37
+ +0.2 valid XML format bonus
38
+ -0.3 wrong action (softened to avoid massive negative gradients)
39
+ -0.5 no parseable action at all
40
+ """
41
+ rewards = []
42
+
43
+ for prompt, text in zip(prompts, completions):
44
+ score = 0.0
45
+
46
+ # Parse action from generated text
47
+ action_match = ACTION_RE.search(text)
48
+ action = action_match.group(1).strip().upper() if action_match else "NONE"
49
+
50
+ # --- Determine ground-truth action from prompt heuristics ---
51
+ last_user_msg = prompt.split("User:")[-1] if "User:" in prompt else prompt
52
+
53
+ is_direct_mention = ("@Eliza" in last_user_msg or "Eliza" in last_user_msg)
54
+ is_stop = any(w in last_user_msg.lower() for w in ["stop", "shut up", "quiet", "be quiet"])
55
+ is_continuation = "Eliza:" in prompt # Eliza spoke earlier in the thread
56
+ is_ambiguous = any(w in last_user_msg.lower()
57
+ for w in ["anyone", "anybody", "help", "assist", "question", "somebody"])
58
+
59
+ should_respond = is_direct_mention or is_continuation or is_ambiguous
60
+
61
+ # --- Score the action ---
62
+ if is_stop:
63
+ if action == "STOP":
64
+ score += 1.0
65
+ elif action == "IGNORE":
66
+ score += 0.3 # acceptable fallback
67
+ else:
68
+ score -= 0.3
69
+ elif should_respond:
70
+ if action == "RESPOND":
71
+ score += 1.0
72
+ else:
73
+ score -= 0.3
74
+ else: # should ignore
75
+ if action == "IGNORE":
76
+ score += 1.0
77
+ else:
78
+ score -= 0.3
79
+
80
+ # Format bonus / penalty
81
+ if "<response>" in text and "</response>" in text:
82
+ score += 0.2
83
+ if action == "NONE":
84
+ score -= 0.5 # couldn't parse anything
85
+
86
+ rewards.append(score)
87
+
88
+ return mx.array(rewards)
89
+
90
+
91
+ # -------------------------------------------------------------------------
92
+ # Log-probability helpers
93
+ # -------------------------------------------------------------------------
94
+ def compute_token_log_probs(model, input_ids, mask):
95
+ """
96
+ Compute per-token log probabilities for the completion portion.
97
+
98
+ Args:
99
+ model: The language model.
100
+ input_ids: [1, L] token IDs (prompt + completion).
101
+ mask: [1, L] float mask (0 = prompt, 1 = completion).
102
+
103
+ Returns:
104
+ Scalar: sum of log-probs over completion tokens.
105
+ """
106
+ logits = model(input_ids) # [1, L, V]
107
+ logits = logits[:, :-1, :] # shift: predict next token
108
+ labels = input_ids[:, 1:] # [1, L-1]
109
+
110
+ # Per-token cross-entropy (positive)
111
+ ce = nn.losses.cross_entropy(logits, labels, reduction="none") # [1, L-1]
112
+ log_probs = -ce # log p(token)
113
+
114
+ token_mask = mask[:, 1:] # align with shifted labels
115
+ masked_log_probs = log_probs * token_mask
116
+
117
+ # Sum over valid tokens → trajectory-level log-prob
118
+ return mx.sum(masked_log_probs, axis=1) # [1]
119
+
120
+
121
+ # -------------------------------------------------------------------------
122
+ # Training Loop
123
+ # -------------------------------------------------------------------------
124
+ def train(args):
125
+ print(f"Loading model: {args.model} with adapter: {args.adapter_path}")
126
+ model, tokenizer = load(args.model, adapter_path=args.adapter_path)
127
+
128
+ # -- Freeze all, then unfreeze LoRA adapters --
129
+ model.freeze()
130
+ for m in model.modules():
131
+ if isinstance(m, LoRALinear):
132
+ m.unfreeze()
133
+ if hasattr(m, "linear"):
134
+ m.linear.freeze()
135
+
136
+ # Reset adapters for pure RL (start from base model behavior)
137
+ if args.reset_adapters:
138
+ if hasattr(m, "lora_a"):
139
+ m.lora_a = mx.random.normal(m.lora_a.shape) * 0.02
140
+ if hasattr(m, "lora_b"):
141
+ m.lora_b = mx.zeros(m.lora_b.shape)
142
+
143
+ if args.reset_adapters:
144
+ mx.eval(model.parameters())
145
+ print("Reset LoRA adapters to random/zero (pure RL from base model behavior).")
146
+
147
+ # Count parameters
148
+ total_params = sum(p.size for _, p in tree_flatten(model.parameters()))
149
+ trainable_params = sum(p.size for _, p in tree_flatten(model.trainable_parameters()))
150
+ print(f"Trainable params: {trainable_params} / {total_params} ({trainable_params/total_params:.2%})")
151
+
152
+ # -- Optimizer --
153
+ optimizer = optim.AdamW(learning_rate=args.lr)
154
+
155
+ # -- Load Data --
156
+ prompts = []
157
+ with open(args.data, "r") as f:
158
+ for line in f:
159
+ if not line.strip():
160
+ continue
161
+ try:
162
+ item = json.loads(line)
163
+ if "messages" in item:
164
+ for msg in reversed(item["messages"]):
165
+ if msg["role"] == "user":
166
+ prompts.append(msg["content"])
167
+ break
168
+ except Exception:
169
+ pass
170
+
171
+ print(f"Loaded {len(prompts)} prompts.")
172
+ if not prompts:
173
+ print("No prompts found! Exiting.")
174
+ return
175
+
176
+ # -- Sampler for generation --
177
+ sampler = make_sampler(temp=args.temp)
178
+
179
+ # ----------------------------------------------------------------
180
+ # GRPO Training Loop
181
+ # ----------------------------------------------------------------
182
+ print(f"\nStarting GRPO training for {args.iter} iterations")
183
+ print(f" Group size: {args.group_size}")
184
+ print(f" Temperature: {args.temp}")
185
+ print(f" Learning rate: {args.lr}")
186
+ print(f" KL weight (β): {args.kl_weight}")
187
+ print(f" Clip epsilon: {args.clip_eps}")
188
+ print(f" Max grad norm: {args.max_grad_norm}")
189
+ print()
190
+
191
+ best_mean_reward = -float("inf")
192
+
193
+ for i in range(args.iter):
194
+ prompt = prompts[i % len(prompts)]
195
+ prompt_tokens = tokenizer.encode(prompt)
196
+ prompt_len = len(prompt_tokens)
197
+
198
+ print(f"[Iter {i+1}/{args.iter}] Prompt: {prompt[:60]}...")
199
+
200
+ # ---- 1. Generation Phase ----
201
+ completions = []
202
+ full_inputs = []
203
+ masks = []
204
+
205
+ for _ in range(args.group_size):
206
+ text = generate(
207
+ model, tokenizer, prompt=prompt,
208
+ max_tokens=args.max_tokens, verbose=False, sampler=sampler,
209
+ )
210
+ completions.append(text)
211
+
212
+ full_text = prompt + text
213
+ full_tokens = tokenizer.encode(full_text)
214
+ full_inputs.append(mx.array(full_tokens))
215
+
216
+ L = len(full_tokens)
217
+ m = mx.zeros((L,), dtype=mx.float32)
218
+ m[prompt_len:] = 1.0
219
+ masks.append(m)
220
+
221
+ # ---- 2. Reward Phase ----
222
+ rewards = compute_rewards([prompt] * args.group_size, completions)
223
+ mean_r = mx.mean(rewards)
224
+ std_r = mx.max(mx.array([mx.std(rewards), mx.array(1e-4)])) # floor std
225
+ advantages = (rewards - mean_r) / std_r
226
+
227
+ print(f" Completions: {[c[:40] + '...' for c in completions]}")
228
+ print(f" Rewards: {rewards.tolist()}")
229
+ print(f" Advantages: {[f'{a:.3f}' for a in advantages.tolist()]}")
230
+
231
+ # Skip update if all advantages are zero (no learning signal)
232
+ if mx.max(mx.abs(advantages)).item() < 1e-6:
233
+ print(" ⏭ Skipping update (zero variance in rewards)")
234
+ continue
235
+
236
+ # ---- 3. Pre-compute old log-probs (serve as BOTH old policy AND reference) ----
237
+ # In GRPO, the "old policy" IS the reference policy for this iteration.
238
+ # We compute log-probs before any parameter updates happen.
239
+ old_log_probs = []
240
+ for j in range(args.group_size):
241
+ inp = full_inputs[j][None, :]
242
+ msk = masks[j][None, :]
243
+ lp = compute_token_log_probs(model, inp, msk)
244
+ mx.eval(lp)
245
+ old_log_probs.append(mx.stop_gradient(lp))
246
+
247
+ # ---- 4. Policy Update with Clipping + KL ----
248
+ step_loss = 0.0
249
+ step_kl = 0.0
250
+
251
+ for j in range(args.group_size):
252
+ inp = full_inputs[j][None, :]
253
+ msk = masks[j][None, :]
254
+ adv = advantages[j]
255
+ old_lp = old_log_probs[j] # already stop_gradient'd
256
+
257
+ def grpo_loss(model_inner):
258
+ # Current policy log-prob
259
+ cur_lp = compute_token_log_probs(model_inner, inp, msk)
260
+
261
+ # PPO-style ratio clipping
262
+ ratio = mx.exp(cur_lp - old_lp)
263
+ clipped_ratio = mx.clip(ratio, 1.0 - args.clip_eps, 1.0 + args.clip_eps)
264
+
265
+ # Surrogate objective (we minimize, so negate)
266
+ surr1 = ratio * adv
267
+ surr2 = clipped_ratio * adv
268
+ policy_loss = -mx.minimum(surr1, surr2)
269
+
270
+ # KL divergence penalty: D_KL(π_θ || π_old) ≈ log(π_θ/π_old)
271
+ # Using old policy as reference prevents drift
272
+ kl = cur_lp - old_lp
273
+ kl_penalty = args.kl_weight * kl
274
+
275
+ return mx.mean(policy_loss + kl_penalty)
276
+
277
+ loss, grads = nn.value_and_grad(model, grpo_loss)(model)
278
+
279
+ # ---- Gradient norm clipping ----
280
+ grad_norms_sq = sum(
281
+ mx.sum(g * g).item()
282
+ for _, g in tree_flatten(grads)
283
+ )
284
+ grad_norm = math.sqrt(grad_norms_sq)
285
+
286
+ if grad_norm > args.max_grad_norm:
287
+ scale = args.max_grad_norm / (grad_norm + 1e-6)
288
+ grads = tree_map(lambda g: g * scale, grads)
289
+
290
+ optimizer.update(model, grads)
291
+ mx.eval(model.parameters())
292
+
293
+ step_loss += loss.item()
294
+
295
+ # Track KL for monitoring
296
+ cur_lp_check = compute_token_log_probs(model, inp, msk)
297
+ mx.eval(cur_lp_check)
298
+ kl_val = (cur_lp_check - old_lp).item()
299
+ step_kl += kl_val
300
+
301
+
302
+ avg_loss = step_loss / args.group_size
303
+ avg_kl = step_kl / args.group_size
304
+
305
+ print(f" Loss: {avg_loss:.4f} | KL: {avg_kl:.4f} | GradNorm: {grad_norm:.2f}")
306
+
307
+ # ---- Early stopping on KL blow-up ----
308
+ if abs(avg_kl) > args.kl_max:
309
+ print(f"\n⚠️ KL divergence ({avg_kl:.2f}) exceeded max ({args.kl_max}). Stopping early.")
310
+ break
311
+
312
+ # Track best reward
313
+ mr = mean_r.item()
314
+ if mr > best_mean_reward:
315
+ best_mean_reward = mr
316
+
317
+ # ---- Save ----
318
+ print(f"\nBest mean reward: {best_mean_reward:.3f}")
319
+ if args.save_path:
320
+ import os
321
+ import shutil
322
+ os.makedirs(os.path.dirname(args.save_path) or ".", exist_ok=True)
323
+
324
+ # Save ONLY trainable (LoRA) parameters, not the full model
325
+ trainable = dict(tree_flatten(model.trainable_parameters()))
326
+ mx.save_safetensors(args.save_path, trainable)
327
+ print(f"Saved {len(trainable)} adapter weight tensors to {args.save_path}")
328
+
329
+ # Copy adapter_config.json so mlx_lm.load() can find it
330
+ adapter_dir = os.path.dirname(args.save_path)
331
+ config_src = os.path.join(args.adapter_path, "adapter_config.json") if args.adapter_path else None
332
+ if config_src and os.path.exists(config_src):
333
+ config_dst = os.path.join(adapter_dir, "adapter_config.json")
334
+ if not os.path.exists(config_dst):
335
+ shutil.copy2(config_src, config_dst)
336
+ print(f"Copied adapter_config.json to {adapter_dir}")
337
+
338
+
339
+ if __name__ == "__main__":
340
+ parser = argparse.ArgumentParser(description="GRPO training for shouldRespond")
341
+ parser.add_argument("--model", type=str, required=True, help="Base model path")
342
+ parser.add_argument("--adapter-path", type=str, default=None, help="SFT adapter to start from")
343
+ parser.add_argument("--data", type=str, required=True, help="Training data JSONL")
344
+ parser.add_argument("--iter", type=int, default=50, help="Number of GRPO iterations")
345
+ parser.add_argument("--group-size", type=int, default=8, help="Completions per prompt (G)")
346
+ parser.add_argument("--max-tokens", type=int, default=150, help="Max generation tokens")
347
+ parser.add_argument("--lr", type=float, default=5e-7, help="Learning rate")
348
+ parser.add_argument("--temp", type=float, default=0.7, help="Sampling temperature")
349
+ parser.add_argument("--kl-weight", type=float, default=0.1, help="KL penalty coefficient β")
350
+ parser.add_argument("--kl-max", type=float, default=10.0, help="Max KL before early stop")
351
+ parser.add_argument("--clip-eps", type=float, default=0.2, help="PPO clip epsilon")
352
+ parser.add_argument("--max-grad-norm", type=float, default=1.0, help="Gradient norm clip")
353
+ parser.add_argument("--save-path", type=str,
354
+ default="trained_models/should_respond_rl/adapters.safetensors",
355
+ help="Output adapter weights path")
356
+ parser.add_argument("--reset-adapters", action="store_true",
357
+ help="Reset LoRA weights to random/zero for pure RL training")
358
+
359
+ args = parser.parse_args()
360
+ train(args)
@@ -0,0 +1,223 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Train from JSONL Scored Trajectories
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import json
9
+ import random
10
+ import argparse
11
+ import logging
12
+ from pathlib import Path
13
+ from typing import List, Dict, Any
14
+
15
+ # Add src to path
16
+ sys.path.insert(0, str(Path(__file__).parent.parent))
17
+
18
+ logging.basicConfig(
19
+ level=logging.INFO,
20
+ format='%(asctime)s [%(levelname)s] %(message)s'
21
+ )
22
+ logger = logging.getLogger(__name__)
23
+
24
+ def detect_backend() -> str:
25
+ """Auto-detect the best available backend."""
26
+ try:
27
+ import mlx.core
28
+ logger.info("MLX backend available (Apple Silicon)")
29
+ return "mlx"
30
+ except ImportError:
31
+ pass
32
+
33
+ try:
34
+ import torch
35
+ if torch.cuda.is_available():
36
+ logger.info(f"CUDA backend available: {torch.cuda.get_device_name(0)}")
37
+ return "cuda"
38
+ except ImportError:
39
+ pass
40
+
41
+ logger.warning("No GPU backend available, falling back to CPU")
42
+ return "cpu"
43
+
44
+ def load_and_process_data(input_file: str, min_score: float) -> List[Dict[str, Any]]:
45
+ """
46
+ Load trajectories from JSONL, filter by score, and convert to chat format.
47
+ """
48
+ samples = []
49
+
50
+ if not os.path.exists(input_file):
51
+ raise FileNotFoundError(f"Input file not found: {input_file}")
52
+
53
+ logger.info(f"Loading data from {input_file}...")
54
+
55
+ with open(input_file, 'r') as f:
56
+ for line in f:
57
+ if not line.strip():
58
+ continue
59
+ try:
60
+ traj = json.loads(line)
61
+
62
+ # Check for direct messages format (SFT dataset)
63
+ if 'messages' in traj:
64
+ samples.append(traj)
65
+ continue
66
+
67
+ # Filter by score if present
68
+ if traj.get('isScored'):
69
+ if traj.get('score', 0) < min_score:
70
+ continue
71
+
72
+ # Extract conversation (Trajectory format)
73
+ # We want to train the model to generate the ACTION based on observation
74
+ # Or generate the RESPONSE based on the task
75
+
76
+ task = traj.get('metadata', {}).get('task', '')
77
+ steps = traj.get('steps', [])
78
+
79
+ if not task:
80
+ continue
81
+
82
+ # Simple Format:
83
+ # System: You are a helpful assistant.
84
+ # User: <Task>
85
+ # Assistant: <Response>
86
+
87
+ # In a real scenario, we might want to train on every step.
88
+ # For this benchmark, we train on the final response to the task.
89
+
90
+ last_step = steps[-1] if steps else None
91
+ if not last_step:
92
+ continue
93
+
94
+ response = last_step.get('action', {}).get('parameters', {}).get('text')
95
+ if not response:
96
+ continue
97
+
98
+ messages = [
99
+ {"role": "system", "content": "You are a helpful assistant."},
100
+ {"role": "user", "content": task},
101
+ {"role": "assistant", "content": response}
102
+ ]
103
+
104
+ samples.append({"messages": messages})
105
+
106
+ except json.JSONDecodeError:
107
+ continue
108
+
109
+ logger.info(f"Loaded {len(samples)} valid training samples (score >= {min_score})")
110
+ return samples
111
+
112
+ def train_mlx(
113
+ samples: List[Dict],
114
+ model_name: str,
115
+ output_dir: str,
116
+ iters: int,
117
+ batch_size: int,
118
+ learning_rate: float
119
+ ):
120
+ """Train using MLX LoRA."""
121
+ import subprocess
122
+
123
+ logger.info("Starting MLX Training...")
124
+
125
+ # Prepare Data Directory
126
+ data_dir = os.path.join(output_dir, "data")
127
+ os.makedirs(data_dir, exist_ok=True)
128
+
129
+ # Split Train/Valid
130
+ random.shuffle(samples)
131
+ split_idx = int(len(samples) * 0.9)
132
+ train_samples = samples[:split_idx]
133
+ valid_samples = samples[split_idx:]
134
+
135
+ # Write JSONL for MLX
136
+ with open(os.path.join(data_dir, "train.jsonl"), 'w') as f:
137
+ for s in train_samples:
138
+ f.write(json.dumps(s) + "\n")
139
+
140
+ with open(os.path.join(data_dir, "valid.jsonl"), 'w') as f:
141
+ for s in valid_samples:
142
+ f.write(json.dumps(s) + "\n")
143
+
144
+ if not valid_samples:
145
+ # Create a dummy validation set if empty (e.g. only 1 sample)
146
+ with open(os.path.join(data_dir, "valid.jsonl"), 'w') as f:
147
+ f.write(json.dumps(train_samples[0]) + "\n")
148
+
149
+ adapter_path = os.path.join(output_dir, "adapters")
150
+
151
+ # Construct MLX Command
152
+ # We use the python module directly via subprocess to avoid import issues with conflicting arguments
153
+ cmd = [
154
+ sys.executable, "-m", "mlx_lm.lora",
155
+ "--model", model_name,
156
+ "--train",
157
+ "--data", data_dir,
158
+ "--adapter-path", adapter_path,
159
+ "--batch-size", str(batch_size),
160
+ "--iters", str(iters),
161
+ "--learning-rate", str(learning_rate),
162
+ "--steps-per-report", "5",
163
+ "--save-every", "10",
164
+ ]
165
+
166
+ logger.info(f"Running command: {' '.join(cmd)}")
167
+ subprocess.run(cmd, check=True)
168
+ logger.info(f"Training complete. Adapters saved to {adapter_path}")
169
+
170
+ def main():
171
+ parser = argparse.ArgumentParser(description="Train from JSONL")
172
+ parser.add_argument("--input", default="scored_trajectories.jsonl", help="Input JSONL file")
173
+ parser.add_argument("--output", default="trained_models/jsonl_run", help="Output directory")
174
+ parser.add_argument("--min-score", type=float, default=0.7, help="Minimum score to include")
175
+
176
+ parser.add_argument("--model", default="mlx-community/Qwen2.5-1.5B-Instruct-4bit", help="Base model (default: Qwen 1.5B 4bit for Mac)")
177
+ parser.add_argument("--backend", choices=["mlx", "cuda", "cpu"], default=None)
178
+
179
+ parser.add_argument("--iters", type=int, default=100, help="Training iterations")
180
+ parser.add_argument("--batch-size", type=int, default=1, help="Batch size") # MLX handles small batch sizes well
181
+ parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate")
182
+
183
+ args = parser.parse_args()
184
+
185
+ # Detect Backend
186
+ backend = args.backend or detect_backend()
187
+ logger.info(f"Backend: {backend}")
188
+
189
+ # Load Data
190
+ full_input_path = os.path.abspath(args.input)
191
+ samples = load_and_process_data(full_input_path, args.min_score)
192
+
193
+ if not samples:
194
+ logger.error("No samples found. Exiting.")
195
+ return
196
+
197
+ # Train
198
+ output_dir = os.path.abspath(args.output)
199
+ os.makedirs(output_dir, exist_ok=True)
200
+
201
+ if backend == "mlx":
202
+ train_mlx(
203
+ samples,
204
+ args.model,
205
+ output_dir,
206
+ args.iters,
207
+ args.batch_size,
208
+ args.lr
209
+ )
210
+ else:
211
+ logger.warning(f"Backend {backend} detected. MLX not available. Running in DRY RUN mode for verification.")
212
+ # Dry run: just verify data processing and split
213
+ random.shuffle(samples)
214
+ split_idx = int(len(samples) * 0.9)
215
+ train_samples = samples[:split_idx]
216
+ valid_samples = samples[split_idx:]
217
+
218
+ logger.info(f"Dry Run: Would train on {len(train_samples)} samples, validate on {len(valid_samples)} samples.")
219
+ logger.info(f"Sample data: {json.dumps(train_samples[0] if train_samples else {}, indent=2)}")
220
+ logger.info("Verification complete (no actual training performed on CPU).")
221
+
222
+ if __name__ == "__main__":
223
+ main()