@elizaos/training 2.0.0-alpha.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Dockerfile +75 -0
- package/LICENSE +21 -0
- package/Makefile +374 -0
- package/README.md +346 -0
- package/config/rubrics.json +137 -0
- package/docker-compose.test.yml +57 -0
- package/package.json +57 -0
- package/python/config/babylon_atropos.yaml +90 -0
- package/python/config/profiles/12gb.json +11 -0
- package/python/config/profiles/16gb.json +10 -0
- package/python/config/profiles/24gb.json +10 -0
- package/python/config/profiles/48gb.json +10 -0
- package/python/config/profiles/cpu.json +11 -0
- package/python/config/profiles/l40-2gpu-safe.json +20 -0
- package/python/config/profiles/l40-2gpu.json +22 -0
- package/python/config/profiles/l40-4gpu.json +21 -0
- package/python/config/profiles/l40.json +17 -0
- package/python/config/tinker_training.yaml +143 -0
- package/python/curriculum_state.json +165 -0
- package/python/env.template +86 -0
- package/python/env.training.template +46 -0
- package/python/pyproject.toml +41 -0
- package/python/requirements-ci.txt +31 -0
- package/python/requirements.txt +87 -0
- package/python/scripts/__init__.py +4 -0
- package/python/scripts/benchmark_should_respond.py +190 -0
- package/python/scripts/debug_inference.py +62 -0
- package/python/scripts/import_json_trajectories.py +412 -0
- package/python/scripts/local-finetune/README.md +63 -0
- package/python/scripts/local-finetune/ingest_and_score.py +139 -0
- package/python/scripts/local-finetune/merge_model.py +32 -0
- package/python/scripts/local-finetune/test_adapter.py +91 -0
- package/python/scripts/local-finetune/train_from_csv.py +132 -0
- package/python/scripts/merge_trajectories.py +318 -0
- package/python/scripts/optimize_prompt_grpo.py +269 -0
- package/python/scripts/run_ab_test.py +143 -0
- package/python/scripts/run_full_pipeline.py +544 -0
- package/python/scripts/run_tinker_training.py +192 -0
- package/python/scripts/run_training.py +914 -0
- package/python/scripts/test_generation.py +29 -0
- package/python/scripts/test_judge.py +155 -0
- package/python/scripts/test_pipeline.py +356 -0
- package/python/scripts/test_trained_model.py +380 -0
- package/python/scripts/train_grpo.py +360 -0
- package/python/scripts/train_jsonl.py +223 -0
- package/python/scripts/train_local.py +528 -0
- package/python/setup.py +20 -0
- package/python/src/__init__.py +190 -0
- package/python/src/data_bridge/__init__.py +24 -0
- package/python/src/data_bridge/converter.py +435 -0
- package/python/src/data_bridge/reader.py +393 -0
- package/python/src/models.py +283 -0
- package/python/src/training/__init__.py +605 -0
- package/python/src/training/ab_testing.py +404 -0
- package/python/src/training/action_executor.py +621 -0
- package/python/src/training/archetype_trainer.py +347 -0
- package/python/src/training/atropos_trainer.py +980 -0
- package/python/src/training/babylon_env.py +1254 -0
- package/python/src/training/error_recovery.py +647 -0
- package/python/src/training/evaluation.py +856 -0
- package/python/src/training/fast_simulator.py +880 -0
- package/python/src/training/format_validator.py +584 -0
- package/python/src/training/hybrid_env.py +522 -0
- package/python/src/training/kl_controller.py +628 -0
- package/python/src/training/multi_prompt_dataset.py +883 -0
- package/python/src/training/multi_turn.py +656 -0
- package/python/src/training/online_env.py +1084 -0
- package/python/src/training/quality_scorer.py +391 -0
- package/python/src/training/quality_utils.py +633 -0
- package/python/src/training/rewards.py +1344 -0
- package/python/src/training/rlaif_env.py +17 -0
- package/python/src/training/rollout_generator.py +502 -0
- package/python/src/training/rubric_loader.py +198 -0
- package/python/src/training/scenario_pool.py +1072 -0
- package/python/src/training/schemas.py +481 -0
- package/python/src/training/service_manager.py +552 -0
- package/python/src/training/simulation_bridge.py +535 -0
- package/python/src/training/tick_reward_attribution.py +399 -0
- package/python/src/training/tinker_client.py +575 -0
- package/python/src/training/tinker_trainer.py +646 -0
- package/python/src/training/tokenization_utils.py +402 -0
- package/python/tests/e2e/__init__.py +13 -0
- package/python/tests/e2e/conftest.py +258 -0
- package/python/tests/e2e/test_full_pipeline.py +643 -0
- package/python/tests/e2e/test_online_training_e2e.py +365 -0
- package/python/tests/integration/__init__.py +12 -0
- package/python/tests/integration/conftest.py +383 -0
- package/python/tests/integration/test_db_integration.py +649 -0
- package/python/tests/integration/test_json_mode_integration.py +554 -0
- package/python/tests/test_action_executor.py +594 -0
- package/python/tests/test_archetype_scoring.py +1027 -0
- package/python/tests/test_atropos_integration.py +360 -0
- package/python/tests/test_evaluation.py +727 -0
- package/python/tests/test_format_validator.py +486 -0
- package/python/tests/test_kl_controller.py +432 -0
- package/python/tests/test_lr_scheduler.py +579 -0
- package/python/tests/test_multi_turn.py +590 -0
- package/python/tests/test_online_env.py +519 -0
- package/python/tests/test_quality_scorer.py +474 -0
- package/python/tests/test_scenario_pool.py +735 -0
- package/python/tests/test_service_manager.py +585 -0
- package/python/tests/test_simulation_rollout.py +581 -0
- package/python/tests/test_tokenization_utils.py +501 -0
- package/python/tests/test_training_orchestrator.py +497 -0
- package/python/tests/test_training_output_structure.py +661 -0
- package/research-output/training-runs/training-run-1770772042899.json +26 -0
- package/research-output/training-runs/training-run-1770930079670.json +32 -0
- package/research-output/training-runs/training-run-1770930143700.json +44 -0
- package/research-output/training-runs/training-run-1770930183638.json +38 -0
- package/research-output/training-runs/training-run-1770930442049.json +38 -0
- package/research-output/training-runs/training-run-1770930793243.json +38 -0
- package/research-output/training-runs/training-run-1771276293257.json +38 -0
- package/research-output/training-runs/training-run-1771276389280.json +38 -0
- package/research-output/training-runs/training-run-1771276502776.json +38 -0
- package/research-output/training-runs/training-run-1771277340748.json +38 -0
- package/research-output/training-runs/training-run-1773013658993.json +38 -0
- package/research-output/training-runs/training-run-1773013861014.json +38 -0
- package/research-output/training-runs/training-run-1773014215983.json +38 -0
- package/scripts/assess-training-data.ts +422 -0
- package/scripts/e2e-training-test.ts +550 -0
- package/scripts/export-rubrics.ts +64 -0
- package/scripts/generate-research-report.ts +1523 -0
- package/scripts/generate_dataset.sh +173 -0
- package/scripts/generate_should_respond.ts +267 -0
- package/scripts/generate_should_respond_dataset.ts +162 -0
- package/scripts/json-mode-benchmark.ts +399 -0
- package/scripts/rank_trajectories.ts +207 -0
- package/scripts/real-archetype-benchmark.ts +210 -0
- package/scripts/run-baseline-comparison.ts +116 -0
- package/scripts/run-full-pipeline.ts +272 -0
- package/scripts/run_rlaif_loop.ts +78 -0
- package/scripts/run_task_benchmark.ts +247 -0
- package/scripts/runpod_setup.sh +137 -0
- package/scripts/runpod_validate.sh +147 -0
- package/scripts/test-model-in-game.ts +955 -0
- package/scripts/test-scoring.ts +73 -0
- package/scripts/test-trained-model.ts +209 -0
- package/scripts/train-and-test.ts +824 -0
- package/scripts/verify-final.ts +118 -0
- package/src/adapter.ts +516 -0
- package/src/archetypes/ArchetypeConfigService.ts +626 -0
- package/src/archetypes/derive-archetype.ts +249 -0
- package/src/archetypes/index.ts +22 -0
- package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
- package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
- package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
- package/src/benchmark/BenchmarkDataViewer.ts +324 -0
- package/src/benchmark/BenchmarkHistoryService.ts +221 -0
- package/src/benchmark/BenchmarkRunner.ts +685 -0
- package/src/benchmark/BenchmarkValidator.ts +204 -0
- package/src/benchmark/FastEvalRunner.ts +225 -0
- package/src/benchmark/MetricsValidator.ts +165 -0
- package/src/benchmark/MetricsVisualizer.ts +909 -0
- package/src/benchmark/ModelBenchmarkService.ts +611 -0
- package/src/benchmark/ModelRegistry.ts +158 -0
- package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
- package/src/benchmark/SimulationA2AInterface.ts +1169 -0
- package/src/benchmark/SimulationEngine.ts +832 -0
- package/src/benchmark/TaskRunner.ts +94 -0
- package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
- package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
- package/src/benchmark/index.ts +91 -0
- package/src/benchmark/parseSimulationMetrics.ts +124 -0
- package/src/benchmark/simulation-types.ts +78 -0
- package/src/dependencies.ts +475 -0
- package/src/generation/TrajectoryGenerator.ts +387 -0
- package/src/generation/index.ts +12 -0
- package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
- package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
- package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
- package/src/huggingface/index.ts +27 -0
- package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
- package/src/index.ts +102 -0
- package/src/init-training.ts +53 -0
- package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
- package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
- package/src/metrics/index.ts +8 -0
- package/src/metrics/types.ts +200 -0
- package/src/rubrics/__tests__/index.test.ts +184 -0
- package/src/rubrics/ass-kisser.ts +85 -0
- package/src/rubrics/degen.ts +80 -0
- package/src/rubrics/goody-twoshoes.ts +84 -0
- package/src/rubrics/index.ts +236 -0
- package/src/rubrics/information-trader.ts +84 -0
- package/src/rubrics/infosec.ts +101 -0
- package/src/rubrics/liar.ts +104 -0
- package/src/rubrics/perps-trader.ts +87 -0
- package/src/rubrics/researcher.ts +81 -0
- package/src/rubrics/scammer.ts +82 -0
- package/src/rubrics/social-butterfly.ts +73 -0
- package/src/rubrics/super-predictor.ts +97 -0
- package/src/rubrics/trader.ts +67 -0
- package/src/scoring/ArchetypeScoringService.ts +486 -0
- package/src/scoring/JudgePromptBuilder.ts +556 -0
- package/src/scoring/LLMJudgeCache.ts +401 -0
- package/src/scoring/index.ts +9 -0
- package/src/training/AutomationPipeline.ts +916 -0
- package/src/training/BenchmarkService.ts +518 -0
- package/src/training/ConfigValidator.ts +220 -0
- package/src/training/MarketOutcomesTracker.ts +187 -0
- package/src/training/ModelDeployer.ts +186 -0
- package/src/training/ModelFetcher.ts +76 -0
- package/src/training/ModelSelectionService.ts +341 -0
- package/src/training/ModelUsageVerifier.ts +160 -0
- package/src/training/MultiModelOrchestrator.ts +580 -0
- package/src/training/RLModelConfig.ts +407 -0
- package/src/training/RewardBackpropagationService.ts +149 -0
- package/src/training/RulerScoringService.ts +666 -0
- package/src/training/TrainingMonitor.ts +166 -0
- package/src/training/TrajectoryRecorder.ts +399 -0
- package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
- package/src/training/index.ts +100 -0
- package/src/training/logRLConfig.ts +34 -0
- package/src/training/pipeline.ts +129 -0
- package/src/training/storage/ModelStorageService.ts +279 -0
- package/src/training/storage/TrainingDataArchiver.ts +197 -0
- package/src/training/storage/index.ts +17 -0
- package/src/training/types.ts +207 -0
- package/src/training/window-utils.ts +138 -0
- package/src/utils/index.ts +101 -0
- package/src/utils/logger.ts +59 -0
- package/src/utils/snowflake.ts +17 -0
- package/src/utils/synthetic-detector.ts +111 -0
- package/tsconfig.json +20 -0
|
@@ -0,0 +1,360 @@
|
|
|
1
|
+
"""
|
|
2
|
+
GRPO (Group Relative Policy Optimization) Training for ShouldRespond.
|
|
3
|
+
|
|
4
|
+
Implements proper GRPO with:
|
|
5
|
+
1. KL divergence penalty against a frozen reference model
|
|
6
|
+
2. PPO-style policy ratio clipping
|
|
7
|
+
3. Gradient norm clipping
|
|
8
|
+
4. Per-token log-prob computation
|
|
9
|
+
5. Early stopping on KL divergence blow-up
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import argparse
|
|
13
|
+
import json
|
|
14
|
+
import math
|
|
15
|
+
import re
|
|
16
|
+
|
|
17
|
+
import mlx.core as mx
|
|
18
|
+
import mlx.nn as nn
|
|
19
|
+
import mlx.optimizers as optim
|
|
20
|
+
from mlx.utils import tree_flatten, tree_map
|
|
21
|
+
from mlx_lm import load, generate
|
|
22
|
+
from mlx_lm.sample_utils import make_sampler
|
|
23
|
+
from mlx_lm.tuner.lora import LoRALinear
|
|
24
|
+
|
|
25
|
+
# -------------------------------------------------------------------------
|
|
26
|
+
# Reward Function
|
|
27
|
+
# -------------------------------------------------------------------------
|
|
28
|
+
ACTION_RE = re.compile(r"<action>\s*(.*?)\s*</action>", re.DOTALL)
|
|
29
|
+
|
|
30
|
+
def compute_rewards(prompts, completions):
|
|
31
|
+
"""
|
|
32
|
+
Reward function for shouldRespond task.
|
|
33
|
+
Returns mx.array of shape (N,) with rewards in [-1.0, +1.2].
|
|
34
|
+
|
|
35
|
+
Graduated scoring:
|
|
36
|
+
+1.0 correct action
|
|
37
|
+
+0.2 valid XML format bonus
|
|
38
|
+
-0.3 wrong action (softened to avoid massive negative gradients)
|
|
39
|
+
-0.5 no parseable action at all
|
|
40
|
+
"""
|
|
41
|
+
rewards = []
|
|
42
|
+
|
|
43
|
+
for prompt, text in zip(prompts, completions):
|
|
44
|
+
score = 0.0
|
|
45
|
+
|
|
46
|
+
# Parse action from generated text
|
|
47
|
+
action_match = ACTION_RE.search(text)
|
|
48
|
+
action = action_match.group(1).strip().upper() if action_match else "NONE"
|
|
49
|
+
|
|
50
|
+
# --- Determine ground-truth action from prompt heuristics ---
|
|
51
|
+
last_user_msg = prompt.split("User:")[-1] if "User:" in prompt else prompt
|
|
52
|
+
|
|
53
|
+
is_direct_mention = ("@Eliza" in last_user_msg or "Eliza" in last_user_msg)
|
|
54
|
+
is_stop = any(w in last_user_msg.lower() for w in ["stop", "shut up", "quiet", "be quiet"])
|
|
55
|
+
is_continuation = "Eliza:" in prompt # Eliza spoke earlier in the thread
|
|
56
|
+
is_ambiguous = any(w in last_user_msg.lower()
|
|
57
|
+
for w in ["anyone", "anybody", "help", "assist", "question", "somebody"])
|
|
58
|
+
|
|
59
|
+
should_respond = is_direct_mention or is_continuation or is_ambiguous
|
|
60
|
+
|
|
61
|
+
# --- Score the action ---
|
|
62
|
+
if is_stop:
|
|
63
|
+
if action == "STOP":
|
|
64
|
+
score += 1.0
|
|
65
|
+
elif action == "IGNORE":
|
|
66
|
+
score += 0.3 # acceptable fallback
|
|
67
|
+
else:
|
|
68
|
+
score -= 0.3
|
|
69
|
+
elif should_respond:
|
|
70
|
+
if action == "RESPOND":
|
|
71
|
+
score += 1.0
|
|
72
|
+
else:
|
|
73
|
+
score -= 0.3
|
|
74
|
+
else: # should ignore
|
|
75
|
+
if action == "IGNORE":
|
|
76
|
+
score += 1.0
|
|
77
|
+
else:
|
|
78
|
+
score -= 0.3
|
|
79
|
+
|
|
80
|
+
# Format bonus / penalty
|
|
81
|
+
if "<response>" in text and "</response>" in text:
|
|
82
|
+
score += 0.2
|
|
83
|
+
if action == "NONE":
|
|
84
|
+
score -= 0.5 # couldn't parse anything
|
|
85
|
+
|
|
86
|
+
rewards.append(score)
|
|
87
|
+
|
|
88
|
+
return mx.array(rewards)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
# -------------------------------------------------------------------------
|
|
92
|
+
# Log-probability helpers
|
|
93
|
+
# -------------------------------------------------------------------------
|
|
94
|
+
def compute_token_log_probs(model, input_ids, mask):
|
|
95
|
+
"""
|
|
96
|
+
Compute per-token log probabilities for the completion portion.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
model: The language model.
|
|
100
|
+
input_ids: [1, L] token IDs (prompt + completion).
|
|
101
|
+
mask: [1, L] float mask (0 = prompt, 1 = completion).
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
Scalar: sum of log-probs over completion tokens.
|
|
105
|
+
"""
|
|
106
|
+
logits = model(input_ids) # [1, L, V]
|
|
107
|
+
logits = logits[:, :-1, :] # shift: predict next token
|
|
108
|
+
labels = input_ids[:, 1:] # [1, L-1]
|
|
109
|
+
|
|
110
|
+
# Per-token cross-entropy (positive)
|
|
111
|
+
ce = nn.losses.cross_entropy(logits, labels, reduction="none") # [1, L-1]
|
|
112
|
+
log_probs = -ce # log p(token)
|
|
113
|
+
|
|
114
|
+
token_mask = mask[:, 1:] # align with shifted labels
|
|
115
|
+
masked_log_probs = log_probs * token_mask
|
|
116
|
+
|
|
117
|
+
# Sum over valid tokens → trajectory-level log-prob
|
|
118
|
+
return mx.sum(masked_log_probs, axis=1) # [1]
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
# -------------------------------------------------------------------------
|
|
122
|
+
# Training Loop
|
|
123
|
+
# -------------------------------------------------------------------------
|
|
124
|
+
def train(args):
|
|
125
|
+
print(f"Loading model: {args.model} with adapter: {args.adapter_path}")
|
|
126
|
+
model, tokenizer = load(args.model, adapter_path=args.adapter_path)
|
|
127
|
+
|
|
128
|
+
# -- Freeze all, then unfreeze LoRA adapters --
|
|
129
|
+
model.freeze()
|
|
130
|
+
for m in model.modules():
|
|
131
|
+
if isinstance(m, LoRALinear):
|
|
132
|
+
m.unfreeze()
|
|
133
|
+
if hasattr(m, "linear"):
|
|
134
|
+
m.linear.freeze()
|
|
135
|
+
|
|
136
|
+
# Reset adapters for pure RL (start from base model behavior)
|
|
137
|
+
if args.reset_adapters:
|
|
138
|
+
if hasattr(m, "lora_a"):
|
|
139
|
+
m.lora_a = mx.random.normal(m.lora_a.shape) * 0.02
|
|
140
|
+
if hasattr(m, "lora_b"):
|
|
141
|
+
m.lora_b = mx.zeros(m.lora_b.shape)
|
|
142
|
+
|
|
143
|
+
if args.reset_adapters:
|
|
144
|
+
mx.eval(model.parameters())
|
|
145
|
+
print("Reset LoRA adapters to random/zero (pure RL from base model behavior).")
|
|
146
|
+
|
|
147
|
+
# Count parameters
|
|
148
|
+
total_params = sum(p.size for _, p in tree_flatten(model.parameters()))
|
|
149
|
+
trainable_params = sum(p.size for _, p in tree_flatten(model.trainable_parameters()))
|
|
150
|
+
print(f"Trainable params: {trainable_params} / {total_params} ({trainable_params/total_params:.2%})")
|
|
151
|
+
|
|
152
|
+
# -- Optimizer --
|
|
153
|
+
optimizer = optim.AdamW(learning_rate=args.lr)
|
|
154
|
+
|
|
155
|
+
# -- Load Data --
|
|
156
|
+
prompts = []
|
|
157
|
+
with open(args.data, "r") as f:
|
|
158
|
+
for line in f:
|
|
159
|
+
if not line.strip():
|
|
160
|
+
continue
|
|
161
|
+
try:
|
|
162
|
+
item = json.loads(line)
|
|
163
|
+
if "messages" in item:
|
|
164
|
+
for msg in reversed(item["messages"]):
|
|
165
|
+
if msg["role"] == "user":
|
|
166
|
+
prompts.append(msg["content"])
|
|
167
|
+
break
|
|
168
|
+
except Exception:
|
|
169
|
+
pass
|
|
170
|
+
|
|
171
|
+
print(f"Loaded {len(prompts)} prompts.")
|
|
172
|
+
if not prompts:
|
|
173
|
+
print("No prompts found! Exiting.")
|
|
174
|
+
return
|
|
175
|
+
|
|
176
|
+
# -- Sampler for generation --
|
|
177
|
+
sampler = make_sampler(temp=args.temp)
|
|
178
|
+
|
|
179
|
+
# ----------------------------------------------------------------
|
|
180
|
+
# GRPO Training Loop
|
|
181
|
+
# ----------------------------------------------------------------
|
|
182
|
+
print(f"\nStarting GRPO training for {args.iter} iterations")
|
|
183
|
+
print(f" Group size: {args.group_size}")
|
|
184
|
+
print(f" Temperature: {args.temp}")
|
|
185
|
+
print(f" Learning rate: {args.lr}")
|
|
186
|
+
print(f" KL weight (β): {args.kl_weight}")
|
|
187
|
+
print(f" Clip epsilon: {args.clip_eps}")
|
|
188
|
+
print(f" Max grad norm: {args.max_grad_norm}")
|
|
189
|
+
print()
|
|
190
|
+
|
|
191
|
+
best_mean_reward = -float("inf")
|
|
192
|
+
|
|
193
|
+
for i in range(args.iter):
|
|
194
|
+
prompt = prompts[i % len(prompts)]
|
|
195
|
+
prompt_tokens = tokenizer.encode(prompt)
|
|
196
|
+
prompt_len = len(prompt_tokens)
|
|
197
|
+
|
|
198
|
+
print(f"[Iter {i+1}/{args.iter}] Prompt: {prompt[:60]}...")
|
|
199
|
+
|
|
200
|
+
# ---- 1. Generation Phase ----
|
|
201
|
+
completions = []
|
|
202
|
+
full_inputs = []
|
|
203
|
+
masks = []
|
|
204
|
+
|
|
205
|
+
for _ in range(args.group_size):
|
|
206
|
+
text = generate(
|
|
207
|
+
model, tokenizer, prompt=prompt,
|
|
208
|
+
max_tokens=args.max_tokens, verbose=False, sampler=sampler,
|
|
209
|
+
)
|
|
210
|
+
completions.append(text)
|
|
211
|
+
|
|
212
|
+
full_text = prompt + text
|
|
213
|
+
full_tokens = tokenizer.encode(full_text)
|
|
214
|
+
full_inputs.append(mx.array(full_tokens))
|
|
215
|
+
|
|
216
|
+
L = len(full_tokens)
|
|
217
|
+
m = mx.zeros((L,), dtype=mx.float32)
|
|
218
|
+
m[prompt_len:] = 1.0
|
|
219
|
+
masks.append(m)
|
|
220
|
+
|
|
221
|
+
# ---- 2. Reward Phase ----
|
|
222
|
+
rewards = compute_rewards([prompt] * args.group_size, completions)
|
|
223
|
+
mean_r = mx.mean(rewards)
|
|
224
|
+
std_r = mx.max(mx.array([mx.std(rewards), mx.array(1e-4)])) # floor std
|
|
225
|
+
advantages = (rewards - mean_r) / std_r
|
|
226
|
+
|
|
227
|
+
print(f" Completions: {[c[:40] + '...' for c in completions]}")
|
|
228
|
+
print(f" Rewards: {rewards.tolist()}")
|
|
229
|
+
print(f" Advantages: {[f'{a:.3f}' for a in advantages.tolist()]}")
|
|
230
|
+
|
|
231
|
+
# Skip update if all advantages are zero (no learning signal)
|
|
232
|
+
if mx.max(mx.abs(advantages)).item() < 1e-6:
|
|
233
|
+
print(" ⏭ Skipping update (zero variance in rewards)")
|
|
234
|
+
continue
|
|
235
|
+
|
|
236
|
+
# ---- 3. Pre-compute old log-probs (serve as BOTH old policy AND reference) ----
|
|
237
|
+
# In GRPO, the "old policy" IS the reference policy for this iteration.
|
|
238
|
+
# We compute log-probs before any parameter updates happen.
|
|
239
|
+
old_log_probs = []
|
|
240
|
+
for j in range(args.group_size):
|
|
241
|
+
inp = full_inputs[j][None, :]
|
|
242
|
+
msk = masks[j][None, :]
|
|
243
|
+
lp = compute_token_log_probs(model, inp, msk)
|
|
244
|
+
mx.eval(lp)
|
|
245
|
+
old_log_probs.append(mx.stop_gradient(lp))
|
|
246
|
+
|
|
247
|
+
# ---- 4. Policy Update with Clipping + KL ----
|
|
248
|
+
step_loss = 0.0
|
|
249
|
+
step_kl = 0.0
|
|
250
|
+
|
|
251
|
+
for j in range(args.group_size):
|
|
252
|
+
inp = full_inputs[j][None, :]
|
|
253
|
+
msk = masks[j][None, :]
|
|
254
|
+
adv = advantages[j]
|
|
255
|
+
old_lp = old_log_probs[j] # already stop_gradient'd
|
|
256
|
+
|
|
257
|
+
def grpo_loss(model_inner):
|
|
258
|
+
# Current policy log-prob
|
|
259
|
+
cur_lp = compute_token_log_probs(model_inner, inp, msk)
|
|
260
|
+
|
|
261
|
+
# PPO-style ratio clipping
|
|
262
|
+
ratio = mx.exp(cur_lp - old_lp)
|
|
263
|
+
clipped_ratio = mx.clip(ratio, 1.0 - args.clip_eps, 1.0 + args.clip_eps)
|
|
264
|
+
|
|
265
|
+
# Surrogate objective (we minimize, so negate)
|
|
266
|
+
surr1 = ratio * adv
|
|
267
|
+
surr2 = clipped_ratio * adv
|
|
268
|
+
policy_loss = -mx.minimum(surr1, surr2)
|
|
269
|
+
|
|
270
|
+
# KL divergence penalty: D_KL(π_θ || π_old) ≈ log(π_θ/π_old)
|
|
271
|
+
# Using old policy as reference prevents drift
|
|
272
|
+
kl = cur_lp - old_lp
|
|
273
|
+
kl_penalty = args.kl_weight * kl
|
|
274
|
+
|
|
275
|
+
return mx.mean(policy_loss + kl_penalty)
|
|
276
|
+
|
|
277
|
+
loss, grads = nn.value_and_grad(model, grpo_loss)(model)
|
|
278
|
+
|
|
279
|
+
# ---- Gradient norm clipping ----
|
|
280
|
+
grad_norms_sq = sum(
|
|
281
|
+
mx.sum(g * g).item()
|
|
282
|
+
for _, g in tree_flatten(grads)
|
|
283
|
+
)
|
|
284
|
+
grad_norm = math.sqrt(grad_norms_sq)
|
|
285
|
+
|
|
286
|
+
if grad_norm > args.max_grad_norm:
|
|
287
|
+
scale = args.max_grad_norm / (grad_norm + 1e-6)
|
|
288
|
+
grads = tree_map(lambda g: g * scale, grads)
|
|
289
|
+
|
|
290
|
+
optimizer.update(model, grads)
|
|
291
|
+
mx.eval(model.parameters())
|
|
292
|
+
|
|
293
|
+
step_loss += loss.item()
|
|
294
|
+
|
|
295
|
+
# Track KL for monitoring
|
|
296
|
+
cur_lp_check = compute_token_log_probs(model, inp, msk)
|
|
297
|
+
mx.eval(cur_lp_check)
|
|
298
|
+
kl_val = (cur_lp_check - old_lp).item()
|
|
299
|
+
step_kl += kl_val
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
avg_loss = step_loss / args.group_size
|
|
303
|
+
avg_kl = step_kl / args.group_size
|
|
304
|
+
|
|
305
|
+
print(f" Loss: {avg_loss:.4f} | KL: {avg_kl:.4f} | GradNorm: {grad_norm:.2f}")
|
|
306
|
+
|
|
307
|
+
# ---- Early stopping on KL blow-up ----
|
|
308
|
+
if abs(avg_kl) > args.kl_max:
|
|
309
|
+
print(f"\n⚠️ KL divergence ({avg_kl:.2f}) exceeded max ({args.kl_max}). Stopping early.")
|
|
310
|
+
break
|
|
311
|
+
|
|
312
|
+
# Track best reward
|
|
313
|
+
mr = mean_r.item()
|
|
314
|
+
if mr > best_mean_reward:
|
|
315
|
+
best_mean_reward = mr
|
|
316
|
+
|
|
317
|
+
# ---- Save ----
|
|
318
|
+
print(f"\nBest mean reward: {best_mean_reward:.3f}")
|
|
319
|
+
if args.save_path:
|
|
320
|
+
import os
|
|
321
|
+
import shutil
|
|
322
|
+
os.makedirs(os.path.dirname(args.save_path) or ".", exist_ok=True)
|
|
323
|
+
|
|
324
|
+
# Save ONLY trainable (LoRA) parameters, not the full model
|
|
325
|
+
trainable = dict(tree_flatten(model.trainable_parameters()))
|
|
326
|
+
mx.save_safetensors(args.save_path, trainable)
|
|
327
|
+
print(f"Saved {len(trainable)} adapter weight tensors to {args.save_path}")
|
|
328
|
+
|
|
329
|
+
# Copy adapter_config.json so mlx_lm.load() can find it
|
|
330
|
+
adapter_dir = os.path.dirname(args.save_path)
|
|
331
|
+
config_src = os.path.join(args.adapter_path, "adapter_config.json") if args.adapter_path else None
|
|
332
|
+
if config_src and os.path.exists(config_src):
|
|
333
|
+
config_dst = os.path.join(adapter_dir, "adapter_config.json")
|
|
334
|
+
if not os.path.exists(config_dst):
|
|
335
|
+
shutil.copy2(config_src, config_dst)
|
|
336
|
+
print(f"Copied adapter_config.json to {adapter_dir}")
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
if __name__ == "__main__":
|
|
340
|
+
parser = argparse.ArgumentParser(description="GRPO training for shouldRespond")
|
|
341
|
+
parser.add_argument("--model", type=str, required=True, help="Base model path")
|
|
342
|
+
parser.add_argument("--adapter-path", type=str, default=None, help="SFT adapter to start from")
|
|
343
|
+
parser.add_argument("--data", type=str, required=True, help="Training data JSONL")
|
|
344
|
+
parser.add_argument("--iter", type=int, default=50, help="Number of GRPO iterations")
|
|
345
|
+
parser.add_argument("--group-size", type=int, default=8, help="Completions per prompt (G)")
|
|
346
|
+
parser.add_argument("--max-tokens", type=int, default=150, help="Max generation tokens")
|
|
347
|
+
parser.add_argument("--lr", type=float, default=5e-7, help="Learning rate")
|
|
348
|
+
parser.add_argument("--temp", type=float, default=0.7, help="Sampling temperature")
|
|
349
|
+
parser.add_argument("--kl-weight", type=float, default=0.1, help="KL penalty coefficient β")
|
|
350
|
+
parser.add_argument("--kl-max", type=float, default=10.0, help="Max KL before early stop")
|
|
351
|
+
parser.add_argument("--clip-eps", type=float, default=0.2, help="PPO clip epsilon")
|
|
352
|
+
parser.add_argument("--max-grad-norm", type=float, default=1.0, help="Gradient norm clip")
|
|
353
|
+
parser.add_argument("--save-path", type=str,
|
|
354
|
+
default="trained_models/should_respond_rl/adapters.safetensors",
|
|
355
|
+
help="Output adapter weights path")
|
|
356
|
+
parser.add_argument("--reset-adapters", action="store_true",
|
|
357
|
+
help="Reset LoRA weights to random/zero for pure RL training")
|
|
358
|
+
|
|
359
|
+
args = parser.parse_args()
|
|
360
|
+
train(args)
|
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Train from JSONL Scored Trajectories
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import os
|
|
7
|
+
import sys
|
|
8
|
+
import json
|
|
9
|
+
import random
|
|
10
|
+
import argparse
|
|
11
|
+
import logging
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import List, Dict, Any
|
|
14
|
+
|
|
15
|
+
# Add src to path
|
|
16
|
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
17
|
+
|
|
18
|
+
logging.basicConfig(
|
|
19
|
+
level=logging.INFO,
|
|
20
|
+
format='%(asctime)s [%(levelname)s] %(message)s'
|
|
21
|
+
)
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
def detect_backend() -> str:
|
|
25
|
+
"""Auto-detect the best available backend."""
|
|
26
|
+
try:
|
|
27
|
+
import mlx.core
|
|
28
|
+
logger.info("MLX backend available (Apple Silicon)")
|
|
29
|
+
return "mlx"
|
|
30
|
+
except ImportError:
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
try:
|
|
34
|
+
import torch
|
|
35
|
+
if torch.cuda.is_available():
|
|
36
|
+
logger.info(f"CUDA backend available: {torch.cuda.get_device_name(0)}")
|
|
37
|
+
return "cuda"
|
|
38
|
+
except ImportError:
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
logger.warning("No GPU backend available, falling back to CPU")
|
|
42
|
+
return "cpu"
|
|
43
|
+
|
|
44
|
+
def load_and_process_data(input_file: str, min_score: float) -> List[Dict[str, Any]]:
|
|
45
|
+
"""
|
|
46
|
+
Load trajectories from JSONL, filter by score, and convert to chat format.
|
|
47
|
+
"""
|
|
48
|
+
samples = []
|
|
49
|
+
|
|
50
|
+
if not os.path.exists(input_file):
|
|
51
|
+
raise FileNotFoundError(f"Input file not found: {input_file}")
|
|
52
|
+
|
|
53
|
+
logger.info(f"Loading data from {input_file}...")
|
|
54
|
+
|
|
55
|
+
with open(input_file, 'r') as f:
|
|
56
|
+
for line in f:
|
|
57
|
+
if not line.strip():
|
|
58
|
+
continue
|
|
59
|
+
try:
|
|
60
|
+
traj = json.loads(line)
|
|
61
|
+
|
|
62
|
+
# Check for direct messages format (SFT dataset)
|
|
63
|
+
if 'messages' in traj:
|
|
64
|
+
samples.append(traj)
|
|
65
|
+
continue
|
|
66
|
+
|
|
67
|
+
# Filter by score if present
|
|
68
|
+
if traj.get('isScored'):
|
|
69
|
+
if traj.get('score', 0) < min_score:
|
|
70
|
+
continue
|
|
71
|
+
|
|
72
|
+
# Extract conversation (Trajectory format)
|
|
73
|
+
# We want to train the model to generate the ACTION based on observation
|
|
74
|
+
# Or generate the RESPONSE based on the task
|
|
75
|
+
|
|
76
|
+
task = traj.get('metadata', {}).get('task', '')
|
|
77
|
+
steps = traj.get('steps', [])
|
|
78
|
+
|
|
79
|
+
if not task:
|
|
80
|
+
continue
|
|
81
|
+
|
|
82
|
+
# Simple Format:
|
|
83
|
+
# System: You are a helpful assistant.
|
|
84
|
+
# User: <Task>
|
|
85
|
+
# Assistant: <Response>
|
|
86
|
+
|
|
87
|
+
# In a real scenario, we might want to train on every step.
|
|
88
|
+
# For this benchmark, we train on the final response to the task.
|
|
89
|
+
|
|
90
|
+
last_step = steps[-1] if steps else None
|
|
91
|
+
if not last_step:
|
|
92
|
+
continue
|
|
93
|
+
|
|
94
|
+
response = last_step.get('action', {}).get('parameters', {}).get('text')
|
|
95
|
+
if not response:
|
|
96
|
+
continue
|
|
97
|
+
|
|
98
|
+
messages = [
|
|
99
|
+
{"role": "system", "content": "You are a helpful assistant."},
|
|
100
|
+
{"role": "user", "content": task},
|
|
101
|
+
{"role": "assistant", "content": response}
|
|
102
|
+
]
|
|
103
|
+
|
|
104
|
+
samples.append({"messages": messages})
|
|
105
|
+
|
|
106
|
+
except json.JSONDecodeError:
|
|
107
|
+
continue
|
|
108
|
+
|
|
109
|
+
logger.info(f"Loaded {len(samples)} valid training samples (score >= {min_score})")
|
|
110
|
+
return samples
|
|
111
|
+
|
|
112
|
+
def train_mlx(
|
|
113
|
+
samples: List[Dict],
|
|
114
|
+
model_name: str,
|
|
115
|
+
output_dir: str,
|
|
116
|
+
iters: int,
|
|
117
|
+
batch_size: int,
|
|
118
|
+
learning_rate: float
|
|
119
|
+
):
|
|
120
|
+
"""Train using MLX LoRA."""
|
|
121
|
+
import subprocess
|
|
122
|
+
|
|
123
|
+
logger.info("Starting MLX Training...")
|
|
124
|
+
|
|
125
|
+
# Prepare Data Directory
|
|
126
|
+
data_dir = os.path.join(output_dir, "data")
|
|
127
|
+
os.makedirs(data_dir, exist_ok=True)
|
|
128
|
+
|
|
129
|
+
# Split Train/Valid
|
|
130
|
+
random.shuffle(samples)
|
|
131
|
+
split_idx = int(len(samples) * 0.9)
|
|
132
|
+
train_samples = samples[:split_idx]
|
|
133
|
+
valid_samples = samples[split_idx:]
|
|
134
|
+
|
|
135
|
+
# Write JSONL for MLX
|
|
136
|
+
with open(os.path.join(data_dir, "train.jsonl"), 'w') as f:
|
|
137
|
+
for s in train_samples:
|
|
138
|
+
f.write(json.dumps(s) + "\n")
|
|
139
|
+
|
|
140
|
+
with open(os.path.join(data_dir, "valid.jsonl"), 'w') as f:
|
|
141
|
+
for s in valid_samples:
|
|
142
|
+
f.write(json.dumps(s) + "\n")
|
|
143
|
+
|
|
144
|
+
if not valid_samples:
|
|
145
|
+
# Create a dummy validation set if empty (e.g. only 1 sample)
|
|
146
|
+
with open(os.path.join(data_dir, "valid.jsonl"), 'w') as f:
|
|
147
|
+
f.write(json.dumps(train_samples[0]) + "\n")
|
|
148
|
+
|
|
149
|
+
adapter_path = os.path.join(output_dir, "adapters")
|
|
150
|
+
|
|
151
|
+
# Construct MLX Command
|
|
152
|
+
# We use the python module directly via subprocess to avoid import issues with conflicting arguments
|
|
153
|
+
cmd = [
|
|
154
|
+
sys.executable, "-m", "mlx_lm.lora",
|
|
155
|
+
"--model", model_name,
|
|
156
|
+
"--train",
|
|
157
|
+
"--data", data_dir,
|
|
158
|
+
"--adapter-path", adapter_path,
|
|
159
|
+
"--batch-size", str(batch_size),
|
|
160
|
+
"--iters", str(iters),
|
|
161
|
+
"--learning-rate", str(learning_rate),
|
|
162
|
+
"--steps-per-report", "5",
|
|
163
|
+
"--save-every", "10",
|
|
164
|
+
]
|
|
165
|
+
|
|
166
|
+
logger.info(f"Running command: {' '.join(cmd)}")
|
|
167
|
+
subprocess.run(cmd, check=True)
|
|
168
|
+
logger.info(f"Training complete. Adapters saved to {adapter_path}")
|
|
169
|
+
|
|
170
|
+
def main():
|
|
171
|
+
parser = argparse.ArgumentParser(description="Train from JSONL")
|
|
172
|
+
parser.add_argument("--input", default="scored_trajectories.jsonl", help="Input JSONL file")
|
|
173
|
+
parser.add_argument("--output", default="trained_models/jsonl_run", help="Output directory")
|
|
174
|
+
parser.add_argument("--min-score", type=float, default=0.7, help="Minimum score to include")
|
|
175
|
+
|
|
176
|
+
parser.add_argument("--model", default="mlx-community/Qwen2.5-1.5B-Instruct-4bit", help="Base model (default: Qwen 1.5B 4bit for Mac)")
|
|
177
|
+
parser.add_argument("--backend", choices=["mlx", "cuda", "cpu"], default=None)
|
|
178
|
+
|
|
179
|
+
parser.add_argument("--iters", type=int, default=100, help="Training iterations")
|
|
180
|
+
parser.add_argument("--batch-size", type=int, default=1, help="Batch size") # MLX handles small batch sizes well
|
|
181
|
+
parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate")
|
|
182
|
+
|
|
183
|
+
args = parser.parse_args()
|
|
184
|
+
|
|
185
|
+
# Detect Backend
|
|
186
|
+
backend = args.backend or detect_backend()
|
|
187
|
+
logger.info(f"Backend: {backend}")
|
|
188
|
+
|
|
189
|
+
# Load Data
|
|
190
|
+
full_input_path = os.path.abspath(args.input)
|
|
191
|
+
samples = load_and_process_data(full_input_path, args.min_score)
|
|
192
|
+
|
|
193
|
+
if not samples:
|
|
194
|
+
logger.error("No samples found. Exiting.")
|
|
195
|
+
return
|
|
196
|
+
|
|
197
|
+
# Train
|
|
198
|
+
output_dir = os.path.abspath(args.output)
|
|
199
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
200
|
+
|
|
201
|
+
if backend == "mlx":
|
|
202
|
+
train_mlx(
|
|
203
|
+
samples,
|
|
204
|
+
args.model,
|
|
205
|
+
output_dir,
|
|
206
|
+
args.iters,
|
|
207
|
+
args.batch_size,
|
|
208
|
+
args.lr
|
|
209
|
+
)
|
|
210
|
+
else:
|
|
211
|
+
logger.warning(f"Backend {backend} detected. MLX not available. Running in DRY RUN mode for verification.")
|
|
212
|
+
# Dry run: just verify data processing and split
|
|
213
|
+
random.shuffle(samples)
|
|
214
|
+
split_idx = int(len(samples) * 0.9)
|
|
215
|
+
train_samples = samples[:split_idx]
|
|
216
|
+
valid_samples = samples[split_idx:]
|
|
217
|
+
|
|
218
|
+
logger.info(f"Dry Run: Would train on {len(train_samples)} samples, validate on {len(valid_samples)} samples.")
|
|
219
|
+
logger.info(f"Sample data: {json.dumps(train_samples[0] if train_samples else {}, indent=2)}")
|
|
220
|
+
logger.info("Verification complete (no actual training performed on CPU).")
|
|
221
|
+
|
|
222
|
+
if __name__ == "__main__":
|
|
223
|
+
main()
|