jfl 0.5.0 → 0.6.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/commands/context-hub.d.ts +1 -0
- package/dist/commands/context-hub.d.ts.map +1 -1
- package/dist/commands/context-hub.js +246 -2
- package/dist/commands/context-hub.js.map +1 -1
- package/dist/commands/peter.d.ts +2 -0
- package/dist/commands/peter.d.ts.map +1 -1
- package/dist/commands/peter.js +242 -52
- package/dist/commands/peter.js.map +1 -1
- package/dist/commands/setup.d.ts +12 -0
- package/dist/commands/setup.d.ts.map +1 -0
- package/dist/commands/setup.js +322 -0
- package/dist/commands/setup.js.map +1 -0
- package/dist/commands/train.d.ts +33 -0
- package/dist/commands/train.d.ts.map +1 -0
- package/dist/commands/train.js +510 -0
- package/dist/commands/train.js.map +1 -0
- package/dist/commands/verify.d.ts +14 -0
- package/dist/commands/verify.d.ts.map +1 -0
- package/dist/commands/verify.js +276 -0
- package/dist/commands/verify.js.map +1 -0
- package/dist/dashboard-static/assets/index-CW9ZxqX8.css +1 -0
- package/dist/dashboard-static/assets/index-DNN__p4K.js +121 -0
- package/dist/dashboard-static/index.html +2 -2
- package/dist/index.js +99 -3
- package/dist/index.js.map +1 -1
- package/dist/lib/agent-session.d.ts.map +1 -1
- package/dist/lib/agent-session.js +12 -4
- package/dist/lib/agent-session.js.map +1 -1
- package/dist/lib/eval-snapshot.js +1 -1
- package/dist/lib/eval-snapshot.js.map +1 -1
- package/dist/lib/pi-sky/bridge.d.ts +55 -0
- package/dist/lib/pi-sky/bridge.d.ts.map +1 -0
- package/dist/lib/pi-sky/bridge.js +264 -0
- package/dist/lib/pi-sky/bridge.js.map +1 -0
- package/dist/lib/pi-sky/cost-monitor.d.ts +21 -0
- package/dist/lib/pi-sky/cost-monitor.d.ts.map +1 -0
- package/dist/lib/pi-sky/cost-monitor.js +126 -0
- package/dist/lib/pi-sky/cost-monitor.js.map +1 -0
- package/dist/lib/pi-sky/eval-sweep.d.ts +27 -0
- package/dist/lib/pi-sky/eval-sweep.d.ts.map +1 -0
- package/dist/lib/pi-sky/eval-sweep.js +141 -0
- package/dist/lib/pi-sky/eval-sweep.js.map +1 -0
- package/dist/lib/pi-sky/event-router.d.ts +32 -0
- package/dist/lib/pi-sky/event-router.d.ts.map +1 -0
- package/dist/lib/pi-sky/event-router.js +176 -0
- package/dist/lib/pi-sky/event-router.js.map +1 -0
- package/dist/lib/pi-sky/experiment.d.ts +9 -0
- package/dist/lib/pi-sky/experiment.d.ts.map +1 -0
- package/dist/lib/pi-sky/experiment.js +83 -0
- package/dist/lib/pi-sky/experiment.js.map +1 -0
- package/dist/lib/pi-sky/index.d.ts +16 -0
- package/dist/lib/pi-sky/index.d.ts.map +1 -0
- package/dist/lib/pi-sky/index.js +16 -0
- package/dist/lib/pi-sky/index.js.map +1 -0
- package/dist/lib/pi-sky/stratus-gate.d.ts +28 -0
- package/dist/lib/pi-sky/stratus-gate.d.ts.map +1 -0
- package/dist/lib/pi-sky/stratus-gate.js +61 -0
- package/dist/lib/pi-sky/stratus-gate.js.map +1 -0
- package/dist/lib/pi-sky/swarm.d.ts +28 -0
- package/dist/lib/pi-sky/swarm.d.ts.map +1 -0
- package/dist/lib/pi-sky/swarm.js +208 -0
- package/dist/lib/pi-sky/swarm.js.map +1 -0
- package/dist/lib/pi-sky/types.d.ts +139 -0
- package/dist/lib/pi-sky/types.d.ts.map +1 -0
- package/dist/lib/pi-sky/types.js +2 -0
- package/dist/lib/pi-sky/types.js.map +1 -0
- package/dist/lib/pi-sky/voice-bridge.d.ts +20 -0
- package/dist/lib/pi-sky/voice-bridge.d.ts.map +1 -0
- package/dist/lib/pi-sky/voice-bridge.js +91 -0
- package/dist/lib/pi-sky/voice-bridge.js.map +1 -0
- package/dist/lib/policy-head.d.ts +16 -1
- package/dist/lib/policy-head.d.ts.map +1 -1
- package/dist/lib/policy-head.js +117 -19
- package/dist/lib/policy-head.js.map +1 -1
- package/dist/lib/predictor.d.ts +10 -0
- package/dist/lib/predictor.d.ts.map +1 -1
- package/dist/lib/predictor.js +46 -7
- package/dist/lib/predictor.js.map +1 -1
- package/dist/lib/setup/agent-generator.d.ts +18 -0
- package/dist/lib/setup/agent-generator.d.ts.map +1 -0
- package/dist/lib/setup/agent-generator.js +114 -0
- package/dist/lib/setup/agent-generator.js.map +1 -0
- package/dist/lib/setup/context-analyzer.d.ts +16 -0
- package/dist/lib/setup/context-analyzer.d.ts.map +1 -0
- package/dist/lib/setup/context-analyzer.js +112 -0
- package/dist/lib/setup/context-analyzer.js.map +1 -0
- package/dist/lib/setup/doc-auditor.d.ts +54 -0
- package/dist/lib/setup/doc-auditor.d.ts.map +1 -0
- package/dist/lib/setup/doc-auditor.js +629 -0
- package/dist/lib/setup/doc-auditor.js.map +1 -0
- package/dist/lib/setup/domain-generator.d.ts +7 -0
- package/dist/lib/setup/domain-generator.d.ts.map +1 -0
- package/dist/lib/setup/domain-generator.js +58 -0
- package/dist/lib/setup/domain-generator.js.map +1 -0
- package/dist/lib/setup/smart-eval-generator.d.ts +38 -0
- package/dist/lib/setup/smart-eval-generator.d.ts.map +1 -0
- package/dist/lib/setup/smart-eval-generator.js +378 -0
- package/dist/lib/setup/smart-eval-generator.js.map +1 -0
- package/dist/lib/setup/smart-recommender.d.ts +63 -0
- package/dist/lib/setup/smart-recommender.d.ts.map +1 -0
- package/dist/lib/setup/smart-recommender.js +329 -0
- package/dist/lib/setup/smart-recommender.js.map +1 -0
- package/dist/lib/setup/spec-generator.d.ts +63 -0
- package/dist/lib/setup/spec-generator.d.ts.map +1 -0
- package/dist/lib/setup/spec-generator.js +310 -0
- package/dist/lib/setup/spec-generator.js.map +1 -0
- package/dist/lib/setup/violation-agent-generator.d.ts +32 -0
- package/dist/lib/setup/violation-agent-generator.d.ts.map +1 -0
- package/dist/lib/setup/violation-agent-generator.js +255 -0
- package/dist/lib/setup/violation-agent-generator.js.map +1 -0
- package/package.json +1 -1
- package/packages/pi/extensions/context.ts +88 -55
- package/packages/pi/extensions/hub-resolver.ts +63 -0
- package/packages/pi/extensions/index.ts +16 -3
- package/packages/pi/extensions/memory-tool.ts +9 -4
- package/packages/pi/extensions/session.ts +68 -16
- package/packages/pi/extensions/tool-renderers.ts +23 -8
- package/scripts/train/requirements.txt +5 -0
- package/scripts/train/train-policy-head.py +477 -0
- package/scripts/train/v2/dataset.py +81 -0
- package/scripts/train/v2/domain.json +18 -0
- package/scripts/train/v2/eval.py +196 -0
- package/scripts/train/v2/generate_data.py +219 -0
- package/scripts/train/v2/infer.py +188 -0
- package/scripts/train/v2/model.py +112 -0
- package/scripts/train/v2/precompute.py +132 -0
- package/scripts/train/v2/train.py +302 -0
- package/scripts/train/v2/transform_buffer.py +227 -0
- package/scripts/train/v2/validate_data.py +115 -0
- package/template/.claude/settings.json +2 -15
- package/template/scripts/session/session-cleanup.sh +2 -11
- package/template/scripts/session/session-end-hub.sh +72 -0
- package/template/scripts/session/session-start-hub.sh +105 -0
- package/dist/dashboard-static/assets/index-B6b867Pv.js +0 -121
- package/dist/dashboard-static/assets/index-Y4BrqxV-.css +0 -1
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
"""
|
|
2
|
+
v2 Policy Head — Transformer-based action selector.
|
|
3
|
+
|
|
4
|
+
Architecture from Drew's Stratus tutorial:
|
|
5
|
+
(current_state_emb, goal_emb) -> state_proj + goal_proj -> fusion -> TransformerEncoder -> classifier -> action logits
|
|
6
|
+
|
|
7
|
+
~8.7M params, ~17MB checkpoint. Replaces v1 MLP reward predictor.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
import torch.nn as nn
|
|
12
|
+
import torch.nn.functional as F
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class PolicyHead(nn.Module):
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
embedding_dim: int = 768,
|
|
19
|
+
hidden_dim: int = 512,
|
|
20
|
+
num_tools: int = 12,
|
|
21
|
+
num_layers: int = 4,
|
|
22
|
+
num_heads: int = 8,
|
|
23
|
+
dropout: float = 0.1,
|
|
24
|
+
):
|
|
25
|
+
super().__init__()
|
|
26
|
+
|
|
27
|
+
self.embedding_dim = embedding_dim
|
|
28
|
+
self.hidden_dim = hidden_dim
|
|
29
|
+
self.num_tools = num_tools
|
|
30
|
+
|
|
31
|
+
self.state_proj = nn.Linear(embedding_dim, hidden_dim)
|
|
32
|
+
self.goal_proj = nn.Linear(embedding_dim, hidden_dim)
|
|
33
|
+
|
|
34
|
+
self.fusion = nn.Sequential(
|
|
35
|
+
nn.Linear(hidden_dim * 2, hidden_dim),
|
|
36
|
+
nn.GELU(),
|
|
37
|
+
nn.Dropout(dropout),
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
encoder_layer = nn.TransformerEncoderLayer(
|
|
41
|
+
d_model=hidden_dim,
|
|
42
|
+
nhead=num_heads,
|
|
43
|
+
dim_feedforward=hidden_dim * 4,
|
|
44
|
+
dropout=dropout,
|
|
45
|
+
activation="gelu",
|
|
46
|
+
batch_first=True,
|
|
47
|
+
)
|
|
48
|
+
self.transformer = nn.TransformerEncoder(
|
|
49
|
+
encoder_layer,
|
|
50
|
+
num_layers=num_layers,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
self.norm = nn.LayerNorm(hidden_dim)
|
|
54
|
+
self.classifier = nn.Sequential(
|
|
55
|
+
nn.Linear(hidden_dim, hidden_dim),
|
|
56
|
+
nn.GELU(),
|
|
57
|
+
nn.Dropout(dropout),
|
|
58
|
+
nn.Linear(hidden_dim, num_tools),
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
self._init_weights()
|
|
62
|
+
|
|
63
|
+
def _init_weights(self):
|
|
64
|
+
for module in self.modules():
|
|
65
|
+
if isinstance(module, nn.Linear):
|
|
66
|
+
nn.init.xavier_uniform_(module.weight)
|
|
67
|
+
if module.bias is not None:
|
|
68
|
+
nn.init.zeros_(module.bias)
|
|
69
|
+
elif isinstance(module, nn.LayerNorm):
|
|
70
|
+
nn.init.ones_(module.weight)
|
|
71
|
+
nn.init.zeros_(module.bias)
|
|
72
|
+
|
|
73
|
+
def forward(
|
|
74
|
+
self,
|
|
75
|
+
current_state_emb: torch.Tensor,
|
|
76
|
+
goal_state_emb: torch.Tensor,
|
|
77
|
+
) -> torch.Tensor:
|
|
78
|
+
state_h = self.state_proj(current_state_emb)
|
|
79
|
+
goal_h = self.goal_proj(goal_state_emb)
|
|
80
|
+
|
|
81
|
+
fused = self.fusion(torch.cat([state_h, goal_h], dim=-1))
|
|
82
|
+
|
|
83
|
+
x = fused.unsqueeze(1)
|
|
84
|
+
x = self.transformer(x)
|
|
85
|
+
x = x.squeeze(1)
|
|
86
|
+
|
|
87
|
+
x = self.norm(x)
|
|
88
|
+
logits = self.classifier(x)
|
|
89
|
+
|
|
90
|
+
return logits
|
|
91
|
+
|
|
92
|
+
def predict(
|
|
93
|
+
self,
|
|
94
|
+
current_state_emb: torch.Tensor,
|
|
95
|
+
goal_state_emb: torch.Tensor,
|
|
96
|
+
top_k: int = 3,
|
|
97
|
+
) -> dict:
|
|
98
|
+
self.eval()
|
|
99
|
+
with torch.no_grad():
|
|
100
|
+
logits = self.forward(current_state_emb, goal_state_emb)
|
|
101
|
+
probs = F.softmax(logits, dim=-1)
|
|
102
|
+
top_probs, top_indices = torch.topk(probs, k=min(top_k, self.num_tools), dim=-1)
|
|
103
|
+
|
|
104
|
+
return {
|
|
105
|
+
"top_k_indices": top_indices,
|
|
106
|
+
"top_k_probs": top_probs,
|
|
107
|
+
"all_probs": probs,
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
@property
|
|
111
|
+
def num_parameters(self) -> int:
|
|
112
|
+
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pre-compute Stratus embeddings for all unique texts in v2 training data.
|
|
3
|
+
Caches embeddings as .npz files to avoid re-computation during training.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
import os
|
|
8
|
+
import sys
|
|
9
|
+
import argparse
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
def get_stratus_embedder(api_url: str, api_key: str):
|
|
13
|
+
import requests
|
|
14
|
+
|
|
15
|
+
def embed_batch(texts: list[str]) -> list[list[float]]:
|
|
16
|
+
response = requests.post(
|
|
17
|
+
f"{api_url}/v1/embeddings",
|
|
18
|
+
headers={
|
|
19
|
+
"Authorization": f"Bearer {api_key}",
|
|
20
|
+
"Content-Type": "application/json",
|
|
21
|
+
},
|
|
22
|
+
json={
|
|
23
|
+
"model": "stratus-x1ac-base",
|
|
24
|
+
"input": texts,
|
|
25
|
+
},
|
|
26
|
+
timeout=30,
|
|
27
|
+
)
|
|
28
|
+
response.raise_for_status()
|
|
29
|
+
data = response.json()
|
|
30
|
+
return [d["embedding"] for d in data["data"]]
|
|
31
|
+
|
|
32
|
+
return embed_batch
|
|
33
|
+
|
|
34
|
+
def collect_unique_texts(data_path: str) -> tuple[list[str], list[str]]:
|
|
35
|
+
states = set()
|
|
36
|
+
goals = set()
|
|
37
|
+
|
|
38
|
+
with open(data_path) as f:
|
|
39
|
+
for line in f:
|
|
40
|
+
line = line.strip()
|
|
41
|
+
if not line:
|
|
42
|
+
continue
|
|
43
|
+
ex = json.loads(line)
|
|
44
|
+
states.add(ex["current_state"])
|
|
45
|
+
goals.add(ex["goal"])
|
|
46
|
+
|
|
47
|
+
return sorted(states), sorted(goals)
|
|
48
|
+
|
|
49
|
+
def precompute_embeddings(
|
|
50
|
+
data_dir: str,
|
|
51
|
+
api_url: str,
|
|
52
|
+
api_key: str,
|
|
53
|
+
batch_size: int = 32,
|
|
54
|
+
):
|
|
55
|
+
embedder = get_stratus_embedder(api_url, api_key)
|
|
56
|
+
|
|
57
|
+
all_states = set()
|
|
58
|
+
all_goals = set()
|
|
59
|
+
|
|
60
|
+
for split in ["train", "val", "test"]:
|
|
61
|
+
path = os.path.join(data_dir, f"{split}.jsonl")
|
|
62
|
+
if not os.path.exists(path):
|
|
63
|
+
print(f" Skipping {split} (file not found)")
|
|
64
|
+
continue
|
|
65
|
+
states, goals = collect_unique_texts(path)
|
|
66
|
+
all_states.update(states)
|
|
67
|
+
all_goals.update(goals)
|
|
68
|
+
|
|
69
|
+
all_texts = sorted(all_states | all_goals)
|
|
70
|
+
print(f"Unique texts to embed: {len(all_texts)} ({len(all_states)} states, {len(all_goals)} goals)")
|
|
71
|
+
|
|
72
|
+
text_to_embedding = {}
|
|
73
|
+
for i in range(0, len(all_texts), batch_size):
|
|
74
|
+
batch = all_texts[i : i + batch_size]
|
|
75
|
+
try:
|
|
76
|
+
embeddings = embedder(batch)
|
|
77
|
+
for text, emb in zip(batch, embeddings):
|
|
78
|
+
text_to_embedding[text] = emb
|
|
79
|
+
except Exception as e:
|
|
80
|
+
print(f" Error embedding batch {i}-{i + len(batch)}: {e}")
|
|
81
|
+
continue
|
|
82
|
+
|
|
83
|
+
done = min(i + batch_size, len(all_texts))
|
|
84
|
+
print(f" Embedded {done}/{len(all_texts)} texts")
|
|
85
|
+
|
|
86
|
+
texts_list = sorted(text_to_embedding.keys())
|
|
87
|
+
text_to_idx = {t: i for i, t in enumerate(texts_list)}
|
|
88
|
+
embeddings_matrix = np.array([text_to_embedding[t] for t in texts_list], dtype=np.float32)
|
|
89
|
+
|
|
90
|
+
cache_path = os.path.join(data_dir, "embeddings_cache.npz")
|
|
91
|
+
np.savez(
|
|
92
|
+
cache_path,
|
|
93
|
+
embeddings=embeddings_matrix,
|
|
94
|
+
texts=np.array(texts_list, dtype=object),
|
|
95
|
+
)
|
|
96
|
+
print(f"Saved embedding cache: {cache_path} ({embeddings_matrix.shape})")
|
|
97
|
+
|
|
98
|
+
index_path = os.path.join(data_dir, "text_to_idx.json")
|
|
99
|
+
with open(index_path, "w") as f:
|
|
100
|
+
json.dump(text_to_idx, f)
|
|
101
|
+
print(f"Saved text index: {index_path} ({len(text_to_idx)} entries)")
|
|
102
|
+
|
|
103
|
+
return text_to_idx, embeddings_matrix
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def main():
|
|
107
|
+
parser = argparse.ArgumentParser(description="Pre-compute Stratus embeddings for v2 training data")
|
|
108
|
+
parser.add_argument("--data-dir", default=".jfl/v2-data", help="Directory with train/val/test JSONL files")
|
|
109
|
+
parser.add_argument("--batch-size", type=int, default=32, help="Embedding batch size")
|
|
110
|
+
args = parser.parse_args()
|
|
111
|
+
|
|
112
|
+
api_url = os.environ.get("STRATUS_API_URL", "https://api.stratus.run")
|
|
113
|
+
api_key = os.environ.get("STRATUS_API_KEY", "")
|
|
114
|
+
|
|
115
|
+
if not api_key:
|
|
116
|
+
print("STRATUS_API_KEY not set")
|
|
117
|
+
sys.exit(1)
|
|
118
|
+
|
|
119
|
+
if not os.path.exists(args.data_dir):
|
|
120
|
+
print(f"Data directory not found: {args.data_dir}")
|
|
121
|
+
sys.exit(1)
|
|
122
|
+
|
|
123
|
+
precompute_embeddings(
|
|
124
|
+
data_dir=args.data_dir,
|
|
125
|
+
api_url=api_url,
|
|
126
|
+
api_key=api_key,
|
|
127
|
+
batch_size=args.batch_size,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
if __name__ == "__main__":
|
|
132
|
+
main()
|
|
@@ -0,0 +1,302 @@
|
|
|
1
|
+
"""
|
|
2
|
+
v2 Policy Head Training Loop.
|
|
3
|
+
|
|
4
|
+
CrossEntropyLoss with label smoothing, cosine annealing with warmup,
|
|
5
|
+
early stopping. Produces .pt checkpoint with model weights, config, and tool index.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import os
|
|
10
|
+
import sys
|
|
11
|
+
import time
|
|
12
|
+
import math
|
|
13
|
+
import argparse
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
import torch
|
|
17
|
+
import torch.nn as nn
|
|
18
|
+
import torch.optim as optim
|
|
19
|
+
from torch.utils.data import DataLoader
|
|
20
|
+
|
|
21
|
+
from model import PolicyHead
|
|
22
|
+
from dataset import PolicyHeadDataset, load_embedding_cache
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def load_tool_index(domain_path: str) -> dict[str, int]:
|
|
26
|
+
with open(domain_path) as f:
|
|
27
|
+
domain = json.load(f)
|
|
28
|
+
return {tool["name"]: i for i, tool in enumerate(domain["tools"])}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def get_lr_scheduler(optimizer, warmup_steps: int, total_steps: int):
|
|
32
|
+
def lr_lambda(step):
|
|
33
|
+
if step < warmup_steps:
|
|
34
|
+
return float(step) / float(max(1, warmup_steps))
|
|
35
|
+
progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
|
|
36
|
+
return max(0.0, 0.5 * (1.0 + math.cos(progress * math.pi)))
|
|
37
|
+
|
|
38
|
+
return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def train_epoch(model, dataloader, criterion, optimizer, scheduler, device):
|
|
42
|
+
model.train()
|
|
43
|
+
total_loss = 0.0
|
|
44
|
+
correct = 0
|
|
45
|
+
total = 0
|
|
46
|
+
|
|
47
|
+
for batch in dataloader:
|
|
48
|
+
state_emb = batch["state_emb"].to(device)
|
|
49
|
+
goal_emb = batch["goal_emb"].to(device)
|
|
50
|
+
labels = batch["label"].to(device)
|
|
51
|
+
|
|
52
|
+
optimizer.zero_grad()
|
|
53
|
+
|
|
54
|
+
logits = model(state_emb, goal_emb)
|
|
55
|
+
loss = criterion(logits, labels)
|
|
56
|
+
|
|
57
|
+
loss.backward()
|
|
58
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
|
59
|
+
optimizer.step()
|
|
60
|
+
scheduler.step()
|
|
61
|
+
|
|
62
|
+
total_loss += loss.item() * labels.size(0)
|
|
63
|
+
preds = logits.argmax(dim=-1)
|
|
64
|
+
correct += (preds == labels).sum().item()
|
|
65
|
+
total += labels.size(0)
|
|
66
|
+
|
|
67
|
+
return total_loss / max(total, 1), correct / max(total, 1)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@torch.no_grad()
|
|
71
|
+
def evaluate(model, dataloader, criterion, device):
|
|
72
|
+
model.eval()
|
|
73
|
+
total_loss = 0.0
|
|
74
|
+
correct = 0
|
|
75
|
+
total = 0
|
|
76
|
+
|
|
77
|
+
for batch in dataloader:
|
|
78
|
+
state_emb = batch["state_emb"].to(device)
|
|
79
|
+
goal_emb = batch["goal_emb"].to(device)
|
|
80
|
+
labels = batch["label"].to(device)
|
|
81
|
+
|
|
82
|
+
logits = model(state_emb, goal_emb)
|
|
83
|
+
loss = criterion(logits, labels)
|
|
84
|
+
|
|
85
|
+
total_loss += loss.item() * labels.size(0)
|
|
86
|
+
preds = logits.argmax(dim=-1)
|
|
87
|
+
correct += (preds == labels).sum().item()
|
|
88
|
+
total += labels.size(0)
|
|
89
|
+
|
|
90
|
+
return total_loss / max(total, 1), correct / max(total, 1)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def train(args):
|
|
94
|
+
# Device
|
|
95
|
+
if torch.cuda.is_available():
|
|
96
|
+
device = "cuda"
|
|
97
|
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
98
|
+
device = "mps"
|
|
99
|
+
else:
|
|
100
|
+
device = "cpu"
|
|
101
|
+
print(f"Device: {device}")
|
|
102
|
+
|
|
103
|
+
# Domain
|
|
104
|
+
domain_path = args.domain
|
|
105
|
+
tool_to_index = load_tool_index(domain_path)
|
|
106
|
+
index_to_tool = {v: k for k, v in tool_to_index.items()}
|
|
107
|
+
num_tools = len(tool_to_index)
|
|
108
|
+
print(f"Tools: {num_tools}")
|
|
109
|
+
|
|
110
|
+
# Embeddings cache
|
|
111
|
+
embeddings_matrix, text_to_idx = load_embedding_cache(args.data_dir)
|
|
112
|
+
if embeddings_matrix is not None:
|
|
113
|
+
print(f"Embedding cache: {embeddings_matrix.shape[0]} texts, {embeddings_matrix.shape[1]}-dim")
|
|
114
|
+
else:
|
|
115
|
+
print("WARNING: No embedding cache found. Training with zero vectors.")
|
|
116
|
+
print(" Run: python precompute.py --data-dir", args.data_dir)
|
|
117
|
+
|
|
118
|
+
# Datasets
|
|
119
|
+
train_path = os.path.join(args.data_dir, "train.jsonl")
|
|
120
|
+
val_path = os.path.join(args.data_dir, "val.jsonl")
|
|
121
|
+
|
|
122
|
+
if not os.path.exists(train_path):
|
|
123
|
+
print(f"Training data not found: {train_path}")
|
|
124
|
+
sys.exit(1)
|
|
125
|
+
|
|
126
|
+
train_ds = PolicyHeadDataset(train_path, tool_to_index, embeddings_matrix, text_to_idx)
|
|
127
|
+
val_ds = PolicyHeadDataset(val_path, tool_to_index, embeddings_matrix, text_to_idx) if os.path.exists(val_path) else None
|
|
128
|
+
|
|
129
|
+
num_workers = 0 if device == "mps" else min(4, os.cpu_count() or 1)
|
|
130
|
+
train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=num_workers)
|
|
131
|
+
val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=num_workers) if val_ds else None
|
|
132
|
+
|
|
133
|
+
print(f"Train: {len(train_ds)} examples")
|
|
134
|
+
if val_ds:
|
|
135
|
+
print(f"Val: {len(val_ds)} examples")
|
|
136
|
+
|
|
137
|
+
# Model
|
|
138
|
+
embedding_dim = embeddings_matrix.shape[1] if embeddings_matrix is not None else 768
|
|
139
|
+
model = PolicyHead(
|
|
140
|
+
embedding_dim=embedding_dim,
|
|
141
|
+
hidden_dim=args.hidden_dim,
|
|
142
|
+
num_tools=num_tools,
|
|
143
|
+
num_layers=args.num_layers,
|
|
144
|
+
num_heads=args.num_heads,
|
|
145
|
+
dropout=args.dropout,
|
|
146
|
+
).to(device)
|
|
147
|
+
|
|
148
|
+
print(f"Parameters: {model.num_parameters:,}")
|
|
149
|
+
|
|
150
|
+
# Warm start
|
|
151
|
+
if args.warm_start and os.path.exists(args.warm_start):
|
|
152
|
+
print(f"Warm-starting from: {args.warm_start}")
|
|
153
|
+
state_dict = torch.load(args.warm_start, map_location=device, weights_only=True)
|
|
154
|
+
if "model_state_dict" in state_dict:
|
|
155
|
+
state_dict = state_dict["model_state_dict"]
|
|
156
|
+
compatible = {}
|
|
157
|
+
for k, v in state_dict.items():
|
|
158
|
+
if k in model.state_dict() and v.shape == model.state_dict()[k].shape:
|
|
159
|
+
compatible[k] = v
|
|
160
|
+
model.load_state_dict(compatible, strict=False)
|
|
161
|
+
print(f" Loaded {len(compatible)}/{len(state_dict)} layers")
|
|
162
|
+
|
|
163
|
+
# Loss, optimizer, scheduler
|
|
164
|
+
criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
|
|
165
|
+
optimizer = optim.AdamW(
|
|
166
|
+
model.parameters(),
|
|
167
|
+
lr=args.lr,
|
|
168
|
+
weight_decay=args.weight_decay,
|
|
169
|
+
)
|
|
170
|
+
total_steps = len(train_loader) * args.epochs
|
|
171
|
+
scheduler = get_lr_scheduler(optimizer, args.warmup_steps, total_steps)
|
|
172
|
+
|
|
173
|
+
# Training loop
|
|
174
|
+
os.makedirs(args.output_dir, exist_ok=True)
|
|
175
|
+
best_val_acc = 0.0
|
|
176
|
+
best_val_loss = float("inf")
|
|
177
|
+
patience_counter = 0
|
|
178
|
+
|
|
179
|
+
print(f"\nStarting training for {args.epochs} epochs...")
|
|
180
|
+
print(f"{'Epoch':>5} {'Train Loss':>12} {'Train Acc':>10} {'Val Loss':>10} {'Val Acc':>9} {'LR':>10} {'Time':>8}")
|
|
181
|
+
print("-" * 75)
|
|
182
|
+
|
|
183
|
+
for epoch in range(1, args.epochs + 1):
|
|
184
|
+
t0 = time.time()
|
|
185
|
+
|
|
186
|
+
train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, scheduler, device)
|
|
187
|
+
|
|
188
|
+
val_loss, val_acc = (0.0, 0.0)
|
|
189
|
+
if val_loader:
|
|
190
|
+
val_loss, val_acc = evaluate(model, val_loader, criterion, device)
|
|
191
|
+
|
|
192
|
+
elapsed = time.time() - t0
|
|
193
|
+
lr = scheduler.get_last_lr()[0]
|
|
194
|
+
|
|
195
|
+
print(
|
|
196
|
+
f"{epoch:5d} {train_loss:12.4f} {train_acc:9.1%} {val_loss:10.4f} {val_acc:8.1%} {lr:10.2e} {elapsed:7.1f}s"
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
# Save best model
|
|
200
|
+
is_best = False
|
|
201
|
+
if val_loader:
|
|
202
|
+
if val_acc > best_val_acc:
|
|
203
|
+
best_val_acc = val_acc
|
|
204
|
+
best_val_loss = val_loss
|
|
205
|
+
is_best = True
|
|
206
|
+
elif train_loss < best_val_loss:
|
|
207
|
+
best_val_loss = train_loss
|
|
208
|
+
best_val_acc = train_acc
|
|
209
|
+
is_best = True
|
|
210
|
+
|
|
211
|
+
if is_best:
|
|
212
|
+
patience_counter = 0
|
|
213
|
+
checkpoint = {
|
|
214
|
+
"epoch": epoch,
|
|
215
|
+
"model_state_dict": model.state_dict(),
|
|
216
|
+
"val_accuracy": best_val_acc,
|
|
217
|
+
"val_loss": best_val_loss,
|
|
218
|
+
"num_tools": num_tools,
|
|
219
|
+
"tool_to_index": tool_to_index,
|
|
220
|
+
"index_to_tool": index_to_tool,
|
|
221
|
+
"config": {
|
|
222
|
+
"embedding_dim": embedding_dim,
|
|
223
|
+
"hidden_dim": args.hidden_dim,
|
|
224
|
+
"num_layers": args.num_layers,
|
|
225
|
+
"num_heads": args.num_heads,
|
|
226
|
+
"dropout": args.dropout,
|
|
227
|
+
},
|
|
228
|
+
}
|
|
229
|
+
ckpt_path = os.path.join(args.output_dir, "best_policy_head.pt")
|
|
230
|
+
torch.save(checkpoint, ckpt_path)
|
|
231
|
+
print(f" ✅ New best model (val_acc={best_val_acc:.1%})")
|
|
232
|
+
else:
|
|
233
|
+
patience_counter += 1
|
|
234
|
+
|
|
235
|
+
# Early stopping
|
|
236
|
+
if patience_counter >= args.patience:
|
|
237
|
+
print(f"\n Early stopping at epoch {epoch} (no improvement for {args.patience} epochs)")
|
|
238
|
+
break
|
|
239
|
+
|
|
240
|
+
print(f"\nTraining complete. Best val accuracy: {best_val_acc:.1%}")
|
|
241
|
+
|
|
242
|
+
ckpt_path = os.path.join(args.output_dir, "best_policy_head.pt")
|
|
243
|
+
if os.path.exists(ckpt_path):
|
|
244
|
+
size_mb = os.path.getsize(ckpt_path) / 1024 / 1024
|
|
245
|
+
print(f"Checkpoint: {ckpt_path} ({size_mb:.1f} MB)")
|
|
246
|
+
|
|
247
|
+
# Write metadata for TypeScript bridge
|
|
248
|
+
meta = {
|
|
249
|
+
"version": 2,
|
|
250
|
+
"architecture": "transformer-4layer-512h",
|
|
251
|
+
"embedding_dim": embedding_dim,
|
|
252
|
+
"hidden_dim": args.hidden_dim,
|
|
253
|
+
"num_tools": num_tools,
|
|
254
|
+
"num_layers": args.num_layers,
|
|
255
|
+
"num_heads": args.num_heads,
|
|
256
|
+
"trained_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
|
257
|
+
"trained_on": len(train_ds),
|
|
258
|
+
"val_accuracy": best_val_acc,
|
|
259
|
+
"val_loss": best_val_loss,
|
|
260
|
+
"device": device,
|
|
261
|
+
"parameters": model.num_parameters,
|
|
262
|
+
"tool_to_index": tool_to_index,
|
|
263
|
+
"index_to_tool": {str(k): v for k, v in index_to_tool.items()},
|
|
264
|
+
"checkpoint_path": os.path.abspath(ckpt_path),
|
|
265
|
+
}
|
|
266
|
+
meta_path = os.path.join(args.output_dir, "policy-head-v2.json")
|
|
267
|
+
with open(meta_path, "w") as f:
|
|
268
|
+
json.dump(meta, f, indent=2)
|
|
269
|
+
print(f"Metadata: {meta_path}")
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def main():
|
|
273
|
+
parser = argparse.ArgumentParser(description="Train v2 policy head (transformer action selector)")
|
|
274
|
+
parser.add_argument("--data-dir", default=".jfl/v2-data", help="Directory with train/val/test JSONL + embeddings")
|
|
275
|
+
parser.add_argument("--domain", default=None, help="Path to domain.json")
|
|
276
|
+
parser.add_argument("--output-dir", default=".jfl/checkpoints", help="Output directory for checkpoints")
|
|
277
|
+
parser.add_argument("--warm-start", default=None, help="Path to pretrained checkpoint for warm start")
|
|
278
|
+
parser.add_argument("--epochs", type=int, default=50, help="Max training epochs")
|
|
279
|
+
parser.add_argument("--batch-size", type=int, default=64, help="Batch size")
|
|
280
|
+
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
|
|
281
|
+
parser.add_argument("--weight-decay", type=float, default=0.01, help="Weight decay")
|
|
282
|
+
parser.add_argument("--warmup-steps", type=int, default=100, help="Warmup steps for LR scheduler")
|
|
283
|
+
parser.add_argument("--patience", type=int, default=7, help="Early stopping patience")
|
|
284
|
+
parser.add_argument("--label-smoothing", type=float, default=0.1, help="Label smoothing for CrossEntropyLoss")
|
|
285
|
+
parser.add_argument("--hidden-dim", type=int, default=512, help="Hidden dimension")
|
|
286
|
+
parser.add_argument("--num-layers", type=int, default=4, help="Transformer encoder layers")
|
|
287
|
+
parser.add_argument("--num-heads", type=int, default=8, help="Attention heads")
|
|
288
|
+
parser.add_argument("--dropout", type=float, default=0.1, help="Dropout rate")
|
|
289
|
+
args = parser.parse_args()
|
|
290
|
+
|
|
291
|
+
if args.domain is None:
|
|
292
|
+
args.domain = os.path.join(os.path.dirname(os.path.abspath(__file__)), "domain.json")
|
|
293
|
+
|
|
294
|
+
if not os.path.exists(args.domain):
|
|
295
|
+
print(f"Domain file not found: {args.domain}")
|
|
296
|
+
sys.exit(1)
|
|
297
|
+
|
|
298
|
+
train(args)
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
if __name__ == "__main__":
|
|
302
|
+
main()
|