claude-brain 0.27.2 → 0.28.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/VERSION +1 -1
- package/package.json +3 -1
- package/src/cli/bin.ts +14 -0
- package/src/cli/commands/export-training.ts +70 -0
- package/src/cli/commands/models.ts +681 -0
- package/src/cli/commands/status.ts +44 -0
- package/src/config/home.ts +1 -0
- package/src/config/schema.ts +30 -0
- package/src/intelligence/inference-router.ts +749 -0
- package/src/intelligence/model-manager.ts +206 -0
- package/src/intelligence/tokenizer.ts +118 -0
- package/src/knowledge/entity-extractor.ts +31 -1
- package/src/memory/compression.ts +17 -1
- package/src/memory/patterns.ts +46 -6
- package/src/retrieval/query/intent-classifier.ts +17 -1
- package/src/routing/entity-extractor.ts +30 -4
- package/src/routing/intent-classifier.ts +45 -16
- package/src/routing/router.ts +22 -2
- package/src/server/handlers/list-tools.ts +6 -6
- package/src/server/http-api.ts +83 -1
- package/src/server/services.ts +47 -0
- package/src/training/data-store.ts +298 -0
- package/src/training/retrain-pipeline.ts +394 -0
|
@@ -0,0 +1,681 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Models Command — SLM Upgrade Phase 5
|
|
3
|
+
* Manage SLM models: list, download, enable, disable, benchmark, stats.
|
|
4
|
+
*
|
|
5
|
+
* Usage:
|
|
6
|
+
* claude-brain models list
|
|
7
|
+
* claude-brain models download [--task <task>|all]
|
|
8
|
+
* claude-brain models enable <task>
|
|
9
|
+
* claude-brain models disable <task>
|
|
10
|
+
* claude-brain models benchmark <task>
|
|
11
|
+
* claude-brain models stats
|
|
12
|
+
*/
|
|
13
|
+
|
|
14
|
+
import { readFileSync, existsSync, mkdirSync, writeFileSync, statSync, copyFileSync } from 'node:fs'
|
|
15
|
+
import { join } from 'node:path'
|
|
16
|
+
import { homedir } from 'node:os'
|
|
17
|
+
import { parseArgs } from 'citty'
|
|
18
|
+
import { renderLogo, theme, heading, dimText, successText, warningText, errorText, box, summaryPanel } from '@/cli/ui/index.js'
|
|
19
|
+
import { getHomePaths, getClaudeBrainHome } from '@/config/home'
|
|
20
|
+
import { getTrainingStats, type TrainingTask } from '@/training/data-store'
|
|
21
|
+
import type { ModelManifest, ModelManifestEntry, ModelTask } from '@/intelligence/model-manager'
|
|
22
|
+
import { shouldRetrain, retrainTask, retrainAll, type RetrainConfig } from '@/training/retrain-pipeline'
|
|
23
|
+
|
|
24
|
+
const ALL_TASKS: ModelTask[] = ['intent', 'entity', 'query', 'knowledge', 'compress', 'pattern']
|
|
25
|
+
|
|
26
|
+
export async function runModels() {
|
|
27
|
+
const args = parseArgs(process.argv.slice(3), {
|
|
28
|
+
subcommand: { type: 'positional', required: false, description: 'Subcommand: list, download, enable, disable, benchmark, stats, retrain' },
|
|
29
|
+
taskArg: { type: 'positional', required: false, description: 'Task name (for enable/disable/benchmark/retrain)' },
|
|
30
|
+
task: { type: 'string', description: 'Target task (for download --task)' },
|
|
31
|
+
source: { type: 'string', description: 'Source: local (default) or release' },
|
|
32
|
+
force: { type: 'boolean', description: 'Force retrain even if checks say not needed' },
|
|
33
|
+
})
|
|
34
|
+
|
|
35
|
+
const subcommand = args.subcommand || ''
|
|
36
|
+
const taskArg = (args.task || args.taskArg || '') as string
|
|
37
|
+
|
|
38
|
+
switch (subcommand) {
|
|
39
|
+
case 'list':
|
|
40
|
+
return listModels()
|
|
41
|
+
case 'download':
|
|
42
|
+
return downloadModels(taskArg || 'all', (args.source as string) || 'local')
|
|
43
|
+
case 'enable':
|
|
44
|
+
return enableTask(taskArg)
|
|
45
|
+
case 'disable':
|
|
46
|
+
return disableTask(taskArg)
|
|
47
|
+
case 'benchmark':
|
|
48
|
+
return benchmarkTask(taskArg)
|
|
49
|
+
case 'stats':
|
|
50
|
+
return showStats()
|
|
51
|
+
case 'retrain':
|
|
52
|
+
return retrainModels(taskArg || 'all', !!args.force)
|
|
53
|
+
default:
|
|
54
|
+
return printModelsHelp()
|
|
55
|
+
}
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
// ─── list ─────────────────────────────────────────────────────────
|
|
59
|
+
|
|
60
|
+
function listModels() {
|
|
61
|
+
console.log()
|
|
62
|
+
console.log(renderLogo())
|
|
63
|
+
console.log()
|
|
64
|
+
console.log(heading('SLM Models'))
|
|
65
|
+
console.log()
|
|
66
|
+
|
|
67
|
+
const paths = getHomePaths()
|
|
68
|
+
const manifestPath = join(paths.models, 'manifest.json')
|
|
69
|
+
|
|
70
|
+
const items: Array<{ label: string; value: string; status?: 'success' | 'warning' | 'error' | 'info' }> = []
|
|
71
|
+
|
|
72
|
+
if (!existsSync(manifestPath)) {
|
|
73
|
+
items.push({ label: 'Manifest', value: 'Not found — run "models download" first', status: 'warning' })
|
|
74
|
+
console.log(summaryPanel('Models', items))
|
|
75
|
+
console.log()
|
|
76
|
+
|
|
77
|
+
// Still show task status even without manifest
|
|
78
|
+
for (const task of ALL_TASKS) {
|
|
79
|
+
console.log(` ${theme.primary(task.padEnd(12))} ${warningText('not installed')}`)
|
|
80
|
+
}
|
|
81
|
+
console.log()
|
|
82
|
+
return
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
let manifest: ModelManifest
|
|
86
|
+
try {
|
|
87
|
+
manifest = JSON.parse(readFileSync(manifestPath, 'utf-8'))
|
|
88
|
+
} catch {
|
|
89
|
+
console.log(errorText('Failed to parse manifest.json'))
|
|
90
|
+
return
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
let totalSize = 0
|
|
94
|
+
const rows: string[] = []
|
|
95
|
+
|
|
96
|
+
for (const task of ALL_TASKS) {
|
|
97
|
+
const entry = manifest.models?.[task]
|
|
98
|
+
if (!entry) {
|
|
99
|
+
rows.push(` ${theme.primary(task.padEnd(12))} ${warningText('not installed')}`)
|
|
100
|
+
continue
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
const modelPath = join(paths.models, entry.file)
|
|
104
|
+
const installed = existsSync(modelPath)
|
|
105
|
+
|
|
106
|
+
let sizeStr = '—'
|
|
107
|
+
if (installed) {
|
|
108
|
+
try {
|
|
109
|
+
const size = statSync(modelPath).size
|
|
110
|
+
totalSize += size
|
|
111
|
+
sizeStr = formatBytes(size)
|
|
112
|
+
} catch {
|
|
113
|
+
sizeStr = '?'
|
|
114
|
+
}
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
const status = installed ? successText('installed') : warningText('missing')
|
|
118
|
+
const version = entry.version ? dimText(`v${entry.version}`) : ''
|
|
119
|
+
const accuracy = entry.accuracy != null ? dimText(`${(entry.accuracy * 100).toFixed(1)}%`) : ''
|
|
120
|
+
|
|
121
|
+
rows.push(
|
|
122
|
+
` ${theme.primary(task.padEnd(12))} ${status.padEnd(24)} ${version.padEnd(14)} ${accuracy.padEnd(10)} ${dimText(sizeStr)}`
|
|
123
|
+
)
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
// Header
|
|
127
|
+
console.log(
|
|
128
|
+
` ${dimText('Task'.padEnd(12))} ${dimText('Status'.padEnd(14))} ${dimText('Version'.padEnd(14))} ${dimText('Accuracy'.padEnd(10))} ${dimText('Size')}`
|
|
129
|
+
)
|
|
130
|
+
for (const row of rows) {
|
|
131
|
+
console.log(row)
|
|
132
|
+
}
|
|
133
|
+
console.log()
|
|
134
|
+
console.log(` ${dimText('Total size:')} ${dimText(formatBytes(totalSize))}`)
|
|
135
|
+
console.log()
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
// ─── download ─────────────────────────────────────────────────────
|
|
139
|
+
|
|
140
|
+
function downloadModels(taskFilter: string, source: string) {
|
|
141
|
+
console.log()
|
|
142
|
+
console.log(renderLogo())
|
|
143
|
+
console.log()
|
|
144
|
+
console.log(heading('Download Models'))
|
|
145
|
+
console.log()
|
|
146
|
+
|
|
147
|
+
const paths = getHomePaths()
|
|
148
|
+
|
|
149
|
+
// Validate task filter
|
|
150
|
+
if (taskFilter !== 'all' && !ALL_TASKS.includes(taskFilter as ModelTask)) {
|
|
151
|
+
console.log(errorText(`Invalid task: ${taskFilter}`))
|
|
152
|
+
console.log(dimText(`Valid tasks: ${ALL_TASKS.join(', ')}`))
|
|
153
|
+
return
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
const tasks = taskFilter === 'all' ? ALL_TASKS : [taskFilter as ModelTask]
|
|
157
|
+
|
|
158
|
+
// Ensure models directory exists
|
|
159
|
+
if (!existsSync(paths.models)) {
|
|
160
|
+
mkdirSync(paths.models, { recursive: true })
|
|
161
|
+
console.log(successText(`Created models directory: ${paths.models}`))
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
// Release source — not yet implemented
|
|
165
|
+
if (source === 'release') {
|
|
166
|
+
console.log(
|
|
167
|
+
box(
|
|
168
|
+
'Downloading from release artifacts is not yet available.\n' +
|
|
169
|
+
'Use --source local to install from ~/slm-training/models/ instead.',
|
|
170
|
+
'Coming Soon'
|
|
171
|
+
)
|
|
172
|
+
)
|
|
173
|
+
console.log()
|
|
174
|
+
return
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
// Local source — copy from ~/slm-training/models/
|
|
178
|
+
const sourceDir = join(homedir(), 'slm-training', 'models')
|
|
179
|
+
|
|
180
|
+
console.log(` ${dimText('Source:')} ${sourceDir}`)
|
|
181
|
+
console.log(` ${dimText('Target:')} ${paths.models}`)
|
|
182
|
+
console.log()
|
|
183
|
+
|
|
184
|
+
if (!existsSync(sourceDir)) {
|
|
185
|
+
console.log(errorText(`Source directory not found: ${sourceDir}`))
|
|
186
|
+
console.log(dimText('Train models first, then place .onnx and .json files in the source directory.'))
|
|
187
|
+
return
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
let installed = 0
|
|
191
|
+
let totalBytes = 0
|
|
192
|
+
const installedTasks: ModelTask[] = []
|
|
193
|
+
const manifestModels: Partial<Record<ModelTask, ModelManifestEntry>> = {}
|
|
194
|
+
|
|
195
|
+
// Load existing manifest to preserve entries for tasks we're not updating
|
|
196
|
+
const manifestPath = join(paths.models, 'manifest.json')
|
|
197
|
+
if (existsSync(manifestPath)) {
|
|
198
|
+
try {
|
|
199
|
+
const existing: ModelManifest = JSON.parse(readFileSync(manifestPath, 'utf-8'))
|
|
200
|
+
if (existing.models) {
|
|
201
|
+
Object.assign(manifestModels, existing.models)
|
|
202
|
+
}
|
|
203
|
+
} catch {
|
|
204
|
+
// Ignore corrupt manifest, we'll overwrite it
|
|
205
|
+
}
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
for (const task of tasks) {
|
|
209
|
+
const onnxFile = `${task}.onnx`
|
|
210
|
+
const metaFile = `${task}.json`
|
|
211
|
+
const srcOnnx = join(sourceDir, onnxFile)
|
|
212
|
+
const srcMeta = join(sourceDir, metaFile)
|
|
213
|
+
|
|
214
|
+
if (!existsSync(srcOnnx)) {
|
|
215
|
+
console.log(` ${warningText(`${onnxFile} not found in source — skipping`)}`)
|
|
216
|
+
continue
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
// Copy ONNX model
|
|
220
|
+
const dstOnnx = join(paths.models, onnxFile)
|
|
221
|
+
copyFileSync(srcOnnx, dstOnnx)
|
|
222
|
+
|
|
223
|
+
// Verify copied file is valid (non-empty and size matches source)
|
|
224
|
+
const srcSize = statSync(srcOnnx).size
|
|
225
|
+
const dstSize = statSync(dstOnnx).size
|
|
226
|
+
if (dstSize === 0 || dstSize !== srcSize) {
|
|
227
|
+
console.log(` ${errorText(`${onnxFile} copy verification failed (src: ${formatBytes(srcSize)}, dst: ${formatBytes(dstSize)}) — skipping`)}`)
|
|
228
|
+
continue
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
totalBytes += dstSize
|
|
232
|
+
console.log(` Copying ${onnxFile}... ${successText('done')} ${dimText(`(${formatBytes(dstSize)})`)}`)
|
|
233
|
+
|
|
234
|
+
// Copy metadata if present
|
|
235
|
+
let meta: Partial<ModelManifestEntry> = {}
|
|
236
|
+
if (existsSync(srcMeta)) {
|
|
237
|
+
const dstMeta = join(paths.models, metaFile)
|
|
238
|
+
copyFileSync(srcMeta, dstMeta)
|
|
239
|
+
try {
|
|
240
|
+
meta = JSON.parse(readFileSync(dstMeta, 'utf-8'))
|
|
241
|
+
} catch {
|
|
242
|
+
console.log(` ${warningText(`Failed to parse ${metaFile} — using defaults`)}`)
|
|
243
|
+
}
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
// Build manifest entry for this task
|
|
247
|
+
// Map training metadata fields → manifest fields:
|
|
248
|
+
// val_acc → accuracy, block_size → maxSeqLen, model_name → params
|
|
249
|
+
const metaAny = meta as Record<string, any>
|
|
250
|
+
manifestModels[task] = {
|
|
251
|
+
version: meta.version ?? '0.1.0',
|
|
252
|
+
file: onnxFile,
|
|
253
|
+
sha256: meta.sha256,
|
|
254
|
+
params: meta.params ?? metaAny.model_name,
|
|
255
|
+
accuracy: meta.accuracy ?? metaAny.val_acc,
|
|
256
|
+
labels: meta.labels,
|
|
257
|
+
maxSeqLen: meta.maxSeqLen ?? metaAny.block_size,
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
installedTasks.push(task)
|
|
261
|
+
installed++
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
console.log()
|
|
265
|
+
|
|
266
|
+
if (installed === 0) {
|
|
267
|
+
console.log(warningText('No models were installed.'))
|
|
268
|
+
console.log()
|
|
269
|
+
return
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
// Write manifest
|
|
273
|
+
const manifest: ModelManifest = { models: manifestModels }
|
|
274
|
+
writeFileSync(manifestPath, JSON.stringify(manifest, null, 2))
|
|
275
|
+
|
|
276
|
+
// Auto-enable successfully installed models in config
|
|
277
|
+
const configPath = join(getClaudeBrainHome(), 'config.json')
|
|
278
|
+
const config = loadConfigFile(configPath)
|
|
279
|
+
if (!config.slm) config.slm = {}
|
|
280
|
+
config.slm.enabled = true
|
|
281
|
+
if (!config.slm.tasks) config.slm.tasks = {}
|
|
282
|
+
for (const task of installedTasks) {
|
|
283
|
+
config.slm.tasks[task] = 'model'
|
|
284
|
+
}
|
|
285
|
+
writeFileSync(configPath, JSON.stringify(config, null, 2))
|
|
286
|
+
|
|
287
|
+
console.log(successText(`Installed ${installed} model${installed !== 1 ? 's' : ''} (total: ${formatBytes(totalBytes)})`))
|
|
288
|
+
console.log(successText(`Auto-enabled ${installedTasks.join(', ')} in config`))
|
|
289
|
+
console.log(dimText(`Manifest written to ${manifestPath}`))
|
|
290
|
+
console.log()
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
// ─── enable ───────────────────────────────────────────────────────
|
|
294
|
+
|
|
295
|
+
function enableTask(task: string) {
|
|
296
|
+
if (!task) {
|
|
297
|
+
console.log(errorText('Missing task argument'))
|
|
298
|
+
console.log(dimText(`Usage: claude-brain models enable <task>`))
|
|
299
|
+
console.log(dimText(`Tasks: ${ALL_TASKS.join(', ')}`))
|
|
300
|
+
return
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
if (!ALL_TASKS.includes(task as ModelTask)) {
|
|
304
|
+
console.log(errorText(`Invalid task: ${task}`))
|
|
305
|
+
console.log(dimText(`Valid tasks: ${ALL_TASKS.join(', ')}`))
|
|
306
|
+
return
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
const configPath = join(getClaudeBrainHome(), 'config.json')
|
|
310
|
+
const config = loadConfigFile(configPath)
|
|
311
|
+
|
|
312
|
+
// Ensure slm section exists
|
|
313
|
+
if (!config.slm) config.slm = {}
|
|
314
|
+
config.slm.enabled = true
|
|
315
|
+
if (!config.slm.tasks) config.slm.tasks = {}
|
|
316
|
+
|
|
317
|
+
// compress uses 'api' baseline, others use 'regex'
|
|
318
|
+
config.slm.tasks[task] = 'model'
|
|
319
|
+
|
|
320
|
+
writeFileSync(configPath, JSON.stringify(config, null, 2))
|
|
321
|
+
console.log(successText(`Enabled model inference for task: ${task}`))
|
|
322
|
+
console.log(dimText(`Config updated: ${configPath}`))
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
// ─── disable ──────────────────────────────────────────────────────
|
|
326
|
+
|
|
327
|
+
function disableTask(task: string) {
|
|
328
|
+
if (!task) {
|
|
329
|
+
console.log(errorText('Missing task argument'))
|
|
330
|
+
console.log(dimText(`Usage: claude-brain models disable <task>`))
|
|
331
|
+
console.log(dimText(`Tasks: ${ALL_TASKS.join(', ')}`))
|
|
332
|
+
return
|
|
333
|
+
}
|
|
334
|
+
|
|
335
|
+
if (!ALL_TASKS.includes(task as ModelTask)) {
|
|
336
|
+
console.log(errorText(`Invalid task: ${task}`))
|
|
337
|
+
console.log(dimText(`Valid tasks: ${ALL_TASKS.join(', ')}`))
|
|
338
|
+
return
|
|
339
|
+
}
|
|
340
|
+
|
|
341
|
+
const configPath = join(getClaudeBrainHome(), 'config.json')
|
|
342
|
+
const config = loadConfigFile(configPath)
|
|
343
|
+
|
|
344
|
+
if (!config.slm) config.slm = {}
|
|
345
|
+
if (!config.slm.tasks) config.slm.tasks = {}
|
|
346
|
+
|
|
347
|
+
// Revert to baseline: compress → 'api', others → 'regex'
|
|
348
|
+
config.slm.tasks[task] = task === 'compress' ? 'api' : 'regex'
|
|
349
|
+
|
|
350
|
+
writeFileSync(configPath, JSON.stringify(config, null, 2))
|
|
351
|
+
console.log(successText(`Disabled model inference for task: ${task}`))
|
|
352
|
+
console.log(dimText(`Reverted to ${task === 'compress' ? 'api' : 'regex'} mode`))
|
|
353
|
+
}
|
|
354
|
+
|
|
355
|
+
// ─── benchmark ────────────────────────────────────────────────────
|
|
356
|
+
|
|
357
|
+
async function benchmarkTask(task: string) {
|
|
358
|
+
if (!task) {
|
|
359
|
+
console.log(errorText('Missing task argument'))
|
|
360
|
+
console.log(dimText(`Usage: claude-brain models benchmark <task>`))
|
|
361
|
+
console.log(dimText(`Tasks: ${ALL_TASKS.join(', ')}`))
|
|
362
|
+
return
|
|
363
|
+
}
|
|
364
|
+
|
|
365
|
+
if (!ALL_TASKS.includes(task as ModelTask)) {
|
|
366
|
+
console.log(errorText(`Invalid task: ${task}`))
|
|
367
|
+
console.log(dimText(`Valid tasks: ${ALL_TASKS.join(', ')}`))
|
|
368
|
+
return
|
|
369
|
+
}
|
|
370
|
+
|
|
371
|
+
console.log()
|
|
372
|
+
console.log(heading(`Benchmark: ${task}`))
|
|
373
|
+
console.log()
|
|
374
|
+
|
|
375
|
+
// Look for test data
|
|
376
|
+
const home = getClaudeBrainHome()
|
|
377
|
+
const testDataPath = join(home, 'training', 'benchmarks', 'baseline', `${task}_test.jsonl`)
|
|
378
|
+
|
|
379
|
+
if (!existsSync(testDataPath)) {
|
|
380
|
+
console.log(warningText(`Test data not found: ${testDataPath}`))
|
|
381
|
+
console.log(dimText('Generate test data first with: claude-brain export-training ' + task))
|
|
382
|
+
return
|
|
383
|
+
}
|
|
384
|
+
|
|
385
|
+
// Load test examples
|
|
386
|
+
const lines = readFileSync(testDataPath, 'utf-8').trim().split('\n').filter(Boolean)
|
|
387
|
+
if (lines.length === 0) {
|
|
388
|
+
console.log(warningText('Test file is empty'))
|
|
389
|
+
return
|
|
390
|
+
}
|
|
391
|
+
|
|
392
|
+
console.log(dimText(`Loaded ${lines.length} test examples from ${testDataPath}`))
|
|
393
|
+
console.log()
|
|
394
|
+
|
|
395
|
+
// Run through regex classifier for intent task
|
|
396
|
+
if (task === 'intent') {
|
|
397
|
+
await benchmarkIntent(lines)
|
|
398
|
+
} else {
|
|
399
|
+
// For non-intent tasks, just report data availability
|
|
400
|
+
console.log(dimText(`Benchmark for "${task}" requires model inference (ONNX).`))
|
|
401
|
+
console.log(dimText(`Regex-only benchmarking is available for the intent task.`))
|
|
402
|
+
console.log()
|
|
403
|
+
console.log(dimText(`Test data: ${lines.length} examples ready`))
|
|
404
|
+
}
|
|
405
|
+
}
|
|
406
|
+
|
|
407
|
+
async function benchmarkIntent(lines: string[]) {
|
|
408
|
+
// Dynamic import to avoid circular deps
|
|
409
|
+
const { IntentClassifier } = await import('@/routing/intent-classifier')
|
|
410
|
+
const classifier = new IntentClassifier()
|
|
411
|
+
|
|
412
|
+
let correct = 0
|
|
413
|
+
const labelCounts: Record<string, { tp: number; fp: number; fn: number }> = {}
|
|
414
|
+
|
|
415
|
+
for (const line of lines) {
|
|
416
|
+
let example: { input: string; output: { label: string } }
|
|
417
|
+
try {
|
|
418
|
+
example = JSON.parse(line)
|
|
419
|
+
} catch {
|
|
420
|
+
continue
|
|
421
|
+
}
|
|
422
|
+
|
|
423
|
+
const expected = example.output?.label
|
|
424
|
+
if (!expected) continue
|
|
425
|
+
|
|
426
|
+
const result = classifier.classify(example.input)
|
|
427
|
+
const predicted = result.primary
|
|
428
|
+
|
|
429
|
+
// Initialize counters
|
|
430
|
+
if (!labelCounts[expected]) labelCounts[expected] = { tp: 0, fp: 0, fn: 0 }
|
|
431
|
+
if (!labelCounts[predicted]) labelCounts[predicted] = { tp: 0, fp: 0, fn: 0 }
|
|
432
|
+
|
|
433
|
+
if (predicted === expected) {
|
|
434
|
+
correct++
|
|
435
|
+
labelCounts[expected].tp++
|
|
436
|
+
} else {
|
|
437
|
+
labelCounts[expected].fn++
|
|
438
|
+
labelCounts[predicted].fp++
|
|
439
|
+
}
|
|
440
|
+
}
|
|
441
|
+
|
|
442
|
+
const total = lines.length
|
|
443
|
+
const accuracy = total > 0 ? correct / total : 0
|
|
444
|
+
|
|
445
|
+
console.log(` ${theme.bold('Overall Accuracy:')} ${accuracy >= 0.7 ? successText(`${(accuracy * 100).toFixed(1)}%`) : warningText(`${(accuracy * 100).toFixed(1)}%`)} (${correct}/${total})`)
|
|
446
|
+
console.log()
|
|
447
|
+
|
|
448
|
+
// Per-class metrics
|
|
449
|
+
console.log(` ${dimText('Label'.padEnd(20))} ${dimText('Prec'.padEnd(8))} ${dimText('Recall'.padEnd(8))} ${dimText('F1'.padEnd(8))} ${dimText('Support')}`)
|
|
450
|
+
console.log(` ${dimText('-'.repeat(52))}`)
|
|
451
|
+
|
|
452
|
+
const sortedLabels = Object.keys(labelCounts).sort()
|
|
453
|
+
for (const label of sortedLabels) {
|
|
454
|
+
const { tp, fp, fn } = labelCounts[label]
|
|
455
|
+
const precision = tp + fp > 0 ? tp / (tp + fp) : 0
|
|
456
|
+
const recall = tp + fn > 0 ? tp / (tp + fn) : 0
|
|
457
|
+
const f1 = precision + recall > 0 ? (2 * precision * recall) / (precision + recall) : 0
|
|
458
|
+
const support = tp + fn
|
|
459
|
+
|
|
460
|
+
console.log(
|
|
461
|
+
` ${theme.primary(label.padEnd(20))} ${(precision * 100).toFixed(1).padStart(5)}% ${(recall * 100).toFixed(1).padStart(5)}% ${(f1 * 100).toFixed(1).padStart(5)}% ${String(support).padStart(5)}`
|
|
462
|
+
)
|
|
463
|
+
}
|
|
464
|
+
console.log()
|
|
465
|
+
}
|
|
466
|
+
|
|
467
|
+
// ─── stats ────────────────────────────────────────────────────────
|
|
468
|
+
|
|
469
|
+
async function showStats() {
|
|
470
|
+
console.log()
|
|
471
|
+
console.log(renderLogo())
|
|
472
|
+
console.log()
|
|
473
|
+
console.log(heading('Training Data Statistics'))
|
|
474
|
+
console.log()
|
|
475
|
+
|
|
476
|
+
const stats = getTrainingStats()
|
|
477
|
+
const items: Array<{ label: string; value: string; status?: 'success' | 'warning' | 'error' | 'info' }> = []
|
|
478
|
+
|
|
479
|
+
let grandTotal = 0
|
|
480
|
+
let grandVerified = 0
|
|
481
|
+
|
|
482
|
+
for (const task of ALL_TASKS) {
|
|
483
|
+
const { total, verified } = stats[task]
|
|
484
|
+
grandTotal += total
|
|
485
|
+
grandVerified += verified
|
|
486
|
+
items.push({
|
|
487
|
+
label: task,
|
|
488
|
+
value: `${total} total, ${verified} verified`,
|
|
489
|
+
status: total > 0 ? 'success' : 'warning'
|
|
490
|
+
})
|
|
491
|
+
}
|
|
492
|
+
|
|
493
|
+
items.push({
|
|
494
|
+
label: 'Total',
|
|
495
|
+
value: `${grandTotal} total, ${grandVerified} verified`,
|
|
496
|
+
status: grandTotal > 0 ? 'success' : 'warning'
|
|
497
|
+
})
|
|
498
|
+
|
|
499
|
+
console.log(summaryPanel('Training Data', items))
|
|
500
|
+
console.log()
|
|
501
|
+
|
|
502
|
+
// Check for model_feedback table
|
|
503
|
+
try {
|
|
504
|
+
const { Database } = await import('bun:sqlite') as any
|
|
505
|
+
const dbPath = join(getClaudeBrainHome(), 'data', 'memory.db')
|
|
506
|
+
if (existsSync(dbPath)) {
|
|
507
|
+
const db = new Database(dbPath, { readonly: true })
|
|
508
|
+
const hasTable = db.prepare("SELECT name FROM sqlite_master WHERE type='table' AND name='model_feedback'").get()
|
|
509
|
+
if (hasTable) {
|
|
510
|
+
const disagreements = (db.prepare('SELECT COUNT(*) as c FROM model_feedback WHERE model_label != regex_label').get() as any)?.c ?? 0
|
|
511
|
+
const totalFeedback = (db.prepare('SELECT COUNT(*) as c FROM model_feedback').get() as any)?.c ?? 0
|
|
512
|
+
console.log(` ${dimText('Model Feedback:')} ${totalFeedback} entries, ${disagreements} disagreements`)
|
|
513
|
+
console.log()
|
|
514
|
+
}
|
|
515
|
+
db.close()
|
|
516
|
+
}
|
|
517
|
+
} catch {
|
|
518
|
+
// model_feedback table doesn't exist yet — that's fine
|
|
519
|
+
}
|
|
520
|
+
}
|
|
521
|
+
|
|
522
|
+
// ─── retrain ──────────────────────────────────────────────────────
|
|
523
|
+
|
|
524
|
+
async function retrainModels(taskFilter: string, force: boolean) {
|
|
525
|
+
console.log()
|
|
526
|
+
console.log(renderLogo())
|
|
527
|
+
console.log()
|
|
528
|
+
console.log(heading('Retrain Models'))
|
|
529
|
+
console.log()
|
|
530
|
+
|
|
531
|
+
// Validate task filter
|
|
532
|
+
if (taskFilter !== 'all' && !ALL_TASKS.includes(taskFilter as ModelTask)) {
|
|
533
|
+
console.log(errorText(`Invalid task: ${taskFilter}`))
|
|
534
|
+
console.log(dimText(`Valid tasks: ${ALL_TASKS.join(', ')}, all`))
|
|
535
|
+
return
|
|
536
|
+
}
|
|
537
|
+
|
|
538
|
+
// Build config from schema defaults + config.json overrides
|
|
539
|
+
const configPath = join(getClaudeBrainHome(), 'config.json')
|
|
540
|
+
const userConfig = loadConfigFile(configPath)
|
|
541
|
+
const retrainCfg = userConfig?.slm?.retrain ?? {}
|
|
542
|
+
|
|
543
|
+
const config: RetrainConfig = {
|
|
544
|
+
minFeedbackCount: retrainCfg.minFeedbackCount ?? 100,
|
|
545
|
+
maxDisagreementRate: retrainCfg.maxDisagreementRate ?? 0.15,
|
|
546
|
+
pythonPath: retrainCfg.pythonPath ?? 'python3',
|
|
547
|
+
trainingDir: retrainCfg.trainingDir ?? '~/slm-training',
|
|
548
|
+
force,
|
|
549
|
+
}
|
|
550
|
+
|
|
551
|
+
console.log(` ${dimText('Training dir:')} ${config.trainingDir}`)
|
|
552
|
+
console.log(` ${dimText('Python path:')} ${config.pythonPath}`)
|
|
553
|
+
console.log(` ${dimText('Min feedback:')} ${config.minFeedbackCount}`)
|
|
554
|
+
console.log(` ${dimText('Max disagreement:')} ${(config.maxDisagreementRate * 100).toFixed(0)}%`)
|
|
555
|
+
if (force) console.log(` ${warningText('Force mode enabled — skipping checks')}`)
|
|
556
|
+
console.log()
|
|
557
|
+
|
|
558
|
+
if (taskFilter === 'all') {
|
|
559
|
+
// Retrain all tasks that need it
|
|
560
|
+
const results = await retrainAll(config)
|
|
561
|
+
|
|
562
|
+
if (results.size === 0) {
|
|
563
|
+
console.log(dimText(' No tasks needed retraining.'))
|
|
564
|
+
console.log()
|
|
565
|
+
return
|
|
566
|
+
}
|
|
567
|
+
|
|
568
|
+
// Summary
|
|
569
|
+
console.log()
|
|
570
|
+
console.log(heading('Retrain Summary'))
|
|
571
|
+
console.log()
|
|
572
|
+
for (const [task, result] of results) {
|
|
573
|
+
if (result.success) {
|
|
574
|
+
const accStr = result.newAccuracy != null
|
|
575
|
+
? `${(result.newAccuracy * 100).toFixed(1)}%`
|
|
576
|
+
: 'n/a'
|
|
577
|
+
const oldStr = result.oldAccuracy != null
|
|
578
|
+
? ` (was ${(result.oldAccuracy * 100).toFixed(1)}%)`
|
|
579
|
+
: ''
|
|
580
|
+
console.log(` ${successText(task.padEnd(12))} accuracy: ${accStr}${oldStr} ${dimText(`${result.trainingDataCount} examples, ${(result.duration / 1000).toFixed(1)}s`)}`)
|
|
581
|
+
} else {
|
|
582
|
+
console.log(` ${errorText(task.padEnd(12))} ${result.error}`)
|
|
583
|
+
}
|
|
584
|
+
}
|
|
585
|
+
} else {
|
|
586
|
+
// Single task
|
|
587
|
+
const task = taskFilter as ModelTask
|
|
588
|
+
|
|
589
|
+
if (!force) {
|
|
590
|
+
const check = shouldRetrain(task, config)
|
|
591
|
+
console.log(` ${dimText('Feedback count:')} ${check.feedbackCount}`)
|
|
592
|
+
console.log(` ${dimText('Disagreement rate:')} ${(check.disagreementRate * 100).toFixed(1)}%`)
|
|
593
|
+
console.log(` ${dimText('Last retrain:')} ${check.lastRetrainDate ?? 'never'}`)
|
|
594
|
+
console.log(` ${dimText('Needs retrain:')} ${check.needed ? 'yes' : 'no'} — ${check.reason}`)
|
|
595
|
+
console.log()
|
|
596
|
+
|
|
597
|
+
if (!check.needed) {
|
|
598
|
+
console.log(dimText(' Retrain not needed. Use --force to override.'))
|
|
599
|
+
console.log()
|
|
600
|
+
return
|
|
601
|
+
}
|
|
602
|
+
}
|
|
603
|
+
|
|
604
|
+
const result = await retrainTask(task, config)
|
|
605
|
+
|
|
606
|
+
console.log()
|
|
607
|
+
if (result.success) {
|
|
608
|
+
const accStr = result.newAccuracy != null
|
|
609
|
+
? `${(result.newAccuracy * 100).toFixed(1)}%`
|
|
610
|
+
: 'n/a'
|
|
611
|
+
const oldStr = result.oldAccuracy != null
|
|
612
|
+
? ` (was ${(result.oldAccuracy * 100).toFixed(1)}%)`
|
|
613
|
+
: ''
|
|
614
|
+
console.log(successText(`Retrain complete: accuracy ${accStr}${oldStr}`))
|
|
615
|
+
console.log(dimText(` ${result.trainingDataCount} examples, ${(result.duration / 1000).toFixed(1)}s`))
|
|
616
|
+
} else {
|
|
617
|
+
console.log(errorText(`Retrain failed: ${result.error}`))
|
|
618
|
+
}
|
|
619
|
+
}
|
|
620
|
+
|
|
621
|
+
console.log()
|
|
622
|
+
}
|
|
623
|
+
|
|
624
|
+
// ─── help ─────────────────────────────────────────────────────────
|
|
625
|
+
|
|
626
|
+
function printModelsHelp() {
|
|
627
|
+
console.log()
|
|
628
|
+
console.log(renderLogo())
|
|
629
|
+
console.log()
|
|
630
|
+
console.log(heading('SLM Model Management'))
|
|
631
|
+
console.log()
|
|
632
|
+
|
|
633
|
+
const subcommands = [
|
|
634
|
+
['list', 'Show installed models and their status'],
|
|
635
|
+
['download', 'Download pre-trained models (--task <task>|all)'],
|
|
636
|
+
['enable <task>', 'Enable model inference for a task'],
|
|
637
|
+
['disable <task>', 'Disable model inference for a task'],
|
|
638
|
+
['benchmark <task>', 'Run accuracy benchmark on test data'],
|
|
639
|
+
['stats', 'Show training data statistics'],
|
|
640
|
+
['retrain [<task>|all]', 'Retrain models from feedback (--force)'],
|
|
641
|
+
]
|
|
642
|
+
|
|
643
|
+
const lines = subcommands
|
|
644
|
+
.map(([cmd, desc]) => ` ${theme.primary(cmd!.padEnd(20))} ${dimText(desc!)}`)
|
|
645
|
+
.join('\n')
|
|
646
|
+
|
|
647
|
+
console.log(theme.bold('Usage:') + ' ' + dimText('claude-brain models <subcommand>'))
|
|
648
|
+
console.log()
|
|
649
|
+
console.log(theme.bold('Subcommands:'))
|
|
650
|
+
console.log(lines)
|
|
651
|
+
console.log()
|
|
652
|
+
console.log(theme.bold('Tasks:') + ' ' + dimText(ALL_TASKS.join(', ')))
|
|
653
|
+
console.log()
|
|
654
|
+
console.log(theme.bold('Examples:'))
|
|
655
|
+
console.log(` ${dimText('claude-brain models list')}`)
|
|
656
|
+
console.log(` ${dimText('claude-brain models enable intent')}`)
|
|
657
|
+
console.log(` ${dimText('claude-brain models benchmark intent')}`)
|
|
658
|
+
console.log(` ${dimText('claude-brain models stats')}`)
|
|
659
|
+
console.log(` ${dimText('claude-brain models retrain intent')}`)
|
|
660
|
+
console.log(` ${dimText('claude-brain models retrain all --force')}`)
|
|
661
|
+
console.log()
|
|
662
|
+
}
|
|
663
|
+
|
|
664
|
+
// ─── helpers ──────────────────────────────────────────────────────
|
|
665
|
+
|
|
666
|
+
function loadConfigFile(configPath: string): Record<string, any> {
|
|
667
|
+
if (!existsSync(configPath)) return {}
|
|
668
|
+
try {
|
|
669
|
+
return JSON.parse(readFileSync(configPath, 'utf-8'))
|
|
670
|
+
} catch {
|
|
671
|
+
return {}
|
|
672
|
+
}
|
|
673
|
+
}
|
|
674
|
+
|
|
675
|
+
function formatBytes(bytes: number): string {
|
|
676
|
+
if (bytes === 0) return '0 B'
|
|
677
|
+
const units = ['B', 'KB', 'MB', 'GB']
|
|
678
|
+
const i = Math.floor(Math.log(bytes) / Math.log(1024))
|
|
679
|
+
const val = bytes / Math.pow(1024, i)
|
|
680
|
+
return `${val.toFixed(i > 0 ? 1 : 0)} ${units[i]}`
|
|
681
|
+
}
|