@soulcraft/brainy 6.5.0 → 6.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. package/assets/models/all-MiniLM-L6-v2-q8/config.json +25 -0
  2. package/assets/models/all-MiniLM-L6-v2-q8/model.onnx +0 -0
  3. package/assets/models/all-MiniLM-L6-v2-q8/tokenizer.json +30686 -0
  4. package/assets/models/all-MiniLM-L6-v2-q8/vocab.json +1 -0
  5. package/dist/brainy.js +0 -6
  6. package/dist/config/index.d.ts +1 -3
  7. package/dist/config/index.js +2 -4
  8. package/dist/config/modelAutoConfig.d.ts +10 -17
  9. package/dist/config/modelAutoConfig.js +15 -88
  10. package/dist/config/sharedConfigManager.d.ts +1 -2
  11. package/dist/config/zeroConfig.d.ts +2 -13
  12. package/dist/config/zeroConfig.js +7 -15
  13. package/dist/critical/model-guardian.d.ts +5 -22
  14. package/dist/critical/model-guardian.js +38 -210
  15. package/dist/embeddings/EmbeddingManager.d.ts +7 -17
  16. package/dist/embeddings/EmbeddingManager.js +28 -136
  17. package/dist/embeddings/wasm/AssetLoader.d.ts +67 -0
  18. package/dist/embeddings/wasm/AssetLoader.js +238 -0
  19. package/dist/embeddings/wasm/EmbeddingPostProcessor.d.ts +60 -0
  20. package/dist/embeddings/wasm/EmbeddingPostProcessor.js +123 -0
  21. package/dist/embeddings/wasm/ONNXInferenceEngine.d.ts +55 -0
  22. package/dist/embeddings/wasm/ONNXInferenceEngine.js +154 -0
  23. package/dist/embeddings/wasm/WASMEmbeddingEngine.d.ts +82 -0
  24. package/dist/embeddings/wasm/WASMEmbeddingEngine.js +231 -0
  25. package/dist/embeddings/wasm/WordPieceTokenizer.d.ts +71 -0
  26. package/dist/embeddings/wasm/WordPieceTokenizer.js +264 -0
  27. package/dist/embeddings/wasm/index.d.ts +13 -0
  28. package/dist/embeddings/wasm/index.js +15 -0
  29. package/dist/embeddings/wasm/types.d.ts +114 -0
  30. package/dist/embeddings/wasm/types.js +25 -0
  31. package/dist/setup.d.ts +11 -11
  32. package/dist/setup.js +17 -31
  33. package/dist/types/brainy.types.d.ts +0 -5
  34. package/dist/utils/embedding.d.ts +45 -62
  35. package/dist/utils/embedding.js +61 -440
  36. package/package.json +10 -3
  37. package/scripts/download-model.cjs +175 -0
@@ -0,0 +1,67 @@
1
+ /**
2
+ * Asset Loader
3
+ *
4
+ * Resolves paths to model files (ONNX model, vocabulary) across environments.
5
+ * Handles Node.js, Bun, and bundled scenarios.
6
+ *
7
+ * Asset Resolution Order:
8
+ * 1. Environment variable: BRAINY_MODEL_PATH
9
+ * 2. Package-relative: node_modules/@soulcraft/brainy/assets/models/
10
+ * 3. Project-relative: ./assets/models/
11
+ */
12
+ /**
13
+ * Asset loader for model files
14
+ */
15
+ export declare class AssetLoader {
16
+ private modelDir;
17
+ /**
18
+ * Get the model directory path
19
+ */
20
+ getModelDir(): Promise<string>;
21
+ /**
22
+ * Resolve the model directory across environments
23
+ */
24
+ private resolveModelDir;
25
+ /**
26
+ * Get package root path (Node.js/Bun only)
27
+ */
28
+ private getPackageRootPath;
29
+ /**
30
+ * Check if path exists (works in Node.js/Bun)
31
+ */
32
+ private pathExists;
33
+ /**
34
+ * Get path to ONNX model file
35
+ */
36
+ getModelPath(): Promise<string>;
37
+ /**
38
+ * Get path to vocabulary file
39
+ */
40
+ getVocabPath(): Promise<string>;
41
+ /**
42
+ * Load vocabulary from JSON file
43
+ */
44
+ loadVocab(): Promise<Record<string, number>>;
45
+ /**
46
+ * Load model as ArrayBuffer (for ONNX session)
47
+ */
48
+ loadModel(): Promise<ArrayBuffer>;
49
+ /**
50
+ * Verify all required assets exist
51
+ */
52
+ verifyAssets(): Promise<{
53
+ valid: boolean;
54
+ modelPath: string;
55
+ vocabPath: string;
56
+ errors: string[];
57
+ }>;
58
+ /**
59
+ * Clear cached paths (for testing)
60
+ */
61
+ clearCache(): void;
62
+ }
63
+ /**
64
+ * Create asset loader instance
65
+ */
66
+ export declare function createAssetLoader(): AssetLoader;
67
+ export declare function getAssetLoader(): AssetLoader;
@@ -0,0 +1,238 @@
1
+ /**
2
+ * Asset Loader
3
+ *
4
+ * Resolves paths to model files (ONNX model, vocabulary) across environments.
5
+ * Handles Node.js, Bun, and bundled scenarios.
6
+ *
7
+ * Asset Resolution Order:
8
+ * 1. Environment variable: BRAINY_MODEL_PATH
9
+ * 2. Package-relative: node_modules/@soulcraft/brainy/assets/models/
10
+ * 3. Project-relative: ./assets/models/
11
+ */
12
+ import { MODEL_CONSTANTS } from './types.js';
13
+ // Cache resolved paths
14
+ let cachedModelDir = null;
15
+ let cachedVocab = null;
16
+ /**
17
+ * Asset loader for model files
18
+ */
19
+ export class AssetLoader {
20
+ constructor() {
21
+ this.modelDir = null;
22
+ }
23
+ /**
24
+ * Get the model directory path
25
+ */
26
+ async getModelDir() {
27
+ if (this.modelDir) {
28
+ return this.modelDir;
29
+ }
30
+ if (cachedModelDir) {
31
+ this.modelDir = cachedModelDir;
32
+ return cachedModelDir;
33
+ }
34
+ // Try to resolve model directory
35
+ const resolved = await this.resolveModelDir();
36
+ this.modelDir = resolved;
37
+ cachedModelDir = resolved;
38
+ return resolved;
39
+ }
40
+ /**
41
+ * Resolve the model directory across environments
42
+ */
43
+ async resolveModelDir() {
44
+ // 1. Check environment variable
45
+ if (typeof process !== 'undefined' && process.env?.BRAINY_MODEL_PATH) {
46
+ const envPath = process.env.BRAINY_MODEL_PATH;
47
+ if (await this.pathExists(envPath)) {
48
+ return envPath;
49
+ }
50
+ }
51
+ // 2. Try common locations
52
+ const modelName = MODEL_CONSTANTS.MODEL_NAME + '-q8';
53
+ const possiblePaths = [
54
+ // Package assets (when installed as dependency)
55
+ `./assets/models/${modelName}`,
56
+ `./node_modules/@soulcraft/brainy/assets/models/${modelName}`,
57
+ // Development paths
58
+ `../assets/models/${modelName}`,
59
+ // Absolute from package root
60
+ this.getPackageRootPath(`assets/models/${modelName}`),
61
+ ].filter(Boolean);
62
+ for (const path of possiblePaths) {
63
+ if (await this.pathExists(path)) {
64
+ return path;
65
+ }
66
+ }
67
+ // If no path found, return default (will error on use)
68
+ return `./assets/models/${modelName}`;
69
+ }
70
+ /**
71
+ * Get package root path (Node.js/Bun only)
72
+ */
73
+ getPackageRootPath(relativePath) {
74
+ if (typeof process === 'undefined') {
75
+ return null;
76
+ }
77
+ try {
78
+ // Use __dirname equivalent
79
+ const url = new URL(import.meta.url);
80
+ const currentDir = url.pathname.replace(/\/[^/]*$/, '');
81
+ // Go up from src/embeddings/wasm to package root
82
+ const packageRoot = currentDir.replace(/\/src\/embeddings\/wasm$/, '');
83
+ return `${packageRoot}/${relativePath}`;
84
+ }
85
+ catch {
86
+ return null;
87
+ }
88
+ }
89
+ /**
90
+ * Check if path exists (works in Node.js/Bun)
91
+ */
92
+ async pathExists(path) {
93
+ if (typeof process === 'undefined') {
94
+ // Browser - check via fetch
95
+ try {
96
+ const response = await fetch(path, { method: 'HEAD' });
97
+ return response.ok;
98
+ }
99
+ catch {
100
+ return false;
101
+ }
102
+ }
103
+ // Node.js/Bun
104
+ try {
105
+ const fs = await import('node:fs/promises');
106
+ await fs.access(path);
107
+ return true;
108
+ }
109
+ catch {
110
+ return false;
111
+ }
112
+ }
113
+ /**
114
+ * Get path to ONNX model file
115
+ */
116
+ async getModelPath() {
117
+ const dir = await this.getModelDir();
118
+ return `${dir}/model.onnx`;
119
+ }
120
+ /**
121
+ * Get path to vocabulary file
122
+ */
123
+ async getVocabPath() {
124
+ const dir = await this.getModelDir();
125
+ return `${dir}/vocab.json`;
126
+ }
127
+ /**
128
+ * Load vocabulary from JSON file
129
+ */
130
+ async loadVocab() {
131
+ if (cachedVocab) {
132
+ return cachedVocab;
133
+ }
134
+ const vocabPath = await this.getVocabPath();
135
+ if (typeof process !== 'undefined') {
136
+ // Node.js/Bun - read from filesystem
137
+ try {
138
+ const fs = await import('node:fs/promises');
139
+ const content = await fs.readFile(vocabPath, 'utf-8');
140
+ cachedVocab = JSON.parse(content);
141
+ return cachedVocab;
142
+ }
143
+ catch (error) {
144
+ throw new Error(`Failed to load vocabulary from ${vocabPath}: ${error instanceof Error ? error.message : String(error)}`);
145
+ }
146
+ }
147
+ else {
148
+ // Browser - fetch
149
+ try {
150
+ const response = await fetch(vocabPath);
151
+ if (!response.ok) {
152
+ throw new Error(`HTTP ${response.status}`);
153
+ }
154
+ cachedVocab = await response.json();
155
+ return cachedVocab;
156
+ }
157
+ catch (error) {
158
+ throw new Error(`Failed to fetch vocabulary from ${vocabPath}: ${error instanceof Error ? error.message : String(error)}`);
159
+ }
160
+ }
161
+ }
162
+ /**
163
+ * Load model as ArrayBuffer (for ONNX session)
164
+ */
165
+ async loadModel() {
166
+ const modelPath = await this.getModelPath();
167
+ if (typeof process !== 'undefined') {
168
+ // Node.js/Bun - read from filesystem
169
+ try {
170
+ const fs = await import('node:fs/promises');
171
+ const buffer = await fs.readFile(modelPath);
172
+ // Convert Node.js Buffer to ArrayBuffer
173
+ return new Uint8Array(buffer).buffer;
174
+ }
175
+ catch (error) {
176
+ throw new Error(`Failed to load model from ${modelPath}: ${error instanceof Error ? error.message : String(error)}`);
177
+ }
178
+ }
179
+ else {
180
+ // Browser - fetch
181
+ try {
182
+ const response = await fetch(modelPath);
183
+ if (!response.ok) {
184
+ throw new Error(`HTTP ${response.status}`);
185
+ }
186
+ return await response.arrayBuffer();
187
+ }
188
+ catch (error) {
189
+ throw new Error(`Failed to fetch model from ${modelPath}: ${error instanceof Error ? error.message : String(error)}`);
190
+ }
191
+ }
192
+ }
193
+ /**
194
+ * Verify all required assets exist
195
+ */
196
+ async verifyAssets() {
197
+ const errors = [];
198
+ const modelPath = await this.getModelPath();
199
+ const vocabPath = await this.getVocabPath();
200
+ if (!(await this.pathExists(modelPath))) {
201
+ errors.push(`Model file not found: ${modelPath}`);
202
+ }
203
+ if (!(await this.pathExists(vocabPath))) {
204
+ errors.push(`Vocabulary file not found: ${vocabPath}`);
205
+ }
206
+ return {
207
+ valid: errors.length === 0,
208
+ modelPath,
209
+ vocabPath,
210
+ errors,
211
+ };
212
+ }
213
+ /**
214
+ * Clear cached paths (for testing)
215
+ */
216
+ clearCache() {
217
+ this.modelDir = null;
218
+ cachedModelDir = null;
219
+ cachedVocab = null;
220
+ }
221
+ }
222
+ /**
223
+ * Create asset loader instance
224
+ */
225
+ export function createAssetLoader() {
226
+ return new AssetLoader();
227
+ }
228
+ /**
229
+ * Singleton asset loader
230
+ */
231
+ let singletonLoader = null;
232
+ export function getAssetLoader() {
233
+ if (!singletonLoader) {
234
+ singletonLoader = new AssetLoader();
235
+ }
236
+ return singletonLoader;
237
+ }
238
+ //# sourceMappingURL=AssetLoader.js.map
@@ -0,0 +1,60 @@
1
+ /**
2
+ * Embedding Post-Processor
3
+ *
4
+ * Converts raw ONNX model output to final embedding vectors.
5
+ * Implements mean pooling and L2 normalization as used by sentence-transformers.
6
+ *
7
+ * Pipeline:
8
+ * 1. Mean Pooling: Average token embeddings (weighted by attention mask)
9
+ * 2. L2 Normalization: Normalize to unit length for cosine similarity
10
+ */
11
+ /**
12
+ * Post-processor for converting ONNX output to sentence embeddings
13
+ */
14
+ export declare class EmbeddingPostProcessor {
15
+ private hiddenSize;
16
+ constructor(hiddenSize?: number);
17
+ /**
18
+ * Mean pool token embeddings weighted by attention mask
19
+ *
20
+ * @param hiddenStates - Raw model output [seqLen * hiddenSize] flattened
21
+ * @param attentionMask - Attention mask [seqLen] (1 for real tokens, 0 for padding)
22
+ * @param seqLen - Sequence length
23
+ * @returns Mean-pooled embedding [hiddenSize]
24
+ */
25
+ meanPool(hiddenStates: Float32Array, attentionMask: number[], seqLen: number): Float32Array;
26
+ /**
27
+ * L2 normalize embedding to unit length
28
+ *
29
+ * @param embedding - Input embedding
30
+ * @returns Normalized embedding with ||x|| = 1
31
+ */
32
+ normalize(embedding: Float32Array): Float32Array;
33
+ /**
34
+ * Full post-processing pipeline: mean pool then normalize
35
+ *
36
+ * @param hiddenStates - Raw model output [seqLen * hiddenSize]
37
+ * @param attentionMask - Attention mask [seqLen]
38
+ * @param seqLen - Sequence length
39
+ * @returns Final normalized embedding [hiddenSize]
40
+ */
41
+ process(hiddenStates: Float32Array, attentionMask: number[], seqLen: number): Float32Array;
42
+ /**
43
+ * Process batch of embeddings
44
+ *
45
+ * @param hiddenStates - Raw model output [batchSize * seqLen * hiddenSize]
46
+ * @param attentionMasks - Attention masks [batchSize][seqLen]
47
+ * @param batchSize - Number of sequences in batch
48
+ * @param seqLen - Sequence length (same for all in batch due to padding)
49
+ * @returns Array of normalized embeddings
50
+ */
51
+ processBatch(hiddenStates: Float32Array, attentionMasks: number[][], batchSize: number, seqLen: number): Float32Array[];
52
+ /**
53
+ * Convert Float32Array to number array
54
+ */
55
+ toNumberArray(embedding: Float32Array): number[];
56
+ }
57
+ /**
58
+ * Create a post-processor with default configuration
59
+ */
60
+ export declare function createPostProcessor(): EmbeddingPostProcessor;
@@ -0,0 +1,123 @@
1
+ /**
2
+ * Embedding Post-Processor
3
+ *
4
+ * Converts raw ONNX model output to final embedding vectors.
5
+ * Implements mean pooling and L2 normalization as used by sentence-transformers.
6
+ *
7
+ * Pipeline:
8
+ * 1. Mean Pooling: Average token embeddings (weighted by attention mask)
9
+ * 2. L2 Normalization: Normalize to unit length for cosine similarity
10
+ */
11
+ import { MODEL_CONSTANTS } from './types.js';
12
+ /**
13
+ * Post-processor for converting ONNX output to sentence embeddings
14
+ */
15
+ export class EmbeddingPostProcessor {
16
+ constructor(hiddenSize = MODEL_CONSTANTS.HIDDEN_SIZE) {
17
+ this.hiddenSize = hiddenSize;
18
+ }
19
+ /**
20
+ * Mean pool token embeddings weighted by attention mask
21
+ *
22
+ * @param hiddenStates - Raw model output [seqLen * hiddenSize] flattened
23
+ * @param attentionMask - Attention mask [seqLen] (1 for real tokens, 0 for padding)
24
+ * @param seqLen - Sequence length
25
+ * @returns Mean-pooled embedding [hiddenSize]
26
+ */
27
+ meanPool(hiddenStates, attentionMask, seqLen) {
28
+ const result = new Float32Array(this.hiddenSize);
29
+ // Sum of attention mask (number of real tokens)
30
+ let maskSum = 0;
31
+ for (let i = 0; i < seqLen; i++) {
32
+ maskSum += attentionMask[i];
33
+ }
34
+ // Avoid division by zero
35
+ if (maskSum === 0) {
36
+ maskSum = 1;
37
+ }
38
+ // Compute weighted sum for each dimension
39
+ for (let dim = 0; dim < this.hiddenSize; dim++) {
40
+ let sum = 0;
41
+ for (let pos = 0; pos < seqLen; pos++) {
42
+ // Get hidden state at [pos, dim]
43
+ const value = hiddenStates[pos * this.hiddenSize + dim];
44
+ // Weight by attention mask
45
+ sum += value * attentionMask[pos];
46
+ }
47
+ // Mean pool
48
+ result[dim] = sum / maskSum;
49
+ }
50
+ return result;
51
+ }
52
+ /**
53
+ * L2 normalize embedding to unit length
54
+ *
55
+ * @param embedding - Input embedding
56
+ * @returns Normalized embedding with ||x|| = 1
57
+ */
58
+ normalize(embedding) {
59
+ // Compute L2 norm
60
+ let sumSquares = 0;
61
+ for (let i = 0; i < embedding.length; i++) {
62
+ sumSquares += embedding[i] * embedding[i];
63
+ }
64
+ const norm = Math.sqrt(sumSquares);
65
+ // Avoid division by zero
66
+ if (norm === 0) {
67
+ return embedding;
68
+ }
69
+ // Normalize
70
+ const result = new Float32Array(embedding.length);
71
+ for (let i = 0; i < embedding.length; i++) {
72
+ result[i] = embedding[i] / norm;
73
+ }
74
+ return result;
75
+ }
76
+ /**
77
+ * Full post-processing pipeline: mean pool then normalize
78
+ *
79
+ * @param hiddenStates - Raw model output [seqLen * hiddenSize]
80
+ * @param attentionMask - Attention mask [seqLen]
81
+ * @param seqLen - Sequence length
82
+ * @returns Final normalized embedding [hiddenSize]
83
+ */
84
+ process(hiddenStates, attentionMask, seqLen) {
85
+ const pooled = this.meanPool(hiddenStates, attentionMask, seqLen);
86
+ return this.normalize(pooled);
87
+ }
88
+ /**
89
+ * Process batch of embeddings
90
+ *
91
+ * @param hiddenStates - Raw model output [batchSize * seqLen * hiddenSize]
92
+ * @param attentionMasks - Attention masks [batchSize][seqLen]
93
+ * @param batchSize - Number of sequences in batch
94
+ * @param seqLen - Sequence length (same for all in batch due to padding)
95
+ * @returns Array of normalized embeddings
96
+ */
97
+ processBatch(hiddenStates, attentionMasks, batchSize, seqLen) {
98
+ const results = [];
99
+ const sequenceSize = seqLen * this.hiddenSize;
100
+ for (let b = 0; b < batchSize; b++) {
101
+ // Extract this sequence's hidden states
102
+ const start = b * sequenceSize;
103
+ const seqHiddenStates = hiddenStates.slice(start, start + sequenceSize);
104
+ // Process
105
+ const embedding = this.process(seqHiddenStates, attentionMasks[b], seqLen);
106
+ results.push(embedding);
107
+ }
108
+ return results;
109
+ }
110
+ /**
111
+ * Convert Float32Array to number array
112
+ */
113
+ toNumberArray(embedding) {
114
+ return Array.from(embedding);
115
+ }
116
+ }
117
+ /**
118
+ * Create a post-processor with default configuration
119
+ */
120
+ export function createPostProcessor() {
121
+ return new EmbeddingPostProcessor(MODEL_CONSTANTS.HIDDEN_SIZE);
122
+ }
123
+ //# sourceMappingURL=EmbeddingPostProcessor.js.map
@@ -0,0 +1,55 @@
1
+ /**
2
+ * ONNX Inference Engine
3
+ *
4
+ * Direct ONNX Runtime Web wrapper for running model inference.
5
+ * Uses WASM backend for universal compatibility (Node.js, Bun, Browser).
6
+ *
7
+ * This replaces transformers.js dependency with direct ONNX control.
8
+ */
9
+ import { InferenceConfig } from './types.js';
10
+ /**
11
+ * ONNX Inference Engine using onnxruntime-web
12
+ */
13
+ export declare class ONNXInferenceEngine {
14
+ private session;
15
+ private initialized;
16
+ private modelPath;
17
+ private config;
18
+ constructor(config?: Partial<InferenceConfig>);
19
+ /**
20
+ * Initialize the ONNX session
21
+ */
22
+ initialize(modelPath?: string): Promise<void>;
23
+ /**
24
+ * Run inference on tokenized input
25
+ *
26
+ * @param inputIds - Token IDs [batchSize, seqLen]
27
+ * @param attentionMask - Attention mask [batchSize, seqLen]
28
+ * @param tokenTypeIds - Token type IDs [batchSize, seqLen] (optional, defaults to zeros)
29
+ * @returns Hidden states [batchSize, seqLen, hiddenSize]
30
+ */
31
+ infer(inputIds: number[][], attentionMask: number[][], tokenTypeIds?: number[][]): Promise<Float32Array>;
32
+ /**
33
+ * Infer single sequence (convenience method)
34
+ */
35
+ inferSingle(inputIds: number[], attentionMask: number[], tokenTypeIds?: number[]): Promise<Float32Array>;
36
+ /**
37
+ * Check if initialized
38
+ */
39
+ isInitialized(): boolean;
40
+ /**
41
+ * Get model input/output names (for debugging)
42
+ */
43
+ getModelInfo(): {
44
+ inputs: readonly string[];
45
+ outputs: readonly string[];
46
+ } | null;
47
+ /**
48
+ * Dispose of the session and free resources
49
+ */
50
+ dispose(): Promise<void>;
51
+ }
52
+ /**
53
+ * Create an inference engine with default configuration
54
+ */
55
+ export declare function createInferenceEngine(modelPath: string): ONNXInferenceEngine;
@@ -0,0 +1,154 @@
1
+ /**
2
+ * ONNX Inference Engine
3
+ *
4
+ * Direct ONNX Runtime Web wrapper for running model inference.
5
+ * Uses WASM backend for universal compatibility (Node.js, Bun, Browser).
6
+ *
7
+ * This replaces transformers.js dependency with direct ONNX control.
8
+ */
9
+ import * as ort from 'onnxruntime-web';
10
+ // Configure ONNX Runtime for WASM-only
11
+ ort.env.wasm.numThreads = 1; // Single-threaded for stability
12
+ ort.env.wasm.simd = true; // Enable SIMD where available
13
+ /**
14
+ * ONNX Inference Engine using onnxruntime-web
15
+ */
16
+ export class ONNXInferenceEngine {
17
+ constructor(config = {}) {
18
+ this.session = null;
19
+ this.initialized = false;
20
+ this.modelPath = config.modelPath ?? '';
21
+ this.config = {
22
+ modelPath: this.modelPath,
23
+ numThreads: config.numThreads ?? 1,
24
+ enableSimd: config.enableSimd ?? true,
25
+ enableCpuMemArena: config.enableCpuMemArena ?? false,
26
+ };
27
+ }
28
+ /**
29
+ * Initialize the ONNX session
30
+ */
31
+ async initialize(modelPath) {
32
+ if (this.initialized && this.session) {
33
+ return;
34
+ }
35
+ const path = modelPath ?? this.modelPath;
36
+ if (!path) {
37
+ throw new Error('Model path is required');
38
+ }
39
+ try {
40
+ // Configure session options
41
+ const sessionOptions = {
42
+ executionProviders: ['wasm'],
43
+ graphOptimizationLevel: 'all',
44
+ enableCpuMemArena: this.config.enableCpuMemArena,
45
+ // Additional WASM-specific options
46
+ executionMode: 'sequential',
47
+ };
48
+ // Load model from file path or URL
49
+ this.session = await ort.InferenceSession.create(path, sessionOptions);
50
+ this.initialized = true;
51
+ }
52
+ catch (error) {
53
+ this.initialized = false;
54
+ this.session = null;
55
+ throw new Error(`Failed to initialize ONNX session: ${error instanceof Error ? error.message : String(error)}`);
56
+ }
57
+ }
58
+ /**
59
+ * Run inference on tokenized input
60
+ *
61
+ * @param inputIds - Token IDs [batchSize, seqLen]
62
+ * @param attentionMask - Attention mask [batchSize, seqLen]
63
+ * @param tokenTypeIds - Token type IDs [batchSize, seqLen] (optional, defaults to zeros)
64
+ * @returns Hidden states [batchSize, seqLen, hiddenSize]
65
+ */
66
+ async infer(inputIds, attentionMask, tokenTypeIds) {
67
+ if (!this.session) {
68
+ throw new Error('Session not initialized. Call initialize() first.');
69
+ }
70
+ const batchSize = inputIds.length;
71
+ const seqLen = inputIds[0].length;
72
+ // Convert to BigInt64Array (ONNX int64 type)
73
+ const inputIdsFlat = new BigInt64Array(batchSize * seqLen);
74
+ const attentionMaskFlat = new BigInt64Array(batchSize * seqLen);
75
+ const tokenTypeIdsFlat = new BigInt64Array(batchSize * seqLen);
76
+ for (let b = 0; b < batchSize; b++) {
77
+ for (let s = 0; s < seqLen; s++) {
78
+ const idx = b * seqLen + s;
79
+ inputIdsFlat[idx] = BigInt(inputIds[b][s]);
80
+ attentionMaskFlat[idx] = BigInt(attentionMask[b][s]);
81
+ tokenTypeIdsFlat[idx] = tokenTypeIds
82
+ ? BigInt(tokenTypeIds[b][s])
83
+ : BigInt(0);
84
+ }
85
+ }
86
+ // Create ONNX tensors
87
+ const inputIdsTensor = new ort.Tensor('int64', inputIdsFlat, [batchSize, seqLen]);
88
+ const attentionMaskTensor = new ort.Tensor('int64', attentionMaskFlat, [batchSize, seqLen]);
89
+ const tokenTypeIdsTensor = new ort.Tensor('int64', tokenTypeIdsFlat, [batchSize, seqLen]);
90
+ try {
91
+ // Run inference
92
+ const feeds = {
93
+ input_ids: inputIdsTensor,
94
+ attention_mask: attentionMaskTensor,
95
+ token_type_ids: tokenTypeIdsTensor,
96
+ };
97
+ const results = await this.session.run(feeds);
98
+ // Extract last_hidden_state (the output we need for mean pooling)
99
+ // Model outputs: last_hidden_state [batch, seq, hidden] and pooler_output [batch, hidden]
100
+ const output = results.last_hidden_state ?? results.token_embeddings;
101
+ if (!output) {
102
+ throw new Error('Model did not return expected output tensor');
103
+ }
104
+ return output.data;
105
+ }
106
+ finally {
107
+ // Dispose tensors to free memory
108
+ inputIdsTensor.dispose();
109
+ attentionMaskTensor.dispose();
110
+ tokenTypeIdsTensor.dispose();
111
+ }
112
+ }
113
+ /**
114
+ * Infer single sequence (convenience method)
115
+ */
116
+ async inferSingle(inputIds, attentionMask, tokenTypeIds) {
117
+ return this.infer([inputIds], [attentionMask], tokenTypeIds ? [tokenTypeIds] : undefined);
118
+ }
119
+ /**
120
+ * Check if initialized
121
+ */
122
+ isInitialized() {
123
+ return this.initialized;
124
+ }
125
+ /**
126
+ * Get model input/output names (for debugging)
127
+ */
128
+ getModelInfo() {
129
+ if (!this.session) {
130
+ return null;
131
+ }
132
+ return {
133
+ inputs: this.session.inputNames,
134
+ outputs: this.session.outputNames,
135
+ };
136
+ }
137
+ /**
138
+ * Dispose of the session and free resources
139
+ */
140
+ async dispose() {
141
+ if (this.session) {
142
+ // Release the session
143
+ this.session = null;
144
+ }
145
+ this.initialized = false;
146
+ }
147
+ }
148
+ /**
149
+ * Create an inference engine with default configuration
150
+ */
151
+ export function createInferenceEngine(modelPath) {
152
+ return new ONNXInferenceEngine({ modelPath });
153
+ }
154
+ //# sourceMappingURL=ONNXInferenceEngine.js.map