jfl 0.5.0 → 0.6.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/commands/context-hub.d.ts +1 -0
- package/dist/commands/context-hub.d.ts.map +1 -1
- package/dist/commands/context-hub.js +246 -2
- package/dist/commands/context-hub.js.map +1 -1
- package/dist/commands/peter.d.ts +2 -0
- package/dist/commands/peter.d.ts.map +1 -1
- package/dist/commands/peter.js +242 -52
- package/dist/commands/peter.js.map +1 -1
- package/dist/commands/setup.d.ts +12 -0
- package/dist/commands/setup.d.ts.map +1 -0
- package/dist/commands/setup.js +322 -0
- package/dist/commands/setup.js.map +1 -0
- package/dist/commands/train.d.ts +33 -0
- package/dist/commands/train.d.ts.map +1 -0
- package/dist/commands/train.js +510 -0
- package/dist/commands/train.js.map +1 -0
- package/dist/commands/verify.d.ts +14 -0
- package/dist/commands/verify.d.ts.map +1 -0
- package/dist/commands/verify.js +276 -0
- package/dist/commands/verify.js.map +1 -0
- package/dist/dashboard-static/assets/index-CW9ZxqX8.css +1 -0
- package/dist/dashboard-static/assets/index-DNN__p4K.js +121 -0
- package/dist/dashboard-static/index.html +2 -2
- package/dist/index.js +99 -3
- package/dist/index.js.map +1 -1
- package/dist/lib/agent-session.d.ts.map +1 -1
- package/dist/lib/agent-session.js +12 -4
- package/dist/lib/agent-session.js.map +1 -1
- package/dist/lib/eval-snapshot.js +1 -1
- package/dist/lib/eval-snapshot.js.map +1 -1
- package/dist/lib/pi-sky/bridge.d.ts +55 -0
- package/dist/lib/pi-sky/bridge.d.ts.map +1 -0
- package/dist/lib/pi-sky/bridge.js +264 -0
- package/dist/lib/pi-sky/bridge.js.map +1 -0
- package/dist/lib/pi-sky/cost-monitor.d.ts +21 -0
- package/dist/lib/pi-sky/cost-monitor.d.ts.map +1 -0
- package/dist/lib/pi-sky/cost-monitor.js +126 -0
- package/dist/lib/pi-sky/cost-monitor.js.map +1 -0
- package/dist/lib/pi-sky/eval-sweep.d.ts +27 -0
- package/dist/lib/pi-sky/eval-sweep.d.ts.map +1 -0
- package/dist/lib/pi-sky/eval-sweep.js +141 -0
- package/dist/lib/pi-sky/eval-sweep.js.map +1 -0
- package/dist/lib/pi-sky/event-router.d.ts +32 -0
- package/dist/lib/pi-sky/event-router.d.ts.map +1 -0
- package/dist/lib/pi-sky/event-router.js +176 -0
- package/dist/lib/pi-sky/event-router.js.map +1 -0
- package/dist/lib/pi-sky/experiment.d.ts +9 -0
- package/dist/lib/pi-sky/experiment.d.ts.map +1 -0
- package/dist/lib/pi-sky/experiment.js +83 -0
- package/dist/lib/pi-sky/experiment.js.map +1 -0
- package/dist/lib/pi-sky/index.d.ts +16 -0
- package/dist/lib/pi-sky/index.d.ts.map +1 -0
- package/dist/lib/pi-sky/index.js +16 -0
- package/dist/lib/pi-sky/index.js.map +1 -0
- package/dist/lib/pi-sky/stratus-gate.d.ts +28 -0
- package/dist/lib/pi-sky/stratus-gate.d.ts.map +1 -0
- package/dist/lib/pi-sky/stratus-gate.js +61 -0
- package/dist/lib/pi-sky/stratus-gate.js.map +1 -0
- package/dist/lib/pi-sky/swarm.d.ts +28 -0
- package/dist/lib/pi-sky/swarm.d.ts.map +1 -0
- package/dist/lib/pi-sky/swarm.js +208 -0
- package/dist/lib/pi-sky/swarm.js.map +1 -0
- package/dist/lib/pi-sky/types.d.ts +139 -0
- package/dist/lib/pi-sky/types.d.ts.map +1 -0
- package/dist/lib/pi-sky/types.js +2 -0
- package/dist/lib/pi-sky/types.js.map +1 -0
- package/dist/lib/pi-sky/voice-bridge.d.ts +20 -0
- package/dist/lib/pi-sky/voice-bridge.d.ts.map +1 -0
- package/dist/lib/pi-sky/voice-bridge.js +91 -0
- package/dist/lib/pi-sky/voice-bridge.js.map +1 -0
- package/dist/lib/policy-head.d.ts +16 -1
- package/dist/lib/policy-head.d.ts.map +1 -1
- package/dist/lib/policy-head.js +117 -19
- package/dist/lib/policy-head.js.map +1 -1
- package/dist/lib/predictor.d.ts +10 -0
- package/dist/lib/predictor.d.ts.map +1 -1
- package/dist/lib/predictor.js +46 -7
- package/dist/lib/predictor.js.map +1 -1
- package/dist/lib/setup/agent-generator.d.ts +18 -0
- package/dist/lib/setup/agent-generator.d.ts.map +1 -0
- package/dist/lib/setup/agent-generator.js +114 -0
- package/dist/lib/setup/agent-generator.js.map +1 -0
- package/dist/lib/setup/context-analyzer.d.ts +16 -0
- package/dist/lib/setup/context-analyzer.d.ts.map +1 -0
- package/dist/lib/setup/context-analyzer.js +112 -0
- package/dist/lib/setup/context-analyzer.js.map +1 -0
- package/dist/lib/setup/doc-auditor.d.ts +54 -0
- package/dist/lib/setup/doc-auditor.d.ts.map +1 -0
- package/dist/lib/setup/doc-auditor.js +629 -0
- package/dist/lib/setup/doc-auditor.js.map +1 -0
- package/dist/lib/setup/domain-generator.d.ts +7 -0
- package/dist/lib/setup/domain-generator.d.ts.map +1 -0
- package/dist/lib/setup/domain-generator.js +58 -0
- package/dist/lib/setup/domain-generator.js.map +1 -0
- package/dist/lib/setup/smart-eval-generator.d.ts +38 -0
- package/dist/lib/setup/smart-eval-generator.d.ts.map +1 -0
- package/dist/lib/setup/smart-eval-generator.js +378 -0
- package/dist/lib/setup/smart-eval-generator.js.map +1 -0
- package/dist/lib/setup/smart-recommender.d.ts +63 -0
- package/dist/lib/setup/smart-recommender.d.ts.map +1 -0
- package/dist/lib/setup/smart-recommender.js +329 -0
- package/dist/lib/setup/smart-recommender.js.map +1 -0
- package/dist/lib/setup/spec-generator.d.ts +63 -0
- package/dist/lib/setup/spec-generator.d.ts.map +1 -0
- package/dist/lib/setup/spec-generator.js +310 -0
- package/dist/lib/setup/spec-generator.js.map +1 -0
- package/dist/lib/setup/violation-agent-generator.d.ts +32 -0
- package/dist/lib/setup/violation-agent-generator.d.ts.map +1 -0
- package/dist/lib/setup/violation-agent-generator.js +255 -0
- package/dist/lib/setup/violation-agent-generator.js.map +1 -0
- package/package.json +1 -1
- package/packages/pi/extensions/context.ts +88 -55
- package/packages/pi/extensions/hub-resolver.ts +63 -0
- package/packages/pi/extensions/index.ts +16 -3
- package/packages/pi/extensions/memory-tool.ts +9 -4
- package/packages/pi/extensions/session.ts +68 -16
- package/packages/pi/extensions/tool-renderers.ts +23 -8
- package/scripts/train/requirements.txt +5 -0
- package/scripts/train/train-policy-head.py +477 -0
- package/scripts/train/v2/dataset.py +81 -0
- package/scripts/train/v2/domain.json +18 -0
- package/scripts/train/v2/eval.py +196 -0
- package/scripts/train/v2/generate_data.py +219 -0
- package/scripts/train/v2/infer.py +188 -0
- package/scripts/train/v2/model.py +112 -0
- package/scripts/train/v2/precompute.py +132 -0
- package/scripts/train/v2/train.py +302 -0
- package/scripts/train/v2/transform_buffer.py +227 -0
- package/scripts/train/v2/validate_data.py +115 -0
- package/template/.claude/settings.json +2 -15
- package/template/scripts/session/session-cleanup.sh +2 -11
- package/template/scripts/session/session-end-hub.sh +72 -0
- package/template/scripts/session/session-start-hub.sh +105 -0
- package/dist/dashboard-static/assets/index-B6b867Pv.js +0 -121
- package/dist/dashboard-static/assets/index-Y4BrqxV-.css +0 -1
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
"""
|
|
2
|
+
v2 Policy Head Evaluation — per-tool accuracy, confusion matrix, confidence analysis.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
import sys
|
|
8
|
+
import argparse
|
|
9
|
+
from collections import defaultdict
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import torch
|
|
13
|
+
from torch.utils.data import DataLoader
|
|
14
|
+
|
|
15
|
+
from model import PolicyHead
|
|
16
|
+
from dataset import PolicyHeadDataset, load_embedding_cache
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def load_model(checkpoint_path: str, device: str = "cpu") -> tuple:
|
|
20
|
+
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
|
21
|
+
config = ckpt["config"]
|
|
22
|
+
tool_to_index = ckpt["tool_to_index"]
|
|
23
|
+
index_to_tool = {v: k for k, v in tool_to_index.items()}
|
|
24
|
+
|
|
25
|
+
model = PolicyHead(
|
|
26
|
+
embedding_dim=config["embedding_dim"],
|
|
27
|
+
hidden_dim=config["hidden_dim"],
|
|
28
|
+
num_tools=ckpt["num_tools"],
|
|
29
|
+
num_layers=config["num_layers"],
|
|
30
|
+
num_heads=config["num_heads"],
|
|
31
|
+
dropout=config.get("dropout", 0.1),
|
|
32
|
+
).to(device)
|
|
33
|
+
|
|
34
|
+
model.load_state_dict(ckpt["model_state_dict"])
|
|
35
|
+
model.eval()
|
|
36
|
+
|
|
37
|
+
return model, tool_to_index, index_to_tool
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@torch.no_grad()
|
|
41
|
+
def evaluate_detailed(model, dataloader, index_to_tool, device):
|
|
42
|
+
model.eval()
|
|
43
|
+
|
|
44
|
+
per_tool_correct = defaultdict(int)
|
|
45
|
+
per_tool_total = defaultdict(int)
|
|
46
|
+
confusion_pairs = defaultdict(int)
|
|
47
|
+
all_predictions = []
|
|
48
|
+
|
|
49
|
+
for batch in dataloader:
|
|
50
|
+
state_emb = batch["state_emb"].to(device)
|
|
51
|
+
goal_emb = batch["goal_emb"].to(device)
|
|
52
|
+
labels = batch["label"].to(device)
|
|
53
|
+
tool_names = batch["tool_name"]
|
|
54
|
+
|
|
55
|
+
logits = model(state_emb, goal_emb)
|
|
56
|
+
probs = torch.softmax(logits, dim=-1)
|
|
57
|
+
preds = logits.argmax(dim=-1)
|
|
58
|
+
|
|
59
|
+
for i in range(labels.size(0)):
|
|
60
|
+
true_tool = tool_names[i]
|
|
61
|
+
pred_tool = index_to_tool[preds[i].item()]
|
|
62
|
+
confidence = probs[i, preds[i]].item()
|
|
63
|
+
|
|
64
|
+
per_tool_total[true_tool] += 1
|
|
65
|
+
if preds[i] == labels[i]:
|
|
66
|
+
per_tool_correct[true_tool] += 1
|
|
67
|
+
else:
|
|
68
|
+
confusion_pairs[(true_tool, pred_tool)] += 1
|
|
69
|
+
|
|
70
|
+
all_predictions.append({
|
|
71
|
+
"true": true_tool,
|
|
72
|
+
"predicted": pred_tool,
|
|
73
|
+
"correct": preds[i].item() == labels[i].item(),
|
|
74
|
+
"confidence": confidence,
|
|
75
|
+
})
|
|
76
|
+
|
|
77
|
+
return per_tool_correct, per_tool_total, confusion_pairs, all_predictions
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def print_report(per_tool_correct, per_tool_total, confusion_pairs, all_predictions):
|
|
81
|
+
total_correct = sum(per_tool_correct.values())
|
|
82
|
+
total = sum(per_tool_total.values())
|
|
83
|
+
overall_acc = total_correct / max(total, 1)
|
|
84
|
+
|
|
85
|
+
print("=" * 70)
|
|
86
|
+
print(" v2 POLICY HEAD — EVALUATION REPORT")
|
|
87
|
+
print("=" * 70)
|
|
88
|
+
print(f"\n Overall Accuracy: {overall_acc:.1%} ({total_correct}/{total})")
|
|
89
|
+
print(f"\n{'Tool':<30} {'Accuracy':>10} {'Correct':>9} {'Total':>7}")
|
|
90
|
+
print("-" * 60)
|
|
91
|
+
|
|
92
|
+
for tool in sorted(per_tool_total.keys()):
|
|
93
|
+
correct = per_tool_correct.get(tool, 0)
|
|
94
|
+
total_t = per_tool_total[tool]
|
|
95
|
+
acc = correct / total_t if total_t > 0 else 0
|
|
96
|
+
bar = "█" * int(acc * 20) + "░" * (20 - int(acc * 20))
|
|
97
|
+
print(f" {tool:<28} {acc:>8.1%} {correct:>5}/{total_t:<5} {bar}")
|
|
98
|
+
|
|
99
|
+
if confusion_pairs:
|
|
100
|
+
print(f"\n{'Top Confusion Pairs':<45} {'Count':>7}")
|
|
101
|
+
print("-" * 55)
|
|
102
|
+
for (true_t, pred_t), count in sorted(
|
|
103
|
+
confusion_pairs.items(), key=lambda x: -x[1]
|
|
104
|
+
)[:10]:
|
|
105
|
+
print(f" {true_t} -> {pred_t:<25} {count:>5}")
|
|
106
|
+
|
|
107
|
+
correct_confs = [p["confidence"] for p in all_predictions if p["correct"]]
|
|
108
|
+
wrong_confs = [p["confidence"] for p in all_predictions if not p["correct"]]
|
|
109
|
+
|
|
110
|
+
print(f"\n Confidence Analysis:")
|
|
111
|
+
if correct_confs:
|
|
112
|
+
print(f" Correct predictions: avg={sum(correct_confs) / len(correct_confs):.3f}")
|
|
113
|
+
if wrong_confs:
|
|
114
|
+
print(f" Wrong predictions: avg={sum(wrong_confs) / len(wrong_confs):.3f}")
|
|
115
|
+
|
|
116
|
+
if overall_acc < 0.80:
|
|
117
|
+
print(f"\n ⚠️ Accuracy below 80% — consider more data or warm-start")
|
|
118
|
+
elif overall_acc < 0.90:
|
|
119
|
+
print(f"\n ⚠️ Accuracy below 90% — good but can improve")
|
|
120
|
+
else:
|
|
121
|
+
print(f"\n ✅ Accuracy ≥90% — ready for deployment")
|
|
122
|
+
|
|
123
|
+
print("=" * 70)
|
|
124
|
+
|
|
125
|
+
return {
|
|
126
|
+
"overall_accuracy": overall_acc,
|
|
127
|
+
"total": total,
|
|
128
|
+
"per_tool": {
|
|
129
|
+
tool: {
|
|
130
|
+
"accuracy": per_tool_correct.get(tool, 0) / per_tool_total[tool] if per_tool_total[tool] > 0 else 0,
|
|
131
|
+
"correct": per_tool_correct.get(tool, 0),
|
|
132
|
+
"total": per_tool_total[tool],
|
|
133
|
+
}
|
|
134
|
+
for tool in per_tool_total
|
|
135
|
+
},
|
|
136
|
+
"top_confusion": [
|
|
137
|
+
{"true": t, "predicted": p, "count": c}
|
|
138
|
+
for (t, p), c in sorted(confusion_pairs.items(), key=lambda x: -x[1])[:10]
|
|
139
|
+
],
|
|
140
|
+
"avg_confidence_correct": sum(correct_confs) / len(correct_confs) if correct_confs else 0,
|
|
141
|
+
"avg_confidence_wrong": sum(wrong_confs) / len(wrong_confs) if wrong_confs else 0,
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def main():
|
|
146
|
+
parser = argparse.ArgumentParser(description="Evaluate v2 policy head")
|
|
147
|
+
parser.add_argument("--checkpoint", default=".jfl/checkpoints/best_policy_head.pt", help="Path to checkpoint")
|
|
148
|
+
parser.add_argument("--data-dir", default=".jfl/v2-data", help="Directory with test.jsonl + embeddings")
|
|
149
|
+
parser.add_argument("--split", default="test", help="Which split to evaluate (train/val/test)")
|
|
150
|
+
parser.add_argument("--json", action="store_true", help="Output results as JSON")
|
|
151
|
+
args = parser.parse_args()
|
|
152
|
+
|
|
153
|
+
if not os.path.exists(args.checkpoint):
|
|
154
|
+
print(f"Checkpoint not found: {args.checkpoint}")
|
|
155
|
+
sys.exit(1)
|
|
156
|
+
|
|
157
|
+
if torch.cuda.is_available():
|
|
158
|
+
device = "cuda"
|
|
159
|
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
160
|
+
device = "mps"
|
|
161
|
+
else:
|
|
162
|
+
device = "cpu"
|
|
163
|
+
|
|
164
|
+
model, tool_to_index, index_to_tool = load_model(args.checkpoint, device)
|
|
165
|
+
|
|
166
|
+
data_path = os.path.join(args.data_dir, f"{args.split}.jsonl")
|
|
167
|
+
if not os.path.exists(data_path):
|
|
168
|
+
print(f"Data not found: {data_path}")
|
|
169
|
+
sys.exit(1)
|
|
170
|
+
|
|
171
|
+
embeddings_matrix, text_to_idx = load_embedding_cache(args.data_dir)
|
|
172
|
+
|
|
173
|
+
test_ds = PolicyHeadDataset(data_path, tool_to_index, embeddings_matrix, text_to_idx)
|
|
174
|
+
test_loader = DataLoader(test_ds, batch_size=64, shuffle=False)
|
|
175
|
+
|
|
176
|
+
print(f"Evaluating on {args.split} split: {len(test_ds)} examples")
|
|
177
|
+
print(f"Device: {device}\n")
|
|
178
|
+
|
|
179
|
+
per_tool_correct, per_tool_total, confusion_pairs, all_predictions = evaluate_detailed(
|
|
180
|
+
model, test_loader, index_to_tool, device
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
results = print_report(per_tool_correct, per_tool_total, confusion_pairs, all_predictions)
|
|
184
|
+
|
|
185
|
+
if args.json:
|
|
186
|
+
print(json.dumps(results, indent=2))
|
|
187
|
+
|
|
188
|
+
# Write eval results alongside checkpoint
|
|
189
|
+
results_path = os.path.join(os.path.dirname(args.checkpoint), "eval-results.json")
|
|
190
|
+
with open(results_path, "w") as f:
|
|
191
|
+
json.dump(results, f, indent=2)
|
|
192
|
+
print(f"\nResults saved to: {results_path}")
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
if __name__ == "__main__":
|
|
196
|
+
main()
|
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Synthetic training data generator for v2 policy head.
|
|
3
|
+
|
|
4
|
+
Generates (current_state, goal, correct_tool) tuples using LLM API.
|
|
5
|
+
Adapted from Drew's generate_data.py for JFL's action taxonomy.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import os
|
|
10
|
+
import sys
|
|
11
|
+
import random
|
|
12
|
+
import argparse
|
|
13
|
+
import time
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
|
|
16
|
+
def load_domain(path: str) -> dict:
|
|
17
|
+
with open(path) as f:
|
|
18
|
+
return json.load(f)
|
|
19
|
+
|
|
20
|
+
def get_llm_client():
|
|
21
|
+
api_key = os.environ.get("OPENAI_API_KEY") or os.environ.get("ANTHROPIC_API_KEY")
|
|
22
|
+
if not api_key:
|
|
23
|
+
print("Set OPENAI_API_KEY or ANTHROPIC_API_KEY for synthetic data generation")
|
|
24
|
+
sys.exit(1)
|
|
25
|
+
|
|
26
|
+
if os.environ.get("OPENAI_API_KEY"):
|
|
27
|
+
from openai import OpenAI
|
|
28
|
+
return OpenAI(), "openai"
|
|
29
|
+
else:
|
|
30
|
+
from anthropic import Anthropic
|
|
31
|
+
return Anthropic(), "anthropic"
|
|
32
|
+
|
|
33
|
+
def generate_goals_for_tool(client, client_type: str, tool: dict, all_tools: list, n: int = 50) -> list:
|
|
34
|
+
tool_list_str = "\n".join(
|
|
35
|
+
f"- {t['name']}: {t['description']}" for t in all_tools
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
prompt = f"""You are generating training data for an AI action-routing system in software development.
|
|
39
|
+
|
|
40
|
+
Given these {len(all_tools)} available action types:
|
|
41
|
+
{tool_list_str}
|
|
42
|
+
|
|
43
|
+
Generate exactly {n} realistic goals/situations that should route to:
|
|
44
|
+
Action: {tool["name"]}
|
|
45
|
+
Description: {tool["description"]}
|
|
46
|
+
|
|
47
|
+
Requirements:
|
|
48
|
+
1. Each goal should describe a real software development situation or user request
|
|
49
|
+
2. Include variety: specific bug reports, vague feature requests, urgent fixes, routine maintenance
|
|
50
|
+
3. Some should be indirect (describe the symptom, not the action needed)
|
|
51
|
+
4. Some should use technical jargon (CI/CD, linting, test coverage, etc.)
|
|
52
|
+
5. Some should include context about current project state
|
|
53
|
+
6. Include goals that could be confused with OTHER action types
|
|
54
|
+
7. Do NOT include the action type name in the goal text
|
|
55
|
+
8. Goals should reflect what an AI coding agent would encounter
|
|
56
|
+
|
|
57
|
+
Output as a JSON object with a "goals" key containing an array of strings. No explanation."""
|
|
58
|
+
|
|
59
|
+
try:
|
|
60
|
+
if client_type == "openai":
|
|
61
|
+
response = client.chat.completions.create(
|
|
62
|
+
model="gpt-4o-mini",
|
|
63
|
+
messages=[{"role": "user", "content": prompt}],
|
|
64
|
+
temperature=0.9,
|
|
65
|
+
response_format={"type": "json_object"},
|
|
66
|
+
)
|
|
67
|
+
result = json.loads(response.choices[0].message.content)
|
|
68
|
+
else:
|
|
69
|
+
response = client.messages.create(
|
|
70
|
+
model="claude-sonnet-4-5-20250514",
|
|
71
|
+
max_tokens=4096,
|
|
72
|
+
messages=[{"role": "user", "content": prompt}],
|
|
73
|
+
)
|
|
74
|
+
text = response.content[0].text
|
|
75
|
+
json_match = text[text.index("{"):text.rindex("}") + 1]
|
|
76
|
+
result = json.loads(json_match)
|
|
77
|
+
|
|
78
|
+
goals = result.get("goals", result.get("queries", []))
|
|
79
|
+
return goals[:n] if isinstance(goals, list) else []
|
|
80
|
+
except Exception as e:
|
|
81
|
+
print(f" Error generating goals for {tool['name']}: {e}")
|
|
82
|
+
return []
|
|
83
|
+
|
|
84
|
+
def generate_states_for_goal(client, client_type: str, goal: str, domain_desc: str, n: int = 2) -> list:
|
|
85
|
+
prompt = f"""You are generating training data for an AI planning system in:
|
|
86
|
+
"{domain_desc}"
|
|
87
|
+
|
|
88
|
+
Given this development goal: "{goal}"
|
|
89
|
+
|
|
90
|
+
Generate {n} different realistic "current state" descriptions that an AI coding agent might be in
|
|
91
|
+
when this goal arises. Each state should describe:
|
|
92
|
+
- Current codebase state (test results, scores, recent changes)
|
|
93
|
+
- What the agent currently knows about the project
|
|
94
|
+
- Any relevant trajectory information (recent actions taken)
|
|
95
|
+
|
|
96
|
+
States should vary in how much context is available and what prior work has been done.
|
|
97
|
+
|
|
98
|
+
Output as a JSON object with a "states" key containing an array of strings (1-3 sentences each)."""
|
|
99
|
+
|
|
100
|
+
try:
|
|
101
|
+
if client_type == "openai":
|
|
102
|
+
response = client.chat.completions.create(
|
|
103
|
+
model="gpt-4o-mini",
|
|
104
|
+
messages=[{"role": "user", "content": prompt}],
|
|
105
|
+
temperature=0.8,
|
|
106
|
+
response_format={"type": "json_object"},
|
|
107
|
+
)
|
|
108
|
+
result = json.loads(response.choices[0].message.content)
|
|
109
|
+
else:
|
|
110
|
+
response = client.messages.create(
|
|
111
|
+
model="claude-sonnet-4-5-20250514",
|
|
112
|
+
max_tokens=2048,
|
|
113
|
+
messages=[{"role": "user", "content": prompt}],
|
|
114
|
+
)
|
|
115
|
+
text = response.content[0].text
|
|
116
|
+
json_match = text[text.index("{"):text.rindex("}") + 1]
|
|
117
|
+
result = json.loads(json_match)
|
|
118
|
+
|
|
119
|
+
states = result.get("states", [])
|
|
120
|
+
return states[:n]
|
|
121
|
+
except Exception as e:
|
|
122
|
+
print(f" Error generating states: {e}")
|
|
123
|
+
return [f"Agent working on codebase. Goal context: {goal[:100]}"]
|
|
124
|
+
|
|
125
|
+
def generate_dataset(
|
|
126
|
+
domain_path: str,
|
|
127
|
+
goals_per_tool: int = 50,
|
|
128
|
+
states_per_goal: int = 2,
|
|
129
|
+
output_dir: str = "data",
|
|
130
|
+
seed: int = 42,
|
|
131
|
+
):
|
|
132
|
+
domain = load_domain(domain_path)
|
|
133
|
+
tools = domain["tools"]
|
|
134
|
+
domain_desc = domain["description"]
|
|
135
|
+
|
|
136
|
+
client, client_type = get_llm_client()
|
|
137
|
+
print(f"Using {client_type} API for generation")
|
|
138
|
+
|
|
139
|
+
all_examples = []
|
|
140
|
+
|
|
141
|
+
for tool_idx, tool in enumerate(tools):
|
|
142
|
+
print(f"[{tool_idx + 1}/{len(tools)}] Generating goals for: {tool['name']}")
|
|
143
|
+
|
|
144
|
+
goals = generate_goals_for_tool(client, client_type, tool, tools, n=goals_per_tool)
|
|
145
|
+
print(f" Generated {len(goals)} goals")
|
|
146
|
+
|
|
147
|
+
for goal_idx, goal in enumerate(goals):
|
|
148
|
+
if goal_idx % 10 == 0:
|
|
149
|
+
print(f" Generating states: {goal_idx}/{len(goals)}")
|
|
150
|
+
|
|
151
|
+
states = generate_states_for_goal(client, client_type, goal, domain_desc, n=states_per_goal)
|
|
152
|
+
|
|
153
|
+
for state in states:
|
|
154
|
+
example = {
|
|
155
|
+
"current_state": state,
|
|
156
|
+
"goal": goal,
|
|
157
|
+
"correct_tool": tool["name"],
|
|
158
|
+
"tool_category": tool["category"],
|
|
159
|
+
"source": "synthetic",
|
|
160
|
+
}
|
|
161
|
+
all_examples.append(example)
|
|
162
|
+
|
|
163
|
+
time.sleep(0.1)
|
|
164
|
+
|
|
165
|
+
random.seed(seed)
|
|
166
|
+
random.shuffle(all_examples)
|
|
167
|
+
|
|
168
|
+
n = len(all_examples)
|
|
169
|
+
train_end = int(n * 0.7)
|
|
170
|
+
val_end = int(n * 0.85)
|
|
171
|
+
|
|
172
|
+
splits = {
|
|
173
|
+
"train": all_examples[:train_end],
|
|
174
|
+
"val": all_examples[train_end:val_end],
|
|
175
|
+
"test": all_examples[val_end:],
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
179
|
+
|
|
180
|
+
for split_name, split_data in splits.items():
|
|
181
|
+
path = os.path.join(output_dir, f"{split_name}.jsonl")
|
|
182
|
+
with open(path, "w") as f:
|
|
183
|
+
for ex in split_data:
|
|
184
|
+
f.write(json.dumps(ex) + "\n")
|
|
185
|
+
print(f" {split_name}: {len(split_data)} examples -> {path}")
|
|
186
|
+
|
|
187
|
+
print(f"\nTotal: {n} examples across {len(tools)} tools")
|
|
188
|
+
print(f" Per tool: ~{n // len(tools)} examples")
|
|
189
|
+
return splits
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def main():
|
|
193
|
+
parser = argparse.ArgumentParser(description="Generate synthetic training data for v2 policy head")
|
|
194
|
+
parser.add_argument("--domain", default=None, help="Path to domain.json")
|
|
195
|
+
parser.add_argument("--output", default=".jfl/v2-data", help="Output directory")
|
|
196
|
+
parser.add_argument("--goals-per-tool", type=int, default=50, help="Goals to generate per tool")
|
|
197
|
+
parser.add_argument("--states-per-goal", type=int, default=2, help="States to generate per goal")
|
|
198
|
+
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
|
199
|
+
args = parser.parse_args()
|
|
200
|
+
|
|
201
|
+
domain_path = args.domain
|
|
202
|
+
if domain_path is None:
|
|
203
|
+
domain_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "domain.json")
|
|
204
|
+
|
|
205
|
+
if not os.path.exists(domain_path):
|
|
206
|
+
print(f"Domain file not found: {domain_path}")
|
|
207
|
+
sys.exit(1)
|
|
208
|
+
|
|
209
|
+
generate_dataset(
|
|
210
|
+
domain_path=domain_path,
|
|
211
|
+
goals_per_tool=args.goals_per_tool,
|
|
212
|
+
states_per_goal=args.states_per_goal,
|
|
213
|
+
output_dir=args.output,
|
|
214
|
+
seed=args.seed,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
if __name__ == "__main__":
|
|
219
|
+
main()
|
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
"""
|
|
2
|
+
v2 Policy Head Inference — CLI script for action selection.
|
|
3
|
+
|
|
4
|
+
Usage:
|
|
5
|
+
python infer.py --checkpoint .jfl/checkpoints/best_policy_head.pt --state "..." --goal "..." --top-k 3
|
|
6
|
+
|
|
7
|
+
Also supports JSON mode for TypeScript bridge:
|
|
8
|
+
python infer.py --checkpoint ... --state "..." --goal "..." --json
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import json
|
|
12
|
+
import os
|
|
13
|
+
import sys
|
|
14
|
+
import argparse
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
import numpy as np
|
|
18
|
+
|
|
19
|
+
from model import PolicyHead
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def load_model(checkpoint_path: str, device: str = "cpu"):
|
|
23
|
+
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
|
24
|
+
config = ckpt["config"]
|
|
25
|
+
tool_to_index = ckpt["tool_to_index"]
|
|
26
|
+
index_to_tool = {v: k for k, v in tool_to_index.items()}
|
|
27
|
+
|
|
28
|
+
model = PolicyHead(
|
|
29
|
+
embedding_dim=config["embedding_dim"],
|
|
30
|
+
hidden_dim=config["hidden_dim"],
|
|
31
|
+
num_tools=ckpt["num_tools"],
|
|
32
|
+
num_layers=config["num_layers"],
|
|
33
|
+
num_heads=config["num_heads"],
|
|
34
|
+
dropout=config.get("dropout", 0.1),
|
|
35
|
+
).to(device)
|
|
36
|
+
|
|
37
|
+
model.load_state_dict(ckpt["model_state_dict"])
|
|
38
|
+
model.eval()
|
|
39
|
+
|
|
40
|
+
return model, tool_to_index, index_to_tool, config
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def get_embedding(text: str, api_url: str, api_key: str) -> list[float]:
|
|
44
|
+
import requests
|
|
45
|
+
|
|
46
|
+
response = requests.post(
|
|
47
|
+
f"{api_url}/v1/embeddings",
|
|
48
|
+
headers={
|
|
49
|
+
"Authorization": f"Bearer {api_key}",
|
|
50
|
+
"Content-Type": "application/json",
|
|
51
|
+
},
|
|
52
|
+
json={
|
|
53
|
+
"model": "stratus-x1ac-base",
|
|
54
|
+
"input": text,
|
|
55
|
+
},
|
|
56
|
+
timeout=15,
|
|
57
|
+
)
|
|
58
|
+
response.raise_for_status()
|
|
59
|
+
data = response.json()
|
|
60
|
+
return data["data"][0]["embedding"]
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def infer(args):
|
|
64
|
+
if torch.cuda.is_available():
|
|
65
|
+
device = "cuda"
|
|
66
|
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
67
|
+
device = "mps"
|
|
68
|
+
else:
|
|
69
|
+
device = "cpu"
|
|
70
|
+
|
|
71
|
+
model, tool_to_index, index_to_tool, config = load_model(args.checkpoint, device)
|
|
72
|
+
|
|
73
|
+
api_url = os.environ.get("STRATUS_API_URL", "https://api.stratus.run")
|
|
74
|
+
api_key = os.environ.get("STRATUS_API_KEY", "")
|
|
75
|
+
|
|
76
|
+
if not api_key:
|
|
77
|
+
print("STRATUS_API_KEY not set", file=sys.stderr)
|
|
78
|
+
sys.exit(1)
|
|
79
|
+
|
|
80
|
+
state_emb = get_embedding(args.state, api_url, api_key)
|
|
81
|
+
goal_emb = get_embedding(args.goal, api_url, api_key)
|
|
82
|
+
|
|
83
|
+
state_tensor = torch.tensor([state_emb], dtype=torch.float32).to(device)
|
|
84
|
+
goal_tensor = torch.tensor([goal_emb], dtype=torch.float32).to(device)
|
|
85
|
+
|
|
86
|
+
result = model.predict(state_tensor, goal_tensor, top_k=args.top_k)
|
|
87
|
+
|
|
88
|
+
top_indices = result["top_k_indices"][0].cpu().tolist()
|
|
89
|
+
top_probs = result["top_k_probs"][0].cpu().tolist()
|
|
90
|
+
|
|
91
|
+
predictions = []
|
|
92
|
+
for idx, prob in zip(top_indices, top_probs):
|
|
93
|
+
predictions.append({
|
|
94
|
+
"action": index_to_tool[idx],
|
|
95
|
+
"confidence": round(prob, 4),
|
|
96
|
+
})
|
|
97
|
+
|
|
98
|
+
if args.json:
|
|
99
|
+
output = {
|
|
100
|
+
"action": predictions[0]["action"],
|
|
101
|
+
"confidence": predictions[0]["confidence"],
|
|
102
|
+
"alternatives": predictions[1:],
|
|
103
|
+
}
|
|
104
|
+
print(json.dumps(output))
|
|
105
|
+
else:
|
|
106
|
+
print(f"\nv2 Policy Head Prediction")
|
|
107
|
+
print(f"{'─' * 40}")
|
|
108
|
+
print(f"State: {args.state[:80]}...")
|
|
109
|
+
print(f"Goal: {args.goal[:80]}...")
|
|
110
|
+
print(f"\nTop {args.top_k} actions:")
|
|
111
|
+
for i, pred in enumerate(predictions):
|
|
112
|
+
marker = "→" if i == 0 else " "
|
|
113
|
+
bar = "█" * int(pred["confidence"] * 20)
|
|
114
|
+
print(f" {marker} {pred['action']:25s} {pred['confidence']:6.1%} {bar}")
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def batch_infer(args):
|
|
118
|
+
"""Batch inference from JSONL input on stdin. Each line: {"state": "...", "goal": "..."}"""
|
|
119
|
+
if torch.cuda.is_available():
|
|
120
|
+
device = "cuda"
|
|
121
|
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
122
|
+
device = "mps"
|
|
123
|
+
else:
|
|
124
|
+
device = "cpu"
|
|
125
|
+
|
|
126
|
+
model, tool_to_index, index_to_tool, config = load_model(args.checkpoint, device)
|
|
127
|
+
|
|
128
|
+
api_url = os.environ.get("STRATUS_API_URL", "https://api.stratus.run")
|
|
129
|
+
api_key = os.environ.get("STRATUS_API_KEY", "")
|
|
130
|
+
|
|
131
|
+
for line in sys.stdin:
|
|
132
|
+
line = line.strip()
|
|
133
|
+
if not line:
|
|
134
|
+
continue
|
|
135
|
+
|
|
136
|
+
try:
|
|
137
|
+
req = json.loads(line)
|
|
138
|
+
state_emb = get_embedding(req["state"], api_url, api_key)
|
|
139
|
+
goal_emb = get_embedding(req["goal"], api_url, api_key)
|
|
140
|
+
|
|
141
|
+
state_tensor = torch.tensor([state_emb], dtype=torch.float32).to(device)
|
|
142
|
+
goal_tensor = torch.tensor([goal_emb], dtype=torch.float32).to(device)
|
|
143
|
+
|
|
144
|
+
result = model.predict(state_tensor, goal_tensor, top_k=args.top_k)
|
|
145
|
+
|
|
146
|
+
top_indices = result["top_k_indices"][0].cpu().tolist()
|
|
147
|
+
top_probs = result["top_k_probs"][0].cpu().tolist()
|
|
148
|
+
|
|
149
|
+
predictions = []
|
|
150
|
+
for idx, prob in zip(top_indices, top_probs):
|
|
151
|
+
predictions.append({
|
|
152
|
+
"action": index_to_tool[idx],
|
|
153
|
+
"confidence": round(prob, 4),
|
|
154
|
+
})
|
|
155
|
+
|
|
156
|
+
output = {
|
|
157
|
+
"action": predictions[0]["action"],
|
|
158
|
+
"confidence": predictions[0]["confidence"],
|
|
159
|
+
"alternatives": predictions[1:],
|
|
160
|
+
}
|
|
161
|
+
print(json.dumps(output))
|
|
162
|
+
sys.stdout.flush()
|
|
163
|
+
except Exception as e:
|
|
164
|
+
print(json.dumps({"error": str(e)}))
|
|
165
|
+
sys.stdout.flush()
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def main():
|
|
169
|
+
parser = argparse.ArgumentParser(description="v2 policy head inference")
|
|
170
|
+
parser.add_argument("--checkpoint", required=True, help="Path to .pt checkpoint")
|
|
171
|
+
parser.add_argument("--state", default=None, help="Current state text")
|
|
172
|
+
parser.add_argument("--goal", default=None, help="Goal text")
|
|
173
|
+
parser.add_argument("--top-k", type=int, default=3, help="Number of top actions")
|
|
174
|
+
parser.add_argument("--json", action="store_true", help="JSON output for TypeScript bridge")
|
|
175
|
+
parser.add_argument("--batch", action="store_true", help="Batch mode: read JSONL from stdin")
|
|
176
|
+
args = parser.parse_args()
|
|
177
|
+
|
|
178
|
+
if args.batch:
|
|
179
|
+
batch_infer(args)
|
|
180
|
+
elif args.state and args.goal:
|
|
181
|
+
infer(args)
|
|
182
|
+
else:
|
|
183
|
+
print("Provide --state and --goal, or use --batch for stdin JSONL")
|
|
184
|
+
sys.exit(1)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
if __name__ == "__main__":
|
|
188
|
+
main()
|