claude-brain 0.27.2 → 0.28.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,206 @@
1
+ /**
2
+ * Model Manager — SLM Upgrade Phase 4A
3
+ * Discovers, loads, and caches ONNX models from ~/.claude-brain/models/
4
+ * Lazy loading: models load on first use, not at startup.
5
+ */
6
+
7
+ import { existsSync, readFileSync } from 'node:fs'
8
+ import { join } from 'node:path'
9
+ import type { Logger } from 'pino'
10
+ import { getHomePaths } from '@/config/home'
11
+
12
+ export type ModelTask = 'intent' | 'entity' | 'query' | 'knowledge' | 'compress' | 'pattern'
13
+
14
+ export interface ModelManifestEntry {
15
+ version: string
16
+ file: string
17
+ sha256?: string
18
+ params?: string
19
+ accuracy?: number
20
+ labels?: string[]
21
+ maxSeqLen?: number
22
+ }
23
+
24
+ export interface ModelManifest {
25
+ models: Partial<Record<ModelTask, ModelManifestEntry>>
26
+ }
27
+
28
+ export interface LoadedModel {
29
+ session: any // onnxruntime.InferenceSession
30
+ manifest: ModelManifestEntry
31
+ loadedAt: number
32
+ }
33
+
34
+ export class ModelManager {
35
+ private modelsDir: string
36
+ private manifest: ModelManifest | null = null
37
+ private loadedModels = new Map<ModelTask, LoadedModel>()
38
+ private onnxRuntime: any = null
39
+ private onnxAvailable: boolean | null = null
40
+ private logger: Logger
41
+
42
+ constructor(logger: Logger, modelsDir?: string) {
43
+ this.logger = logger.child({ component: 'model-manager' })
44
+ this.modelsDir = modelsDir || getHomePaths().models
45
+ }
46
+
47
+ /**
48
+ * Check if ONNX Runtime is available (installed as optional dep)
49
+ * Tries onnxruntime-node first (faster native), falls back to onnxruntime-web (WASM)
50
+ */
51
+ private async checkOnnxRuntime(): Promise<boolean> {
52
+ if (this.onnxAvailable !== null) return this.onnxAvailable
53
+ // Try native node bindings first
54
+ try {
55
+ this.onnxRuntime = await import('onnxruntime-node')
56
+ this.onnxAvailable = true
57
+ this.logger.debug('ONNX Runtime (native) available')
58
+ return true
59
+ } catch {
60
+ // Native not available, try WASM fallback
61
+ }
62
+ try {
63
+ this.onnxRuntime = await import('onnxruntime-web')
64
+ this.onnxAvailable = true
65
+ this.logger.debug('ONNX Runtime (WASM) available')
66
+ } catch {
67
+ this.onnxAvailable = false
68
+ this.logger.debug('ONNX Runtime not installed — models will not load')
69
+ }
70
+ return this.onnxAvailable
71
+ }
72
+
73
+ /**
74
+ * Load the manifest.json from the models directory
75
+ */
76
+ private loadManifest(): ModelManifest | null {
77
+ if (this.manifest) return this.manifest
78
+ const manifestPath = join(this.modelsDir, 'manifest.json')
79
+ if (!existsSync(manifestPath)) {
80
+ this.logger.debug({ modelsDir: this.modelsDir }, 'No manifest.json found')
81
+ return null
82
+ }
83
+ try {
84
+ const raw = readFileSync(manifestPath, 'utf-8')
85
+ this.manifest = JSON.parse(raw) as ModelManifest
86
+ return this.manifest
87
+ } catch (error) {
88
+ this.logger.warn({ error }, 'Failed to parse manifest.json')
89
+ return null
90
+ }
91
+ }
92
+
93
+ /**
94
+ * Check if a model file exists for a given task (without loading it)
95
+ */
96
+ hasModel(task: ModelTask): boolean {
97
+ const manifest = this.loadManifest()
98
+ if (!manifest?.models[task]) return false
99
+ const entry = manifest.models[task]!
100
+ return existsSync(join(this.modelsDir, entry.file))
101
+ }
102
+
103
+ /**
104
+ * Get manifest entry for a task
105
+ */
106
+ getManifestEntry(task: ModelTask): ModelManifestEntry | null {
107
+ const manifest = this.loadManifest()
108
+ return manifest?.models[task] ?? null
109
+ }
110
+
111
+ /**
112
+ * Lazy-load a model on first use. Returns null if unavailable.
113
+ */
114
+ async loadModel(task: ModelTask): Promise<LoadedModel | null> {
115
+ // Return cached model
116
+ if (this.loadedModels.has(task)) return this.loadedModels.get(task)!
117
+
118
+ // Check prerequisites
119
+ if (!(await this.checkOnnxRuntime())) return null
120
+ const manifest = this.loadManifest()
121
+ if (!manifest?.models[task]) return null
122
+
123
+ const entry = manifest.models[task]!
124
+ const modelPath = join(this.modelsDir, entry.file)
125
+ if (!existsSync(modelPath)) {
126
+ this.logger.debug({ task, file: entry.file }, 'Model file not found')
127
+ return null
128
+ }
129
+
130
+ try {
131
+ const startMs = Date.now()
132
+ const session = await this.onnxRuntime.InferenceSession.create(modelPath)
133
+ const loaded: LoadedModel = {
134
+ session,
135
+ manifest: entry,
136
+ loadedAt: Date.now(),
137
+ }
138
+ this.loadedModels.set(task, loaded)
139
+ this.logger.info({ task, file: entry.file, loadMs: Date.now() - startMs }, 'Model loaded')
140
+ return loaded
141
+ } catch (error) {
142
+ this.logger.warn({ error, task, file: entry.file }, 'Failed to load model')
143
+ return null
144
+ }
145
+ }
146
+
147
+ /**
148
+ * Run inference on a loaded model. Returns raw output tensor data.
149
+ * Callers (InferenceRouter) handle task-specific pre/post processing.
150
+ * Automatically detects which inputs the model accepts (input_ids, attention_mask).
151
+ */
152
+ async infer(task: ModelTask, inputIds: number[], attentionMask?: number[]): Promise<Float32Array | null> {
153
+ const model = await this.loadModel(task)
154
+ if (!model) return null
155
+
156
+ try {
157
+ const OrtTensor = this.onnxRuntime.Tensor
158
+ const inputTensor = new OrtTensor('int64', BigInt64Array.from(inputIds.map(BigInt)), [1, inputIds.length])
159
+ const feeds: Record<string, any> = { input_ids: inputTensor }
160
+
161
+ // Only pass attention_mask if the model actually accepts it
162
+ const modelInputNames = model.session.inputNames ?? []
163
+ if (attentionMask && modelInputNames.includes('attention_mask')) {
164
+ const maskTensor = new OrtTensor('int64', BigInt64Array.from(attentionMask.map(BigInt)), [1, attentionMask.length])
165
+ feeds.attention_mask = maskTensor
166
+ }
167
+
168
+ const results = await model.session.run(feeds)
169
+
170
+ // Most classification models output 'logits'
171
+ const outputKey = Object.keys(results)[0]
172
+ if (!outputKey) return null
173
+ return results[outputKey].data as Float32Array
174
+ } catch (error) {
175
+ this.logger.warn({ error, task }, 'Inference failed')
176
+ return null
177
+ }
178
+ }
179
+
180
+ /**
181
+ * Get status of all models (for CLI and health checks)
182
+ */
183
+ getStatus(): Record<ModelTask, { available: boolean; loaded: boolean; version?: string; accuracy?: number }> {
184
+ const tasks: ModelTask[] = ['intent', 'entity', 'query', 'knowledge', 'compress', 'pattern']
185
+ const status = {} as Record<ModelTask, { available: boolean; loaded: boolean; version?: string; accuracy?: number }>
186
+
187
+ for (const task of tasks) {
188
+ const entry = this.getManifestEntry(task)
189
+ status[task] = {
190
+ available: this.hasModel(task),
191
+ loaded: this.loadedModels.has(task),
192
+ version: entry?.version,
193
+ accuracy: entry?.accuracy,
194
+ }
195
+ }
196
+ return status
197
+ }
198
+
199
+ /**
200
+ * Unload all models (for cleanup/testing)
201
+ */
202
+ unloadAll(): void {
203
+ this.loadedModels.clear()
204
+ this.logger.debug('All models unloaded')
205
+ }
206
+ }
@@ -0,0 +1,118 @@
1
+ /**
2
+ * Tokenizer — SLM Upgrade Phase 4B
3
+ * GPT-2 BPE tokenizer for ONNX model inference.
4
+ *
5
+ * Strategy:
6
+ * 1. Try to dynamically import `tiktoken` (JS package)
7
+ * 2. Fall back to a simple whitespace tokenizer with hash-based IDs
8
+ *
9
+ * The tokenizer pads/truncates sequences to a fixed length and
10
+ * returns both input_ids and attention_mask arrays.
11
+ */
12
+
13
+ import type { Logger } from 'pino'
14
+
15
+ export interface TokenizerOutput {
16
+ inputIds: number[]
17
+ attentionMask: number[]
18
+ }
19
+
20
+ export interface Tokenizer {
21
+ encode(text: string, maxLength: number): TokenizerOutput
22
+ decode(tokenIds: number[]): string
23
+ }
24
+
25
+ /** Singleton cache */
26
+ let cachedTokenizer: Tokenizer | null = null
27
+
28
+ /**
29
+ * Get a tokenizer instance. Tries tiktoken first, falls back to hash-based.
30
+ */
31
+ export async function getTokenizer(logger?: Logger): Promise<Tokenizer> {
32
+ if (cachedTokenizer) return cachedTokenizer
33
+
34
+ // Try tiktoken (JS binding for GPT-2 BPE)
35
+ try {
36
+ const tiktoken = await import('tiktoken')
37
+ const enc = tiktoken.encoding_for_model('gpt2')
38
+
39
+ cachedTokenizer = {
40
+ encode(text: string, maxLength: number): TokenizerOutput {
41
+ const tokens = Array.from(enc.encode(text))
42
+
43
+ // Truncate if necessary
44
+ const truncated = tokens.slice(0, maxLength)
45
+
46
+ // Pad to maxLength
47
+ const inputIds = new Array(maxLength).fill(0)
48
+ const attentionMask = new Array(maxLength).fill(0)
49
+
50
+ for (let i = 0; i < truncated.length; i++) {
51
+ inputIds[i] = truncated[i]
52
+ attentionMask[i] = 1
53
+ }
54
+
55
+ return { inputIds, attentionMask }
56
+ },
57
+ decode(tokenIds: number[]): string {
58
+ // Filter out padding (0) tokens
59
+ const filtered = tokenIds.filter(id => id !== 0)
60
+ return new TextDecoder().decode(enc.decode(new Uint32Array(filtered)))
61
+ }
62
+ }
63
+
64
+ logger?.debug('Using tiktoken GPT-2 tokenizer')
65
+ return cachedTokenizer
66
+ } catch {
67
+ // tiktoken not available
68
+ }
69
+
70
+ // Fallback: simple whitespace tokenizer with hash-based IDs
71
+ logger?.warn('tiktoken not available — using fallback hash-based tokenizer (reduced accuracy)')
72
+
73
+ cachedTokenizer = {
74
+ encode(text: string, maxLength: number): TokenizerOutput {
75
+ // Split on whitespace and punctuation, filter empties
76
+ const tokens = text
77
+ .toLowerCase()
78
+ .split(/(\s+|[.,!?;:'"()\[\]{}<>\/\\@#$%^&*+=~`|_-]+)/)
79
+ .filter(t => t.trim().length > 0)
80
+
81
+ // Hash each token to a stable ID in [1, 50256] range (GPT-2 vocab size)
82
+ const VOCAB_SIZE = 50256
83
+ const tokenIds = tokens.map(t => {
84
+ let hash = 5381
85
+ for (let i = 0; i < t.length; i++) {
86
+ hash = ((hash << 5) + hash + t.charCodeAt(i)) & 0x7fffffff
87
+ }
88
+ return (hash % (VOCAB_SIZE - 1)) + 1 // avoid 0 (used for padding)
89
+ })
90
+
91
+ // Truncate
92
+ const truncated = tokenIds.slice(0, maxLength)
93
+
94
+ // Pad to maxLength
95
+ const inputIds = new Array(maxLength).fill(0)
96
+ const attentionMask = new Array(maxLength).fill(0)
97
+
98
+ for (let i = 0; i < truncated.length; i++) {
99
+ inputIds[i] = truncated[i]
100
+ attentionMask[i] = 1
101
+ }
102
+
103
+ return { inputIds, attentionMask }
104
+ },
105
+ decode(_tokenIds: number[]): string {
106
+ // Hash-based tokenizer is one-way; decode is not possible.
107
+ // Compression will fall back to returning original text.
108
+ return ''
109
+ }
110
+ }
111
+
112
+ return cachedTokenizer
113
+ }
114
+
115
+ /** Reset cached tokenizer (for testing) */
116
+ export function _resetTokenizerForTesting(): void {
117
+ cachedTokenizer = null
118
+ }
@@ -222,6 +222,7 @@ export class EntityExtractor {
222
222
  }
223
223
 
224
224
  extract(text: string): ExtractedEntity[] {
225
+ const startTime = Date.now()
225
226
  const entities: Map<string, ExtractedEntity> = new Map()
226
227
 
227
228
  this.extractFromDictionary(text, entities)
@@ -233,8 +234,37 @@ export class EntityExtractor {
233
234
  this.extractWithNlp(text, entities)
234
235
  }
235
236
 
236
- return Array.from(entities.values())
237
+ const result = Array.from(entities.values())
237
238
  .sort((a, b) => b.confidence - a.confidence)
239
+
240
+ // SLM Phase 1A: Log extraction for training data collection
241
+ this._logTraining(text, result, startTime)
242
+
243
+ return result
244
+ }
245
+
246
+ /**
247
+ * SLM Phase 1A: Log entity extraction result for training data.
248
+ */
249
+ private _logTraining(text: string, entities: ExtractedEntity[], startTime: number): void {
250
+ try {
251
+ const { logTrainingData } = require('@/training/data-store')
252
+ logTrainingData({
253
+ task: 'entity' as const,
254
+ input: text,
255
+ output: JSON.stringify(entities.map(e => ({
256
+ text: e.name,
257
+ type: e.type,
258
+ normalized: e.normalizedName,
259
+ confidence: e.confidence,
260
+ source: e.source,
261
+ positions: e.positions,
262
+ }))),
263
+ metadata: JSON.stringify({ count: entities.length, elapsed_ms: Date.now() - startTime }),
264
+ })
265
+ } catch {
266
+ // Non-critical
267
+ }
238
268
  }
239
269
 
240
270
  extractBatch(texts: string[]): ExtractedEntity[][] {
@@ -42,8 +42,24 @@ Drop: filler words, repetition, context that's obvious from the category.
42
42
  Observation: ${content}`
43
43
 
44
44
  try {
45
+ const startTime = Date.now()
45
46
  const response = await this.callLLM(prompt)
46
- return { summary: response, original: content, compressed: true }
47
+ const result: CompressedObservation = { summary: response, original: content, compressed: true }
48
+
49
+ // SLM Phase 1A: Log compression input/output pair as gold training data
50
+ try {
51
+ const { logTrainingData } = require('@/training/data-store')
52
+ logTrainingData({
53
+ task: 'compress' as const,
54
+ input: content,
55
+ output: JSON.stringify({ summary: response }),
56
+ metadata: JSON.stringify({ category, elapsed_ms: Date.now() - startTime, provider: this.config.provider }),
57
+ })
58
+ } catch {
59
+ // Non-critical
60
+ }
61
+
62
+ return result
47
63
  } catch (error) {
48
64
  this.logger.warn({ error }, 'LLM compression failed, storing original')
49
65
  return { summary: content, compressed: false }
@@ -34,11 +34,19 @@ export class PatternRecognizer {
34
34
  private memory: MemoryManager
35
35
  private patterns: Map<string, Pattern> = new Map()
36
36
 
37
+ /** SLM Upgrade: Optional inference router for model-based pattern classification */
38
+ private inferenceRouter: any = null
39
+
37
40
  constructor(logger: Logger, memory: MemoryManager) {
38
41
  this.logger = logger.child({ component: 'pattern-recognizer' })
39
42
  this.memory = memory
40
43
  }
41
44
 
45
+ /** SLM Upgrade: Set optional inference router */
46
+ setInferenceRouter(router: any): void {
47
+ this.inferenceRouter = router
48
+ }
49
+
42
50
  /**
43
51
  * Analyze all decisions to find patterns
44
52
  */
@@ -58,7 +66,7 @@ export class PatternRecognizer {
58
66
  const clusters = await this.clusterDecisions(allDecisions)
59
67
 
60
68
  // Extract patterns from clusters
61
- const patterns = this.extractPatterns(clusters)
69
+ const patterns = await this.extractPatterns(clusters)
62
70
 
63
71
  // Store patterns
64
72
  for (const pattern of patterns) {
@@ -159,14 +167,16 @@ export class PatternRecognizer {
159
167
  /**
160
168
  * Extract patterns from decision clusters
161
169
  */
162
- private extractPatterns(clusters: DecisionWithProject[][]): Pattern[] {
170
+ private async extractPatterns(clusters: DecisionWithProject[][]): Promise<Pattern[]> {
163
171
  const patterns: Pattern[] = []
164
172
 
165
173
  for (const cluster of clusters) {
166
- const pattern = this.analyzeCluster(cluster)
174
+ const pattern = await this.analyzeCluster(cluster)
167
175
 
168
176
  if (pattern) {
169
177
  patterns.push(pattern)
178
+ // SLM Phase 1A: Log classification for training data
179
+ this._logPatternTraining(pattern.description, pattern.type)
170
180
  }
171
181
  }
172
182
 
@@ -176,7 +186,7 @@ export class PatternRecognizer {
176
186
  /**
177
187
  * Analyze a cluster to find the common pattern
178
188
  */
179
- private analyzeCluster(cluster: DecisionWithProject[]): Pattern | null {
189
+ private async analyzeCluster(cluster: DecisionWithProject[]): Promise<Pattern | null> {
180
190
  if (cluster.length < 3) return null
181
191
 
182
192
  // Extract common keywords
@@ -184,8 +194,21 @@ export class PatternRecognizer {
184
194
 
185
195
  if (keywords.length === 0) return null
186
196
 
187
- // Determine pattern type
188
- const type = this.determinePatternType(cluster)
197
+ // Determine pattern type — SLM model first, regex fallback
198
+ let type: Pattern['type']
199
+ if (this.inferenceRouter) {
200
+ try {
201
+ // Use representative text from cluster for model classification
202
+ const representative = cluster.slice(0, 3)
203
+ .map(d => d.decision + (d.reasoning ? ` ${d.reasoning}` : ''))
204
+ .join('. ')
205
+ type = await this.inferenceRouter.classifyPatternType(representative)
206
+ } catch {
207
+ type = this.determinePatternType(cluster)
208
+ }
209
+ } else {
210
+ type = this.determinePatternType(cluster)
211
+ }
189
212
 
190
213
  // Create pattern description
191
214
  const description = this.generatePatternDescription(keywords, type)
@@ -284,6 +307,23 @@ export class PatternRecognizer {
284
307
  return 'solution'
285
308
  }
286
309
 
310
+ /**
311
+ * SLM Phase 1A: Log pattern type classification for training data.
312
+ * Called from analyzeCluster when a pattern type is determined.
313
+ */
314
+ private _logPatternTraining(description: string, patternType: Pattern['type']): void {
315
+ try {
316
+ const { logTrainingData } = require('@/training/data-store')
317
+ logTrainingData({
318
+ task: 'pattern' as const,
319
+ input: description,
320
+ output: JSON.stringify({ label: patternType }),
321
+ })
322
+ } catch {
323
+ // Non-critical
324
+ }
325
+ }
326
+
287
327
  /**
288
328
  * Generate pattern description
289
329
  */
@@ -77,6 +77,7 @@ const LOW_CONFIDENCE = 0.5
77
77
  * Classify the intent of a query
78
78
  */
79
79
  export function classifyIntent(query: string): QueryIntent {
80
+ const startTime = Date.now()
80
81
  const normalizedQuery = query.toLowerCase().trim()
81
82
 
82
83
  // Track matches for each intent
@@ -132,11 +133,26 @@ export function classifyIntent(query: string): QueryIntent {
132
133
  metadata.comparisonTerms = extractComparisonTerms(query)
133
134
  }
134
135
 
135
- return {
136
+ const result: QueryIntent = {
136
137
  type: bestIntent as QueryIntent['type'],
137
138
  confidence,
138
139
  metadata
139
140
  }
141
+
142
+ // SLM Phase 1A: Log query classification for training data collection
143
+ try {
144
+ const { logTrainingData } = require('@/training/data-store')
145
+ logTrainingData({
146
+ task: 'query' as const,
147
+ input: query,
148
+ output: JSON.stringify({ label: result.type }),
149
+ metadata: JSON.stringify({ confidence, scores, elapsed_ms: Date.now() - startTime }),
150
+ })
151
+ } catch {
152
+ // Non-critical
153
+ }
154
+
155
+ return result
140
156
  }
141
157
 
142
158
  /**
@@ -112,6 +112,14 @@ export interface BrainExtractedEntities {
112
112
  }
113
113
 
114
114
  export class BrainEntityExtractor {
115
+ /** SLM Upgrade: Optional inference router for model-based entity extraction */
116
+ private inferenceRouter: any = null
117
+
118
+ /** SLM Upgrade: Set optional inference router for model-based NER */
119
+ setInferenceRouter(router: any): void {
120
+ this.inferenceRouter = router
121
+ }
122
+
115
123
  /**
116
124
  * Extract all entities from a natural language message
117
125
  */
@@ -122,7 +130,7 @@ export class BrainEntityExtractor {
122
130
 
123
131
  // Use provided project or try to detect from message
124
132
  entities.project = knownProject || await this.extractProject(message)
125
- entities.technologies = this.extractTechnologies(message)
133
+ entities.technologies = await this.extractTechnologies(message)
126
134
  entities.topic = this.extractTopic(message)
127
135
 
128
136
  // Extract decision components if present
@@ -203,12 +211,30 @@ export class BrainEntityExtractor {
203
211
  }
204
212
 
205
213
  /**
206
- * Extract technology mentions
214
+ * Extract technology mentions.
215
+ * SLM Upgrade: tries model-based NER first if InferenceRouter is available,
216
+ * merges with dictionary fallback for comprehensive coverage.
207
217
  */
208
- private extractTechnologies(message: string): string[] {
218
+ private async extractTechnologies(message: string): Promise<string[]> {
219
+ const found = new Set<string>()
220
+
221
+ // SLM: Try model-based entity extraction first
222
+ if (this.inferenceRouter) {
223
+ try {
224
+ const modelEntities = await this.inferenceRouter.extractEntities(message)
225
+ for (const entity of modelEntities) {
226
+ if (entity.type === 'technology') {
227
+ found.add(entity.normalizedName)
228
+ }
229
+ }
230
+ } catch {
231
+ // Model failed, fall through to dictionary
232
+ }
233
+ }
234
+
235
+ // Dictionary fallback (always runs — merges with model results)
209
236
  const lower = message.toLowerCase()
210
237
  const words = lower.split(/[\s,;:()[\]{}"'`|/\\]+/)
211
- const found = new Set<string>()
212
238
 
213
239
  for (const word of words) {
214
240
  const cleaned = word.replace(/^[^a-z0-9]+|[^a-z0-9]+$/g, '')