code-graph-context 2.8.0 → 2.10.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +101 -26
- package/dist/cli/cli.js +250 -10
- package/dist/core/embeddings/embedding-sidecar.js +244 -0
- package/dist/core/embeddings/embeddings.service.js +60 -132
- package/dist/core/embeddings/local-embeddings.service.js +41 -0
- package/dist/core/embeddings/openai-embeddings.service.js +114 -0
- package/dist/mcp/constants.js +24 -1
- package/dist/mcp/handlers/graph-generator.handler.js +6 -5
- package/dist/mcp/mcp.server.js +2 -0
- package/dist/mcp/service-init.js +24 -3
- package/dist/mcp/tools/index.js +3 -0
- package/dist/mcp/tools/search-codebase.tool.js +37 -13
- package/dist/mcp/tools/session-note.tool.js +5 -6
- package/dist/mcp/tools/swarm-claim-task.tool.js +35 -0
- package/dist/mcp/tools/swarm-cleanup.tool.js +55 -3
- package/dist/mcp/tools/swarm-constants.js +28 -0
- package/dist/mcp/tools/swarm-message.tool.js +362 -0
- package/dist/storage/neo4j/neo4j.service.js +4 -4
- package/package.json +3 -1
- package/sidecar/embedding_server.py +147 -0
- package/sidecar/requirements.txt +5 -0
|
@@ -0,0 +1,362 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Swarm Message Tool
|
|
3
|
+
* Direct agent-to-agent messaging via Neo4j for reliable coordination.
|
|
4
|
+
*
|
|
5
|
+
* Unlike pheromones (passive, decay-based stigmergy), messages are explicit
|
|
6
|
+
* and delivered to agents when they claim tasks. This ensures critical
|
|
7
|
+
* coordination signals (blocked, conflict, findings) are reliably received.
|
|
8
|
+
*/
|
|
9
|
+
import { z } from 'zod';
|
|
10
|
+
import { Neo4jService } from '../../storage/neo4j/neo4j.service.js';
|
|
11
|
+
import { TOOL_NAMES, TOOL_METADATA } from '../constants.js';
|
|
12
|
+
import { createErrorResponse, createSuccessResponse, resolveProjectIdOrError, debugLog } from '../utils.js';
|
|
13
|
+
import { MESSAGE_CATEGORY_KEYS, MESSAGE_DEFAULT_TTL_MS, generateMessageId } from './swarm-constants.js';
|
|
14
|
+
// ============================================================================
|
|
15
|
+
// NEO4J QUERIES
|
|
16
|
+
// ============================================================================
|
|
17
|
+
/**
|
|
18
|
+
* Send a message. Creates a SwarmMessage node with optional target agent.
|
|
19
|
+
* Broadcast messages (no toAgentId) are visible to all agents in the swarm.
|
|
20
|
+
*/
|
|
21
|
+
const SEND_MESSAGE_QUERY = `
|
|
22
|
+
CREATE (m:SwarmMessage {
|
|
23
|
+
id: $messageId,
|
|
24
|
+
projectId: $projectId,
|
|
25
|
+
swarmId: $swarmId,
|
|
26
|
+
fromAgentId: $fromAgentId,
|
|
27
|
+
toAgentId: $toAgentId,
|
|
28
|
+
category: $category,
|
|
29
|
+
content: $content,
|
|
30
|
+
taskId: $taskId,
|
|
31
|
+
filePaths: $filePaths,
|
|
32
|
+
timestamp: timestamp(),
|
|
33
|
+
expiresAt: timestamp() + $ttlMs,
|
|
34
|
+
readBy: []
|
|
35
|
+
})
|
|
36
|
+
RETURN m.id as id,
|
|
37
|
+
m.swarmId as swarmId,
|
|
38
|
+
m.fromAgentId as fromAgentId,
|
|
39
|
+
m.toAgentId as toAgentId,
|
|
40
|
+
m.category as category,
|
|
41
|
+
m.timestamp as timestamp,
|
|
42
|
+
m.expiresAt as expiresAt
|
|
43
|
+
`;
|
|
44
|
+
/**
|
|
45
|
+
* Read messages for an agent. Returns messages that are:
|
|
46
|
+
* 1. Addressed to this agent specifically, OR
|
|
47
|
+
* 2. Broadcast (toAgentId is null) to the same swarm
|
|
48
|
+
* AND not yet expired.
|
|
49
|
+
* Optionally filters to unread-only (not in readBy list).
|
|
50
|
+
*/
|
|
51
|
+
const READ_MESSAGES_QUERY = `
|
|
52
|
+
MATCH (m:SwarmMessage)
|
|
53
|
+
WHERE m.projectId = $projectId
|
|
54
|
+
AND m.swarmId = $swarmId
|
|
55
|
+
AND m.expiresAt > timestamp()
|
|
56
|
+
AND (m.toAgentId IS NULL OR m.toAgentId = $agentId)
|
|
57
|
+
AND ($unreadOnly = false OR NOT $agentId IN m.readBy)
|
|
58
|
+
AND ($categories IS NULL OR size($categories) = 0 OR m.category IN $categories)
|
|
59
|
+
AND ($fromAgentId IS NULL OR m.fromAgentId = $fromAgentId)
|
|
60
|
+
RETURN m.id as id,
|
|
61
|
+
m.swarmId as swarmId,
|
|
62
|
+
m.fromAgentId as fromAgentId,
|
|
63
|
+
m.toAgentId as toAgentId,
|
|
64
|
+
m.category as category,
|
|
65
|
+
m.content as content,
|
|
66
|
+
m.taskId as taskId,
|
|
67
|
+
m.filePaths as filePaths,
|
|
68
|
+
m.timestamp as timestamp,
|
|
69
|
+
m.expiresAt as expiresAt,
|
|
70
|
+
m.readBy as readBy,
|
|
71
|
+
NOT $agentId IN m.readBy as isUnread
|
|
72
|
+
ORDER BY m.timestamp DESC
|
|
73
|
+
LIMIT toInteger($limit)
|
|
74
|
+
`;
|
|
75
|
+
/**
|
|
76
|
+
* Acknowledge (mark as read) specific messages for an agent.
|
|
77
|
+
* Uses APOC to atomically add agentId to readBy array.
|
|
78
|
+
*/
|
|
79
|
+
const ACKNOWLEDGE_MESSAGES_QUERY = `
|
|
80
|
+
UNWIND $messageIds as msgId
|
|
81
|
+
MATCH (m:SwarmMessage {id: msgId, projectId: $projectId})
|
|
82
|
+
WHERE NOT $agentId IN m.readBy
|
|
83
|
+
SET m.readBy = m.readBy + $agentId
|
|
84
|
+
RETURN m.id as id, m.category as category
|
|
85
|
+
`;
|
|
86
|
+
/**
|
|
87
|
+
* Acknowledge ALL unread messages for an agent in a swarm.
|
|
88
|
+
*/
|
|
89
|
+
const ACKNOWLEDGE_ALL_QUERY = `
|
|
90
|
+
MATCH (m:SwarmMessage)
|
|
91
|
+
WHERE m.projectId = $projectId
|
|
92
|
+
AND m.swarmId = $swarmId
|
|
93
|
+
AND (m.toAgentId IS NULL OR m.toAgentId = $agentId)
|
|
94
|
+
AND NOT $agentId IN m.readBy
|
|
95
|
+
AND m.expiresAt > timestamp()
|
|
96
|
+
SET m.readBy = m.readBy + $agentId
|
|
97
|
+
RETURN count(m) as acknowledged
|
|
98
|
+
`;
|
|
99
|
+
/**
|
|
100
|
+
* Fetch pending messages for delivery during task claim.
|
|
101
|
+
* Returns unread messages addressed to or broadcast to the agent.
|
|
102
|
+
* Used internally by swarm_claim_task integration.
|
|
103
|
+
*/
|
|
104
|
+
export const PENDING_MESSAGES_FOR_AGENT_QUERY = `
|
|
105
|
+
MATCH (m:SwarmMessage)
|
|
106
|
+
WHERE m.projectId = $projectId
|
|
107
|
+
AND m.swarmId = $swarmId
|
|
108
|
+
AND m.expiresAt > timestamp()
|
|
109
|
+
AND (m.toAgentId IS NULL OR m.toAgentId = $agentId)
|
|
110
|
+
AND NOT $agentId IN m.readBy
|
|
111
|
+
RETURN m.id as id,
|
|
112
|
+
m.fromAgentId as fromAgentId,
|
|
113
|
+
m.category as category,
|
|
114
|
+
m.content as content,
|
|
115
|
+
m.taskId as taskId,
|
|
116
|
+
m.filePaths as filePaths,
|
|
117
|
+
m.timestamp as timestamp
|
|
118
|
+
ORDER BY
|
|
119
|
+
CASE m.category
|
|
120
|
+
WHEN 'alert' THEN 0
|
|
121
|
+
WHEN 'conflict' THEN 1
|
|
122
|
+
WHEN 'blocked' THEN 2
|
|
123
|
+
WHEN 'request' THEN 3
|
|
124
|
+
WHEN 'finding' THEN 4
|
|
125
|
+
WHEN 'handoff' THEN 5
|
|
126
|
+
ELSE 6
|
|
127
|
+
END,
|
|
128
|
+
m.timestamp DESC
|
|
129
|
+
LIMIT 10
|
|
130
|
+
`;
|
|
131
|
+
/**
|
|
132
|
+
* Auto-acknowledge messages that were delivered during claim.
|
|
133
|
+
*/
|
|
134
|
+
export const AUTO_ACKNOWLEDGE_QUERY = `
|
|
135
|
+
UNWIND $messageIds as msgId
|
|
136
|
+
MATCH (m:SwarmMessage {id: msgId})
|
|
137
|
+
WHERE NOT $agentId IN m.readBy
|
|
138
|
+
SET m.readBy = m.readBy + $agentId
|
|
139
|
+
RETURN count(m) as acknowledged
|
|
140
|
+
`;
|
|
141
|
+
/**
|
|
142
|
+
* Cleanup expired messages for a swarm.
|
|
143
|
+
*/
|
|
144
|
+
const CLEANUP_EXPIRED_QUERY = `
|
|
145
|
+
MATCH (m:SwarmMessage)
|
|
146
|
+
WHERE m.projectId = $projectId
|
|
147
|
+
AND ($swarmId IS NULL OR m.swarmId = $swarmId)
|
|
148
|
+
AND m.expiresAt < timestamp()
|
|
149
|
+
DELETE m
|
|
150
|
+
RETURN count(m) as cleaned
|
|
151
|
+
`;
|
|
152
|
+
// ============================================================================
|
|
153
|
+
// TOOL CREATION
|
|
154
|
+
// ============================================================================
|
|
155
|
+
export const createSwarmMessageTool = (server) => {
|
|
156
|
+
server.registerTool(TOOL_NAMES.swarmMessage, {
|
|
157
|
+
title: TOOL_METADATA[TOOL_NAMES.swarmMessage].title,
|
|
158
|
+
description: TOOL_METADATA[TOOL_NAMES.swarmMessage].description,
|
|
159
|
+
inputSchema: {
|
|
160
|
+
projectId: z.string().describe('Project ID, name, or path'),
|
|
161
|
+
swarmId: z.string().describe('Swarm ID for scoping messages'),
|
|
162
|
+
agentId: z.string().describe('Your unique agent identifier'),
|
|
163
|
+
action: z
|
|
164
|
+
.enum(['send', 'read', 'acknowledge'])
|
|
165
|
+
.describe('Action: send (post message), read (get messages), acknowledge (mark as read)'),
|
|
166
|
+
// Send parameters
|
|
167
|
+
toAgentId: z
|
|
168
|
+
.string()
|
|
169
|
+
.optional()
|
|
170
|
+
.describe('Target agent ID. Omit for broadcast to all swarm agents.'),
|
|
171
|
+
category: z
|
|
172
|
+
.enum(MESSAGE_CATEGORY_KEYS)
|
|
173
|
+
.optional()
|
|
174
|
+
.describe('Message category: blocked (need help), conflict (resource clash), finding (important discovery), ' +
|
|
175
|
+
'request (direct ask), alert (urgent notification), handoff (context transfer)'),
|
|
176
|
+
content: z
|
|
177
|
+
.string()
|
|
178
|
+
.optional()
|
|
179
|
+
.describe('Message content (required for send action)'),
|
|
180
|
+
taskId: z
|
|
181
|
+
.string()
|
|
182
|
+
.optional()
|
|
183
|
+
.describe('Related task ID for context'),
|
|
184
|
+
filePaths: z
|
|
185
|
+
.array(z.string())
|
|
186
|
+
.optional()
|
|
187
|
+
.describe('File paths relevant to this message'),
|
|
188
|
+
ttlMs: z
|
|
189
|
+
.number()
|
|
190
|
+
.int()
|
|
191
|
+
.optional()
|
|
192
|
+
.describe(`Time-to-live in ms (default: ${MESSAGE_DEFAULT_TTL_MS / 3600000}h). Set 0 for swarm lifetime.`),
|
|
193
|
+
// Read parameters
|
|
194
|
+
unreadOnly: z
|
|
195
|
+
.boolean()
|
|
196
|
+
.optional()
|
|
197
|
+
.default(true)
|
|
198
|
+
.describe('Only return unread messages (default: true)'),
|
|
199
|
+
categories: z
|
|
200
|
+
.array(z.enum(MESSAGE_CATEGORY_KEYS))
|
|
201
|
+
.optional()
|
|
202
|
+
.describe('Filter by message categories'),
|
|
203
|
+
fromAgentId: z
|
|
204
|
+
.string()
|
|
205
|
+
.optional()
|
|
206
|
+
.describe('Filter messages from a specific agent'),
|
|
207
|
+
limit: z
|
|
208
|
+
.number()
|
|
209
|
+
.int()
|
|
210
|
+
.min(1)
|
|
211
|
+
.max(100)
|
|
212
|
+
.optional()
|
|
213
|
+
.default(20)
|
|
214
|
+
.describe('Maximum messages to return (default: 20)'),
|
|
215
|
+
// Acknowledge parameters
|
|
216
|
+
messageIds: z
|
|
217
|
+
.array(z.string())
|
|
218
|
+
.optional()
|
|
219
|
+
.describe('Specific message IDs to acknowledge. Omit to acknowledge all unread.'),
|
|
220
|
+
// Maintenance
|
|
221
|
+
cleanup: z
|
|
222
|
+
.boolean()
|
|
223
|
+
.optional()
|
|
224
|
+
.default(false)
|
|
225
|
+
.describe('Also clean up expired messages'),
|
|
226
|
+
},
|
|
227
|
+
}, async ({ projectId, swarmId, agentId, action, toAgentId, category, content, taskId, filePaths, ttlMs, unreadOnly = true, categories, fromAgentId, limit = 20, messageIds, cleanup = false, }) => {
|
|
228
|
+
const neo4jService = new Neo4jService();
|
|
229
|
+
const projectResult = await resolveProjectIdOrError(projectId, neo4jService);
|
|
230
|
+
if (!projectResult.success) {
|
|
231
|
+
await neo4jService.close();
|
|
232
|
+
return projectResult.error;
|
|
233
|
+
}
|
|
234
|
+
const resolvedProjectId = projectResult.projectId;
|
|
235
|
+
try {
|
|
236
|
+
// Optional cleanup of expired messages
|
|
237
|
+
let cleanedCount = 0;
|
|
238
|
+
if (cleanup) {
|
|
239
|
+
const cleanResult = await neo4jService.run(CLEANUP_EXPIRED_QUERY, {
|
|
240
|
+
projectId: resolvedProjectId,
|
|
241
|
+
swarmId,
|
|
242
|
+
});
|
|
243
|
+
cleanedCount = cleanResult[0]?.cleaned ?? 0;
|
|
244
|
+
}
|
|
245
|
+
// ── SEND ──────────────────────────────────────────────────────
|
|
246
|
+
if (action === 'send') {
|
|
247
|
+
if (!category) {
|
|
248
|
+
return createErrorResponse('category is required for send action');
|
|
249
|
+
}
|
|
250
|
+
if (!content) {
|
|
251
|
+
return createErrorResponse('content is required for send action');
|
|
252
|
+
}
|
|
253
|
+
const messageId = generateMessageId();
|
|
254
|
+
const effectiveTtl = ttlMs ?? MESSAGE_DEFAULT_TTL_MS;
|
|
255
|
+
const result = await neo4jService.run(SEND_MESSAGE_QUERY, {
|
|
256
|
+
messageId,
|
|
257
|
+
projectId: resolvedProjectId,
|
|
258
|
+
swarmId,
|
|
259
|
+
fromAgentId: agentId,
|
|
260
|
+
toAgentId: toAgentId ?? null,
|
|
261
|
+
category,
|
|
262
|
+
content,
|
|
263
|
+
taskId: taskId ?? null,
|
|
264
|
+
filePaths: filePaths ?? [],
|
|
265
|
+
ttlMs: effectiveTtl,
|
|
266
|
+
});
|
|
267
|
+
if (result.length === 0) {
|
|
268
|
+
return createErrorResponse('Failed to create message');
|
|
269
|
+
}
|
|
270
|
+
const msg = result[0];
|
|
271
|
+
const ts = typeof msg.timestamp === 'object' && msg.timestamp?.toNumber ? msg.timestamp.toNumber() : msg.timestamp;
|
|
272
|
+
return createSuccessResponse(JSON.stringify({
|
|
273
|
+
action: 'sent',
|
|
274
|
+
message: {
|
|
275
|
+
id: messageId,
|
|
276
|
+
swarmId: msg.swarmId,
|
|
277
|
+
from: msg.fromAgentId,
|
|
278
|
+
to: msg.toAgentId ?? 'broadcast',
|
|
279
|
+
category: msg.category,
|
|
280
|
+
expiresIn: effectiveTtl > 0 ? `${Math.round(effectiveTtl / 60000)} minutes` : 'never',
|
|
281
|
+
},
|
|
282
|
+
...(cleanedCount > 0 && { expiredCleaned: cleanedCount }),
|
|
283
|
+
}));
|
|
284
|
+
}
|
|
285
|
+
// ── READ ──────────────────────────────────────────────────────
|
|
286
|
+
if (action === 'read') {
|
|
287
|
+
const result = await neo4jService.run(READ_MESSAGES_QUERY, {
|
|
288
|
+
projectId: resolvedProjectId,
|
|
289
|
+
swarmId,
|
|
290
|
+
agentId,
|
|
291
|
+
unreadOnly,
|
|
292
|
+
categories: categories ?? null,
|
|
293
|
+
fromAgentId: fromAgentId ?? null,
|
|
294
|
+
limit: Math.floor(limit),
|
|
295
|
+
});
|
|
296
|
+
const messages = result.map((m) => {
|
|
297
|
+
const ts = typeof m.timestamp === 'object' && m.timestamp?.toNumber ? m.timestamp.toNumber() : m.timestamp;
|
|
298
|
+
return {
|
|
299
|
+
id: m.id,
|
|
300
|
+
from: m.fromAgentId,
|
|
301
|
+
to: m.toAgentId ?? 'broadcast',
|
|
302
|
+
category: m.category,
|
|
303
|
+
content: m.content,
|
|
304
|
+
taskId: m.taskId ?? undefined,
|
|
305
|
+
filePaths: m.filePaths?.length > 0 ? m.filePaths : undefined,
|
|
306
|
+
isUnread: m.isUnread,
|
|
307
|
+
age: ts ? `${Math.round((Date.now() - ts) / 1000)}s ago` : null,
|
|
308
|
+
};
|
|
309
|
+
});
|
|
310
|
+
return createSuccessResponse(JSON.stringify({
|
|
311
|
+
action: 'read',
|
|
312
|
+
swarmId,
|
|
313
|
+
forAgent: agentId,
|
|
314
|
+
count: messages.length,
|
|
315
|
+
messages,
|
|
316
|
+
...(cleanedCount > 0 && { expiredCleaned: cleanedCount }),
|
|
317
|
+
}));
|
|
318
|
+
}
|
|
319
|
+
// ── ACKNOWLEDGE ───────────────────────────────────────────────
|
|
320
|
+
if (action === 'acknowledge') {
|
|
321
|
+
if (messageIds && messageIds.length > 0) {
|
|
322
|
+
// Acknowledge specific messages
|
|
323
|
+
const result = await neo4jService.run(ACKNOWLEDGE_MESSAGES_QUERY, {
|
|
324
|
+
messageIds,
|
|
325
|
+
projectId: resolvedProjectId,
|
|
326
|
+
agentId,
|
|
327
|
+
});
|
|
328
|
+
return createSuccessResponse(JSON.stringify({
|
|
329
|
+
action: 'acknowledged',
|
|
330
|
+
count: result.length,
|
|
331
|
+
messageIds: result.map((r) => r.id),
|
|
332
|
+
...(cleanedCount > 0 && { expiredCleaned: cleanedCount }),
|
|
333
|
+
}));
|
|
334
|
+
}
|
|
335
|
+
else {
|
|
336
|
+
// Acknowledge all unread
|
|
337
|
+
const result = await neo4jService.run(ACKNOWLEDGE_ALL_QUERY, {
|
|
338
|
+
projectId: resolvedProjectId,
|
|
339
|
+
swarmId,
|
|
340
|
+
agentId,
|
|
341
|
+
});
|
|
342
|
+
const count = typeof result[0]?.acknowledged === 'object'
|
|
343
|
+
? result[0].acknowledged.toNumber()
|
|
344
|
+
: result[0]?.acknowledged ?? 0;
|
|
345
|
+
return createSuccessResponse(JSON.stringify({
|
|
346
|
+
action: 'acknowledged_all',
|
|
347
|
+
count,
|
|
348
|
+
...(cleanedCount > 0 && { expiredCleaned: cleanedCount }),
|
|
349
|
+
}));
|
|
350
|
+
}
|
|
351
|
+
}
|
|
352
|
+
return createErrorResponse(`Unknown action: ${action}`);
|
|
353
|
+
}
|
|
354
|
+
catch (error) {
|
|
355
|
+
await debugLog('Swarm message error', { error: String(error) });
|
|
356
|
+
return createErrorResponse(error instanceof Error ? error : String(error));
|
|
357
|
+
}
|
|
358
|
+
finally {
|
|
359
|
+
await neo4jService.close();
|
|
360
|
+
}
|
|
361
|
+
});
|
|
362
|
+
};
|
|
@@ -136,19 +136,19 @@ export const QUERIES = {
|
|
|
136
136
|
RETURN labels(n)[0] as nodeType, count(*) as count
|
|
137
137
|
ORDER BY count DESC
|
|
138
138
|
`,
|
|
139
|
-
CREATE_EMBEDDED_VECTOR_INDEX: `
|
|
139
|
+
CREATE_EMBEDDED_VECTOR_INDEX: (dimensions) => `
|
|
140
140
|
CREATE VECTOR INDEX embedded_nodes_idx IF NOT EXISTS
|
|
141
141
|
FOR (n:Embedded) ON (n.embedding)
|
|
142
142
|
OPTIONS {indexConfig: {
|
|
143
|
-
\`vector.dimensions\`:
|
|
143
|
+
\`vector.dimensions\`: ${dimensions},
|
|
144
144
|
\`vector.similarity_function\`: 'cosine'
|
|
145
145
|
}}
|
|
146
146
|
`,
|
|
147
|
-
CREATE_SESSION_NOTES_VECTOR_INDEX: `
|
|
147
|
+
CREATE_SESSION_NOTES_VECTOR_INDEX: (dimensions) => `
|
|
148
148
|
CREATE VECTOR INDEX session_notes_idx IF NOT EXISTS
|
|
149
149
|
FOR (n:SessionNote) ON (n.embedding)
|
|
150
150
|
OPTIONS {indexConfig: {
|
|
151
|
-
\`vector.dimensions\`:
|
|
151
|
+
\`vector.dimensions\`: ${dimensions},
|
|
152
152
|
\`vector.similarity_function\`: 'cosine'
|
|
153
153
|
}}
|
|
154
154
|
`,
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "code-graph-context",
|
|
3
|
-
"version": "2.
|
|
3
|
+
"version": "2.10.0",
|
|
4
4
|
"description": "MCP server that builds code graphs to provide rich context to LLMs",
|
|
5
5
|
"type": "module",
|
|
6
6
|
"homepage": "https://github.com/drewdrewH/code-graph-context#readme",
|
|
@@ -34,6 +34,8 @@
|
|
|
34
34
|
},
|
|
35
35
|
"files": [
|
|
36
36
|
"dist/**/*",
|
|
37
|
+
"sidecar/embedding_server.py",
|
|
38
|
+
"sidecar/requirements.txt",
|
|
37
39
|
"README.md",
|
|
38
40
|
"LICENSE",
|
|
39
41
|
".env.example"
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Local embedding server for code-graph-context.
|
|
3
|
+
Uses Qodo-Embed-1-1.5B for high-quality code embeddings without OpenAI dependency.
|
|
4
|
+
Runs as a sidecar process managed by the Node.js MCP server.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import gc
|
|
8
|
+
import os
|
|
9
|
+
import sys
|
|
10
|
+
import signal
|
|
11
|
+
import logging
|
|
12
|
+
|
|
13
|
+
from fastapi import FastAPI, HTTPException
|
|
14
|
+
from pydantic import BaseModel
|
|
15
|
+
|
|
16
|
+
logging.basicConfig(
|
|
17
|
+
level=logging.INFO,
|
|
18
|
+
format="%(asctime)s [%(levelname)s] %(message)s",
|
|
19
|
+
stream=sys.stderr,
|
|
20
|
+
)
|
|
21
|
+
logger = logging.getLogger("embedding-sidecar")
|
|
22
|
+
|
|
23
|
+
app = FastAPI(title="code-graph-context embedding sidecar")
|
|
24
|
+
|
|
25
|
+
model = None
|
|
26
|
+
model_name = os.environ.get("EMBEDDING_MODEL", "Qodo/Qodo-Embed-1-1.5B")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class EmbedRequest(BaseModel):
|
|
30
|
+
texts: list[str]
|
|
31
|
+
batch_size: int = 8
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class EmbedResponse(BaseModel):
|
|
35
|
+
embeddings: list[list[float]]
|
|
36
|
+
dimensions: int
|
|
37
|
+
model: str
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@app.on_event("startup")
|
|
41
|
+
def load_model():
|
|
42
|
+
global model
|
|
43
|
+
try:
|
|
44
|
+
import torch
|
|
45
|
+
from sentence_transformers import SentenceTransformer
|
|
46
|
+
|
|
47
|
+
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
|
48
|
+
logger.info(f"Loading {model_name} on {device}...")
|
|
49
|
+
model = SentenceTransformer(model_name, device=device)
|
|
50
|
+
|
|
51
|
+
# Warm up with a test embedding
|
|
52
|
+
test = model.encode(["warmup"], show_progress_bar=False)
|
|
53
|
+
dims = len(test[0])
|
|
54
|
+
logger.info(f"Model loaded: {dims} dimensions, device={device}")
|
|
55
|
+
except Exception as e:
|
|
56
|
+
logger.error(f"Failed to load model: {e}")
|
|
57
|
+
raise
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@app.get("/health")
|
|
61
|
+
def health():
|
|
62
|
+
if model is None:
|
|
63
|
+
raise HTTPException(status_code=503, detail="Model not loaded")
|
|
64
|
+
sample = model.encode(["dim_check"], show_progress_bar=False)
|
|
65
|
+
return {
|
|
66
|
+
"status": "ok",
|
|
67
|
+
"model": model_name,
|
|
68
|
+
"dimensions": len(sample[0]),
|
|
69
|
+
"device": str(model.device),
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@app.post("/embed", response_model=EmbedResponse)
|
|
74
|
+
async def embed(req: EmbedRequest):
|
|
75
|
+
if model is None:
|
|
76
|
+
raise HTTPException(status_code=503, detail="Model not loaded")
|
|
77
|
+
|
|
78
|
+
if not req.texts:
|
|
79
|
+
return EmbedResponse(embeddings=[], dimensions=0, model=model_name)
|
|
80
|
+
|
|
81
|
+
try:
|
|
82
|
+
embeddings = _encode_with_oom_fallback(req.texts, req.batch_size)
|
|
83
|
+
dims = len(embeddings[0])
|
|
84
|
+
return EmbedResponse(
|
|
85
|
+
embeddings=embeddings,
|
|
86
|
+
dimensions=dims,
|
|
87
|
+
model=model_name,
|
|
88
|
+
)
|
|
89
|
+
except Exception as e:
|
|
90
|
+
logger.error(f"Embedding error: {e}")
|
|
91
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _encode_with_oom_fallback(texts: list[str], batch_size: int) -> list[list[float]]:
|
|
95
|
+
"""
|
|
96
|
+
Encode texts, falling back to CPU if MPS runs out of memory.
|
|
97
|
+
Also retries with smaller batch sizes before giving up.
|
|
98
|
+
"""
|
|
99
|
+
import torch
|
|
100
|
+
|
|
101
|
+
try:
|
|
102
|
+
result = model.encode(
|
|
103
|
+
texts,
|
|
104
|
+
batch_size=batch_size,
|
|
105
|
+
show_progress_bar=False,
|
|
106
|
+
normalize_embeddings=True,
|
|
107
|
+
)
|
|
108
|
+
return result.tolist()
|
|
109
|
+
except (torch.mps.OutOfMemoryError, RuntimeError) as e:
|
|
110
|
+
if "out of memory" not in str(e).lower():
|
|
111
|
+
raise
|
|
112
|
+
|
|
113
|
+
logger.warning(f"MPS OOM with batch_size={batch_size}, len={len(texts)}. Falling back to CPU.")
|
|
114
|
+
|
|
115
|
+
# Free MPS memory
|
|
116
|
+
if hasattr(torch.mps, "empty_cache"):
|
|
117
|
+
torch.mps.empty_cache()
|
|
118
|
+
gc.collect()
|
|
119
|
+
|
|
120
|
+
# Fall back to CPU for this request
|
|
121
|
+
original_device = model.device
|
|
122
|
+
model.to("cpu")
|
|
123
|
+
|
|
124
|
+
try:
|
|
125
|
+
# Use smaller batches on CPU
|
|
126
|
+
cpu_batch = min(batch_size, 4)
|
|
127
|
+
result = model.encode(
|
|
128
|
+
texts,
|
|
129
|
+
batch_size=cpu_batch,
|
|
130
|
+
show_progress_bar=False,
|
|
131
|
+
normalize_embeddings=True,
|
|
132
|
+
)
|
|
133
|
+
return result.tolist()
|
|
134
|
+
finally:
|
|
135
|
+
# Move back to MPS for future requests
|
|
136
|
+
try:
|
|
137
|
+
model.to(original_device)
|
|
138
|
+
except Exception:
|
|
139
|
+
logger.warning("Could not move model back to MPS, staying on CPU")
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def handle_signal(sig, _frame):
|
|
143
|
+
logger.info(f"Received signal {sig}, shutting down")
|
|
144
|
+
sys.exit(0)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
signal.signal(signal.SIGTERM, handle_signal)
|