jfl 0.5.0 → 0.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (135) hide show
  1. package/dist/commands/context-hub.d.ts +1 -0
  2. package/dist/commands/context-hub.d.ts.map +1 -1
  3. package/dist/commands/context-hub.js +246 -2
  4. package/dist/commands/context-hub.js.map +1 -1
  5. package/dist/commands/peter.d.ts +2 -0
  6. package/dist/commands/peter.d.ts.map +1 -1
  7. package/dist/commands/peter.js +242 -52
  8. package/dist/commands/peter.js.map +1 -1
  9. package/dist/commands/setup.d.ts +12 -0
  10. package/dist/commands/setup.d.ts.map +1 -0
  11. package/dist/commands/setup.js +322 -0
  12. package/dist/commands/setup.js.map +1 -0
  13. package/dist/commands/train.d.ts +33 -0
  14. package/dist/commands/train.d.ts.map +1 -0
  15. package/dist/commands/train.js +510 -0
  16. package/dist/commands/train.js.map +1 -0
  17. package/dist/commands/verify.d.ts +14 -0
  18. package/dist/commands/verify.d.ts.map +1 -0
  19. package/dist/commands/verify.js +276 -0
  20. package/dist/commands/verify.js.map +1 -0
  21. package/dist/dashboard-static/assets/index-CW9ZxqX8.css +1 -0
  22. package/dist/dashboard-static/assets/index-DNN__p4K.js +121 -0
  23. package/dist/dashboard-static/index.html +2 -2
  24. package/dist/index.js +99 -3
  25. package/dist/index.js.map +1 -1
  26. package/dist/lib/agent-session.d.ts.map +1 -1
  27. package/dist/lib/agent-session.js +12 -4
  28. package/dist/lib/agent-session.js.map +1 -1
  29. package/dist/lib/eval-snapshot.js +1 -1
  30. package/dist/lib/eval-snapshot.js.map +1 -1
  31. package/dist/lib/pi-sky/bridge.d.ts +55 -0
  32. package/dist/lib/pi-sky/bridge.d.ts.map +1 -0
  33. package/dist/lib/pi-sky/bridge.js +264 -0
  34. package/dist/lib/pi-sky/bridge.js.map +1 -0
  35. package/dist/lib/pi-sky/cost-monitor.d.ts +21 -0
  36. package/dist/lib/pi-sky/cost-monitor.d.ts.map +1 -0
  37. package/dist/lib/pi-sky/cost-monitor.js +126 -0
  38. package/dist/lib/pi-sky/cost-monitor.js.map +1 -0
  39. package/dist/lib/pi-sky/eval-sweep.d.ts +27 -0
  40. package/dist/lib/pi-sky/eval-sweep.d.ts.map +1 -0
  41. package/dist/lib/pi-sky/eval-sweep.js +141 -0
  42. package/dist/lib/pi-sky/eval-sweep.js.map +1 -0
  43. package/dist/lib/pi-sky/event-router.d.ts +32 -0
  44. package/dist/lib/pi-sky/event-router.d.ts.map +1 -0
  45. package/dist/lib/pi-sky/event-router.js +176 -0
  46. package/dist/lib/pi-sky/event-router.js.map +1 -0
  47. package/dist/lib/pi-sky/experiment.d.ts +9 -0
  48. package/dist/lib/pi-sky/experiment.d.ts.map +1 -0
  49. package/dist/lib/pi-sky/experiment.js +83 -0
  50. package/dist/lib/pi-sky/experiment.js.map +1 -0
  51. package/dist/lib/pi-sky/index.d.ts +16 -0
  52. package/dist/lib/pi-sky/index.d.ts.map +1 -0
  53. package/dist/lib/pi-sky/index.js +16 -0
  54. package/dist/lib/pi-sky/index.js.map +1 -0
  55. package/dist/lib/pi-sky/stratus-gate.d.ts +28 -0
  56. package/dist/lib/pi-sky/stratus-gate.d.ts.map +1 -0
  57. package/dist/lib/pi-sky/stratus-gate.js +61 -0
  58. package/dist/lib/pi-sky/stratus-gate.js.map +1 -0
  59. package/dist/lib/pi-sky/swarm.d.ts +28 -0
  60. package/dist/lib/pi-sky/swarm.d.ts.map +1 -0
  61. package/dist/lib/pi-sky/swarm.js +208 -0
  62. package/dist/lib/pi-sky/swarm.js.map +1 -0
  63. package/dist/lib/pi-sky/types.d.ts +139 -0
  64. package/dist/lib/pi-sky/types.d.ts.map +1 -0
  65. package/dist/lib/pi-sky/types.js +2 -0
  66. package/dist/lib/pi-sky/types.js.map +1 -0
  67. package/dist/lib/pi-sky/voice-bridge.d.ts +20 -0
  68. package/dist/lib/pi-sky/voice-bridge.d.ts.map +1 -0
  69. package/dist/lib/pi-sky/voice-bridge.js +91 -0
  70. package/dist/lib/pi-sky/voice-bridge.js.map +1 -0
  71. package/dist/lib/policy-head.d.ts +16 -1
  72. package/dist/lib/policy-head.d.ts.map +1 -1
  73. package/dist/lib/policy-head.js +117 -19
  74. package/dist/lib/policy-head.js.map +1 -1
  75. package/dist/lib/predictor.d.ts +10 -0
  76. package/dist/lib/predictor.d.ts.map +1 -1
  77. package/dist/lib/predictor.js +46 -7
  78. package/dist/lib/predictor.js.map +1 -1
  79. package/dist/lib/setup/agent-generator.d.ts +18 -0
  80. package/dist/lib/setup/agent-generator.d.ts.map +1 -0
  81. package/dist/lib/setup/agent-generator.js +114 -0
  82. package/dist/lib/setup/agent-generator.js.map +1 -0
  83. package/dist/lib/setup/context-analyzer.d.ts +16 -0
  84. package/dist/lib/setup/context-analyzer.d.ts.map +1 -0
  85. package/dist/lib/setup/context-analyzer.js +112 -0
  86. package/dist/lib/setup/context-analyzer.js.map +1 -0
  87. package/dist/lib/setup/doc-auditor.d.ts +54 -0
  88. package/dist/lib/setup/doc-auditor.d.ts.map +1 -0
  89. package/dist/lib/setup/doc-auditor.js +629 -0
  90. package/dist/lib/setup/doc-auditor.js.map +1 -0
  91. package/dist/lib/setup/domain-generator.d.ts +7 -0
  92. package/dist/lib/setup/domain-generator.d.ts.map +1 -0
  93. package/dist/lib/setup/domain-generator.js +58 -0
  94. package/dist/lib/setup/domain-generator.js.map +1 -0
  95. package/dist/lib/setup/smart-eval-generator.d.ts +38 -0
  96. package/dist/lib/setup/smart-eval-generator.d.ts.map +1 -0
  97. package/dist/lib/setup/smart-eval-generator.js +378 -0
  98. package/dist/lib/setup/smart-eval-generator.js.map +1 -0
  99. package/dist/lib/setup/smart-recommender.d.ts +63 -0
  100. package/dist/lib/setup/smart-recommender.d.ts.map +1 -0
  101. package/dist/lib/setup/smart-recommender.js +329 -0
  102. package/dist/lib/setup/smart-recommender.js.map +1 -0
  103. package/dist/lib/setup/spec-generator.d.ts +63 -0
  104. package/dist/lib/setup/spec-generator.d.ts.map +1 -0
  105. package/dist/lib/setup/spec-generator.js +310 -0
  106. package/dist/lib/setup/spec-generator.js.map +1 -0
  107. package/dist/lib/setup/violation-agent-generator.d.ts +32 -0
  108. package/dist/lib/setup/violation-agent-generator.d.ts.map +1 -0
  109. package/dist/lib/setup/violation-agent-generator.js +255 -0
  110. package/dist/lib/setup/violation-agent-generator.js.map +1 -0
  111. package/package.json +1 -1
  112. package/packages/pi/extensions/context.ts +88 -55
  113. package/packages/pi/extensions/hub-resolver.ts +63 -0
  114. package/packages/pi/extensions/index.ts +16 -3
  115. package/packages/pi/extensions/memory-tool.ts +9 -4
  116. package/packages/pi/extensions/session.ts +68 -16
  117. package/packages/pi/extensions/tool-renderers.ts +23 -8
  118. package/scripts/train/requirements.txt +5 -0
  119. package/scripts/train/train-policy-head.py +477 -0
  120. package/scripts/train/v2/dataset.py +81 -0
  121. package/scripts/train/v2/domain.json +18 -0
  122. package/scripts/train/v2/eval.py +196 -0
  123. package/scripts/train/v2/generate_data.py +219 -0
  124. package/scripts/train/v2/infer.py +188 -0
  125. package/scripts/train/v2/model.py +112 -0
  126. package/scripts/train/v2/precompute.py +132 -0
  127. package/scripts/train/v2/train.py +302 -0
  128. package/scripts/train/v2/transform_buffer.py +227 -0
  129. package/scripts/train/v2/validate_data.py +115 -0
  130. package/template/.claude/settings.json +2 -15
  131. package/template/scripts/session/session-cleanup.sh +2 -11
  132. package/template/scripts/session/session-end-hub.sh +72 -0
  133. package/template/scripts/session/session-start-hub.sh +105 -0
  134. package/dist/dashboard-static/assets/index-B6b867Pv.js +0 -121
  135. package/dist/dashboard-static/assets/index-Y4BrqxV-.css +0 -1
@@ -0,0 +1,477 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Policy head trainer for JFL RL agents.
4
+
5
+ Trains a 3-layer MLP to predict reward from (state, action) embeddings.
6
+ Matches PolicyWeights JSON format consumed by policy-head.ts inference.
7
+
8
+ Architecture (per Andrew @ Stratus, 2026-03-13 call):
9
+ Input: concat(state_emb, action_emb) = 2 * embed_dim
10
+ Layer 1: Linear(input_dim, 512) + ReLU
11
+ Layer 2: Linear(512, 512) + ReLU + LayerNorm + Dropout
12
+ Layer 3: Linear(512, 1)
13
+
14
+ Usage:
15
+ python train-policy-head.py --data /path/to/.jfl/training-buffer.jsonl
16
+ python train-policy-head.py --embeddings /path/to/embeddings.npz --rewards /path/to/rewards.npy
17
+ """
18
+
19
+ import argparse
20
+ import json
21
+ import os
22
+ import sys
23
+ import time
24
+ from pathlib import Path
25
+
26
+ import numpy as np
27
+ import torch
28
+ import torch.nn as nn
29
+ from torch.utils.data import DataLoader, TensorDataset, random_split
30
+
31
+
32
+ class PolicyHead(nn.Module):
33
+ def __init__(self, input_dim: int, hidden_dim: int = 512):
34
+ super().__init__()
35
+ self.net = nn.Sequential(
36
+ nn.Linear(input_dim, hidden_dim),
37
+ nn.ReLU(),
38
+ nn.Linear(hidden_dim, hidden_dim),
39
+ nn.ReLU(),
40
+ nn.LayerNorm(hidden_dim),
41
+ nn.Dropout(0.1),
42
+ nn.Linear(hidden_dim, 1),
43
+ )
44
+
45
+ def forward(self, x):
46
+ return self.net(x).squeeze(-1)
47
+
48
+
49
+ def get_embeddings_from_stratus(texts: list[str], api_key: str, api_url: str) -> np.ndarray:
50
+ """Batch-embed texts via Stratus /v1/embeddings endpoint."""
51
+ import requests
52
+
53
+ embeddings = []
54
+ batch_size = 32
55
+
56
+ for i in range(0, len(texts), batch_size):
57
+ batch = texts[i : i + batch_size]
58
+ resp = requests.post(
59
+ f"{api_url}/v1/embeddings",
60
+ headers={
61
+ "Authorization": f"Bearer {api_key}",
62
+ "Content-Type": "application/json",
63
+ },
64
+ json={"model": "stratus-x1ac-base", "input": batch},
65
+ timeout=30,
66
+ )
67
+ resp.raise_for_status()
68
+ data = resp.json()
69
+ for item in data["data"]:
70
+ embeddings.append(item["embedding"])
71
+
72
+ if i > 0 and i % 100 == 0:
73
+ print(f" Embedded {i}/{len(texts)} texts...")
74
+
75
+ return np.array(embeddings, dtype=np.float32)
76
+
77
+
78
+ def load_training_data(jsonl_path: str, reward_clip: float = 1.0) -> tuple[list[str], list[str], np.ndarray]:
79
+ """Load training buffer JSONL, return (state_texts, action_texts, rewards).
80
+
81
+ Applies data quality filtering:
82
+ - Clips rewards to [-reward_clip, reward_clip] (default ±1.0)
83
+ - Drops entries with zero reward (no learning signal)
84
+ - Drops entries with missing state/action data
85
+ """
86
+ state_texts = []
87
+ action_texts = []
88
+ rewards = []
89
+ skipped_zero = 0
90
+ skipped_outlier = 0
91
+ skipped_missing = 0
92
+
93
+ with open(jsonl_path) as f:
94
+ for line in f:
95
+ line = line.strip()
96
+ if not line:
97
+ continue
98
+ try:
99
+ entry = json.loads(line)
100
+ except json.JSONDecodeError:
101
+ continue
102
+
103
+ state = entry.get("state", {})
104
+ action = entry.get("action", {})
105
+ reward = entry.get("reward", {})
106
+
107
+ dims = state.get("dimension_scores", {})
108
+ dims_str = ", ".join(f"{k}={v:.4f}" for k, v in dims.items()) if dims else "none"
109
+ deltas = state.get("recent_deltas", [])
110
+ deltas_str = ", ".join(f"{d:+.4f}" for d in deltas) if deltas else "none"
111
+
112
+ state_text = "\n".join([
113
+ f"Agent: {state.get('agent', 'unknown')}",
114
+ f"Composite: {state.get('composite_score', 0):.4f}",
115
+ f"Tests: {state.get('tests_passing', 0)}/{state.get('tests_total', 0)}",
116
+ f"Trajectory: {state.get('trajectory_length', 0)}",
117
+ f"Dimensions: {dims_str}",
118
+ f"Recent deltas: {deltas_str}",
119
+ ])
120
+
121
+ files = action.get("files_affected", [])[:5]
122
+ action_text = "\n".join([
123
+ f"Type: {action.get('type', 'unknown')}",
124
+ f"Description: {action.get('description', '')[:150]}",
125
+ f"Scope: {action.get('scope', 'unknown')}",
126
+ f"Files: {', '.join(files) if files else 'none'}",
127
+ ])
128
+
129
+ composite_delta = reward.get("composite_delta", 0.0)
130
+
131
+ if not action.get("description") or not state.get("agent"):
132
+ skipped_missing += 1
133
+ continue
134
+
135
+ if composite_delta == 0.0:
136
+ skipped_zero += 1
137
+ continue
138
+
139
+ if abs(composite_delta) > reward_clip:
140
+ skipped_outlier += 1
141
+ composite_delta = max(-reward_clip, min(reward_clip, composite_delta))
142
+
143
+ state_texts.append(state_text)
144
+ action_texts.append(action_text)
145
+ rewards.append(composite_delta)
146
+
147
+ total_raw = len(state_texts) + skipped_zero + skipped_outlier + skipped_missing
148
+ print(f" Data quality filter (reward_clip=±{reward_clip}):")
149
+ print(f" Raw entries: {total_raw}")
150
+ print(f" Kept: {len(state_texts)}")
151
+ print(f" Skipped (zero): {skipped_zero}")
152
+ print(f" Clipped (outlier):{skipped_outlier}")
153
+ print(f" Skipped (missing):{skipped_missing}")
154
+
155
+ return state_texts, action_texts, np.array(rewards, dtype=np.float32)
156
+
157
+
158
+ def compute_metrics(predictions: np.ndarray, targets: np.ndarray) -> dict:
159
+ """Compute direction accuracy and rank correlation."""
160
+ direction_correct = np.sum(np.sign(predictions) == np.sign(targets))
161
+ direction_accuracy = direction_correct / len(targets) if len(targets) > 0 else 0.0
162
+
163
+ from scipy.stats import spearmanr
164
+ try:
165
+ rank_corr, _ = spearmanr(predictions, targets)
166
+ if np.isnan(rank_corr):
167
+ rank_corr = 0.0
168
+ except Exception:
169
+ rank_corr = 0.0
170
+
171
+ return {
172
+ "direction_accuracy": float(direction_accuracy),
173
+ "rank_correlation": float(rank_corr),
174
+ "mse": float(np.mean((predictions - targets) ** 2)),
175
+ "mae": float(np.mean(np.abs(predictions - targets))),
176
+ }
177
+
178
+
179
+ def export_weights(model: PolicyHead, embed_dim: int, target_mean: float, target_std: float,
180
+ train_size: int, metrics: dict, output_path: str):
181
+ """Export model weights to PolicyWeights JSON format for policy-head.ts inference."""
182
+ state_dict = model.state_dict()
183
+
184
+ def to_list(tensor):
185
+ return tensor.cpu().detach().numpy().tolist()
186
+
187
+ weights = {
188
+ "version": 1,
189
+ "architecture": "mlp-3layer-512h",
190
+ "embed_dim": embed_dim,
191
+ "mode": "embedding",
192
+ "trained_at": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()),
193
+ "trained_on": train_size,
194
+ "direction_accuracy": metrics["direction_accuracy"],
195
+ "rank_correlation": metrics["rank_correlation"],
196
+ "target_mean": float(target_mean),
197
+ "target_std": float(target_std),
198
+ "layers": {
199
+ "W1": to_list(state_dict["net.0.weight"].T),
200
+ "b1": to_list(state_dict["net.0.bias"]),
201
+ "W2": to_list(state_dict["net.2.weight"].T),
202
+ "b2": to_list(state_dict["net.2.bias"]),
203
+ "W3": to_list(state_dict["net.6.weight"].T),
204
+ "b3": to_list(state_dict["net.6.bias"]),
205
+ },
206
+ }
207
+
208
+ with open(output_path, "w") as f:
209
+ json.dump(weights, f, indent=2)
210
+
211
+ size_mb = os.path.getsize(output_path) / (1024 * 1024)
212
+ print(f"\n Exported weights to {output_path} ({size_mb:.1f} MB)")
213
+ print(f" Direction accuracy: {metrics['direction_accuracy']:.3f}")
214
+ print(f" Rank correlation: {metrics['rank_correlation']:.3f}")
215
+
216
+
217
+ def train(args):
218
+ device = "mps" if torch.backends.mps.is_available() else "cpu"
219
+ print(f"\n Device: {device}")
220
+
221
+ # Load data
222
+ if args.embeddings and args.rewards:
223
+ print(f" Loading pre-computed embeddings from {args.embeddings}")
224
+ data = np.load(args.embeddings)
225
+ state_embs = data["state_embeddings"]
226
+ action_embs = data["action_embeddings"]
227
+ rewards = np.load(args.rewards)
228
+ else:
229
+ print(f" Loading training data from {args.data}")
230
+ state_texts, action_texts, rewards = load_training_data(args.data, reward_clip=args.reward_clip)
231
+ print(f" Loaded {len(rewards)} usable entries")
232
+
233
+ if len(rewards) < args.min_entries:
234
+ print(f"\n Not enough data: {len(rewards)} < {args.min_entries} minimum")
235
+ print(f" Need {args.min_entries - len(rewards)} more training entries")
236
+ sys.exit(1)
237
+
238
+ api_key = args.api_key or os.environ.get("STRATUS_API_KEY")
239
+ api_url = args.api_url or os.environ.get("STRATUS_API_URL", "https://api.stratus.run")
240
+
241
+ if not api_key:
242
+ dotenv_path = Path(args.data).parent.parent / ".env"
243
+ if dotenv_path.exists():
244
+ for line in dotenv_path.read_text().splitlines():
245
+ if line.startswith("STRATUS_API_KEY="):
246
+ api_key = line.split("=", 1)[1].strip().strip('"').strip("'")
247
+ break
248
+
249
+ if not api_key:
250
+ print("\n STRATUS_API_KEY not set. Cannot compute embeddings.")
251
+ print(" Either set the env var or use --embeddings with pre-computed data.")
252
+ sys.exit(1)
253
+
254
+ # Check for cached embeddings
255
+ cache_dir = Path(args.data).parent / "train-cache"
256
+ cache_dir.mkdir(exist_ok=True)
257
+ cache_path = cache_dir / f"embeddings-{len(rewards)}.npz"
258
+
259
+ if cache_path.exists() and not args.force_embed:
260
+ print(f" Loading cached embeddings from {cache_path}")
261
+ cached = np.load(cache_path)
262
+ state_embs = cached["state_embeddings"]
263
+ action_embs = cached["action_embeddings"]
264
+ else:
265
+ print(f" Computing embeddings for {len(state_texts)} entries...")
266
+ state_embs = get_embeddings_from_stratus(state_texts, api_key, api_url)
267
+ action_embs = get_embeddings_from_stratus(action_texts, api_key, api_url)
268
+ np.savez(cache_path, state_embeddings=state_embs, action_embeddings=action_embs)
269
+ print(f" Cached embeddings to {cache_path}")
270
+
271
+ embed_dim = state_embs.shape[1]
272
+ input_dim = embed_dim * 2
273
+ print(f" Embedding dim: {embed_dim}, Input dim: {input_dim}")
274
+ print(f" Entries: {len(rewards)}, Reward range: [{rewards.min():.4f}, {rewards.max():.4f}]")
275
+
276
+ # Normalize targets
277
+ target_mean = float(rewards.mean())
278
+ target_std = float(rewards.std()) if rewards.std() > 1e-8 else 1.0
279
+ normalized_rewards = (rewards - target_mean) / target_std
280
+
281
+ # Build tensors
282
+ X = np.concatenate([state_embs, action_embs], axis=1)
283
+ X_tensor = torch.tensor(X, dtype=torch.float32)
284
+ y_tensor = torch.tensor(normalized_rewards, dtype=torch.float32)
285
+
286
+ dataset = TensorDataset(X_tensor, y_tensor)
287
+
288
+ # 70/30 split
289
+ val_size = max(1, int(len(dataset) * args.val_ratio))
290
+ train_size = len(dataset) - val_size
291
+ train_dataset, val_dataset = random_split(
292
+ dataset, [train_size, val_size],
293
+ generator=torch.Generator().manual_seed(args.seed)
294
+ )
295
+
296
+ print(f" Train: {train_size}, Val: {val_size}")
297
+
298
+ train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
299
+ val_loader = DataLoader(val_dataset, batch_size=args.batch_size)
300
+
301
+ # Build model
302
+ model = PolicyHead(input_dim, hidden_dim=args.hidden_dim).to(device)
303
+ param_count = sum(p.numel() for p in model.parameters())
304
+ print(f" Model parameters: {param_count:,}")
305
+
306
+ optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
307
+
308
+ # Warmup + cosine schedule
309
+ warmup_steps = args.warmup_epochs * len(train_loader)
310
+ total_steps = args.epochs * len(train_loader)
311
+
312
+ def lr_lambda(step):
313
+ if step < warmup_steps:
314
+ return step / max(1, warmup_steps)
315
+ progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
316
+ return 0.5 * (1.0 + np.cos(np.pi * progress))
317
+
318
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
319
+ criterion = nn.MSELoss()
320
+
321
+ # Training loop
322
+ best_val_loss = float("inf")
323
+ best_epoch = 0
324
+ best_state = None
325
+ patience_counter = 0
326
+
327
+ print(f"\n Training for up to {args.epochs} epochs (patience: {args.patience})...\n")
328
+
329
+ for epoch in range(args.epochs):
330
+ model.train()
331
+ train_loss = 0.0
332
+ train_batches = 0
333
+
334
+ for X_batch, y_batch in train_loader:
335
+ X_batch, y_batch = X_batch.to(device), y_batch.to(device)
336
+ optimizer.zero_grad()
337
+ pred = model(X_batch)
338
+ loss = criterion(pred, y_batch)
339
+ loss.backward()
340
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
341
+ optimizer.step()
342
+ scheduler.step()
343
+ train_loss += loss.item()
344
+ train_batches += 1
345
+
346
+ avg_train_loss = train_loss / max(1, train_batches)
347
+
348
+ # Validation
349
+ model.eval()
350
+ val_loss = 0.0
351
+ val_batches = 0
352
+ val_preds = []
353
+ val_targets = []
354
+
355
+ with torch.no_grad():
356
+ for X_batch, y_batch in val_loader:
357
+ X_batch, y_batch = X_batch.to(device), y_batch.to(device)
358
+ pred = model(X_batch)
359
+ loss = criterion(pred, y_batch)
360
+ val_loss += loss.item()
361
+ val_batches += 1
362
+ val_preds.extend(pred.cpu().numpy())
363
+ val_targets.extend(y_batch.cpu().numpy())
364
+
365
+ avg_val_loss = val_loss / max(1, val_batches)
366
+
367
+ # Denormalize for metrics
368
+ val_preds_denorm = np.array(val_preds) * target_std + target_mean
369
+ val_targets_denorm = np.array(val_targets) * target_std + target_mean
370
+
371
+ if (epoch + 1) % max(1, args.epochs // 20) == 0 or epoch == 0:
372
+ metrics = compute_metrics(val_preds_denorm, val_targets_denorm)
373
+ lr = optimizer.param_groups[0]["lr"]
374
+ print(f" Epoch {epoch+1:4d} train_loss={avg_train_loss:.6f} val_loss={avg_val_loss:.6f} "
375
+ f"dir_acc={metrics['direction_accuracy']:.3f} rank_corr={metrics['rank_correlation']:.3f} "
376
+ f"lr={lr:.2e}")
377
+
378
+ # Early stopping
379
+ if avg_val_loss < best_val_loss - args.min_delta:
380
+ best_val_loss = avg_val_loss
381
+ best_epoch = epoch + 1
382
+ best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
383
+ patience_counter = 0
384
+ else:
385
+ patience_counter += 1
386
+ if patience_counter >= args.patience:
387
+ print(f"\n Early stopping at epoch {epoch+1} (best: {best_epoch})")
388
+ break
389
+
390
+ # Load best checkpoint
391
+ if best_state:
392
+ model.load_state_dict(best_state)
393
+ model.to(device)
394
+
395
+ # Final metrics on full val set
396
+ model.eval()
397
+ all_preds = []
398
+ all_targets = []
399
+ with torch.no_grad():
400
+ for X_batch, y_batch in val_loader:
401
+ X_batch = X_batch.to(device)
402
+ pred = model(X_batch)
403
+ all_preds.extend(pred.cpu().numpy())
404
+ all_targets.extend(y_batch.numpy())
405
+
406
+ final_preds = np.array(all_preds) * target_std + target_mean
407
+ final_targets = np.array(all_targets) * target_std + target_mean
408
+ final_metrics = compute_metrics(final_preds, final_targets)
409
+
410
+ print(f"\n Final metrics (best epoch {best_epoch}):")
411
+ print(f" Direction accuracy: {final_metrics['direction_accuracy']:.3f}")
412
+ print(f" Rank correlation: {final_metrics['rank_correlation']:.3f}")
413
+ print(f" MSE: {final_metrics['mse']:.6f}")
414
+ print(f" MAE: {final_metrics['mae']:.6f}")
415
+
416
+ # Export
417
+ export_weights(model, embed_dim, target_mean, target_std, len(rewards), final_metrics, args.output)
418
+
419
+ # Also save training metadata
420
+ meta_path = args.output.replace("policy-weights.json", "training-meta.json")
421
+ meta = {
422
+ "trained_at": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()),
423
+ "data_source": str(args.data) if args.data else "pre-computed",
424
+ "entries": int(len(rewards)),
425
+ "train_size": train_size,
426
+ "val_size": val_size,
427
+ "embed_dim": embed_dim,
428
+ "hidden_dim": args.hidden_dim,
429
+ "epochs_run": min(epoch + 1, args.epochs),
430
+ "best_epoch": best_epoch,
431
+ "best_val_loss": float(best_val_loss),
432
+ "lr": args.lr,
433
+ "batch_size": args.batch_size,
434
+ "dropout": 0.1,
435
+ "weight_decay": args.weight_decay,
436
+ "device": device,
437
+ "param_count": param_count,
438
+ "metrics": final_metrics,
439
+ }
440
+ with open(meta_path, "w") as f:
441
+ json.dump(meta, f, indent=2)
442
+ print(f" Training metadata saved to {meta_path}")
443
+
444
+
445
+ def main():
446
+ parser = argparse.ArgumentParser(description="Train JFL policy head")
447
+ parser.add_argument("--data", type=str, help="Path to training-buffer.jsonl")
448
+ parser.add_argument("--embeddings", type=str, help="Path to pre-computed embeddings .npz")
449
+ parser.add_argument("--rewards", type=str, help="Path to rewards .npy (with --embeddings)")
450
+ parser.add_argument("--output", type=str, default=".jfl/policy-weights.json",
451
+ help="Output path for policy weights JSON")
452
+ parser.add_argument("--api-key", type=str, help="Stratus API key (or STRATUS_API_KEY env)")
453
+ parser.add_argument("--api-url", type=str, default="https://api.stratus.run",
454
+ help="Stratus API URL")
455
+ parser.add_argument("--epochs", type=int, default=500, help="Max training epochs")
456
+ parser.add_argument("--batch-size", type=int, default=64, help="Batch size")
457
+ parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
458
+ parser.add_argument("--hidden-dim", type=int, default=512, help="Hidden layer dimension")
459
+ parser.add_argument("--weight-decay", type=float, default=0.01, help="Weight decay")
460
+ parser.add_argument("--patience", type=int, default=30, help="Early stopping patience")
461
+ parser.add_argument("--min-delta", type=float, default=1e-5, help="Min improvement for early stop")
462
+ parser.add_argument("--warmup-epochs", type=int, default=10, help="LR warmup epochs")
463
+ parser.add_argument("--val-ratio", type=float, default=0.3, help="Validation split ratio")
464
+ parser.add_argument("--min-entries", type=int, default=50, help="Minimum entries to train")
465
+ parser.add_argument("--seed", type=int, default=42, help="Random seed")
466
+ parser.add_argument("--reward-clip", type=float, default=1.0, help="Clip rewards to ±this value (default: 1.0)")
467
+ parser.add_argument("--force-embed", action="store_true", help="Force re-computation of embeddings")
468
+ args = parser.parse_args()
469
+
470
+ if not args.data and not args.embeddings:
471
+ parser.error("Either --data or --embeddings is required")
472
+
473
+ train(args)
474
+
475
+
476
+ if __name__ == "__main__":
477
+ main()
@@ -0,0 +1,81 @@
1
+ """
2
+ PyTorch Dataset for v2 policy head training.
3
+ Supports pre-computed embedding cache for fast training.
4
+ """
5
+
6
+ import json
7
+ import numpy as np
8
+ import torch
9
+ from torch.utils.data import Dataset
10
+
11
+
12
+ class PolicyHeadDataset(Dataset):
13
+ def __init__(
14
+ self,
15
+ data_path: str,
16
+ tool_to_index: dict[str, int],
17
+ embeddings_matrix: np.ndarray | None = None,
18
+ text_to_idx: dict[str, int] | None = None,
19
+ ):
20
+ self.examples = []
21
+ with open(data_path) as f:
22
+ for line in f:
23
+ line = line.strip()
24
+ if not line:
25
+ continue
26
+ self.examples.append(json.loads(line))
27
+
28
+ self.tool_to_index = tool_to_index
29
+ self.embeddings_matrix = embeddings_matrix
30
+ self.text_to_idx = text_to_idx or {}
31
+
32
+ def __len__(self):
33
+ return len(self.examples)
34
+
35
+ def __getitem__(self, idx):
36
+ ex = self.examples[idx]
37
+
38
+ state_text = ex["current_state"]
39
+ goal_text = ex["goal"]
40
+ tool_idx = self.tool_to_index.get(ex["correct_tool"], 0)
41
+
42
+ if self.embeddings_matrix is not None and state_text in self.text_to_idx:
43
+ state_emb = torch.tensor(
44
+ self.embeddings_matrix[self.text_to_idx[state_text]],
45
+ dtype=torch.float32,
46
+ )
47
+ else:
48
+ state_emb = torch.zeros(768)
49
+
50
+ if self.embeddings_matrix is not None and goal_text in self.text_to_idx:
51
+ goal_emb = torch.tensor(
52
+ self.embeddings_matrix[self.text_to_idx[goal_text]],
53
+ dtype=torch.float32,
54
+ )
55
+ else:
56
+ goal_emb = torch.zeros(768)
57
+
58
+ return {
59
+ "state_emb": state_emb,
60
+ "goal_emb": goal_emb,
61
+ "label": torch.tensor(tool_idx, dtype=torch.long),
62
+ "tool_name": ex["correct_tool"],
63
+ }
64
+
65
+
66
+ def load_embedding_cache(data_dir: str) -> tuple[np.ndarray | None, dict[str, int] | None]:
67
+ import os
68
+
69
+ cache_path = os.path.join(data_dir, "embeddings_cache.npz")
70
+ index_path = os.path.join(data_dir, "text_to_idx.json")
71
+
72
+ if not os.path.exists(cache_path) or not os.path.exists(index_path):
73
+ return None, None
74
+
75
+ data = np.load(cache_path, allow_pickle=True)
76
+ embeddings = data["embeddings"]
77
+
78
+ with open(index_path) as f:
79
+ text_to_idx = json.load(f)
80
+
81
+ return embeddings, text_to_idx
@@ -0,0 +1,18 @@
1
+ {
2
+ "domain": "jfl_gtm_development",
3
+ "description": "AI-driven software development and GTM automation — code changes, testing, configuration, experiments",
4
+ "tools": [
5
+ {"name": "fix_bug", "description": "Fix a bug or error in existing code — address failing tests, runtime errors, or incorrect behavior", "category": "modification"},
6
+ {"name": "refactor_code", "description": "Restructure existing code without changing behavior — improve readability, reduce complexity, extract functions", "category": "modification"},
7
+ {"name": "add_feature", "description": "Implement a new feature or capability — add new functions, endpoints, CLI commands, or UI components", "category": "modification"},
8
+ {"name": "add_tests", "description": "Write or improve test coverage — unit tests, integration tests, eval datasets", "category": "execution"},
9
+ {"name": "update_config", "description": "Change configuration, settings, or infrastructure — update package.json, tsconfig, CI workflows, deploy configs", "category": "system"},
10
+ {"name": "run_experiment", "description": "Conduct a research experiment — try a hypothesis, benchmark alternatives, A/B test approaches", "category": "execution"},
11
+ {"name": "update_docs", "description": "Update documentation, README, or knowledge docs — specs, decision records, API docs", "category": "communication"},
12
+ {"name": "optimize_performance", "description": "Improve speed, reduce resource usage, optimize queries or algorithms", "category": "execution"},
13
+ {"name": "add_monitoring", "description": "Add observability — logging, metrics, health checks, alerting, telemetry", "category": "observation"},
14
+ {"name": "security_hardening", "description": "Fix security issues — input validation, auth, secrets management, dependency updates", "category": "system"},
15
+ {"name": "dependency_update", "description": "Update, add, or remove package dependencies — version bumps, migration to new libraries", "category": "system"},
16
+ {"name": "data_pipeline", "description": "Build or modify data processing pipelines — ETL, data transformation, training data preparation", "category": "execution"}
17
+ ]
18
+ }