jfl 0.8.0 → 0.9.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. package/dist/commands/doctor.d.ts +1 -0
  2. package/dist/commands/doctor.d.ts.map +1 -1
  3. package/dist/commands/doctor.js +30 -1
  4. package/dist/commands/doctor.js.map +1 -1
  5. package/dist/commands/ide.d.ts +2 -1
  6. package/dist/commands/ide.d.ts.map +1 -1
  7. package/dist/commands/ide.js +60 -1
  8. package/dist/commands/ide.js.map +1 -1
  9. package/dist/commands/init-from-service.d.ts +15 -0
  10. package/dist/commands/init-from-service.d.ts.map +1 -0
  11. package/dist/commands/init-from-service.js +541 -0
  12. package/dist/commands/init-from-service.js.map +1 -0
  13. package/dist/commands/init.d.ts +1 -0
  14. package/dist/commands/init.d.ts.map +1 -1
  15. package/dist/commands/init.js +32 -1
  16. package/dist/commands/init.js.map +1 -1
  17. package/dist/commands/kanban.d.ts.map +1 -1
  18. package/dist/commands/kanban.js +13 -4
  19. package/dist/commands/kanban.js.map +1 -1
  20. package/dist/commands/linear.d.ts +41 -0
  21. package/dist/commands/linear.d.ts.map +1 -0
  22. package/dist/commands/linear.js +715 -0
  23. package/dist/commands/linear.js.map +1 -0
  24. package/dist/commands/peter.d.ts.map +1 -1
  25. package/dist/commands/peter.js +232 -25
  26. package/dist/commands/peter.js.map +1 -1
  27. package/dist/commands/services.d.ts.map +1 -1
  28. package/dist/commands/services.js +146 -0
  29. package/dist/commands/services.js.map +1 -1
  30. package/dist/commands/setup.d.ts.map +1 -1
  31. package/dist/commands/setup.js +173 -13
  32. package/dist/commands/setup.js.map +1 -1
  33. package/dist/commands/telemetry-monitor.d.ts +11 -0
  34. package/dist/commands/telemetry-monitor.d.ts.map +1 -0
  35. package/dist/commands/telemetry-monitor.js +224 -0
  36. package/dist/commands/telemetry-monitor.js.map +1 -0
  37. package/dist/commands/telemetry-test.d.ts +11 -0
  38. package/dist/commands/telemetry-test.d.ts.map +1 -0
  39. package/dist/commands/telemetry-test.js +67 -0
  40. package/dist/commands/telemetry-test.js.map +1 -0
  41. package/dist/commands/tenet-agents.d.ts +13 -0
  42. package/dist/commands/tenet-agents.d.ts.map +1 -0
  43. package/dist/commands/tenet-agents.js +191 -0
  44. package/dist/commands/tenet-agents.js.map +1 -0
  45. package/dist/commands/tenet-setup.d.ts +19 -0
  46. package/dist/commands/tenet-setup.d.ts.map +1 -0
  47. package/dist/commands/tenet-setup.js +131 -0
  48. package/dist/commands/tenet-setup.js.map +1 -0
  49. package/dist/commands/train.d.ts +18 -0
  50. package/dist/commands/train.d.ts.map +1 -1
  51. package/dist/commands/train.js +182 -0
  52. package/dist/commands/train.js.map +1 -1
  53. package/dist/commands/whoami.d.ts +2 -0
  54. package/dist/commands/whoami.d.ts.map +1 -0
  55. package/dist/commands/whoami.js +24 -0
  56. package/dist/commands/whoami.js.map +1 -0
  57. package/dist/index.js +159 -10
  58. package/dist/index.js.map +1 -1
  59. package/dist/lib/advanced-setup.d.ts +78 -0
  60. package/dist/lib/advanced-setup.d.ts.map +1 -0
  61. package/dist/lib/advanced-setup.js +433 -0
  62. package/dist/lib/advanced-setup.js.map +1 -0
  63. package/dist/lib/agent-config.d.ts +33 -0
  64. package/dist/lib/agent-config.d.ts.map +1 -1
  65. package/dist/lib/agent-config.js +26 -0
  66. package/dist/lib/agent-config.js.map +1 -1
  67. package/dist/lib/counterfactual-training-bridge.d.ts +114 -0
  68. package/dist/lib/counterfactual-training-bridge.d.ts.map +1 -0
  69. package/dist/lib/counterfactual-training-bridge.js +322 -0
  70. package/dist/lib/counterfactual-training-bridge.js.map +1 -0
  71. package/dist/lib/discovery-agent.d.ts +48 -0
  72. package/dist/lib/discovery-agent.d.ts.map +1 -0
  73. package/dist/lib/discovery-agent.js +111 -0
  74. package/dist/lib/discovery-agent.js.map +1 -0
  75. package/dist/lib/flow-engine.d.ts.map +1 -1
  76. package/dist/lib/flow-engine.js +46 -8
  77. package/dist/lib/flow-engine.js.map +1 -1
  78. package/dist/lib/gtm-generator.d.ts +29 -0
  79. package/dist/lib/gtm-generator.d.ts.map +1 -0
  80. package/dist/lib/gtm-generator.js +252 -0
  81. package/dist/lib/gtm-generator.js.map +1 -0
  82. package/dist/lib/hub-health.d.ts +40 -0
  83. package/dist/lib/hub-health.d.ts.map +1 -0
  84. package/dist/lib/hub-health.js +89 -0
  85. package/dist/lib/hub-health.js.map +1 -0
  86. package/dist/lib/invariant-monitor.d.ts +6 -2
  87. package/dist/lib/invariant-monitor.d.ts.map +1 -1
  88. package/dist/lib/invariant-monitor.js +89 -2
  89. package/dist/lib/invariant-monitor.js.map +1 -1
  90. package/dist/lib/journal-analyzer.d.ts +71 -0
  91. package/dist/lib/journal-analyzer.d.ts.map +1 -0
  92. package/dist/lib/journal-analyzer.js +306 -0
  93. package/dist/lib/journal-analyzer.js.map +1 -0
  94. package/dist/lib/linear-client.d.ts +73 -0
  95. package/dist/lib/linear-client.d.ts.map +1 -0
  96. package/dist/lib/linear-client.js +112 -0
  97. package/dist/lib/linear-client.js.map +1 -0
  98. package/dist/lib/linear-id-map.d.ts +20 -0
  99. package/dist/lib/linear-id-map.d.ts.map +1 -0
  100. package/dist/lib/linear-id-map.js +57 -0
  101. package/dist/lib/linear-id-map.js.map +1 -0
  102. package/dist/lib/linear-kanban.d.ts +66 -0
  103. package/dist/lib/linear-kanban.d.ts.map +1 -0
  104. package/dist/lib/linear-kanban.js +175 -0
  105. package/dist/lib/linear-kanban.js.map +1 -0
  106. package/dist/lib/onboarding.d.ts +40 -0
  107. package/dist/lib/onboarding.d.ts.map +1 -0
  108. package/dist/lib/onboarding.js +213 -0
  109. package/dist/lib/onboarding.js.map +1 -0
  110. package/dist/lib/physical-world-model.d.ts +50 -0
  111. package/dist/lib/physical-world-model.d.ts.map +1 -0
  112. package/dist/lib/physical-world-model.js +251 -0
  113. package/dist/lib/physical-world-model.js.map +1 -0
  114. package/dist/lib/planning-loop.d.ts +157 -0
  115. package/dist/lib/planning-loop.d.ts.map +1 -0
  116. package/dist/lib/planning-loop.js +537 -0
  117. package/dist/lib/planning-loop.js.map +1 -0
  118. package/dist/lib/policy-head.d.ts +13 -0
  119. package/dist/lib/policy-head.d.ts.map +1 -1
  120. package/dist/lib/policy-head.js +168 -2
  121. package/dist/lib/policy-head.js.map +1 -1
  122. package/dist/lib/resource-optimizer-middleware.d.ts +39 -0
  123. package/dist/lib/resource-optimizer-middleware.d.ts.map +1 -0
  124. package/dist/lib/resource-optimizer-middleware.js +222 -0
  125. package/dist/lib/resource-optimizer-middleware.js.map +1 -0
  126. package/dist/lib/resource-optimizer.d.ts +71 -0
  127. package/dist/lib/resource-optimizer.d.ts.map +1 -0
  128. package/dist/lib/resource-optimizer.js +228 -0
  129. package/dist/lib/resource-optimizer.js.map +1 -0
  130. package/dist/lib/rl-manager.d.ts +74 -0
  131. package/dist/lib/rl-manager.d.ts.map +1 -0
  132. package/dist/lib/rl-manager.js +244 -0
  133. package/dist/lib/rl-manager.js.map +1 -0
  134. package/dist/lib/service-analyzer.d.ts +76 -0
  135. package/dist/lib/service-analyzer.d.ts.map +1 -0
  136. package/dist/lib/service-analyzer.js +704 -0
  137. package/dist/lib/service-analyzer.js.map +1 -0
  138. package/dist/lib/service-gtm.js +2 -2
  139. package/dist/lib/service-gtm.js.map +1 -1
  140. package/dist/lib/service-questionnaire.d.ts +11 -0
  141. package/dist/lib/service-questionnaire.d.ts.map +1 -0
  142. package/dist/lib/service-questionnaire.js +89 -0
  143. package/dist/lib/service-questionnaire.js.map +1 -0
  144. package/dist/lib/setup/agent-generator.d.ts +2 -0
  145. package/dist/lib/setup/agent-generator.d.ts.map +1 -1
  146. package/dist/lib/setup/agent-generator.js +128 -4
  147. package/dist/lib/setup/agent-generator.js.map +1 -1
  148. package/dist/lib/setup/flow-generator.d.ts +10 -0
  149. package/dist/lib/setup/flow-generator.d.ts.map +1 -0
  150. package/dist/lib/setup/flow-generator.js +113 -0
  151. package/dist/lib/setup/flow-generator.js.map +1 -0
  152. package/dist/lib/setup/invariant-bridge.d.ts +91 -0
  153. package/dist/lib/setup/invariant-bridge.d.ts.map +1 -0
  154. package/dist/lib/setup/invariant-bridge.js +384 -0
  155. package/dist/lib/setup/invariant-bridge.js.map +1 -0
  156. package/dist/lib/setup/spec-generator.d.ts +41 -5
  157. package/dist/lib/setup/spec-generator.d.ts.map +1 -1
  158. package/dist/lib/setup/spec-generator.js +503 -29
  159. package/dist/lib/setup/spec-generator.js.map +1 -1
  160. package/dist/lib/stratus-client.js +1 -1
  161. package/dist/lib/stratus-client.js.map +1 -1
  162. package/dist/lib/surface-agent.d.ts +78 -0
  163. package/dist/lib/surface-agent.d.ts.map +1 -0
  164. package/dist/lib/surface-agent.js +105 -0
  165. package/dist/lib/surface-agent.js.map +1 -0
  166. package/dist/lib/surface-coordination-example.d.ts +30 -0
  167. package/dist/lib/surface-coordination-example.d.ts.map +1 -0
  168. package/dist/lib/surface-coordination-example.js +164 -0
  169. package/dist/lib/surface-coordination-example.js.map +1 -0
  170. package/dist/lib/telemetry/physical-world-collector.d.ts +15 -0
  171. package/dist/lib/telemetry/physical-world-collector.d.ts.map +1 -0
  172. package/dist/lib/telemetry/physical-world-collector.js +177 -0
  173. package/dist/lib/telemetry/physical-world-collector.js.map +1 -0
  174. package/dist/lib/telemetry/training-bridge.d.ts +51 -0
  175. package/dist/lib/telemetry/training-bridge.d.ts.map +1 -0
  176. package/dist/lib/telemetry/training-bridge.js +185 -0
  177. package/dist/lib/telemetry/training-bridge.js.map +1 -0
  178. package/dist/lib/telemetry.d.ts +2 -1
  179. package/dist/lib/telemetry.d.ts.map +1 -1
  180. package/dist/lib/telemetry.js +23 -2
  181. package/dist/lib/telemetry.js.map +1 -1
  182. package/dist/lib/tenet-board-agent.d.ts +52 -0
  183. package/dist/lib/tenet-board-agent.d.ts.map +1 -0
  184. package/dist/lib/tenet-board-agent.js +226 -0
  185. package/dist/lib/tenet-board-agent.js.map +1 -0
  186. package/dist/lib/tenet-ide-agent.d.ts +40 -0
  187. package/dist/lib/tenet-ide-agent.d.ts.map +1 -0
  188. package/dist/lib/tenet-ide-agent.js +199 -0
  189. package/dist/lib/tenet-ide-agent.js.map +1 -0
  190. package/dist/lib/workspace/data-pipeline.d.ts.map +1 -1
  191. package/dist/lib/workspace/data-pipeline.js +27 -5
  192. package/dist/lib/workspace/data-pipeline.js.map +1 -1
  193. package/dist/lib/workspace/sidebar-runner.d.ts +13 -0
  194. package/dist/lib/workspace/sidebar-runner.d.ts.map +1 -0
  195. package/dist/lib/workspace/sidebar-runner.js +419 -0
  196. package/dist/lib/workspace/sidebar-runner.js.map +1 -0
  197. package/dist/lib/workspace/surface-registry.d.ts.map +1 -1
  198. package/dist/lib/workspace/surface-registry.js +4 -1
  199. package/dist/lib/workspace/surface-registry.js.map +1 -1
  200. package/dist/lib/workspace/surfaces/agent-overview.d.ts +3 -3
  201. package/dist/lib/workspace/surfaces/agent-overview.d.ts.map +1 -1
  202. package/dist/lib/workspace/surfaces/agent-overview.js +3 -3
  203. package/dist/lib/workspace/surfaces/agent-overview.js.map +1 -1
  204. package/dist/lib/workspace/surfaces/index.d.ts +3 -0
  205. package/dist/lib/workspace/surfaces/index.d.ts.map +1 -1
  206. package/dist/lib/workspace/surfaces/index.js +3 -0
  207. package/dist/lib/workspace/surfaces/index.js.map +1 -1
  208. package/dist/lib/workspace/surfaces/kanban.d.ts +15 -0
  209. package/dist/lib/workspace/surfaces/kanban.d.ts.map +1 -0
  210. package/dist/lib/workspace/surfaces/kanban.js +43 -0
  211. package/dist/lib/workspace/surfaces/kanban.js.map +1 -0
  212. package/dist/lib/workspace/surfaces/physical-world.d.ts +15 -0
  213. package/dist/lib/workspace/surfaces/physical-world.d.ts.map +1 -0
  214. package/dist/lib/workspace/surfaces/physical-world.js +37 -0
  215. package/dist/lib/workspace/surfaces/physical-world.js.map +1 -0
  216. package/dist/lib/workspace/surfaces/sidebar.d.ts +22 -0
  217. package/dist/lib/workspace/surfaces/sidebar.d.ts.map +1 -0
  218. package/dist/lib/workspace/surfaces/sidebar.js +90 -0
  219. package/dist/lib/workspace/surfaces/sidebar.js.map +1 -0
  220. package/dist/types/flows.d.ts +2 -1
  221. package/dist/types/flows.d.ts.map +1 -1
  222. package/dist/types/physical-world-model.d.ts +65 -0
  223. package/dist/types/physical-world-model.d.ts.map +1 -0
  224. package/dist/types/physical-world-model.js +43 -0
  225. package/dist/types/physical-world-model.js.map +1 -0
  226. package/dist/types/telemetry.d.ts +37 -0
  227. package/dist/types/telemetry.d.ts.map +1 -1
  228. package/dist/types/world-model.d.ts.map +1 -1
  229. package/dist/types/world-model.js +14 -7
  230. package/dist/types/world-model.js.map +1 -1
  231. package/dist/utils/context-hub-port.d.ts.map +1 -1
  232. package/dist/utils/context-hub-port.js +6 -1
  233. package/dist/utils/context-hub-port.js.map +1 -1
  234. package/package.json +3 -2
  235. package/packages/pi/extensions/index.ts +34 -6
  236. package/packages/pi/extensions/onboarding-v1.ts +8 -8
  237. package/packages/pi/extensions/onboarding-v2.ts +5 -5
  238. package/scripts/telemetry-dashboard.sh +44 -0
  239. package/scripts/test-planning-loop-e2e.ts +181 -0
  240. package/scripts/test-server-inference.ts +49 -0
  241. package/scripts/test-state-sensitivity.ts +32 -0
  242. package/scripts/train/v2/benchmark.py +661 -0
  243. package/scripts/train/v2/generate_balanced.py +439 -0
  244. package/scripts/train/v2/generate_hard_negatives.py +219 -0
  245. package/scripts/train/v2/infer.py +149 -36
  246. package/scripts/train/v2/infer_server.py +224 -0
  247. package/scripts/train/v2/online_train.py +576 -0
  248. package/scripts/train/v2/precompute.py +24 -6
  249. package/template/CLAUDE.md +74 -132
@@ -0,0 +1,576 @@
1
+ """
2
+ Online Learning Harness for v2 Policy Head.
3
+
4
+ Implements Drew's recommended hybrid approach (section 8.3):
5
+ - Experience replay: 70% historical + 30% new data per batch
6
+ - Small learning rate (1e-5) to avoid catastrophic forgetting
7
+ - Validation monitoring with automatic rollback if degradation >10%
8
+ - Continuous checkpointing for recovery
9
+
10
+ Usage:
11
+ # Fine-tune on new data with experience replay
12
+ python online_train.py --new-data .jfl/v2-data/new.jsonl --checkpoint .jfl/checkpoints/best_policy_head.pt
13
+
14
+ # Continuous mode: watch for new data and retrain automatically
15
+ python online_train.py --watch --checkpoint .jfl/checkpoints/best_policy_head.pt
16
+
17
+ Drew's architecture decision:
18
+ Pre-training: offline, 3e-4 LR, full dataset
19
+ Online: continuous, 1e-5 LR, experience replay, validation gating
20
+ Batch retraining: weekly, full offline, reset to best checkpoint
21
+ """
22
+
23
+ import json
24
+ import os
25
+ import sys
26
+ import time
27
+ import math
28
+ import random
29
+ import shutil
30
+ import argparse
31
+ from pathlib import Path
32
+
33
+ import numpy as np
34
+ import torch
35
+ import torch.nn as nn
36
+ import torch.optim as optim
37
+ from torch.utils.data import DataLoader, Dataset, ConcatDataset, Subset
38
+
39
+ from model import PolicyHead
40
+ from dataset import PolicyHeadDataset, load_embedding_cache
41
+
42
+
43
+ # ============================================================================
44
+ # Experience Replay Buffer
45
+ # ============================================================================
46
+
47
+ class ExperienceReplayBuffer:
48
+ """
49
+ Maintains a pool of historical training examples for replay.
50
+
51
+ Drew's recommendation (section 8.3):
52
+ Mix new transitions (30%) + sampled historical data (70%) in each batch.
53
+ """
54
+
55
+ def __init__(self, max_size: int = 10000, new_ratio: float = 0.3, seed: int = 42):
56
+ self.max_size = max_size
57
+ self.new_ratio = new_ratio
58
+ self.rng = random.Random(seed)
59
+ self.historical: list[dict] = []
60
+ self.new_examples: list[dict] = []
61
+
62
+ def add_historical(self, examples: list[dict]):
63
+ """Add examples to the historical pool."""
64
+ self.historical.extend(examples)
65
+ # Reservoir sampling if over max size
66
+ if len(self.historical) > self.max_size:
67
+ self.historical = self.rng.sample(self.historical, self.max_size)
68
+
69
+ def add_new(self, examples: list[dict]):
70
+ """Add new examples that will be weighted higher in sampling."""
71
+ self.new_examples.extend(examples)
72
+
73
+ def sample_batch(self, batch_size: int) -> list[dict]:
74
+ """
75
+ Sample a mixed batch: new_ratio % new + (1-new_ratio) % historical.
76
+ If not enough new examples, fill with historical.
77
+ """
78
+ n_new = min(
79
+ int(batch_size * self.new_ratio),
80
+ len(self.new_examples)
81
+ )
82
+ n_historical = batch_size - n_new
83
+
84
+ batch = []
85
+
86
+ if n_new > 0 and self.new_examples:
87
+ batch.extend(self.rng.sample(
88
+ self.new_examples,
89
+ min(n_new, len(self.new_examples))
90
+ ))
91
+
92
+ if n_historical > 0 and self.historical:
93
+ batch.extend(self.rng.sample(
94
+ self.historical,
95
+ min(n_historical, len(self.historical))
96
+ ))
97
+
98
+ self.rng.shuffle(batch)
99
+ return batch
100
+
101
+ def get_mixed_dataset_indices(
102
+ self, n_historical: int, n_new: int
103
+ ) -> tuple[list[int], list[int]]:
104
+ """Return indices for creating a mixed dataset split."""
105
+ hist_indices = self.rng.sample(
106
+ range(len(self.historical)),
107
+ min(n_historical, len(self.historical))
108
+ ) if self.historical else []
109
+
110
+ new_indices = self.rng.sample(
111
+ range(len(self.new_examples)),
112
+ min(n_new, len(self.new_examples))
113
+ ) if self.new_examples else []
114
+
115
+ return hist_indices, new_indices
116
+
117
+ @property
118
+ def total_size(self) -> int:
119
+ return len(self.historical) + len(self.new_examples)
120
+
121
+ def stats(self) -> dict:
122
+ return {
123
+ "historical": len(self.historical),
124
+ "new": len(self.new_examples),
125
+ "total": self.total_size,
126
+ "new_ratio": self.new_ratio,
127
+ "max_size": self.max_size,
128
+ }
129
+
130
+
131
+ # ============================================================================
132
+ # Validation Monitor
133
+ # ============================================================================
134
+
135
+ class ValidationMonitor:
136
+ """
137
+ Tracks validation metrics and triggers rollback on degradation.
138
+
139
+ Drew's recommendation:
140
+ Track performance on held-out test set, rollback if degradation detected.
141
+ Rollback plan: Automatic rollback if L3 metrics degrade >10%.
142
+ """
143
+
144
+ def __init__(
145
+ self,
146
+ degradation_threshold: float = 0.10,
147
+ patience: int = 3,
148
+ checkpoint_dir: str = ".jfl/checkpoints",
149
+ ):
150
+ self.degradation_threshold = degradation_threshold
151
+ self.patience = patience
152
+ self.checkpoint_dir = checkpoint_dir
153
+ self.baseline_accuracy: float | None = None
154
+ self.best_accuracy: float = 0.0
155
+ self.history: list[dict] = []
156
+ self.degradation_count: int = 0
157
+
158
+ def set_baseline(self, accuracy: float):
159
+ """Set the baseline accuracy from pre-trained model."""
160
+ self.baseline_accuracy = accuracy
161
+ self.best_accuracy = accuracy
162
+ print(f" Baseline accuracy: {accuracy:.1%}")
163
+
164
+ def check(self, epoch: int, val_accuracy: float, val_loss: float) -> dict:
165
+ """
166
+ Check if model has degraded beyond threshold.
167
+
168
+ Returns:
169
+ {
170
+ "action": "continue" | "rollback" | "save_best",
171
+ "reason": str,
172
+ "degradation": float,
173
+ }
174
+ """
175
+ self.history.append({
176
+ "epoch": epoch,
177
+ "val_accuracy": val_accuracy,
178
+ "val_loss": val_loss,
179
+ "timestamp": time.time(),
180
+ })
181
+
182
+ result = {
183
+ "action": "continue",
184
+ "reason": "",
185
+ "degradation": 0.0,
186
+ }
187
+
188
+ if self.baseline_accuracy is None:
189
+ self.set_baseline(val_accuracy)
190
+ return result
191
+
192
+ # Check for improvement
193
+ if val_accuracy > self.best_accuracy:
194
+ self.best_accuracy = val_accuracy
195
+ self.degradation_count = 0
196
+ result["action"] = "save_best"
197
+ result["reason"] = f"New best accuracy: {val_accuracy:.1%} (was {self.best_accuracy:.1%})"
198
+ return result
199
+
200
+ # Check for degradation
201
+ degradation = (self.baseline_accuracy - val_accuracy) / self.baseline_accuracy
202
+ result["degradation"] = degradation
203
+
204
+ if degradation > self.degradation_threshold:
205
+ self.degradation_count += 1
206
+ if self.degradation_count >= self.patience:
207
+ result["action"] = "rollback"
208
+ result["reason"] = (
209
+ f"Accuracy degraded {degradation:.1%} from baseline "
210
+ f"({val_accuracy:.1%} vs {self.baseline_accuracy:.1%}) "
211
+ f"for {self.degradation_count} consecutive checks"
212
+ )
213
+ else:
214
+ result["reason"] = (
215
+ f"Degradation {degradation:.1%} detected "
216
+ f"({self.degradation_count}/{self.patience} until rollback)"
217
+ )
218
+ else:
219
+ self.degradation_count = 0
220
+
221
+ return result
222
+
223
+ def save_rollback_checkpoint(self, model, optimizer, epoch: int, path: str):
224
+ """Save a checkpoint that can be rolled back to."""
225
+ os.makedirs(os.path.dirname(path), exist_ok=True)
226
+ torch.save({
227
+ "epoch": epoch,
228
+ "model_state_dict": model.state_dict(),
229
+ "optimizer_state_dict": optimizer.state_dict(),
230
+ "baseline_accuracy": self.baseline_accuracy,
231
+ "best_accuracy": self.best_accuracy,
232
+ }, path)
233
+
234
+
235
+ # ============================================================================
236
+ # Online Training Loop
237
+ # ============================================================================
238
+
239
+ def online_train(args):
240
+ """
241
+ Fine-tune policy head on new data with experience replay.
242
+
243
+ Key differences from offline train.py:
244
+ - Lower learning rate (1e-5 vs 3e-4)
245
+ - Experience replay (70% historical + 30% new)
246
+ - Validation monitoring with rollback
247
+ - Warm-starts from existing checkpoint (required)
248
+ """
249
+ # Device
250
+ if torch.cuda.is_available():
251
+ device = "cuda"
252
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
253
+ device = "mps"
254
+ else:
255
+ device = "cpu"
256
+ print(f"Device: {device}")
257
+
258
+ # Load existing checkpoint (required for online learning)
259
+ if not os.path.exists(args.checkpoint):
260
+ print(f"ERROR: Checkpoint not found: {args.checkpoint}")
261
+ print("Online learning requires a pre-trained checkpoint.")
262
+ print("Run offline training first: python train.py")
263
+ sys.exit(1)
264
+
265
+ checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=False)
266
+ config = checkpoint.get("config", {})
267
+ tool_to_index = checkpoint["tool_to_index"]
268
+ index_to_tool = checkpoint.get("index_to_tool", {str(v): k for k, v in tool_to_index.items()})
269
+ num_tools = checkpoint.get("num_tools", len(tool_to_index))
270
+ baseline_accuracy = checkpoint.get("val_accuracy", 0.0)
271
+
272
+ print(f"Loaded checkpoint: {args.checkpoint}")
273
+ print(f" Baseline val accuracy: {baseline_accuracy:.1%}")
274
+ print(f" Tools: {num_tools}")
275
+
276
+ # Embeddings
277
+ embeddings_matrix, text_to_idx = load_embedding_cache(args.data_dir)
278
+ if embeddings_matrix is not None:
279
+ print(f"Embedding cache: {embeddings_matrix.shape[0]} texts, {embeddings_matrix.shape[1]}-dim")
280
+
281
+ embedding_dim = config.get("embedding_dim", 768)
282
+
283
+ # Model
284
+ model = PolicyHead(
285
+ embedding_dim=embedding_dim,
286
+ hidden_dim=config.get("hidden_dim", 512),
287
+ num_tools=num_tools,
288
+ num_layers=config.get("num_layers", 4),
289
+ num_heads=config.get("num_heads", 8),
290
+ dropout=config.get("dropout", 0.1),
291
+ ).to(device)
292
+
293
+ # Load pre-trained weights
294
+ model.load_state_dict(checkpoint["model_state_dict"])
295
+ print(f" Loaded {model.num_parameters:,} parameters")
296
+
297
+ # Load datasets
298
+ # Historical = existing train data
299
+ train_path = os.path.join(args.data_dir, "train.jsonl")
300
+ val_path = os.path.join(args.data_dir, "val.jsonl")
301
+
302
+ if not os.path.exists(train_path):
303
+ print(f"Training data not found: {train_path}")
304
+ sys.exit(1)
305
+
306
+ historical_ds = PolicyHeadDataset(train_path, tool_to_index, embeddings_matrix, text_to_idx)
307
+ val_ds = PolicyHeadDataset(val_path, tool_to_index, embeddings_matrix, text_to_idx) if os.path.exists(val_path) else None
308
+
309
+ # New data (counterfactual + recent real)
310
+ new_paths = []
311
+ if args.new_data and os.path.exists(args.new_data):
312
+ new_paths.append(args.new_data)
313
+
314
+ # Also check for counterfactual data
315
+ cf_path = os.path.join(args.data_dir, "counterfactual.jsonl")
316
+ if os.path.exists(cf_path) and cf_path not in new_paths:
317
+ new_paths.append(cf_path)
318
+
319
+ new_datasets = []
320
+ for p in new_paths:
321
+ ds = PolicyHeadDataset(p, tool_to_index, embeddings_matrix, text_to_idx)
322
+ if len(ds) > 0:
323
+ new_datasets.append(ds)
324
+ print(f" New data: {p} ({len(ds)} examples)")
325
+
326
+ if not new_datasets:
327
+ print("No new data to train on. Nothing to do.")
328
+ return
329
+
330
+ new_ds = ConcatDataset(new_datasets) if len(new_datasets) > 1 else new_datasets[0]
331
+
332
+ # Experience replay: mix historical (70%) + new (30%)
333
+ replay_ratio = args.replay_ratio
334
+ n_new = len(new_ds)
335
+ n_historical = int(n_new * (1 - replay_ratio) / replay_ratio)
336
+ n_historical = min(n_historical, len(historical_ds))
337
+
338
+ print(f"\n Experience replay:")
339
+ print(f" Historical pool: {len(historical_ds)} examples")
340
+ print(f" New data: {n_new} examples")
341
+ print(f" Sampling: {n_historical} historical + {n_new} new = {n_historical + n_new} total")
342
+ print(f" Ratio: {n_new/(n_historical+n_new):.0%} new / {n_historical/(n_historical+n_new):.0%} historical")
343
+
344
+ # Create mixed dataset via random sampling
345
+ rng = random.Random(args.seed)
346
+ historical_indices = rng.sample(range(len(historical_ds)), n_historical)
347
+ historical_subset = Subset(historical_ds, historical_indices)
348
+
349
+ mixed_ds = ConcatDataset([historical_subset, new_ds])
350
+
351
+ num_workers = 0 if device == "mps" else min(4, os.cpu_count() or 1)
352
+ train_loader = DataLoader(mixed_ds, batch_size=args.batch_size, shuffle=True, num_workers=num_workers)
353
+ val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=num_workers) if val_ds else None
354
+
355
+ # Optimizer with low LR (Drew: 1e-5 for online vs 1e-3 for pre-training)
356
+ optimizer = optim.AdamW(
357
+ model.parameters(),
358
+ lr=args.lr,
359
+ weight_decay=args.weight_decay,
360
+ )
361
+
362
+ criterion = nn.CrossEntropyLoss(label_smoothing=0.05) # Less smoothing for fine-tuning
363
+
364
+ # Validation monitor
365
+ monitor = ValidationMonitor(
366
+ degradation_threshold=args.degradation_threshold,
367
+ patience=args.rollback_patience,
368
+ checkpoint_dir=args.output_dir,
369
+ )
370
+ monitor.set_baseline(baseline_accuracy)
371
+
372
+ # Save rollback checkpoint
373
+ rollback_path = os.path.join(args.output_dir, "rollback_checkpoint.pt")
374
+ shutil.copy2(args.checkpoint, rollback_path)
375
+ print(f" Rollback checkpoint: {rollback_path}")
376
+
377
+ # Training loop
378
+ print(f"\n Online fine-tuning for {args.epochs} epochs (lr={args.lr})...")
379
+ print(f" {'Epoch':>5} {'Train Loss':>12} {'Train Acc':>10} {'Val Loss':>10} {'Val Acc':>9} {'Status':>12}")
380
+ print(" " + "-" * 70)
381
+
382
+ from train import train_epoch, evaluate
383
+
384
+ for epoch in range(1, args.epochs + 1):
385
+ train_loss, train_acc = train_epoch(
386
+ model, train_loader, criterion, optimizer,
387
+ # Use constant LR (no scheduler for online)
388
+ type("FakeScheduler", (), {"step": lambda self: None, "get_last_lr": lambda self: [args.lr]})(),
389
+ device,
390
+ )
391
+
392
+ val_loss, val_acc = (0.0, 0.0)
393
+ if val_loader:
394
+ val_loss, val_acc = evaluate(model, val_loader, criterion, device)
395
+
396
+ # Check validation monitor
397
+ check_result = monitor.check(epoch, val_acc, val_loss)
398
+ status = check_result["action"]
399
+
400
+ status_str = {
401
+ "continue": "✓",
402
+ "save_best": "★ best",
403
+ "rollback": "⚠ ROLLBACK",
404
+ }.get(status, status)
405
+
406
+ print(
407
+ f" {epoch:5d} {train_loss:12.4f} {train_acc:9.1%} {val_loss:10.4f} {val_acc:8.1%} {status_str:>12}"
408
+ )
409
+
410
+ if status == "save_best":
411
+ # Save new best
412
+ save_checkpoint(model, optimizer, epoch, val_acc, val_loss, config,
413
+ tool_to_index, index_to_tool, num_tools,
414
+ len(mixed_ds), device, args.output_dir)
415
+
416
+ elif status == "rollback":
417
+ print(f"\n ⚠ ROLLING BACK: {check_result['reason']}")
418
+ print(f" Restoring: {rollback_path}")
419
+
420
+ # Restore from rollback checkpoint
421
+ rollback_ckpt = torch.load(rollback_path, map_location=device, weights_only=False)
422
+ model.load_state_dict(rollback_ckpt["model_state_dict"])
423
+ print(f" Rolled back to baseline accuracy: {baseline_accuracy:.1%}")
424
+
425
+ # Write rollback event
426
+ write_rollback_event(args.output_dir, epoch, check_result)
427
+ break
428
+
429
+ # Final save if we completed without rollback
430
+ if check_result.get("action") != "rollback":
431
+ save_checkpoint(model, optimizer, args.epochs, val_acc, val_loss, config,
432
+ tool_to_index, index_to_tool, num_tools,
433
+ len(mixed_ds), device, args.output_dir)
434
+ print(f"\n Online training complete. Final val accuracy: {val_acc:.1%}")
435
+
436
+ # Write training event
437
+ write_training_event(args.output_dir, {
438
+ "type": "online_training",
439
+ "epochs": epoch,
440
+ "final_val_accuracy": val_acc,
441
+ "baseline_accuracy": baseline_accuracy,
442
+ "new_examples": n_new,
443
+ "historical_examples": n_historical,
444
+ "replay_ratio": replay_ratio,
445
+ "rollback": check_result.get("action") == "rollback",
446
+ })
447
+
448
+
449
+ def save_checkpoint(model, optimizer, epoch, val_acc, val_loss, config,
450
+ tool_to_index, index_to_tool, num_tools, trained_on,
451
+ device, output_dir):
452
+ """Save checkpoint + metadata."""
453
+ os.makedirs(output_dir, exist_ok=True)
454
+
455
+ ckpt_path = os.path.join(output_dir, "best_policy_head.pt")
456
+ torch.save({
457
+ "epoch": epoch,
458
+ "model_state_dict": model.state_dict(),
459
+ "optimizer_state_dict": optimizer.state_dict(),
460
+ "val_accuracy": val_acc,
461
+ "val_loss": val_loss,
462
+ "num_tools": num_tools,
463
+ "tool_to_index": tool_to_index,
464
+ "index_to_tool": index_to_tool,
465
+ "config": config,
466
+ "training_mode": "online",
467
+ }, ckpt_path)
468
+
469
+ meta = {
470
+ "version": 2,
471
+ "architecture": f"transformer-{config.get('num_layers', 4)}layer-{config.get('hidden_dim', 512)}h",
472
+ "embedding_dim": config.get("embedding_dim", 768),
473
+ "hidden_dim": config.get("hidden_dim", 512),
474
+ "num_tools": num_tools,
475
+ "trained_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
476
+ "trained_on": trained_on,
477
+ "val_accuracy": val_acc,
478
+ "training_mode": "online",
479
+ "tool_to_index": tool_to_index,
480
+ "index_to_tool": {str(k): v for k, v in index_to_tool.items()},
481
+ "checkpoint_path": os.path.abspath(ckpt_path),
482
+ }
483
+ meta_path = os.path.join(output_dir, "policy-head-v2.json")
484
+ with open(meta_path, "w") as f:
485
+ json.dump(meta, f, indent=2)
486
+
487
+
488
+ def write_rollback_event(output_dir, epoch, check_result):
489
+ """Record rollback event for monitoring."""
490
+ event_path = os.path.join(output_dir, "training-events.jsonl")
491
+ event = {
492
+ "type": "rollback",
493
+ "epoch": epoch,
494
+ "reason": check_result["reason"],
495
+ "degradation": check_result["degradation"],
496
+ "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
497
+ }
498
+ with open(event_path, "a") as f:
499
+ f.write(json.dumps(event) + "\n")
500
+
501
+
502
+ def write_training_event(output_dir, data):
503
+ """Record training event for monitoring."""
504
+ event_path = os.path.join(output_dir, "training-events.jsonl")
505
+ data["timestamp"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
506
+ os.makedirs(output_dir, exist_ok=True)
507
+ with open(event_path, "a") as f:
508
+ f.write(json.dumps(data) + "\n")
509
+
510
+
511
+ def main():
512
+ parser = argparse.ArgumentParser(
513
+ description="Online learning for v2 Policy Head with experience replay"
514
+ )
515
+ parser.add_argument(
516
+ "--checkpoint", required=True,
517
+ help="Path to pre-trained checkpoint (.pt)"
518
+ )
519
+ parser.add_argument(
520
+ "--data-dir", default=".jfl/v2-data",
521
+ help="Directory with train/val JSONL + embeddings"
522
+ )
523
+ parser.add_argument(
524
+ "--new-data", default=None,
525
+ help="Path to new training data JSONL"
526
+ )
527
+ parser.add_argument(
528
+ "--output-dir", default=".jfl/checkpoints",
529
+ help="Output directory for checkpoints"
530
+ )
531
+ parser.add_argument(
532
+ "--domain", default=None,
533
+ help="Path to domain.json (uses default if not specified)"
534
+ )
535
+ parser.add_argument(
536
+ "--epochs", type=int, default=10,
537
+ help="Max fine-tuning epochs (default: 10, much less than offline)"
538
+ )
539
+ parser.add_argument(
540
+ "--batch-size", type=int, default=32,
541
+ help="Batch size"
542
+ )
543
+ parser.add_argument(
544
+ "--lr", type=float, default=1e-5,
545
+ help="Learning rate (default: 1e-5, much lower than offline 3e-4)"
546
+ )
547
+ parser.add_argument(
548
+ "--weight-decay", type=float, default=0.01,
549
+ help="Weight decay"
550
+ )
551
+ parser.add_argument(
552
+ "--replay-ratio", type=float, default=0.3,
553
+ help="Fraction of batch that is new data (default: 0.3 = 30%% new)"
554
+ )
555
+ parser.add_argument(
556
+ "--degradation-threshold", type=float, default=0.10,
557
+ help="Rollback if val accuracy drops more than this fraction (default: 0.10 = 10%%)"
558
+ )
559
+ parser.add_argument(
560
+ "--rollback-patience", type=int, default=3,
561
+ help="Number of consecutive degraded epochs before rollback"
562
+ )
563
+ parser.add_argument(
564
+ "--seed", type=int, default=42,
565
+ help="Random seed"
566
+ )
567
+ args = parser.parse_args()
568
+
569
+ if args.domain is None:
570
+ args.domain = os.path.join(os.path.dirname(os.path.abspath(__file__)), "domain.json")
571
+
572
+ online_train(args)
573
+
574
+
575
+ if __name__ == "__main__":
576
+ main()
@@ -57,21 +57,39 @@ def precompute_embeddings(
57
57
  all_states = set()
58
58
  all_goals = set()
59
59
 
60
- for split in ["train", "val", "test"]:
60
+ for split in ["train", "val", "test", "benchmark", "counterfactual", "synthetic"]:
61
61
  path = os.path.join(data_dir, f"{split}.jsonl")
62
62
  if not os.path.exists(path):
63
- print(f" Skipping {split} (file not found)")
63
+ if split in ["train", "val", "test"]:
64
+ print(f" Skipping {split} (file not found)")
64
65
  continue
65
66
  states, goals = collect_unique_texts(path)
66
67
  all_states.update(states)
67
68
  all_goals.update(goals)
69
+ if split not in ["train", "val", "test"]:
70
+ print(f" Added {split}: {len(states)} states, {len(goals)} goals")
68
71
 
69
72
  all_texts = sorted(all_states | all_goals)
70
73
  print(f"Unique texts to embed: {len(all_texts)} ({len(all_states)} states, {len(all_goals)} goals)")
71
74
 
75
+ # Load existing cache to avoid re-embedding
72
76
  text_to_embedding = {}
73
- for i in range(0, len(all_texts), batch_size):
74
- batch = all_texts[i : i + batch_size]
77
+ cache_path = os.path.join(data_dir, "embeddings_cache.npz")
78
+ index_path = os.path.join(data_dir, "text_to_idx.json")
79
+ if os.path.exists(cache_path) and os.path.exists(index_path):
80
+ existing_idx = json.load(open(index_path))
81
+ existing_emb = np.load(cache_path, allow_pickle=True)["embeddings"]
82
+ for text, idx in existing_idx.items():
83
+ if idx < len(existing_emb):
84
+ text_to_embedding[text] = existing_emb[idx].tolist()
85
+ print(f" Loaded {len(text_to_embedding)} cached embeddings")
86
+
87
+ # Only embed new texts
88
+ new_texts = [t for t in all_texts if t not in text_to_embedding]
89
+ print(f" New texts to embed: {len(new_texts)} (cached: {len(text_to_embedding)})")
90
+
91
+ for i in range(0, len(new_texts), batch_size):
92
+ batch = new_texts[i : i + batch_size]
75
93
  try:
76
94
  embeddings = embedder(batch)
77
95
  for text, emb in zip(batch, embeddings):
@@ -80,8 +98,8 @@ def precompute_embeddings(
80
98
  print(f" Error embedding batch {i}-{i + len(batch)}: {e}")
81
99
  continue
82
100
 
83
- done = min(i + batch_size, len(all_texts))
84
- print(f" Embedded {done}/{len(all_texts)} texts")
101
+ done = min(i + batch_size, len(new_texts))
102
+ print(f" Embedded {done}/{len(new_texts)} new texts")
85
103
 
86
104
  texts_list = sorted(text_to_embedding.keys())
87
105
  text_to_idx = {t: i for i, t in enumerate(texts_list)}