cozo-memory 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +533 -0
- package/dist/api_bridge.js +266 -0
- package/dist/benchmark-gpu-cpu.js +188 -0
- package/dist/benchmark-heavy.js +230 -0
- package/dist/benchmark.js +160 -0
- package/dist/clear-cache.js +29 -0
- package/dist/db-service.js +228 -0
- package/dist/download-model.js +48 -0
- package/dist/embedding-service.js +249 -0
- package/dist/full-system-test.js +45 -0
- package/dist/hybrid-search.js +337 -0
- package/dist/index.js +3106 -0
- package/dist/inference-engine.js +348 -0
- package/dist/memory-service.js +215 -0
- package/dist/test-advanced-filters.js +64 -0
- package/dist/test-advanced-search.js +82 -0
- package/dist/test-advanced-time.js +47 -0
- package/dist/test-embedding.js +22 -0
- package/dist/test-filter-expr.js +84 -0
- package/dist/test-fts.js +58 -0
- package/dist/test-functions.js +25 -0
- package/dist/test-gpu-check.js +16 -0
- package/dist/test-graph-algs-final.js +76 -0
- package/dist/test-graph-filters.js +88 -0
- package/dist/test-graph-rag.js +124 -0
- package/dist/test-graph-walking.js +138 -0
- package/dist/test-index.js +35 -0
- package/dist/test-int-filter.js +48 -0
- package/dist/test-integration.js +69 -0
- package/dist/test-lower.js +35 -0
- package/dist/test-lsh.js +67 -0
- package/dist/test-mcp-tool.js +40 -0
- package/dist/test-pagerank.js +31 -0
- package/dist/test-semantic-walk.js +145 -0
- package/dist/test-time-filter.js +66 -0
- package/dist/test-time-functions.js +38 -0
- package/dist/test-triggers.js +60 -0
- package/dist/test-ts-ort.js +48 -0
- package/dist/test-validity-access.js +35 -0
- package/dist/test-validity-body.js +42 -0
- package/dist/test-validity-decomp.js +37 -0
- package/dist/test-validity-extraction.js +45 -0
- package/dist/test-validity-json.js +35 -0
- package/dist/test-validity.js +38 -0
- package/dist/types.js +3 -0
- package/dist/verify-gpu.js +30 -0
- package/dist/verify_transaction_tool.js +46 -0
- package/package.json +75 -0
|
@@ -0,0 +1,266 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
var __importDefault = (this && this.__importDefault) || function (mod) {
|
|
3
|
+
return (mod && mod.__esModule) ? mod : { "default": mod };
|
|
4
|
+
};
|
|
5
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
6
|
+
const express_1 = __importDefault(require("express"));
|
|
7
|
+
const cors_1 = __importDefault(require("cors"));
|
|
8
|
+
const index_js_1 = require("./index.js");
|
|
9
|
+
const uuid_1 = require("uuid");
|
|
10
|
+
const app = (0, express_1.default)();
|
|
11
|
+
const port = process.env.PORT || 3001;
|
|
12
|
+
app.use((0, cors_1.default)());
|
|
13
|
+
app.use(express_1.default.json());
|
|
14
|
+
const memoryServer = new index_js_1.MemoryServer();
|
|
15
|
+
// --- Entities ---
|
|
16
|
+
app.get("/api/entities", async (req, res) => {
|
|
17
|
+
try {
|
|
18
|
+
const result = await memoryServer.db.run('?[id, name, type, metadata, ts] := *entity{id, name, type, metadata, created_at, @ "NOW"}, ts = to_int(created_at)');
|
|
19
|
+
res.json(result.rows.map((r) => ({
|
|
20
|
+
id: r[0],
|
|
21
|
+
name: r[1],
|
|
22
|
+
type: r[2],
|
|
23
|
+
metadata: r[3],
|
|
24
|
+
created_at: r[4]
|
|
25
|
+
})));
|
|
26
|
+
}
|
|
27
|
+
catch (error) {
|
|
28
|
+
res.status(500).json({ error: error.message });
|
|
29
|
+
}
|
|
30
|
+
});
|
|
31
|
+
app.post("/api/entities", async (req, res) => {
|
|
32
|
+
const { name, type, metadata } = req.body;
|
|
33
|
+
if (!name || !type)
|
|
34
|
+
return res.status(400).json({ error: "Name and type are required" });
|
|
35
|
+
try {
|
|
36
|
+
// We use the same logic as in create_entity tool
|
|
37
|
+
const id = (0, uuid_1.v4)();
|
|
38
|
+
const embedding = await memoryServer.embeddingService.embed(name + " " + type);
|
|
39
|
+
await memoryServer.db.run(`
|
|
40
|
+
?[id, created_at, name, type, embedding, metadata] <- [
|
|
41
|
+
[$id, "ASSERT", $name, $type, [${embedding.join(",")}], $metadata]
|
|
42
|
+
] :put entity {id, created_at => name, type, embedding, metadata}
|
|
43
|
+
`, { id, name, type, metadata: metadata || {} });
|
|
44
|
+
res.status(201).json({ id, name, type, metadata, status: "Entity created" });
|
|
45
|
+
}
|
|
46
|
+
catch (error) {
|
|
47
|
+
res.status(500).json({ error: error.message });
|
|
48
|
+
}
|
|
49
|
+
});
|
|
50
|
+
app.get("/api/entities/:id", async (req, res) => {
|
|
51
|
+
const { id } = req.params;
|
|
52
|
+
try {
|
|
53
|
+
// Logic from get_entity_details
|
|
54
|
+
const entityRes = await memoryServer.db.run('?[id, name, type, metadata, ts] := *entity{id, name, type, metadata, created_at, @ "NOW"}, id = $id, ts = to_int(created_at)', { id });
|
|
55
|
+
if (entityRes.rows.length === 0)
|
|
56
|
+
return res.status(404).json({ error: "Entity not found" });
|
|
57
|
+
const obsRes = await memoryServer.db.run('?[id, text, metadata, ts] := *observation{id, entity_id, text, metadata, created_at, @ "NOW"}, entity_id = $id, ts = to_int(created_at)', { id });
|
|
58
|
+
const relRes = await memoryServer.db.run(`
|
|
59
|
+
?[target_id, type, strength, metadata, direction] := *relationship{from_id, to_id, relation_type: type, strength, metadata, @ "NOW"}, from_id = $id, target_id = to_id, direction = 'outgoing'
|
|
60
|
+
?[target_id, type, strength, metadata, direction] := *relationship{from_id, to_id, relation_type: type, strength, metadata, @ "NOW"}, to_id = $id, target_id = from_id, direction = 'incoming'
|
|
61
|
+
`, { id });
|
|
62
|
+
res.json({
|
|
63
|
+
entity: {
|
|
64
|
+
id: entityRes.rows[0][0],
|
|
65
|
+
name: entityRes.rows[0][1],
|
|
66
|
+
type: entityRes.rows[0][2],
|
|
67
|
+
metadata: entityRes.rows[0][3],
|
|
68
|
+
created_at: entityRes.rows[0][4]
|
|
69
|
+
},
|
|
70
|
+
observations: obsRes.rows.map((r) => ({ id: r[0], text: r[1], metadata: r[2], created_at: r[3] })),
|
|
71
|
+
relations: relRes.rows.map((r) => ({ target_id: r[0], type: r[1], strength: r[2], metadata: r[3], direction: r[4] }))
|
|
72
|
+
});
|
|
73
|
+
}
|
|
74
|
+
catch (error) {
|
|
75
|
+
res.status(500).json({ error: error.message });
|
|
76
|
+
}
|
|
77
|
+
});
|
|
78
|
+
app.delete("/api/entities/:id", async (req, res) => {
|
|
79
|
+
const { id } = req.params;
|
|
80
|
+
try {
|
|
81
|
+
// Fixed logic from delete_entity
|
|
82
|
+
await memoryServer.db.run(`
|
|
83
|
+
{ ?[id, created_at] := *observation{id, entity_id, created_at}, entity_id = $target_id :rm observation {id, created_at} }
|
|
84
|
+
{ ?[from_id, to_id, relation_type, created_at] := *relationship{from_id, to_id, relation_type, created_at}, from_id = $target_id :rm relationship {from_id, to_id, relation_type, created_at} }
|
|
85
|
+
{ ?[from_id, to_id, relation_type, created_at] := *relationship{from_id, to_id, relation_type, created_at}, to_id = $target_id :rm relationship {from_id, to_id, relation_type, created_at} }
|
|
86
|
+
{ ?[id, created_at] := *entity{id, created_at}, id = $target_id :rm entity {id, created_at} }
|
|
87
|
+
`, { target_id: id });
|
|
88
|
+
res.json({ status: "Entity and related data deleted" });
|
|
89
|
+
}
|
|
90
|
+
catch (error) {
|
|
91
|
+
res.status(500).json({ error: error.message });
|
|
92
|
+
}
|
|
93
|
+
});
|
|
94
|
+
// --- Observations ---
|
|
95
|
+
app.post("/api/observations", async (req, res) => {
|
|
96
|
+
const { entity_id, text, metadata } = req.body;
|
|
97
|
+
if (!entity_id || !text)
|
|
98
|
+
return res.status(400).json({ error: "Entity ID and text are required" });
|
|
99
|
+
try {
|
|
100
|
+
const id = (0, uuid_1.v4)();
|
|
101
|
+
const embedding = await memoryServer.embeddingService.embed(text);
|
|
102
|
+
await memoryServer.db.run(`
|
|
103
|
+
?[id, created_at, entity_id, text, embedding, metadata] <- [
|
|
104
|
+
[$id, "ASSERT", $entity_id, $text, [${embedding.join(",")}], $metadata]
|
|
105
|
+
] :put observation {id, created_at => entity_id, text, embedding, metadata}
|
|
106
|
+
`, { id, entity_id, text, metadata: metadata || {} });
|
|
107
|
+
res.status(201).json({ id, entity_id, text, metadata, status: "Observation added" });
|
|
108
|
+
}
|
|
109
|
+
catch (error) {
|
|
110
|
+
res.status(500).json({ error: error.message });
|
|
111
|
+
}
|
|
112
|
+
});
|
|
113
|
+
// --- Search / Context ---
|
|
114
|
+
app.get("/api/search", async (req, res) => {
|
|
115
|
+
const { query, limit = 10 } = req.query;
|
|
116
|
+
if (!query)
|
|
117
|
+
return res.status(400).json({ error: "Query is required" });
|
|
118
|
+
try {
|
|
119
|
+
const results = await memoryServer.hybridSearch.search({
|
|
120
|
+
query: query,
|
|
121
|
+
limit: Number(limit)
|
|
122
|
+
});
|
|
123
|
+
res.json(results);
|
|
124
|
+
}
|
|
125
|
+
catch (error) {
|
|
126
|
+
res.status(500).json({ error: error.message });
|
|
127
|
+
}
|
|
128
|
+
});
|
|
129
|
+
app.get("/api/context", async (req, res) => {
|
|
130
|
+
const { query, context_window = 20 } = req.query;
|
|
131
|
+
if (!query)
|
|
132
|
+
return res.status(400).json({ error: "Query is required" });
|
|
133
|
+
try {
|
|
134
|
+
// Logic from get_context
|
|
135
|
+
const searchResults = await memoryServer.hybridSearch.search({
|
|
136
|
+
query: query,
|
|
137
|
+
limit: Number(context_window)
|
|
138
|
+
});
|
|
139
|
+
const entities = searchResults.filter(r => r.type === 'entity');
|
|
140
|
+
const observations = searchResults.filter(r => r.type === 'observation');
|
|
141
|
+
const graphContext = [];
|
|
142
|
+
for (const entity of entities) {
|
|
143
|
+
const connections = await memoryServer.db.run(`
|
|
144
|
+
?[target_name, rel_type] := *relationship{from_id, to_id, relation_type: rel_type, @ "NOW"}, from_id = $id, *entity{id: to_id, name: target_name, @ "NOW"}
|
|
145
|
+
?[target_name, rel_type] := *relationship{from_id, to_id, relation_type: rel_type, @ "NOW"}, to_id = $id, *entity{id: from_id, name: target_name, @ "NOW"}
|
|
146
|
+
`, { id: entity.id });
|
|
147
|
+
graphContext.push({ entity: entity.name, connections: connections.rows });
|
|
148
|
+
}
|
|
149
|
+
res.json({
|
|
150
|
+
search_results: searchResults,
|
|
151
|
+
graph_context: graphContext
|
|
152
|
+
});
|
|
153
|
+
}
|
|
154
|
+
catch (error) {
|
|
155
|
+
res.status(500).json({ error: error.message });
|
|
156
|
+
}
|
|
157
|
+
});
|
|
158
|
+
app.get("/api/evolution/:id", async (req, res) => {
|
|
159
|
+
const { id } = req.params;
|
|
160
|
+
const { to_id, since, until } = req.query;
|
|
161
|
+
try {
|
|
162
|
+
const result = await memoryServer.getRelationEvolution({
|
|
163
|
+
from_id: id,
|
|
164
|
+
to_id: to_id,
|
|
165
|
+
since: since ? Number(since) : undefined,
|
|
166
|
+
until: until ? Number(until) : undefined
|
|
167
|
+
});
|
|
168
|
+
res.json(result);
|
|
169
|
+
}
|
|
170
|
+
catch (error) {
|
|
171
|
+
res.status(500).json({ error: error.message });
|
|
172
|
+
}
|
|
173
|
+
});
|
|
174
|
+
// --- Health & Maintenance ---
|
|
175
|
+
app.get("/api/health", async (req, res) => {
|
|
176
|
+
try {
|
|
177
|
+
const e = await memoryServer.db.run('?[count(id)] := *entity{id, @ "NOW"}');
|
|
178
|
+
const o = await memoryServer.db.run('?[count(id)] := *observation{id, @ "NOW"}');
|
|
179
|
+
const r = await memoryServer.db.run('?[count(f)] := *relationship{from_id: f, @ "NOW"}');
|
|
180
|
+
res.json({
|
|
181
|
+
entities: e.rows[0]?.[0] ?? 0,
|
|
182
|
+
observations: o.rows[0]?.[0] ?? 0,
|
|
183
|
+
relationships: r.rows[0]?.[0] ?? 0,
|
|
184
|
+
status: "healthy"
|
|
185
|
+
});
|
|
186
|
+
}
|
|
187
|
+
catch (error) {
|
|
188
|
+
console.error("Health endpoint error:", error);
|
|
189
|
+
res.status(500).json({ error: error.message });
|
|
190
|
+
}
|
|
191
|
+
});
|
|
192
|
+
app.get("/api/communities", async (req, res) => {
|
|
193
|
+
try {
|
|
194
|
+
const query = `
|
|
195
|
+
edges[f, t, s] := *relationship{from_id: f, to_id: t, strength: s, @ "NOW"}
|
|
196
|
+
temp_communities[community_id, entity_id] <~ LabelPropagation(edges[f, t, s])
|
|
197
|
+
?[entity_id, community_id] := temp_communities[community_id, entity_id]
|
|
198
|
+
`;
|
|
199
|
+
const result = await memoryServer.db.run(query);
|
|
200
|
+
const entitiesRes = await memoryServer.db.run('?[id, name, type] := *entity{id, name, type, @ "NOW"}');
|
|
201
|
+
const entityMap = new Map();
|
|
202
|
+
entitiesRes.rows.forEach((r) => entityMap.set(r[0], { name: r[1], type: r[2] }));
|
|
203
|
+
const communities = {};
|
|
204
|
+
result.rows.forEach((r) => {
|
|
205
|
+
const communityId = String(r[1]);
|
|
206
|
+
const entityId = r[0];
|
|
207
|
+
const info = entityMap.get(entityId) || { name: "Unknown", type: "Unknown" };
|
|
208
|
+
if (!communities[communityId])
|
|
209
|
+
communities[communityId] = [];
|
|
210
|
+
communities[communityId].push({ id: entityId, name: info.name, type: info.type });
|
|
211
|
+
});
|
|
212
|
+
res.json({ community_count: Object.keys(communities).length, communities });
|
|
213
|
+
}
|
|
214
|
+
catch (error) {
|
|
215
|
+
res.status(500).json({ error: error.message });
|
|
216
|
+
}
|
|
217
|
+
});
|
|
218
|
+
// --- Snapshots ---
|
|
219
|
+
app.get("/api/snapshots", async (req, res) => {
|
|
220
|
+
try {
|
|
221
|
+
const result = await memoryServer.db.run("?[id, e, o, r, meta, created_at] := *memory_snapshot{snapshot_id: id, entity_count: e, observation_count: o, relation_count: r, metadata: meta, created_at}");
|
|
222
|
+
res.json(result.rows.map((r) => ({
|
|
223
|
+
snapshot_id: r[0],
|
|
224
|
+
entity_count: r[1],
|
|
225
|
+
observation_count: r[2],
|
|
226
|
+
relation_count: r[3],
|
|
227
|
+
metadata: r[4],
|
|
228
|
+
created_at: r[5]
|
|
229
|
+
})));
|
|
230
|
+
}
|
|
231
|
+
catch (error) {
|
|
232
|
+
res.status(500).json({ error: error.message });
|
|
233
|
+
}
|
|
234
|
+
});
|
|
235
|
+
app.post("/api/snapshots", async (req, res) => {
|
|
236
|
+
const { metadata } = req.body;
|
|
237
|
+
try {
|
|
238
|
+
const [entityResult, obsResult, relResult] = await Promise.all([
|
|
239
|
+
memoryServer.db.run('?[id] := *entity{id, @ "NOW"}'),
|
|
240
|
+
memoryServer.db.run('?[id] := *observation{id, @ "NOW"}'),
|
|
241
|
+
memoryServer.db.run('?[from_id, to_id] := *relationship{from_id, to_id, @ "NOW"}')
|
|
242
|
+
]);
|
|
243
|
+
const counts = {
|
|
244
|
+
entities: entityResult.rows.length,
|
|
245
|
+
observations: obsResult.rows.length,
|
|
246
|
+
relations: relResult.rows.length
|
|
247
|
+
};
|
|
248
|
+
const snapshot_id = (0, uuid_1.v4)();
|
|
249
|
+
const now = Date.now();
|
|
250
|
+
await memoryServer.db.run("?[snapshot_id, entity_count, observation_count, relation_count, metadata, created_at] <- [[$id, $e, $o, $r, $meta, $now]]:put memory_snapshot {snapshot_id => entity_count, observation_count, relation_count, metadata, created_at}", {
|
|
251
|
+
id: snapshot_id,
|
|
252
|
+
e: counts.entities,
|
|
253
|
+
o: counts.observations,
|
|
254
|
+
r: counts.relations,
|
|
255
|
+
meta: metadata || {},
|
|
256
|
+
now
|
|
257
|
+
});
|
|
258
|
+
res.status(201).json({ snapshot_id, ...counts, status: "Snapshot created" });
|
|
259
|
+
}
|
|
260
|
+
catch (error) {
|
|
261
|
+
res.status(500).json({ error: error.message });
|
|
262
|
+
}
|
|
263
|
+
});
|
|
264
|
+
app.listen(port, () => {
|
|
265
|
+
console.log(`API Bridge listening at http://localhost:${port}`);
|
|
266
|
+
});
|
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
var __createBinding = (this && this.__createBinding) || (Object.create ? (function(o, m, k, k2) {
|
|
3
|
+
if (k2 === undefined) k2 = k;
|
|
4
|
+
var desc = Object.getOwnPropertyDescriptor(m, k);
|
|
5
|
+
if (!desc || ("get" in desc ? !m.__esModule : desc.writable || desc.configurable)) {
|
|
6
|
+
desc = { enumerable: true, get: function() { return m[k]; } };
|
|
7
|
+
}
|
|
8
|
+
Object.defineProperty(o, k2, desc);
|
|
9
|
+
}) : (function(o, m, k, k2) {
|
|
10
|
+
if (k2 === undefined) k2 = k;
|
|
11
|
+
o[k2] = m[k];
|
|
12
|
+
}));
|
|
13
|
+
var __setModuleDefault = (this && this.__setModuleDefault) || (Object.create ? (function(o, v) {
|
|
14
|
+
Object.defineProperty(o, "default", { enumerable: true, value: v });
|
|
15
|
+
}) : function(o, v) {
|
|
16
|
+
o["default"] = v;
|
|
17
|
+
});
|
|
18
|
+
var __importStar = (this && this.__importStar) || (function () {
|
|
19
|
+
var ownKeys = function(o) {
|
|
20
|
+
ownKeys = Object.getOwnPropertyNames || function (o) {
|
|
21
|
+
var ar = [];
|
|
22
|
+
for (var k in o) if (Object.prototype.hasOwnProperty.call(o, k)) ar[ar.length] = k;
|
|
23
|
+
return ar;
|
|
24
|
+
};
|
|
25
|
+
return ownKeys(o);
|
|
26
|
+
};
|
|
27
|
+
return function (mod) {
|
|
28
|
+
if (mod && mod.__esModule) return mod;
|
|
29
|
+
var result = {};
|
|
30
|
+
if (mod != null) for (var k = ownKeys(mod), i = 0; i < k.length; i++) if (k[i] !== "default") __createBinding(result, mod, k[i]);
|
|
31
|
+
__setModuleDefault(result, mod);
|
|
32
|
+
return result;
|
|
33
|
+
};
|
|
34
|
+
})();
|
|
35
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
36
|
+
const transformers_1 = require("@xenova/transformers");
|
|
37
|
+
const ort = require('onnxruntime-node');
|
|
38
|
+
const path = __importStar(require("path"));
|
|
39
|
+
const fs = __importStar(require("fs"));
|
|
40
|
+
// Configure cache path
|
|
41
|
+
const CACHE_DIR = path.resolve('./.cache');
|
|
42
|
+
transformers_1.env.cacheDir = CACHE_DIR;
|
|
43
|
+
const MODEL_ID = "Xenova/bge-m3";
|
|
44
|
+
const EMBEDDING_DIM = 1024;
|
|
45
|
+
const QUANTIZED = false;
|
|
46
|
+
const MODEL_FILE = QUANTIZED ? "model_quantized.onnx" : "model.onnx";
|
|
47
|
+
const BATCH_SIZE = 10;
|
|
48
|
+
const NUM_TEXTS = 200;
|
|
49
|
+
const texts = Array.from({ length: NUM_TEXTS }, (_, i) => `This is a complex test sentence number ${i} for the performance comparison of CPU and GPU embeddings. The longer the text, the more apparent the advantage of parallel processing on the graphics card should be, especially with transformer models like BGE-M3. We are testing the DirectML integration on Windows with an RTX 2080 here.`);
|
|
50
|
+
async function runBenchmark() {
|
|
51
|
+
console.log("==========================================");
|
|
52
|
+
console.log(`Starting Benchmark: CPU vs GPU (DirectML) [FP32 Mode]`);
|
|
53
|
+
console.log(`Batch Size: ${BATCH_SIZE}, Total Texts: ${NUM_TEXTS}`);
|
|
54
|
+
// 1. Prepare Model Path
|
|
55
|
+
const modelPath = path.join(CACHE_DIR, 'Xenova', 'bge-m3', 'onnx', MODEL_FILE);
|
|
56
|
+
if (!fs.existsSync(modelPath)) {
|
|
57
|
+
console.error(`Model not found at: ${modelPath}`);
|
|
58
|
+
return;
|
|
59
|
+
}
|
|
60
|
+
console.log(`Model path: ${modelPath}`);
|
|
61
|
+
// 2. Load Tokenizer
|
|
62
|
+
console.log("Loading Tokenizer...");
|
|
63
|
+
const tokenizer = await transformers_1.AutoTokenizer.from_pretrained(MODEL_ID);
|
|
64
|
+
// 3. Define Helper for Inference
|
|
65
|
+
async function runInference(session, label, useInt32 = false) {
|
|
66
|
+
console.log(`Starting ${label} inference (Int32 Inputs: ${useInt32})...`);
|
|
67
|
+
const start = performance.now();
|
|
68
|
+
for (let i = 0; i < texts.length; i += BATCH_SIZE) {
|
|
69
|
+
const batchTexts = texts.slice(i, i + BATCH_SIZE);
|
|
70
|
+
const model_inputs = await tokenizer(batchTexts, { padding: true, truncation: true });
|
|
71
|
+
const feeds = {};
|
|
72
|
+
for (const [key, value] of Object.entries(model_inputs)) {
|
|
73
|
+
if (key === 'input_ids' || key === 'attention_mask' || key === 'token_type_ids') {
|
|
74
|
+
// @ts-ignore
|
|
75
|
+
let data = value.data || value.cpuData;
|
|
76
|
+
// @ts-ignore
|
|
77
|
+
const dims = value.dims; // [batch_size, seq_len]
|
|
78
|
+
// Convert to Int32 if requested (for DirectML optimization)
|
|
79
|
+
// Note: The model input MUST support int32 in ONNX, otherwise this will fail.
|
|
80
|
+
// Most transformer models use int64 for input_ids.
|
|
81
|
+
// If we change type here, we rely on ORT to cast or the model to accept it.
|
|
82
|
+
// DirectML often prefers Int32 for indices.
|
|
83
|
+
let type = 'int64';
|
|
84
|
+
if (useInt32) {
|
|
85
|
+
data = Int32Array.from(data);
|
|
86
|
+
type = 'int32';
|
|
87
|
+
}
|
|
88
|
+
else {
|
|
89
|
+
// Ensure BigInt64Array
|
|
90
|
+
if (!(data instanceof BigInt64Array)) {
|
|
91
|
+
data = BigInt64Array.from(data);
|
|
92
|
+
}
|
|
93
|
+
}
|
|
94
|
+
try {
|
|
95
|
+
feeds[key] = new ort.Tensor(type, data, dims);
|
|
96
|
+
}
|
|
97
|
+
catch (err) {
|
|
98
|
+
console.error(`Error creating tensor for ${key}:`, err.message);
|
|
99
|
+
throw err;
|
|
100
|
+
}
|
|
101
|
+
}
|
|
102
|
+
}
|
|
103
|
+
// Run Inference
|
|
104
|
+
await session.run(feeds);
|
|
105
|
+
process.stdout.write(".");
|
|
106
|
+
}
|
|
107
|
+
const end = performance.now();
|
|
108
|
+
const duration = (end - start) / 1000;
|
|
109
|
+
console.log(`\n${label} Time: ${duration.toFixed(2)}s`);
|
|
110
|
+
const speed = texts.length / duration;
|
|
111
|
+
console.log(`${label} Speed: ${speed.toFixed(2)} Embeddings/s`);
|
|
112
|
+
return speed;
|
|
113
|
+
}
|
|
114
|
+
// --- Phase 1: CPU Benchmark ---
|
|
115
|
+
console.log("\n--- Phase 1: CPU Benchmark ---");
|
|
116
|
+
let speedCpu = 0;
|
|
117
|
+
try {
|
|
118
|
+
const sessionCpu = await ort.InferenceSession.create(modelPath, {
|
|
119
|
+
executionProviders: ['cpu']
|
|
120
|
+
});
|
|
121
|
+
console.log("CPU Session created.");
|
|
122
|
+
// CPU usually handles Int64 fine
|
|
123
|
+
speedCpu = await runInference(sessionCpu, "CPU", false);
|
|
124
|
+
}
|
|
125
|
+
catch (e) {
|
|
126
|
+
console.error("CPU Benchmark failed:", e.message);
|
|
127
|
+
}
|
|
128
|
+
// --- Phase 2: GPU Benchmark ---
|
|
129
|
+
console.log("\n--- Phase 2: GPU (DirectML) Benchmark ---");
|
|
130
|
+
try {
|
|
131
|
+
const sessionOptions = {
|
|
132
|
+
executionProviders: ['dml', 'cpu'],
|
|
133
|
+
enableCpuMemArena: false // Sometimes helps with DML memory management
|
|
134
|
+
};
|
|
135
|
+
const startGpuLoad = performance.now();
|
|
136
|
+
const sessionGpu = await ort.InferenceSession.create(modelPath, sessionOptions);
|
|
137
|
+
const endGpuLoad = performance.now();
|
|
138
|
+
console.log(`GPU Session created in ${((endGpuLoad - startGpuLoad) / 1000).toFixed(2)}s`);
|
|
139
|
+
// Warmup
|
|
140
|
+
{
|
|
141
|
+
const text = ["Warmup"];
|
|
142
|
+
const model_inputs = await tokenizer(text, { padding: true, truncation: true });
|
|
143
|
+
const feeds = {};
|
|
144
|
+
for (const [key, value] of Object.entries(model_inputs)) {
|
|
145
|
+
if (key === 'input_ids' || key === 'attention_mask' || key === 'token_type_ids') {
|
|
146
|
+
// @ts-ignore
|
|
147
|
+
let data = value.data || value.cpuData;
|
|
148
|
+
// @ts-ignore
|
|
149
|
+
const dims = value.dims;
|
|
150
|
+
if (!(data instanceof BigInt64Array))
|
|
151
|
+
data = BigInt64Array.from(data);
|
|
152
|
+
feeds[key] = new ort.Tensor('int64', data, dims);
|
|
153
|
+
}
|
|
154
|
+
}
|
|
155
|
+
await sessionGpu.run(feeds);
|
|
156
|
+
}
|
|
157
|
+
// Try with Int64 first (standard)
|
|
158
|
+
let speedGpu = await runInference(sessionGpu, "GPU (Int64)", false);
|
|
159
|
+
// Compare
|
|
160
|
+
if (speedCpu > 0) {
|
|
161
|
+
const speedup = speedGpu / speedCpu;
|
|
162
|
+
console.log("\n==========================================");
|
|
163
|
+
console.log(`Result: GPU is ${speedup.toFixed(2)}x faster than CPU.`);
|
|
164
|
+
console.log("==========================================");
|
|
165
|
+
}
|
|
166
|
+
// Optional: Try Int32 if result is bad?
|
|
167
|
+
// Usually ONNX models are strict about input types.
|
|
168
|
+
// If the model expects Int64, passing Int32 might fail with "Type mismatch".
|
|
169
|
+
// We can try catching it.
|
|
170
|
+
/*
|
|
171
|
+
try {
|
|
172
|
+
console.log("\nAttempting GPU with Int32 Inputs (experimental)...");
|
|
173
|
+
const speedGpu32 = await runInference(sessionGpu, "GPU (Int32)", true);
|
|
174
|
+
if (speedCpu > 0) {
|
|
175
|
+
const speedup = speedGpu32 / speedCpu;
|
|
176
|
+
console.log(`Result (Int32): GPU is ${speedup.toFixed(2)}x faster than CPU.`);
|
|
177
|
+
}
|
|
178
|
+
} catch(e) {
|
|
179
|
+
console.log("Int32 inference not supported by model (expected).");
|
|
180
|
+
}
|
|
181
|
+
*/
|
|
182
|
+
}
|
|
183
|
+
catch (e) {
|
|
184
|
+
console.error("GPU Benchmark failed:", e.message);
|
|
185
|
+
console.error(e.stack);
|
|
186
|
+
}
|
|
187
|
+
}
|
|
188
|
+
runBenchmark();
|
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
var __createBinding = (this && this.__createBinding) || (Object.create ? (function(o, m, k, k2) {
|
|
3
|
+
if (k2 === undefined) k2 = k;
|
|
4
|
+
var desc = Object.getOwnPropertyDescriptor(m, k);
|
|
5
|
+
if (!desc || ("get" in desc ? !m.__esModule : desc.writable || desc.configurable)) {
|
|
6
|
+
desc = { enumerable: true, get: function() { return m[k]; } };
|
|
7
|
+
}
|
|
8
|
+
Object.defineProperty(o, k2, desc);
|
|
9
|
+
}) : (function(o, m, k, k2) {
|
|
10
|
+
if (k2 === undefined) k2 = k;
|
|
11
|
+
o[k2] = m[k];
|
|
12
|
+
}));
|
|
13
|
+
var __setModuleDefault = (this && this.__setModuleDefault) || (Object.create ? (function(o, v) {
|
|
14
|
+
Object.defineProperty(o, "default", { enumerable: true, value: v });
|
|
15
|
+
}) : function(o, v) {
|
|
16
|
+
o["default"] = v;
|
|
17
|
+
});
|
|
18
|
+
var __importStar = (this && this.__importStar) || (function () {
|
|
19
|
+
var ownKeys = function(o) {
|
|
20
|
+
ownKeys = Object.getOwnPropertyNames || function (o) {
|
|
21
|
+
var ar = [];
|
|
22
|
+
for (var k in o) if (Object.prototype.hasOwnProperty.call(o, k)) ar[ar.length] = k;
|
|
23
|
+
return ar;
|
|
24
|
+
};
|
|
25
|
+
return ownKeys(o);
|
|
26
|
+
};
|
|
27
|
+
return function (mod) {
|
|
28
|
+
if (mod && mod.__esModule) return mod;
|
|
29
|
+
var result = {};
|
|
30
|
+
if (mod != null) for (var k = ownKeys(mod), i = 0; i < k.length; i++) if (k[i] !== "default") __createBinding(result, mod, k[i]);
|
|
31
|
+
__setModuleDefault(result, mod);
|
|
32
|
+
return result;
|
|
33
|
+
};
|
|
34
|
+
})();
|
|
35
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
36
|
+
const transformers_1 = require("@xenova/transformers");
|
|
37
|
+
const ort = __importStar(require("onnxruntime-node"));
|
|
38
|
+
const path = __importStar(require("path"));
|
|
39
|
+
const fs = __importStar(require("fs"));
|
|
40
|
+
// Configure cache path
|
|
41
|
+
const CACHE_DIR = path.resolve('./.cache');
|
|
42
|
+
transformers_1.env.cacheDir = CACHE_DIR;
|
|
43
|
+
const EMBEDDING_DIM = 1024; // bge-m3
|
|
44
|
+
const MODEL_ID = "Xenova/bge-m3";
|
|
45
|
+
const QUANTIZED = false;
|
|
46
|
+
const MODEL_FILE = QUANTIZED ? "model_quantized.onnx" : "model.onnx";
|
|
47
|
+
// Massive Load Configuration
|
|
48
|
+
const BATCH_SIZE = 10; // Process 10 items concurrently
|
|
49
|
+
const NUM_TEXTS = 50;
|
|
50
|
+
const TEXT_LENGTH_MULTIPLIER = 5; // 5x longer texts
|
|
51
|
+
const baseText = `This is a very long complex test sentence for the extended performance comparison of CPU and GPU embeddings. We want to maximize the graphics card (RTX 2080) utilization to make the activity visible in the Task Manager. DirectML should show what it can do here. `;
|
|
52
|
+
const longText = baseText.repeat(TEXT_LENGTH_MULTIPLIER);
|
|
53
|
+
const texts = Array.from({ length: NUM_TEXTS }, (_, i) => `[${i}] ${longText}`);
|
|
54
|
+
async function runBenchmark() {
|
|
55
|
+
console.log("==========================================");
|
|
56
|
+
console.log(`Starting HEAVY Benchmark: CPU vs GPU (DirectML) [FP32 Mode]`);
|
|
57
|
+
console.log(`Batch Size: ${BATCH_SIZE}, Total Texts: ${NUM_TEXTS}`);
|
|
58
|
+
console.log(`Text Length: ~${longText.length} chars`);
|
|
59
|
+
console.log("==========================================");
|
|
60
|
+
// 1. Prepare Model Path
|
|
61
|
+
const modelPath = path.join(CACHE_DIR, 'Xenova', 'bge-m3', 'onnx', MODEL_FILE);
|
|
62
|
+
if (!fs.existsSync(modelPath)) {
|
|
63
|
+
console.error(`Model not found at: ${modelPath}`);
|
|
64
|
+
return;
|
|
65
|
+
}
|
|
66
|
+
console.log(`Model path: ${modelPath}`);
|
|
67
|
+
// 2. Load Tokenizer
|
|
68
|
+
console.log("Loading Tokenizer...");
|
|
69
|
+
const tokenizer = await transformers_1.AutoTokenizer.from_pretrained(MODEL_ID);
|
|
70
|
+
// 3. Define Helper for Inference
|
|
71
|
+
async function runInference(session, label) {
|
|
72
|
+
console.log(`Starting ${label} inference...`);
|
|
73
|
+
const start = performance.now();
|
|
74
|
+
let processed = 0;
|
|
75
|
+
for (let i = 0; i < texts.length; i += BATCH_SIZE) {
|
|
76
|
+
const batchTexts = texts.slice(i, i + BATCH_SIZE);
|
|
77
|
+
const model_inputs = await tokenizer(batchTexts, { padding: true, truncation: true, maxLength: 512 });
|
|
78
|
+
const feeds = {};
|
|
79
|
+
for (const [key, value] of Object.entries(model_inputs)) {
|
|
80
|
+
if (key === 'input_ids' || key === 'attention_mask' || key === 'token_type_ids') {
|
|
81
|
+
// @ts-ignore
|
|
82
|
+
let data = value.data || value.cpuData;
|
|
83
|
+
// @ts-ignore
|
|
84
|
+
const dims = value.dims; // [batch_size, seq_len]
|
|
85
|
+
// Ensure BigInt64Array
|
|
86
|
+
if (!(data instanceof BigInt64Array)) {
|
|
87
|
+
data = BigInt64Array.from(data);
|
|
88
|
+
}
|
|
89
|
+
try {
|
|
90
|
+
feeds[key] = new ort.Tensor('int64', data, dims);
|
|
91
|
+
}
|
|
92
|
+
catch (err) {
|
|
93
|
+
console.error(`Error creating tensor for ${key}:`, err.message);
|
|
94
|
+
throw err;
|
|
95
|
+
}
|
|
96
|
+
}
|
|
97
|
+
}
|
|
98
|
+
// Run Inference
|
|
99
|
+
await session.run(feeds);
|
|
100
|
+
processed += batchTexts.length;
|
|
101
|
+
if (processed % (BATCH_SIZE * 5) === 0) {
|
|
102
|
+
process.stdout.write(`\r${label}: ${processed}/${NUM_TEXTS} Embeddings...`);
|
|
103
|
+
}
|
|
104
|
+
}
|
|
105
|
+
const end = performance.now();
|
|
106
|
+
const duration = (end - start) / 1000;
|
|
107
|
+
console.log(`\n${label} Time: ${duration.toFixed(2)}s`);
|
|
108
|
+
const speed = texts.length / duration;
|
|
109
|
+
console.log(`${label} Speed: ${speed.toFixed(2)} Embeddings/s`);
|
|
110
|
+
return speed;
|
|
111
|
+
}
|
|
112
|
+
// --- Phase 1: GPU Benchmark (Priority) ---
|
|
113
|
+
console.log("\n--- Phase 1: GPU (DirectML) Benchmark ---");
|
|
114
|
+
let speedGpu = 0;
|
|
115
|
+
try {
|
|
116
|
+
const sessionOptions = {
|
|
117
|
+
executionProviders: [
|
|
118
|
+
{
|
|
119
|
+
name: 'dml',
|
|
120
|
+
device_id: 0,
|
|
121
|
+
},
|
|
122
|
+
'cpu'
|
|
123
|
+
],
|
|
124
|
+
graphOptimizationLevel: 'all',
|
|
125
|
+
enableCpuMemArena: false
|
|
126
|
+
};
|
|
127
|
+
console.log("Creating GPU Session (This might take a moment)...");
|
|
128
|
+
const startGpuLoad = performance.now();
|
|
129
|
+
const sessionGpu = await ort.InferenceSession.create(modelPath, sessionOptions);
|
|
130
|
+
const endGpuLoad = performance.now();
|
|
131
|
+
console.log(`GPU Session created in ${((endGpuLoad - startGpuLoad) / 1000).toFixed(2)}s`);
|
|
132
|
+
// @ts-ignore
|
|
133
|
+
console.log(`Providers: ${sessionGpu.getProviders ? sessionGpu.getProviders() : 'Unknown'}`);
|
|
134
|
+
// Warmup
|
|
135
|
+
console.log("GPU Warmup...");
|
|
136
|
+
{
|
|
137
|
+
const text = ["Warmup sentence to wake up the GPU."];
|
|
138
|
+
const model_inputs = await tokenizer(text, { padding: true, truncation: true });
|
|
139
|
+
const feeds = {};
|
|
140
|
+
for (const [key, value] of Object.entries(model_inputs)) {
|
|
141
|
+
if (key === 'input_ids' || key === 'attention_mask' || key === 'token_type_ids') {
|
|
142
|
+
// @ts-ignore
|
|
143
|
+
let data = value.data || value.cpuData;
|
|
144
|
+
// @ts-ignore
|
|
145
|
+
const dims = value.dims;
|
|
146
|
+
if (!(data instanceof BigInt64Array))
|
|
147
|
+
data = BigInt64Array.from(data);
|
|
148
|
+
feeds[key] = new ort.Tensor('int64', data, dims);
|
|
149
|
+
}
|
|
150
|
+
}
|
|
151
|
+
await sessionGpu.run(feeds);
|
|
152
|
+
}
|
|
153
|
+
console.log("PLEASE WATCH TASK MANAGER GPU TAB NOW!");
|
|
154
|
+
await new Promise(resolve => setTimeout(resolve, 2000)); // Give user time to switch
|
|
155
|
+
speedGpu = await runInference(sessionGpu, "GPU");
|
|
156
|
+
}
|
|
157
|
+
catch (e) {
|
|
158
|
+
console.error("GPU Benchmark failed:", e.message);
|
|
159
|
+
console.error(e.stack);
|
|
160
|
+
}
|
|
161
|
+
// --- Phase 2: CPU Benchmark (Comparison) ---
|
|
162
|
+
// Only run if GPU succeeded to compare, or if user wants to see baseline
|
|
163
|
+
console.log("\n--- Phase 2: CPU Benchmark ---");
|
|
164
|
+
let speedCpu = 0;
|
|
165
|
+
try {
|
|
166
|
+
const sessionCpu = await ort.InferenceSession.create(modelPath, {
|
|
167
|
+
executionProviders: ['cpu'],
|
|
168
|
+
graphOptimizationLevel: 'all'
|
|
169
|
+
});
|
|
170
|
+
console.log("CPU Session created.");
|
|
171
|
+
// Limit CPU run to avoid waiting too long if it's very slow
|
|
172
|
+
// We'll run a subset for CPU
|
|
173
|
+
const cpuTextsOriginal = texts;
|
|
174
|
+
// Use fewer texts for CPU to save time, then extrapolate
|
|
175
|
+
// 200 texts is enough for a reliable CPU speed measure
|
|
176
|
+
const CPU_SUBSET_SIZE = 200;
|
|
177
|
+
console.log(`(Using only ${CPU_SUBSET_SIZE} texts for CPU Benchmark to save time...)`);
|
|
178
|
+
// Mock the texts array temporarily or adjust the function
|
|
179
|
+
// Actually, let's just slice the array in the function call?
|
|
180
|
+
// No, the function uses the global 'texts'. I should have made it an argument.
|
|
181
|
+
// I'll just change the global 'texts' variable? No it's const.
|
|
182
|
+
// I'll create a new runner function or just accept it runs on full set?
|
|
183
|
+
// 2000 texts on CPU might take forever if it's 5/s -> 400s = 6 mins. Too long.
|
|
184
|
+
// I'll define a new runInferenceForCpu that takes texts
|
|
185
|
+
async function runInferenceSubset(session, label, subset) {
|
|
186
|
+
console.log(`Starting ${label} inference (Subset: ${subset.length})...`);
|
|
187
|
+
const start = performance.now();
|
|
188
|
+
let processed = 0;
|
|
189
|
+
for (let i = 0; i < subset.length; i += BATCH_SIZE) {
|
|
190
|
+
const batchTexts = subset.slice(i, i + BATCH_SIZE);
|
|
191
|
+
const model_inputs = await tokenizer(batchTexts, { padding: true, truncation: true, maxLength: 512 });
|
|
192
|
+
const feeds = {};
|
|
193
|
+
for (const [key, value] of Object.entries(model_inputs)) {
|
|
194
|
+
if (key === 'input_ids' || key === 'attention_mask' || key === 'token_type_ids') {
|
|
195
|
+
// @ts-ignore
|
|
196
|
+
let data = value.data || value.cpuData;
|
|
197
|
+
// @ts-ignore
|
|
198
|
+
const dims = value.dims;
|
|
199
|
+
if (!(data instanceof BigInt64Array))
|
|
200
|
+
data = BigInt64Array.from(data);
|
|
201
|
+
feeds[key] = new ort.Tensor('int64', data, dims);
|
|
202
|
+
}
|
|
203
|
+
}
|
|
204
|
+
await session.run(feeds);
|
|
205
|
+
processed += batchTexts.length;
|
|
206
|
+
process.stdout.write(`.`);
|
|
207
|
+
}
|
|
208
|
+
const end = performance.now();
|
|
209
|
+
const duration = (end - start) / 1000;
|
|
210
|
+
console.log(`\n${label} Time: ${duration.toFixed(2)}s`);
|
|
211
|
+
const speed = subset.length / duration;
|
|
212
|
+
console.log(`${label} Speed: ${speed.toFixed(2)} Embeddings/s`);
|
|
213
|
+
return speed;
|
|
214
|
+
}
|
|
215
|
+
speedCpu = await runInferenceSubset(sessionCpu, "CPU", texts.slice(0, CPU_SUBSET_SIZE));
|
|
216
|
+
}
|
|
217
|
+
catch (e) {
|
|
218
|
+
console.error("CPU Benchmark failed:", e.message);
|
|
219
|
+
}
|
|
220
|
+
// Final Result
|
|
221
|
+
if (speedGpu > 0 && speedCpu > 0) {
|
|
222
|
+
const speedup = speedGpu / speedCpu;
|
|
223
|
+
console.log("\n==========================================");
|
|
224
|
+
console.log(`Result: GPU is ${speedup.toFixed(2)}x faster than CPU.`);
|
|
225
|
+
console.log(`GPU Throughput: ${speedGpu.toFixed(2)} emb/s`);
|
|
226
|
+
console.log(`CPU Throughput: ${speedCpu.toFixed(2)} emb/s`);
|
|
227
|
+
console.log("==========================================");
|
|
228
|
+
}
|
|
229
|
+
}
|
|
230
|
+
runBenchmark();
|