jfl 0.5.0 → 0.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (135) hide show
  1. package/dist/commands/context-hub.d.ts +1 -0
  2. package/dist/commands/context-hub.d.ts.map +1 -1
  3. package/dist/commands/context-hub.js +246 -2
  4. package/dist/commands/context-hub.js.map +1 -1
  5. package/dist/commands/peter.d.ts +2 -0
  6. package/dist/commands/peter.d.ts.map +1 -1
  7. package/dist/commands/peter.js +242 -52
  8. package/dist/commands/peter.js.map +1 -1
  9. package/dist/commands/setup.d.ts +12 -0
  10. package/dist/commands/setup.d.ts.map +1 -0
  11. package/dist/commands/setup.js +322 -0
  12. package/dist/commands/setup.js.map +1 -0
  13. package/dist/commands/train.d.ts +33 -0
  14. package/dist/commands/train.d.ts.map +1 -0
  15. package/dist/commands/train.js +510 -0
  16. package/dist/commands/train.js.map +1 -0
  17. package/dist/commands/verify.d.ts +14 -0
  18. package/dist/commands/verify.d.ts.map +1 -0
  19. package/dist/commands/verify.js +276 -0
  20. package/dist/commands/verify.js.map +1 -0
  21. package/dist/dashboard-static/assets/index-CW9ZxqX8.css +1 -0
  22. package/dist/dashboard-static/assets/index-DNN__p4K.js +121 -0
  23. package/dist/dashboard-static/index.html +2 -2
  24. package/dist/index.js +99 -3
  25. package/dist/index.js.map +1 -1
  26. package/dist/lib/agent-session.d.ts.map +1 -1
  27. package/dist/lib/agent-session.js +12 -4
  28. package/dist/lib/agent-session.js.map +1 -1
  29. package/dist/lib/eval-snapshot.js +1 -1
  30. package/dist/lib/eval-snapshot.js.map +1 -1
  31. package/dist/lib/pi-sky/bridge.d.ts +55 -0
  32. package/dist/lib/pi-sky/bridge.d.ts.map +1 -0
  33. package/dist/lib/pi-sky/bridge.js +264 -0
  34. package/dist/lib/pi-sky/bridge.js.map +1 -0
  35. package/dist/lib/pi-sky/cost-monitor.d.ts +21 -0
  36. package/dist/lib/pi-sky/cost-monitor.d.ts.map +1 -0
  37. package/dist/lib/pi-sky/cost-monitor.js +126 -0
  38. package/dist/lib/pi-sky/cost-monitor.js.map +1 -0
  39. package/dist/lib/pi-sky/eval-sweep.d.ts +27 -0
  40. package/dist/lib/pi-sky/eval-sweep.d.ts.map +1 -0
  41. package/dist/lib/pi-sky/eval-sweep.js +141 -0
  42. package/dist/lib/pi-sky/eval-sweep.js.map +1 -0
  43. package/dist/lib/pi-sky/event-router.d.ts +32 -0
  44. package/dist/lib/pi-sky/event-router.d.ts.map +1 -0
  45. package/dist/lib/pi-sky/event-router.js +176 -0
  46. package/dist/lib/pi-sky/event-router.js.map +1 -0
  47. package/dist/lib/pi-sky/experiment.d.ts +9 -0
  48. package/dist/lib/pi-sky/experiment.d.ts.map +1 -0
  49. package/dist/lib/pi-sky/experiment.js +83 -0
  50. package/dist/lib/pi-sky/experiment.js.map +1 -0
  51. package/dist/lib/pi-sky/index.d.ts +16 -0
  52. package/dist/lib/pi-sky/index.d.ts.map +1 -0
  53. package/dist/lib/pi-sky/index.js +16 -0
  54. package/dist/lib/pi-sky/index.js.map +1 -0
  55. package/dist/lib/pi-sky/stratus-gate.d.ts +28 -0
  56. package/dist/lib/pi-sky/stratus-gate.d.ts.map +1 -0
  57. package/dist/lib/pi-sky/stratus-gate.js +61 -0
  58. package/dist/lib/pi-sky/stratus-gate.js.map +1 -0
  59. package/dist/lib/pi-sky/swarm.d.ts +28 -0
  60. package/dist/lib/pi-sky/swarm.d.ts.map +1 -0
  61. package/dist/lib/pi-sky/swarm.js +208 -0
  62. package/dist/lib/pi-sky/swarm.js.map +1 -0
  63. package/dist/lib/pi-sky/types.d.ts +139 -0
  64. package/dist/lib/pi-sky/types.d.ts.map +1 -0
  65. package/dist/lib/pi-sky/types.js +2 -0
  66. package/dist/lib/pi-sky/types.js.map +1 -0
  67. package/dist/lib/pi-sky/voice-bridge.d.ts +20 -0
  68. package/dist/lib/pi-sky/voice-bridge.d.ts.map +1 -0
  69. package/dist/lib/pi-sky/voice-bridge.js +91 -0
  70. package/dist/lib/pi-sky/voice-bridge.js.map +1 -0
  71. package/dist/lib/policy-head.d.ts +16 -1
  72. package/dist/lib/policy-head.d.ts.map +1 -1
  73. package/dist/lib/policy-head.js +117 -19
  74. package/dist/lib/policy-head.js.map +1 -1
  75. package/dist/lib/predictor.d.ts +10 -0
  76. package/dist/lib/predictor.d.ts.map +1 -1
  77. package/dist/lib/predictor.js +46 -7
  78. package/dist/lib/predictor.js.map +1 -1
  79. package/dist/lib/setup/agent-generator.d.ts +18 -0
  80. package/dist/lib/setup/agent-generator.d.ts.map +1 -0
  81. package/dist/lib/setup/agent-generator.js +114 -0
  82. package/dist/lib/setup/agent-generator.js.map +1 -0
  83. package/dist/lib/setup/context-analyzer.d.ts +16 -0
  84. package/dist/lib/setup/context-analyzer.d.ts.map +1 -0
  85. package/dist/lib/setup/context-analyzer.js +112 -0
  86. package/dist/lib/setup/context-analyzer.js.map +1 -0
  87. package/dist/lib/setup/doc-auditor.d.ts +54 -0
  88. package/dist/lib/setup/doc-auditor.d.ts.map +1 -0
  89. package/dist/lib/setup/doc-auditor.js +629 -0
  90. package/dist/lib/setup/doc-auditor.js.map +1 -0
  91. package/dist/lib/setup/domain-generator.d.ts +7 -0
  92. package/dist/lib/setup/domain-generator.d.ts.map +1 -0
  93. package/dist/lib/setup/domain-generator.js +58 -0
  94. package/dist/lib/setup/domain-generator.js.map +1 -0
  95. package/dist/lib/setup/smart-eval-generator.d.ts +38 -0
  96. package/dist/lib/setup/smart-eval-generator.d.ts.map +1 -0
  97. package/dist/lib/setup/smart-eval-generator.js +378 -0
  98. package/dist/lib/setup/smart-eval-generator.js.map +1 -0
  99. package/dist/lib/setup/smart-recommender.d.ts +63 -0
  100. package/dist/lib/setup/smart-recommender.d.ts.map +1 -0
  101. package/dist/lib/setup/smart-recommender.js +329 -0
  102. package/dist/lib/setup/smart-recommender.js.map +1 -0
  103. package/dist/lib/setup/spec-generator.d.ts +63 -0
  104. package/dist/lib/setup/spec-generator.d.ts.map +1 -0
  105. package/dist/lib/setup/spec-generator.js +310 -0
  106. package/dist/lib/setup/spec-generator.js.map +1 -0
  107. package/dist/lib/setup/violation-agent-generator.d.ts +32 -0
  108. package/dist/lib/setup/violation-agent-generator.d.ts.map +1 -0
  109. package/dist/lib/setup/violation-agent-generator.js +255 -0
  110. package/dist/lib/setup/violation-agent-generator.js.map +1 -0
  111. package/package.json +1 -1
  112. package/packages/pi/extensions/context.ts +88 -55
  113. package/packages/pi/extensions/hub-resolver.ts +63 -0
  114. package/packages/pi/extensions/index.ts +16 -3
  115. package/packages/pi/extensions/memory-tool.ts +9 -4
  116. package/packages/pi/extensions/session.ts +68 -16
  117. package/packages/pi/extensions/tool-renderers.ts +23 -8
  118. package/scripts/train/requirements.txt +5 -0
  119. package/scripts/train/train-policy-head.py +477 -0
  120. package/scripts/train/v2/dataset.py +81 -0
  121. package/scripts/train/v2/domain.json +18 -0
  122. package/scripts/train/v2/eval.py +196 -0
  123. package/scripts/train/v2/generate_data.py +219 -0
  124. package/scripts/train/v2/infer.py +188 -0
  125. package/scripts/train/v2/model.py +112 -0
  126. package/scripts/train/v2/precompute.py +132 -0
  127. package/scripts/train/v2/train.py +302 -0
  128. package/scripts/train/v2/transform_buffer.py +227 -0
  129. package/scripts/train/v2/validate_data.py +115 -0
  130. package/template/.claude/settings.json +2 -15
  131. package/template/scripts/session/session-cleanup.sh +2 -11
  132. package/template/scripts/session/session-end-hub.sh +72 -0
  133. package/template/scripts/session/session-start-hub.sh +105 -0
  134. package/dist/dashboard-static/assets/index-B6b867Pv.js +0 -121
  135. package/dist/dashboard-static/assets/index-Y4BrqxV-.css +0 -1
@@ -0,0 +1,196 @@
1
+ """
2
+ v2 Policy Head Evaluation — per-tool accuracy, confusion matrix, confidence analysis.
3
+ """
4
+
5
+ import json
6
+ import os
7
+ import sys
8
+ import argparse
9
+ from collections import defaultdict
10
+
11
+ import numpy as np
12
+ import torch
13
+ from torch.utils.data import DataLoader
14
+
15
+ from model import PolicyHead
16
+ from dataset import PolicyHeadDataset, load_embedding_cache
17
+
18
+
19
+ def load_model(checkpoint_path: str, device: str = "cpu") -> tuple:
20
+ ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
21
+ config = ckpt["config"]
22
+ tool_to_index = ckpt["tool_to_index"]
23
+ index_to_tool = {v: k for k, v in tool_to_index.items()}
24
+
25
+ model = PolicyHead(
26
+ embedding_dim=config["embedding_dim"],
27
+ hidden_dim=config["hidden_dim"],
28
+ num_tools=ckpt["num_tools"],
29
+ num_layers=config["num_layers"],
30
+ num_heads=config["num_heads"],
31
+ dropout=config.get("dropout", 0.1),
32
+ ).to(device)
33
+
34
+ model.load_state_dict(ckpt["model_state_dict"])
35
+ model.eval()
36
+
37
+ return model, tool_to_index, index_to_tool
38
+
39
+
40
+ @torch.no_grad()
41
+ def evaluate_detailed(model, dataloader, index_to_tool, device):
42
+ model.eval()
43
+
44
+ per_tool_correct = defaultdict(int)
45
+ per_tool_total = defaultdict(int)
46
+ confusion_pairs = defaultdict(int)
47
+ all_predictions = []
48
+
49
+ for batch in dataloader:
50
+ state_emb = batch["state_emb"].to(device)
51
+ goal_emb = batch["goal_emb"].to(device)
52
+ labels = batch["label"].to(device)
53
+ tool_names = batch["tool_name"]
54
+
55
+ logits = model(state_emb, goal_emb)
56
+ probs = torch.softmax(logits, dim=-1)
57
+ preds = logits.argmax(dim=-1)
58
+
59
+ for i in range(labels.size(0)):
60
+ true_tool = tool_names[i]
61
+ pred_tool = index_to_tool[preds[i].item()]
62
+ confidence = probs[i, preds[i]].item()
63
+
64
+ per_tool_total[true_tool] += 1
65
+ if preds[i] == labels[i]:
66
+ per_tool_correct[true_tool] += 1
67
+ else:
68
+ confusion_pairs[(true_tool, pred_tool)] += 1
69
+
70
+ all_predictions.append({
71
+ "true": true_tool,
72
+ "predicted": pred_tool,
73
+ "correct": preds[i].item() == labels[i].item(),
74
+ "confidence": confidence,
75
+ })
76
+
77
+ return per_tool_correct, per_tool_total, confusion_pairs, all_predictions
78
+
79
+
80
+ def print_report(per_tool_correct, per_tool_total, confusion_pairs, all_predictions):
81
+ total_correct = sum(per_tool_correct.values())
82
+ total = sum(per_tool_total.values())
83
+ overall_acc = total_correct / max(total, 1)
84
+
85
+ print("=" * 70)
86
+ print(" v2 POLICY HEAD — EVALUATION REPORT")
87
+ print("=" * 70)
88
+ print(f"\n Overall Accuracy: {overall_acc:.1%} ({total_correct}/{total})")
89
+ print(f"\n{'Tool':<30} {'Accuracy':>10} {'Correct':>9} {'Total':>7}")
90
+ print("-" * 60)
91
+
92
+ for tool in sorted(per_tool_total.keys()):
93
+ correct = per_tool_correct.get(tool, 0)
94
+ total_t = per_tool_total[tool]
95
+ acc = correct / total_t if total_t > 0 else 0
96
+ bar = "█" * int(acc * 20) + "░" * (20 - int(acc * 20))
97
+ print(f" {tool:<28} {acc:>8.1%} {correct:>5}/{total_t:<5} {bar}")
98
+
99
+ if confusion_pairs:
100
+ print(f"\n{'Top Confusion Pairs':<45} {'Count':>7}")
101
+ print("-" * 55)
102
+ for (true_t, pred_t), count in sorted(
103
+ confusion_pairs.items(), key=lambda x: -x[1]
104
+ )[:10]:
105
+ print(f" {true_t} -> {pred_t:<25} {count:>5}")
106
+
107
+ correct_confs = [p["confidence"] for p in all_predictions if p["correct"]]
108
+ wrong_confs = [p["confidence"] for p in all_predictions if not p["correct"]]
109
+
110
+ print(f"\n Confidence Analysis:")
111
+ if correct_confs:
112
+ print(f" Correct predictions: avg={sum(correct_confs) / len(correct_confs):.3f}")
113
+ if wrong_confs:
114
+ print(f" Wrong predictions: avg={sum(wrong_confs) / len(wrong_confs):.3f}")
115
+
116
+ if overall_acc < 0.80:
117
+ print(f"\n ⚠️ Accuracy below 80% — consider more data or warm-start")
118
+ elif overall_acc < 0.90:
119
+ print(f"\n ⚠️ Accuracy below 90% — good but can improve")
120
+ else:
121
+ print(f"\n ✅ Accuracy ≥90% — ready for deployment")
122
+
123
+ print("=" * 70)
124
+
125
+ return {
126
+ "overall_accuracy": overall_acc,
127
+ "total": total,
128
+ "per_tool": {
129
+ tool: {
130
+ "accuracy": per_tool_correct.get(tool, 0) / per_tool_total[tool] if per_tool_total[tool] > 0 else 0,
131
+ "correct": per_tool_correct.get(tool, 0),
132
+ "total": per_tool_total[tool],
133
+ }
134
+ for tool in per_tool_total
135
+ },
136
+ "top_confusion": [
137
+ {"true": t, "predicted": p, "count": c}
138
+ for (t, p), c in sorted(confusion_pairs.items(), key=lambda x: -x[1])[:10]
139
+ ],
140
+ "avg_confidence_correct": sum(correct_confs) / len(correct_confs) if correct_confs else 0,
141
+ "avg_confidence_wrong": sum(wrong_confs) / len(wrong_confs) if wrong_confs else 0,
142
+ }
143
+
144
+
145
+ def main():
146
+ parser = argparse.ArgumentParser(description="Evaluate v2 policy head")
147
+ parser.add_argument("--checkpoint", default=".jfl/checkpoints/best_policy_head.pt", help="Path to checkpoint")
148
+ parser.add_argument("--data-dir", default=".jfl/v2-data", help="Directory with test.jsonl + embeddings")
149
+ parser.add_argument("--split", default="test", help="Which split to evaluate (train/val/test)")
150
+ parser.add_argument("--json", action="store_true", help="Output results as JSON")
151
+ args = parser.parse_args()
152
+
153
+ if not os.path.exists(args.checkpoint):
154
+ print(f"Checkpoint not found: {args.checkpoint}")
155
+ sys.exit(1)
156
+
157
+ if torch.cuda.is_available():
158
+ device = "cuda"
159
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
160
+ device = "mps"
161
+ else:
162
+ device = "cpu"
163
+
164
+ model, tool_to_index, index_to_tool = load_model(args.checkpoint, device)
165
+
166
+ data_path = os.path.join(args.data_dir, f"{args.split}.jsonl")
167
+ if not os.path.exists(data_path):
168
+ print(f"Data not found: {data_path}")
169
+ sys.exit(1)
170
+
171
+ embeddings_matrix, text_to_idx = load_embedding_cache(args.data_dir)
172
+
173
+ test_ds = PolicyHeadDataset(data_path, tool_to_index, embeddings_matrix, text_to_idx)
174
+ test_loader = DataLoader(test_ds, batch_size=64, shuffle=False)
175
+
176
+ print(f"Evaluating on {args.split} split: {len(test_ds)} examples")
177
+ print(f"Device: {device}\n")
178
+
179
+ per_tool_correct, per_tool_total, confusion_pairs, all_predictions = evaluate_detailed(
180
+ model, test_loader, index_to_tool, device
181
+ )
182
+
183
+ results = print_report(per_tool_correct, per_tool_total, confusion_pairs, all_predictions)
184
+
185
+ if args.json:
186
+ print(json.dumps(results, indent=2))
187
+
188
+ # Write eval results alongside checkpoint
189
+ results_path = os.path.join(os.path.dirname(args.checkpoint), "eval-results.json")
190
+ with open(results_path, "w") as f:
191
+ json.dump(results, f, indent=2)
192
+ print(f"\nResults saved to: {results_path}")
193
+
194
+
195
+ if __name__ == "__main__":
196
+ main()
@@ -0,0 +1,219 @@
1
+ """
2
+ Synthetic training data generator for v2 policy head.
3
+
4
+ Generates (current_state, goal, correct_tool) tuples using LLM API.
5
+ Adapted from Drew's generate_data.py for JFL's action taxonomy.
6
+ """
7
+
8
+ import json
9
+ import os
10
+ import sys
11
+ import random
12
+ import argparse
13
+ import time
14
+ from pathlib import Path
15
+
16
+ def load_domain(path: str) -> dict:
17
+ with open(path) as f:
18
+ return json.load(f)
19
+
20
+ def get_llm_client():
21
+ api_key = os.environ.get("OPENAI_API_KEY") or os.environ.get("ANTHROPIC_API_KEY")
22
+ if not api_key:
23
+ print("Set OPENAI_API_KEY or ANTHROPIC_API_KEY for synthetic data generation")
24
+ sys.exit(1)
25
+
26
+ if os.environ.get("OPENAI_API_KEY"):
27
+ from openai import OpenAI
28
+ return OpenAI(), "openai"
29
+ else:
30
+ from anthropic import Anthropic
31
+ return Anthropic(), "anthropic"
32
+
33
+ def generate_goals_for_tool(client, client_type: str, tool: dict, all_tools: list, n: int = 50) -> list:
34
+ tool_list_str = "\n".join(
35
+ f"- {t['name']}: {t['description']}" for t in all_tools
36
+ )
37
+
38
+ prompt = f"""You are generating training data for an AI action-routing system in software development.
39
+
40
+ Given these {len(all_tools)} available action types:
41
+ {tool_list_str}
42
+
43
+ Generate exactly {n} realistic goals/situations that should route to:
44
+ Action: {tool["name"]}
45
+ Description: {tool["description"]}
46
+
47
+ Requirements:
48
+ 1. Each goal should describe a real software development situation or user request
49
+ 2. Include variety: specific bug reports, vague feature requests, urgent fixes, routine maintenance
50
+ 3. Some should be indirect (describe the symptom, not the action needed)
51
+ 4. Some should use technical jargon (CI/CD, linting, test coverage, etc.)
52
+ 5. Some should include context about current project state
53
+ 6. Include goals that could be confused with OTHER action types
54
+ 7. Do NOT include the action type name in the goal text
55
+ 8. Goals should reflect what an AI coding agent would encounter
56
+
57
+ Output as a JSON object with a "goals" key containing an array of strings. No explanation."""
58
+
59
+ try:
60
+ if client_type == "openai":
61
+ response = client.chat.completions.create(
62
+ model="gpt-4o-mini",
63
+ messages=[{"role": "user", "content": prompt}],
64
+ temperature=0.9,
65
+ response_format={"type": "json_object"},
66
+ )
67
+ result = json.loads(response.choices[0].message.content)
68
+ else:
69
+ response = client.messages.create(
70
+ model="claude-sonnet-4-5-20250514",
71
+ max_tokens=4096,
72
+ messages=[{"role": "user", "content": prompt}],
73
+ )
74
+ text = response.content[0].text
75
+ json_match = text[text.index("{"):text.rindex("}") + 1]
76
+ result = json.loads(json_match)
77
+
78
+ goals = result.get("goals", result.get("queries", []))
79
+ return goals[:n] if isinstance(goals, list) else []
80
+ except Exception as e:
81
+ print(f" Error generating goals for {tool['name']}: {e}")
82
+ return []
83
+
84
+ def generate_states_for_goal(client, client_type: str, goal: str, domain_desc: str, n: int = 2) -> list:
85
+ prompt = f"""You are generating training data for an AI planning system in:
86
+ "{domain_desc}"
87
+
88
+ Given this development goal: "{goal}"
89
+
90
+ Generate {n} different realistic "current state" descriptions that an AI coding agent might be in
91
+ when this goal arises. Each state should describe:
92
+ - Current codebase state (test results, scores, recent changes)
93
+ - What the agent currently knows about the project
94
+ - Any relevant trajectory information (recent actions taken)
95
+
96
+ States should vary in how much context is available and what prior work has been done.
97
+
98
+ Output as a JSON object with a "states" key containing an array of strings (1-3 sentences each)."""
99
+
100
+ try:
101
+ if client_type == "openai":
102
+ response = client.chat.completions.create(
103
+ model="gpt-4o-mini",
104
+ messages=[{"role": "user", "content": prompt}],
105
+ temperature=0.8,
106
+ response_format={"type": "json_object"},
107
+ )
108
+ result = json.loads(response.choices[0].message.content)
109
+ else:
110
+ response = client.messages.create(
111
+ model="claude-sonnet-4-5-20250514",
112
+ max_tokens=2048,
113
+ messages=[{"role": "user", "content": prompt}],
114
+ )
115
+ text = response.content[0].text
116
+ json_match = text[text.index("{"):text.rindex("}") + 1]
117
+ result = json.loads(json_match)
118
+
119
+ states = result.get("states", [])
120
+ return states[:n]
121
+ except Exception as e:
122
+ print(f" Error generating states: {e}")
123
+ return [f"Agent working on codebase. Goal context: {goal[:100]}"]
124
+
125
+ def generate_dataset(
126
+ domain_path: str,
127
+ goals_per_tool: int = 50,
128
+ states_per_goal: int = 2,
129
+ output_dir: str = "data",
130
+ seed: int = 42,
131
+ ):
132
+ domain = load_domain(domain_path)
133
+ tools = domain["tools"]
134
+ domain_desc = domain["description"]
135
+
136
+ client, client_type = get_llm_client()
137
+ print(f"Using {client_type} API for generation")
138
+
139
+ all_examples = []
140
+
141
+ for tool_idx, tool in enumerate(tools):
142
+ print(f"[{tool_idx + 1}/{len(tools)}] Generating goals for: {tool['name']}")
143
+
144
+ goals = generate_goals_for_tool(client, client_type, tool, tools, n=goals_per_tool)
145
+ print(f" Generated {len(goals)} goals")
146
+
147
+ for goal_idx, goal in enumerate(goals):
148
+ if goal_idx % 10 == 0:
149
+ print(f" Generating states: {goal_idx}/{len(goals)}")
150
+
151
+ states = generate_states_for_goal(client, client_type, goal, domain_desc, n=states_per_goal)
152
+
153
+ for state in states:
154
+ example = {
155
+ "current_state": state,
156
+ "goal": goal,
157
+ "correct_tool": tool["name"],
158
+ "tool_category": tool["category"],
159
+ "source": "synthetic",
160
+ }
161
+ all_examples.append(example)
162
+
163
+ time.sleep(0.1)
164
+
165
+ random.seed(seed)
166
+ random.shuffle(all_examples)
167
+
168
+ n = len(all_examples)
169
+ train_end = int(n * 0.7)
170
+ val_end = int(n * 0.85)
171
+
172
+ splits = {
173
+ "train": all_examples[:train_end],
174
+ "val": all_examples[train_end:val_end],
175
+ "test": all_examples[val_end:],
176
+ }
177
+
178
+ os.makedirs(output_dir, exist_ok=True)
179
+
180
+ for split_name, split_data in splits.items():
181
+ path = os.path.join(output_dir, f"{split_name}.jsonl")
182
+ with open(path, "w") as f:
183
+ for ex in split_data:
184
+ f.write(json.dumps(ex) + "\n")
185
+ print(f" {split_name}: {len(split_data)} examples -> {path}")
186
+
187
+ print(f"\nTotal: {n} examples across {len(tools)} tools")
188
+ print(f" Per tool: ~{n // len(tools)} examples")
189
+ return splits
190
+
191
+
192
+ def main():
193
+ parser = argparse.ArgumentParser(description="Generate synthetic training data for v2 policy head")
194
+ parser.add_argument("--domain", default=None, help="Path to domain.json")
195
+ parser.add_argument("--output", default=".jfl/v2-data", help="Output directory")
196
+ parser.add_argument("--goals-per-tool", type=int, default=50, help="Goals to generate per tool")
197
+ parser.add_argument("--states-per-goal", type=int, default=2, help="States to generate per goal")
198
+ parser.add_argument("--seed", type=int, default=42, help="Random seed")
199
+ args = parser.parse_args()
200
+
201
+ domain_path = args.domain
202
+ if domain_path is None:
203
+ domain_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "domain.json")
204
+
205
+ if not os.path.exists(domain_path):
206
+ print(f"Domain file not found: {domain_path}")
207
+ sys.exit(1)
208
+
209
+ generate_dataset(
210
+ domain_path=domain_path,
211
+ goals_per_tool=args.goals_per_tool,
212
+ states_per_goal=args.states_per_goal,
213
+ output_dir=args.output,
214
+ seed=args.seed,
215
+ )
216
+
217
+
218
+ if __name__ == "__main__":
219
+ main()
@@ -0,0 +1,188 @@
1
+ """
2
+ v2 Policy Head Inference — CLI script for action selection.
3
+
4
+ Usage:
5
+ python infer.py --checkpoint .jfl/checkpoints/best_policy_head.pt --state "..." --goal "..." --top-k 3
6
+
7
+ Also supports JSON mode for TypeScript bridge:
8
+ python infer.py --checkpoint ... --state "..." --goal "..." --json
9
+ """
10
+
11
+ import json
12
+ import os
13
+ import sys
14
+ import argparse
15
+
16
+ import torch
17
+ import numpy as np
18
+
19
+ from model import PolicyHead
20
+
21
+
22
+ def load_model(checkpoint_path: str, device: str = "cpu"):
23
+ ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
24
+ config = ckpt["config"]
25
+ tool_to_index = ckpt["tool_to_index"]
26
+ index_to_tool = {v: k for k, v in tool_to_index.items()}
27
+
28
+ model = PolicyHead(
29
+ embedding_dim=config["embedding_dim"],
30
+ hidden_dim=config["hidden_dim"],
31
+ num_tools=ckpt["num_tools"],
32
+ num_layers=config["num_layers"],
33
+ num_heads=config["num_heads"],
34
+ dropout=config.get("dropout", 0.1),
35
+ ).to(device)
36
+
37
+ model.load_state_dict(ckpt["model_state_dict"])
38
+ model.eval()
39
+
40
+ return model, tool_to_index, index_to_tool, config
41
+
42
+
43
+ def get_embedding(text: str, api_url: str, api_key: str) -> list[float]:
44
+ import requests
45
+
46
+ response = requests.post(
47
+ f"{api_url}/v1/embeddings",
48
+ headers={
49
+ "Authorization": f"Bearer {api_key}",
50
+ "Content-Type": "application/json",
51
+ },
52
+ json={
53
+ "model": "stratus-x1ac-base",
54
+ "input": text,
55
+ },
56
+ timeout=15,
57
+ )
58
+ response.raise_for_status()
59
+ data = response.json()
60
+ return data["data"][0]["embedding"]
61
+
62
+
63
+ def infer(args):
64
+ if torch.cuda.is_available():
65
+ device = "cuda"
66
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
67
+ device = "mps"
68
+ else:
69
+ device = "cpu"
70
+
71
+ model, tool_to_index, index_to_tool, config = load_model(args.checkpoint, device)
72
+
73
+ api_url = os.environ.get("STRATUS_API_URL", "https://api.stratus.run")
74
+ api_key = os.environ.get("STRATUS_API_KEY", "")
75
+
76
+ if not api_key:
77
+ print("STRATUS_API_KEY not set", file=sys.stderr)
78
+ sys.exit(1)
79
+
80
+ state_emb = get_embedding(args.state, api_url, api_key)
81
+ goal_emb = get_embedding(args.goal, api_url, api_key)
82
+
83
+ state_tensor = torch.tensor([state_emb], dtype=torch.float32).to(device)
84
+ goal_tensor = torch.tensor([goal_emb], dtype=torch.float32).to(device)
85
+
86
+ result = model.predict(state_tensor, goal_tensor, top_k=args.top_k)
87
+
88
+ top_indices = result["top_k_indices"][0].cpu().tolist()
89
+ top_probs = result["top_k_probs"][0].cpu().tolist()
90
+
91
+ predictions = []
92
+ for idx, prob in zip(top_indices, top_probs):
93
+ predictions.append({
94
+ "action": index_to_tool[idx],
95
+ "confidence": round(prob, 4),
96
+ })
97
+
98
+ if args.json:
99
+ output = {
100
+ "action": predictions[0]["action"],
101
+ "confidence": predictions[0]["confidence"],
102
+ "alternatives": predictions[1:],
103
+ }
104
+ print(json.dumps(output))
105
+ else:
106
+ print(f"\nv2 Policy Head Prediction")
107
+ print(f"{'─' * 40}")
108
+ print(f"State: {args.state[:80]}...")
109
+ print(f"Goal: {args.goal[:80]}...")
110
+ print(f"\nTop {args.top_k} actions:")
111
+ for i, pred in enumerate(predictions):
112
+ marker = "→" if i == 0 else " "
113
+ bar = "█" * int(pred["confidence"] * 20)
114
+ print(f" {marker} {pred['action']:25s} {pred['confidence']:6.1%} {bar}")
115
+
116
+
117
+ def batch_infer(args):
118
+ """Batch inference from JSONL input on stdin. Each line: {"state": "...", "goal": "..."}"""
119
+ if torch.cuda.is_available():
120
+ device = "cuda"
121
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
122
+ device = "mps"
123
+ else:
124
+ device = "cpu"
125
+
126
+ model, tool_to_index, index_to_tool, config = load_model(args.checkpoint, device)
127
+
128
+ api_url = os.environ.get("STRATUS_API_URL", "https://api.stratus.run")
129
+ api_key = os.environ.get("STRATUS_API_KEY", "")
130
+
131
+ for line in sys.stdin:
132
+ line = line.strip()
133
+ if not line:
134
+ continue
135
+
136
+ try:
137
+ req = json.loads(line)
138
+ state_emb = get_embedding(req["state"], api_url, api_key)
139
+ goal_emb = get_embedding(req["goal"], api_url, api_key)
140
+
141
+ state_tensor = torch.tensor([state_emb], dtype=torch.float32).to(device)
142
+ goal_tensor = torch.tensor([goal_emb], dtype=torch.float32).to(device)
143
+
144
+ result = model.predict(state_tensor, goal_tensor, top_k=args.top_k)
145
+
146
+ top_indices = result["top_k_indices"][0].cpu().tolist()
147
+ top_probs = result["top_k_probs"][0].cpu().tolist()
148
+
149
+ predictions = []
150
+ for idx, prob in zip(top_indices, top_probs):
151
+ predictions.append({
152
+ "action": index_to_tool[idx],
153
+ "confidence": round(prob, 4),
154
+ })
155
+
156
+ output = {
157
+ "action": predictions[0]["action"],
158
+ "confidence": predictions[0]["confidence"],
159
+ "alternatives": predictions[1:],
160
+ }
161
+ print(json.dumps(output))
162
+ sys.stdout.flush()
163
+ except Exception as e:
164
+ print(json.dumps({"error": str(e)}))
165
+ sys.stdout.flush()
166
+
167
+
168
+ def main():
169
+ parser = argparse.ArgumentParser(description="v2 policy head inference")
170
+ parser.add_argument("--checkpoint", required=True, help="Path to .pt checkpoint")
171
+ parser.add_argument("--state", default=None, help="Current state text")
172
+ parser.add_argument("--goal", default=None, help="Goal text")
173
+ parser.add_argument("--top-k", type=int, default=3, help="Number of top actions")
174
+ parser.add_argument("--json", action="store_true", help="JSON output for TypeScript bridge")
175
+ parser.add_argument("--batch", action="store_true", help="Batch mode: read JSONL from stdin")
176
+ args = parser.parse_args()
177
+
178
+ if args.batch:
179
+ batch_infer(args)
180
+ elif args.state and args.goal:
181
+ infer(args)
182
+ else:
183
+ print("Provide --state and --goal, or use --batch for stdin JSONL")
184
+ sys.exit(1)
185
+
186
+
187
+ if __name__ == "__main__":
188
+ main()