claude-brain 0.27.2 → 0.28.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/VERSION +1 -1
- package/package.json +3 -1
- package/src/cli/bin.ts +14 -0
- package/src/cli/commands/export-training.ts +70 -0
- package/src/cli/commands/models.ts +681 -0
- package/src/cli/commands/status.ts +44 -0
- package/src/config/home.ts +1 -0
- package/src/config/schema.ts +30 -0
- package/src/intelligence/inference-router.ts +749 -0
- package/src/intelligence/model-manager.ts +206 -0
- package/src/intelligence/tokenizer.ts +118 -0
- package/src/knowledge/entity-extractor.ts +31 -1
- package/src/memory/compression.ts +17 -1
- package/src/memory/patterns.ts +46 -6
- package/src/retrieval/query/intent-classifier.ts +17 -1
- package/src/routing/entity-extractor.ts +30 -4
- package/src/routing/intent-classifier.ts +45 -16
- package/src/routing/router.ts +22 -2
- package/src/server/handlers/list-tools.ts +6 -6
- package/src/server/http-api.ts +83 -1
- package/src/server/services.ts +47 -0
- package/src/training/data-store.ts +298 -0
- package/src/training/retrain-pipeline.ts +394 -0
|
@@ -0,0 +1,298 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Training Data Store — Phase 1A (SLM Upgrade)
|
|
3
|
+
* Logs classification decisions to SQLite for model training.
|
|
4
|
+
* Async, non-blocking — never impacts main request path.
|
|
5
|
+
*
|
|
6
|
+
* Table: training_data in ~/.claude-brain/data/memory.db
|
|
7
|
+
*/
|
|
8
|
+
|
|
9
|
+
import { Database } from 'bun:sqlite'
|
|
10
|
+
import { join } from 'node:path'
|
|
11
|
+
import { existsSync, mkdirSync } from 'node:fs'
|
|
12
|
+
import { getClaudeBrainHome } from '@/config/home'
|
|
13
|
+
|
|
14
|
+
export type TrainingTask = 'intent' | 'entity' | 'query' | 'knowledge' | 'compress' | 'pattern'
|
|
15
|
+
|
|
16
|
+
export interface TrainingEntry {
|
|
17
|
+
task: TrainingTask
|
|
18
|
+
input: string
|
|
19
|
+
output: string // JSON-encoded: label, entities array, summary, etc.
|
|
20
|
+
metadata?: string // JSON-encoded: confidence, scores, timing
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
export interface ModelFeedbackEntry {
|
|
24
|
+
task: string
|
|
25
|
+
input: string
|
|
26
|
+
modelPrediction: string
|
|
27
|
+
modelConfidence: number
|
|
28
|
+
regexPrediction: string
|
|
29
|
+
actualLabel?: string
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
let db: Database | null = null
|
|
33
|
+
let insertStmt: ReturnType<Database['prepare']> | null = null
|
|
34
|
+
let feedbackInsertStmt: ReturnType<Database['prepare']> | null = null
|
|
35
|
+
|
|
36
|
+
function getDb(): Database | null {
|
|
37
|
+
if (db) return db
|
|
38
|
+
try {
|
|
39
|
+
const dataDir = join(getClaudeBrainHome(), 'data')
|
|
40
|
+
if (!existsSync(dataDir)) {
|
|
41
|
+
mkdirSync(dataDir, { recursive: true })
|
|
42
|
+
}
|
|
43
|
+
const dbPath = join(dataDir, 'memory.db')
|
|
44
|
+
db = new Database(dbPath)
|
|
45
|
+
db.run('PRAGMA journal_mode = WAL')
|
|
46
|
+
ensureTable(db)
|
|
47
|
+
insertStmt = db.prepare(
|
|
48
|
+
'INSERT INTO training_data (task, input, output, metadata) VALUES (?, ?, ?, ?)'
|
|
49
|
+
)
|
|
50
|
+
feedbackInsertStmt = db.prepare(
|
|
51
|
+
'INSERT INTO model_feedback (task, input, model_prediction, model_confidence, regex_prediction, actual_label) VALUES (?, ?, ?, ?, ?, ?)'
|
|
52
|
+
)
|
|
53
|
+
return db
|
|
54
|
+
} catch {
|
|
55
|
+
return null
|
|
56
|
+
}
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
function ensureTable(database: Database): void {
|
|
60
|
+
database.run(`
|
|
61
|
+
CREATE TABLE IF NOT EXISTS training_data (
|
|
62
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
63
|
+
task TEXT NOT NULL,
|
|
64
|
+
input TEXT NOT NULL,
|
|
65
|
+
output TEXT NOT NULL,
|
|
66
|
+
metadata TEXT,
|
|
67
|
+
verified INTEGER DEFAULT 0,
|
|
68
|
+
created_at TEXT DEFAULT (datetime('now'))
|
|
69
|
+
)
|
|
70
|
+
`)
|
|
71
|
+
// Indexes for efficient querying
|
|
72
|
+
database.run('CREATE INDEX IF NOT EXISTS idx_training_task ON training_data(task)')
|
|
73
|
+
database.run('CREATE INDEX IF NOT EXISTS idx_training_verified ON training_data(verified)')
|
|
74
|
+
|
|
75
|
+
// Phase 6A: Model feedback table for continuous learning loop
|
|
76
|
+
database.run(`
|
|
77
|
+
CREATE TABLE IF NOT EXISTS model_feedback (
|
|
78
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
79
|
+
task TEXT NOT NULL,
|
|
80
|
+
input TEXT NOT NULL,
|
|
81
|
+
model_prediction TEXT NOT NULL,
|
|
82
|
+
model_confidence REAL NOT NULL,
|
|
83
|
+
regex_prediction TEXT NOT NULL,
|
|
84
|
+
actual_label TEXT,
|
|
85
|
+
created_at TEXT DEFAULT (datetime('now'))
|
|
86
|
+
)
|
|
87
|
+
`)
|
|
88
|
+
database.run('CREATE INDEX IF NOT EXISTS idx_feedback_task ON model_feedback(task)')
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
/**
|
|
92
|
+
* Log a training example. Fire-and-forget — errors are silently swallowed.
|
|
93
|
+
*/
|
|
94
|
+
export function logTrainingData(entry: TrainingEntry): void {
|
|
95
|
+
setImmediate(() => {
|
|
96
|
+
try {
|
|
97
|
+
const database = getDb()
|
|
98
|
+
if (!database || !insertStmt) return
|
|
99
|
+
insertStmt.run(entry.task, entry.input, entry.output, entry.metadata || null)
|
|
100
|
+
} catch {
|
|
101
|
+
// Never block or crash the main path
|
|
102
|
+
}
|
|
103
|
+
})
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
/**
|
|
107
|
+
* Export training data as JSONL lines for a specific task.
|
|
108
|
+
*/
|
|
109
|
+
export function exportTrainingData(
|
|
110
|
+
task: TrainingTask,
|
|
111
|
+
options?: { verifiedOnly?: boolean; limit?: number }
|
|
112
|
+
): string[] {
|
|
113
|
+
const database = getDb()
|
|
114
|
+
if (!database) return []
|
|
115
|
+
|
|
116
|
+
let sql = 'SELECT input, output, metadata, verified, created_at FROM training_data WHERE task = ?'
|
|
117
|
+
const params: any[] = [task]
|
|
118
|
+
|
|
119
|
+
if (options?.verifiedOnly) {
|
|
120
|
+
sql += ' AND verified = 1'
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
sql += ' ORDER BY created_at DESC'
|
|
124
|
+
|
|
125
|
+
if (options?.limit) {
|
|
126
|
+
sql += ' LIMIT ?'
|
|
127
|
+
params.push(options.limit)
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
const rows = database.prepare(sql).all(...params) as any[]
|
|
131
|
+
return rows.map(row => JSON.stringify({
|
|
132
|
+
input: row.input,
|
|
133
|
+
output: JSON.parse(row.output),
|
|
134
|
+
metadata: row.metadata ? JSON.parse(row.metadata) : null,
|
|
135
|
+
verified: row.verified === 1,
|
|
136
|
+
created_at: row.created_at,
|
|
137
|
+
}))
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
/**
|
|
141
|
+
* Get count of training examples per task.
|
|
142
|
+
*/
|
|
143
|
+
export function getTrainingStats(): Record<TrainingTask, { total: number; verified: number }> {
|
|
144
|
+
const database = getDb()
|
|
145
|
+
const tasks: TrainingTask[] = ['intent', 'entity', 'query', 'knowledge', 'compress', 'pattern']
|
|
146
|
+
const stats = {} as Record<TrainingTask, { total: number; verified: number }>
|
|
147
|
+
|
|
148
|
+
for (const task of tasks) {
|
|
149
|
+
if (!database) {
|
|
150
|
+
stats[task] = { total: 0, verified: 0 }
|
|
151
|
+
continue
|
|
152
|
+
}
|
|
153
|
+
const total = (database.prepare('SELECT COUNT(*) as c FROM training_data WHERE task = ?').get(task) as any)?.c || 0
|
|
154
|
+
const verified = (database.prepare('SELECT COUNT(*) as c FROM training_data WHERE task = ? AND verified = 1').get(task) as any)?.c || 0
|
|
155
|
+
stats[task] = { total, verified }
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
return stats
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
// ── Phase 6A: Model Feedback Functions ──────────────────────────────
|
|
162
|
+
|
|
163
|
+
/**
|
|
164
|
+
* Log a model vs regex comparison. Fire-and-forget — errors are silently swallowed.
|
|
165
|
+
*/
|
|
166
|
+
export function logModelFeedback(entry: ModelFeedbackEntry): void {
|
|
167
|
+
setImmediate(() => {
|
|
168
|
+
try {
|
|
169
|
+
const database = getDb()
|
|
170
|
+
if (!database || !feedbackInsertStmt) return
|
|
171
|
+
feedbackInsertStmt.run(
|
|
172
|
+
entry.task,
|
|
173
|
+
entry.input,
|
|
174
|
+
entry.modelPrediction,
|
|
175
|
+
entry.modelConfidence,
|
|
176
|
+
entry.regexPrediction,
|
|
177
|
+
entry.actualLabel || null
|
|
178
|
+
)
|
|
179
|
+
} catch {
|
|
180
|
+
// Never block or crash the main path
|
|
181
|
+
}
|
|
182
|
+
})
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
/**
|
|
186
|
+
* Get per-task feedback stats: total, agreements, disagreements, disagreement rate, reviewed count.
|
|
187
|
+
*/
|
|
188
|
+
export function getModelFeedbackStats(): Record<string, {
|
|
189
|
+
total: number
|
|
190
|
+
agreements: number
|
|
191
|
+
disagreements: number
|
|
192
|
+
disagreementRate: number
|
|
193
|
+
reviewed: number
|
|
194
|
+
}> {
|
|
195
|
+
const database = getDb()
|
|
196
|
+
const tasks: TrainingTask[] = ['intent', 'entity', 'query', 'knowledge', 'compress', 'pattern']
|
|
197
|
+
const stats = {} as Record<string, {
|
|
198
|
+
total: number
|
|
199
|
+
agreements: number
|
|
200
|
+
disagreements: number
|
|
201
|
+
disagreementRate: number
|
|
202
|
+
reviewed: number
|
|
203
|
+
}>
|
|
204
|
+
|
|
205
|
+
for (const task of tasks) {
|
|
206
|
+
if (!database) {
|
|
207
|
+
stats[task] = { total: 0, agreements: 0, disagreements: 0, disagreementRate: 0, reviewed: 0 }
|
|
208
|
+
continue
|
|
209
|
+
}
|
|
210
|
+
const total = (database.prepare(
|
|
211
|
+
'SELECT COUNT(*) as c FROM model_feedback WHERE task = ?'
|
|
212
|
+
).get(task) as any)?.c || 0
|
|
213
|
+
|
|
214
|
+
const agreements = (database.prepare(
|
|
215
|
+
'SELECT COUNT(*) as c FROM model_feedback WHERE task = ? AND model_prediction = regex_prediction'
|
|
216
|
+
).get(task) as any)?.c || 0
|
|
217
|
+
|
|
218
|
+
const disagreements = total - agreements
|
|
219
|
+
|
|
220
|
+
const reviewed = (database.prepare(
|
|
221
|
+
'SELECT COUNT(*) as c FROM model_feedback WHERE task = ? AND actual_label IS NOT NULL'
|
|
222
|
+
).get(task) as any)?.c || 0
|
|
223
|
+
|
|
224
|
+
stats[task] = {
|
|
225
|
+
total,
|
|
226
|
+
agreements,
|
|
227
|
+
disagreements,
|
|
228
|
+
disagreementRate: total > 0 ? disagreements / total : 0,
|
|
229
|
+
reviewed,
|
|
230
|
+
}
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
return stats
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
/**
|
|
237
|
+
* Export feedback as JSONL lines for a specific task.
|
|
238
|
+
*/
|
|
239
|
+
export function exportModelFeedback(
|
|
240
|
+
task: string,
|
|
241
|
+
options?: { limit?: number }
|
|
242
|
+
): string[] {
|
|
243
|
+
const database = getDb()
|
|
244
|
+
if (!database) return []
|
|
245
|
+
|
|
246
|
+
let sql = 'SELECT input, model_prediction, model_confidence, regex_prediction, actual_label, created_at FROM model_feedback WHERE task = ?'
|
|
247
|
+
const params: any[] = [task]
|
|
248
|
+
|
|
249
|
+
sql += ' ORDER BY created_at DESC'
|
|
250
|
+
|
|
251
|
+
if (options?.limit) {
|
|
252
|
+
sql += ' LIMIT ?'
|
|
253
|
+
params.push(options.limit)
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
const rows = database.prepare(sql).all(...params) as any[]
|
|
257
|
+
return rows.map(row => JSON.stringify({
|
|
258
|
+
input: row.input,
|
|
259
|
+
modelPrediction: row.model_prediction,
|
|
260
|
+
modelConfidence: row.model_confidence,
|
|
261
|
+
regexPrediction: row.regex_prediction,
|
|
262
|
+
actualLabel: row.actual_label,
|
|
263
|
+
createdAt: row.created_at,
|
|
264
|
+
}))
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
/**
|
|
268
|
+
* Get the most recent disagreements for human review.
|
|
269
|
+
*/
|
|
270
|
+
export function getDisagreements(
|
|
271
|
+
task: string,
|
|
272
|
+
limit: number = 50
|
|
273
|
+
): Array<{
|
|
274
|
+
input: string
|
|
275
|
+
modelPrediction: string
|
|
276
|
+
modelConfidence: number
|
|
277
|
+
regexPrediction: string
|
|
278
|
+
createdAt: string
|
|
279
|
+
}> {
|
|
280
|
+
const database = getDb()
|
|
281
|
+
if (!database) return []
|
|
282
|
+
|
|
283
|
+
const rows = database.prepare(`
|
|
284
|
+
SELECT input, model_prediction, model_confidence, regex_prediction, created_at
|
|
285
|
+
FROM model_feedback
|
|
286
|
+
WHERE task = ? AND model_prediction != regex_prediction
|
|
287
|
+
ORDER BY created_at DESC
|
|
288
|
+
LIMIT ?
|
|
289
|
+
`).all(task, limit) as any[]
|
|
290
|
+
|
|
291
|
+
return rows.map(row => ({
|
|
292
|
+
input: row.input,
|
|
293
|
+
modelPrediction: row.model_prediction,
|
|
294
|
+
modelConfidence: row.model_confidence,
|
|
295
|
+
regexPrediction: row.regex_prediction,
|
|
296
|
+
createdAt: row.created_at,
|
|
297
|
+
}))
|
|
298
|
+
}
|
|
@@ -0,0 +1,394 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Retrain Pipeline — SLM Upgrade Phase 6B
|
|
3
|
+
* Automated retraining when enough feedback accumulates.
|
|
4
|
+
*
|
|
5
|
+
* Flow:
|
|
6
|
+
* 1. Check if retraining is needed (feedback count + disagreement rate)
|
|
7
|
+
* 2. Export merged training data (existing + feedback corrections) as JSONL
|
|
8
|
+
* 3. Shell out to Python: prepare_data → train → evaluate → export_onnx
|
|
9
|
+
* 4. Compare new accuracy to old, promote model if improved
|
|
10
|
+
* 5. Update retrain-state.json with timestamp + stats
|
|
11
|
+
*/
|
|
12
|
+
|
|
13
|
+
import { execSync } from 'node:child_process'
|
|
14
|
+
import { existsSync, mkdirSync, readFileSync, writeFileSync } from 'node:fs'
|
|
15
|
+
import { join } from 'node:path'
|
|
16
|
+
import { homedir } from 'node:os'
|
|
17
|
+
import { getClaudeBrainHome, getHomePaths } from '@/config/home'
|
|
18
|
+
import {
|
|
19
|
+
exportTrainingData,
|
|
20
|
+
exportModelFeedback,
|
|
21
|
+
getModelFeedbackStats,
|
|
22
|
+
type TrainingTask,
|
|
23
|
+
} from '@/training/data-store'
|
|
24
|
+
import type { ModelManifest, ModelTask } from '@/intelligence/model-manager'
|
|
25
|
+
|
|
26
|
+
// ── Types ────────────────────────────────────────────────────────────
|
|
27
|
+
|
|
28
|
+
export interface RetrainConfig {
|
|
29
|
+
minFeedbackCount: number
|
|
30
|
+
maxDisagreementRate: number
|
|
31
|
+
pythonPath: string
|
|
32
|
+
trainingDir: string
|
|
33
|
+
force: boolean
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
export interface RetrainCheck {
|
|
37
|
+
needed: boolean
|
|
38
|
+
reason: string
|
|
39
|
+
feedbackCount: number
|
|
40
|
+
disagreementRate: number
|
|
41
|
+
lastRetrainDate: string | null
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
export interface RetrainResult {
|
|
45
|
+
task: string
|
|
46
|
+
success: boolean
|
|
47
|
+
error?: string
|
|
48
|
+
oldAccuracy?: number
|
|
49
|
+
newAccuracy?: number
|
|
50
|
+
trainingDataCount: number
|
|
51
|
+
duration: number
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
interface RetrainStateEntry {
|
|
55
|
+
lastRetrain: string
|
|
56
|
+
lastAccuracy: number
|
|
57
|
+
feedbackAtRetrain: number
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
type RetrainState = Record<string, RetrainStateEntry>
|
|
61
|
+
|
|
62
|
+
// ── Helpers ──────────────────────────────────────────────────────────
|
|
63
|
+
|
|
64
|
+
function resolveTrainingDir(trainingDir: string): string {
|
|
65
|
+
if (trainingDir.startsWith('~')) {
|
|
66
|
+
return join(homedir(), trainingDir.slice(1))
|
|
67
|
+
}
|
|
68
|
+
return trainingDir
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
function getRetrainStatePath(): string {
|
|
72
|
+
return join(getClaudeBrainHome(), 'data', 'retrain-state.json')
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
function loadRetrainState(): RetrainState {
|
|
76
|
+
const statePath = getRetrainStatePath()
|
|
77
|
+
if (!existsSync(statePath)) return {}
|
|
78
|
+
try {
|
|
79
|
+
return JSON.parse(readFileSync(statePath, 'utf-8'))
|
|
80
|
+
} catch {
|
|
81
|
+
return {}
|
|
82
|
+
}
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
function saveRetrainState(state: RetrainState): void {
|
|
86
|
+
const statePath = getRetrainStatePath()
|
|
87
|
+
const dir = join(getClaudeBrainHome(), 'data')
|
|
88
|
+
if (!existsSync(dir)) {
|
|
89
|
+
mkdirSync(dir, { recursive: true })
|
|
90
|
+
}
|
|
91
|
+
writeFileSync(statePath, JSON.stringify(state, null, 2))
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
function loadManifest(): ModelManifest | null {
|
|
95
|
+
const manifestPath = join(getHomePaths().models, 'manifest.json')
|
|
96
|
+
if (!existsSync(manifestPath)) return null
|
|
97
|
+
try {
|
|
98
|
+
return JSON.parse(readFileSync(manifestPath, 'utf-8'))
|
|
99
|
+
} catch {
|
|
100
|
+
return null
|
|
101
|
+
}
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
function saveManifest(manifest: ModelManifest): void {
|
|
105
|
+
const modelsDir = getHomePaths().models
|
|
106
|
+
if (!existsSync(modelsDir)) {
|
|
107
|
+
mkdirSync(modelsDir, { recursive: true })
|
|
108
|
+
}
|
|
109
|
+
writeFileSync(join(modelsDir, 'manifest.json'), JSON.stringify(manifest, null, 2))
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
// ── Public API ───────────────────────────────────────────────────────
|
|
113
|
+
|
|
114
|
+
/**
|
|
115
|
+
* Check if retraining is needed for a task based on feedback accumulation.
|
|
116
|
+
*/
|
|
117
|
+
export function shouldRetrain(task: ModelTask, config: RetrainConfig): RetrainCheck {
|
|
118
|
+
const state = loadRetrainState()
|
|
119
|
+
const taskState = state[task]
|
|
120
|
+
const lastRetrainDate = taskState?.lastRetrain ?? null
|
|
121
|
+
const feedbackAtLastRetrain = taskState?.feedbackAtRetrain ?? 0
|
|
122
|
+
|
|
123
|
+
const feedbackStats = getModelFeedbackStats()
|
|
124
|
+
const taskStats = feedbackStats[task]
|
|
125
|
+
|
|
126
|
+
if (!taskStats) {
|
|
127
|
+
return {
|
|
128
|
+
needed: false,
|
|
129
|
+
reason: 'No feedback data available',
|
|
130
|
+
feedbackCount: 0,
|
|
131
|
+
disagreementRate: 0,
|
|
132
|
+
lastRetrainDate,
|
|
133
|
+
}
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
const newFeedbackCount = taskStats.total - feedbackAtLastRetrain
|
|
137
|
+
const disagreementRate = taskStats.disagreementRate
|
|
138
|
+
|
|
139
|
+
// Check if enough new feedback has accumulated
|
|
140
|
+
if (newFeedbackCount < config.minFeedbackCount) {
|
|
141
|
+
return {
|
|
142
|
+
needed: false,
|
|
143
|
+
reason: `Only ${newFeedbackCount} new feedback entries (need ${config.minFeedbackCount})`,
|
|
144
|
+
feedbackCount: taskStats.total,
|
|
145
|
+
disagreementRate,
|
|
146
|
+
lastRetrainDate,
|
|
147
|
+
}
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
// Check if disagreement rate exceeds threshold
|
|
151
|
+
if (disagreementRate > config.maxDisagreementRate) {
|
|
152
|
+
return {
|
|
153
|
+
needed: true,
|
|
154
|
+
reason: `Disagreement rate ${(disagreementRate * 100).toFixed(1)}% exceeds threshold ${(config.maxDisagreementRate * 100).toFixed(1)}%`,
|
|
155
|
+
feedbackCount: taskStats.total,
|
|
156
|
+
disagreementRate,
|
|
157
|
+
lastRetrainDate,
|
|
158
|
+
}
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
// Enough feedback accumulated even if disagreement rate is below threshold
|
|
162
|
+
return {
|
|
163
|
+
needed: true,
|
|
164
|
+
reason: `${newFeedbackCount} new feedback entries since last retrain`,
|
|
165
|
+
feedbackCount: taskStats.total,
|
|
166
|
+
disagreementRate,
|
|
167
|
+
lastRetrainDate,
|
|
168
|
+
}
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
/**
|
|
172
|
+
* Export merged training data (training_data + feedback corrections) as JSONL.
|
|
173
|
+
*/
|
|
174
|
+
export function exportMergedTrainingData(
|
|
175
|
+
task: ModelTask,
|
|
176
|
+
trainingDir: string
|
|
177
|
+
): { path: string; count: number } {
|
|
178
|
+
const resolvedDir = resolveTrainingDir(trainingDir)
|
|
179
|
+
const dataDir = join(resolvedDir, 'data', task)
|
|
180
|
+
if (!existsSync(dataDir)) {
|
|
181
|
+
mkdirSync(dataDir, { recursive: true })
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
const outputPath = join(dataDir, 'train.jsonl')
|
|
185
|
+
|
|
186
|
+
// Get existing training data
|
|
187
|
+
const trainingLines = exportTrainingData(task as TrainingTask)
|
|
188
|
+
|
|
189
|
+
// Get feedback data (disagreements are useful training signal)
|
|
190
|
+
const feedbackLines = exportModelFeedback(task)
|
|
191
|
+
|
|
192
|
+
// Merge: training data first, then feedback
|
|
193
|
+
const allLines = [...trainingLines, ...feedbackLines]
|
|
194
|
+
|
|
195
|
+
writeFileSync(outputPath, allLines.join('\n') + '\n')
|
|
196
|
+
|
|
197
|
+
return { path: outputPath, count: allLines.length }
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
/**
|
|
201
|
+
* Run the full retrain pipeline for a single task.
|
|
202
|
+
*/
|
|
203
|
+
export async function retrainTask(
|
|
204
|
+
task: ModelTask,
|
|
205
|
+
config: RetrainConfig
|
|
206
|
+
): Promise<RetrainResult> {
|
|
207
|
+
const startTime = Date.now()
|
|
208
|
+
const resolvedDir = resolveTrainingDir(config.trainingDir)
|
|
209
|
+
|
|
210
|
+
// Step 1: Export merged training data
|
|
211
|
+
console.log(` [${task}] Exporting merged training data...`)
|
|
212
|
+
let dataCount: number
|
|
213
|
+
try {
|
|
214
|
+
const { count } = exportMergedTrainingData(task, config.trainingDir)
|
|
215
|
+
dataCount = count
|
|
216
|
+
console.log(` [${task}] Exported ${count} training examples`)
|
|
217
|
+
} catch (err) {
|
|
218
|
+
return {
|
|
219
|
+
task,
|
|
220
|
+
success: false,
|
|
221
|
+
error: `Failed to export training data: ${err instanceof Error ? err.message : String(err)}`,
|
|
222
|
+
trainingDataCount: 0,
|
|
223
|
+
duration: Date.now() - startTime,
|
|
224
|
+
}
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
if (dataCount === 0) {
|
|
228
|
+
return {
|
|
229
|
+
task,
|
|
230
|
+
success: false,
|
|
231
|
+
error: 'No training data available',
|
|
232
|
+
trainingDataCount: 0,
|
|
233
|
+
duration: Date.now() - startTime,
|
|
234
|
+
}
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
// Step 2: Run Python training pipeline
|
|
238
|
+
const pythonSteps = [
|
|
239
|
+
{ script: 'scripts/prepare_data.py', args: `--task ${task}`, label: 'Preparing data' },
|
|
240
|
+
{ script: 'scripts/train.py', args: `--task ${task}`, label: 'Training model' },
|
|
241
|
+
{ script: 'scripts/evaluate.py', args: `--task ${task} --save`, label: 'Evaluating model' },
|
|
242
|
+
{ script: 'scripts/export_onnx.py', args: `--task ${task} --output-dir models/`, label: 'Exporting ONNX' },
|
|
243
|
+
]
|
|
244
|
+
|
|
245
|
+
for (const step of pythonSteps) {
|
|
246
|
+
console.log(` [${task}] ${step.label}...`)
|
|
247
|
+
const cmd = `${config.pythonPath} ${step.script} ${step.args}`
|
|
248
|
+
try {
|
|
249
|
+
execSync(cmd, {
|
|
250
|
+
cwd: resolvedDir,
|
|
251
|
+
stdio: 'pipe',
|
|
252
|
+
timeout: 600_000, // 10 minute timeout per step
|
|
253
|
+
})
|
|
254
|
+
} catch (err) {
|
|
255
|
+
const stderr = err instanceof Error && 'stderr' in err
|
|
256
|
+
? (err as any).stderr?.toString().slice(0, 500)
|
|
257
|
+
: String(err)
|
|
258
|
+
return {
|
|
259
|
+
task,
|
|
260
|
+
success: false,
|
|
261
|
+
error: `Python step "${step.label}" failed: ${stderr}`,
|
|
262
|
+
trainingDataCount: dataCount,
|
|
263
|
+
duration: Date.now() - startTime,
|
|
264
|
+
}
|
|
265
|
+
}
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
// Step 3: Read evaluation results
|
|
269
|
+
const benchmarkPath = join(resolvedDir, 'benchmarks', `${task}.json`)
|
|
270
|
+
let newAccuracy: number | undefined
|
|
271
|
+
try {
|
|
272
|
+
if (existsSync(benchmarkPath)) {
|
|
273
|
+
const benchData = JSON.parse(readFileSync(benchmarkPath, 'utf-8'))
|
|
274
|
+
newAccuracy = benchData.accuracy ?? benchData.overall_accuracy
|
|
275
|
+
}
|
|
276
|
+
} catch {
|
|
277
|
+
console.log(` [${task}] Warning: Could not read benchmark results`)
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
// Step 4: Compare with existing model accuracy
|
|
281
|
+
const manifest = loadManifest()
|
|
282
|
+
const oldAccuracy = manifest?.models[task]?.accuracy
|
|
283
|
+
|
|
284
|
+
if (newAccuracy != null && oldAccuracy != null && newAccuracy < oldAccuracy && !config.force) {
|
|
285
|
+
console.log(` [${task}] New accuracy ${(newAccuracy * 100).toFixed(1)}% < old ${(oldAccuracy * 100).toFixed(1)}% — skipping model replacement`)
|
|
286
|
+
// Still record the retrain attempt
|
|
287
|
+
const state = loadRetrainState()
|
|
288
|
+
const feedbackStats = getModelFeedbackStats()
|
|
289
|
+
state[task] = {
|
|
290
|
+
lastRetrain: new Date().toISOString(),
|
|
291
|
+
lastAccuracy: oldAccuracy,
|
|
292
|
+
feedbackAtRetrain: feedbackStats[task]?.total ?? 0,
|
|
293
|
+
}
|
|
294
|
+
saveRetrainState(state)
|
|
295
|
+
|
|
296
|
+
return {
|
|
297
|
+
task,
|
|
298
|
+
success: false,
|
|
299
|
+
error: `New accuracy (${(newAccuracy * 100).toFixed(1)}%) lower than current (${(oldAccuracy * 100).toFixed(1)}%). Use --force to override.`,
|
|
300
|
+
oldAccuracy,
|
|
301
|
+
newAccuracy,
|
|
302
|
+
trainingDataCount: dataCount,
|
|
303
|
+
duration: Date.now() - startTime,
|
|
304
|
+
}
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
// Step 5: Copy new ONNX model to ~/.claude-brain/models/
|
|
308
|
+
const modelsDir = getHomePaths().models
|
|
309
|
+
if (!existsSync(modelsDir)) {
|
|
310
|
+
mkdirSync(modelsDir, { recursive: true })
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
const onnxSource = join(resolvedDir, 'models', `${task}.onnx`)
|
|
314
|
+
if (existsSync(onnxSource)) {
|
|
315
|
+
const onnxDest = join(modelsDir, `${task}.onnx`)
|
|
316
|
+
const onnxData = readFileSync(onnxSource)
|
|
317
|
+
writeFileSync(onnxDest, onnxData)
|
|
318
|
+
console.log(` [${task}] Model copied to ${onnxDest}`)
|
|
319
|
+
|
|
320
|
+
// Update manifest
|
|
321
|
+
const currentManifest = loadManifest() ?? { models: {} }
|
|
322
|
+
const existing = currentManifest.models[task]
|
|
323
|
+
currentManifest.models[task] = {
|
|
324
|
+
version: bumpVersion(existing?.version),
|
|
325
|
+
file: `${task}.onnx`,
|
|
326
|
+
accuracy: newAccuracy,
|
|
327
|
+
labels: existing?.labels,
|
|
328
|
+
maxSeqLen: existing?.maxSeqLen,
|
|
329
|
+
}
|
|
330
|
+
saveManifest(currentManifest)
|
|
331
|
+
console.log(` [${task}] Manifest updated`)
|
|
332
|
+
} else {
|
|
333
|
+
console.log(` [${task}] Warning: ONNX file not found at ${onnxSource}`)
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
// Step 6: Update retrain state
|
|
337
|
+
const state = loadRetrainState()
|
|
338
|
+
const feedbackStats = getModelFeedbackStats()
|
|
339
|
+
state[task] = {
|
|
340
|
+
lastRetrain: new Date().toISOString(),
|
|
341
|
+
lastAccuracy: newAccuracy ?? oldAccuracy ?? 0,
|
|
342
|
+
feedbackAtRetrain: feedbackStats[task]?.total ?? 0,
|
|
343
|
+
}
|
|
344
|
+
saveRetrainState(state)
|
|
345
|
+
|
|
346
|
+
console.log(` [${task}] Retrain complete in ${((Date.now() - startTime) / 1000).toFixed(1)}s`)
|
|
347
|
+
|
|
348
|
+
return {
|
|
349
|
+
task,
|
|
350
|
+
success: true,
|
|
351
|
+
oldAccuracy,
|
|
352
|
+
newAccuracy,
|
|
353
|
+
trainingDataCount: dataCount,
|
|
354
|
+
duration: Date.now() - startTime,
|
|
355
|
+
}
|
|
356
|
+
}
|
|
357
|
+
|
|
358
|
+
/**
|
|
359
|
+
* Run retrain for all tasks that need it.
|
|
360
|
+
*/
|
|
361
|
+
export async function retrainAll(
|
|
362
|
+
config: RetrainConfig
|
|
363
|
+
): Promise<Map<string, RetrainResult>> {
|
|
364
|
+
const ALL_TASKS: ModelTask[] = ['intent', 'entity', 'query', 'knowledge', 'compress', 'pattern']
|
|
365
|
+
const results = new Map<string, RetrainResult>()
|
|
366
|
+
|
|
367
|
+
for (const task of ALL_TASKS) {
|
|
368
|
+
const check = shouldRetrain(task, config)
|
|
369
|
+
|
|
370
|
+
if (!check.needed && !config.force) {
|
|
371
|
+
console.log(` [${task}] Skipping — ${check.reason}`)
|
|
372
|
+
continue
|
|
373
|
+
}
|
|
374
|
+
|
|
375
|
+
if (config.force && !check.needed) {
|
|
376
|
+
console.log(` [${task}] Forcing retrain (${check.reason})`)
|
|
377
|
+
}
|
|
378
|
+
|
|
379
|
+
const result = await retrainTask(task, config)
|
|
380
|
+
results.set(task, result)
|
|
381
|
+
}
|
|
382
|
+
|
|
383
|
+
return results
|
|
384
|
+
}
|
|
385
|
+
|
|
386
|
+
// ── Internal Helpers ─────────────────────────────────────────────────
|
|
387
|
+
|
|
388
|
+
function bumpVersion(current?: string): string {
|
|
389
|
+
if (!current) return '1.0.0'
|
|
390
|
+
const parts = current.split('.').map(Number)
|
|
391
|
+
// Bump patch version
|
|
392
|
+
parts[2] = (parts[2] ?? 0) + 1
|
|
393
|
+
return parts.join('.')
|
|
394
|
+
}
|