@zuvia-software-solutions/code-mapper 2.2.3 → 2.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/cli/analyze.js +5 -1
- package/dist/core/db/adapter.d.ts +6 -0
- package/dist/core/db/adapter.js +102 -2
- package/dist/core/db/schema.d.ts +3 -1
- package/dist/core/db/schema.js +7 -5
- package/models/jina-code-0.5b-mlx/config.json +73 -0
- package/models/jina-code-0.5b-mlx/model.py +127 -0
- package/models/mlx-embedder.py +49 -10
- package/package.json +3 -3
- package/models/jina-v5-small-mlx/config.json +0 -19
- package/models/jina-v5-small-mlx/model.py +0 -260
package/dist/cli/analyze.js
CHANGED
|
@@ -5,7 +5,7 @@ import { execFileSync } from 'child_process';
|
|
|
5
5
|
import v8 from 'v8';
|
|
6
6
|
import cliProgress from 'cli-progress';
|
|
7
7
|
import { runPipelineFromRepo } from '../core/ingestion/pipeline.js';
|
|
8
|
-
import { openDb, closeDb, resetDb, getStats, insertEmbeddingsBatch, countEmbeddings } from '../core/db/adapter.js';
|
|
8
|
+
import { openDb, closeDb, resetDb, getStats, insertEmbeddingsBatch, countEmbeddings, populateSearchText } from '../core/db/adapter.js';
|
|
9
9
|
import { loadGraphToDb } from '../core/db/graph-loader.js';
|
|
10
10
|
import { stitchRoutes } from '../core/ingestion/route-stitcher.js';
|
|
11
11
|
import { toNodeId } from '../core/db/schema.js';
|
|
@@ -200,6 +200,10 @@ export const analyzeCommand = async (inputPath, options) => {
|
|
|
200
200
|
const dbWarnings = dbResult.warnings;
|
|
201
201
|
// Phase 2.5: HTTP route stitching (post-DB-load, needs content field)
|
|
202
202
|
stitchRoutes(db);
|
|
203
|
+
// Phase 2.6: Populate searchText for BM25 concept matching
|
|
204
|
+
// Uses first comment + callers + module — must run after edges are loaded
|
|
205
|
+
updateBar(84, 'Building search index...');
|
|
206
|
+
populateSearchText(db);
|
|
203
207
|
// Phase 3: FTS (85-90%)
|
|
204
208
|
// FTS5 is auto-created by schema triggers — no manual index creation needed
|
|
205
209
|
updateBar(85, 'Search indexes ready');
|
|
@@ -90,6 +90,12 @@ export declare function getStats(db: Database.Database): {
|
|
|
90
90
|
export declare function insertNodesBatch(db: Database.Database, nodes: readonly NodeInsert[]): void;
|
|
91
91
|
/** Batch insert edges in a single transaction. */
|
|
92
92
|
export declare function insertEdgesBatch(db: Database.Database, edges: readonly EdgeInsert[]): void;
|
|
93
|
+
/**
|
|
94
|
+
* Populate the searchText column for all nodes with semantic summaries.
|
|
95
|
+
* Uses first comment + callers + module to enable BM25 concept matching.
|
|
96
|
+
* Call AFTER edges are loaded (needs CALLS and MEMBER_OF edges).
|
|
97
|
+
*/
|
|
98
|
+
export declare function populateSearchText(db: Database.Database): void;
|
|
93
99
|
/** Batch insert embeddings in a single transaction. */
|
|
94
100
|
export declare function insertEmbeddingsBatch(db: Database.Database, items: readonly {
|
|
95
101
|
nodeId: NodeId;
|
package/dist/core/db/adapter.js
CHANGED
|
@@ -145,12 +145,12 @@ const INSERT_NODE_SQL = `
|
|
|
145
145
|
id, label, name, filePath, startLine, endLine, isExported, content, description,
|
|
146
146
|
heuristicLabel, cohesion, symbolCount, keywords, enrichedBy,
|
|
147
147
|
processType, stepCount, communities, entryPointId, terminalId,
|
|
148
|
-
parameterCount, returnType, nameExpanded
|
|
148
|
+
parameterCount, returnType, nameExpanded, searchText
|
|
149
149
|
) VALUES (
|
|
150
150
|
@id, @label, @name, @filePath, @startLine, @endLine, @isExported, @content, @description,
|
|
151
151
|
@heuristicLabel, @cohesion, @symbolCount, @keywords, @enrichedBy,
|
|
152
152
|
@processType, @stepCount, @communities, @entryPointId, @terminalId,
|
|
153
|
-
@parameterCount, @returnType, @nameExpanded
|
|
153
|
+
@parameterCount, @returnType, @nameExpanded, @searchText
|
|
154
154
|
)
|
|
155
155
|
`;
|
|
156
156
|
/** Insert or replace a node. Automatically expands name for FTS natural language matching. */
|
|
@@ -178,6 +178,7 @@ export function insertNode(db, node) {
|
|
|
178
178
|
parameterCount: node.parameterCount ?? null,
|
|
179
179
|
returnType: node.returnType ?? null,
|
|
180
180
|
nameExpanded: node.nameExpanded ?? expandIdentifier(node.name ?? ''),
|
|
181
|
+
searchText: node.searchText ?? '',
|
|
181
182
|
});
|
|
182
183
|
}
|
|
183
184
|
/** Get a node by ID. Returns undefined if not found. */
|
|
@@ -373,6 +374,7 @@ export function insertNodesBatch(db, nodes) {
|
|
|
373
374
|
terminalId: node.terminalId ?? null, parameterCount: node.parameterCount ?? null,
|
|
374
375
|
returnType: node.returnType ?? null,
|
|
375
376
|
nameExpanded: node.nameExpanded ?? expandIdentifier(node.name ?? ''),
|
|
377
|
+
searchText: node.searchText ?? '',
|
|
376
378
|
});
|
|
377
379
|
}
|
|
378
380
|
});
|
|
@@ -392,6 +394,104 @@ export function insertEdgesBatch(db, edges) {
|
|
|
392
394
|
});
|
|
393
395
|
txn(edges);
|
|
394
396
|
}
|
|
397
|
+
/**
|
|
398
|
+
* Populate the searchText column for all nodes with semantic summaries.
|
|
399
|
+
* Uses first comment + callers + module to enable BM25 concept matching.
|
|
400
|
+
* Call AFTER edges are loaded (needs CALLS and MEMBER_OF edges).
|
|
401
|
+
*/
|
|
402
|
+
export function populateSearchText(db) {
|
|
403
|
+
// Extract first comment from content
|
|
404
|
+
function extractComment(content) {
|
|
405
|
+
if (!content)
|
|
406
|
+
return '';
|
|
407
|
+
const lines = content.split('\n');
|
|
408
|
+
const out = [];
|
|
409
|
+
let inBlock = false;
|
|
410
|
+
for (const l of lines) {
|
|
411
|
+
const t = l.trim();
|
|
412
|
+
if (t.startsWith('/**') || t.startsWith('/*')) {
|
|
413
|
+
inBlock = true;
|
|
414
|
+
const inner = t.replace(/^\/\*\*?\s*/, '').replace(/\*\/\s*$/, '').trim();
|
|
415
|
+
if (inner && !inner.startsWith('@'))
|
|
416
|
+
out.push(inner);
|
|
417
|
+
if (t.includes('*/'))
|
|
418
|
+
inBlock = false;
|
|
419
|
+
continue;
|
|
420
|
+
}
|
|
421
|
+
if (inBlock) {
|
|
422
|
+
if (t.includes('*/')) {
|
|
423
|
+
inBlock = false;
|
|
424
|
+
continue;
|
|
425
|
+
}
|
|
426
|
+
const inner = t.replace(/^\*\s?/, '').trim();
|
|
427
|
+
if (inner && !inner.startsWith('@'))
|
|
428
|
+
out.push(inner);
|
|
429
|
+
if (out.length >= 3)
|
|
430
|
+
break;
|
|
431
|
+
continue;
|
|
432
|
+
}
|
|
433
|
+
if (t.startsWith('//')) {
|
|
434
|
+
const inner = t.slice(2).trim();
|
|
435
|
+
if (inner)
|
|
436
|
+
out.push(inner);
|
|
437
|
+
if (out.length >= 3)
|
|
438
|
+
break;
|
|
439
|
+
continue;
|
|
440
|
+
}
|
|
441
|
+
if (out.length > 0)
|
|
442
|
+
break;
|
|
443
|
+
}
|
|
444
|
+
return out.join(' ');
|
|
445
|
+
}
|
|
446
|
+
const nodes = db.prepare("SELECT id, name, nameExpanded, content FROM nodes WHERE label IN ('Function','Class','Method','Interface','Const','TypeAlias','Enum')").all();
|
|
447
|
+
if (nodes.length === 0)
|
|
448
|
+
return;
|
|
449
|
+
// Batch fetch callers + module
|
|
450
|
+
const callerMap = new Map();
|
|
451
|
+
const moduleMap = new Map();
|
|
452
|
+
const ids = nodes.map(n => n.id);
|
|
453
|
+
for (let i = 0; i < ids.length; i += 900) {
|
|
454
|
+
const chunk = ids.slice(i, i + 900);
|
|
455
|
+
const ph = chunk.map(() => '?').join(',');
|
|
456
|
+
const callerRows = db.prepare(`SELECT e.targetId AS nid, n.name FROM edges e JOIN nodes n ON n.id = e.sourceId WHERE e.targetId IN (${ph}) AND e.type = 'CALLS' AND e.confidence >= 0.7`).all(...chunk);
|
|
457
|
+
for (const r of callerRows) {
|
|
458
|
+
if (!callerMap.has(r.nid))
|
|
459
|
+
callerMap.set(r.nid, []);
|
|
460
|
+
callerMap.get(r.nid).push(r.name);
|
|
461
|
+
}
|
|
462
|
+
const modRows = db.prepare(`SELECT e.sourceId AS nid, c.heuristicLabel AS module FROM edges e JOIN nodes c ON c.id = e.targetId WHERE e.sourceId IN (${ph}) AND e.type = 'MEMBER_OF' AND c.label = 'Community'`).all(...chunk);
|
|
463
|
+
for (const r of modRows)
|
|
464
|
+
moduleMap.set(r.nid, r.module);
|
|
465
|
+
}
|
|
466
|
+
// Build searchText and update
|
|
467
|
+
// Drop FTS triggers temporarily to avoid column-count issues during bulk update,
|
|
468
|
+
// then rebuild the FTS index in one pass (faster than per-row trigger updates)
|
|
469
|
+
db.exec("DROP TRIGGER IF EXISTS nodes_fts_au");
|
|
470
|
+
const txn = db.transaction(() => {
|
|
471
|
+
for (const node of nodes) {
|
|
472
|
+
const parts = [];
|
|
473
|
+
if (node.nameExpanded)
|
|
474
|
+
parts.push(node.nameExpanded);
|
|
475
|
+
const comment = extractComment(node.content);
|
|
476
|
+
if (comment)
|
|
477
|
+
parts.push(comment);
|
|
478
|
+
const callers = callerMap.get(node.id)?.slice(0, 5);
|
|
479
|
+
if (callers && callers.length > 0)
|
|
480
|
+
parts.push(callers.map(c => expandIdentifier(c)).join(' '));
|
|
481
|
+
const mod = moduleMap.get(node.id);
|
|
482
|
+
if (mod)
|
|
483
|
+
parts.push(mod);
|
|
484
|
+
db.prepare('UPDATE nodes SET searchText = ? WHERE id = ?').run(parts.join(' | '), node.id);
|
|
485
|
+
}
|
|
486
|
+
});
|
|
487
|
+
txn();
|
|
488
|
+
// Rebuild FTS index from scratch and recreate the trigger
|
|
489
|
+
db.exec("INSERT INTO nodes_fts(nodes_fts) VALUES('rebuild')");
|
|
490
|
+
db.exec(`CREATE TRIGGER IF NOT EXISTS nodes_fts_au AFTER UPDATE ON nodes BEGIN
|
|
491
|
+
INSERT INTO nodes_fts(nodes_fts, rowid, name, nameExpanded, searchText, filePath, content) VALUES ('delete', old.rowid, old.name, old.nameExpanded, old.searchText, old.filePath, old.content);
|
|
492
|
+
INSERT INTO nodes_fts(nodes_fts, rowid, name, nameExpanded, searchText, filePath, content) VALUES (new.rowid, new.name, new.nameExpanded, new.searchText, new.filePath, new.content);
|
|
493
|
+
END`);
|
|
494
|
+
}
|
|
395
495
|
/** Batch insert embeddings in a single transaction. */
|
|
396
496
|
export function insertEmbeddingsBatch(db, items) {
|
|
397
497
|
const stmt = db.prepare('INSERT OR REPLACE INTO embeddings (nodeId, embedding, textHash) VALUES (?, ?, ?)');
|
package/dist/core/db/schema.d.ts
CHANGED
|
@@ -49,6 +49,7 @@ export interface NodeRow {
|
|
|
49
49
|
readonly parameterCount: number | null;
|
|
50
50
|
readonly returnType: string | null;
|
|
51
51
|
readonly nameExpanded: string;
|
|
52
|
+
readonly searchText: string;
|
|
52
53
|
}
|
|
53
54
|
/** An edge row as stored in the `edges` table */
|
|
54
55
|
export interface EdgeRow {
|
|
@@ -91,6 +92,7 @@ export interface NodeInsert {
|
|
|
91
92
|
readonly parameterCount?: number | null;
|
|
92
93
|
readonly returnType?: string | null;
|
|
93
94
|
readonly nameExpanded?: string;
|
|
95
|
+
readonly searchText?: string;
|
|
94
96
|
}
|
|
95
97
|
/** Fields required to insert an edge */
|
|
96
98
|
export interface EdgeInsert {
|
|
@@ -105,4 +107,4 @@ export interface EdgeInsert {
|
|
|
105
107
|
}
|
|
106
108
|
/** Legacy edge table name constant (kept for compatibility) */
|
|
107
109
|
export declare const REL_TABLE_NAME = "CodeRelation";
|
|
108
|
-
export declare const SCHEMA_SQL = "\n-- Nodes: unified table for all code elements\nCREATE TABLE IF NOT EXISTS nodes (\n id TEXT PRIMARY KEY,\n label TEXT NOT NULL,\n name TEXT NOT NULL DEFAULT '',\n filePath TEXT NOT NULL DEFAULT '',\n startLine INTEGER,\n endLine INTEGER,\n isExported INTEGER,\n content TEXT NOT NULL DEFAULT '',\n description TEXT NOT NULL DEFAULT '',\n heuristicLabel TEXT,\n cohesion REAL,\n symbolCount INTEGER,\n keywords TEXT,\n enrichedBy TEXT,\n processType TEXT,\n stepCount INTEGER,\n communities TEXT,\n entryPointId TEXT,\n terminalId TEXT,\n parameterCount INTEGER,\n returnType TEXT,\n nameExpanded TEXT DEFAULT ''\n);\n\nCREATE INDEX IF NOT EXISTS idx_nodes_label ON nodes(label);\nCREATE INDEX IF NOT EXISTS idx_nodes_name ON nodes(name);\nCREATE INDEX IF NOT EXISTS idx_nodes_filePath ON nodes(filePath);\nCREATE INDEX IF NOT EXISTS idx_nodes_label_name ON nodes(label, name);\nCREATE INDEX IF NOT EXISTS idx_nodes_filePath_lines ON nodes(filePath, startLine, endLine);\n\n-- Edges: single table for all relationships\nCREATE TABLE IF NOT EXISTS edges (\n id TEXT PRIMARY KEY,\n sourceId TEXT NOT NULL,\n targetId TEXT NOT NULL,\n type TEXT NOT NULL,\n confidence REAL NOT NULL DEFAULT 1.0,\n reason TEXT NOT NULL DEFAULT '',\n step INTEGER NOT NULL DEFAULT 0,\n callLine INTEGER\n);\n\nCREATE INDEX IF NOT EXISTS idx_edges_sourceId ON edges(sourceId);\nCREATE INDEX IF NOT EXISTS idx_edges_targetId ON edges(targetId);\nCREATE INDEX IF NOT EXISTS idx_edges_type ON edges(type);\nCREATE INDEX IF NOT EXISTS idx_edges_source_type ON edges(sourceId, type);\nCREATE INDEX IF NOT EXISTS idx_edges_target_type ON edges(targetId, type);\n\n-- Embeddings: vector storage\nCREATE TABLE IF NOT EXISTS embeddings (\n nodeId TEXT PRIMARY KEY,\n embedding BLOB NOT NULL,\n textHash TEXT\n);\n\n-- FTS5 virtual table (auto-updated via triggers)\nCREATE VIRTUAL TABLE IF NOT EXISTS nodes_fts USING fts5(\n name,\n nameExpanded,\n filePath,\n content,\n content='nodes',\n content_rowid='rowid'\n);\n\nCREATE TRIGGER IF NOT EXISTS nodes_fts_ai AFTER INSERT ON nodes BEGIN\n INSERT INTO nodes_fts(rowid, name, nameExpanded, filePath, content) VALUES (new.rowid, new.name, new.nameExpanded, new.filePath, new.content);\nEND;\nCREATE TRIGGER IF NOT EXISTS nodes_fts_ad AFTER DELETE ON nodes BEGIN\n INSERT INTO nodes_fts(nodes_fts, rowid, name, nameExpanded, filePath, content) VALUES ('delete', old.rowid, old.name, old.nameExpanded, old.filePath, old.content);\nEND;\nCREATE TRIGGER IF NOT EXISTS nodes_fts_au AFTER UPDATE ON nodes BEGIN\n INSERT INTO nodes_fts(nodes_fts, rowid, name, nameExpanded, filePath, content) VALUES ('delete', old.rowid, old.name, old.nameExpanded, old.filePath, old.content);\n INSERT INTO nodes_fts(nodes_fts, rowid, name, nameExpanded, filePath, content) VALUES (new.rowid, new.name, new.nameExpanded, new.filePath, new.content);\nEND;\n";
|
|
110
|
+
export declare const SCHEMA_SQL = "\n-- Nodes: unified table for all code elements\nCREATE TABLE IF NOT EXISTS nodes (\n id TEXT PRIMARY KEY,\n label TEXT NOT NULL,\n name TEXT NOT NULL DEFAULT '',\n filePath TEXT NOT NULL DEFAULT '',\n startLine INTEGER,\n endLine INTEGER,\n isExported INTEGER,\n content TEXT NOT NULL DEFAULT '',\n description TEXT NOT NULL DEFAULT '',\n heuristicLabel TEXT,\n cohesion REAL,\n symbolCount INTEGER,\n keywords TEXT,\n enrichedBy TEXT,\n processType TEXT,\n stepCount INTEGER,\n communities TEXT,\n entryPointId TEXT,\n terminalId TEXT,\n parameterCount INTEGER,\n returnType TEXT,\n nameExpanded TEXT DEFAULT '',\n searchText TEXT DEFAULT ''\n);\n\nCREATE INDEX IF NOT EXISTS idx_nodes_label ON nodes(label);\nCREATE INDEX IF NOT EXISTS idx_nodes_name ON nodes(name);\nCREATE INDEX IF NOT EXISTS idx_nodes_filePath ON nodes(filePath);\nCREATE INDEX IF NOT EXISTS idx_nodes_label_name ON nodes(label, name);\nCREATE INDEX IF NOT EXISTS idx_nodes_filePath_lines ON nodes(filePath, startLine, endLine);\n\n-- Edges: single table for all relationships\nCREATE TABLE IF NOT EXISTS edges (\n id TEXT PRIMARY KEY,\n sourceId TEXT NOT NULL,\n targetId TEXT NOT NULL,\n type TEXT NOT NULL,\n confidence REAL NOT NULL DEFAULT 1.0,\n reason TEXT NOT NULL DEFAULT '',\n step INTEGER NOT NULL DEFAULT 0,\n callLine INTEGER\n);\n\nCREATE INDEX IF NOT EXISTS idx_edges_sourceId ON edges(sourceId);\nCREATE INDEX IF NOT EXISTS idx_edges_targetId ON edges(targetId);\nCREATE INDEX IF NOT EXISTS idx_edges_type ON edges(type);\nCREATE INDEX IF NOT EXISTS idx_edges_source_type ON edges(sourceId, type);\nCREATE INDEX IF NOT EXISTS idx_edges_target_type ON edges(targetId, type);\n\n-- Embeddings: vector storage\nCREATE TABLE IF NOT EXISTS embeddings (\n nodeId TEXT PRIMARY KEY,\n embedding BLOB NOT NULL,\n textHash TEXT\n);\n\n-- FTS5 virtual table (auto-updated via triggers)\nCREATE VIRTUAL TABLE IF NOT EXISTS nodes_fts USING fts5(\n name,\n nameExpanded,\n searchText,\n filePath,\n content,\n content='nodes',\n content_rowid='rowid'\n);\n\nCREATE TRIGGER IF NOT EXISTS nodes_fts_ai AFTER INSERT ON nodes BEGIN\n INSERT INTO nodes_fts(rowid, name, nameExpanded, searchText, filePath, content) VALUES (new.rowid, new.name, new.nameExpanded, new.searchText, new.filePath, new.content);\nEND;\nCREATE TRIGGER IF NOT EXISTS nodes_fts_ad AFTER DELETE ON nodes BEGIN\n INSERT INTO nodes_fts(nodes_fts, rowid, name, nameExpanded, searchText, filePath, content) VALUES ('delete', old.rowid, old.name, old.nameExpanded, old.searchText, old.filePath, old.content);\nEND;\nCREATE TRIGGER IF NOT EXISTS nodes_fts_au AFTER UPDATE ON nodes BEGIN\n INSERT INTO nodes_fts(nodes_fts, rowid, name, nameExpanded, searchText, filePath, content) VALUES ('delete', old.rowid, old.name, old.nameExpanded, old.searchText, old.filePath, old.content);\n INSERT INTO nodes_fts(nodes_fts, rowid, name, nameExpanded, searchText, filePath, content) VALUES (new.rowid, new.name, new.nameExpanded, new.searchText, new.filePath, new.content);\nEND;\n";
|
package/dist/core/db/schema.js
CHANGED
|
@@ -79,7 +79,8 @@ CREATE TABLE IF NOT EXISTS nodes (
|
|
|
79
79
|
terminalId TEXT,
|
|
80
80
|
parameterCount INTEGER,
|
|
81
81
|
returnType TEXT,
|
|
82
|
-
nameExpanded TEXT DEFAULT ''
|
|
82
|
+
nameExpanded TEXT DEFAULT '',
|
|
83
|
+
searchText TEXT DEFAULT ''
|
|
83
84
|
);
|
|
84
85
|
|
|
85
86
|
CREATE INDEX IF NOT EXISTS idx_nodes_label ON nodes(label);
|
|
@@ -117,6 +118,7 @@ CREATE TABLE IF NOT EXISTS embeddings (
|
|
|
117
118
|
CREATE VIRTUAL TABLE IF NOT EXISTS nodes_fts USING fts5(
|
|
118
119
|
name,
|
|
119
120
|
nameExpanded,
|
|
121
|
+
searchText,
|
|
120
122
|
filePath,
|
|
121
123
|
content,
|
|
122
124
|
content='nodes',
|
|
@@ -124,13 +126,13 @@ CREATE VIRTUAL TABLE IF NOT EXISTS nodes_fts USING fts5(
|
|
|
124
126
|
);
|
|
125
127
|
|
|
126
128
|
CREATE TRIGGER IF NOT EXISTS nodes_fts_ai AFTER INSERT ON nodes BEGIN
|
|
127
|
-
INSERT INTO nodes_fts(rowid, name, nameExpanded, filePath, content) VALUES (new.rowid, new.name, new.nameExpanded, new.filePath, new.content);
|
|
129
|
+
INSERT INTO nodes_fts(rowid, name, nameExpanded, searchText, filePath, content) VALUES (new.rowid, new.name, new.nameExpanded, new.searchText, new.filePath, new.content);
|
|
128
130
|
END;
|
|
129
131
|
CREATE TRIGGER IF NOT EXISTS nodes_fts_ad AFTER DELETE ON nodes BEGIN
|
|
130
|
-
INSERT INTO nodes_fts(nodes_fts, rowid, name, nameExpanded, filePath, content) VALUES ('delete', old.rowid, old.name, old.nameExpanded, old.filePath, old.content);
|
|
132
|
+
INSERT INTO nodes_fts(nodes_fts, rowid, name, nameExpanded, searchText, filePath, content) VALUES ('delete', old.rowid, old.name, old.nameExpanded, old.searchText, old.filePath, old.content);
|
|
131
133
|
END;
|
|
132
134
|
CREATE TRIGGER IF NOT EXISTS nodes_fts_au AFTER UPDATE ON nodes BEGIN
|
|
133
|
-
INSERT INTO nodes_fts(nodes_fts, rowid, name, nameExpanded, filePath, content) VALUES ('delete', old.rowid, old.name, old.nameExpanded, old.filePath, old.content);
|
|
134
|
-
INSERT INTO nodes_fts(nodes_fts, rowid, name, nameExpanded, filePath, content) VALUES (new.rowid, new.name, new.nameExpanded, new.filePath, new.content);
|
|
135
|
+
INSERT INTO nodes_fts(nodes_fts, rowid, name, nameExpanded, searchText, filePath, content) VALUES ('delete', old.rowid, old.name, old.nameExpanded, old.searchText, old.filePath, old.content);
|
|
136
|
+
INSERT INTO nodes_fts(nodes_fts, rowid, name, nameExpanded, searchText, filePath, content) VALUES (new.rowid, new.name, new.nameExpanded, new.searchText, new.filePath, new.content);
|
|
135
137
|
END;
|
|
136
138
|
`;
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
{
|
|
2
|
+
"architectures": [
|
|
3
|
+
"Qwen2ForCausalLM"
|
|
4
|
+
],
|
|
5
|
+
"attention_dropout": 0.0,
|
|
6
|
+
"bos_token_id": 151643,
|
|
7
|
+
"eos_token_id": 151643,
|
|
8
|
+
"hidden_act": "silu",
|
|
9
|
+
"hidden_size": 896,
|
|
10
|
+
"initializer_range": 0.02,
|
|
11
|
+
"intermediate_size": 4864,
|
|
12
|
+
"layer_types": [
|
|
13
|
+
"full_attention",
|
|
14
|
+
"full_attention",
|
|
15
|
+
"full_attention",
|
|
16
|
+
"full_attention",
|
|
17
|
+
"full_attention",
|
|
18
|
+
"full_attention",
|
|
19
|
+
"full_attention",
|
|
20
|
+
"full_attention",
|
|
21
|
+
"full_attention",
|
|
22
|
+
"full_attention",
|
|
23
|
+
"full_attention",
|
|
24
|
+
"full_attention",
|
|
25
|
+
"full_attention",
|
|
26
|
+
"full_attention",
|
|
27
|
+
"full_attention",
|
|
28
|
+
"full_attention",
|
|
29
|
+
"full_attention",
|
|
30
|
+
"full_attention",
|
|
31
|
+
"full_attention",
|
|
32
|
+
"full_attention",
|
|
33
|
+
"full_attention",
|
|
34
|
+
"full_attention",
|
|
35
|
+
"full_attention",
|
|
36
|
+
"full_attention"
|
|
37
|
+
],
|
|
38
|
+
"matryoshka_dims": [
|
|
39
|
+
64,
|
|
40
|
+
128,
|
|
41
|
+
256,
|
|
42
|
+
512,
|
|
43
|
+
896
|
|
44
|
+
],
|
|
45
|
+
"max_position_embeddings": 32768,
|
|
46
|
+
"max_window_layers": 24,
|
|
47
|
+
"model_type": "qwen2",
|
|
48
|
+
"num_attention_heads": 14,
|
|
49
|
+
"num_hidden_layers": 24,
|
|
50
|
+
"num_key_value_heads": 2,
|
|
51
|
+
"prompt_names": [
|
|
52
|
+
"query",
|
|
53
|
+
"passage"
|
|
54
|
+
],
|
|
55
|
+
"rms_norm_eps": 1e-06,
|
|
56
|
+
"rope_scaling": null,
|
|
57
|
+
"rope_theta": 1000000.0,
|
|
58
|
+
"sliding_window": null,
|
|
59
|
+
"task_names": [
|
|
60
|
+
"nl2code",
|
|
61
|
+
"qa",
|
|
62
|
+
"code2code",
|
|
63
|
+
"code2nl",
|
|
64
|
+
"code2completion"
|
|
65
|
+
],
|
|
66
|
+
"tie_word_embeddings": true,
|
|
67
|
+
"tokenizer_class": "Qwen2TokenizerFast",
|
|
68
|
+
"torch_dtype": "bfloat16",
|
|
69
|
+
"transformers_version": "4.53.0",
|
|
70
|
+
"use_cache": true,
|
|
71
|
+
"use_sliding_window": false,
|
|
72
|
+
"vocab_size": 151936
|
|
73
|
+
}
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Optional, List
|
|
4
|
+
import mlx.core as mx
|
|
5
|
+
import mlx.nn as nn
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class ModelArgs:
|
|
9
|
+
hidden_size: int
|
|
10
|
+
num_hidden_layers: int
|
|
11
|
+
intermediate_size: int
|
|
12
|
+
num_attention_heads: int
|
|
13
|
+
rms_norm_eps: float
|
|
14
|
+
vocab_size: int
|
|
15
|
+
num_key_value_heads: int
|
|
16
|
+
max_position_embeddings: int
|
|
17
|
+
rope_theta: float = 1000000.0
|
|
18
|
+
tie_word_embeddings: bool = True
|
|
19
|
+
|
|
20
|
+
class Attention(nn.Module):
|
|
21
|
+
def __init__(self, args):
|
|
22
|
+
super().__init__()
|
|
23
|
+
dim = args.hidden_size
|
|
24
|
+
self.n_heads = args.num_attention_heads
|
|
25
|
+
self.n_kv_heads = args.num_key_value_heads
|
|
26
|
+
self.head_dim = dim // self.n_heads
|
|
27
|
+
self.scale = self.head_dim ** -0.5
|
|
28
|
+
self.rope_theta = args.rope_theta
|
|
29
|
+
self.q_proj = nn.Linear(dim, self.n_heads * self.head_dim, bias=True)
|
|
30
|
+
self.k_proj = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=True)
|
|
31
|
+
self.v_proj = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=True)
|
|
32
|
+
self.o_proj = nn.Linear(self.n_heads * self.head_dim, dim, bias=False)
|
|
33
|
+
|
|
34
|
+
def __call__(self, x, mask=None):
|
|
35
|
+
B, L, D = x.shape
|
|
36
|
+
q = self.q_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
|
|
37
|
+
k = self.k_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
|
|
38
|
+
v = self.v_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
|
|
39
|
+
q = mx.fast.rope(q, self.head_dim, traditional=False, base=self.rope_theta, scale=1.0, offset=0)
|
|
40
|
+
k = mx.fast.rope(k, self.head_dim, traditional=False, base=self.rope_theta, scale=1.0, offset=0)
|
|
41
|
+
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask.astype(q.dtype) if mask is not None else None, scale=self.scale)
|
|
42
|
+
return self.o_proj(out.transpose(0, 2, 1, 3).reshape(B, L, -1))
|
|
43
|
+
|
|
44
|
+
class MLP(nn.Module):
|
|
45
|
+
def __init__(self, dim, hidden):
|
|
46
|
+
super().__init__()
|
|
47
|
+
self.gate_proj = nn.Linear(dim, hidden, bias=False)
|
|
48
|
+
self.down_proj = nn.Linear(hidden, dim, bias=False)
|
|
49
|
+
self.up_proj = nn.Linear(dim, hidden, bias=False)
|
|
50
|
+
def __call__(self, x):
|
|
51
|
+
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
|
52
|
+
|
|
53
|
+
class TransformerBlock(nn.Module):
|
|
54
|
+
def __init__(self, args):
|
|
55
|
+
super().__init__()
|
|
56
|
+
self.self_attn = Attention(args)
|
|
57
|
+
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
|
58
|
+
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
|
59
|
+
self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
|
60
|
+
def __call__(self, x, mask=None):
|
|
61
|
+
h = x + self.self_attn(self.input_layernorm(x), mask)
|
|
62
|
+
return h + self.mlp(self.post_attention_layernorm(h))
|
|
63
|
+
|
|
64
|
+
class Qwen2Model(nn.Module):
|
|
65
|
+
def __init__(self, args):
|
|
66
|
+
super().__init__()
|
|
67
|
+
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
|
68
|
+
self.layers = [TransformerBlock(args) for _ in range(args.num_hidden_layers)]
|
|
69
|
+
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
|
70
|
+
def __call__(self, inputs, mask=None):
|
|
71
|
+
h = self.embed_tokens(inputs)
|
|
72
|
+
for layer in self.layers:
|
|
73
|
+
h = layer(h, mask)
|
|
74
|
+
return self.norm(h)
|
|
75
|
+
|
|
76
|
+
class JinaCodeEmbeddingModel(nn.Module):
|
|
77
|
+
def __init__(self, config):
|
|
78
|
+
super().__init__()
|
|
79
|
+
args = ModelArgs(
|
|
80
|
+
hidden_size=config["hidden_size"],
|
|
81
|
+
num_hidden_layers=config["num_hidden_layers"],
|
|
82
|
+
intermediate_size=config["intermediate_size"],
|
|
83
|
+
num_attention_heads=config["num_attention_heads"],
|
|
84
|
+
rms_norm_eps=config["rms_norm_eps"],
|
|
85
|
+
vocab_size=config["vocab_size"],
|
|
86
|
+
num_key_value_heads=config["num_key_value_heads"],
|
|
87
|
+
max_position_embeddings=config["max_position_embeddings"],
|
|
88
|
+
rope_theta=config.get("rope_theta", 1000000.0),
|
|
89
|
+
)
|
|
90
|
+
self.model = Qwen2Model(args)
|
|
91
|
+
self.config = config
|
|
92
|
+
|
|
93
|
+
def __call__(self, input_ids, attention_mask=None):
|
|
94
|
+
B, L = input_ids.shape
|
|
95
|
+
causal = mx.tril(mx.ones((L, L)))
|
|
96
|
+
causal = mx.where(causal == 0, -1e4, 0.0)[None, None, :, :]
|
|
97
|
+
if attention_mask is not None:
|
|
98
|
+
pad = mx.where(attention_mask == 0, -1e4, 0.0)[:, None, None, :]
|
|
99
|
+
mask = causal + pad
|
|
100
|
+
else:
|
|
101
|
+
mask = causal
|
|
102
|
+
h = self.model(input_ids, mask)
|
|
103
|
+
if attention_mask is not None:
|
|
104
|
+
seq_lens = mx.sum(attention_mask.astype(mx.int32), axis=1) - 1
|
|
105
|
+
embs = h[mx.arange(B), seq_lens]
|
|
106
|
+
else:
|
|
107
|
+
embs = h[:, -1, :]
|
|
108
|
+
norms = mx.linalg.norm(embs, axis=1, keepdims=True)
|
|
109
|
+
return embs / norms
|
|
110
|
+
|
|
111
|
+
def encode(self, texts, tokenizer, max_length=8192, truncate_dim=None, task="nl2code", prompt_type="query"):
|
|
112
|
+
PREFIXES = {"nl2code": {"query": "Find the most relevant code snippet given the following query:\n", "passage": "Candidate code snippet:\n"}}
|
|
113
|
+
prefix = PREFIXES.get(task, {}).get(prompt_type, "")
|
|
114
|
+
if prefix:
|
|
115
|
+
texts = [prefix + t for t in texts]
|
|
116
|
+
encodings = tokenizer.encode_batch(texts)
|
|
117
|
+
ml = min(max_length, max(len(e.ids) for e in encodings))
|
|
118
|
+
iids, amask = [], []
|
|
119
|
+
for e in encodings:
|
|
120
|
+
ids = e.ids[:ml]; m = e.attention_mask[:ml]; p = ml - len(ids)
|
|
121
|
+
if p > 0: ids = ids + [0]*p; m = m + [0]*p
|
|
122
|
+
iids.append(ids); amask.append(m)
|
|
123
|
+
embs = self(mx.array(iids), mx.array(amask))
|
|
124
|
+
if truncate_dim:
|
|
125
|
+
embs = embs[:, :truncate_dim]
|
|
126
|
+
embs = embs / mx.linalg.norm(embs, axis=1, keepdims=True)
|
|
127
|
+
return embs
|
package/models/mlx-embedder.py
CHANGED
|
@@ -27,7 +27,7 @@ import mlx.core as mx
|
|
|
27
27
|
import mlx.nn as nn
|
|
28
28
|
from tokenizers import Tokenizer
|
|
29
29
|
|
|
30
|
-
MODEL_DIR = os.path.dirname(os.path.abspath(__file__)) + "/jina-
|
|
30
|
+
MODEL_DIR = os.path.dirname(os.path.abspath(__file__)) + "/jina-code-0.5b-mlx"
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
|
|
@@ -35,27 +35,54 @@ MODEL_DIR = os.path.dirname(os.path.abspath(__file__)) + "/jina-v5-small-mlx"
|
|
|
35
35
|
|
|
36
36
|
|
|
37
37
|
def ensure_model_downloaded():
|
|
38
|
-
"""Download model weights from HuggingFace if not present."""
|
|
38
|
+
"""Download and convert model weights from HuggingFace if not present."""
|
|
39
39
|
weights_path = os.path.join(MODEL_DIR, "model.safetensors")
|
|
40
40
|
if os.path.exists(weights_path):
|
|
41
41
|
return
|
|
42
42
|
|
|
43
|
-
|
|
43
|
+
os.makedirs(MODEL_DIR, exist_ok=True)
|
|
44
|
+
print(json.dumps({"phase": "downloading", "message": "Downloading Jina Code 0.5B embedding model (~940MB, first time only)..."}), flush=True)
|
|
45
|
+
|
|
44
46
|
try:
|
|
45
47
|
from huggingface_hub import hf_hub_download
|
|
46
48
|
import shutil
|
|
47
|
-
|
|
48
|
-
|
|
49
|
+
|
|
50
|
+
repo = "jinaai/jina-code-embeddings-0.5b"
|
|
51
|
+
|
|
52
|
+
# Download tokenizer from the v5 MLX model (same Qwen tokenizer, has tokenizer.json)
|
|
53
|
+
tokenizer_repo = "jinaai/jina-embeddings-v5-text-small-retrieval-mlx"
|
|
54
|
+
for fname in ["tokenizer.json"]:
|
|
55
|
+
dest = os.path.join(MODEL_DIR, fname)
|
|
56
|
+
if not os.path.exists(dest):
|
|
57
|
+
path = hf_hub_download(tokenizer_repo, fname)
|
|
58
|
+
shutil.copy(path, dest)
|
|
59
|
+
|
|
60
|
+
# Download vocab/merges from the 0.5B model
|
|
61
|
+
for fname in ["vocab.json", "merges.txt"]:
|
|
49
62
|
dest = os.path.join(MODEL_DIR, fname)
|
|
50
63
|
if not os.path.exists(dest):
|
|
51
64
|
path = hf_hub_download(repo, fname)
|
|
52
65
|
shutil.copy(path, dest)
|
|
53
|
-
|
|
66
|
+
|
|
67
|
+
# Download and convert weights: bf16 → fp16, add 'model.' key prefix
|
|
68
|
+
print(json.dumps({"phase": "converting", "message": "Converting model weights (bf16 → fp16)..."}), flush=True)
|
|
69
|
+
raw_path = hf_hub_download(repo, "model.safetensors")
|
|
70
|
+
raw_weights = mx.load(raw_path)
|
|
71
|
+
converted = {}
|
|
72
|
+
for k, v in raw_weights.items():
|
|
73
|
+
new_key = "model." + k
|
|
74
|
+
if v.dtype == mx.bfloat16:
|
|
75
|
+
v = v.astype(mx.float16)
|
|
76
|
+
converted[new_key] = v
|
|
77
|
+
mx.save_safetensors(weights_path, converted)
|
|
78
|
+
|
|
79
|
+
print(json.dumps({"phase": "downloaded", "message": f"Model ready ({len(converted)} weights converted)"}), flush=True)
|
|
80
|
+
|
|
54
81
|
except ImportError:
|
|
55
82
|
raise RuntimeError(
|
|
56
83
|
"Model weights not found. Install huggingface_hub to auto-download:\n"
|
|
57
84
|
" pip3 install huggingface_hub\n"
|
|
58
|
-
"Or manually download from: https://huggingface.co/jinaai/jina-embeddings-
|
|
85
|
+
"Or manually download from: https://huggingface.co/jinaai/jina-code-embeddings-0.5b"
|
|
59
86
|
)
|
|
60
87
|
|
|
61
88
|
|
|
@@ -64,7 +91,10 @@ def load_model():
|
|
|
64
91
|
ensure_model_downloaded()
|
|
65
92
|
|
|
66
93
|
sys.path.insert(0, MODEL_DIR)
|
|
67
|
-
|
|
94
|
+
import importlib
|
|
95
|
+
model_module = importlib.import_module("model")
|
|
96
|
+
# Support both model class names (v5 = JinaEmbeddingModel, code-0.5b = JinaCodeEmbeddingModel)
|
|
97
|
+
JinaEmbeddingModel = getattr(model_module, "JinaEmbeddingModel", None) or getattr(model_module, "JinaCodeEmbeddingModel")
|
|
68
98
|
|
|
69
99
|
with open(os.path.join(MODEL_DIR, "config.json")) as f:
|
|
70
100
|
config = json.load(f)
|
|
@@ -99,8 +129,17 @@ def embed_tiered(model, tokenizer, texts, task_type="retrieval.passage", truncat
|
|
|
99
129
|
if not texts:
|
|
100
130
|
return []
|
|
101
131
|
|
|
102
|
-
# Add task prefix
|
|
103
|
-
|
|
132
|
+
# Add task prefix — auto-detect based on model type
|
|
133
|
+
# v5 (Qwen3): "Query: " / "Document: "
|
|
134
|
+
# code-0.5b (Qwen2): "Find the most relevant code snippet...\n" / "Candidate code snippet:\n"
|
|
135
|
+
is_code_model = "jina-code" in MODEL_DIR
|
|
136
|
+
if is_code_model:
|
|
137
|
+
prefix_map = {
|
|
138
|
+
"retrieval.query": "Find the most relevant code snippet given the following query:\n",
|
|
139
|
+
"retrieval.passage": "Candidate code snippet:\n",
|
|
140
|
+
}
|
|
141
|
+
else:
|
|
142
|
+
prefix_map = {"retrieval.query": "Query: ", "retrieval.passage": "Document: "}
|
|
104
143
|
prefix = prefix_map.get(task_type, "")
|
|
105
144
|
prefixed = [prefix + t for t in texts] if prefix else texts
|
|
106
145
|
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@zuvia-software-solutions/code-mapper",
|
|
3
|
-
"version": "2.
|
|
3
|
+
"version": "2.3.1",
|
|
4
4
|
"description": "Graph-powered code intelligence for AI agents. Index any codebase, query via MCP or CLI.",
|
|
5
5
|
"author": "Abhigyan Patwari",
|
|
6
6
|
"license": "PolyForm-Noncommercial-1.0.0",
|
|
@@ -36,8 +36,8 @@
|
|
|
36
36
|
"skills",
|
|
37
37
|
"vendor",
|
|
38
38
|
"models/mlx-embedder.py",
|
|
39
|
-
"models/jina-
|
|
40
|
-
"models/jina-
|
|
39
|
+
"models/jina-code-0.5b-mlx/model.py",
|
|
40
|
+
"models/jina-code-0.5b-mlx/config.json"
|
|
41
41
|
],
|
|
42
42
|
"scripts": {
|
|
43
43
|
"build": "tsc",
|
|
@@ -1,19 +0,0 @@
|
|
|
1
|
-
{
|
|
2
|
-
"model_type": "qwen3",
|
|
3
|
-
"hidden_size": 1024,
|
|
4
|
-
"num_hidden_layers": 28,
|
|
5
|
-
"intermediate_size": 3072,
|
|
6
|
-
"num_attention_heads": 16,
|
|
7
|
-
"num_key_value_heads": 8,
|
|
8
|
-
"rms_norm_eps": 1e-06,
|
|
9
|
-
"vocab_size": 151936,
|
|
10
|
-
"max_position_embeddings": 32768,
|
|
11
|
-
"rope_theta": 3500000,
|
|
12
|
-
"rope_parameters": {
|
|
13
|
-
"rope_theta": 3500000,
|
|
14
|
-
"rope_type": "default"
|
|
15
|
-
},
|
|
16
|
-
"head_dim": 128,
|
|
17
|
-
"tie_word_embeddings": true,
|
|
18
|
-
"rope_scaling": null
|
|
19
|
-
}
|
|
@@ -1,260 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Jina Embeddings v5 Text Small - MLX Implementation
|
|
3
|
-
|
|
4
|
-
Pure MLX port of jina-embeddings-v5-text-small (Qwen3-0.6B backbone).
|
|
5
|
-
Zero dependency on PyTorch or transformers.
|
|
6
|
-
|
|
7
|
-
Features:
|
|
8
|
-
- Causal attention (decoder architecture)
|
|
9
|
-
- QKNorm (q_norm/k_norm per head)
|
|
10
|
-
- Last-token pooling
|
|
11
|
-
- L2 normalization
|
|
12
|
-
- Matryoshka embedding dimensions: [32, 64, 128, 256, 512, 768, 1024]
|
|
13
|
-
- Max sequence length: 32768 tokens
|
|
14
|
-
- Embedding dimension: 1024
|
|
15
|
-
|
|
16
|
-
Architecture:
|
|
17
|
-
- RoPE (rope_theta from config)
|
|
18
|
-
- SwiGLU MLP
|
|
19
|
-
- RMSNorm
|
|
20
|
-
- QKNorm (RMSNorm on Q/K per head)
|
|
21
|
-
- No attention bias
|
|
22
|
-
"""
|
|
23
|
-
|
|
24
|
-
from dataclasses import dataclass
|
|
25
|
-
from typing import Any, Dict, Optional, Union
|
|
26
|
-
|
|
27
|
-
import mlx.core as mx
|
|
28
|
-
import mlx.nn as nn
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
@dataclass
|
|
32
|
-
class ModelArgs:
|
|
33
|
-
model_type: str
|
|
34
|
-
hidden_size: int
|
|
35
|
-
num_hidden_layers: int
|
|
36
|
-
intermediate_size: int
|
|
37
|
-
num_attention_heads: int
|
|
38
|
-
rms_norm_eps: float
|
|
39
|
-
vocab_size: int
|
|
40
|
-
num_key_value_heads: int
|
|
41
|
-
max_position_embeddings: int
|
|
42
|
-
head_dim: int
|
|
43
|
-
tie_word_embeddings: bool
|
|
44
|
-
rope_parameters: Optional[Dict[str, Union[float, str]]] = None
|
|
45
|
-
rope_theta: Optional[float] = None
|
|
46
|
-
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
class Attention(nn.Module):
|
|
50
|
-
def __init__(self, args: ModelArgs):
|
|
51
|
-
super().__init__()
|
|
52
|
-
|
|
53
|
-
dim = args.hidden_size
|
|
54
|
-
self.n_heads = n_heads = args.num_attention_heads
|
|
55
|
-
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
|
56
|
-
|
|
57
|
-
head_dim = args.head_dim
|
|
58
|
-
self.scale = head_dim**-0.5
|
|
59
|
-
self.head_dim = head_dim
|
|
60
|
-
|
|
61
|
-
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
|
|
62
|
-
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
|
63
|
-
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
|
64
|
-
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
|
65
|
-
|
|
66
|
-
# Qwen3 has QKNorm
|
|
67
|
-
self.q_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps)
|
|
68
|
-
self.k_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps)
|
|
69
|
-
|
|
70
|
-
# Resolve rope_theta from config
|
|
71
|
-
if args.rope_parameters and 'rope_theta' in args.rope_parameters:
|
|
72
|
-
rope_theta = float(args.rope_parameters['rope_theta'])
|
|
73
|
-
elif args.rope_theta:
|
|
74
|
-
rope_theta = float(args.rope_theta)
|
|
75
|
-
else:
|
|
76
|
-
rope_theta = 10000.0
|
|
77
|
-
self.rope_theta = rope_theta
|
|
78
|
-
|
|
79
|
-
def __call__(
|
|
80
|
-
self,
|
|
81
|
-
x: mx.array,
|
|
82
|
-
mask: Optional[mx.array] = None,
|
|
83
|
-
) -> mx.array:
|
|
84
|
-
B, L, D = x.shape
|
|
85
|
-
|
|
86
|
-
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
|
87
|
-
|
|
88
|
-
# Reshape and apply QKNorm
|
|
89
|
-
queries = self.q_norm(queries.reshape(B, L, self.n_heads, -1)).transpose(0, 2, 1, 3)
|
|
90
|
-
keys = self.k_norm(keys.reshape(B, L, self.n_kv_heads, -1)).transpose(0, 2, 1, 3)
|
|
91
|
-
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
|
92
|
-
|
|
93
|
-
# RoPE via mx.fast
|
|
94
|
-
queries = mx.fast.rope(queries, self.head_dim, traditional=False, base=self.rope_theta, scale=1.0, offset=0)
|
|
95
|
-
keys = mx.fast.rope(keys, self.head_dim, traditional=False, base=self.rope_theta, scale=1.0, offset=0)
|
|
96
|
-
|
|
97
|
-
# Scaled dot-product attention (handles GQA natively)
|
|
98
|
-
output = mx.fast.scaled_dot_product_attention(
|
|
99
|
-
queries, keys, values,
|
|
100
|
-
mask=mask.astype(queries.dtype) if mask is not None else None,
|
|
101
|
-
scale=self.scale,
|
|
102
|
-
)
|
|
103
|
-
|
|
104
|
-
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
105
|
-
return self.o_proj(output)
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
class MLP(nn.Module):
|
|
109
|
-
def __init__(self, dim, hidden_dim):
|
|
110
|
-
super().__init__()
|
|
111
|
-
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
|
112
|
-
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
|
113
|
-
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
|
114
|
-
|
|
115
|
-
def __call__(self, x) -> mx.array:
|
|
116
|
-
gate = nn.silu(self.gate_proj(x))
|
|
117
|
-
return self.down_proj(gate * self.up_proj(x))
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
class TransformerBlock(nn.Module):
|
|
121
|
-
def __init__(self, args: ModelArgs):
|
|
122
|
-
super().__init__()
|
|
123
|
-
self.self_attn = Attention(args)
|
|
124
|
-
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
|
125
|
-
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
|
126
|
-
self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
|
127
|
-
|
|
128
|
-
def __call__(
|
|
129
|
-
self,
|
|
130
|
-
x: mx.array,
|
|
131
|
-
mask: Optional[mx.array] = None,
|
|
132
|
-
) -> mx.array:
|
|
133
|
-
r = self.self_attn(self.input_layernorm(x), mask)
|
|
134
|
-
h = x + r
|
|
135
|
-
r = self.mlp(self.post_attention_layernorm(h))
|
|
136
|
-
out = h + r
|
|
137
|
-
return out
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
class Qwen3Model(nn.Module):
|
|
141
|
-
def __init__(self, args: ModelArgs):
|
|
142
|
-
super().__init__()
|
|
143
|
-
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
|
144
|
-
self.layers = [TransformerBlock(args=args) for _ in range(args.num_hidden_layers)]
|
|
145
|
-
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
|
146
|
-
|
|
147
|
-
def __call__(self, inputs: mx.array, mask: Optional[mx.array] = None):
|
|
148
|
-
h = self.embed_tokens(inputs)
|
|
149
|
-
for layer in self.layers:
|
|
150
|
-
h = layer(h, mask)
|
|
151
|
-
return self.norm(h)
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
class JinaEmbeddingModel(nn.Module):
|
|
155
|
-
"""Jina v5-text-small embedding model with last-token pooling."""
|
|
156
|
-
|
|
157
|
-
def __init__(self, config: dict):
|
|
158
|
-
super().__init__()
|
|
159
|
-
args = ModelArgs(**config)
|
|
160
|
-
self.model = Qwen3Model(args)
|
|
161
|
-
self.config = config
|
|
162
|
-
|
|
163
|
-
def __call__(
|
|
164
|
-
self,
|
|
165
|
-
input_ids: mx.array,
|
|
166
|
-
attention_mask: Optional[mx.array] = None,
|
|
167
|
-
):
|
|
168
|
-
batch_size, seq_len = input_ids.shape
|
|
169
|
-
|
|
170
|
-
# Causal mask (Qwen3 is a decoder model)
|
|
171
|
-
causal_mask = mx.tril(mx.ones((seq_len, seq_len)))
|
|
172
|
-
causal_mask = mx.where(causal_mask == 0, -1e4, 0.0)
|
|
173
|
-
causal_mask = causal_mask[None, None, :, :]
|
|
174
|
-
|
|
175
|
-
# Combine with padding mask
|
|
176
|
-
if attention_mask is not None:
|
|
177
|
-
padding_mask = mx.where(attention_mask == 0, -1e4, 0.0)
|
|
178
|
-
padding_mask = padding_mask[:, None, None, :]
|
|
179
|
-
mask = causal_mask + padding_mask
|
|
180
|
-
else:
|
|
181
|
-
mask = causal_mask
|
|
182
|
-
|
|
183
|
-
hidden_states = self.model(input_ids, mask)
|
|
184
|
-
|
|
185
|
-
# Last token pooling
|
|
186
|
-
if attention_mask is not None:
|
|
187
|
-
sequence_lengths = mx.sum(attention_mask, axis=1) - 1
|
|
188
|
-
batch_indices = mx.arange(hidden_states.shape[0])
|
|
189
|
-
embeddings = hidden_states[batch_indices, sequence_lengths]
|
|
190
|
-
else:
|
|
191
|
-
embeddings = hidden_states[:, -1, :]
|
|
192
|
-
|
|
193
|
-
# L2 normalization
|
|
194
|
-
norms = mx.linalg.norm(embeddings, axis=1, keepdims=True)
|
|
195
|
-
embeddings = embeddings / norms
|
|
196
|
-
|
|
197
|
-
return embeddings
|
|
198
|
-
|
|
199
|
-
def encode(
|
|
200
|
-
self,
|
|
201
|
-
texts: list[str],
|
|
202
|
-
tokenizer,
|
|
203
|
-
max_length: int = 8192,
|
|
204
|
-
truncate_dim: Optional[int] = None,
|
|
205
|
-
task_type: str = "retrieval.query",
|
|
206
|
-
):
|
|
207
|
-
"""
|
|
208
|
-
Encode texts to embeddings.
|
|
209
|
-
|
|
210
|
-
Args:
|
|
211
|
-
texts: List of input texts
|
|
212
|
-
tokenizer: Tokenizer instance (from tokenizers library)
|
|
213
|
-
max_length: Maximum sequence length
|
|
214
|
-
truncate_dim: Optional Matryoshka dimension [32, 64, 128, 256, 512, 768, 1024]
|
|
215
|
-
task_type: Task prefix ("retrieval.query", "retrieval.passage", etc.)
|
|
216
|
-
|
|
217
|
-
Returns:
|
|
218
|
-
Embeddings array [batch, dim]
|
|
219
|
-
"""
|
|
220
|
-
prefix_map = {
|
|
221
|
-
"retrieval.query": "Query: ",
|
|
222
|
-
"retrieval.passage": "Document: ",
|
|
223
|
-
"classification": "Document: ",
|
|
224
|
-
"text-matching": "Document: ",
|
|
225
|
-
"clustering": "Document: ",
|
|
226
|
-
}
|
|
227
|
-
prefix = prefix_map.get(task_type, "")
|
|
228
|
-
|
|
229
|
-
if prefix:
|
|
230
|
-
texts = [prefix + text for text in texts]
|
|
231
|
-
|
|
232
|
-
encodings = tokenizer.encode_batch(texts)
|
|
233
|
-
|
|
234
|
-
max_len = min(max_length, max(len(enc.ids) for enc in encodings))
|
|
235
|
-
input_ids = []
|
|
236
|
-
attention_mask = []
|
|
237
|
-
|
|
238
|
-
for encoding in encodings:
|
|
239
|
-
ids = encoding.ids[:max_len]
|
|
240
|
-
mask = encoding.attention_mask[:max_len]
|
|
241
|
-
|
|
242
|
-
pad_len = max_len - len(ids)
|
|
243
|
-
if pad_len > 0:
|
|
244
|
-
ids = ids + [0] * pad_len
|
|
245
|
-
mask = mask + [0] * pad_len
|
|
246
|
-
|
|
247
|
-
input_ids.append(ids)
|
|
248
|
-
attention_mask.append(mask)
|
|
249
|
-
|
|
250
|
-
input_ids = mx.array(input_ids)
|
|
251
|
-
attention_mask = mx.array(attention_mask)
|
|
252
|
-
|
|
253
|
-
embeddings = self(input_ids, attention_mask)
|
|
254
|
-
|
|
255
|
-
if truncate_dim is not None:
|
|
256
|
-
embeddings = embeddings[:, :truncate_dim]
|
|
257
|
-
norms = mx.linalg.norm(embeddings, axis=1, keepdims=True)
|
|
258
|
-
embeddings = embeddings / norms
|
|
259
|
-
|
|
260
|
-
return embeddings
|