@soulcraft/brainy 6.5.0 → 6.6.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/assets/models/all-MiniLM-L6-v2-q8/config.json +25 -0
- package/assets/models/all-MiniLM-L6-v2-q8/model.onnx +0 -0
- package/assets/models/all-MiniLM-L6-v2-q8/tokenizer.json +30686 -0
- package/assets/models/all-MiniLM-L6-v2-q8/vocab.json +1 -0
- package/dist/brainy.js +0 -6
- package/dist/config/index.d.ts +1 -3
- package/dist/config/index.js +2 -4
- package/dist/config/modelAutoConfig.d.ts +10 -17
- package/dist/config/modelAutoConfig.js +15 -88
- package/dist/config/sharedConfigManager.d.ts +1 -2
- package/dist/config/zeroConfig.d.ts +2 -13
- package/dist/config/zeroConfig.js +7 -15
- package/dist/critical/model-guardian.d.ts +5 -22
- package/dist/critical/model-guardian.js +38 -210
- package/dist/embeddings/EmbeddingManager.d.ts +7 -17
- package/dist/embeddings/EmbeddingManager.js +28 -136
- package/dist/embeddings/wasm/AssetLoader.d.ts +67 -0
- package/dist/embeddings/wasm/AssetLoader.js +238 -0
- package/dist/embeddings/wasm/EmbeddingPostProcessor.d.ts +60 -0
- package/dist/embeddings/wasm/EmbeddingPostProcessor.js +123 -0
- package/dist/embeddings/wasm/ONNXInferenceEngine.d.ts +55 -0
- package/dist/embeddings/wasm/ONNXInferenceEngine.js +154 -0
- package/dist/embeddings/wasm/WASMEmbeddingEngine.d.ts +82 -0
- package/dist/embeddings/wasm/WASMEmbeddingEngine.js +231 -0
- package/dist/embeddings/wasm/WordPieceTokenizer.d.ts +71 -0
- package/dist/embeddings/wasm/WordPieceTokenizer.js +264 -0
- package/dist/embeddings/wasm/index.d.ts +13 -0
- package/dist/embeddings/wasm/index.js +15 -0
- package/dist/embeddings/wasm/types.d.ts +114 -0
- package/dist/embeddings/wasm/types.js +25 -0
- package/dist/setup.d.ts +11 -11
- package/dist/setup.js +17 -31
- package/dist/types/brainy.types.d.ts +0 -5
- package/dist/utils/embedding.d.ts +45 -62
- package/dist/utils/embedding.js +61 -440
- package/package.json +10 -3
- package/scripts/download-model.cjs +175 -0
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Asset Loader
|
|
3
|
+
*
|
|
4
|
+
* Resolves paths to model files (ONNX model, vocabulary) across environments.
|
|
5
|
+
* Handles Node.js, Bun, and bundled scenarios.
|
|
6
|
+
*
|
|
7
|
+
* Asset Resolution Order:
|
|
8
|
+
* 1. Environment variable: BRAINY_MODEL_PATH
|
|
9
|
+
* 2. Package-relative: node_modules/@soulcraft/brainy/assets/models/
|
|
10
|
+
* 3. Project-relative: ./assets/models/
|
|
11
|
+
*/
|
|
12
|
+
/**
|
|
13
|
+
* Asset loader for model files
|
|
14
|
+
*/
|
|
15
|
+
export declare class AssetLoader {
|
|
16
|
+
private modelDir;
|
|
17
|
+
/**
|
|
18
|
+
* Get the model directory path
|
|
19
|
+
*/
|
|
20
|
+
getModelDir(): Promise<string>;
|
|
21
|
+
/**
|
|
22
|
+
* Resolve the model directory across environments
|
|
23
|
+
*/
|
|
24
|
+
private resolveModelDir;
|
|
25
|
+
/**
|
|
26
|
+
* Get package root path (Node.js/Bun only)
|
|
27
|
+
*/
|
|
28
|
+
private getPackageRootPath;
|
|
29
|
+
/**
|
|
30
|
+
* Check if path exists (works in Node.js/Bun)
|
|
31
|
+
*/
|
|
32
|
+
private pathExists;
|
|
33
|
+
/**
|
|
34
|
+
* Get path to ONNX model file
|
|
35
|
+
*/
|
|
36
|
+
getModelPath(): Promise<string>;
|
|
37
|
+
/**
|
|
38
|
+
* Get path to vocabulary file
|
|
39
|
+
*/
|
|
40
|
+
getVocabPath(): Promise<string>;
|
|
41
|
+
/**
|
|
42
|
+
* Load vocabulary from JSON file
|
|
43
|
+
*/
|
|
44
|
+
loadVocab(): Promise<Record<string, number>>;
|
|
45
|
+
/**
|
|
46
|
+
* Load model as ArrayBuffer (for ONNX session)
|
|
47
|
+
*/
|
|
48
|
+
loadModel(): Promise<ArrayBuffer>;
|
|
49
|
+
/**
|
|
50
|
+
* Verify all required assets exist
|
|
51
|
+
*/
|
|
52
|
+
verifyAssets(): Promise<{
|
|
53
|
+
valid: boolean;
|
|
54
|
+
modelPath: string;
|
|
55
|
+
vocabPath: string;
|
|
56
|
+
errors: string[];
|
|
57
|
+
}>;
|
|
58
|
+
/**
|
|
59
|
+
* Clear cached paths (for testing)
|
|
60
|
+
*/
|
|
61
|
+
clearCache(): void;
|
|
62
|
+
}
|
|
63
|
+
/**
|
|
64
|
+
* Create asset loader instance
|
|
65
|
+
*/
|
|
66
|
+
export declare function createAssetLoader(): AssetLoader;
|
|
67
|
+
export declare function getAssetLoader(): AssetLoader;
|
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Asset Loader
|
|
3
|
+
*
|
|
4
|
+
* Resolves paths to model files (ONNX model, vocabulary) across environments.
|
|
5
|
+
* Handles Node.js, Bun, and bundled scenarios.
|
|
6
|
+
*
|
|
7
|
+
* Asset Resolution Order:
|
|
8
|
+
* 1. Environment variable: BRAINY_MODEL_PATH
|
|
9
|
+
* 2. Package-relative: node_modules/@soulcraft/brainy/assets/models/
|
|
10
|
+
* 3. Project-relative: ./assets/models/
|
|
11
|
+
*/
|
|
12
|
+
import { MODEL_CONSTANTS } from './types.js';
|
|
13
|
+
// Cache resolved paths
|
|
14
|
+
let cachedModelDir = null;
|
|
15
|
+
let cachedVocab = null;
|
|
16
|
+
/**
|
|
17
|
+
* Asset loader for model files
|
|
18
|
+
*/
|
|
19
|
+
export class AssetLoader {
|
|
20
|
+
constructor() {
|
|
21
|
+
this.modelDir = null;
|
|
22
|
+
}
|
|
23
|
+
/**
|
|
24
|
+
* Get the model directory path
|
|
25
|
+
*/
|
|
26
|
+
async getModelDir() {
|
|
27
|
+
if (this.modelDir) {
|
|
28
|
+
return this.modelDir;
|
|
29
|
+
}
|
|
30
|
+
if (cachedModelDir) {
|
|
31
|
+
this.modelDir = cachedModelDir;
|
|
32
|
+
return cachedModelDir;
|
|
33
|
+
}
|
|
34
|
+
// Try to resolve model directory
|
|
35
|
+
const resolved = await this.resolveModelDir();
|
|
36
|
+
this.modelDir = resolved;
|
|
37
|
+
cachedModelDir = resolved;
|
|
38
|
+
return resolved;
|
|
39
|
+
}
|
|
40
|
+
/**
|
|
41
|
+
* Resolve the model directory across environments
|
|
42
|
+
*/
|
|
43
|
+
async resolveModelDir() {
|
|
44
|
+
// 1. Check environment variable
|
|
45
|
+
if (typeof process !== 'undefined' && process.env?.BRAINY_MODEL_PATH) {
|
|
46
|
+
const envPath = process.env.BRAINY_MODEL_PATH;
|
|
47
|
+
if (await this.pathExists(envPath)) {
|
|
48
|
+
return envPath;
|
|
49
|
+
}
|
|
50
|
+
}
|
|
51
|
+
// 2. Try common locations
|
|
52
|
+
const modelName = MODEL_CONSTANTS.MODEL_NAME + '-q8';
|
|
53
|
+
const possiblePaths = [
|
|
54
|
+
// Package assets (when installed as dependency)
|
|
55
|
+
`./assets/models/${modelName}`,
|
|
56
|
+
`./node_modules/@soulcraft/brainy/assets/models/${modelName}`,
|
|
57
|
+
// Development paths
|
|
58
|
+
`../assets/models/${modelName}`,
|
|
59
|
+
// Absolute from package root
|
|
60
|
+
this.getPackageRootPath(`assets/models/${modelName}`),
|
|
61
|
+
].filter(Boolean);
|
|
62
|
+
for (const path of possiblePaths) {
|
|
63
|
+
if (await this.pathExists(path)) {
|
|
64
|
+
return path;
|
|
65
|
+
}
|
|
66
|
+
}
|
|
67
|
+
// If no path found, return default (will error on use)
|
|
68
|
+
return `./assets/models/${modelName}`;
|
|
69
|
+
}
|
|
70
|
+
/**
|
|
71
|
+
* Get package root path (Node.js/Bun only)
|
|
72
|
+
*/
|
|
73
|
+
getPackageRootPath(relativePath) {
|
|
74
|
+
if (typeof process === 'undefined') {
|
|
75
|
+
return null;
|
|
76
|
+
}
|
|
77
|
+
try {
|
|
78
|
+
// Use __dirname equivalent
|
|
79
|
+
const url = new URL(import.meta.url);
|
|
80
|
+
const currentDir = url.pathname.replace(/\/[^/]*$/, '');
|
|
81
|
+
// Go up from src/embeddings/wasm to package root
|
|
82
|
+
const packageRoot = currentDir.replace(/\/src\/embeddings\/wasm$/, '');
|
|
83
|
+
return `${packageRoot}/${relativePath}`;
|
|
84
|
+
}
|
|
85
|
+
catch {
|
|
86
|
+
return null;
|
|
87
|
+
}
|
|
88
|
+
}
|
|
89
|
+
/**
|
|
90
|
+
* Check if path exists (works in Node.js/Bun)
|
|
91
|
+
*/
|
|
92
|
+
async pathExists(path) {
|
|
93
|
+
if (typeof process === 'undefined') {
|
|
94
|
+
// Browser - check via fetch
|
|
95
|
+
try {
|
|
96
|
+
const response = await fetch(path, { method: 'HEAD' });
|
|
97
|
+
return response.ok;
|
|
98
|
+
}
|
|
99
|
+
catch {
|
|
100
|
+
return false;
|
|
101
|
+
}
|
|
102
|
+
}
|
|
103
|
+
// Node.js/Bun
|
|
104
|
+
try {
|
|
105
|
+
const fs = await import('node:fs/promises');
|
|
106
|
+
await fs.access(path);
|
|
107
|
+
return true;
|
|
108
|
+
}
|
|
109
|
+
catch {
|
|
110
|
+
return false;
|
|
111
|
+
}
|
|
112
|
+
}
|
|
113
|
+
/**
|
|
114
|
+
* Get path to ONNX model file
|
|
115
|
+
*/
|
|
116
|
+
async getModelPath() {
|
|
117
|
+
const dir = await this.getModelDir();
|
|
118
|
+
return `${dir}/model.onnx`;
|
|
119
|
+
}
|
|
120
|
+
/**
|
|
121
|
+
* Get path to vocabulary file
|
|
122
|
+
*/
|
|
123
|
+
async getVocabPath() {
|
|
124
|
+
const dir = await this.getModelDir();
|
|
125
|
+
return `${dir}/vocab.json`;
|
|
126
|
+
}
|
|
127
|
+
/**
|
|
128
|
+
* Load vocabulary from JSON file
|
|
129
|
+
*/
|
|
130
|
+
async loadVocab() {
|
|
131
|
+
if (cachedVocab) {
|
|
132
|
+
return cachedVocab;
|
|
133
|
+
}
|
|
134
|
+
const vocabPath = await this.getVocabPath();
|
|
135
|
+
if (typeof process !== 'undefined') {
|
|
136
|
+
// Node.js/Bun - read from filesystem
|
|
137
|
+
try {
|
|
138
|
+
const fs = await import('node:fs/promises');
|
|
139
|
+
const content = await fs.readFile(vocabPath, 'utf-8');
|
|
140
|
+
cachedVocab = JSON.parse(content);
|
|
141
|
+
return cachedVocab;
|
|
142
|
+
}
|
|
143
|
+
catch (error) {
|
|
144
|
+
throw new Error(`Failed to load vocabulary from ${vocabPath}: ${error instanceof Error ? error.message : String(error)}`);
|
|
145
|
+
}
|
|
146
|
+
}
|
|
147
|
+
else {
|
|
148
|
+
// Browser - fetch
|
|
149
|
+
try {
|
|
150
|
+
const response = await fetch(vocabPath);
|
|
151
|
+
if (!response.ok) {
|
|
152
|
+
throw new Error(`HTTP ${response.status}`);
|
|
153
|
+
}
|
|
154
|
+
cachedVocab = await response.json();
|
|
155
|
+
return cachedVocab;
|
|
156
|
+
}
|
|
157
|
+
catch (error) {
|
|
158
|
+
throw new Error(`Failed to fetch vocabulary from ${vocabPath}: ${error instanceof Error ? error.message : String(error)}`);
|
|
159
|
+
}
|
|
160
|
+
}
|
|
161
|
+
}
|
|
162
|
+
/**
|
|
163
|
+
* Load model as ArrayBuffer (for ONNX session)
|
|
164
|
+
*/
|
|
165
|
+
async loadModel() {
|
|
166
|
+
const modelPath = await this.getModelPath();
|
|
167
|
+
if (typeof process !== 'undefined') {
|
|
168
|
+
// Node.js/Bun - read from filesystem
|
|
169
|
+
try {
|
|
170
|
+
const fs = await import('node:fs/promises');
|
|
171
|
+
const buffer = await fs.readFile(modelPath);
|
|
172
|
+
// Convert Node.js Buffer to ArrayBuffer
|
|
173
|
+
return new Uint8Array(buffer).buffer;
|
|
174
|
+
}
|
|
175
|
+
catch (error) {
|
|
176
|
+
throw new Error(`Failed to load model from ${modelPath}: ${error instanceof Error ? error.message : String(error)}`);
|
|
177
|
+
}
|
|
178
|
+
}
|
|
179
|
+
else {
|
|
180
|
+
// Browser - fetch
|
|
181
|
+
try {
|
|
182
|
+
const response = await fetch(modelPath);
|
|
183
|
+
if (!response.ok) {
|
|
184
|
+
throw new Error(`HTTP ${response.status}`);
|
|
185
|
+
}
|
|
186
|
+
return await response.arrayBuffer();
|
|
187
|
+
}
|
|
188
|
+
catch (error) {
|
|
189
|
+
throw new Error(`Failed to fetch model from ${modelPath}: ${error instanceof Error ? error.message : String(error)}`);
|
|
190
|
+
}
|
|
191
|
+
}
|
|
192
|
+
}
|
|
193
|
+
/**
|
|
194
|
+
* Verify all required assets exist
|
|
195
|
+
*/
|
|
196
|
+
async verifyAssets() {
|
|
197
|
+
const errors = [];
|
|
198
|
+
const modelPath = await this.getModelPath();
|
|
199
|
+
const vocabPath = await this.getVocabPath();
|
|
200
|
+
if (!(await this.pathExists(modelPath))) {
|
|
201
|
+
errors.push(`Model file not found: ${modelPath}`);
|
|
202
|
+
}
|
|
203
|
+
if (!(await this.pathExists(vocabPath))) {
|
|
204
|
+
errors.push(`Vocabulary file not found: ${vocabPath}`);
|
|
205
|
+
}
|
|
206
|
+
return {
|
|
207
|
+
valid: errors.length === 0,
|
|
208
|
+
modelPath,
|
|
209
|
+
vocabPath,
|
|
210
|
+
errors,
|
|
211
|
+
};
|
|
212
|
+
}
|
|
213
|
+
/**
|
|
214
|
+
* Clear cached paths (for testing)
|
|
215
|
+
*/
|
|
216
|
+
clearCache() {
|
|
217
|
+
this.modelDir = null;
|
|
218
|
+
cachedModelDir = null;
|
|
219
|
+
cachedVocab = null;
|
|
220
|
+
}
|
|
221
|
+
}
|
|
222
|
+
/**
|
|
223
|
+
* Create asset loader instance
|
|
224
|
+
*/
|
|
225
|
+
export function createAssetLoader() {
|
|
226
|
+
return new AssetLoader();
|
|
227
|
+
}
|
|
228
|
+
/**
|
|
229
|
+
* Singleton asset loader
|
|
230
|
+
*/
|
|
231
|
+
let singletonLoader = null;
|
|
232
|
+
export function getAssetLoader() {
|
|
233
|
+
if (!singletonLoader) {
|
|
234
|
+
singletonLoader = new AssetLoader();
|
|
235
|
+
}
|
|
236
|
+
return singletonLoader;
|
|
237
|
+
}
|
|
238
|
+
//# sourceMappingURL=AssetLoader.js.map
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Embedding Post-Processor
|
|
3
|
+
*
|
|
4
|
+
* Converts raw ONNX model output to final embedding vectors.
|
|
5
|
+
* Implements mean pooling and L2 normalization as used by sentence-transformers.
|
|
6
|
+
*
|
|
7
|
+
* Pipeline:
|
|
8
|
+
* 1. Mean Pooling: Average token embeddings (weighted by attention mask)
|
|
9
|
+
* 2. L2 Normalization: Normalize to unit length for cosine similarity
|
|
10
|
+
*/
|
|
11
|
+
/**
|
|
12
|
+
* Post-processor for converting ONNX output to sentence embeddings
|
|
13
|
+
*/
|
|
14
|
+
export declare class EmbeddingPostProcessor {
|
|
15
|
+
private hiddenSize;
|
|
16
|
+
constructor(hiddenSize?: number);
|
|
17
|
+
/**
|
|
18
|
+
* Mean pool token embeddings weighted by attention mask
|
|
19
|
+
*
|
|
20
|
+
* @param hiddenStates - Raw model output [seqLen * hiddenSize] flattened
|
|
21
|
+
* @param attentionMask - Attention mask [seqLen] (1 for real tokens, 0 for padding)
|
|
22
|
+
* @param seqLen - Sequence length
|
|
23
|
+
* @returns Mean-pooled embedding [hiddenSize]
|
|
24
|
+
*/
|
|
25
|
+
meanPool(hiddenStates: Float32Array, attentionMask: number[], seqLen: number): Float32Array;
|
|
26
|
+
/**
|
|
27
|
+
* L2 normalize embedding to unit length
|
|
28
|
+
*
|
|
29
|
+
* @param embedding - Input embedding
|
|
30
|
+
* @returns Normalized embedding with ||x|| = 1
|
|
31
|
+
*/
|
|
32
|
+
normalize(embedding: Float32Array): Float32Array;
|
|
33
|
+
/**
|
|
34
|
+
* Full post-processing pipeline: mean pool then normalize
|
|
35
|
+
*
|
|
36
|
+
* @param hiddenStates - Raw model output [seqLen * hiddenSize]
|
|
37
|
+
* @param attentionMask - Attention mask [seqLen]
|
|
38
|
+
* @param seqLen - Sequence length
|
|
39
|
+
* @returns Final normalized embedding [hiddenSize]
|
|
40
|
+
*/
|
|
41
|
+
process(hiddenStates: Float32Array, attentionMask: number[], seqLen: number): Float32Array;
|
|
42
|
+
/**
|
|
43
|
+
* Process batch of embeddings
|
|
44
|
+
*
|
|
45
|
+
* @param hiddenStates - Raw model output [batchSize * seqLen * hiddenSize]
|
|
46
|
+
* @param attentionMasks - Attention masks [batchSize][seqLen]
|
|
47
|
+
* @param batchSize - Number of sequences in batch
|
|
48
|
+
* @param seqLen - Sequence length (same for all in batch due to padding)
|
|
49
|
+
* @returns Array of normalized embeddings
|
|
50
|
+
*/
|
|
51
|
+
processBatch(hiddenStates: Float32Array, attentionMasks: number[][], batchSize: number, seqLen: number): Float32Array[];
|
|
52
|
+
/**
|
|
53
|
+
* Convert Float32Array to number array
|
|
54
|
+
*/
|
|
55
|
+
toNumberArray(embedding: Float32Array): number[];
|
|
56
|
+
}
|
|
57
|
+
/**
|
|
58
|
+
* Create a post-processor with default configuration
|
|
59
|
+
*/
|
|
60
|
+
export declare function createPostProcessor(): EmbeddingPostProcessor;
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Embedding Post-Processor
|
|
3
|
+
*
|
|
4
|
+
* Converts raw ONNX model output to final embedding vectors.
|
|
5
|
+
* Implements mean pooling and L2 normalization as used by sentence-transformers.
|
|
6
|
+
*
|
|
7
|
+
* Pipeline:
|
|
8
|
+
* 1. Mean Pooling: Average token embeddings (weighted by attention mask)
|
|
9
|
+
* 2. L2 Normalization: Normalize to unit length for cosine similarity
|
|
10
|
+
*/
|
|
11
|
+
import { MODEL_CONSTANTS } from './types.js';
|
|
12
|
+
/**
|
|
13
|
+
* Post-processor for converting ONNX output to sentence embeddings
|
|
14
|
+
*/
|
|
15
|
+
export class EmbeddingPostProcessor {
|
|
16
|
+
constructor(hiddenSize = MODEL_CONSTANTS.HIDDEN_SIZE) {
|
|
17
|
+
this.hiddenSize = hiddenSize;
|
|
18
|
+
}
|
|
19
|
+
/**
|
|
20
|
+
* Mean pool token embeddings weighted by attention mask
|
|
21
|
+
*
|
|
22
|
+
* @param hiddenStates - Raw model output [seqLen * hiddenSize] flattened
|
|
23
|
+
* @param attentionMask - Attention mask [seqLen] (1 for real tokens, 0 for padding)
|
|
24
|
+
* @param seqLen - Sequence length
|
|
25
|
+
* @returns Mean-pooled embedding [hiddenSize]
|
|
26
|
+
*/
|
|
27
|
+
meanPool(hiddenStates, attentionMask, seqLen) {
|
|
28
|
+
const result = new Float32Array(this.hiddenSize);
|
|
29
|
+
// Sum of attention mask (number of real tokens)
|
|
30
|
+
let maskSum = 0;
|
|
31
|
+
for (let i = 0; i < seqLen; i++) {
|
|
32
|
+
maskSum += attentionMask[i];
|
|
33
|
+
}
|
|
34
|
+
// Avoid division by zero
|
|
35
|
+
if (maskSum === 0) {
|
|
36
|
+
maskSum = 1;
|
|
37
|
+
}
|
|
38
|
+
// Compute weighted sum for each dimension
|
|
39
|
+
for (let dim = 0; dim < this.hiddenSize; dim++) {
|
|
40
|
+
let sum = 0;
|
|
41
|
+
for (let pos = 0; pos < seqLen; pos++) {
|
|
42
|
+
// Get hidden state at [pos, dim]
|
|
43
|
+
const value = hiddenStates[pos * this.hiddenSize + dim];
|
|
44
|
+
// Weight by attention mask
|
|
45
|
+
sum += value * attentionMask[pos];
|
|
46
|
+
}
|
|
47
|
+
// Mean pool
|
|
48
|
+
result[dim] = sum / maskSum;
|
|
49
|
+
}
|
|
50
|
+
return result;
|
|
51
|
+
}
|
|
52
|
+
/**
|
|
53
|
+
* L2 normalize embedding to unit length
|
|
54
|
+
*
|
|
55
|
+
* @param embedding - Input embedding
|
|
56
|
+
* @returns Normalized embedding with ||x|| = 1
|
|
57
|
+
*/
|
|
58
|
+
normalize(embedding) {
|
|
59
|
+
// Compute L2 norm
|
|
60
|
+
let sumSquares = 0;
|
|
61
|
+
for (let i = 0; i < embedding.length; i++) {
|
|
62
|
+
sumSquares += embedding[i] * embedding[i];
|
|
63
|
+
}
|
|
64
|
+
const norm = Math.sqrt(sumSquares);
|
|
65
|
+
// Avoid division by zero
|
|
66
|
+
if (norm === 0) {
|
|
67
|
+
return embedding;
|
|
68
|
+
}
|
|
69
|
+
// Normalize
|
|
70
|
+
const result = new Float32Array(embedding.length);
|
|
71
|
+
for (let i = 0; i < embedding.length; i++) {
|
|
72
|
+
result[i] = embedding[i] / norm;
|
|
73
|
+
}
|
|
74
|
+
return result;
|
|
75
|
+
}
|
|
76
|
+
/**
|
|
77
|
+
* Full post-processing pipeline: mean pool then normalize
|
|
78
|
+
*
|
|
79
|
+
* @param hiddenStates - Raw model output [seqLen * hiddenSize]
|
|
80
|
+
* @param attentionMask - Attention mask [seqLen]
|
|
81
|
+
* @param seqLen - Sequence length
|
|
82
|
+
* @returns Final normalized embedding [hiddenSize]
|
|
83
|
+
*/
|
|
84
|
+
process(hiddenStates, attentionMask, seqLen) {
|
|
85
|
+
const pooled = this.meanPool(hiddenStates, attentionMask, seqLen);
|
|
86
|
+
return this.normalize(pooled);
|
|
87
|
+
}
|
|
88
|
+
/**
|
|
89
|
+
* Process batch of embeddings
|
|
90
|
+
*
|
|
91
|
+
* @param hiddenStates - Raw model output [batchSize * seqLen * hiddenSize]
|
|
92
|
+
* @param attentionMasks - Attention masks [batchSize][seqLen]
|
|
93
|
+
* @param batchSize - Number of sequences in batch
|
|
94
|
+
* @param seqLen - Sequence length (same for all in batch due to padding)
|
|
95
|
+
* @returns Array of normalized embeddings
|
|
96
|
+
*/
|
|
97
|
+
processBatch(hiddenStates, attentionMasks, batchSize, seqLen) {
|
|
98
|
+
const results = [];
|
|
99
|
+
const sequenceSize = seqLen * this.hiddenSize;
|
|
100
|
+
for (let b = 0; b < batchSize; b++) {
|
|
101
|
+
// Extract this sequence's hidden states
|
|
102
|
+
const start = b * sequenceSize;
|
|
103
|
+
const seqHiddenStates = hiddenStates.slice(start, start + sequenceSize);
|
|
104
|
+
// Process
|
|
105
|
+
const embedding = this.process(seqHiddenStates, attentionMasks[b], seqLen);
|
|
106
|
+
results.push(embedding);
|
|
107
|
+
}
|
|
108
|
+
return results;
|
|
109
|
+
}
|
|
110
|
+
/**
|
|
111
|
+
* Convert Float32Array to number array
|
|
112
|
+
*/
|
|
113
|
+
toNumberArray(embedding) {
|
|
114
|
+
return Array.from(embedding);
|
|
115
|
+
}
|
|
116
|
+
}
|
|
117
|
+
/**
|
|
118
|
+
* Create a post-processor with default configuration
|
|
119
|
+
*/
|
|
120
|
+
export function createPostProcessor() {
|
|
121
|
+
return new EmbeddingPostProcessor(MODEL_CONSTANTS.HIDDEN_SIZE);
|
|
122
|
+
}
|
|
123
|
+
//# sourceMappingURL=EmbeddingPostProcessor.js.map
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* ONNX Inference Engine
|
|
3
|
+
*
|
|
4
|
+
* Direct ONNX Runtime Web wrapper for running model inference.
|
|
5
|
+
* Uses WASM backend for universal compatibility (Node.js, Bun, Browser).
|
|
6
|
+
*
|
|
7
|
+
* This replaces transformers.js dependency with direct ONNX control.
|
|
8
|
+
*/
|
|
9
|
+
import { InferenceConfig } from './types.js';
|
|
10
|
+
/**
|
|
11
|
+
* ONNX Inference Engine using onnxruntime-web
|
|
12
|
+
*/
|
|
13
|
+
export declare class ONNXInferenceEngine {
|
|
14
|
+
private session;
|
|
15
|
+
private initialized;
|
|
16
|
+
private modelPath;
|
|
17
|
+
private config;
|
|
18
|
+
constructor(config?: Partial<InferenceConfig>);
|
|
19
|
+
/**
|
|
20
|
+
* Initialize the ONNX session
|
|
21
|
+
*/
|
|
22
|
+
initialize(modelPath?: string): Promise<void>;
|
|
23
|
+
/**
|
|
24
|
+
* Run inference on tokenized input
|
|
25
|
+
*
|
|
26
|
+
* @param inputIds - Token IDs [batchSize, seqLen]
|
|
27
|
+
* @param attentionMask - Attention mask [batchSize, seqLen]
|
|
28
|
+
* @param tokenTypeIds - Token type IDs [batchSize, seqLen] (optional, defaults to zeros)
|
|
29
|
+
* @returns Hidden states [batchSize, seqLen, hiddenSize]
|
|
30
|
+
*/
|
|
31
|
+
infer(inputIds: number[][], attentionMask: number[][], tokenTypeIds?: number[][]): Promise<Float32Array>;
|
|
32
|
+
/**
|
|
33
|
+
* Infer single sequence (convenience method)
|
|
34
|
+
*/
|
|
35
|
+
inferSingle(inputIds: number[], attentionMask: number[], tokenTypeIds?: number[]): Promise<Float32Array>;
|
|
36
|
+
/**
|
|
37
|
+
* Check if initialized
|
|
38
|
+
*/
|
|
39
|
+
isInitialized(): boolean;
|
|
40
|
+
/**
|
|
41
|
+
* Get model input/output names (for debugging)
|
|
42
|
+
*/
|
|
43
|
+
getModelInfo(): {
|
|
44
|
+
inputs: readonly string[];
|
|
45
|
+
outputs: readonly string[];
|
|
46
|
+
} | null;
|
|
47
|
+
/**
|
|
48
|
+
* Dispose of the session and free resources
|
|
49
|
+
*/
|
|
50
|
+
dispose(): Promise<void>;
|
|
51
|
+
}
|
|
52
|
+
/**
|
|
53
|
+
* Create an inference engine with default configuration
|
|
54
|
+
*/
|
|
55
|
+
export declare function createInferenceEngine(modelPath: string): ONNXInferenceEngine;
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* ONNX Inference Engine
|
|
3
|
+
*
|
|
4
|
+
* Direct ONNX Runtime Web wrapper for running model inference.
|
|
5
|
+
* Uses WASM backend for universal compatibility (Node.js, Bun, Browser).
|
|
6
|
+
*
|
|
7
|
+
* This replaces transformers.js dependency with direct ONNX control.
|
|
8
|
+
*/
|
|
9
|
+
import * as ort from 'onnxruntime-web';
|
|
10
|
+
// Configure ONNX Runtime for WASM-only
|
|
11
|
+
ort.env.wasm.numThreads = 1; // Single-threaded for stability
|
|
12
|
+
ort.env.wasm.simd = true; // Enable SIMD where available
|
|
13
|
+
/**
|
|
14
|
+
* ONNX Inference Engine using onnxruntime-web
|
|
15
|
+
*/
|
|
16
|
+
export class ONNXInferenceEngine {
|
|
17
|
+
constructor(config = {}) {
|
|
18
|
+
this.session = null;
|
|
19
|
+
this.initialized = false;
|
|
20
|
+
this.modelPath = config.modelPath ?? '';
|
|
21
|
+
this.config = {
|
|
22
|
+
modelPath: this.modelPath,
|
|
23
|
+
numThreads: config.numThreads ?? 1,
|
|
24
|
+
enableSimd: config.enableSimd ?? true,
|
|
25
|
+
enableCpuMemArena: config.enableCpuMemArena ?? false,
|
|
26
|
+
};
|
|
27
|
+
}
|
|
28
|
+
/**
|
|
29
|
+
* Initialize the ONNX session
|
|
30
|
+
*/
|
|
31
|
+
async initialize(modelPath) {
|
|
32
|
+
if (this.initialized && this.session) {
|
|
33
|
+
return;
|
|
34
|
+
}
|
|
35
|
+
const path = modelPath ?? this.modelPath;
|
|
36
|
+
if (!path) {
|
|
37
|
+
throw new Error('Model path is required');
|
|
38
|
+
}
|
|
39
|
+
try {
|
|
40
|
+
// Configure session options
|
|
41
|
+
const sessionOptions = {
|
|
42
|
+
executionProviders: ['wasm'],
|
|
43
|
+
graphOptimizationLevel: 'all',
|
|
44
|
+
enableCpuMemArena: this.config.enableCpuMemArena,
|
|
45
|
+
// Additional WASM-specific options
|
|
46
|
+
executionMode: 'sequential',
|
|
47
|
+
};
|
|
48
|
+
// Load model from file path or URL
|
|
49
|
+
this.session = await ort.InferenceSession.create(path, sessionOptions);
|
|
50
|
+
this.initialized = true;
|
|
51
|
+
}
|
|
52
|
+
catch (error) {
|
|
53
|
+
this.initialized = false;
|
|
54
|
+
this.session = null;
|
|
55
|
+
throw new Error(`Failed to initialize ONNX session: ${error instanceof Error ? error.message : String(error)}`);
|
|
56
|
+
}
|
|
57
|
+
}
|
|
58
|
+
/**
|
|
59
|
+
* Run inference on tokenized input
|
|
60
|
+
*
|
|
61
|
+
* @param inputIds - Token IDs [batchSize, seqLen]
|
|
62
|
+
* @param attentionMask - Attention mask [batchSize, seqLen]
|
|
63
|
+
* @param tokenTypeIds - Token type IDs [batchSize, seqLen] (optional, defaults to zeros)
|
|
64
|
+
* @returns Hidden states [batchSize, seqLen, hiddenSize]
|
|
65
|
+
*/
|
|
66
|
+
async infer(inputIds, attentionMask, tokenTypeIds) {
|
|
67
|
+
if (!this.session) {
|
|
68
|
+
throw new Error('Session not initialized. Call initialize() first.');
|
|
69
|
+
}
|
|
70
|
+
const batchSize = inputIds.length;
|
|
71
|
+
const seqLen = inputIds[0].length;
|
|
72
|
+
// Convert to BigInt64Array (ONNX int64 type)
|
|
73
|
+
const inputIdsFlat = new BigInt64Array(batchSize * seqLen);
|
|
74
|
+
const attentionMaskFlat = new BigInt64Array(batchSize * seqLen);
|
|
75
|
+
const tokenTypeIdsFlat = new BigInt64Array(batchSize * seqLen);
|
|
76
|
+
for (let b = 0; b < batchSize; b++) {
|
|
77
|
+
for (let s = 0; s < seqLen; s++) {
|
|
78
|
+
const idx = b * seqLen + s;
|
|
79
|
+
inputIdsFlat[idx] = BigInt(inputIds[b][s]);
|
|
80
|
+
attentionMaskFlat[idx] = BigInt(attentionMask[b][s]);
|
|
81
|
+
tokenTypeIdsFlat[idx] = tokenTypeIds
|
|
82
|
+
? BigInt(tokenTypeIds[b][s])
|
|
83
|
+
: BigInt(0);
|
|
84
|
+
}
|
|
85
|
+
}
|
|
86
|
+
// Create ONNX tensors
|
|
87
|
+
const inputIdsTensor = new ort.Tensor('int64', inputIdsFlat, [batchSize, seqLen]);
|
|
88
|
+
const attentionMaskTensor = new ort.Tensor('int64', attentionMaskFlat, [batchSize, seqLen]);
|
|
89
|
+
const tokenTypeIdsTensor = new ort.Tensor('int64', tokenTypeIdsFlat, [batchSize, seqLen]);
|
|
90
|
+
try {
|
|
91
|
+
// Run inference
|
|
92
|
+
const feeds = {
|
|
93
|
+
input_ids: inputIdsTensor,
|
|
94
|
+
attention_mask: attentionMaskTensor,
|
|
95
|
+
token_type_ids: tokenTypeIdsTensor,
|
|
96
|
+
};
|
|
97
|
+
const results = await this.session.run(feeds);
|
|
98
|
+
// Extract last_hidden_state (the output we need for mean pooling)
|
|
99
|
+
// Model outputs: last_hidden_state [batch, seq, hidden] and pooler_output [batch, hidden]
|
|
100
|
+
const output = results.last_hidden_state ?? results.token_embeddings;
|
|
101
|
+
if (!output) {
|
|
102
|
+
throw new Error('Model did not return expected output tensor');
|
|
103
|
+
}
|
|
104
|
+
return output.data;
|
|
105
|
+
}
|
|
106
|
+
finally {
|
|
107
|
+
// Dispose tensors to free memory
|
|
108
|
+
inputIdsTensor.dispose();
|
|
109
|
+
attentionMaskTensor.dispose();
|
|
110
|
+
tokenTypeIdsTensor.dispose();
|
|
111
|
+
}
|
|
112
|
+
}
|
|
113
|
+
/**
|
|
114
|
+
* Infer single sequence (convenience method)
|
|
115
|
+
*/
|
|
116
|
+
async inferSingle(inputIds, attentionMask, tokenTypeIds) {
|
|
117
|
+
return this.infer([inputIds], [attentionMask], tokenTypeIds ? [tokenTypeIds] : undefined);
|
|
118
|
+
}
|
|
119
|
+
/**
|
|
120
|
+
* Check if initialized
|
|
121
|
+
*/
|
|
122
|
+
isInitialized() {
|
|
123
|
+
return this.initialized;
|
|
124
|
+
}
|
|
125
|
+
/**
|
|
126
|
+
* Get model input/output names (for debugging)
|
|
127
|
+
*/
|
|
128
|
+
getModelInfo() {
|
|
129
|
+
if (!this.session) {
|
|
130
|
+
return null;
|
|
131
|
+
}
|
|
132
|
+
return {
|
|
133
|
+
inputs: this.session.inputNames,
|
|
134
|
+
outputs: this.session.outputNames,
|
|
135
|
+
};
|
|
136
|
+
}
|
|
137
|
+
/**
|
|
138
|
+
* Dispose of the session and free resources
|
|
139
|
+
*/
|
|
140
|
+
async dispose() {
|
|
141
|
+
if (this.session) {
|
|
142
|
+
// Release the session
|
|
143
|
+
this.session = null;
|
|
144
|
+
}
|
|
145
|
+
this.initialized = false;
|
|
146
|
+
}
|
|
147
|
+
}
|
|
148
|
+
/**
|
|
149
|
+
* Create an inference engine with default configuration
|
|
150
|
+
*/
|
|
151
|
+
export function createInferenceEngine(modelPath) {
|
|
152
|
+
return new ONNXInferenceEngine({ modelPath });
|
|
153
|
+
}
|
|
154
|
+
//# sourceMappingURL=ONNXInferenceEngine.js.map
|