jfl 0.8.0 → 0.9.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/commands/doctor.d.ts +1 -0
- package/dist/commands/doctor.d.ts.map +1 -1
- package/dist/commands/doctor.js +30 -1
- package/dist/commands/doctor.js.map +1 -1
- package/dist/commands/ide.d.ts +2 -1
- package/dist/commands/ide.d.ts.map +1 -1
- package/dist/commands/ide.js +60 -1
- package/dist/commands/ide.js.map +1 -1
- package/dist/commands/init-from-service.d.ts +15 -0
- package/dist/commands/init-from-service.d.ts.map +1 -0
- package/dist/commands/init-from-service.js +541 -0
- package/dist/commands/init-from-service.js.map +1 -0
- package/dist/commands/init.d.ts +1 -0
- package/dist/commands/init.d.ts.map +1 -1
- package/dist/commands/init.js +32 -1
- package/dist/commands/init.js.map +1 -1
- package/dist/commands/kanban.d.ts.map +1 -1
- package/dist/commands/kanban.js +13 -4
- package/dist/commands/kanban.js.map +1 -1
- package/dist/commands/linear.d.ts +41 -0
- package/dist/commands/linear.d.ts.map +1 -0
- package/dist/commands/linear.js +715 -0
- package/dist/commands/linear.js.map +1 -0
- package/dist/commands/peter.d.ts.map +1 -1
- package/dist/commands/peter.js +232 -25
- package/dist/commands/peter.js.map +1 -1
- package/dist/commands/services.d.ts.map +1 -1
- package/dist/commands/services.js +146 -0
- package/dist/commands/services.js.map +1 -1
- package/dist/commands/setup.d.ts.map +1 -1
- package/dist/commands/setup.js +173 -13
- package/dist/commands/setup.js.map +1 -1
- package/dist/commands/telemetry-monitor.d.ts +11 -0
- package/dist/commands/telemetry-monitor.d.ts.map +1 -0
- package/dist/commands/telemetry-monitor.js +224 -0
- package/dist/commands/telemetry-monitor.js.map +1 -0
- package/dist/commands/telemetry-test.d.ts +11 -0
- package/dist/commands/telemetry-test.d.ts.map +1 -0
- package/dist/commands/telemetry-test.js +67 -0
- package/dist/commands/telemetry-test.js.map +1 -0
- package/dist/commands/tenet-agents.d.ts +13 -0
- package/dist/commands/tenet-agents.d.ts.map +1 -0
- package/dist/commands/tenet-agents.js +191 -0
- package/dist/commands/tenet-agents.js.map +1 -0
- package/dist/commands/tenet-setup.d.ts +19 -0
- package/dist/commands/tenet-setup.d.ts.map +1 -0
- package/dist/commands/tenet-setup.js +131 -0
- package/dist/commands/tenet-setup.js.map +1 -0
- package/dist/commands/train.d.ts +18 -0
- package/dist/commands/train.d.ts.map +1 -1
- package/dist/commands/train.js +182 -0
- package/dist/commands/train.js.map +1 -1
- package/dist/commands/whoami.d.ts +2 -0
- package/dist/commands/whoami.d.ts.map +1 -0
- package/dist/commands/whoami.js +24 -0
- package/dist/commands/whoami.js.map +1 -0
- package/dist/index.js +159 -10
- package/dist/index.js.map +1 -1
- package/dist/lib/advanced-setup.d.ts +78 -0
- package/dist/lib/advanced-setup.d.ts.map +1 -0
- package/dist/lib/advanced-setup.js +433 -0
- package/dist/lib/advanced-setup.js.map +1 -0
- package/dist/lib/agent-config.d.ts +33 -0
- package/dist/lib/agent-config.d.ts.map +1 -1
- package/dist/lib/agent-config.js +26 -0
- package/dist/lib/agent-config.js.map +1 -1
- package/dist/lib/counterfactual-training-bridge.d.ts +114 -0
- package/dist/lib/counterfactual-training-bridge.d.ts.map +1 -0
- package/dist/lib/counterfactual-training-bridge.js +322 -0
- package/dist/lib/counterfactual-training-bridge.js.map +1 -0
- package/dist/lib/discovery-agent.d.ts +48 -0
- package/dist/lib/discovery-agent.d.ts.map +1 -0
- package/dist/lib/discovery-agent.js +111 -0
- package/dist/lib/discovery-agent.js.map +1 -0
- package/dist/lib/flow-engine.d.ts.map +1 -1
- package/dist/lib/flow-engine.js +46 -8
- package/dist/lib/flow-engine.js.map +1 -1
- package/dist/lib/gtm-generator.d.ts +29 -0
- package/dist/lib/gtm-generator.d.ts.map +1 -0
- package/dist/lib/gtm-generator.js +252 -0
- package/dist/lib/gtm-generator.js.map +1 -0
- package/dist/lib/hub-health.d.ts +40 -0
- package/dist/lib/hub-health.d.ts.map +1 -0
- package/dist/lib/hub-health.js +89 -0
- package/dist/lib/hub-health.js.map +1 -0
- package/dist/lib/invariant-monitor.d.ts +6 -2
- package/dist/lib/invariant-monitor.d.ts.map +1 -1
- package/dist/lib/invariant-monitor.js +89 -2
- package/dist/lib/invariant-monitor.js.map +1 -1
- package/dist/lib/journal-analyzer.d.ts +71 -0
- package/dist/lib/journal-analyzer.d.ts.map +1 -0
- package/dist/lib/journal-analyzer.js +306 -0
- package/dist/lib/journal-analyzer.js.map +1 -0
- package/dist/lib/linear-client.d.ts +73 -0
- package/dist/lib/linear-client.d.ts.map +1 -0
- package/dist/lib/linear-client.js +112 -0
- package/dist/lib/linear-client.js.map +1 -0
- package/dist/lib/linear-id-map.d.ts +20 -0
- package/dist/lib/linear-id-map.d.ts.map +1 -0
- package/dist/lib/linear-id-map.js +57 -0
- package/dist/lib/linear-id-map.js.map +1 -0
- package/dist/lib/linear-kanban.d.ts +66 -0
- package/dist/lib/linear-kanban.d.ts.map +1 -0
- package/dist/lib/linear-kanban.js +175 -0
- package/dist/lib/linear-kanban.js.map +1 -0
- package/dist/lib/onboarding.d.ts +40 -0
- package/dist/lib/onboarding.d.ts.map +1 -0
- package/dist/lib/onboarding.js +213 -0
- package/dist/lib/onboarding.js.map +1 -0
- package/dist/lib/physical-world-model.d.ts +50 -0
- package/dist/lib/physical-world-model.d.ts.map +1 -0
- package/dist/lib/physical-world-model.js +251 -0
- package/dist/lib/physical-world-model.js.map +1 -0
- package/dist/lib/planning-loop.d.ts +157 -0
- package/dist/lib/planning-loop.d.ts.map +1 -0
- package/dist/lib/planning-loop.js +537 -0
- package/dist/lib/planning-loop.js.map +1 -0
- package/dist/lib/policy-head.d.ts +13 -0
- package/dist/lib/policy-head.d.ts.map +1 -1
- package/dist/lib/policy-head.js +168 -2
- package/dist/lib/policy-head.js.map +1 -1
- package/dist/lib/resource-optimizer-middleware.d.ts +39 -0
- package/dist/lib/resource-optimizer-middleware.d.ts.map +1 -0
- package/dist/lib/resource-optimizer-middleware.js +222 -0
- package/dist/lib/resource-optimizer-middleware.js.map +1 -0
- package/dist/lib/resource-optimizer.d.ts +71 -0
- package/dist/lib/resource-optimizer.d.ts.map +1 -0
- package/dist/lib/resource-optimizer.js +228 -0
- package/dist/lib/resource-optimizer.js.map +1 -0
- package/dist/lib/rl-manager.d.ts +74 -0
- package/dist/lib/rl-manager.d.ts.map +1 -0
- package/dist/lib/rl-manager.js +244 -0
- package/dist/lib/rl-manager.js.map +1 -0
- package/dist/lib/service-analyzer.d.ts +76 -0
- package/dist/lib/service-analyzer.d.ts.map +1 -0
- package/dist/lib/service-analyzer.js +704 -0
- package/dist/lib/service-analyzer.js.map +1 -0
- package/dist/lib/service-gtm.js +2 -2
- package/dist/lib/service-gtm.js.map +1 -1
- package/dist/lib/service-questionnaire.d.ts +11 -0
- package/dist/lib/service-questionnaire.d.ts.map +1 -0
- package/dist/lib/service-questionnaire.js +89 -0
- package/dist/lib/service-questionnaire.js.map +1 -0
- package/dist/lib/setup/agent-generator.d.ts +2 -0
- package/dist/lib/setup/agent-generator.d.ts.map +1 -1
- package/dist/lib/setup/agent-generator.js +128 -4
- package/dist/lib/setup/agent-generator.js.map +1 -1
- package/dist/lib/setup/flow-generator.d.ts +10 -0
- package/dist/lib/setup/flow-generator.d.ts.map +1 -0
- package/dist/lib/setup/flow-generator.js +113 -0
- package/dist/lib/setup/flow-generator.js.map +1 -0
- package/dist/lib/setup/invariant-bridge.d.ts +91 -0
- package/dist/lib/setup/invariant-bridge.d.ts.map +1 -0
- package/dist/lib/setup/invariant-bridge.js +384 -0
- package/dist/lib/setup/invariant-bridge.js.map +1 -0
- package/dist/lib/setup/spec-generator.d.ts +41 -5
- package/dist/lib/setup/spec-generator.d.ts.map +1 -1
- package/dist/lib/setup/spec-generator.js +503 -29
- package/dist/lib/setup/spec-generator.js.map +1 -1
- package/dist/lib/stratus-client.js +1 -1
- package/dist/lib/stratus-client.js.map +1 -1
- package/dist/lib/surface-agent.d.ts +78 -0
- package/dist/lib/surface-agent.d.ts.map +1 -0
- package/dist/lib/surface-agent.js +105 -0
- package/dist/lib/surface-agent.js.map +1 -0
- package/dist/lib/surface-coordination-example.d.ts +30 -0
- package/dist/lib/surface-coordination-example.d.ts.map +1 -0
- package/dist/lib/surface-coordination-example.js +164 -0
- package/dist/lib/surface-coordination-example.js.map +1 -0
- package/dist/lib/telemetry/physical-world-collector.d.ts +15 -0
- package/dist/lib/telemetry/physical-world-collector.d.ts.map +1 -0
- package/dist/lib/telemetry/physical-world-collector.js +177 -0
- package/dist/lib/telemetry/physical-world-collector.js.map +1 -0
- package/dist/lib/telemetry/training-bridge.d.ts +51 -0
- package/dist/lib/telemetry/training-bridge.d.ts.map +1 -0
- package/dist/lib/telemetry/training-bridge.js +185 -0
- package/dist/lib/telemetry/training-bridge.js.map +1 -0
- package/dist/lib/telemetry.d.ts +2 -1
- package/dist/lib/telemetry.d.ts.map +1 -1
- package/dist/lib/telemetry.js +23 -2
- package/dist/lib/telemetry.js.map +1 -1
- package/dist/lib/tenet-board-agent.d.ts +52 -0
- package/dist/lib/tenet-board-agent.d.ts.map +1 -0
- package/dist/lib/tenet-board-agent.js +226 -0
- package/dist/lib/tenet-board-agent.js.map +1 -0
- package/dist/lib/tenet-ide-agent.d.ts +40 -0
- package/dist/lib/tenet-ide-agent.d.ts.map +1 -0
- package/dist/lib/tenet-ide-agent.js +199 -0
- package/dist/lib/tenet-ide-agent.js.map +1 -0
- package/dist/lib/workspace/data-pipeline.d.ts.map +1 -1
- package/dist/lib/workspace/data-pipeline.js +27 -5
- package/dist/lib/workspace/data-pipeline.js.map +1 -1
- package/dist/lib/workspace/sidebar-runner.d.ts +13 -0
- package/dist/lib/workspace/sidebar-runner.d.ts.map +1 -0
- package/dist/lib/workspace/sidebar-runner.js +419 -0
- package/dist/lib/workspace/sidebar-runner.js.map +1 -0
- package/dist/lib/workspace/surface-registry.d.ts.map +1 -1
- package/dist/lib/workspace/surface-registry.js +4 -1
- package/dist/lib/workspace/surface-registry.js.map +1 -1
- package/dist/lib/workspace/surfaces/agent-overview.d.ts +3 -3
- package/dist/lib/workspace/surfaces/agent-overview.d.ts.map +1 -1
- package/dist/lib/workspace/surfaces/agent-overview.js +3 -3
- package/dist/lib/workspace/surfaces/agent-overview.js.map +1 -1
- package/dist/lib/workspace/surfaces/index.d.ts +3 -0
- package/dist/lib/workspace/surfaces/index.d.ts.map +1 -1
- package/dist/lib/workspace/surfaces/index.js +3 -0
- package/dist/lib/workspace/surfaces/index.js.map +1 -1
- package/dist/lib/workspace/surfaces/kanban.d.ts +15 -0
- package/dist/lib/workspace/surfaces/kanban.d.ts.map +1 -0
- package/dist/lib/workspace/surfaces/kanban.js +43 -0
- package/dist/lib/workspace/surfaces/kanban.js.map +1 -0
- package/dist/lib/workspace/surfaces/physical-world.d.ts +15 -0
- package/dist/lib/workspace/surfaces/physical-world.d.ts.map +1 -0
- package/dist/lib/workspace/surfaces/physical-world.js +37 -0
- package/dist/lib/workspace/surfaces/physical-world.js.map +1 -0
- package/dist/lib/workspace/surfaces/sidebar.d.ts +22 -0
- package/dist/lib/workspace/surfaces/sidebar.d.ts.map +1 -0
- package/dist/lib/workspace/surfaces/sidebar.js +90 -0
- package/dist/lib/workspace/surfaces/sidebar.js.map +1 -0
- package/dist/types/flows.d.ts +2 -1
- package/dist/types/flows.d.ts.map +1 -1
- package/dist/types/physical-world-model.d.ts +65 -0
- package/dist/types/physical-world-model.d.ts.map +1 -0
- package/dist/types/physical-world-model.js +43 -0
- package/dist/types/physical-world-model.js.map +1 -0
- package/dist/types/telemetry.d.ts +37 -0
- package/dist/types/telemetry.d.ts.map +1 -1
- package/dist/types/world-model.d.ts.map +1 -1
- package/dist/types/world-model.js +14 -7
- package/dist/types/world-model.js.map +1 -1
- package/dist/utils/context-hub-port.d.ts.map +1 -1
- package/dist/utils/context-hub-port.js +6 -1
- package/dist/utils/context-hub-port.js.map +1 -1
- package/package.json +3 -2
- package/packages/pi/extensions/index.ts +34 -6
- package/packages/pi/extensions/onboarding-v1.ts +8 -8
- package/packages/pi/extensions/onboarding-v2.ts +5 -5
- package/scripts/telemetry-dashboard.sh +44 -0
- package/scripts/test-planning-loop-e2e.ts +181 -0
- package/scripts/test-server-inference.ts +49 -0
- package/scripts/test-state-sensitivity.ts +32 -0
- package/scripts/train/v2/benchmark.py +661 -0
- package/scripts/train/v2/generate_balanced.py +439 -0
- package/scripts/train/v2/generate_hard_negatives.py +219 -0
- package/scripts/train/v2/infer.py +149 -36
- package/scripts/train/v2/infer_server.py +224 -0
- package/scripts/train/v2/online_train.py +576 -0
- package/scripts/train/v2/precompute.py +24 -6
- package/template/CLAUDE.md +74 -132
|
@@ -0,0 +1,576 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Online Learning Harness for v2 Policy Head.
|
|
3
|
+
|
|
4
|
+
Implements Drew's recommended hybrid approach (section 8.3):
|
|
5
|
+
- Experience replay: 70% historical + 30% new data per batch
|
|
6
|
+
- Small learning rate (1e-5) to avoid catastrophic forgetting
|
|
7
|
+
- Validation monitoring with automatic rollback if degradation >10%
|
|
8
|
+
- Continuous checkpointing for recovery
|
|
9
|
+
|
|
10
|
+
Usage:
|
|
11
|
+
# Fine-tune on new data with experience replay
|
|
12
|
+
python online_train.py --new-data .jfl/v2-data/new.jsonl --checkpoint .jfl/checkpoints/best_policy_head.pt
|
|
13
|
+
|
|
14
|
+
# Continuous mode: watch for new data and retrain automatically
|
|
15
|
+
python online_train.py --watch --checkpoint .jfl/checkpoints/best_policy_head.pt
|
|
16
|
+
|
|
17
|
+
Drew's architecture decision:
|
|
18
|
+
Pre-training: offline, 3e-4 LR, full dataset
|
|
19
|
+
Online: continuous, 1e-5 LR, experience replay, validation gating
|
|
20
|
+
Batch retraining: weekly, full offline, reset to best checkpoint
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import json
|
|
24
|
+
import os
|
|
25
|
+
import sys
|
|
26
|
+
import time
|
|
27
|
+
import math
|
|
28
|
+
import random
|
|
29
|
+
import shutil
|
|
30
|
+
import argparse
|
|
31
|
+
from pathlib import Path
|
|
32
|
+
|
|
33
|
+
import numpy as np
|
|
34
|
+
import torch
|
|
35
|
+
import torch.nn as nn
|
|
36
|
+
import torch.optim as optim
|
|
37
|
+
from torch.utils.data import DataLoader, Dataset, ConcatDataset, Subset
|
|
38
|
+
|
|
39
|
+
from model import PolicyHead
|
|
40
|
+
from dataset import PolicyHeadDataset, load_embedding_cache
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
# ============================================================================
|
|
44
|
+
# Experience Replay Buffer
|
|
45
|
+
# ============================================================================
|
|
46
|
+
|
|
47
|
+
class ExperienceReplayBuffer:
|
|
48
|
+
"""
|
|
49
|
+
Maintains a pool of historical training examples for replay.
|
|
50
|
+
|
|
51
|
+
Drew's recommendation (section 8.3):
|
|
52
|
+
Mix new transitions (30%) + sampled historical data (70%) in each batch.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(self, max_size: int = 10000, new_ratio: float = 0.3, seed: int = 42):
|
|
56
|
+
self.max_size = max_size
|
|
57
|
+
self.new_ratio = new_ratio
|
|
58
|
+
self.rng = random.Random(seed)
|
|
59
|
+
self.historical: list[dict] = []
|
|
60
|
+
self.new_examples: list[dict] = []
|
|
61
|
+
|
|
62
|
+
def add_historical(self, examples: list[dict]):
|
|
63
|
+
"""Add examples to the historical pool."""
|
|
64
|
+
self.historical.extend(examples)
|
|
65
|
+
# Reservoir sampling if over max size
|
|
66
|
+
if len(self.historical) > self.max_size:
|
|
67
|
+
self.historical = self.rng.sample(self.historical, self.max_size)
|
|
68
|
+
|
|
69
|
+
def add_new(self, examples: list[dict]):
|
|
70
|
+
"""Add new examples that will be weighted higher in sampling."""
|
|
71
|
+
self.new_examples.extend(examples)
|
|
72
|
+
|
|
73
|
+
def sample_batch(self, batch_size: int) -> list[dict]:
|
|
74
|
+
"""
|
|
75
|
+
Sample a mixed batch: new_ratio % new + (1-new_ratio) % historical.
|
|
76
|
+
If not enough new examples, fill with historical.
|
|
77
|
+
"""
|
|
78
|
+
n_new = min(
|
|
79
|
+
int(batch_size * self.new_ratio),
|
|
80
|
+
len(self.new_examples)
|
|
81
|
+
)
|
|
82
|
+
n_historical = batch_size - n_new
|
|
83
|
+
|
|
84
|
+
batch = []
|
|
85
|
+
|
|
86
|
+
if n_new > 0 and self.new_examples:
|
|
87
|
+
batch.extend(self.rng.sample(
|
|
88
|
+
self.new_examples,
|
|
89
|
+
min(n_new, len(self.new_examples))
|
|
90
|
+
))
|
|
91
|
+
|
|
92
|
+
if n_historical > 0 and self.historical:
|
|
93
|
+
batch.extend(self.rng.sample(
|
|
94
|
+
self.historical,
|
|
95
|
+
min(n_historical, len(self.historical))
|
|
96
|
+
))
|
|
97
|
+
|
|
98
|
+
self.rng.shuffle(batch)
|
|
99
|
+
return batch
|
|
100
|
+
|
|
101
|
+
def get_mixed_dataset_indices(
|
|
102
|
+
self, n_historical: int, n_new: int
|
|
103
|
+
) -> tuple[list[int], list[int]]:
|
|
104
|
+
"""Return indices for creating a mixed dataset split."""
|
|
105
|
+
hist_indices = self.rng.sample(
|
|
106
|
+
range(len(self.historical)),
|
|
107
|
+
min(n_historical, len(self.historical))
|
|
108
|
+
) if self.historical else []
|
|
109
|
+
|
|
110
|
+
new_indices = self.rng.sample(
|
|
111
|
+
range(len(self.new_examples)),
|
|
112
|
+
min(n_new, len(self.new_examples))
|
|
113
|
+
) if self.new_examples else []
|
|
114
|
+
|
|
115
|
+
return hist_indices, new_indices
|
|
116
|
+
|
|
117
|
+
@property
|
|
118
|
+
def total_size(self) -> int:
|
|
119
|
+
return len(self.historical) + len(self.new_examples)
|
|
120
|
+
|
|
121
|
+
def stats(self) -> dict:
|
|
122
|
+
return {
|
|
123
|
+
"historical": len(self.historical),
|
|
124
|
+
"new": len(self.new_examples),
|
|
125
|
+
"total": self.total_size,
|
|
126
|
+
"new_ratio": self.new_ratio,
|
|
127
|
+
"max_size": self.max_size,
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
# ============================================================================
|
|
132
|
+
# Validation Monitor
|
|
133
|
+
# ============================================================================
|
|
134
|
+
|
|
135
|
+
class ValidationMonitor:
|
|
136
|
+
"""
|
|
137
|
+
Tracks validation metrics and triggers rollback on degradation.
|
|
138
|
+
|
|
139
|
+
Drew's recommendation:
|
|
140
|
+
Track performance on held-out test set, rollback if degradation detected.
|
|
141
|
+
Rollback plan: Automatic rollback if L3 metrics degrade >10%.
|
|
142
|
+
"""
|
|
143
|
+
|
|
144
|
+
def __init__(
|
|
145
|
+
self,
|
|
146
|
+
degradation_threshold: float = 0.10,
|
|
147
|
+
patience: int = 3,
|
|
148
|
+
checkpoint_dir: str = ".jfl/checkpoints",
|
|
149
|
+
):
|
|
150
|
+
self.degradation_threshold = degradation_threshold
|
|
151
|
+
self.patience = patience
|
|
152
|
+
self.checkpoint_dir = checkpoint_dir
|
|
153
|
+
self.baseline_accuracy: float | None = None
|
|
154
|
+
self.best_accuracy: float = 0.0
|
|
155
|
+
self.history: list[dict] = []
|
|
156
|
+
self.degradation_count: int = 0
|
|
157
|
+
|
|
158
|
+
def set_baseline(self, accuracy: float):
|
|
159
|
+
"""Set the baseline accuracy from pre-trained model."""
|
|
160
|
+
self.baseline_accuracy = accuracy
|
|
161
|
+
self.best_accuracy = accuracy
|
|
162
|
+
print(f" Baseline accuracy: {accuracy:.1%}")
|
|
163
|
+
|
|
164
|
+
def check(self, epoch: int, val_accuracy: float, val_loss: float) -> dict:
|
|
165
|
+
"""
|
|
166
|
+
Check if model has degraded beyond threshold.
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
{
|
|
170
|
+
"action": "continue" | "rollback" | "save_best",
|
|
171
|
+
"reason": str,
|
|
172
|
+
"degradation": float,
|
|
173
|
+
}
|
|
174
|
+
"""
|
|
175
|
+
self.history.append({
|
|
176
|
+
"epoch": epoch,
|
|
177
|
+
"val_accuracy": val_accuracy,
|
|
178
|
+
"val_loss": val_loss,
|
|
179
|
+
"timestamp": time.time(),
|
|
180
|
+
})
|
|
181
|
+
|
|
182
|
+
result = {
|
|
183
|
+
"action": "continue",
|
|
184
|
+
"reason": "",
|
|
185
|
+
"degradation": 0.0,
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
if self.baseline_accuracy is None:
|
|
189
|
+
self.set_baseline(val_accuracy)
|
|
190
|
+
return result
|
|
191
|
+
|
|
192
|
+
# Check for improvement
|
|
193
|
+
if val_accuracy > self.best_accuracy:
|
|
194
|
+
self.best_accuracy = val_accuracy
|
|
195
|
+
self.degradation_count = 0
|
|
196
|
+
result["action"] = "save_best"
|
|
197
|
+
result["reason"] = f"New best accuracy: {val_accuracy:.1%} (was {self.best_accuracy:.1%})"
|
|
198
|
+
return result
|
|
199
|
+
|
|
200
|
+
# Check for degradation
|
|
201
|
+
degradation = (self.baseline_accuracy - val_accuracy) / self.baseline_accuracy
|
|
202
|
+
result["degradation"] = degradation
|
|
203
|
+
|
|
204
|
+
if degradation > self.degradation_threshold:
|
|
205
|
+
self.degradation_count += 1
|
|
206
|
+
if self.degradation_count >= self.patience:
|
|
207
|
+
result["action"] = "rollback"
|
|
208
|
+
result["reason"] = (
|
|
209
|
+
f"Accuracy degraded {degradation:.1%} from baseline "
|
|
210
|
+
f"({val_accuracy:.1%} vs {self.baseline_accuracy:.1%}) "
|
|
211
|
+
f"for {self.degradation_count} consecutive checks"
|
|
212
|
+
)
|
|
213
|
+
else:
|
|
214
|
+
result["reason"] = (
|
|
215
|
+
f"Degradation {degradation:.1%} detected "
|
|
216
|
+
f"({self.degradation_count}/{self.patience} until rollback)"
|
|
217
|
+
)
|
|
218
|
+
else:
|
|
219
|
+
self.degradation_count = 0
|
|
220
|
+
|
|
221
|
+
return result
|
|
222
|
+
|
|
223
|
+
def save_rollback_checkpoint(self, model, optimizer, epoch: int, path: str):
|
|
224
|
+
"""Save a checkpoint that can be rolled back to."""
|
|
225
|
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
226
|
+
torch.save({
|
|
227
|
+
"epoch": epoch,
|
|
228
|
+
"model_state_dict": model.state_dict(),
|
|
229
|
+
"optimizer_state_dict": optimizer.state_dict(),
|
|
230
|
+
"baseline_accuracy": self.baseline_accuracy,
|
|
231
|
+
"best_accuracy": self.best_accuracy,
|
|
232
|
+
}, path)
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
# ============================================================================
|
|
236
|
+
# Online Training Loop
|
|
237
|
+
# ============================================================================
|
|
238
|
+
|
|
239
|
+
def online_train(args):
|
|
240
|
+
"""
|
|
241
|
+
Fine-tune policy head on new data with experience replay.
|
|
242
|
+
|
|
243
|
+
Key differences from offline train.py:
|
|
244
|
+
- Lower learning rate (1e-5 vs 3e-4)
|
|
245
|
+
- Experience replay (70% historical + 30% new)
|
|
246
|
+
- Validation monitoring with rollback
|
|
247
|
+
- Warm-starts from existing checkpoint (required)
|
|
248
|
+
"""
|
|
249
|
+
# Device
|
|
250
|
+
if torch.cuda.is_available():
|
|
251
|
+
device = "cuda"
|
|
252
|
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
253
|
+
device = "mps"
|
|
254
|
+
else:
|
|
255
|
+
device = "cpu"
|
|
256
|
+
print(f"Device: {device}")
|
|
257
|
+
|
|
258
|
+
# Load existing checkpoint (required for online learning)
|
|
259
|
+
if not os.path.exists(args.checkpoint):
|
|
260
|
+
print(f"ERROR: Checkpoint not found: {args.checkpoint}")
|
|
261
|
+
print("Online learning requires a pre-trained checkpoint.")
|
|
262
|
+
print("Run offline training first: python train.py")
|
|
263
|
+
sys.exit(1)
|
|
264
|
+
|
|
265
|
+
checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=False)
|
|
266
|
+
config = checkpoint.get("config", {})
|
|
267
|
+
tool_to_index = checkpoint["tool_to_index"]
|
|
268
|
+
index_to_tool = checkpoint.get("index_to_tool", {str(v): k for k, v in tool_to_index.items()})
|
|
269
|
+
num_tools = checkpoint.get("num_tools", len(tool_to_index))
|
|
270
|
+
baseline_accuracy = checkpoint.get("val_accuracy", 0.0)
|
|
271
|
+
|
|
272
|
+
print(f"Loaded checkpoint: {args.checkpoint}")
|
|
273
|
+
print(f" Baseline val accuracy: {baseline_accuracy:.1%}")
|
|
274
|
+
print(f" Tools: {num_tools}")
|
|
275
|
+
|
|
276
|
+
# Embeddings
|
|
277
|
+
embeddings_matrix, text_to_idx = load_embedding_cache(args.data_dir)
|
|
278
|
+
if embeddings_matrix is not None:
|
|
279
|
+
print(f"Embedding cache: {embeddings_matrix.shape[0]} texts, {embeddings_matrix.shape[1]}-dim")
|
|
280
|
+
|
|
281
|
+
embedding_dim = config.get("embedding_dim", 768)
|
|
282
|
+
|
|
283
|
+
# Model
|
|
284
|
+
model = PolicyHead(
|
|
285
|
+
embedding_dim=embedding_dim,
|
|
286
|
+
hidden_dim=config.get("hidden_dim", 512),
|
|
287
|
+
num_tools=num_tools,
|
|
288
|
+
num_layers=config.get("num_layers", 4),
|
|
289
|
+
num_heads=config.get("num_heads", 8),
|
|
290
|
+
dropout=config.get("dropout", 0.1),
|
|
291
|
+
).to(device)
|
|
292
|
+
|
|
293
|
+
# Load pre-trained weights
|
|
294
|
+
model.load_state_dict(checkpoint["model_state_dict"])
|
|
295
|
+
print(f" Loaded {model.num_parameters:,} parameters")
|
|
296
|
+
|
|
297
|
+
# Load datasets
|
|
298
|
+
# Historical = existing train data
|
|
299
|
+
train_path = os.path.join(args.data_dir, "train.jsonl")
|
|
300
|
+
val_path = os.path.join(args.data_dir, "val.jsonl")
|
|
301
|
+
|
|
302
|
+
if not os.path.exists(train_path):
|
|
303
|
+
print(f"Training data not found: {train_path}")
|
|
304
|
+
sys.exit(1)
|
|
305
|
+
|
|
306
|
+
historical_ds = PolicyHeadDataset(train_path, tool_to_index, embeddings_matrix, text_to_idx)
|
|
307
|
+
val_ds = PolicyHeadDataset(val_path, tool_to_index, embeddings_matrix, text_to_idx) if os.path.exists(val_path) else None
|
|
308
|
+
|
|
309
|
+
# New data (counterfactual + recent real)
|
|
310
|
+
new_paths = []
|
|
311
|
+
if args.new_data and os.path.exists(args.new_data):
|
|
312
|
+
new_paths.append(args.new_data)
|
|
313
|
+
|
|
314
|
+
# Also check for counterfactual data
|
|
315
|
+
cf_path = os.path.join(args.data_dir, "counterfactual.jsonl")
|
|
316
|
+
if os.path.exists(cf_path) and cf_path not in new_paths:
|
|
317
|
+
new_paths.append(cf_path)
|
|
318
|
+
|
|
319
|
+
new_datasets = []
|
|
320
|
+
for p in new_paths:
|
|
321
|
+
ds = PolicyHeadDataset(p, tool_to_index, embeddings_matrix, text_to_idx)
|
|
322
|
+
if len(ds) > 0:
|
|
323
|
+
new_datasets.append(ds)
|
|
324
|
+
print(f" New data: {p} ({len(ds)} examples)")
|
|
325
|
+
|
|
326
|
+
if not new_datasets:
|
|
327
|
+
print("No new data to train on. Nothing to do.")
|
|
328
|
+
return
|
|
329
|
+
|
|
330
|
+
new_ds = ConcatDataset(new_datasets) if len(new_datasets) > 1 else new_datasets[0]
|
|
331
|
+
|
|
332
|
+
# Experience replay: mix historical (70%) + new (30%)
|
|
333
|
+
replay_ratio = args.replay_ratio
|
|
334
|
+
n_new = len(new_ds)
|
|
335
|
+
n_historical = int(n_new * (1 - replay_ratio) / replay_ratio)
|
|
336
|
+
n_historical = min(n_historical, len(historical_ds))
|
|
337
|
+
|
|
338
|
+
print(f"\n Experience replay:")
|
|
339
|
+
print(f" Historical pool: {len(historical_ds)} examples")
|
|
340
|
+
print(f" New data: {n_new} examples")
|
|
341
|
+
print(f" Sampling: {n_historical} historical + {n_new} new = {n_historical + n_new} total")
|
|
342
|
+
print(f" Ratio: {n_new/(n_historical+n_new):.0%} new / {n_historical/(n_historical+n_new):.0%} historical")
|
|
343
|
+
|
|
344
|
+
# Create mixed dataset via random sampling
|
|
345
|
+
rng = random.Random(args.seed)
|
|
346
|
+
historical_indices = rng.sample(range(len(historical_ds)), n_historical)
|
|
347
|
+
historical_subset = Subset(historical_ds, historical_indices)
|
|
348
|
+
|
|
349
|
+
mixed_ds = ConcatDataset([historical_subset, new_ds])
|
|
350
|
+
|
|
351
|
+
num_workers = 0 if device == "mps" else min(4, os.cpu_count() or 1)
|
|
352
|
+
train_loader = DataLoader(mixed_ds, batch_size=args.batch_size, shuffle=True, num_workers=num_workers)
|
|
353
|
+
val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=num_workers) if val_ds else None
|
|
354
|
+
|
|
355
|
+
# Optimizer with low LR (Drew: 1e-5 for online vs 1e-3 for pre-training)
|
|
356
|
+
optimizer = optim.AdamW(
|
|
357
|
+
model.parameters(),
|
|
358
|
+
lr=args.lr,
|
|
359
|
+
weight_decay=args.weight_decay,
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
criterion = nn.CrossEntropyLoss(label_smoothing=0.05) # Less smoothing for fine-tuning
|
|
363
|
+
|
|
364
|
+
# Validation monitor
|
|
365
|
+
monitor = ValidationMonitor(
|
|
366
|
+
degradation_threshold=args.degradation_threshold,
|
|
367
|
+
patience=args.rollback_patience,
|
|
368
|
+
checkpoint_dir=args.output_dir,
|
|
369
|
+
)
|
|
370
|
+
monitor.set_baseline(baseline_accuracy)
|
|
371
|
+
|
|
372
|
+
# Save rollback checkpoint
|
|
373
|
+
rollback_path = os.path.join(args.output_dir, "rollback_checkpoint.pt")
|
|
374
|
+
shutil.copy2(args.checkpoint, rollback_path)
|
|
375
|
+
print(f" Rollback checkpoint: {rollback_path}")
|
|
376
|
+
|
|
377
|
+
# Training loop
|
|
378
|
+
print(f"\n Online fine-tuning for {args.epochs} epochs (lr={args.lr})...")
|
|
379
|
+
print(f" {'Epoch':>5} {'Train Loss':>12} {'Train Acc':>10} {'Val Loss':>10} {'Val Acc':>9} {'Status':>12}")
|
|
380
|
+
print(" " + "-" * 70)
|
|
381
|
+
|
|
382
|
+
from train import train_epoch, evaluate
|
|
383
|
+
|
|
384
|
+
for epoch in range(1, args.epochs + 1):
|
|
385
|
+
train_loss, train_acc = train_epoch(
|
|
386
|
+
model, train_loader, criterion, optimizer,
|
|
387
|
+
# Use constant LR (no scheduler for online)
|
|
388
|
+
type("FakeScheduler", (), {"step": lambda self: None, "get_last_lr": lambda self: [args.lr]})(),
|
|
389
|
+
device,
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
val_loss, val_acc = (0.0, 0.0)
|
|
393
|
+
if val_loader:
|
|
394
|
+
val_loss, val_acc = evaluate(model, val_loader, criterion, device)
|
|
395
|
+
|
|
396
|
+
# Check validation monitor
|
|
397
|
+
check_result = monitor.check(epoch, val_acc, val_loss)
|
|
398
|
+
status = check_result["action"]
|
|
399
|
+
|
|
400
|
+
status_str = {
|
|
401
|
+
"continue": "✓",
|
|
402
|
+
"save_best": "★ best",
|
|
403
|
+
"rollback": "⚠ ROLLBACK",
|
|
404
|
+
}.get(status, status)
|
|
405
|
+
|
|
406
|
+
print(
|
|
407
|
+
f" {epoch:5d} {train_loss:12.4f} {train_acc:9.1%} {val_loss:10.4f} {val_acc:8.1%} {status_str:>12}"
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
if status == "save_best":
|
|
411
|
+
# Save new best
|
|
412
|
+
save_checkpoint(model, optimizer, epoch, val_acc, val_loss, config,
|
|
413
|
+
tool_to_index, index_to_tool, num_tools,
|
|
414
|
+
len(mixed_ds), device, args.output_dir)
|
|
415
|
+
|
|
416
|
+
elif status == "rollback":
|
|
417
|
+
print(f"\n ⚠ ROLLING BACK: {check_result['reason']}")
|
|
418
|
+
print(f" Restoring: {rollback_path}")
|
|
419
|
+
|
|
420
|
+
# Restore from rollback checkpoint
|
|
421
|
+
rollback_ckpt = torch.load(rollback_path, map_location=device, weights_only=False)
|
|
422
|
+
model.load_state_dict(rollback_ckpt["model_state_dict"])
|
|
423
|
+
print(f" Rolled back to baseline accuracy: {baseline_accuracy:.1%}")
|
|
424
|
+
|
|
425
|
+
# Write rollback event
|
|
426
|
+
write_rollback_event(args.output_dir, epoch, check_result)
|
|
427
|
+
break
|
|
428
|
+
|
|
429
|
+
# Final save if we completed without rollback
|
|
430
|
+
if check_result.get("action") != "rollback":
|
|
431
|
+
save_checkpoint(model, optimizer, args.epochs, val_acc, val_loss, config,
|
|
432
|
+
tool_to_index, index_to_tool, num_tools,
|
|
433
|
+
len(mixed_ds), device, args.output_dir)
|
|
434
|
+
print(f"\n Online training complete. Final val accuracy: {val_acc:.1%}")
|
|
435
|
+
|
|
436
|
+
# Write training event
|
|
437
|
+
write_training_event(args.output_dir, {
|
|
438
|
+
"type": "online_training",
|
|
439
|
+
"epochs": epoch,
|
|
440
|
+
"final_val_accuracy": val_acc,
|
|
441
|
+
"baseline_accuracy": baseline_accuracy,
|
|
442
|
+
"new_examples": n_new,
|
|
443
|
+
"historical_examples": n_historical,
|
|
444
|
+
"replay_ratio": replay_ratio,
|
|
445
|
+
"rollback": check_result.get("action") == "rollback",
|
|
446
|
+
})
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
def save_checkpoint(model, optimizer, epoch, val_acc, val_loss, config,
|
|
450
|
+
tool_to_index, index_to_tool, num_tools, trained_on,
|
|
451
|
+
device, output_dir):
|
|
452
|
+
"""Save checkpoint + metadata."""
|
|
453
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
454
|
+
|
|
455
|
+
ckpt_path = os.path.join(output_dir, "best_policy_head.pt")
|
|
456
|
+
torch.save({
|
|
457
|
+
"epoch": epoch,
|
|
458
|
+
"model_state_dict": model.state_dict(),
|
|
459
|
+
"optimizer_state_dict": optimizer.state_dict(),
|
|
460
|
+
"val_accuracy": val_acc,
|
|
461
|
+
"val_loss": val_loss,
|
|
462
|
+
"num_tools": num_tools,
|
|
463
|
+
"tool_to_index": tool_to_index,
|
|
464
|
+
"index_to_tool": index_to_tool,
|
|
465
|
+
"config": config,
|
|
466
|
+
"training_mode": "online",
|
|
467
|
+
}, ckpt_path)
|
|
468
|
+
|
|
469
|
+
meta = {
|
|
470
|
+
"version": 2,
|
|
471
|
+
"architecture": f"transformer-{config.get('num_layers', 4)}layer-{config.get('hidden_dim', 512)}h",
|
|
472
|
+
"embedding_dim": config.get("embedding_dim", 768),
|
|
473
|
+
"hidden_dim": config.get("hidden_dim", 512),
|
|
474
|
+
"num_tools": num_tools,
|
|
475
|
+
"trained_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
|
476
|
+
"trained_on": trained_on,
|
|
477
|
+
"val_accuracy": val_acc,
|
|
478
|
+
"training_mode": "online",
|
|
479
|
+
"tool_to_index": tool_to_index,
|
|
480
|
+
"index_to_tool": {str(k): v for k, v in index_to_tool.items()},
|
|
481
|
+
"checkpoint_path": os.path.abspath(ckpt_path),
|
|
482
|
+
}
|
|
483
|
+
meta_path = os.path.join(output_dir, "policy-head-v2.json")
|
|
484
|
+
with open(meta_path, "w") as f:
|
|
485
|
+
json.dump(meta, f, indent=2)
|
|
486
|
+
|
|
487
|
+
|
|
488
|
+
def write_rollback_event(output_dir, epoch, check_result):
|
|
489
|
+
"""Record rollback event for monitoring."""
|
|
490
|
+
event_path = os.path.join(output_dir, "training-events.jsonl")
|
|
491
|
+
event = {
|
|
492
|
+
"type": "rollback",
|
|
493
|
+
"epoch": epoch,
|
|
494
|
+
"reason": check_result["reason"],
|
|
495
|
+
"degradation": check_result["degradation"],
|
|
496
|
+
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
|
497
|
+
}
|
|
498
|
+
with open(event_path, "a") as f:
|
|
499
|
+
f.write(json.dumps(event) + "\n")
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
def write_training_event(output_dir, data):
|
|
503
|
+
"""Record training event for monitoring."""
|
|
504
|
+
event_path = os.path.join(output_dir, "training-events.jsonl")
|
|
505
|
+
data["timestamp"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
|
506
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
507
|
+
with open(event_path, "a") as f:
|
|
508
|
+
f.write(json.dumps(data) + "\n")
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
def main():
|
|
512
|
+
parser = argparse.ArgumentParser(
|
|
513
|
+
description="Online learning for v2 Policy Head with experience replay"
|
|
514
|
+
)
|
|
515
|
+
parser.add_argument(
|
|
516
|
+
"--checkpoint", required=True,
|
|
517
|
+
help="Path to pre-trained checkpoint (.pt)"
|
|
518
|
+
)
|
|
519
|
+
parser.add_argument(
|
|
520
|
+
"--data-dir", default=".jfl/v2-data",
|
|
521
|
+
help="Directory with train/val JSONL + embeddings"
|
|
522
|
+
)
|
|
523
|
+
parser.add_argument(
|
|
524
|
+
"--new-data", default=None,
|
|
525
|
+
help="Path to new training data JSONL"
|
|
526
|
+
)
|
|
527
|
+
parser.add_argument(
|
|
528
|
+
"--output-dir", default=".jfl/checkpoints",
|
|
529
|
+
help="Output directory for checkpoints"
|
|
530
|
+
)
|
|
531
|
+
parser.add_argument(
|
|
532
|
+
"--domain", default=None,
|
|
533
|
+
help="Path to domain.json (uses default if not specified)"
|
|
534
|
+
)
|
|
535
|
+
parser.add_argument(
|
|
536
|
+
"--epochs", type=int, default=10,
|
|
537
|
+
help="Max fine-tuning epochs (default: 10, much less than offline)"
|
|
538
|
+
)
|
|
539
|
+
parser.add_argument(
|
|
540
|
+
"--batch-size", type=int, default=32,
|
|
541
|
+
help="Batch size"
|
|
542
|
+
)
|
|
543
|
+
parser.add_argument(
|
|
544
|
+
"--lr", type=float, default=1e-5,
|
|
545
|
+
help="Learning rate (default: 1e-5, much lower than offline 3e-4)"
|
|
546
|
+
)
|
|
547
|
+
parser.add_argument(
|
|
548
|
+
"--weight-decay", type=float, default=0.01,
|
|
549
|
+
help="Weight decay"
|
|
550
|
+
)
|
|
551
|
+
parser.add_argument(
|
|
552
|
+
"--replay-ratio", type=float, default=0.3,
|
|
553
|
+
help="Fraction of batch that is new data (default: 0.3 = 30%% new)"
|
|
554
|
+
)
|
|
555
|
+
parser.add_argument(
|
|
556
|
+
"--degradation-threshold", type=float, default=0.10,
|
|
557
|
+
help="Rollback if val accuracy drops more than this fraction (default: 0.10 = 10%%)"
|
|
558
|
+
)
|
|
559
|
+
parser.add_argument(
|
|
560
|
+
"--rollback-patience", type=int, default=3,
|
|
561
|
+
help="Number of consecutive degraded epochs before rollback"
|
|
562
|
+
)
|
|
563
|
+
parser.add_argument(
|
|
564
|
+
"--seed", type=int, default=42,
|
|
565
|
+
help="Random seed"
|
|
566
|
+
)
|
|
567
|
+
args = parser.parse_args()
|
|
568
|
+
|
|
569
|
+
if args.domain is None:
|
|
570
|
+
args.domain = os.path.join(os.path.dirname(os.path.abspath(__file__)), "domain.json")
|
|
571
|
+
|
|
572
|
+
online_train(args)
|
|
573
|
+
|
|
574
|
+
|
|
575
|
+
if __name__ == "__main__":
|
|
576
|
+
main()
|
|
@@ -57,21 +57,39 @@ def precompute_embeddings(
|
|
|
57
57
|
all_states = set()
|
|
58
58
|
all_goals = set()
|
|
59
59
|
|
|
60
|
-
for split in ["train", "val", "test"]:
|
|
60
|
+
for split in ["train", "val", "test", "benchmark", "counterfactual", "synthetic"]:
|
|
61
61
|
path = os.path.join(data_dir, f"{split}.jsonl")
|
|
62
62
|
if not os.path.exists(path):
|
|
63
|
-
|
|
63
|
+
if split in ["train", "val", "test"]:
|
|
64
|
+
print(f" Skipping {split} (file not found)")
|
|
64
65
|
continue
|
|
65
66
|
states, goals = collect_unique_texts(path)
|
|
66
67
|
all_states.update(states)
|
|
67
68
|
all_goals.update(goals)
|
|
69
|
+
if split not in ["train", "val", "test"]:
|
|
70
|
+
print(f" Added {split}: {len(states)} states, {len(goals)} goals")
|
|
68
71
|
|
|
69
72
|
all_texts = sorted(all_states | all_goals)
|
|
70
73
|
print(f"Unique texts to embed: {len(all_texts)} ({len(all_states)} states, {len(all_goals)} goals)")
|
|
71
74
|
|
|
75
|
+
# Load existing cache to avoid re-embedding
|
|
72
76
|
text_to_embedding = {}
|
|
73
|
-
|
|
74
|
-
|
|
77
|
+
cache_path = os.path.join(data_dir, "embeddings_cache.npz")
|
|
78
|
+
index_path = os.path.join(data_dir, "text_to_idx.json")
|
|
79
|
+
if os.path.exists(cache_path) and os.path.exists(index_path):
|
|
80
|
+
existing_idx = json.load(open(index_path))
|
|
81
|
+
existing_emb = np.load(cache_path, allow_pickle=True)["embeddings"]
|
|
82
|
+
for text, idx in existing_idx.items():
|
|
83
|
+
if idx < len(existing_emb):
|
|
84
|
+
text_to_embedding[text] = existing_emb[idx].tolist()
|
|
85
|
+
print(f" Loaded {len(text_to_embedding)} cached embeddings")
|
|
86
|
+
|
|
87
|
+
# Only embed new texts
|
|
88
|
+
new_texts = [t for t in all_texts if t not in text_to_embedding]
|
|
89
|
+
print(f" New texts to embed: {len(new_texts)} (cached: {len(text_to_embedding)})")
|
|
90
|
+
|
|
91
|
+
for i in range(0, len(new_texts), batch_size):
|
|
92
|
+
batch = new_texts[i : i + batch_size]
|
|
75
93
|
try:
|
|
76
94
|
embeddings = embedder(batch)
|
|
77
95
|
for text, emb in zip(batch, embeddings):
|
|
@@ -80,8 +98,8 @@ def precompute_embeddings(
|
|
|
80
98
|
print(f" Error embedding batch {i}-{i + len(batch)}: {e}")
|
|
81
99
|
continue
|
|
82
100
|
|
|
83
|
-
done = min(i + batch_size, len(
|
|
84
|
-
print(f" Embedded {done}/{len(
|
|
101
|
+
done = min(i + batch_size, len(new_texts))
|
|
102
|
+
print(f" Embedded {done}/{len(new_texts)} new texts")
|
|
85
103
|
|
|
86
104
|
texts_list = sorted(text_to_embedding.keys())
|
|
87
105
|
text_to_idx = {t: i for i, t in enumerate(texts_list)}
|