jfl 0.5.0 → 0.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (135) hide show
  1. package/dist/commands/context-hub.d.ts +1 -0
  2. package/dist/commands/context-hub.d.ts.map +1 -1
  3. package/dist/commands/context-hub.js +246 -2
  4. package/dist/commands/context-hub.js.map +1 -1
  5. package/dist/commands/peter.d.ts +2 -0
  6. package/dist/commands/peter.d.ts.map +1 -1
  7. package/dist/commands/peter.js +242 -52
  8. package/dist/commands/peter.js.map +1 -1
  9. package/dist/commands/setup.d.ts +12 -0
  10. package/dist/commands/setup.d.ts.map +1 -0
  11. package/dist/commands/setup.js +322 -0
  12. package/dist/commands/setup.js.map +1 -0
  13. package/dist/commands/train.d.ts +33 -0
  14. package/dist/commands/train.d.ts.map +1 -0
  15. package/dist/commands/train.js +510 -0
  16. package/dist/commands/train.js.map +1 -0
  17. package/dist/commands/verify.d.ts +14 -0
  18. package/dist/commands/verify.d.ts.map +1 -0
  19. package/dist/commands/verify.js +276 -0
  20. package/dist/commands/verify.js.map +1 -0
  21. package/dist/dashboard-static/assets/index-CW9ZxqX8.css +1 -0
  22. package/dist/dashboard-static/assets/index-DNN__p4K.js +121 -0
  23. package/dist/dashboard-static/index.html +2 -2
  24. package/dist/index.js +99 -3
  25. package/dist/index.js.map +1 -1
  26. package/dist/lib/agent-session.d.ts.map +1 -1
  27. package/dist/lib/agent-session.js +12 -4
  28. package/dist/lib/agent-session.js.map +1 -1
  29. package/dist/lib/eval-snapshot.js +1 -1
  30. package/dist/lib/eval-snapshot.js.map +1 -1
  31. package/dist/lib/pi-sky/bridge.d.ts +55 -0
  32. package/dist/lib/pi-sky/bridge.d.ts.map +1 -0
  33. package/dist/lib/pi-sky/bridge.js +264 -0
  34. package/dist/lib/pi-sky/bridge.js.map +1 -0
  35. package/dist/lib/pi-sky/cost-monitor.d.ts +21 -0
  36. package/dist/lib/pi-sky/cost-monitor.d.ts.map +1 -0
  37. package/dist/lib/pi-sky/cost-monitor.js +126 -0
  38. package/dist/lib/pi-sky/cost-monitor.js.map +1 -0
  39. package/dist/lib/pi-sky/eval-sweep.d.ts +27 -0
  40. package/dist/lib/pi-sky/eval-sweep.d.ts.map +1 -0
  41. package/dist/lib/pi-sky/eval-sweep.js +141 -0
  42. package/dist/lib/pi-sky/eval-sweep.js.map +1 -0
  43. package/dist/lib/pi-sky/event-router.d.ts +32 -0
  44. package/dist/lib/pi-sky/event-router.d.ts.map +1 -0
  45. package/dist/lib/pi-sky/event-router.js +176 -0
  46. package/dist/lib/pi-sky/event-router.js.map +1 -0
  47. package/dist/lib/pi-sky/experiment.d.ts +9 -0
  48. package/dist/lib/pi-sky/experiment.d.ts.map +1 -0
  49. package/dist/lib/pi-sky/experiment.js +83 -0
  50. package/dist/lib/pi-sky/experiment.js.map +1 -0
  51. package/dist/lib/pi-sky/index.d.ts +16 -0
  52. package/dist/lib/pi-sky/index.d.ts.map +1 -0
  53. package/dist/lib/pi-sky/index.js +16 -0
  54. package/dist/lib/pi-sky/index.js.map +1 -0
  55. package/dist/lib/pi-sky/stratus-gate.d.ts +28 -0
  56. package/dist/lib/pi-sky/stratus-gate.d.ts.map +1 -0
  57. package/dist/lib/pi-sky/stratus-gate.js +61 -0
  58. package/dist/lib/pi-sky/stratus-gate.js.map +1 -0
  59. package/dist/lib/pi-sky/swarm.d.ts +28 -0
  60. package/dist/lib/pi-sky/swarm.d.ts.map +1 -0
  61. package/dist/lib/pi-sky/swarm.js +208 -0
  62. package/dist/lib/pi-sky/swarm.js.map +1 -0
  63. package/dist/lib/pi-sky/types.d.ts +139 -0
  64. package/dist/lib/pi-sky/types.d.ts.map +1 -0
  65. package/dist/lib/pi-sky/types.js +2 -0
  66. package/dist/lib/pi-sky/types.js.map +1 -0
  67. package/dist/lib/pi-sky/voice-bridge.d.ts +20 -0
  68. package/dist/lib/pi-sky/voice-bridge.d.ts.map +1 -0
  69. package/dist/lib/pi-sky/voice-bridge.js +91 -0
  70. package/dist/lib/pi-sky/voice-bridge.js.map +1 -0
  71. package/dist/lib/policy-head.d.ts +16 -1
  72. package/dist/lib/policy-head.d.ts.map +1 -1
  73. package/dist/lib/policy-head.js +117 -19
  74. package/dist/lib/policy-head.js.map +1 -1
  75. package/dist/lib/predictor.d.ts +10 -0
  76. package/dist/lib/predictor.d.ts.map +1 -1
  77. package/dist/lib/predictor.js +46 -7
  78. package/dist/lib/predictor.js.map +1 -1
  79. package/dist/lib/setup/agent-generator.d.ts +18 -0
  80. package/dist/lib/setup/agent-generator.d.ts.map +1 -0
  81. package/dist/lib/setup/agent-generator.js +114 -0
  82. package/dist/lib/setup/agent-generator.js.map +1 -0
  83. package/dist/lib/setup/context-analyzer.d.ts +16 -0
  84. package/dist/lib/setup/context-analyzer.d.ts.map +1 -0
  85. package/dist/lib/setup/context-analyzer.js +112 -0
  86. package/dist/lib/setup/context-analyzer.js.map +1 -0
  87. package/dist/lib/setup/doc-auditor.d.ts +54 -0
  88. package/dist/lib/setup/doc-auditor.d.ts.map +1 -0
  89. package/dist/lib/setup/doc-auditor.js +629 -0
  90. package/dist/lib/setup/doc-auditor.js.map +1 -0
  91. package/dist/lib/setup/domain-generator.d.ts +7 -0
  92. package/dist/lib/setup/domain-generator.d.ts.map +1 -0
  93. package/dist/lib/setup/domain-generator.js +58 -0
  94. package/dist/lib/setup/domain-generator.js.map +1 -0
  95. package/dist/lib/setup/smart-eval-generator.d.ts +38 -0
  96. package/dist/lib/setup/smart-eval-generator.d.ts.map +1 -0
  97. package/dist/lib/setup/smart-eval-generator.js +378 -0
  98. package/dist/lib/setup/smart-eval-generator.js.map +1 -0
  99. package/dist/lib/setup/smart-recommender.d.ts +63 -0
  100. package/dist/lib/setup/smart-recommender.d.ts.map +1 -0
  101. package/dist/lib/setup/smart-recommender.js +329 -0
  102. package/dist/lib/setup/smart-recommender.js.map +1 -0
  103. package/dist/lib/setup/spec-generator.d.ts +63 -0
  104. package/dist/lib/setup/spec-generator.d.ts.map +1 -0
  105. package/dist/lib/setup/spec-generator.js +310 -0
  106. package/dist/lib/setup/spec-generator.js.map +1 -0
  107. package/dist/lib/setup/violation-agent-generator.d.ts +32 -0
  108. package/dist/lib/setup/violation-agent-generator.d.ts.map +1 -0
  109. package/dist/lib/setup/violation-agent-generator.js +255 -0
  110. package/dist/lib/setup/violation-agent-generator.js.map +1 -0
  111. package/package.json +1 -1
  112. package/packages/pi/extensions/context.ts +88 -55
  113. package/packages/pi/extensions/hub-resolver.ts +63 -0
  114. package/packages/pi/extensions/index.ts +16 -3
  115. package/packages/pi/extensions/memory-tool.ts +9 -4
  116. package/packages/pi/extensions/session.ts +68 -16
  117. package/packages/pi/extensions/tool-renderers.ts +23 -8
  118. package/scripts/train/requirements.txt +5 -0
  119. package/scripts/train/train-policy-head.py +477 -0
  120. package/scripts/train/v2/dataset.py +81 -0
  121. package/scripts/train/v2/domain.json +18 -0
  122. package/scripts/train/v2/eval.py +196 -0
  123. package/scripts/train/v2/generate_data.py +219 -0
  124. package/scripts/train/v2/infer.py +188 -0
  125. package/scripts/train/v2/model.py +112 -0
  126. package/scripts/train/v2/precompute.py +132 -0
  127. package/scripts/train/v2/train.py +302 -0
  128. package/scripts/train/v2/transform_buffer.py +227 -0
  129. package/scripts/train/v2/validate_data.py +115 -0
  130. package/template/.claude/settings.json +2 -15
  131. package/template/scripts/session/session-cleanup.sh +2 -11
  132. package/template/scripts/session/session-end-hub.sh +72 -0
  133. package/template/scripts/session/session-start-hub.sh +105 -0
  134. package/dist/dashboard-static/assets/index-B6b867Pv.js +0 -121
  135. package/dist/dashboard-static/assets/index-Y4BrqxV-.css +0 -1
@@ -0,0 +1,112 @@
1
+ """
2
+ v2 Policy Head — Transformer-based action selector.
3
+
4
+ Architecture from Drew's Stratus tutorial:
5
+ (current_state_emb, goal_emb) -> state_proj + goal_proj -> fusion -> TransformerEncoder -> classifier -> action logits
6
+
7
+ ~8.7M params, ~17MB checkpoint. Replaces v1 MLP reward predictor.
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+
15
+ class PolicyHead(nn.Module):
16
+ def __init__(
17
+ self,
18
+ embedding_dim: int = 768,
19
+ hidden_dim: int = 512,
20
+ num_tools: int = 12,
21
+ num_layers: int = 4,
22
+ num_heads: int = 8,
23
+ dropout: float = 0.1,
24
+ ):
25
+ super().__init__()
26
+
27
+ self.embedding_dim = embedding_dim
28
+ self.hidden_dim = hidden_dim
29
+ self.num_tools = num_tools
30
+
31
+ self.state_proj = nn.Linear(embedding_dim, hidden_dim)
32
+ self.goal_proj = nn.Linear(embedding_dim, hidden_dim)
33
+
34
+ self.fusion = nn.Sequential(
35
+ nn.Linear(hidden_dim * 2, hidden_dim),
36
+ nn.GELU(),
37
+ nn.Dropout(dropout),
38
+ )
39
+
40
+ encoder_layer = nn.TransformerEncoderLayer(
41
+ d_model=hidden_dim,
42
+ nhead=num_heads,
43
+ dim_feedforward=hidden_dim * 4,
44
+ dropout=dropout,
45
+ activation="gelu",
46
+ batch_first=True,
47
+ )
48
+ self.transformer = nn.TransformerEncoder(
49
+ encoder_layer,
50
+ num_layers=num_layers,
51
+ )
52
+
53
+ self.norm = nn.LayerNorm(hidden_dim)
54
+ self.classifier = nn.Sequential(
55
+ nn.Linear(hidden_dim, hidden_dim),
56
+ nn.GELU(),
57
+ nn.Dropout(dropout),
58
+ nn.Linear(hidden_dim, num_tools),
59
+ )
60
+
61
+ self._init_weights()
62
+
63
+ def _init_weights(self):
64
+ for module in self.modules():
65
+ if isinstance(module, nn.Linear):
66
+ nn.init.xavier_uniform_(module.weight)
67
+ if module.bias is not None:
68
+ nn.init.zeros_(module.bias)
69
+ elif isinstance(module, nn.LayerNorm):
70
+ nn.init.ones_(module.weight)
71
+ nn.init.zeros_(module.bias)
72
+
73
+ def forward(
74
+ self,
75
+ current_state_emb: torch.Tensor,
76
+ goal_state_emb: torch.Tensor,
77
+ ) -> torch.Tensor:
78
+ state_h = self.state_proj(current_state_emb)
79
+ goal_h = self.goal_proj(goal_state_emb)
80
+
81
+ fused = self.fusion(torch.cat([state_h, goal_h], dim=-1))
82
+
83
+ x = fused.unsqueeze(1)
84
+ x = self.transformer(x)
85
+ x = x.squeeze(1)
86
+
87
+ x = self.norm(x)
88
+ logits = self.classifier(x)
89
+
90
+ return logits
91
+
92
+ def predict(
93
+ self,
94
+ current_state_emb: torch.Tensor,
95
+ goal_state_emb: torch.Tensor,
96
+ top_k: int = 3,
97
+ ) -> dict:
98
+ self.eval()
99
+ with torch.no_grad():
100
+ logits = self.forward(current_state_emb, goal_state_emb)
101
+ probs = F.softmax(logits, dim=-1)
102
+ top_probs, top_indices = torch.topk(probs, k=min(top_k, self.num_tools), dim=-1)
103
+
104
+ return {
105
+ "top_k_indices": top_indices,
106
+ "top_k_probs": top_probs,
107
+ "all_probs": probs,
108
+ }
109
+
110
+ @property
111
+ def num_parameters(self) -> int:
112
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
@@ -0,0 +1,132 @@
1
+ """
2
+ Pre-compute Stratus embeddings for all unique texts in v2 training data.
3
+ Caches embeddings as .npz files to avoid re-computation during training.
4
+ """
5
+
6
+ import json
7
+ import os
8
+ import sys
9
+ import argparse
10
+ import numpy as np
11
+
12
+ def get_stratus_embedder(api_url: str, api_key: str):
13
+ import requests
14
+
15
+ def embed_batch(texts: list[str]) -> list[list[float]]:
16
+ response = requests.post(
17
+ f"{api_url}/v1/embeddings",
18
+ headers={
19
+ "Authorization": f"Bearer {api_key}",
20
+ "Content-Type": "application/json",
21
+ },
22
+ json={
23
+ "model": "stratus-x1ac-base",
24
+ "input": texts,
25
+ },
26
+ timeout=30,
27
+ )
28
+ response.raise_for_status()
29
+ data = response.json()
30
+ return [d["embedding"] for d in data["data"]]
31
+
32
+ return embed_batch
33
+
34
+ def collect_unique_texts(data_path: str) -> tuple[list[str], list[str]]:
35
+ states = set()
36
+ goals = set()
37
+
38
+ with open(data_path) as f:
39
+ for line in f:
40
+ line = line.strip()
41
+ if not line:
42
+ continue
43
+ ex = json.loads(line)
44
+ states.add(ex["current_state"])
45
+ goals.add(ex["goal"])
46
+
47
+ return sorted(states), sorted(goals)
48
+
49
+ def precompute_embeddings(
50
+ data_dir: str,
51
+ api_url: str,
52
+ api_key: str,
53
+ batch_size: int = 32,
54
+ ):
55
+ embedder = get_stratus_embedder(api_url, api_key)
56
+
57
+ all_states = set()
58
+ all_goals = set()
59
+
60
+ for split in ["train", "val", "test"]:
61
+ path = os.path.join(data_dir, f"{split}.jsonl")
62
+ if not os.path.exists(path):
63
+ print(f" Skipping {split} (file not found)")
64
+ continue
65
+ states, goals = collect_unique_texts(path)
66
+ all_states.update(states)
67
+ all_goals.update(goals)
68
+
69
+ all_texts = sorted(all_states | all_goals)
70
+ print(f"Unique texts to embed: {len(all_texts)} ({len(all_states)} states, {len(all_goals)} goals)")
71
+
72
+ text_to_embedding = {}
73
+ for i in range(0, len(all_texts), batch_size):
74
+ batch = all_texts[i : i + batch_size]
75
+ try:
76
+ embeddings = embedder(batch)
77
+ for text, emb in zip(batch, embeddings):
78
+ text_to_embedding[text] = emb
79
+ except Exception as e:
80
+ print(f" Error embedding batch {i}-{i + len(batch)}: {e}")
81
+ continue
82
+
83
+ done = min(i + batch_size, len(all_texts))
84
+ print(f" Embedded {done}/{len(all_texts)} texts")
85
+
86
+ texts_list = sorted(text_to_embedding.keys())
87
+ text_to_idx = {t: i for i, t in enumerate(texts_list)}
88
+ embeddings_matrix = np.array([text_to_embedding[t] for t in texts_list], dtype=np.float32)
89
+
90
+ cache_path = os.path.join(data_dir, "embeddings_cache.npz")
91
+ np.savez(
92
+ cache_path,
93
+ embeddings=embeddings_matrix,
94
+ texts=np.array(texts_list, dtype=object),
95
+ )
96
+ print(f"Saved embedding cache: {cache_path} ({embeddings_matrix.shape})")
97
+
98
+ index_path = os.path.join(data_dir, "text_to_idx.json")
99
+ with open(index_path, "w") as f:
100
+ json.dump(text_to_idx, f)
101
+ print(f"Saved text index: {index_path} ({len(text_to_idx)} entries)")
102
+
103
+ return text_to_idx, embeddings_matrix
104
+
105
+
106
+ def main():
107
+ parser = argparse.ArgumentParser(description="Pre-compute Stratus embeddings for v2 training data")
108
+ parser.add_argument("--data-dir", default=".jfl/v2-data", help="Directory with train/val/test JSONL files")
109
+ parser.add_argument("--batch-size", type=int, default=32, help="Embedding batch size")
110
+ args = parser.parse_args()
111
+
112
+ api_url = os.environ.get("STRATUS_API_URL", "https://api.stratus.run")
113
+ api_key = os.environ.get("STRATUS_API_KEY", "")
114
+
115
+ if not api_key:
116
+ print("STRATUS_API_KEY not set")
117
+ sys.exit(1)
118
+
119
+ if not os.path.exists(args.data_dir):
120
+ print(f"Data directory not found: {args.data_dir}")
121
+ sys.exit(1)
122
+
123
+ precompute_embeddings(
124
+ data_dir=args.data_dir,
125
+ api_url=api_url,
126
+ api_key=api_key,
127
+ batch_size=args.batch_size,
128
+ )
129
+
130
+
131
+ if __name__ == "__main__":
132
+ main()
@@ -0,0 +1,302 @@
1
+ """
2
+ v2 Policy Head Training Loop.
3
+
4
+ CrossEntropyLoss with label smoothing, cosine annealing with warmup,
5
+ early stopping. Produces .pt checkpoint with model weights, config, and tool index.
6
+ """
7
+
8
+ import json
9
+ import os
10
+ import sys
11
+ import time
12
+ import math
13
+ import argparse
14
+
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.optim as optim
19
+ from torch.utils.data import DataLoader
20
+
21
+ from model import PolicyHead
22
+ from dataset import PolicyHeadDataset, load_embedding_cache
23
+
24
+
25
+ def load_tool_index(domain_path: str) -> dict[str, int]:
26
+ with open(domain_path) as f:
27
+ domain = json.load(f)
28
+ return {tool["name"]: i for i, tool in enumerate(domain["tools"])}
29
+
30
+
31
+ def get_lr_scheduler(optimizer, warmup_steps: int, total_steps: int):
32
+ def lr_lambda(step):
33
+ if step < warmup_steps:
34
+ return float(step) / float(max(1, warmup_steps))
35
+ progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
36
+ return max(0.0, 0.5 * (1.0 + math.cos(progress * math.pi)))
37
+
38
+ return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
39
+
40
+
41
+ def train_epoch(model, dataloader, criterion, optimizer, scheduler, device):
42
+ model.train()
43
+ total_loss = 0.0
44
+ correct = 0
45
+ total = 0
46
+
47
+ for batch in dataloader:
48
+ state_emb = batch["state_emb"].to(device)
49
+ goal_emb = batch["goal_emb"].to(device)
50
+ labels = batch["label"].to(device)
51
+
52
+ optimizer.zero_grad()
53
+
54
+ logits = model(state_emb, goal_emb)
55
+ loss = criterion(logits, labels)
56
+
57
+ loss.backward()
58
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
59
+ optimizer.step()
60
+ scheduler.step()
61
+
62
+ total_loss += loss.item() * labels.size(0)
63
+ preds = logits.argmax(dim=-1)
64
+ correct += (preds == labels).sum().item()
65
+ total += labels.size(0)
66
+
67
+ return total_loss / max(total, 1), correct / max(total, 1)
68
+
69
+
70
+ @torch.no_grad()
71
+ def evaluate(model, dataloader, criterion, device):
72
+ model.eval()
73
+ total_loss = 0.0
74
+ correct = 0
75
+ total = 0
76
+
77
+ for batch in dataloader:
78
+ state_emb = batch["state_emb"].to(device)
79
+ goal_emb = batch["goal_emb"].to(device)
80
+ labels = batch["label"].to(device)
81
+
82
+ logits = model(state_emb, goal_emb)
83
+ loss = criterion(logits, labels)
84
+
85
+ total_loss += loss.item() * labels.size(0)
86
+ preds = logits.argmax(dim=-1)
87
+ correct += (preds == labels).sum().item()
88
+ total += labels.size(0)
89
+
90
+ return total_loss / max(total, 1), correct / max(total, 1)
91
+
92
+
93
+ def train(args):
94
+ # Device
95
+ if torch.cuda.is_available():
96
+ device = "cuda"
97
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
98
+ device = "mps"
99
+ else:
100
+ device = "cpu"
101
+ print(f"Device: {device}")
102
+
103
+ # Domain
104
+ domain_path = args.domain
105
+ tool_to_index = load_tool_index(domain_path)
106
+ index_to_tool = {v: k for k, v in tool_to_index.items()}
107
+ num_tools = len(tool_to_index)
108
+ print(f"Tools: {num_tools}")
109
+
110
+ # Embeddings cache
111
+ embeddings_matrix, text_to_idx = load_embedding_cache(args.data_dir)
112
+ if embeddings_matrix is not None:
113
+ print(f"Embedding cache: {embeddings_matrix.shape[0]} texts, {embeddings_matrix.shape[1]}-dim")
114
+ else:
115
+ print("WARNING: No embedding cache found. Training with zero vectors.")
116
+ print(" Run: python precompute.py --data-dir", args.data_dir)
117
+
118
+ # Datasets
119
+ train_path = os.path.join(args.data_dir, "train.jsonl")
120
+ val_path = os.path.join(args.data_dir, "val.jsonl")
121
+
122
+ if not os.path.exists(train_path):
123
+ print(f"Training data not found: {train_path}")
124
+ sys.exit(1)
125
+
126
+ train_ds = PolicyHeadDataset(train_path, tool_to_index, embeddings_matrix, text_to_idx)
127
+ val_ds = PolicyHeadDataset(val_path, tool_to_index, embeddings_matrix, text_to_idx) if os.path.exists(val_path) else None
128
+
129
+ num_workers = 0 if device == "mps" else min(4, os.cpu_count() or 1)
130
+ train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=num_workers)
131
+ val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=num_workers) if val_ds else None
132
+
133
+ print(f"Train: {len(train_ds)} examples")
134
+ if val_ds:
135
+ print(f"Val: {len(val_ds)} examples")
136
+
137
+ # Model
138
+ embedding_dim = embeddings_matrix.shape[1] if embeddings_matrix is not None else 768
139
+ model = PolicyHead(
140
+ embedding_dim=embedding_dim,
141
+ hidden_dim=args.hidden_dim,
142
+ num_tools=num_tools,
143
+ num_layers=args.num_layers,
144
+ num_heads=args.num_heads,
145
+ dropout=args.dropout,
146
+ ).to(device)
147
+
148
+ print(f"Parameters: {model.num_parameters:,}")
149
+
150
+ # Warm start
151
+ if args.warm_start and os.path.exists(args.warm_start):
152
+ print(f"Warm-starting from: {args.warm_start}")
153
+ state_dict = torch.load(args.warm_start, map_location=device, weights_only=True)
154
+ if "model_state_dict" in state_dict:
155
+ state_dict = state_dict["model_state_dict"]
156
+ compatible = {}
157
+ for k, v in state_dict.items():
158
+ if k in model.state_dict() and v.shape == model.state_dict()[k].shape:
159
+ compatible[k] = v
160
+ model.load_state_dict(compatible, strict=False)
161
+ print(f" Loaded {len(compatible)}/{len(state_dict)} layers")
162
+
163
+ # Loss, optimizer, scheduler
164
+ criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
165
+ optimizer = optim.AdamW(
166
+ model.parameters(),
167
+ lr=args.lr,
168
+ weight_decay=args.weight_decay,
169
+ )
170
+ total_steps = len(train_loader) * args.epochs
171
+ scheduler = get_lr_scheduler(optimizer, args.warmup_steps, total_steps)
172
+
173
+ # Training loop
174
+ os.makedirs(args.output_dir, exist_ok=True)
175
+ best_val_acc = 0.0
176
+ best_val_loss = float("inf")
177
+ patience_counter = 0
178
+
179
+ print(f"\nStarting training for {args.epochs} epochs...")
180
+ print(f"{'Epoch':>5} {'Train Loss':>12} {'Train Acc':>10} {'Val Loss':>10} {'Val Acc':>9} {'LR':>10} {'Time':>8}")
181
+ print("-" * 75)
182
+
183
+ for epoch in range(1, args.epochs + 1):
184
+ t0 = time.time()
185
+
186
+ train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, scheduler, device)
187
+
188
+ val_loss, val_acc = (0.0, 0.0)
189
+ if val_loader:
190
+ val_loss, val_acc = evaluate(model, val_loader, criterion, device)
191
+
192
+ elapsed = time.time() - t0
193
+ lr = scheduler.get_last_lr()[0]
194
+
195
+ print(
196
+ f"{epoch:5d} {train_loss:12.4f} {train_acc:9.1%} {val_loss:10.4f} {val_acc:8.1%} {lr:10.2e} {elapsed:7.1f}s"
197
+ )
198
+
199
+ # Save best model
200
+ is_best = False
201
+ if val_loader:
202
+ if val_acc > best_val_acc:
203
+ best_val_acc = val_acc
204
+ best_val_loss = val_loss
205
+ is_best = True
206
+ elif train_loss < best_val_loss:
207
+ best_val_loss = train_loss
208
+ best_val_acc = train_acc
209
+ is_best = True
210
+
211
+ if is_best:
212
+ patience_counter = 0
213
+ checkpoint = {
214
+ "epoch": epoch,
215
+ "model_state_dict": model.state_dict(),
216
+ "val_accuracy": best_val_acc,
217
+ "val_loss": best_val_loss,
218
+ "num_tools": num_tools,
219
+ "tool_to_index": tool_to_index,
220
+ "index_to_tool": index_to_tool,
221
+ "config": {
222
+ "embedding_dim": embedding_dim,
223
+ "hidden_dim": args.hidden_dim,
224
+ "num_layers": args.num_layers,
225
+ "num_heads": args.num_heads,
226
+ "dropout": args.dropout,
227
+ },
228
+ }
229
+ ckpt_path = os.path.join(args.output_dir, "best_policy_head.pt")
230
+ torch.save(checkpoint, ckpt_path)
231
+ print(f" ✅ New best model (val_acc={best_val_acc:.1%})")
232
+ else:
233
+ patience_counter += 1
234
+
235
+ # Early stopping
236
+ if patience_counter >= args.patience:
237
+ print(f"\n Early stopping at epoch {epoch} (no improvement for {args.patience} epochs)")
238
+ break
239
+
240
+ print(f"\nTraining complete. Best val accuracy: {best_val_acc:.1%}")
241
+
242
+ ckpt_path = os.path.join(args.output_dir, "best_policy_head.pt")
243
+ if os.path.exists(ckpt_path):
244
+ size_mb = os.path.getsize(ckpt_path) / 1024 / 1024
245
+ print(f"Checkpoint: {ckpt_path} ({size_mb:.1f} MB)")
246
+
247
+ # Write metadata for TypeScript bridge
248
+ meta = {
249
+ "version": 2,
250
+ "architecture": "transformer-4layer-512h",
251
+ "embedding_dim": embedding_dim,
252
+ "hidden_dim": args.hidden_dim,
253
+ "num_tools": num_tools,
254
+ "num_layers": args.num_layers,
255
+ "num_heads": args.num_heads,
256
+ "trained_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
257
+ "trained_on": len(train_ds),
258
+ "val_accuracy": best_val_acc,
259
+ "val_loss": best_val_loss,
260
+ "device": device,
261
+ "parameters": model.num_parameters,
262
+ "tool_to_index": tool_to_index,
263
+ "index_to_tool": {str(k): v for k, v in index_to_tool.items()},
264
+ "checkpoint_path": os.path.abspath(ckpt_path),
265
+ }
266
+ meta_path = os.path.join(args.output_dir, "policy-head-v2.json")
267
+ with open(meta_path, "w") as f:
268
+ json.dump(meta, f, indent=2)
269
+ print(f"Metadata: {meta_path}")
270
+
271
+
272
+ def main():
273
+ parser = argparse.ArgumentParser(description="Train v2 policy head (transformer action selector)")
274
+ parser.add_argument("--data-dir", default=".jfl/v2-data", help="Directory with train/val/test JSONL + embeddings")
275
+ parser.add_argument("--domain", default=None, help="Path to domain.json")
276
+ parser.add_argument("--output-dir", default=".jfl/checkpoints", help="Output directory for checkpoints")
277
+ parser.add_argument("--warm-start", default=None, help="Path to pretrained checkpoint for warm start")
278
+ parser.add_argument("--epochs", type=int, default=50, help="Max training epochs")
279
+ parser.add_argument("--batch-size", type=int, default=64, help="Batch size")
280
+ parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
281
+ parser.add_argument("--weight-decay", type=float, default=0.01, help="Weight decay")
282
+ parser.add_argument("--warmup-steps", type=int, default=100, help="Warmup steps for LR scheduler")
283
+ parser.add_argument("--patience", type=int, default=7, help="Early stopping patience")
284
+ parser.add_argument("--label-smoothing", type=float, default=0.1, help="Label smoothing for CrossEntropyLoss")
285
+ parser.add_argument("--hidden-dim", type=int, default=512, help="Hidden dimension")
286
+ parser.add_argument("--num-layers", type=int, default=4, help="Transformer encoder layers")
287
+ parser.add_argument("--num-heads", type=int, default=8, help="Attention heads")
288
+ parser.add_argument("--dropout", type=float, default=0.1, help="Dropout rate")
289
+ args = parser.parse_args()
290
+
291
+ if args.domain is None:
292
+ args.domain = os.path.join(os.path.dirname(os.path.abspath(__file__)), "domain.json")
293
+
294
+ if not os.path.exists(args.domain):
295
+ print(f"Domain file not found: {args.domain}")
296
+ sys.exit(1)
297
+
298
+ train(args)
299
+
300
+
301
+ if __name__ == "__main__":
302
+ main()