codebasesearch 0.1.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.thornsignore +419 -0
- package/README.md +129 -0
- package/bin/code-search.js +27 -0
- package/mcp.js +276 -0
- package/package.json +35 -0
- package/scripts/patch-transformers.js +42 -0
- package/src/cli.js +113 -0
- package/src/embeddings.js +151 -0
- package/src/ignore-parser.js +129 -0
- package/src/patch-sharp.js +38 -0
- package/src/scanner.js +172 -0
- package/src/search.js +45 -0
- package/src/store.js +166 -0
package/mcp.js
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
1
|
+
#!/usr/bin/env node
|
|
2
|
+
|
|
3
|
+
// MUST patch sharp before any other imports
|
|
4
|
+
import fs from 'fs';
|
|
5
|
+
import path from 'path';
|
|
6
|
+
import { fileURLToPath } from 'url';
|
|
7
|
+
|
|
8
|
+
const __dirname = path.dirname(fileURLToPath(import.meta.url));
|
|
9
|
+
const distPath = path.join(__dirname, 'node_modules', '@huggingface', 'transformers', 'dist', 'transformers.node.mjs');
|
|
10
|
+
|
|
11
|
+
if (fs.existsSync(distPath)) {
|
|
12
|
+
let content = fs.readFileSync(distPath, 'utf-8');
|
|
13
|
+
if (!content.includes('SHARP_REMOVED_FOR_WINDOWS_COMPATIBILITY')) {
|
|
14
|
+
content = content.replace(/import \* as __WEBPACK_EXTERNAL_MODULE_sharp__ from "sharp";\n/, '// SHARP_REMOVED_FOR_WINDOWS_COMPATIBILITY\n');
|
|
15
|
+
content = content.replace(/module\.exports = __WEBPACK_EXTERNAL_MODULE_sharp__;/g, 'module.exports = {};');
|
|
16
|
+
content = content.replace(/} else \{\s*throw new Error\('Unable to load image processing library\.'\);\s*\}/, '} else {\n loadImageFunction = async () => { throw new Error(\'Image processing unavailable\'); };\n}');
|
|
17
|
+
try { fs.writeFileSync(distPath, content); } catch (e) {}
|
|
18
|
+
}
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
import { Server } from '@modelcontextprotocol/sdk/server/index.js';
|
|
22
|
+
import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js';
|
|
23
|
+
import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js';
|
|
24
|
+
import { cwd } from 'process';
|
|
25
|
+
import { join, resolve } from 'path';
|
|
26
|
+
import { existsSync, readFileSync, appendFileSync, writeFileSync } from 'fs';
|
|
27
|
+
import { loadIgnorePatterns } from './src/ignore-parser.js';
|
|
28
|
+
import { scanRepository } from './src/scanner.js';
|
|
29
|
+
import { generateEmbeddings } from './src/embeddings.js';
|
|
30
|
+
import { initStore, upsertChunks, closeStore } from './src/store.js';
|
|
31
|
+
import { executeSearch } from './src/search.js';
|
|
32
|
+
|
|
33
|
+
async function ensureIgnoreEntry(rootPath) {
|
|
34
|
+
const gitignorePath = join(rootPath, '.gitignore');
|
|
35
|
+
const entry = '.code-search/';
|
|
36
|
+
|
|
37
|
+
try {
|
|
38
|
+
if (existsSync(gitignorePath)) {
|
|
39
|
+
const content = readFileSync(gitignorePath, 'utf8');
|
|
40
|
+
if (!content.includes(entry)) {
|
|
41
|
+
appendFileSync(gitignorePath, `\n${entry}`);
|
|
42
|
+
}
|
|
43
|
+
} else {
|
|
44
|
+
writeFileSync(gitignorePath, `${entry}\n`);
|
|
45
|
+
}
|
|
46
|
+
} catch (e) {
|
|
47
|
+
// Ignore write errors, proceed with search anyway
|
|
48
|
+
}
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
class CodeSearchManager {
|
|
52
|
+
async search(repositoryPath, query) {
|
|
53
|
+
const absolutePath = resolve(repositoryPath);
|
|
54
|
+
|
|
55
|
+
if (!existsSync(absolutePath)) {
|
|
56
|
+
return {
|
|
57
|
+
error: `Repository path not found: ${absolutePath}`,
|
|
58
|
+
results: [],
|
|
59
|
+
};
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
try {
|
|
63
|
+
// Ensure .code-search/ is in .gitignore
|
|
64
|
+
await ensureIgnoreEntry(absolutePath);
|
|
65
|
+
|
|
66
|
+
// Load ignore patterns
|
|
67
|
+
const ignorePatterns = loadIgnorePatterns(absolutePath);
|
|
68
|
+
const dbPath = join(absolutePath, '.code-search');
|
|
69
|
+
|
|
70
|
+
// Initialize store
|
|
71
|
+
await initStore(dbPath);
|
|
72
|
+
|
|
73
|
+
// Scan repository
|
|
74
|
+
const chunks = scanRepository(absolutePath, ignorePatterns);
|
|
75
|
+
|
|
76
|
+
if (chunks.length === 0) {
|
|
77
|
+
await closeStore();
|
|
78
|
+
return {
|
|
79
|
+
query,
|
|
80
|
+
results: [],
|
|
81
|
+
message: 'No code chunks found in repository',
|
|
82
|
+
};
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
// Generate embeddings in batches
|
|
86
|
+
const batchSize = 32;
|
|
87
|
+
const chunkTexts = chunks.map(c => c.content);
|
|
88
|
+
const allEmbeddings = [];
|
|
89
|
+
|
|
90
|
+
for (let i = 0; i < chunkTexts.length; i += batchSize) {
|
|
91
|
+
const batchTexts = chunkTexts.slice(i, i + batchSize);
|
|
92
|
+
const batchEmbeddings = await generateEmbeddings(batchTexts);
|
|
93
|
+
allEmbeddings.push(...batchEmbeddings);
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
// Create chunks with embeddings
|
|
97
|
+
const chunksWithEmbeddings = chunks.map((chunk, idx) => ({
|
|
98
|
+
...chunk,
|
|
99
|
+
vector: allEmbeddings[idx],
|
|
100
|
+
}));
|
|
101
|
+
|
|
102
|
+
// Upsert to store
|
|
103
|
+
await upsertChunks(chunksWithEmbeddings);
|
|
104
|
+
|
|
105
|
+
// Execute search
|
|
106
|
+
const results = await executeSearch(query);
|
|
107
|
+
|
|
108
|
+
// Format results
|
|
109
|
+
const formattedResults = results.map((result, idx) => ({
|
|
110
|
+
rank: idx + 1,
|
|
111
|
+
file: result.file_path,
|
|
112
|
+
lines: `${result.line_start}-${result.line_end}`,
|
|
113
|
+
score: (result.score * 100).toFixed(1),
|
|
114
|
+
snippet: result.content.split('\n').slice(0, 3).join('\n'),
|
|
115
|
+
}));
|
|
116
|
+
|
|
117
|
+
await closeStore();
|
|
118
|
+
|
|
119
|
+
return {
|
|
120
|
+
query,
|
|
121
|
+
repository: absolutePath,
|
|
122
|
+
resultsCount: formattedResults.length,
|
|
123
|
+
results: formattedResults,
|
|
124
|
+
};
|
|
125
|
+
} catch (error) {
|
|
126
|
+
await closeStore().catch(() => {});
|
|
127
|
+
return {
|
|
128
|
+
error: error.message,
|
|
129
|
+
results: [],
|
|
130
|
+
};
|
|
131
|
+
}
|
|
132
|
+
}
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
const manager = new CodeSearchManager();
|
|
136
|
+
|
|
137
|
+
const server = new Server(
|
|
138
|
+
{
|
|
139
|
+
name: 'code-search-mcp',
|
|
140
|
+
version: '0.1.0',
|
|
141
|
+
},
|
|
142
|
+
{
|
|
143
|
+
capabilities: {
|
|
144
|
+
tools: {},
|
|
145
|
+
},
|
|
146
|
+
}
|
|
147
|
+
);
|
|
148
|
+
|
|
149
|
+
server.setRequestHandler(ListToolsRequestSchema, async () => {
|
|
150
|
+
return {
|
|
151
|
+
tools: [
|
|
152
|
+
{
|
|
153
|
+
name: 'search',
|
|
154
|
+
description:
|
|
155
|
+
'Search through a code repository using semantic search with Jina embeddings. Automatically indexes the repository before searching.',
|
|
156
|
+
inputSchema: {
|
|
157
|
+
type: 'object',
|
|
158
|
+
properties: {
|
|
159
|
+
repository_path: {
|
|
160
|
+
type: 'string',
|
|
161
|
+
description:
|
|
162
|
+
'Absolute or relative path to the repository to search in (defaults to current directory)',
|
|
163
|
+
},
|
|
164
|
+
query: {
|
|
165
|
+
type: 'string',
|
|
166
|
+
description:
|
|
167
|
+
'Natural language search query (e.g., "authentication middleware", "database connection")',
|
|
168
|
+
},
|
|
169
|
+
},
|
|
170
|
+
required: ['query'],
|
|
171
|
+
},
|
|
172
|
+
},
|
|
173
|
+
],
|
|
174
|
+
};
|
|
175
|
+
});
|
|
176
|
+
|
|
177
|
+
server.setRequestHandler(CallToolRequestSchema, async (request) => {
|
|
178
|
+
const { name, arguments: args } = request.params;
|
|
179
|
+
|
|
180
|
+
if (name !== 'search') {
|
|
181
|
+
return {
|
|
182
|
+
content: [
|
|
183
|
+
{
|
|
184
|
+
type: 'text',
|
|
185
|
+
text: `Unknown tool: ${name}`,
|
|
186
|
+
},
|
|
187
|
+
],
|
|
188
|
+
isError: true,
|
|
189
|
+
};
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
const query = args?.query;
|
|
193
|
+
const repositoryPath = args?.repository_path || cwd();
|
|
194
|
+
|
|
195
|
+
if (!query || typeof query !== 'string') {
|
|
196
|
+
return {
|
|
197
|
+
content: [
|
|
198
|
+
{
|
|
199
|
+
type: 'text',
|
|
200
|
+
text: 'Error: query is required and must be a string',
|
|
201
|
+
},
|
|
202
|
+
],
|
|
203
|
+
isError: true,
|
|
204
|
+
};
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
try {
|
|
208
|
+
const result = await manager.search(repositoryPath, query);
|
|
209
|
+
|
|
210
|
+
if (result.error) {
|
|
211
|
+
return {
|
|
212
|
+
content: [
|
|
213
|
+
{
|
|
214
|
+
type: 'text',
|
|
215
|
+
text: `Error: ${result.error}`,
|
|
216
|
+
},
|
|
217
|
+
],
|
|
218
|
+
isError: true,
|
|
219
|
+
};
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
const text =
|
|
223
|
+
result.resultsCount === 0
|
|
224
|
+
? `No results found for query: "${query}"\nRepository: ${result.repository}`
|
|
225
|
+
: `Found ${result.resultsCount} result${result.resultsCount !== 1 ? 's' : ''} for query: "${query}"\nRepository: ${result.repository}\n\n${result.results
|
|
226
|
+
.map(
|
|
227
|
+
(r) =>
|
|
228
|
+
`${r.rank}. ${r.file}:${r.lines} (score: ${r.score}%)\n${r.snippet
|
|
229
|
+
.split('\n')
|
|
230
|
+
.map((line) => ` ${line}`)
|
|
231
|
+
.join('\n')}`
|
|
232
|
+
)
|
|
233
|
+
.join('\n\n')}`;
|
|
234
|
+
|
|
235
|
+
return {
|
|
236
|
+
content: [
|
|
237
|
+
{
|
|
238
|
+
type: 'text',
|
|
239
|
+
text,
|
|
240
|
+
},
|
|
241
|
+
],
|
|
242
|
+
};
|
|
243
|
+
} catch (error) {
|
|
244
|
+
return {
|
|
245
|
+
content: [
|
|
246
|
+
{
|
|
247
|
+
type: 'text',
|
|
248
|
+
text: `Error: ${error.message}`,
|
|
249
|
+
},
|
|
250
|
+
],
|
|
251
|
+
isError: true,
|
|
252
|
+
};
|
|
253
|
+
}
|
|
254
|
+
});
|
|
255
|
+
|
|
256
|
+
export async function startMcpServer() {
|
|
257
|
+
const transport = new StdioServerTransport();
|
|
258
|
+
await server.connect(transport);
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
const isMain = process.argv[1] && (
|
|
262
|
+
process.argv[1] === fileURLToPath(import.meta.url) ||
|
|
263
|
+
process.argv[1].endsWith('mcp.js') ||
|
|
264
|
+
process.argv[1].endsWith('code-search-mcp')
|
|
265
|
+
);
|
|
266
|
+
|
|
267
|
+
if (isMain) {
|
|
268
|
+
main().catch((error) => {
|
|
269
|
+
console.error('Server error:', error);
|
|
270
|
+
process.exit(1);
|
|
271
|
+
});
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
async function main() {
|
|
275
|
+
await startMcpServer();
|
|
276
|
+
}
|
package/package.json
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
{
|
|
2
|
+
"name": "codebasesearch",
|
|
3
|
+
"version": "0.1.4",
|
|
4
|
+
"description": "Ultra-simple code search tool with Jina embeddings, LanceDB, and MCP protocol support",
|
|
5
|
+
"type": "module",
|
|
6
|
+
"bin": {
|
|
7
|
+
"code-search": "./bin/code-search.js",
|
|
8
|
+
"code-search-mcp": "./mcp.js"
|
|
9
|
+
},
|
|
10
|
+
"main": "src/cli.js",
|
|
11
|
+
"engines": {
|
|
12
|
+
"node": ">=18.0.0"
|
|
13
|
+
},
|
|
14
|
+
"scripts": {
|
|
15
|
+
"start": "node ./bin/code-search.js",
|
|
16
|
+
"postinstall": "node scripts/patch-transformers.js"
|
|
17
|
+
},
|
|
18
|
+
"dependencies": {
|
|
19
|
+
"@huggingface/transformers": "^3.8.1",
|
|
20
|
+
"@modelcontextprotocol/sdk": "1.0.0",
|
|
21
|
+
"apache-arrow": "^14.0.0",
|
|
22
|
+
"fast-glob": "^3.3.2",
|
|
23
|
+
"onnxruntime-node": "^1.23.2",
|
|
24
|
+
"onnxruntime-web": "^1.19.0",
|
|
25
|
+
"vectordb": "^0.21.2"
|
|
26
|
+
},
|
|
27
|
+
"keywords": [
|
|
28
|
+
"code-search",
|
|
29
|
+
"embeddings",
|
|
30
|
+
"vector-db",
|
|
31
|
+
"jina",
|
|
32
|
+
"lancedb",
|
|
33
|
+
"mcp"
|
|
34
|
+
]
|
|
35
|
+
}
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
import fs from 'fs';
|
|
2
|
+
import path from 'path';
|
|
3
|
+
import { fileURLToPath } from 'url';
|
|
4
|
+
|
|
5
|
+
const __dirname = path.dirname(fileURLToPath(import.meta.url));
|
|
6
|
+
|
|
7
|
+
// Patch @huggingface/transformers dist file for Windows compatibility
|
|
8
|
+
const distPath = path.join(__dirname, '..', 'node_modules', '@huggingface', 'transformers', 'dist', 'transformers.node.mjs');
|
|
9
|
+
|
|
10
|
+
if (!fs.existsSync(distPath)) {
|
|
11
|
+
console.log('transformers.node.mjs not found, skipping patch');
|
|
12
|
+
process.exit(0);
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
let distContent = fs.readFileSync(distPath, 'utf-8');
|
|
16
|
+
|
|
17
|
+
// Check if already patched
|
|
18
|
+
if (distContent.includes('SHARP_PATCHED_FOR_WINDOWS')) {
|
|
19
|
+
console.log('transformers.node.mjs already patched');
|
|
20
|
+
process.exit(0);
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
// Remove sharp import line
|
|
24
|
+
distContent = distContent.replace(
|
|
25
|
+
/import \* as __WEBPACK_EXTERNAL_MODULE_sharp__ from "sharp";\n/,
|
|
26
|
+
'// SHARP_PATCHED_FOR_WINDOWS: sharp removed\n'
|
|
27
|
+
);
|
|
28
|
+
|
|
29
|
+
// Replace sharp module exports with stub
|
|
30
|
+
distContent = distContent.replace(
|
|
31
|
+
/module\.exports = __WEBPACK_EXTERNAL_MODULE_sharp__;/g,
|
|
32
|
+
'module.exports = {};'
|
|
33
|
+
);
|
|
34
|
+
|
|
35
|
+
// Replace image processing error with fallback
|
|
36
|
+
distContent = distContent.replace(
|
|
37
|
+
/} else \{\s*throw new Error\('Unable to load image processing library\.'\);\s*\}/,
|
|
38
|
+
'} else {\n loadImageFunction = async () => { throw new Error(\'Image processing unavailable\'); };\n}'
|
|
39
|
+
);
|
|
40
|
+
|
|
41
|
+
fs.writeFileSync(distPath, distContent);
|
|
42
|
+
console.log('Successfully patched transformers.node.mjs for Windows compatibility');
|
package/src/cli.js
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
import { cwd } from 'process';
|
|
2
|
+
import { join } from 'path';
|
|
3
|
+
import { existsSync, readFileSync, appendFileSync, writeFileSync } from 'fs';
|
|
4
|
+
import { loadIgnorePatterns } from './ignore-parser.js';
|
|
5
|
+
import { scanRepository } from './scanner.js';
|
|
6
|
+
import { generateEmbeddings } from './embeddings.js';
|
|
7
|
+
import { initStore, upsertChunks, closeStore } from './store.js';
|
|
8
|
+
import { executeSearch, formatResults } from './search.js';
|
|
9
|
+
import { startMcpServer } from '../mcp.js';
|
|
10
|
+
|
|
11
|
+
async function isGitRepository(rootPath) {
|
|
12
|
+
const gitDir = join(rootPath, '.git');
|
|
13
|
+
try {
|
|
14
|
+
return existsSync(gitDir);
|
|
15
|
+
} catch {
|
|
16
|
+
return false;
|
|
17
|
+
}
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
async function ensureIgnoreEntry(rootPath) {
|
|
21
|
+
const gitignorePath = join(rootPath, '.gitignore');
|
|
22
|
+
const entry = '.code-search/';
|
|
23
|
+
|
|
24
|
+
try {
|
|
25
|
+
if (existsSync(gitignorePath)) {
|
|
26
|
+
const content = readFileSync(gitignorePath, 'utf8');
|
|
27
|
+
if (!content.includes(entry)) {
|
|
28
|
+
appendFileSync(gitignorePath, `\n${entry}`);
|
|
29
|
+
}
|
|
30
|
+
} else {
|
|
31
|
+
writeFileSync(gitignorePath, `${entry}\n`);
|
|
32
|
+
}
|
|
33
|
+
} catch (e) {
|
|
34
|
+
// Ignore write errors, proceed with search anyway
|
|
35
|
+
}
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
export async function run(args) {
|
|
40
|
+
try {
|
|
41
|
+
// Start MCP server if no arguments provided
|
|
42
|
+
if (args.length === 0) {
|
|
43
|
+
await startMcpServer();
|
|
44
|
+
return;
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
const query = args.join(' ');
|
|
48
|
+
const rootPath = cwd();
|
|
49
|
+
|
|
50
|
+
console.log(`Code Search Tool`);
|
|
51
|
+
console.log(`Root: ${rootPath}\n`);
|
|
52
|
+
|
|
53
|
+
// Check if git repo
|
|
54
|
+
const isGit = await isGitRepository(rootPath);
|
|
55
|
+
if (!isGit) {
|
|
56
|
+
console.warn('Warning: Not a git repository. Indexing current directory anyway.\n');
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
// Ensure .code-search/ is in .gitignore
|
|
60
|
+
await ensureIgnoreEntry(rootPath);
|
|
61
|
+
|
|
62
|
+
// Load ignore patterns
|
|
63
|
+
const ignorePatterns = loadIgnorePatterns(rootPath);
|
|
64
|
+
const dbPath = join(rootPath, '.code-search');
|
|
65
|
+
|
|
66
|
+
// Initialize store
|
|
67
|
+
await initStore(dbPath);
|
|
68
|
+
|
|
69
|
+
// Scan repository
|
|
70
|
+
console.log('Scanning repository...');
|
|
71
|
+
const chunks = scanRepository(rootPath, ignorePatterns);
|
|
72
|
+
console.log(`Found ${chunks.length} code chunks\n`);
|
|
73
|
+
|
|
74
|
+
// Always reindex to ensure freshness
|
|
75
|
+
console.log('Generating embeddings and indexing...');
|
|
76
|
+
|
|
77
|
+
// Generate embeddings in batches
|
|
78
|
+
const batchSize = 32;
|
|
79
|
+
const chunkTexts = chunks.map(c => c.content);
|
|
80
|
+
const allEmbeddings = [];
|
|
81
|
+
|
|
82
|
+
for (let i = 0; i < chunkTexts.length; i += batchSize) {
|
|
83
|
+
const batchTexts = chunkTexts.slice(i, i + batchSize);
|
|
84
|
+
const batchEmbeddings = await generateEmbeddings(batchTexts);
|
|
85
|
+
allEmbeddings.push(...batchEmbeddings);
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
// Create chunks with embeddings
|
|
89
|
+
const chunksWithEmbeddings = chunks.map((chunk, idx) => ({
|
|
90
|
+
...chunk,
|
|
91
|
+
vector: allEmbeddings[idx]
|
|
92
|
+
}));
|
|
93
|
+
|
|
94
|
+
// Upsert to store
|
|
95
|
+
await upsertChunks(chunksWithEmbeddings);
|
|
96
|
+
console.log('Index created\n');
|
|
97
|
+
|
|
98
|
+
// Execute search
|
|
99
|
+
const results = await executeSearch(query);
|
|
100
|
+
|
|
101
|
+
// Format and display results
|
|
102
|
+
const output = formatResults(results);
|
|
103
|
+
console.log(output);
|
|
104
|
+
|
|
105
|
+
// Clean shutdown
|
|
106
|
+
await closeStore();
|
|
107
|
+
process.exit(0);
|
|
108
|
+
} catch (error) {
|
|
109
|
+
console.error('Error:', error.message);
|
|
110
|
+
await closeStore().catch(() => {});
|
|
111
|
+
process.exit(1);
|
|
112
|
+
}
|
|
113
|
+
}
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
import { pipeline, env } from '@huggingface/transformers';
|
|
2
|
+
import { rmSync, existsSync } from 'fs';
|
|
3
|
+
import { homedir } from 'os';
|
|
4
|
+
import { join } from 'path';
|
|
5
|
+
|
|
6
|
+
// Force WASM backend only - disable onnxruntime-node to avoid memory issues on Windows
|
|
7
|
+
try {
|
|
8
|
+
env.backends.onnx.wasm.numThreads = 1;
|
|
9
|
+
env.backends.onnx.ort = null;
|
|
10
|
+
} catch (e) {
|
|
11
|
+
// Continue even if env config fails
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
let modelCache = null;
|
|
15
|
+
let cacheCleared = false;
|
|
16
|
+
|
|
17
|
+
function clearModelCache() {
|
|
18
|
+
const cacheDirs = [
|
|
19
|
+
join(homedir(), '.cache', 'huggingface', 'transformers'),
|
|
20
|
+
join(process.cwd(), 'node_modules', '@huggingface', 'transformers', '.cache'),
|
|
21
|
+
];
|
|
22
|
+
|
|
23
|
+
for (const cacheDir of cacheDirs) {
|
|
24
|
+
try {
|
|
25
|
+
if (existsSync(cacheDir)) {
|
|
26
|
+
rmSync(cacheDir, { recursive: true, force: true });
|
|
27
|
+
}
|
|
28
|
+
} catch (e) {
|
|
29
|
+
// Ignore errors, continue with next cache dir
|
|
30
|
+
}
|
|
31
|
+
}
|
|
32
|
+
console.error('Cleared corrupted model cache');
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
async function getModel(retryOnError = true) {
|
|
36
|
+
if (modelCache) {
|
|
37
|
+
return modelCache;
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
console.error('Loading embeddings model (this may take a moment on first run)...');
|
|
41
|
+
|
|
42
|
+
const modelLoadPromise = pipeline(
|
|
43
|
+
'feature-extraction',
|
|
44
|
+
'Xenova/all-minilm-l6-v2'
|
|
45
|
+
);
|
|
46
|
+
|
|
47
|
+
const timeoutPromise = new Promise((_, reject) =>
|
|
48
|
+
setTimeout(() => reject(new Error('Model loading timeout after 5 minutes')), 300000)
|
|
49
|
+
);
|
|
50
|
+
|
|
51
|
+
try {
|
|
52
|
+
modelCache = await Promise.race([modelLoadPromise, timeoutPromise]);
|
|
53
|
+
} catch (e) {
|
|
54
|
+
if (retryOnError && !cacheCleared && (e.message.includes('Protobuf') || e.message.includes('parsing'))) {
|
|
55
|
+
console.error('Detected corrupted cache, clearing and retrying...');
|
|
56
|
+
cacheCleared = true;
|
|
57
|
+
clearModelCache();
|
|
58
|
+
modelCache = null;
|
|
59
|
+
return getModel(false);
|
|
60
|
+
}
|
|
61
|
+
console.error('Error loading model:', e.message);
|
|
62
|
+
throw e;
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
return modelCache;
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
async function meanPooling(modelOutput, attentionMask) {
|
|
69
|
+
// Get token embeddings from model output
|
|
70
|
+
const tokenEmbeddings = modelOutput.data;
|
|
71
|
+
const embeddingDim = modelOutput.dims[modelOutput.dims.length - 1];
|
|
72
|
+
const batchSize = modelOutput.dims[0];
|
|
73
|
+
const seqLength = modelOutput.dims[1];
|
|
74
|
+
|
|
75
|
+
const pooled = [];
|
|
76
|
+
|
|
77
|
+
for (let b = 0; b < batchSize; b++) {
|
|
78
|
+
let sum = new Array(embeddingDim).fill(0);
|
|
79
|
+
let count = 0;
|
|
80
|
+
|
|
81
|
+
for (let s = 0; s < seqLength; s++) {
|
|
82
|
+
const tokenIdx = b * seqLength + s;
|
|
83
|
+
const maskValue = attentionMask[tokenIdx] || 1;
|
|
84
|
+
|
|
85
|
+
if (maskValue > 0) {
|
|
86
|
+
const tokenStart = tokenIdx * embeddingDim;
|
|
87
|
+
for (let d = 0; d < embeddingDim; d++) {
|
|
88
|
+
sum[d] += tokenEmbeddings[tokenStart + d] * maskValue;
|
|
89
|
+
}
|
|
90
|
+
count += maskValue;
|
|
91
|
+
}
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
const normalized = sum.map(v => v / Math.max(count, 1e-9));
|
|
95
|
+
pooled.push(normalized);
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
return pooled;
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
export async function generateEmbeddings(texts) {
|
|
102
|
+
const model = await getModel();
|
|
103
|
+
|
|
104
|
+
if (!Array.isArray(texts)) {
|
|
105
|
+
texts = [texts];
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
// Generate embeddings for all texts
|
|
109
|
+
const embeddings = await model(texts, {
|
|
110
|
+
pooling: 'mean',
|
|
111
|
+
normalize: true
|
|
112
|
+
});
|
|
113
|
+
|
|
114
|
+
// Convert to regular arrays
|
|
115
|
+
const result = [];
|
|
116
|
+
|
|
117
|
+
// embeddings is a Tensor, convert to array
|
|
118
|
+
if (embeddings && embeddings.data) {
|
|
119
|
+
const data = Array.from(embeddings.data);
|
|
120
|
+
const shape = embeddings.dims;
|
|
121
|
+
|
|
122
|
+
// Shape is [batchSize, embeddingDim]
|
|
123
|
+
if (shape && shape.length === 2) {
|
|
124
|
+
const [batchSize, embeddingDim] = shape;
|
|
125
|
+
for (let i = 0; i < batchSize; i++) {
|
|
126
|
+
const start = i * embeddingDim;
|
|
127
|
+
const end = start + embeddingDim;
|
|
128
|
+
result.push(data.slice(start, end));
|
|
129
|
+
}
|
|
130
|
+
} else {
|
|
131
|
+
// Fallback: assume single embedding
|
|
132
|
+
result.push(data);
|
|
133
|
+
}
|
|
134
|
+
} else if (Array.isArray(embeddings)) {
|
|
135
|
+
// Already an array
|
|
136
|
+
for (const emb of embeddings) {
|
|
137
|
+
if (emb.data) {
|
|
138
|
+
result.push(Array.from(emb.data));
|
|
139
|
+
} else {
|
|
140
|
+
result.push(Array.from(emb));
|
|
141
|
+
}
|
|
142
|
+
}
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
return result;
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
export async function generateSingleEmbedding(text) {
|
|
149
|
+
const embeddings = await generateEmbeddings([text]);
|
|
150
|
+
return embeddings[0];
|
|
151
|
+
}
|