jfl 0.5.0 → 0.6.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/commands/context-hub.d.ts +1 -0
- package/dist/commands/context-hub.d.ts.map +1 -1
- package/dist/commands/context-hub.js +246 -2
- package/dist/commands/context-hub.js.map +1 -1
- package/dist/commands/peter.d.ts +2 -0
- package/dist/commands/peter.d.ts.map +1 -1
- package/dist/commands/peter.js +242 -52
- package/dist/commands/peter.js.map +1 -1
- package/dist/commands/setup.d.ts +12 -0
- package/dist/commands/setup.d.ts.map +1 -0
- package/dist/commands/setup.js +322 -0
- package/dist/commands/setup.js.map +1 -0
- package/dist/commands/train.d.ts +33 -0
- package/dist/commands/train.d.ts.map +1 -0
- package/dist/commands/train.js +510 -0
- package/dist/commands/train.js.map +1 -0
- package/dist/commands/verify.d.ts +14 -0
- package/dist/commands/verify.d.ts.map +1 -0
- package/dist/commands/verify.js +276 -0
- package/dist/commands/verify.js.map +1 -0
- package/dist/dashboard-static/assets/index-CW9ZxqX8.css +1 -0
- package/dist/dashboard-static/assets/index-DNN__p4K.js +121 -0
- package/dist/dashboard-static/index.html +2 -2
- package/dist/index.js +99 -3
- package/dist/index.js.map +1 -1
- package/dist/lib/agent-session.d.ts.map +1 -1
- package/dist/lib/agent-session.js +12 -4
- package/dist/lib/agent-session.js.map +1 -1
- package/dist/lib/eval-snapshot.js +1 -1
- package/dist/lib/eval-snapshot.js.map +1 -1
- package/dist/lib/pi-sky/bridge.d.ts +55 -0
- package/dist/lib/pi-sky/bridge.d.ts.map +1 -0
- package/dist/lib/pi-sky/bridge.js +264 -0
- package/dist/lib/pi-sky/bridge.js.map +1 -0
- package/dist/lib/pi-sky/cost-monitor.d.ts +21 -0
- package/dist/lib/pi-sky/cost-monitor.d.ts.map +1 -0
- package/dist/lib/pi-sky/cost-monitor.js +126 -0
- package/dist/lib/pi-sky/cost-monitor.js.map +1 -0
- package/dist/lib/pi-sky/eval-sweep.d.ts +27 -0
- package/dist/lib/pi-sky/eval-sweep.d.ts.map +1 -0
- package/dist/lib/pi-sky/eval-sweep.js +141 -0
- package/dist/lib/pi-sky/eval-sweep.js.map +1 -0
- package/dist/lib/pi-sky/event-router.d.ts +32 -0
- package/dist/lib/pi-sky/event-router.d.ts.map +1 -0
- package/dist/lib/pi-sky/event-router.js +176 -0
- package/dist/lib/pi-sky/event-router.js.map +1 -0
- package/dist/lib/pi-sky/experiment.d.ts +9 -0
- package/dist/lib/pi-sky/experiment.d.ts.map +1 -0
- package/dist/lib/pi-sky/experiment.js +83 -0
- package/dist/lib/pi-sky/experiment.js.map +1 -0
- package/dist/lib/pi-sky/index.d.ts +16 -0
- package/dist/lib/pi-sky/index.d.ts.map +1 -0
- package/dist/lib/pi-sky/index.js +16 -0
- package/dist/lib/pi-sky/index.js.map +1 -0
- package/dist/lib/pi-sky/stratus-gate.d.ts +28 -0
- package/dist/lib/pi-sky/stratus-gate.d.ts.map +1 -0
- package/dist/lib/pi-sky/stratus-gate.js +61 -0
- package/dist/lib/pi-sky/stratus-gate.js.map +1 -0
- package/dist/lib/pi-sky/swarm.d.ts +28 -0
- package/dist/lib/pi-sky/swarm.d.ts.map +1 -0
- package/dist/lib/pi-sky/swarm.js +208 -0
- package/dist/lib/pi-sky/swarm.js.map +1 -0
- package/dist/lib/pi-sky/types.d.ts +139 -0
- package/dist/lib/pi-sky/types.d.ts.map +1 -0
- package/dist/lib/pi-sky/types.js +2 -0
- package/dist/lib/pi-sky/types.js.map +1 -0
- package/dist/lib/pi-sky/voice-bridge.d.ts +20 -0
- package/dist/lib/pi-sky/voice-bridge.d.ts.map +1 -0
- package/dist/lib/pi-sky/voice-bridge.js +91 -0
- package/dist/lib/pi-sky/voice-bridge.js.map +1 -0
- package/dist/lib/policy-head.d.ts +16 -1
- package/dist/lib/policy-head.d.ts.map +1 -1
- package/dist/lib/policy-head.js +117 -19
- package/dist/lib/policy-head.js.map +1 -1
- package/dist/lib/predictor.d.ts +10 -0
- package/dist/lib/predictor.d.ts.map +1 -1
- package/dist/lib/predictor.js +46 -7
- package/dist/lib/predictor.js.map +1 -1
- package/dist/lib/setup/agent-generator.d.ts +18 -0
- package/dist/lib/setup/agent-generator.d.ts.map +1 -0
- package/dist/lib/setup/agent-generator.js +114 -0
- package/dist/lib/setup/agent-generator.js.map +1 -0
- package/dist/lib/setup/context-analyzer.d.ts +16 -0
- package/dist/lib/setup/context-analyzer.d.ts.map +1 -0
- package/dist/lib/setup/context-analyzer.js +112 -0
- package/dist/lib/setup/context-analyzer.js.map +1 -0
- package/dist/lib/setup/doc-auditor.d.ts +54 -0
- package/dist/lib/setup/doc-auditor.d.ts.map +1 -0
- package/dist/lib/setup/doc-auditor.js +629 -0
- package/dist/lib/setup/doc-auditor.js.map +1 -0
- package/dist/lib/setup/domain-generator.d.ts +7 -0
- package/dist/lib/setup/domain-generator.d.ts.map +1 -0
- package/dist/lib/setup/domain-generator.js +58 -0
- package/dist/lib/setup/domain-generator.js.map +1 -0
- package/dist/lib/setup/smart-eval-generator.d.ts +38 -0
- package/dist/lib/setup/smart-eval-generator.d.ts.map +1 -0
- package/dist/lib/setup/smart-eval-generator.js +378 -0
- package/dist/lib/setup/smart-eval-generator.js.map +1 -0
- package/dist/lib/setup/smart-recommender.d.ts +63 -0
- package/dist/lib/setup/smart-recommender.d.ts.map +1 -0
- package/dist/lib/setup/smart-recommender.js +329 -0
- package/dist/lib/setup/smart-recommender.js.map +1 -0
- package/dist/lib/setup/spec-generator.d.ts +63 -0
- package/dist/lib/setup/spec-generator.d.ts.map +1 -0
- package/dist/lib/setup/spec-generator.js +310 -0
- package/dist/lib/setup/spec-generator.js.map +1 -0
- package/dist/lib/setup/violation-agent-generator.d.ts +32 -0
- package/dist/lib/setup/violation-agent-generator.d.ts.map +1 -0
- package/dist/lib/setup/violation-agent-generator.js +255 -0
- package/dist/lib/setup/violation-agent-generator.js.map +1 -0
- package/package.json +1 -1
- package/packages/pi/extensions/context.ts +88 -55
- package/packages/pi/extensions/hub-resolver.ts +63 -0
- package/packages/pi/extensions/index.ts +16 -3
- package/packages/pi/extensions/memory-tool.ts +9 -4
- package/packages/pi/extensions/session.ts +68 -16
- package/packages/pi/extensions/tool-renderers.ts +23 -8
- package/scripts/train/requirements.txt +5 -0
- package/scripts/train/train-policy-head.py +477 -0
- package/scripts/train/v2/dataset.py +81 -0
- package/scripts/train/v2/domain.json +18 -0
- package/scripts/train/v2/eval.py +196 -0
- package/scripts/train/v2/generate_data.py +219 -0
- package/scripts/train/v2/infer.py +188 -0
- package/scripts/train/v2/model.py +112 -0
- package/scripts/train/v2/precompute.py +132 -0
- package/scripts/train/v2/train.py +302 -0
- package/scripts/train/v2/transform_buffer.py +227 -0
- package/scripts/train/v2/validate_data.py +115 -0
- package/template/.claude/settings.json +2 -15
- package/template/scripts/session/session-cleanup.sh +2 -11
- package/template/scripts/session/session-end-hub.sh +72 -0
- package/template/scripts/session/session-start-hub.sh +105 -0
- package/dist/dashboard-static/assets/index-B6b867Pv.js +0 -121
- package/dist/dashboard-static/assets/index-Y4BrqxV-.css +0 -1
|
@@ -0,0 +1,477 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Policy head trainer for JFL RL agents.
|
|
4
|
+
|
|
5
|
+
Trains a 3-layer MLP to predict reward from (state, action) embeddings.
|
|
6
|
+
Matches PolicyWeights JSON format consumed by policy-head.ts inference.
|
|
7
|
+
|
|
8
|
+
Architecture (per Andrew @ Stratus, 2026-03-13 call):
|
|
9
|
+
Input: concat(state_emb, action_emb) = 2 * embed_dim
|
|
10
|
+
Layer 1: Linear(input_dim, 512) + ReLU
|
|
11
|
+
Layer 2: Linear(512, 512) + ReLU + LayerNorm + Dropout
|
|
12
|
+
Layer 3: Linear(512, 1)
|
|
13
|
+
|
|
14
|
+
Usage:
|
|
15
|
+
python train-policy-head.py --data /path/to/.jfl/training-buffer.jsonl
|
|
16
|
+
python train-policy-head.py --embeddings /path/to/embeddings.npz --rewards /path/to/rewards.npy
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import argparse
|
|
20
|
+
import json
|
|
21
|
+
import os
|
|
22
|
+
import sys
|
|
23
|
+
import time
|
|
24
|
+
from pathlib import Path
|
|
25
|
+
|
|
26
|
+
import numpy as np
|
|
27
|
+
import torch
|
|
28
|
+
import torch.nn as nn
|
|
29
|
+
from torch.utils.data import DataLoader, TensorDataset, random_split
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class PolicyHead(nn.Module):
|
|
33
|
+
def __init__(self, input_dim: int, hidden_dim: int = 512):
|
|
34
|
+
super().__init__()
|
|
35
|
+
self.net = nn.Sequential(
|
|
36
|
+
nn.Linear(input_dim, hidden_dim),
|
|
37
|
+
nn.ReLU(),
|
|
38
|
+
nn.Linear(hidden_dim, hidden_dim),
|
|
39
|
+
nn.ReLU(),
|
|
40
|
+
nn.LayerNorm(hidden_dim),
|
|
41
|
+
nn.Dropout(0.1),
|
|
42
|
+
nn.Linear(hidden_dim, 1),
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
def forward(self, x):
|
|
46
|
+
return self.net(x).squeeze(-1)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def get_embeddings_from_stratus(texts: list[str], api_key: str, api_url: str) -> np.ndarray:
|
|
50
|
+
"""Batch-embed texts via Stratus /v1/embeddings endpoint."""
|
|
51
|
+
import requests
|
|
52
|
+
|
|
53
|
+
embeddings = []
|
|
54
|
+
batch_size = 32
|
|
55
|
+
|
|
56
|
+
for i in range(0, len(texts), batch_size):
|
|
57
|
+
batch = texts[i : i + batch_size]
|
|
58
|
+
resp = requests.post(
|
|
59
|
+
f"{api_url}/v1/embeddings",
|
|
60
|
+
headers={
|
|
61
|
+
"Authorization": f"Bearer {api_key}",
|
|
62
|
+
"Content-Type": "application/json",
|
|
63
|
+
},
|
|
64
|
+
json={"model": "stratus-x1ac-base", "input": batch},
|
|
65
|
+
timeout=30,
|
|
66
|
+
)
|
|
67
|
+
resp.raise_for_status()
|
|
68
|
+
data = resp.json()
|
|
69
|
+
for item in data["data"]:
|
|
70
|
+
embeddings.append(item["embedding"])
|
|
71
|
+
|
|
72
|
+
if i > 0 and i % 100 == 0:
|
|
73
|
+
print(f" Embedded {i}/{len(texts)} texts...")
|
|
74
|
+
|
|
75
|
+
return np.array(embeddings, dtype=np.float32)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def load_training_data(jsonl_path: str, reward_clip: float = 1.0) -> tuple[list[str], list[str], np.ndarray]:
|
|
79
|
+
"""Load training buffer JSONL, return (state_texts, action_texts, rewards).
|
|
80
|
+
|
|
81
|
+
Applies data quality filtering:
|
|
82
|
+
- Clips rewards to [-reward_clip, reward_clip] (default ±1.0)
|
|
83
|
+
- Drops entries with zero reward (no learning signal)
|
|
84
|
+
- Drops entries with missing state/action data
|
|
85
|
+
"""
|
|
86
|
+
state_texts = []
|
|
87
|
+
action_texts = []
|
|
88
|
+
rewards = []
|
|
89
|
+
skipped_zero = 0
|
|
90
|
+
skipped_outlier = 0
|
|
91
|
+
skipped_missing = 0
|
|
92
|
+
|
|
93
|
+
with open(jsonl_path) as f:
|
|
94
|
+
for line in f:
|
|
95
|
+
line = line.strip()
|
|
96
|
+
if not line:
|
|
97
|
+
continue
|
|
98
|
+
try:
|
|
99
|
+
entry = json.loads(line)
|
|
100
|
+
except json.JSONDecodeError:
|
|
101
|
+
continue
|
|
102
|
+
|
|
103
|
+
state = entry.get("state", {})
|
|
104
|
+
action = entry.get("action", {})
|
|
105
|
+
reward = entry.get("reward", {})
|
|
106
|
+
|
|
107
|
+
dims = state.get("dimension_scores", {})
|
|
108
|
+
dims_str = ", ".join(f"{k}={v:.4f}" for k, v in dims.items()) if dims else "none"
|
|
109
|
+
deltas = state.get("recent_deltas", [])
|
|
110
|
+
deltas_str = ", ".join(f"{d:+.4f}" for d in deltas) if deltas else "none"
|
|
111
|
+
|
|
112
|
+
state_text = "\n".join([
|
|
113
|
+
f"Agent: {state.get('agent', 'unknown')}",
|
|
114
|
+
f"Composite: {state.get('composite_score', 0):.4f}",
|
|
115
|
+
f"Tests: {state.get('tests_passing', 0)}/{state.get('tests_total', 0)}",
|
|
116
|
+
f"Trajectory: {state.get('trajectory_length', 0)}",
|
|
117
|
+
f"Dimensions: {dims_str}",
|
|
118
|
+
f"Recent deltas: {deltas_str}",
|
|
119
|
+
])
|
|
120
|
+
|
|
121
|
+
files = action.get("files_affected", [])[:5]
|
|
122
|
+
action_text = "\n".join([
|
|
123
|
+
f"Type: {action.get('type', 'unknown')}",
|
|
124
|
+
f"Description: {action.get('description', '')[:150]}",
|
|
125
|
+
f"Scope: {action.get('scope', 'unknown')}",
|
|
126
|
+
f"Files: {', '.join(files) if files else 'none'}",
|
|
127
|
+
])
|
|
128
|
+
|
|
129
|
+
composite_delta = reward.get("composite_delta", 0.0)
|
|
130
|
+
|
|
131
|
+
if not action.get("description") or not state.get("agent"):
|
|
132
|
+
skipped_missing += 1
|
|
133
|
+
continue
|
|
134
|
+
|
|
135
|
+
if composite_delta == 0.0:
|
|
136
|
+
skipped_zero += 1
|
|
137
|
+
continue
|
|
138
|
+
|
|
139
|
+
if abs(composite_delta) > reward_clip:
|
|
140
|
+
skipped_outlier += 1
|
|
141
|
+
composite_delta = max(-reward_clip, min(reward_clip, composite_delta))
|
|
142
|
+
|
|
143
|
+
state_texts.append(state_text)
|
|
144
|
+
action_texts.append(action_text)
|
|
145
|
+
rewards.append(composite_delta)
|
|
146
|
+
|
|
147
|
+
total_raw = len(state_texts) + skipped_zero + skipped_outlier + skipped_missing
|
|
148
|
+
print(f" Data quality filter (reward_clip=±{reward_clip}):")
|
|
149
|
+
print(f" Raw entries: {total_raw}")
|
|
150
|
+
print(f" Kept: {len(state_texts)}")
|
|
151
|
+
print(f" Skipped (zero): {skipped_zero}")
|
|
152
|
+
print(f" Clipped (outlier):{skipped_outlier}")
|
|
153
|
+
print(f" Skipped (missing):{skipped_missing}")
|
|
154
|
+
|
|
155
|
+
return state_texts, action_texts, np.array(rewards, dtype=np.float32)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def compute_metrics(predictions: np.ndarray, targets: np.ndarray) -> dict:
|
|
159
|
+
"""Compute direction accuracy and rank correlation."""
|
|
160
|
+
direction_correct = np.sum(np.sign(predictions) == np.sign(targets))
|
|
161
|
+
direction_accuracy = direction_correct / len(targets) if len(targets) > 0 else 0.0
|
|
162
|
+
|
|
163
|
+
from scipy.stats import spearmanr
|
|
164
|
+
try:
|
|
165
|
+
rank_corr, _ = spearmanr(predictions, targets)
|
|
166
|
+
if np.isnan(rank_corr):
|
|
167
|
+
rank_corr = 0.0
|
|
168
|
+
except Exception:
|
|
169
|
+
rank_corr = 0.0
|
|
170
|
+
|
|
171
|
+
return {
|
|
172
|
+
"direction_accuracy": float(direction_accuracy),
|
|
173
|
+
"rank_correlation": float(rank_corr),
|
|
174
|
+
"mse": float(np.mean((predictions - targets) ** 2)),
|
|
175
|
+
"mae": float(np.mean(np.abs(predictions - targets))),
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def export_weights(model: PolicyHead, embed_dim: int, target_mean: float, target_std: float,
|
|
180
|
+
train_size: int, metrics: dict, output_path: str):
|
|
181
|
+
"""Export model weights to PolicyWeights JSON format for policy-head.ts inference."""
|
|
182
|
+
state_dict = model.state_dict()
|
|
183
|
+
|
|
184
|
+
def to_list(tensor):
|
|
185
|
+
return tensor.cpu().detach().numpy().tolist()
|
|
186
|
+
|
|
187
|
+
weights = {
|
|
188
|
+
"version": 1,
|
|
189
|
+
"architecture": "mlp-3layer-512h",
|
|
190
|
+
"embed_dim": embed_dim,
|
|
191
|
+
"mode": "embedding",
|
|
192
|
+
"trained_at": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()),
|
|
193
|
+
"trained_on": train_size,
|
|
194
|
+
"direction_accuracy": metrics["direction_accuracy"],
|
|
195
|
+
"rank_correlation": metrics["rank_correlation"],
|
|
196
|
+
"target_mean": float(target_mean),
|
|
197
|
+
"target_std": float(target_std),
|
|
198
|
+
"layers": {
|
|
199
|
+
"W1": to_list(state_dict["net.0.weight"].T),
|
|
200
|
+
"b1": to_list(state_dict["net.0.bias"]),
|
|
201
|
+
"W2": to_list(state_dict["net.2.weight"].T),
|
|
202
|
+
"b2": to_list(state_dict["net.2.bias"]),
|
|
203
|
+
"W3": to_list(state_dict["net.6.weight"].T),
|
|
204
|
+
"b3": to_list(state_dict["net.6.bias"]),
|
|
205
|
+
},
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
with open(output_path, "w") as f:
|
|
209
|
+
json.dump(weights, f, indent=2)
|
|
210
|
+
|
|
211
|
+
size_mb = os.path.getsize(output_path) / (1024 * 1024)
|
|
212
|
+
print(f"\n Exported weights to {output_path} ({size_mb:.1f} MB)")
|
|
213
|
+
print(f" Direction accuracy: {metrics['direction_accuracy']:.3f}")
|
|
214
|
+
print(f" Rank correlation: {metrics['rank_correlation']:.3f}")
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def train(args):
|
|
218
|
+
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
|
219
|
+
print(f"\n Device: {device}")
|
|
220
|
+
|
|
221
|
+
# Load data
|
|
222
|
+
if args.embeddings and args.rewards:
|
|
223
|
+
print(f" Loading pre-computed embeddings from {args.embeddings}")
|
|
224
|
+
data = np.load(args.embeddings)
|
|
225
|
+
state_embs = data["state_embeddings"]
|
|
226
|
+
action_embs = data["action_embeddings"]
|
|
227
|
+
rewards = np.load(args.rewards)
|
|
228
|
+
else:
|
|
229
|
+
print(f" Loading training data from {args.data}")
|
|
230
|
+
state_texts, action_texts, rewards = load_training_data(args.data, reward_clip=args.reward_clip)
|
|
231
|
+
print(f" Loaded {len(rewards)} usable entries")
|
|
232
|
+
|
|
233
|
+
if len(rewards) < args.min_entries:
|
|
234
|
+
print(f"\n Not enough data: {len(rewards)} < {args.min_entries} minimum")
|
|
235
|
+
print(f" Need {args.min_entries - len(rewards)} more training entries")
|
|
236
|
+
sys.exit(1)
|
|
237
|
+
|
|
238
|
+
api_key = args.api_key or os.environ.get("STRATUS_API_KEY")
|
|
239
|
+
api_url = args.api_url or os.environ.get("STRATUS_API_URL", "https://api.stratus.run")
|
|
240
|
+
|
|
241
|
+
if not api_key:
|
|
242
|
+
dotenv_path = Path(args.data).parent.parent / ".env"
|
|
243
|
+
if dotenv_path.exists():
|
|
244
|
+
for line in dotenv_path.read_text().splitlines():
|
|
245
|
+
if line.startswith("STRATUS_API_KEY="):
|
|
246
|
+
api_key = line.split("=", 1)[1].strip().strip('"').strip("'")
|
|
247
|
+
break
|
|
248
|
+
|
|
249
|
+
if not api_key:
|
|
250
|
+
print("\n STRATUS_API_KEY not set. Cannot compute embeddings.")
|
|
251
|
+
print(" Either set the env var or use --embeddings with pre-computed data.")
|
|
252
|
+
sys.exit(1)
|
|
253
|
+
|
|
254
|
+
# Check for cached embeddings
|
|
255
|
+
cache_dir = Path(args.data).parent / "train-cache"
|
|
256
|
+
cache_dir.mkdir(exist_ok=True)
|
|
257
|
+
cache_path = cache_dir / f"embeddings-{len(rewards)}.npz"
|
|
258
|
+
|
|
259
|
+
if cache_path.exists() and not args.force_embed:
|
|
260
|
+
print(f" Loading cached embeddings from {cache_path}")
|
|
261
|
+
cached = np.load(cache_path)
|
|
262
|
+
state_embs = cached["state_embeddings"]
|
|
263
|
+
action_embs = cached["action_embeddings"]
|
|
264
|
+
else:
|
|
265
|
+
print(f" Computing embeddings for {len(state_texts)} entries...")
|
|
266
|
+
state_embs = get_embeddings_from_stratus(state_texts, api_key, api_url)
|
|
267
|
+
action_embs = get_embeddings_from_stratus(action_texts, api_key, api_url)
|
|
268
|
+
np.savez(cache_path, state_embeddings=state_embs, action_embeddings=action_embs)
|
|
269
|
+
print(f" Cached embeddings to {cache_path}")
|
|
270
|
+
|
|
271
|
+
embed_dim = state_embs.shape[1]
|
|
272
|
+
input_dim = embed_dim * 2
|
|
273
|
+
print(f" Embedding dim: {embed_dim}, Input dim: {input_dim}")
|
|
274
|
+
print(f" Entries: {len(rewards)}, Reward range: [{rewards.min():.4f}, {rewards.max():.4f}]")
|
|
275
|
+
|
|
276
|
+
# Normalize targets
|
|
277
|
+
target_mean = float(rewards.mean())
|
|
278
|
+
target_std = float(rewards.std()) if rewards.std() > 1e-8 else 1.0
|
|
279
|
+
normalized_rewards = (rewards - target_mean) / target_std
|
|
280
|
+
|
|
281
|
+
# Build tensors
|
|
282
|
+
X = np.concatenate([state_embs, action_embs], axis=1)
|
|
283
|
+
X_tensor = torch.tensor(X, dtype=torch.float32)
|
|
284
|
+
y_tensor = torch.tensor(normalized_rewards, dtype=torch.float32)
|
|
285
|
+
|
|
286
|
+
dataset = TensorDataset(X_tensor, y_tensor)
|
|
287
|
+
|
|
288
|
+
# 70/30 split
|
|
289
|
+
val_size = max(1, int(len(dataset) * args.val_ratio))
|
|
290
|
+
train_size = len(dataset) - val_size
|
|
291
|
+
train_dataset, val_dataset = random_split(
|
|
292
|
+
dataset, [train_size, val_size],
|
|
293
|
+
generator=torch.Generator().manual_seed(args.seed)
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
print(f" Train: {train_size}, Val: {val_size}")
|
|
297
|
+
|
|
298
|
+
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
|
|
299
|
+
val_loader = DataLoader(val_dataset, batch_size=args.batch_size)
|
|
300
|
+
|
|
301
|
+
# Build model
|
|
302
|
+
model = PolicyHead(input_dim, hidden_dim=args.hidden_dim).to(device)
|
|
303
|
+
param_count = sum(p.numel() for p in model.parameters())
|
|
304
|
+
print(f" Model parameters: {param_count:,}")
|
|
305
|
+
|
|
306
|
+
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
|
307
|
+
|
|
308
|
+
# Warmup + cosine schedule
|
|
309
|
+
warmup_steps = args.warmup_epochs * len(train_loader)
|
|
310
|
+
total_steps = args.epochs * len(train_loader)
|
|
311
|
+
|
|
312
|
+
def lr_lambda(step):
|
|
313
|
+
if step < warmup_steps:
|
|
314
|
+
return step / max(1, warmup_steps)
|
|
315
|
+
progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
|
|
316
|
+
return 0.5 * (1.0 + np.cos(np.pi * progress))
|
|
317
|
+
|
|
318
|
+
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
|
319
|
+
criterion = nn.MSELoss()
|
|
320
|
+
|
|
321
|
+
# Training loop
|
|
322
|
+
best_val_loss = float("inf")
|
|
323
|
+
best_epoch = 0
|
|
324
|
+
best_state = None
|
|
325
|
+
patience_counter = 0
|
|
326
|
+
|
|
327
|
+
print(f"\n Training for up to {args.epochs} epochs (patience: {args.patience})...\n")
|
|
328
|
+
|
|
329
|
+
for epoch in range(args.epochs):
|
|
330
|
+
model.train()
|
|
331
|
+
train_loss = 0.0
|
|
332
|
+
train_batches = 0
|
|
333
|
+
|
|
334
|
+
for X_batch, y_batch in train_loader:
|
|
335
|
+
X_batch, y_batch = X_batch.to(device), y_batch.to(device)
|
|
336
|
+
optimizer.zero_grad()
|
|
337
|
+
pred = model(X_batch)
|
|
338
|
+
loss = criterion(pred, y_batch)
|
|
339
|
+
loss.backward()
|
|
340
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
|
341
|
+
optimizer.step()
|
|
342
|
+
scheduler.step()
|
|
343
|
+
train_loss += loss.item()
|
|
344
|
+
train_batches += 1
|
|
345
|
+
|
|
346
|
+
avg_train_loss = train_loss / max(1, train_batches)
|
|
347
|
+
|
|
348
|
+
# Validation
|
|
349
|
+
model.eval()
|
|
350
|
+
val_loss = 0.0
|
|
351
|
+
val_batches = 0
|
|
352
|
+
val_preds = []
|
|
353
|
+
val_targets = []
|
|
354
|
+
|
|
355
|
+
with torch.no_grad():
|
|
356
|
+
for X_batch, y_batch in val_loader:
|
|
357
|
+
X_batch, y_batch = X_batch.to(device), y_batch.to(device)
|
|
358
|
+
pred = model(X_batch)
|
|
359
|
+
loss = criterion(pred, y_batch)
|
|
360
|
+
val_loss += loss.item()
|
|
361
|
+
val_batches += 1
|
|
362
|
+
val_preds.extend(pred.cpu().numpy())
|
|
363
|
+
val_targets.extend(y_batch.cpu().numpy())
|
|
364
|
+
|
|
365
|
+
avg_val_loss = val_loss / max(1, val_batches)
|
|
366
|
+
|
|
367
|
+
# Denormalize for metrics
|
|
368
|
+
val_preds_denorm = np.array(val_preds) * target_std + target_mean
|
|
369
|
+
val_targets_denorm = np.array(val_targets) * target_std + target_mean
|
|
370
|
+
|
|
371
|
+
if (epoch + 1) % max(1, args.epochs // 20) == 0 or epoch == 0:
|
|
372
|
+
metrics = compute_metrics(val_preds_denorm, val_targets_denorm)
|
|
373
|
+
lr = optimizer.param_groups[0]["lr"]
|
|
374
|
+
print(f" Epoch {epoch+1:4d} train_loss={avg_train_loss:.6f} val_loss={avg_val_loss:.6f} "
|
|
375
|
+
f"dir_acc={metrics['direction_accuracy']:.3f} rank_corr={metrics['rank_correlation']:.3f} "
|
|
376
|
+
f"lr={lr:.2e}")
|
|
377
|
+
|
|
378
|
+
# Early stopping
|
|
379
|
+
if avg_val_loss < best_val_loss - args.min_delta:
|
|
380
|
+
best_val_loss = avg_val_loss
|
|
381
|
+
best_epoch = epoch + 1
|
|
382
|
+
best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
|
|
383
|
+
patience_counter = 0
|
|
384
|
+
else:
|
|
385
|
+
patience_counter += 1
|
|
386
|
+
if patience_counter >= args.patience:
|
|
387
|
+
print(f"\n Early stopping at epoch {epoch+1} (best: {best_epoch})")
|
|
388
|
+
break
|
|
389
|
+
|
|
390
|
+
# Load best checkpoint
|
|
391
|
+
if best_state:
|
|
392
|
+
model.load_state_dict(best_state)
|
|
393
|
+
model.to(device)
|
|
394
|
+
|
|
395
|
+
# Final metrics on full val set
|
|
396
|
+
model.eval()
|
|
397
|
+
all_preds = []
|
|
398
|
+
all_targets = []
|
|
399
|
+
with torch.no_grad():
|
|
400
|
+
for X_batch, y_batch in val_loader:
|
|
401
|
+
X_batch = X_batch.to(device)
|
|
402
|
+
pred = model(X_batch)
|
|
403
|
+
all_preds.extend(pred.cpu().numpy())
|
|
404
|
+
all_targets.extend(y_batch.numpy())
|
|
405
|
+
|
|
406
|
+
final_preds = np.array(all_preds) * target_std + target_mean
|
|
407
|
+
final_targets = np.array(all_targets) * target_std + target_mean
|
|
408
|
+
final_metrics = compute_metrics(final_preds, final_targets)
|
|
409
|
+
|
|
410
|
+
print(f"\n Final metrics (best epoch {best_epoch}):")
|
|
411
|
+
print(f" Direction accuracy: {final_metrics['direction_accuracy']:.3f}")
|
|
412
|
+
print(f" Rank correlation: {final_metrics['rank_correlation']:.3f}")
|
|
413
|
+
print(f" MSE: {final_metrics['mse']:.6f}")
|
|
414
|
+
print(f" MAE: {final_metrics['mae']:.6f}")
|
|
415
|
+
|
|
416
|
+
# Export
|
|
417
|
+
export_weights(model, embed_dim, target_mean, target_std, len(rewards), final_metrics, args.output)
|
|
418
|
+
|
|
419
|
+
# Also save training metadata
|
|
420
|
+
meta_path = args.output.replace("policy-weights.json", "training-meta.json")
|
|
421
|
+
meta = {
|
|
422
|
+
"trained_at": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()),
|
|
423
|
+
"data_source": str(args.data) if args.data else "pre-computed",
|
|
424
|
+
"entries": int(len(rewards)),
|
|
425
|
+
"train_size": train_size,
|
|
426
|
+
"val_size": val_size,
|
|
427
|
+
"embed_dim": embed_dim,
|
|
428
|
+
"hidden_dim": args.hidden_dim,
|
|
429
|
+
"epochs_run": min(epoch + 1, args.epochs),
|
|
430
|
+
"best_epoch": best_epoch,
|
|
431
|
+
"best_val_loss": float(best_val_loss),
|
|
432
|
+
"lr": args.lr,
|
|
433
|
+
"batch_size": args.batch_size,
|
|
434
|
+
"dropout": 0.1,
|
|
435
|
+
"weight_decay": args.weight_decay,
|
|
436
|
+
"device": device,
|
|
437
|
+
"param_count": param_count,
|
|
438
|
+
"metrics": final_metrics,
|
|
439
|
+
}
|
|
440
|
+
with open(meta_path, "w") as f:
|
|
441
|
+
json.dump(meta, f, indent=2)
|
|
442
|
+
print(f" Training metadata saved to {meta_path}")
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
def main():
|
|
446
|
+
parser = argparse.ArgumentParser(description="Train JFL policy head")
|
|
447
|
+
parser.add_argument("--data", type=str, help="Path to training-buffer.jsonl")
|
|
448
|
+
parser.add_argument("--embeddings", type=str, help="Path to pre-computed embeddings .npz")
|
|
449
|
+
parser.add_argument("--rewards", type=str, help="Path to rewards .npy (with --embeddings)")
|
|
450
|
+
parser.add_argument("--output", type=str, default=".jfl/policy-weights.json",
|
|
451
|
+
help="Output path for policy weights JSON")
|
|
452
|
+
parser.add_argument("--api-key", type=str, help="Stratus API key (or STRATUS_API_KEY env)")
|
|
453
|
+
parser.add_argument("--api-url", type=str, default="https://api.stratus.run",
|
|
454
|
+
help="Stratus API URL")
|
|
455
|
+
parser.add_argument("--epochs", type=int, default=500, help="Max training epochs")
|
|
456
|
+
parser.add_argument("--batch-size", type=int, default=64, help="Batch size")
|
|
457
|
+
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
|
|
458
|
+
parser.add_argument("--hidden-dim", type=int, default=512, help="Hidden layer dimension")
|
|
459
|
+
parser.add_argument("--weight-decay", type=float, default=0.01, help="Weight decay")
|
|
460
|
+
parser.add_argument("--patience", type=int, default=30, help="Early stopping patience")
|
|
461
|
+
parser.add_argument("--min-delta", type=float, default=1e-5, help="Min improvement for early stop")
|
|
462
|
+
parser.add_argument("--warmup-epochs", type=int, default=10, help="LR warmup epochs")
|
|
463
|
+
parser.add_argument("--val-ratio", type=float, default=0.3, help="Validation split ratio")
|
|
464
|
+
parser.add_argument("--min-entries", type=int, default=50, help="Minimum entries to train")
|
|
465
|
+
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
|
466
|
+
parser.add_argument("--reward-clip", type=float, default=1.0, help="Clip rewards to ±this value (default: 1.0)")
|
|
467
|
+
parser.add_argument("--force-embed", action="store_true", help="Force re-computation of embeddings")
|
|
468
|
+
args = parser.parse_args()
|
|
469
|
+
|
|
470
|
+
if not args.data and not args.embeddings:
|
|
471
|
+
parser.error("Either --data or --embeddings is required")
|
|
472
|
+
|
|
473
|
+
train(args)
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
if __name__ == "__main__":
|
|
477
|
+
main()
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
"""
|
|
2
|
+
PyTorch Dataset for v2 policy head training.
|
|
3
|
+
Supports pre-computed embedding cache for fast training.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
from torch.utils.data import Dataset
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PolicyHeadDataset(Dataset):
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
data_path: str,
|
|
16
|
+
tool_to_index: dict[str, int],
|
|
17
|
+
embeddings_matrix: np.ndarray | None = None,
|
|
18
|
+
text_to_idx: dict[str, int] | None = None,
|
|
19
|
+
):
|
|
20
|
+
self.examples = []
|
|
21
|
+
with open(data_path) as f:
|
|
22
|
+
for line in f:
|
|
23
|
+
line = line.strip()
|
|
24
|
+
if not line:
|
|
25
|
+
continue
|
|
26
|
+
self.examples.append(json.loads(line))
|
|
27
|
+
|
|
28
|
+
self.tool_to_index = tool_to_index
|
|
29
|
+
self.embeddings_matrix = embeddings_matrix
|
|
30
|
+
self.text_to_idx = text_to_idx or {}
|
|
31
|
+
|
|
32
|
+
def __len__(self):
|
|
33
|
+
return len(self.examples)
|
|
34
|
+
|
|
35
|
+
def __getitem__(self, idx):
|
|
36
|
+
ex = self.examples[idx]
|
|
37
|
+
|
|
38
|
+
state_text = ex["current_state"]
|
|
39
|
+
goal_text = ex["goal"]
|
|
40
|
+
tool_idx = self.tool_to_index.get(ex["correct_tool"], 0)
|
|
41
|
+
|
|
42
|
+
if self.embeddings_matrix is not None and state_text in self.text_to_idx:
|
|
43
|
+
state_emb = torch.tensor(
|
|
44
|
+
self.embeddings_matrix[self.text_to_idx[state_text]],
|
|
45
|
+
dtype=torch.float32,
|
|
46
|
+
)
|
|
47
|
+
else:
|
|
48
|
+
state_emb = torch.zeros(768)
|
|
49
|
+
|
|
50
|
+
if self.embeddings_matrix is not None and goal_text in self.text_to_idx:
|
|
51
|
+
goal_emb = torch.tensor(
|
|
52
|
+
self.embeddings_matrix[self.text_to_idx[goal_text]],
|
|
53
|
+
dtype=torch.float32,
|
|
54
|
+
)
|
|
55
|
+
else:
|
|
56
|
+
goal_emb = torch.zeros(768)
|
|
57
|
+
|
|
58
|
+
return {
|
|
59
|
+
"state_emb": state_emb,
|
|
60
|
+
"goal_emb": goal_emb,
|
|
61
|
+
"label": torch.tensor(tool_idx, dtype=torch.long),
|
|
62
|
+
"tool_name": ex["correct_tool"],
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def load_embedding_cache(data_dir: str) -> tuple[np.ndarray | None, dict[str, int] | None]:
|
|
67
|
+
import os
|
|
68
|
+
|
|
69
|
+
cache_path = os.path.join(data_dir, "embeddings_cache.npz")
|
|
70
|
+
index_path = os.path.join(data_dir, "text_to_idx.json")
|
|
71
|
+
|
|
72
|
+
if not os.path.exists(cache_path) or not os.path.exists(index_path):
|
|
73
|
+
return None, None
|
|
74
|
+
|
|
75
|
+
data = np.load(cache_path, allow_pickle=True)
|
|
76
|
+
embeddings = data["embeddings"]
|
|
77
|
+
|
|
78
|
+
with open(index_path) as f:
|
|
79
|
+
text_to_idx = json.load(f)
|
|
80
|
+
|
|
81
|
+
return embeddings, text_to_idx
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
{
|
|
2
|
+
"domain": "jfl_gtm_development",
|
|
3
|
+
"description": "AI-driven software development and GTM automation — code changes, testing, configuration, experiments",
|
|
4
|
+
"tools": [
|
|
5
|
+
{"name": "fix_bug", "description": "Fix a bug or error in existing code — address failing tests, runtime errors, or incorrect behavior", "category": "modification"},
|
|
6
|
+
{"name": "refactor_code", "description": "Restructure existing code without changing behavior — improve readability, reduce complexity, extract functions", "category": "modification"},
|
|
7
|
+
{"name": "add_feature", "description": "Implement a new feature or capability — add new functions, endpoints, CLI commands, or UI components", "category": "modification"},
|
|
8
|
+
{"name": "add_tests", "description": "Write or improve test coverage — unit tests, integration tests, eval datasets", "category": "execution"},
|
|
9
|
+
{"name": "update_config", "description": "Change configuration, settings, or infrastructure — update package.json, tsconfig, CI workflows, deploy configs", "category": "system"},
|
|
10
|
+
{"name": "run_experiment", "description": "Conduct a research experiment — try a hypothesis, benchmark alternatives, A/B test approaches", "category": "execution"},
|
|
11
|
+
{"name": "update_docs", "description": "Update documentation, README, or knowledge docs — specs, decision records, API docs", "category": "communication"},
|
|
12
|
+
{"name": "optimize_performance", "description": "Improve speed, reduce resource usage, optimize queries or algorithms", "category": "execution"},
|
|
13
|
+
{"name": "add_monitoring", "description": "Add observability — logging, metrics, health checks, alerting, telemetry", "category": "observation"},
|
|
14
|
+
{"name": "security_hardening", "description": "Fix security issues — input validation, auth, secrets management, dependency updates", "category": "system"},
|
|
15
|
+
{"name": "dependency_update", "description": "Update, add, or remove package dependencies — version bumps, migration to new libraries", "category": "system"},
|
|
16
|
+
{"name": "data_pipeline", "description": "Build or modify data processing pipelines — ETL, data transformation, training data preparation", "category": "execution"}
|
|
17
|
+
]
|
|
18
|
+
}
|