@soulcraft/brainy 6.5.0 → 6.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/assets/models/all-MiniLM-L6-v2-q8/config.json +25 -0
- package/assets/models/all-MiniLM-L6-v2-q8/model.onnx +0 -0
- package/assets/models/all-MiniLM-L6-v2-q8/tokenizer.json +30686 -0
- package/assets/models/all-MiniLM-L6-v2-q8/vocab.json +1 -0
- package/dist/critical/model-guardian.d.ts +5 -22
- package/dist/critical/model-guardian.js +38 -210
- package/dist/embeddings/EmbeddingManager.d.ts +7 -17
- package/dist/embeddings/EmbeddingManager.js +28 -136
- package/dist/embeddings/wasm/AssetLoader.d.ts +67 -0
- package/dist/embeddings/wasm/AssetLoader.js +238 -0
- package/dist/embeddings/wasm/EmbeddingPostProcessor.d.ts +60 -0
- package/dist/embeddings/wasm/EmbeddingPostProcessor.js +123 -0
- package/dist/embeddings/wasm/ONNXInferenceEngine.d.ts +55 -0
- package/dist/embeddings/wasm/ONNXInferenceEngine.js +154 -0
- package/dist/embeddings/wasm/WASMEmbeddingEngine.d.ts +82 -0
- package/dist/embeddings/wasm/WASMEmbeddingEngine.js +231 -0
- package/dist/embeddings/wasm/WordPieceTokenizer.d.ts +71 -0
- package/dist/embeddings/wasm/WordPieceTokenizer.js +264 -0
- package/dist/embeddings/wasm/index.d.ts +13 -0
- package/dist/embeddings/wasm/index.js +15 -0
- package/dist/embeddings/wasm/types.d.ts +114 -0
- package/dist/embeddings/wasm/types.js +25 -0
- package/dist/setup.d.ts +11 -11
- package/dist/setup.js +17 -31
- package/dist/utils/embedding.d.ts +45 -62
- package/dist/utils/embedding.js +61 -440
- package/package.json +10 -3
- package/scripts/download-model.cjs +175 -0
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Embedding Post-Processor
|
|
3
|
+
*
|
|
4
|
+
* Converts raw ONNX model output to final embedding vectors.
|
|
5
|
+
* Implements mean pooling and L2 normalization as used by sentence-transformers.
|
|
6
|
+
*
|
|
7
|
+
* Pipeline:
|
|
8
|
+
* 1. Mean Pooling: Average token embeddings (weighted by attention mask)
|
|
9
|
+
* 2. L2 Normalization: Normalize to unit length for cosine similarity
|
|
10
|
+
*/
|
|
11
|
+
import { MODEL_CONSTANTS } from './types.js';
|
|
12
|
+
/**
|
|
13
|
+
* Post-processor for converting ONNX output to sentence embeddings
|
|
14
|
+
*/
|
|
15
|
+
export class EmbeddingPostProcessor {
|
|
16
|
+
constructor(hiddenSize = MODEL_CONSTANTS.HIDDEN_SIZE) {
|
|
17
|
+
this.hiddenSize = hiddenSize;
|
|
18
|
+
}
|
|
19
|
+
/**
|
|
20
|
+
* Mean pool token embeddings weighted by attention mask
|
|
21
|
+
*
|
|
22
|
+
* @param hiddenStates - Raw model output [seqLen * hiddenSize] flattened
|
|
23
|
+
* @param attentionMask - Attention mask [seqLen] (1 for real tokens, 0 for padding)
|
|
24
|
+
* @param seqLen - Sequence length
|
|
25
|
+
* @returns Mean-pooled embedding [hiddenSize]
|
|
26
|
+
*/
|
|
27
|
+
meanPool(hiddenStates, attentionMask, seqLen) {
|
|
28
|
+
const result = new Float32Array(this.hiddenSize);
|
|
29
|
+
// Sum of attention mask (number of real tokens)
|
|
30
|
+
let maskSum = 0;
|
|
31
|
+
for (let i = 0; i < seqLen; i++) {
|
|
32
|
+
maskSum += attentionMask[i];
|
|
33
|
+
}
|
|
34
|
+
// Avoid division by zero
|
|
35
|
+
if (maskSum === 0) {
|
|
36
|
+
maskSum = 1;
|
|
37
|
+
}
|
|
38
|
+
// Compute weighted sum for each dimension
|
|
39
|
+
for (let dim = 0; dim < this.hiddenSize; dim++) {
|
|
40
|
+
let sum = 0;
|
|
41
|
+
for (let pos = 0; pos < seqLen; pos++) {
|
|
42
|
+
// Get hidden state at [pos, dim]
|
|
43
|
+
const value = hiddenStates[pos * this.hiddenSize + dim];
|
|
44
|
+
// Weight by attention mask
|
|
45
|
+
sum += value * attentionMask[pos];
|
|
46
|
+
}
|
|
47
|
+
// Mean pool
|
|
48
|
+
result[dim] = sum / maskSum;
|
|
49
|
+
}
|
|
50
|
+
return result;
|
|
51
|
+
}
|
|
52
|
+
/**
|
|
53
|
+
* L2 normalize embedding to unit length
|
|
54
|
+
*
|
|
55
|
+
* @param embedding - Input embedding
|
|
56
|
+
* @returns Normalized embedding with ||x|| = 1
|
|
57
|
+
*/
|
|
58
|
+
normalize(embedding) {
|
|
59
|
+
// Compute L2 norm
|
|
60
|
+
let sumSquares = 0;
|
|
61
|
+
for (let i = 0; i < embedding.length; i++) {
|
|
62
|
+
sumSquares += embedding[i] * embedding[i];
|
|
63
|
+
}
|
|
64
|
+
const norm = Math.sqrt(sumSquares);
|
|
65
|
+
// Avoid division by zero
|
|
66
|
+
if (norm === 0) {
|
|
67
|
+
return embedding;
|
|
68
|
+
}
|
|
69
|
+
// Normalize
|
|
70
|
+
const result = new Float32Array(embedding.length);
|
|
71
|
+
for (let i = 0; i < embedding.length; i++) {
|
|
72
|
+
result[i] = embedding[i] / norm;
|
|
73
|
+
}
|
|
74
|
+
return result;
|
|
75
|
+
}
|
|
76
|
+
/**
|
|
77
|
+
* Full post-processing pipeline: mean pool then normalize
|
|
78
|
+
*
|
|
79
|
+
* @param hiddenStates - Raw model output [seqLen * hiddenSize]
|
|
80
|
+
* @param attentionMask - Attention mask [seqLen]
|
|
81
|
+
* @param seqLen - Sequence length
|
|
82
|
+
* @returns Final normalized embedding [hiddenSize]
|
|
83
|
+
*/
|
|
84
|
+
process(hiddenStates, attentionMask, seqLen) {
|
|
85
|
+
const pooled = this.meanPool(hiddenStates, attentionMask, seqLen);
|
|
86
|
+
return this.normalize(pooled);
|
|
87
|
+
}
|
|
88
|
+
/**
|
|
89
|
+
* Process batch of embeddings
|
|
90
|
+
*
|
|
91
|
+
* @param hiddenStates - Raw model output [batchSize * seqLen * hiddenSize]
|
|
92
|
+
* @param attentionMasks - Attention masks [batchSize][seqLen]
|
|
93
|
+
* @param batchSize - Number of sequences in batch
|
|
94
|
+
* @param seqLen - Sequence length (same for all in batch due to padding)
|
|
95
|
+
* @returns Array of normalized embeddings
|
|
96
|
+
*/
|
|
97
|
+
processBatch(hiddenStates, attentionMasks, batchSize, seqLen) {
|
|
98
|
+
const results = [];
|
|
99
|
+
const sequenceSize = seqLen * this.hiddenSize;
|
|
100
|
+
for (let b = 0; b < batchSize; b++) {
|
|
101
|
+
// Extract this sequence's hidden states
|
|
102
|
+
const start = b * sequenceSize;
|
|
103
|
+
const seqHiddenStates = hiddenStates.slice(start, start + sequenceSize);
|
|
104
|
+
// Process
|
|
105
|
+
const embedding = this.process(seqHiddenStates, attentionMasks[b], seqLen);
|
|
106
|
+
results.push(embedding);
|
|
107
|
+
}
|
|
108
|
+
return results;
|
|
109
|
+
}
|
|
110
|
+
/**
|
|
111
|
+
* Convert Float32Array to number array
|
|
112
|
+
*/
|
|
113
|
+
toNumberArray(embedding) {
|
|
114
|
+
return Array.from(embedding);
|
|
115
|
+
}
|
|
116
|
+
}
|
|
117
|
+
/**
|
|
118
|
+
* Create a post-processor with default configuration
|
|
119
|
+
*/
|
|
120
|
+
export function createPostProcessor() {
|
|
121
|
+
return new EmbeddingPostProcessor(MODEL_CONSTANTS.HIDDEN_SIZE);
|
|
122
|
+
}
|
|
123
|
+
//# sourceMappingURL=EmbeddingPostProcessor.js.map
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* ONNX Inference Engine
|
|
3
|
+
*
|
|
4
|
+
* Direct ONNX Runtime Web wrapper for running model inference.
|
|
5
|
+
* Uses WASM backend for universal compatibility (Node.js, Bun, Browser).
|
|
6
|
+
*
|
|
7
|
+
* This replaces transformers.js dependency with direct ONNX control.
|
|
8
|
+
*/
|
|
9
|
+
import { InferenceConfig } from './types.js';
|
|
10
|
+
/**
|
|
11
|
+
* ONNX Inference Engine using onnxruntime-web
|
|
12
|
+
*/
|
|
13
|
+
export declare class ONNXInferenceEngine {
|
|
14
|
+
private session;
|
|
15
|
+
private initialized;
|
|
16
|
+
private modelPath;
|
|
17
|
+
private config;
|
|
18
|
+
constructor(config?: Partial<InferenceConfig>);
|
|
19
|
+
/**
|
|
20
|
+
* Initialize the ONNX session
|
|
21
|
+
*/
|
|
22
|
+
initialize(modelPath?: string): Promise<void>;
|
|
23
|
+
/**
|
|
24
|
+
* Run inference on tokenized input
|
|
25
|
+
*
|
|
26
|
+
* @param inputIds - Token IDs [batchSize, seqLen]
|
|
27
|
+
* @param attentionMask - Attention mask [batchSize, seqLen]
|
|
28
|
+
* @param tokenTypeIds - Token type IDs [batchSize, seqLen] (optional, defaults to zeros)
|
|
29
|
+
* @returns Hidden states [batchSize, seqLen, hiddenSize]
|
|
30
|
+
*/
|
|
31
|
+
infer(inputIds: number[][], attentionMask: number[][], tokenTypeIds?: number[][]): Promise<Float32Array>;
|
|
32
|
+
/**
|
|
33
|
+
* Infer single sequence (convenience method)
|
|
34
|
+
*/
|
|
35
|
+
inferSingle(inputIds: number[], attentionMask: number[], tokenTypeIds?: number[]): Promise<Float32Array>;
|
|
36
|
+
/**
|
|
37
|
+
* Check if initialized
|
|
38
|
+
*/
|
|
39
|
+
isInitialized(): boolean;
|
|
40
|
+
/**
|
|
41
|
+
* Get model input/output names (for debugging)
|
|
42
|
+
*/
|
|
43
|
+
getModelInfo(): {
|
|
44
|
+
inputs: readonly string[];
|
|
45
|
+
outputs: readonly string[];
|
|
46
|
+
} | null;
|
|
47
|
+
/**
|
|
48
|
+
* Dispose of the session and free resources
|
|
49
|
+
*/
|
|
50
|
+
dispose(): Promise<void>;
|
|
51
|
+
}
|
|
52
|
+
/**
|
|
53
|
+
* Create an inference engine with default configuration
|
|
54
|
+
*/
|
|
55
|
+
export declare function createInferenceEngine(modelPath: string): ONNXInferenceEngine;
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* ONNX Inference Engine
|
|
3
|
+
*
|
|
4
|
+
* Direct ONNX Runtime Web wrapper for running model inference.
|
|
5
|
+
* Uses WASM backend for universal compatibility (Node.js, Bun, Browser).
|
|
6
|
+
*
|
|
7
|
+
* This replaces transformers.js dependency with direct ONNX control.
|
|
8
|
+
*/
|
|
9
|
+
import * as ort from 'onnxruntime-web';
|
|
10
|
+
// Configure ONNX Runtime for WASM-only
|
|
11
|
+
ort.env.wasm.numThreads = 1; // Single-threaded for stability
|
|
12
|
+
ort.env.wasm.simd = true; // Enable SIMD where available
|
|
13
|
+
/**
|
|
14
|
+
* ONNX Inference Engine using onnxruntime-web
|
|
15
|
+
*/
|
|
16
|
+
export class ONNXInferenceEngine {
|
|
17
|
+
constructor(config = {}) {
|
|
18
|
+
this.session = null;
|
|
19
|
+
this.initialized = false;
|
|
20
|
+
this.modelPath = config.modelPath ?? '';
|
|
21
|
+
this.config = {
|
|
22
|
+
modelPath: this.modelPath,
|
|
23
|
+
numThreads: config.numThreads ?? 1,
|
|
24
|
+
enableSimd: config.enableSimd ?? true,
|
|
25
|
+
enableCpuMemArena: config.enableCpuMemArena ?? false,
|
|
26
|
+
};
|
|
27
|
+
}
|
|
28
|
+
/**
|
|
29
|
+
* Initialize the ONNX session
|
|
30
|
+
*/
|
|
31
|
+
async initialize(modelPath) {
|
|
32
|
+
if (this.initialized && this.session) {
|
|
33
|
+
return;
|
|
34
|
+
}
|
|
35
|
+
const path = modelPath ?? this.modelPath;
|
|
36
|
+
if (!path) {
|
|
37
|
+
throw new Error('Model path is required');
|
|
38
|
+
}
|
|
39
|
+
try {
|
|
40
|
+
// Configure session options
|
|
41
|
+
const sessionOptions = {
|
|
42
|
+
executionProviders: ['wasm'],
|
|
43
|
+
graphOptimizationLevel: 'all',
|
|
44
|
+
enableCpuMemArena: this.config.enableCpuMemArena,
|
|
45
|
+
// Additional WASM-specific options
|
|
46
|
+
executionMode: 'sequential',
|
|
47
|
+
};
|
|
48
|
+
// Load model from file path or URL
|
|
49
|
+
this.session = await ort.InferenceSession.create(path, sessionOptions);
|
|
50
|
+
this.initialized = true;
|
|
51
|
+
}
|
|
52
|
+
catch (error) {
|
|
53
|
+
this.initialized = false;
|
|
54
|
+
this.session = null;
|
|
55
|
+
throw new Error(`Failed to initialize ONNX session: ${error instanceof Error ? error.message : String(error)}`);
|
|
56
|
+
}
|
|
57
|
+
}
|
|
58
|
+
/**
|
|
59
|
+
* Run inference on tokenized input
|
|
60
|
+
*
|
|
61
|
+
* @param inputIds - Token IDs [batchSize, seqLen]
|
|
62
|
+
* @param attentionMask - Attention mask [batchSize, seqLen]
|
|
63
|
+
* @param tokenTypeIds - Token type IDs [batchSize, seqLen] (optional, defaults to zeros)
|
|
64
|
+
* @returns Hidden states [batchSize, seqLen, hiddenSize]
|
|
65
|
+
*/
|
|
66
|
+
async infer(inputIds, attentionMask, tokenTypeIds) {
|
|
67
|
+
if (!this.session) {
|
|
68
|
+
throw new Error('Session not initialized. Call initialize() first.');
|
|
69
|
+
}
|
|
70
|
+
const batchSize = inputIds.length;
|
|
71
|
+
const seqLen = inputIds[0].length;
|
|
72
|
+
// Convert to BigInt64Array (ONNX int64 type)
|
|
73
|
+
const inputIdsFlat = new BigInt64Array(batchSize * seqLen);
|
|
74
|
+
const attentionMaskFlat = new BigInt64Array(batchSize * seqLen);
|
|
75
|
+
const tokenTypeIdsFlat = new BigInt64Array(batchSize * seqLen);
|
|
76
|
+
for (let b = 0; b < batchSize; b++) {
|
|
77
|
+
for (let s = 0; s < seqLen; s++) {
|
|
78
|
+
const idx = b * seqLen + s;
|
|
79
|
+
inputIdsFlat[idx] = BigInt(inputIds[b][s]);
|
|
80
|
+
attentionMaskFlat[idx] = BigInt(attentionMask[b][s]);
|
|
81
|
+
tokenTypeIdsFlat[idx] = tokenTypeIds
|
|
82
|
+
? BigInt(tokenTypeIds[b][s])
|
|
83
|
+
: BigInt(0);
|
|
84
|
+
}
|
|
85
|
+
}
|
|
86
|
+
// Create ONNX tensors
|
|
87
|
+
const inputIdsTensor = new ort.Tensor('int64', inputIdsFlat, [batchSize, seqLen]);
|
|
88
|
+
const attentionMaskTensor = new ort.Tensor('int64', attentionMaskFlat, [batchSize, seqLen]);
|
|
89
|
+
const tokenTypeIdsTensor = new ort.Tensor('int64', tokenTypeIdsFlat, [batchSize, seqLen]);
|
|
90
|
+
try {
|
|
91
|
+
// Run inference
|
|
92
|
+
const feeds = {
|
|
93
|
+
input_ids: inputIdsTensor,
|
|
94
|
+
attention_mask: attentionMaskTensor,
|
|
95
|
+
token_type_ids: tokenTypeIdsTensor,
|
|
96
|
+
};
|
|
97
|
+
const results = await this.session.run(feeds);
|
|
98
|
+
// Extract last_hidden_state (the output we need for mean pooling)
|
|
99
|
+
// Model outputs: last_hidden_state [batch, seq, hidden] and pooler_output [batch, hidden]
|
|
100
|
+
const output = results.last_hidden_state ?? results.token_embeddings;
|
|
101
|
+
if (!output) {
|
|
102
|
+
throw new Error('Model did not return expected output tensor');
|
|
103
|
+
}
|
|
104
|
+
return output.data;
|
|
105
|
+
}
|
|
106
|
+
finally {
|
|
107
|
+
// Dispose tensors to free memory
|
|
108
|
+
inputIdsTensor.dispose();
|
|
109
|
+
attentionMaskTensor.dispose();
|
|
110
|
+
tokenTypeIdsTensor.dispose();
|
|
111
|
+
}
|
|
112
|
+
}
|
|
113
|
+
/**
|
|
114
|
+
* Infer single sequence (convenience method)
|
|
115
|
+
*/
|
|
116
|
+
async inferSingle(inputIds, attentionMask, tokenTypeIds) {
|
|
117
|
+
return this.infer([inputIds], [attentionMask], tokenTypeIds ? [tokenTypeIds] : undefined);
|
|
118
|
+
}
|
|
119
|
+
/**
|
|
120
|
+
* Check if initialized
|
|
121
|
+
*/
|
|
122
|
+
isInitialized() {
|
|
123
|
+
return this.initialized;
|
|
124
|
+
}
|
|
125
|
+
/**
|
|
126
|
+
* Get model input/output names (for debugging)
|
|
127
|
+
*/
|
|
128
|
+
getModelInfo() {
|
|
129
|
+
if (!this.session) {
|
|
130
|
+
return null;
|
|
131
|
+
}
|
|
132
|
+
return {
|
|
133
|
+
inputs: this.session.inputNames,
|
|
134
|
+
outputs: this.session.outputNames,
|
|
135
|
+
};
|
|
136
|
+
}
|
|
137
|
+
/**
|
|
138
|
+
* Dispose of the session and free resources
|
|
139
|
+
*/
|
|
140
|
+
async dispose() {
|
|
141
|
+
if (this.session) {
|
|
142
|
+
// Release the session
|
|
143
|
+
this.session = null;
|
|
144
|
+
}
|
|
145
|
+
this.initialized = false;
|
|
146
|
+
}
|
|
147
|
+
}
|
|
148
|
+
/**
|
|
149
|
+
* Create an inference engine with default configuration
|
|
150
|
+
*/
|
|
151
|
+
export function createInferenceEngine(modelPath) {
|
|
152
|
+
return new ONNXInferenceEngine({ modelPath });
|
|
153
|
+
}
|
|
154
|
+
//# sourceMappingURL=ONNXInferenceEngine.js.map
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* WASM Embedding Engine
|
|
3
|
+
*
|
|
4
|
+
* The main embedding engine that combines all components:
|
|
5
|
+
* - WordPieceTokenizer: Text → Token IDs
|
|
6
|
+
* - ONNXInferenceEngine: Token IDs → Hidden States
|
|
7
|
+
* - EmbeddingPostProcessor: Hidden States → Normalized Embedding
|
|
8
|
+
*
|
|
9
|
+
* This replaces transformers.js with a clean, production-grade implementation.
|
|
10
|
+
*
|
|
11
|
+
* Features:
|
|
12
|
+
* - Singleton pattern (one model instance)
|
|
13
|
+
* - Lazy initialization
|
|
14
|
+
* - Batch processing support
|
|
15
|
+
* - Zero runtime dependencies
|
|
16
|
+
*/
|
|
17
|
+
import { EmbeddingResult, EngineStats } from './types.js';
|
|
18
|
+
/**
|
|
19
|
+
* WASM-based embedding engine
|
|
20
|
+
*/
|
|
21
|
+
export declare class WASMEmbeddingEngine {
|
|
22
|
+
private tokenizer;
|
|
23
|
+
private inference;
|
|
24
|
+
private postProcessor;
|
|
25
|
+
private initialized;
|
|
26
|
+
private embedCount;
|
|
27
|
+
private totalProcessingTimeMs;
|
|
28
|
+
private constructor();
|
|
29
|
+
/**
|
|
30
|
+
* Get the singleton instance
|
|
31
|
+
*/
|
|
32
|
+
static getInstance(): WASMEmbeddingEngine;
|
|
33
|
+
/**
|
|
34
|
+
* Initialize all components
|
|
35
|
+
*/
|
|
36
|
+
initialize(): Promise<void>;
|
|
37
|
+
/**
|
|
38
|
+
* Perform actual initialization
|
|
39
|
+
*/
|
|
40
|
+
private performInit;
|
|
41
|
+
/**
|
|
42
|
+
* Generate embedding for text
|
|
43
|
+
*/
|
|
44
|
+
embed(text: string): Promise<number[]>;
|
|
45
|
+
/**
|
|
46
|
+
* Generate embedding with metadata
|
|
47
|
+
*/
|
|
48
|
+
embedWithMetadata(text: string): Promise<EmbeddingResult>;
|
|
49
|
+
/**
|
|
50
|
+
* Batch embed multiple texts
|
|
51
|
+
*/
|
|
52
|
+
embedBatch(texts: string[]): Promise<number[][]>;
|
|
53
|
+
/**
|
|
54
|
+
* Check if initialized
|
|
55
|
+
*/
|
|
56
|
+
isInitialized(): boolean;
|
|
57
|
+
/**
|
|
58
|
+
* Get engine statistics
|
|
59
|
+
*/
|
|
60
|
+
getStats(): EngineStats;
|
|
61
|
+
/**
|
|
62
|
+
* Dispose and free resources
|
|
63
|
+
*/
|
|
64
|
+
dispose(): Promise<void>;
|
|
65
|
+
/**
|
|
66
|
+
* Reset singleton (for testing)
|
|
67
|
+
*/
|
|
68
|
+
static resetInstance(): void;
|
|
69
|
+
}
|
|
70
|
+
export declare const wasmEmbeddingEngine: WASMEmbeddingEngine;
|
|
71
|
+
/**
|
|
72
|
+
* Convenience function to get embeddings
|
|
73
|
+
*/
|
|
74
|
+
export declare function embed(text: string): Promise<number[]>;
|
|
75
|
+
/**
|
|
76
|
+
* Convenience function for batch embeddings
|
|
77
|
+
*/
|
|
78
|
+
export declare function embedBatch(texts: string[]): Promise<number[][]>;
|
|
79
|
+
/**
|
|
80
|
+
* Get embedding stats
|
|
81
|
+
*/
|
|
82
|
+
export declare function getEmbeddingStats(): EngineStats;
|
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* WASM Embedding Engine
|
|
3
|
+
*
|
|
4
|
+
* The main embedding engine that combines all components:
|
|
5
|
+
* - WordPieceTokenizer: Text → Token IDs
|
|
6
|
+
* - ONNXInferenceEngine: Token IDs → Hidden States
|
|
7
|
+
* - EmbeddingPostProcessor: Hidden States → Normalized Embedding
|
|
8
|
+
*
|
|
9
|
+
* This replaces transformers.js with a clean, production-grade implementation.
|
|
10
|
+
*
|
|
11
|
+
* Features:
|
|
12
|
+
* - Singleton pattern (one model instance)
|
|
13
|
+
* - Lazy initialization
|
|
14
|
+
* - Batch processing support
|
|
15
|
+
* - Zero runtime dependencies
|
|
16
|
+
*/
|
|
17
|
+
import { WordPieceTokenizer } from './WordPieceTokenizer.js';
|
|
18
|
+
import { ONNXInferenceEngine } from './ONNXInferenceEngine.js';
|
|
19
|
+
import { EmbeddingPostProcessor } from './EmbeddingPostProcessor.js';
|
|
20
|
+
import { getAssetLoader } from './AssetLoader.js';
|
|
21
|
+
import { MODEL_CONSTANTS } from './types.js';
|
|
22
|
+
// Global singleton instance
|
|
23
|
+
let globalInstance = null;
|
|
24
|
+
let globalInitPromise = null;
|
|
25
|
+
/**
|
|
26
|
+
* WASM-based embedding engine
|
|
27
|
+
*/
|
|
28
|
+
export class WASMEmbeddingEngine {
|
|
29
|
+
constructor() {
|
|
30
|
+
this.tokenizer = null;
|
|
31
|
+
this.inference = null;
|
|
32
|
+
this.postProcessor = null;
|
|
33
|
+
this.initialized = false;
|
|
34
|
+
this.embedCount = 0;
|
|
35
|
+
this.totalProcessingTimeMs = 0;
|
|
36
|
+
// Private constructor for singleton
|
|
37
|
+
}
|
|
38
|
+
/**
|
|
39
|
+
* Get the singleton instance
|
|
40
|
+
*/
|
|
41
|
+
static getInstance() {
|
|
42
|
+
if (!globalInstance) {
|
|
43
|
+
globalInstance = new WASMEmbeddingEngine();
|
|
44
|
+
}
|
|
45
|
+
return globalInstance;
|
|
46
|
+
}
|
|
47
|
+
/**
|
|
48
|
+
* Initialize all components
|
|
49
|
+
*/
|
|
50
|
+
async initialize() {
|
|
51
|
+
// Already initialized
|
|
52
|
+
if (this.initialized) {
|
|
53
|
+
return;
|
|
54
|
+
}
|
|
55
|
+
// Initialization in progress
|
|
56
|
+
if (globalInitPromise) {
|
|
57
|
+
await globalInitPromise;
|
|
58
|
+
return;
|
|
59
|
+
}
|
|
60
|
+
// Start initialization
|
|
61
|
+
globalInitPromise = this.performInit();
|
|
62
|
+
try {
|
|
63
|
+
await globalInitPromise;
|
|
64
|
+
}
|
|
65
|
+
finally {
|
|
66
|
+
globalInitPromise = null;
|
|
67
|
+
}
|
|
68
|
+
}
|
|
69
|
+
/**
|
|
70
|
+
* Perform actual initialization
|
|
71
|
+
*/
|
|
72
|
+
async performInit() {
|
|
73
|
+
const startTime = Date.now();
|
|
74
|
+
console.log('🚀 Initializing WASM Embedding Engine...');
|
|
75
|
+
try {
|
|
76
|
+
const assetLoader = getAssetLoader();
|
|
77
|
+
// Verify assets exist
|
|
78
|
+
const verification = await assetLoader.verifyAssets();
|
|
79
|
+
if (!verification.valid) {
|
|
80
|
+
throw new Error(`Missing model assets:\n${verification.errors.join('\n')}\n\n` +
|
|
81
|
+
`Expected model at: ${verification.modelPath}\n` +
|
|
82
|
+
`Expected vocab at: ${verification.vocabPath}\n\n` +
|
|
83
|
+
`Run 'npm run download-model' to download the model files.`);
|
|
84
|
+
}
|
|
85
|
+
// Load vocabulary and create tokenizer
|
|
86
|
+
console.log('📖 Loading vocabulary...');
|
|
87
|
+
const vocab = await assetLoader.loadVocab();
|
|
88
|
+
this.tokenizer = new WordPieceTokenizer(vocab);
|
|
89
|
+
console.log(`✅ Vocabulary loaded: ${this.tokenizer.vocabSize} tokens`);
|
|
90
|
+
// Initialize ONNX inference engine
|
|
91
|
+
console.log('🧠 Loading ONNX model...');
|
|
92
|
+
const modelPath = await assetLoader.getModelPath();
|
|
93
|
+
this.inference = new ONNXInferenceEngine({ modelPath });
|
|
94
|
+
await this.inference.initialize(modelPath);
|
|
95
|
+
console.log('✅ ONNX model loaded');
|
|
96
|
+
// Create post-processor
|
|
97
|
+
this.postProcessor = new EmbeddingPostProcessor(MODEL_CONSTANTS.HIDDEN_SIZE);
|
|
98
|
+
this.initialized = true;
|
|
99
|
+
const initTime = Date.now() - startTime;
|
|
100
|
+
console.log(`✅ WASM Embedding Engine ready in ${initTime}ms`);
|
|
101
|
+
}
|
|
102
|
+
catch (error) {
|
|
103
|
+
this.initialized = false;
|
|
104
|
+
this.tokenizer = null;
|
|
105
|
+
this.inference = null;
|
|
106
|
+
this.postProcessor = null;
|
|
107
|
+
throw new Error(`Failed to initialize WASM Embedding Engine: ${error instanceof Error ? error.message : String(error)}`);
|
|
108
|
+
}
|
|
109
|
+
}
|
|
110
|
+
/**
|
|
111
|
+
* Generate embedding for text
|
|
112
|
+
*/
|
|
113
|
+
async embed(text) {
|
|
114
|
+
const result = await this.embedWithMetadata(text);
|
|
115
|
+
return result.embedding;
|
|
116
|
+
}
|
|
117
|
+
/**
|
|
118
|
+
* Generate embedding with metadata
|
|
119
|
+
*/
|
|
120
|
+
async embedWithMetadata(text) {
|
|
121
|
+
// Ensure initialized
|
|
122
|
+
if (!this.initialized) {
|
|
123
|
+
await this.initialize();
|
|
124
|
+
}
|
|
125
|
+
if (!this.tokenizer || !this.inference || !this.postProcessor) {
|
|
126
|
+
throw new Error('Engine not properly initialized');
|
|
127
|
+
}
|
|
128
|
+
const startTime = Date.now();
|
|
129
|
+
// 1. Tokenize
|
|
130
|
+
const tokenized = this.tokenizer.encode(text);
|
|
131
|
+
// 2. Run inference
|
|
132
|
+
const hiddenStates = await this.inference.inferSingle(tokenized.inputIds, tokenized.attentionMask, tokenized.tokenTypeIds);
|
|
133
|
+
// 3. Post-process (mean pool + normalize)
|
|
134
|
+
const embedding = this.postProcessor.process(hiddenStates, tokenized.attentionMask, tokenized.inputIds.length);
|
|
135
|
+
const processingTimeMs = Date.now() - startTime;
|
|
136
|
+
this.embedCount++;
|
|
137
|
+
this.totalProcessingTimeMs += processingTimeMs;
|
|
138
|
+
return {
|
|
139
|
+
embedding: Array.from(embedding),
|
|
140
|
+
tokenCount: tokenized.tokenCount,
|
|
141
|
+
processingTimeMs,
|
|
142
|
+
};
|
|
143
|
+
}
|
|
144
|
+
/**
|
|
145
|
+
* Batch embed multiple texts
|
|
146
|
+
*/
|
|
147
|
+
async embedBatch(texts) {
|
|
148
|
+
// Ensure initialized
|
|
149
|
+
if (!this.initialized) {
|
|
150
|
+
await this.initialize();
|
|
151
|
+
}
|
|
152
|
+
if (!this.tokenizer || !this.inference || !this.postProcessor) {
|
|
153
|
+
throw new Error('Engine not properly initialized');
|
|
154
|
+
}
|
|
155
|
+
if (texts.length === 0) {
|
|
156
|
+
return [];
|
|
157
|
+
}
|
|
158
|
+
// Tokenize all texts
|
|
159
|
+
const batch = this.tokenizer.encodeBatch(texts);
|
|
160
|
+
const seqLen = batch.inputIds[0].length;
|
|
161
|
+
// Run batch inference
|
|
162
|
+
const hiddenStates = await this.inference.infer(batch.inputIds, batch.attentionMask, batch.tokenTypeIds);
|
|
163
|
+
// Post-process each result
|
|
164
|
+
const embeddings = this.postProcessor.processBatch(hiddenStates, batch.attentionMask, texts.length, seqLen);
|
|
165
|
+
this.embedCount += texts.length;
|
|
166
|
+
return embeddings.map(e => Array.from(e));
|
|
167
|
+
}
|
|
168
|
+
/**
|
|
169
|
+
* Check if initialized
|
|
170
|
+
*/
|
|
171
|
+
isInitialized() {
|
|
172
|
+
return this.initialized;
|
|
173
|
+
}
|
|
174
|
+
/**
|
|
175
|
+
* Get engine statistics
|
|
176
|
+
*/
|
|
177
|
+
getStats() {
|
|
178
|
+
return {
|
|
179
|
+
initialized: this.initialized,
|
|
180
|
+
embedCount: this.embedCount,
|
|
181
|
+
totalProcessingTimeMs: this.totalProcessingTimeMs,
|
|
182
|
+
avgProcessingTimeMs: this.embedCount > 0
|
|
183
|
+
? this.totalProcessingTimeMs / this.embedCount
|
|
184
|
+
: 0,
|
|
185
|
+
modelName: MODEL_CONSTANTS.MODEL_NAME,
|
|
186
|
+
};
|
|
187
|
+
}
|
|
188
|
+
/**
|
|
189
|
+
* Dispose and free resources
|
|
190
|
+
*/
|
|
191
|
+
async dispose() {
|
|
192
|
+
if (this.inference) {
|
|
193
|
+
await this.inference.dispose();
|
|
194
|
+
this.inference = null;
|
|
195
|
+
}
|
|
196
|
+
this.tokenizer = null;
|
|
197
|
+
this.postProcessor = null;
|
|
198
|
+
this.initialized = false;
|
|
199
|
+
}
|
|
200
|
+
/**
|
|
201
|
+
* Reset singleton (for testing)
|
|
202
|
+
*/
|
|
203
|
+
static resetInstance() {
|
|
204
|
+
if (globalInstance) {
|
|
205
|
+
globalInstance.dispose();
|
|
206
|
+
}
|
|
207
|
+
globalInstance = null;
|
|
208
|
+
globalInitPromise = null;
|
|
209
|
+
}
|
|
210
|
+
}
|
|
211
|
+
// Export singleton access
|
|
212
|
+
export const wasmEmbeddingEngine = WASMEmbeddingEngine.getInstance();
|
|
213
|
+
/**
|
|
214
|
+
* Convenience function to get embeddings
|
|
215
|
+
*/
|
|
216
|
+
export async function embed(text) {
|
|
217
|
+
return wasmEmbeddingEngine.embed(text);
|
|
218
|
+
}
|
|
219
|
+
/**
|
|
220
|
+
* Convenience function for batch embeddings
|
|
221
|
+
*/
|
|
222
|
+
export async function embedBatch(texts) {
|
|
223
|
+
return wasmEmbeddingEngine.embedBatch(texts);
|
|
224
|
+
}
|
|
225
|
+
/**
|
|
226
|
+
* Get embedding stats
|
|
227
|
+
*/
|
|
228
|
+
export function getEmbeddingStats() {
|
|
229
|
+
return wasmEmbeddingEngine.getStats();
|
|
230
|
+
}
|
|
231
|
+
//# sourceMappingURL=WASMEmbeddingEngine.js.map
|