@mastra/pg 0.0.0-commonjs-20250227130920
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.turbo/turbo-build.log +23 -0
- package/CHANGELOG.md +503 -0
- package/LICENSE +44 -0
- package/README.md +161 -0
- package/dist/_tsup-dts-rollup.d.cts +304 -0
- package/dist/_tsup-dts-rollup.d.ts +304 -0
- package/dist/index.cjs +1043 -0
- package/dist/index.d.cts +4 -0
- package/dist/index.d.ts +4 -0
- package/dist/index.js +1035 -0
- package/docker-compose.perf.yaml +21 -0
- package/docker-compose.yaml +14 -0
- package/eslint.config.js +6 -0
- package/package.json +50 -0
- package/src/index.ts +2 -0
- package/src/storage/index.test.ts +380 -0
- package/src/storage/index.ts +592 -0
- package/src/vector/filter.test.ts +967 -0
- package/src/vector/filter.ts +107 -0
- package/src/vector/index.test.ts +1302 -0
- package/src/vector/index.ts +391 -0
- package/src/vector/performance.helpers.ts +286 -0
- package/src/vector/sql-builder.ts +285 -0
- package/src/vector/types.ts +16 -0
- package/src/vector/vector.performance.test.ts +370 -0
- package/tsconfig.json +5 -0
- package/vitest.config.ts +12 -0
- package/vitest.perf.config.ts +8 -0
|
@@ -0,0 +1,391 @@
|
|
|
1
|
+
import type { Filter } from '@mastra/core/filter';
|
|
2
|
+
import { MastraVector } from '@mastra/core/vector';
|
|
3
|
+
import type { IndexStats, QueryResult } from '@mastra/core/vector';
|
|
4
|
+
import pg from 'pg';
|
|
5
|
+
|
|
6
|
+
import { PGFilterTranslator } from './filter';
|
|
7
|
+
import { buildFilterQuery } from './sql-builder';
|
|
8
|
+
import type { IndexConfig, IndexType } from './types';
|
|
9
|
+
|
|
10
|
+
export interface PGIndexStats extends IndexStats {
|
|
11
|
+
type: IndexType;
|
|
12
|
+
config: {
|
|
13
|
+
m?: number;
|
|
14
|
+
efConstruction?: number;
|
|
15
|
+
lists?: number;
|
|
16
|
+
probes?: number;
|
|
17
|
+
};
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
export class PgVector extends MastraVector {
|
|
21
|
+
private pool: pg.Pool;
|
|
22
|
+
private indexCache: Map<string, PGIndexStats> = new Map();
|
|
23
|
+
|
|
24
|
+
constructor(connectionString: string) {
|
|
25
|
+
super();
|
|
26
|
+
|
|
27
|
+
const basePool = new pg.Pool({
|
|
28
|
+
connectionString,
|
|
29
|
+
max: 20, // Maximum number of clients in the pool
|
|
30
|
+
idleTimeoutMillis: 30000, // Close idle connections after 30 seconds
|
|
31
|
+
connectionTimeoutMillis: 2000, // Fail fast if can't connect
|
|
32
|
+
});
|
|
33
|
+
|
|
34
|
+
const telemetry = this.__getTelemetry();
|
|
35
|
+
|
|
36
|
+
this.pool =
|
|
37
|
+
telemetry?.traceClass(basePool, {
|
|
38
|
+
spanNamePrefix: 'pg-vector',
|
|
39
|
+
attributes: {
|
|
40
|
+
'vector.type': 'postgres',
|
|
41
|
+
},
|
|
42
|
+
}) ?? basePool;
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
transformFilter(filter?: Filter) {
|
|
46
|
+
const pgFilter = new PGFilterTranslator();
|
|
47
|
+
const translatedFilter = pgFilter.translate(filter ?? {});
|
|
48
|
+
return translatedFilter;
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
async getIndexInfo(indexName: string): Promise<PGIndexStats> {
|
|
52
|
+
if (!this.indexCache.has(indexName)) {
|
|
53
|
+
this.indexCache.set(indexName, await this.describeIndex(indexName));
|
|
54
|
+
}
|
|
55
|
+
return this.indexCache.get(indexName)!;
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
async query(
|
|
59
|
+
indexName: string,
|
|
60
|
+
queryVector: number[],
|
|
61
|
+
topK: number = 10,
|
|
62
|
+
filter?: Filter,
|
|
63
|
+
includeVector: boolean = false,
|
|
64
|
+
minScore: number = 0, // Optional minimum score threshold
|
|
65
|
+
options?: {
|
|
66
|
+
ef?: number; // For HNSW
|
|
67
|
+
probes?: number; // For IVF
|
|
68
|
+
},
|
|
69
|
+
): Promise<QueryResult[]> {
|
|
70
|
+
const client = await this.pool.connect();
|
|
71
|
+
try {
|
|
72
|
+
const vectorStr = `[${queryVector.join(',')}]`;
|
|
73
|
+
const translatedFilter = this.transformFilter(filter);
|
|
74
|
+
const { sql: filterQuery, values: filterValues } = buildFilterQuery(translatedFilter, minScore);
|
|
75
|
+
|
|
76
|
+
// Get index type and configuration
|
|
77
|
+
const indexInfo = await this.getIndexInfo(indexName);
|
|
78
|
+
|
|
79
|
+
// Set HNSW search parameter if applicable
|
|
80
|
+
if (indexInfo.type === 'hnsw') {
|
|
81
|
+
// Calculate ef and clamp between 1 and 1000
|
|
82
|
+
const calculatedEf = options?.ef ?? Math.max(topK, (indexInfo?.config?.m ?? 16) * topK);
|
|
83
|
+
const searchEf = Math.min(1000, Math.max(1, calculatedEf));
|
|
84
|
+
await client.query(`SET LOCAL hnsw.ef_search = ${searchEf}`);
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
if (indexInfo.type === 'ivfflat' && options?.probes) {
|
|
88
|
+
await client.query(`SET LOCAL ivfflat.probes = ${options.probes}`);
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
const query = `
|
|
92
|
+
WITH vector_scores AS (
|
|
93
|
+
SELECT
|
|
94
|
+
vector_id as id,
|
|
95
|
+
1 - (embedding <=> '${vectorStr}'::vector) as score,
|
|
96
|
+
metadata
|
|
97
|
+
${includeVector ? ', embedding' : ''}
|
|
98
|
+
FROM ${indexName}
|
|
99
|
+
${filterQuery}
|
|
100
|
+
)
|
|
101
|
+
SELECT *
|
|
102
|
+
FROM vector_scores
|
|
103
|
+
WHERE score > $1
|
|
104
|
+
ORDER BY score DESC
|
|
105
|
+
LIMIT ${topK}`;
|
|
106
|
+
const result = await client.query(query, filterValues);
|
|
107
|
+
|
|
108
|
+
return result.rows.map(({ id, score, metadata, embedding }) => ({
|
|
109
|
+
id,
|
|
110
|
+
score,
|
|
111
|
+
metadata,
|
|
112
|
+
...(includeVector && embedding && { vector: JSON.parse(embedding) }),
|
|
113
|
+
}));
|
|
114
|
+
} finally {
|
|
115
|
+
client.release();
|
|
116
|
+
}
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
async upsert(
|
|
120
|
+
indexName: string,
|
|
121
|
+
vectors: number[][],
|
|
122
|
+
metadata?: Record<string, any>[],
|
|
123
|
+
ids?: string[],
|
|
124
|
+
): Promise<string[]> {
|
|
125
|
+
// Start a transaction
|
|
126
|
+
const client = await this.pool.connect();
|
|
127
|
+
try {
|
|
128
|
+
await client.query('BEGIN');
|
|
129
|
+
const vectorIds = ids || vectors.map(() => crypto.randomUUID());
|
|
130
|
+
|
|
131
|
+
for (let i = 0; i < vectors.length; i++) {
|
|
132
|
+
const query = `
|
|
133
|
+
INSERT INTO ${indexName} (vector_id, embedding, metadata)
|
|
134
|
+
VALUES ($1, $2::vector, $3::jsonb)
|
|
135
|
+
ON CONFLICT (vector_id)
|
|
136
|
+
DO UPDATE SET
|
|
137
|
+
embedding = $2::vector,
|
|
138
|
+
metadata = $3::jsonb
|
|
139
|
+
RETURNING embedding::text
|
|
140
|
+
`;
|
|
141
|
+
|
|
142
|
+
await client.query(query, [vectorIds[i], `[${vectors[i]?.join(',')}]`, JSON.stringify(metadata?.[i] || {})]);
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
await client.query('COMMIT');
|
|
146
|
+
return vectorIds;
|
|
147
|
+
} catch (error) {
|
|
148
|
+
await client.query('ROLLBACK');
|
|
149
|
+
throw error;
|
|
150
|
+
} finally {
|
|
151
|
+
client.release();
|
|
152
|
+
}
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
async createIndex(
|
|
156
|
+
indexName: string,
|
|
157
|
+
dimension: number,
|
|
158
|
+
metric: 'cosine' | 'euclidean' | 'dotproduct' = 'cosine',
|
|
159
|
+
indexConfig: IndexConfig = {},
|
|
160
|
+
defineIndex: boolean = true,
|
|
161
|
+
): Promise<void> {
|
|
162
|
+
const client = await this.pool.connect();
|
|
163
|
+
try {
|
|
164
|
+
// Validate inputs
|
|
165
|
+
if (!indexName.match(/^[a-zA-Z_][a-zA-Z0-9_]*$/)) {
|
|
166
|
+
throw new Error('Invalid index name format');
|
|
167
|
+
}
|
|
168
|
+
if (!Number.isInteger(dimension) || dimension <= 0) {
|
|
169
|
+
throw new Error('Dimension must be a positive integer');
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
// First check if vector extension is available
|
|
173
|
+
const extensionCheck = await client.query(`
|
|
174
|
+
SELECT EXISTS (
|
|
175
|
+
SELECT 1 FROM pg_available_extensions WHERE name = 'vector'
|
|
176
|
+
);
|
|
177
|
+
`);
|
|
178
|
+
|
|
179
|
+
if (!extensionCheck.rows[0].exists) {
|
|
180
|
+
throw new Error('PostgreSQL vector extension is not available. Please install it first.');
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
// Try to create extension
|
|
184
|
+
await client.query('CREATE EXTENSION IF NOT EXISTS vector');
|
|
185
|
+
|
|
186
|
+
// Create the table with explicit schema
|
|
187
|
+
await client.query(`
|
|
188
|
+
CREATE TABLE IF NOT EXISTS ${indexName} (
|
|
189
|
+
id SERIAL PRIMARY KEY,
|
|
190
|
+
vector_id TEXT UNIQUE NOT NULL,
|
|
191
|
+
embedding vector(${dimension}),
|
|
192
|
+
metadata JSONB DEFAULT '{}'::jsonb
|
|
193
|
+
);
|
|
194
|
+
`);
|
|
195
|
+
|
|
196
|
+
if (defineIndex) {
|
|
197
|
+
await this.defineIndex(indexName, metric, indexConfig);
|
|
198
|
+
}
|
|
199
|
+
} catch (error: any) {
|
|
200
|
+
console.error('Failed to create vector table:', error);
|
|
201
|
+
throw error;
|
|
202
|
+
} finally {
|
|
203
|
+
client.release();
|
|
204
|
+
}
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
/**
|
|
208
|
+
* @deprecated This function is deprecated. Use buildIndex instead
|
|
209
|
+
*/
|
|
210
|
+
async defineIndex(
|
|
211
|
+
indexName: string,
|
|
212
|
+
metric: 'cosine' | 'euclidean' | 'dotproduct' = 'cosine',
|
|
213
|
+
indexConfig: IndexConfig,
|
|
214
|
+
): Promise<void> {
|
|
215
|
+
return this.buildIndex(indexName, metric, indexConfig);
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
async buildIndex(
|
|
219
|
+
indexName: string,
|
|
220
|
+
metric: 'cosine' | 'euclidean' | 'dotproduct' = 'cosine',
|
|
221
|
+
indexConfig: IndexConfig,
|
|
222
|
+
): Promise<void> {
|
|
223
|
+
const client = await this.pool.connect();
|
|
224
|
+
try {
|
|
225
|
+
await client.query(`DROP INDEX IF EXISTS ${indexName}_vector_idx`);
|
|
226
|
+
|
|
227
|
+
if (indexConfig.type === 'flat') return;
|
|
228
|
+
|
|
229
|
+
const metricOp =
|
|
230
|
+
metric === 'cosine' ? 'vector_cosine_ops' : metric === 'euclidean' ? 'vector_l2_ops' : 'vector_ip_ops';
|
|
231
|
+
|
|
232
|
+
let indexSQL: string;
|
|
233
|
+
if (indexConfig.type === 'hnsw') {
|
|
234
|
+
const m = indexConfig.hnsw?.m ?? 8;
|
|
235
|
+
const efConstruction = indexConfig.hnsw?.efConstruction ?? 32;
|
|
236
|
+
|
|
237
|
+
indexSQL = `
|
|
238
|
+
CREATE INDEX ${indexName}_vector_idx
|
|
239
|
+
ON ${indexName}
|
|
240
|
+
USING hnsw (embedding ${metricOp})
|
|
241
|
+
WITH (
|
|
242
|
+
m = ${m},
|
|
243
|
+
ef_construction = ${efConstruction}
|
|
244
|
+
)
|
|
245
|
+
`;
|
|
246
|
+
} else {
|
|
247
|
+
let lists: number;
|
|
248
|
+
if (indexConfig.ivf?.lists) {
|
|
249
|
+
lists = indexConfig.ivf.lists;
|
|
250
|
+
} else {
|
|
251
|
+
const size = (await client.query(`SELECT COUNT(*) FROM ${indexName}`)).rows[0].count;
|
|
252
|
+
lists = Math.max(100, Math.min(4000, Math.floor(Math.sqrt(size) * 2)));
|
|
253
|
+
}
|
|
254
|
+
indexSQL = `
|
|
255
|
+
CREATE INDEX ${indexName}_vector_idx
|
|
256
|
+
ON ${indexName}
|
|
257
|
+
USING ivfflat (embedding ${metricOp})
|
|
258
|
+
WITH (lists = ${lists});
|
|
259
|
+
`;
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
await client.query(indexSQL);
|
|
263
|
+
this.indexCache.delete(indexName);
|
|
264
|
+
} finally {
|
|
265
|
+
client.release();
|
|
266
|
+
}
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
async listIndexes(): Promise<string[]> {
|
|
270
|
+
const client = await this.pool.connect();
|
|
271
|
+
try {
|
|
272
|
+
// Then let's see which ones have vector columns
|
|
273
|
+
const vectorTablesQuery = `
|
|
274
|
+
SELECT DISTINCT table_name
|
|
275
|
+
FROM information_schema.columns
|
|
276
|
+
WHERE table_schema = 'public'
|
|
277
|
+
AND udt_name = 'vector';
|
|
278
|
+
`;
|
|
279
|
+
const vectorTables = await client.query(vectorTablesQuery);
|
|
280
|
+
return vectorTables.rows.map(row => row.table_name);
|
|
281
|
+
} finally {
|
|
282
|
+
client.release();
|
|
283
|
+
}
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
async describeIndex(indexName: string): Promise<PGIndexStats> {
|
|
287
|
+
const client = await this.pool.connect();
|
|
288
|
+
try {
|
|
289
|
+
// Get vector dimension
|
|
290
|
+
const dimensionQuery = `
|
|
291
|
+
SELECT atttypmod as dimension
|
|
292
|
+
FROM pg_attribute
|
|
293
|
+
WHERE attrelid = $1::regclass
|
|
294
|
+
AND attname = 'embedding';
|
|
295
|
+
`;
|
|
296
|
+
|
|
297
|
+
// Get row count
|
|
298
|
+
const countQuery = `
|
|
299
|
+
SELECT COUNT(*) as count
|
|
300
|
+
FROM ${indexName};
|
|
301
|
+
`;
|
|
302
|
+
|
|
303
|
+
// Get index metric type
|
|
304
|
+
const indexQuery = `
|
|
305
|
+
SELECT
|
|
306
|
+
am.amname as index_method,
|
|
307
|
+
pg_get_indexdef(i.indexrelid) as index_def,
|
|
308
|
+
opclass.opcname as operator_class
|
|
309
|
+
FROM pg_index i
|
|
310
|
+
JOIN pg_class c ON i.indexrelid = c.oid
|
|
311
|
+
JOIN pg_am am ON c.relam = am.oid
|
|
312
|
+
JOIN pg_opclass opclass ON i.indclass[0] = opclass.oid
|
|
313
|
+
WHERE c.relname = '${indexName}_vector_idx';
|
|
314
|
+
`;
|
|
315
|
+
|
|
316
|
+
const [dimResult, countResult, indexResult] = await Promise.all([
|
|
317
|
+
client.query(dimensionQuery, [indexName]),
|
|
318
|
+
client.query(countQuery),
|
|
319
|
+
client.query(indexQuery),
|
|
320
|
+
]);
|
|
321
|
+
|
|
322
|
+
const { index_method, index_def, operator_class } = indexResult.rows[0] || {
|
|
323
|
+
index_method: 'flat',
|
|
324
|
+
index_def: '',
|
|
325
|
+
operator_class: 'cosine',
|
|
326
|
+
};
|
|
327
|
+
|
|
328
|
+
// Convert pg_vector index method to our metric type
|
|
329
|
+
const metric = operator_class.includes('l2')
|
|
330
|
+
? 'euclidean'
|
|
331
|
+
: operator_class.includes('ip')
|
|
332
|
+
? 'dotproduct'
|
|
333
|
+
: 'cosine';
|
|
334
|
+
|
|
335
|
+
// Parse index configuration
|
|
336
|
+
const config: { m?: number; efConstruction?: number; lists?: number } = {};
|
|
337
|
+
|
|
338
|
+
if (index_method === 'hnsw') {
|
|
339
|
+
const m = index_def.match(/m\s*=\s*'?(\d+)'?/)?.[1];
|
|
340
|
+
const efConstruction = index_def.match(/ef_construction\s*=\s*'?(\d+)'?/)?.[1];
|
|
341
|
+
if (m) config.m = parseInt(m);
|
|
342
|
+
if (efConstruction) config.efConstruction = parseInt(efConstruction);
|
|
343
|
+
} else if (index_method === 'ivfflat') {
|
|
344
|
+
const lists = index_def.match(/lists\s*=\s*'?(\d+)'?/)?.[1];
|
|
345
|
+
if (lists) config.lists = parseInt(lists);
|
|
346
|
+
}
|
|
347
|
+
|
|
348
|
+
return {
|
|
349
|
+
dimension: dimResult.rows[0].dimension,
|
|
350
|
+
count: parseInt(countResult.rows[0].count),
|
|
351
|
+
metric,
|
|
352
|
+
type: index_method as 'flat' | 'hnsw' | 'ivfflat',
|
|
353
|
+
config,
|
|
354
|
+
};
|
|
355
|
+
} catch (e: any) {
|
|
356
|
+
await client.query('ROLLBACK');
|
|
357
|
+
throw new Error(`Failed to describe vector table: ${e.message}`);
|
|
358
|
+
} finally {
|
|
359
|
+
client.release();
|
|
360
|
+
}
|
|
361
|
+
}
|
|
362
|
+
|
|
363
|
+
async deleteIndex(indexName: string): Promise<void> {
|
|
364
|
+
const client = await this.pool.connect();
|
|
365
|
+
try {
|
|
366
|
+
// Drop the table
|
|
367
|
+
await client.query(`DROP TABLE IF EXISTS ${indexName} CASCADE`);
|
|
368
|
+
} catch (error: any) {
|
|
369
|
+
await client.query('ROLLBACK');
|
|
370
|
+
throw new Error(`Failed to delete vector table: ${error.message}`);
|
|
371
|
+
} finally {
|
|
372
|
+
client.release();
|
|
373
|
+
}
|
|
374
|
+
}
|
|
375
|
+
|
|
376
|
+
async truncateIndex(indexName: string) {
|
|
377
|
+
const client = await this.pool.connect();
|
|
378
|
+
try {
|
|
379
|
+
await client.query(`TRUNCATE ${indexName}`);
|
|
380
|
+
} catch (e: any) {
|
|
381
|
+
await client.query('ROLLBACK');
|
|
382
|
+
throw new Error(`Failed to truncate vector table: ${e.message}`);
|
|
383
|
+
} finally {
|
|
384
|
+
client.release();
|
|
385
|
+
}
|
|
386
|
+
}
|
|
387
|
+
|
|
388
|
+
async disconnect() {
|
|
389
|
+
await this.pool.end();
|
|
390
|
+
}
|
|
391
|
+
}
|
|
@@ -0,0 +1,286 @@
|
|
|
1
|
+
import type { IndexConfig, IndexType } from './types';
|
|
2
|
+
|
|
3
|
+
import type { PgVector } from '.';
|
|
4
|
+
|
|
5
|
+
export interface TestResult {
|
|
6
|
+
distribution: string;
|
|
7
|
+
dimension: number;
|
|
8
|
+
type: IndexType;
|
|
9
|
+
size: number;
|
|
10
|
+
k?: number;
|
|
11
|
+
metrics: {
|
|
12
|
+
recall?: number;
|
|
13
|
+
minRecall?: number;
|
|
14
|
+
maxRecall?: number;
|
|
15
|
+
latency?: {
|
|
16
|
+
p50: number;
|
|
17
|
+
p95: number;
|
|
18
|
+
lists?: number;
|
|
19
|
+
vectorsPerList?: number;
|
|
20
|
+
m?: number;
|
|
21
|
+
ef?: number;
|
|
22
|
+
};
|
|
23
|
+
clustering?: {
|
|
24
|
+
numLists?: number;
|
|
25
|
+
avgVectorsPerList?: number;
|
|
26
|
+
recommendedLists?: number;
|
|
27
|
+
distribution?: string;
|
|
28
|
+
};
|
|
29
|
+
};
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
export const generateRandomVectors = (count: number, dim: number) => {
|
|
33
|
+
return Array.from({ length: count }, () => {
|
|
34
|
+
return Array.from({ length: dim }, () => Math.random() * 2 - 1);
|
|
35
|
+
});
|
|
36
|
+
};
|
|
37
|
+
|
|
38
|
+
export const generateClusteredVectors = (count: number, dim: number, numClusters: number = 10) => {
|
|
39
|
+
// Generate cluster centers
|
|
40
|
+
const centers = Array.from({ length: numClusters }, () => Array.from({ length: dim }, () => Math.random() * 2 - 1));
|
|
41
|
+
|
|
42
|
+
// Generate vectors around centers with varying spread
|
|
43
|
+
return Array.from({ length: count }, () => {
|
|
44
|
+
// Pick a random cluster, with some clusters being more popular
|
|
45
|
+
const centerIdx = Math.floor(Math.pow(Math.random(), 2) * numClusters);
|
|
46
|
+
const center = centers[centerIdx] as number[];
|
|
47
|
+
|
|
48
|
+
// Add noise, with some vectors being further from centers
|
|
49
|
+
const spread = Math.random() < 0.8 ? 0.1 : 0.5; // 80% close, 20% far
|
|
50
|
+
return center.map(c => c + (Math.random() * spread - spread / 2));
|
|
51
|
+
});
|
|
52
|
+
};
|
|
53
|
+
|
|
54
|
+
// Or even more extreme:
|
|
55
|
+
export const generateSkewedVectors = (count: number, dim: number) => {
|
|
56
|
+
// Create dense clusters with sparse regions
|
|
57
|
+
const vectors: number[][] = [];
|
|
58
|
+
|
|
59
|
+
const denseCount = Math.floor(count * 0.6);
|
|
60
|
+
const sparseCount = count - denseCount;
|
|
61
|
+
|
|
62
|
+
// Dense cluster (60% of vectors)
|
|
63
|
+
const denseCenter = Array.from({ length: dim }, () => Math.random() * 0.2);
|
|
64
|
+
for (let i = 0; i < denseCount; i++) {
|
|
65
|
+
vectors.push(denseCenter.map(c => c + (Math.random() * 0.1 - 0.05)));
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
// Scattered vectors (40%)
|
|
69
|
+
for (let i = 0; i < sparseCount; i++) {
|
|
70
|
+
vectors.push(Array.from({ length: dim }, () => Math.random() * 2 - 1));
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
return vectors.sort(() => Math.random() - 0.5); // Shuffle
|
|
74
|
+
};
|
|
75
|
+
|
|
76
|
+
export const findNearestBruteForce = (query: number[], vectors: number[][], k: number) => {
|
|
77
|
+
const similarities = vectors.map((vector, idx) => {
|
|
78
|
+
const similarity = cosineSimilarity(query, vector);
|
|
79
|
+
return { idx, dist: similarity };
|
|
80
|
+
});
|
|
81
|
+
|
|
82
|
+
const sorted = similarities.sort((a, b) => b.dist - a.dist);
|
|
83
|
+
return sorted.slice(0, k).map(x => x.idx);
|
|
84
|
+
};
|
|
85
|
+
|
|
86
|
+
export const calculateRecall = (actual: number[], expected: number[], k: number): number => {
|
|
87
|
+
let score = 0;
|
|
88
|
+
for (let i = 0; i < k; i++) {
|
|
89
|
+
if (actual[i] === expected[i]) {
|
|
90
|
+
score += 1;
|
|
91
|
+
} else if (expected.includes(actual[i] ?? 0)) {
|
|
92
|
+
score += 0.5;
|
|
93
|
+
}
|
|
94
|
+
}
|
|
95
|
+
return score / k;
|
|
96
|
+
};
|
|
97
|
+
|
|
98
|
+
export function cosineSimilarity(a: number[], b: number[]): number {
|
|
99
|
+
const dotProduct = a.reduce((sum, val, i) => sum + (val ?? 0) * (b[i] ?? 0), 0);
|
|
100
|
+
const normA = Math.sqrt(a.reduce((sum, val) => sum + val * val, 0));
|
|
101
|
+
const normB = Math.sqrt(b.reduce((sum, val) => sum + val * val, 0));
|
|
102
|
+
return dotProduct / (normA * normB);
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
export const formatTable = (data: any[], columns: string[]) => {
|
|
106
|
+
const colWidths = columns.map(col =>
|
|
107
|
+
Math.max(
|
|
108
|
+
col.length,
|
|
109
|
+
...data.map(row => {
|
|
110
|
+
const value = row[col];
|
|
111
|
+
return value === undefined || value === null ? '-'.length : value.toString().length;
|
|
112
|
+
}),
|
|
113
|
+
),
|
|
114
|
+
);
|
|
115
|
+
|
|
116
|
+
const topBorder = '┌' + colWidths.map(w => '─'.repeat(w)).join('┬') + '┐';
|
|
117
|
+
const headerSeparator = '├' + colWidths.map(w => '─'.repeat(w)).join('┼') + '┤';
|
|
118
|
+
const bottomBorder = '└' + colWidths.map(w => '─'.repeat(w)).join('┴') + '┘';
|
|
119
|
+
|
|
120
|
+
const header = '│' + columns.map((col, i) => col.padEnd(colWidths[i] ?? 0)).join('│') + '│';
|
|
121
|
+
const rows = data.map(
|
|
122
|
+
row =>
|
|
123
|
+
'│' +
|
|
124
|
+
columns
|
|
125
|
+
.map((col, i) => {
|
|
126
|
+
const value = row[col];
|
|
127
|
+
const displayValue = value === undefined || value === null ? '-' : value.toString();
|
|
128
|
+
return displayValue.padEnd(colWidths[i]);
|
|
129
|
+
})
|
|
130
|
+
.join('│') +
|
|
131
|
+
'│',
|
|
132
|
+
);
|
|
133
|
+
|
|
134
|
+
return [topBorder, header, headerSeparator, ...rows, bottomBorder].join('\n');
|
|
135
|
+
};
|
|
136
|
+
|
|
137
|
+
export const groupBy = <T, K extends keyof T>(
|
|
138
|
+
array: T[],
|
|
139
|
+
key: K | ((item: T) => string),
|
|
140
|
+
reducer?: (group: T[]) => any,
|
|
141
|
+
): Record<string, any> => {
|
|
142
|
+
const grouped = array.reduce(
|
|
143
|
+
(acc, item) => {
|
|
144
|
+
const value = typeof key === 'function' ? key(item) : item[key];
|
|
145
|
+
if (!acc[value as any]) acc[value as any] = [];
|
|
146
|
+
acc[value as any]?.push(item);
|
|
147
|
+
return acc;
|
|
148
|
+
},
|
|
149
|
+
{} as Record<string, T[]>,
|
|
150
|
+
);
|
|
151
|
+
|
|
152
|
+
if (reducer) {
|
|
153
|
+
return Object.fromEntries(Object.entries(grouped).map(([key, group]) => [key, reducer(group)]));
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
return grouped;
|
|
157
|
+
};
|
|
158
|
+
|
|
159
|
+
export const calculateTimeout = (dimension: number, size: number, k: number) => {
|
|
160
|
+
let timeout = 600000;
|
|
161
|
+
if (dimension >= 1024) timeout *= 3;
|
|
162
|
+
else if (dimension >= 384) timeout *= 1.5;
|
|
163
|
+
if (size >= 10000) timeout *= 2;
|
|
164
|
+
if (k >= 75) timeout *= 1.5;
|
|
165
|
+
return timeout * 5;
|
|
166
|
+
};
|
|
167
|
+
|
|
168
|
+
export const baseTestConfigs = {
|
|
169
|
+
smokeTests: [{ dimension: 384, size: 1_000, k: 10, queryCount: 10 }],
|
|
170
|
+
'64': [
|
|
171
|
+
{ dimension: 64, size: 100, k: 10, queryCount: 30 },
|
|
172
|
+
{ dimension: 64, size: 100, k: 25, queryCount: 30 },
|
|
173
|
+
{ dimension: 64, size: 100, k: 50, queryCount: 30 },
|
|
174
|
+
{ dimension: 64, size: 100, k: 100, queryCount: 30 },
|
|
175
|
+
{ dimension: 64, size: 1_000, k: 10, queryCount: 30 },
|
|
176
|
+
{ dimension: 64, size: 1_000, k: 25, queryCount: 30 },
|
|
177
|
+
{ dimension: 64, size: 1_000, k: 50, queryCount: 30 },
|
|
178
|
+
{ dimension: 64, size: 1_000, k: 100, queryCount: 30 },
|
|
179
|
+
{ dimension: 64, size: 10_000, k: 10, queryCount: 30 },
|
|
180
|
+
{ dimension: 64, size: 100_000, k: 10, queryCount: 30 },
|
|
181
|
+
{ dimension: 64, size: 100_000, k: 25, queryCount: 30 },
|
|
182
|
+
{ dimension: 64, size: 100_000, k: 50, queryCount: 30 },
|
|
183
|
+
{ dimension: 64, size: 100_000, k: 100, queryCount: 30 },
|
|
184
|
+
{ dimension: 64, size: 500_000, k: 10, queryCount: 30 },
|
|
185
|
+
{ dimension: 64, size: 1_000_000, k: 10, queryCount: 30 },
|
|
186
|
+
],
|
|
187
|
+
'384': [
|
|
188
|
+
{ dimension: 384, size: 100, k: 10, queryCount: 30 },
|
|
189
|
+
{ dimension: 384, size: 100, k: 25, queryCount: 30 },
|
|
190
|
+
{ dimension: 384, size: 100, k: 50, queryCount: 30 },
|
|
191
|
+
{ dimension: 384, size: 100, k: 100, queryCount: 30 },
|
|
192
|
+
{ dimension: 384, size: 1_000, k: 10, queryCount: 30 },
|
|
193
|
+
{ dimension: 384, size: 1_000, k: 25, queryCount: 30 },
|
|
194
|
+
{ dimension: 384, size: 1_000, k: 50, queryCount: 30 },
|
|
195
|
+
{ dimension: 384, size: 1_000, k: 100, queryCount: 30 },
|
|
196
|
+
{ dimension: 384, size: 10_000, k: 10, queryCount: 30 },
|
|
197
|
+
{ dimension: 384, size: 100_000, k: 10, queryCount: 30 },
|
|
198
|
+
{ dimension: 384, size: 100_000, k: 25, queryCount: 30 },
|
|
199
|
+
{ dimension: 384, size: 100_000, k: 50, queryCount: 30 },
|
|
200
|
+
{ dimension: 384, size: 100_000, k: 100, queryCount: 30 },
|
|
201
|
+
{ dimension: 384, size: 500_000, k: 10, queryCount: 30 },
|
|
202
|
+
],
|
|
203
|
+
'1024': [
|
|
204
|
+
{ dimension: 1024, size: 100, k: 10, queryCount: 30 },
|
|
205
|
+
{ dimension: 1024, size: 100, k: 25, queryCount: 30 },
|
|
206
|
+
{ dimension: 1024, size: 100, k: 50, queryCount: 30 },
|
|
207
|
+
{ dimension: 1024, size: 100, k: 100, queryCount: 30 },
|
|
208
|
+
{ dimension: 1024, size: 1_000, k: 10, queryCount: 30 },
|
|
209
|
+
{ dimension: 1024, size: 1_000, k: 25, queryCount: 30 },
|
|
210
|
+
{ dimension: 1024, size: 1_000, k: 50, queryCount: 30 },
|
|
211
|
+
{ dimension: 1024, size: 1_000, k: 100, queryCount: 30 },
|
|
212
|
+
{ dimension: 1024, size: 10_000, k: 10, queryCount: 30 },
|
|
213
|
+
{ dimension: 1024, size: 10_000, k: 25, queryCount: 30 },
|
|
214
|
+
{ dimension: 1024, size: 10_000, k: 50, queryCount: 30 },
|
|
215
|
+
{ dimension: 1024, size: 10_000, k: 100, queryCount: 30 },
|
|
216
|
+
{ dimension: 1024, size: 50_000, k: 10, queryCount: 30 },
|
|
217
|
+
{ dimension: 1024, size: 50_000, k: 25, queryCount: 30 },
|
|
218
|
+
],
|
|
219
|
+
stressTests: [
|
|
220
|
+
// Maximum load
|
|
221
|
+
{ dimension: 512, size: 1_000_000, k: 50, queryCount: 5 },
|
|
222
|
+
|
|
223
|
+
// Dense search
|
|
224
|
+
{ dimension: 256, size: 1_000_000, k: 100, queryCount: 5 },
|
|
225
|
+
|
|
226
|
+
{ dimension: 1024, size: 500_000, k: 50, queryCount: 5 },
|
|
227
|
+
],
|
|
228
|
+
};
|
|
229
|
+
|
|
230
|
+
export interface TestConfig {
|
|
231
|
+
dimension: number;
|
|
232
|
+
size: number;
|
|
233
|
+
k: number;
|
|
234
|
+
queryCount: number;
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
export async function warmupQuery(vectorDB: PgVector, indexName: string, dimension: number, k: number) {
|
|
238
|
+
const warmupVector = generateRandomVectors(1, dimension)[0] as number[];
|
|
239
|
+
await vectorDB.query(indexName, warmupVector, k);
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
export async function measureLatency<T>(fn: () => Promise<T>): Promise<[number, T]> {
|
|
243
|
+
const start = performance.now();
|
|
244
|
+
const result = await fn();
|
|
245
|
+
const end = performance.now();
|
|
246
|
+
return [end - start, result];
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
export const getListCount = (indexConfig: IndexConfig, size: number): number | undefined => {
|
|
250
|
+
if (indexConfig.type !== 'ivfflat') return undefined;
|
|
251
|
+
if (indexConfig.ivf?.lists) return indexConfig.ivf.lists;
|
|
252
|
+
return Math.max(100, Math.min(4000, Math.floor(Math.sqrt(size) * 2)));
|
|
253
|
+
};
|
|
254
|
+
|
|
255
|
+
export const getHNSWConfig = (indexConfig: IndexConfig): { m: number; efConstruction: number } => {
|
|
256
|
+
return {
|
|
257
|
+
m: indexConfig.hnsw?.m ?? 8,
|
|
258
|
+
efConstruction: indexConfig.hnsw?.efConstruction ?? 32,
|
|
259
|
+
};
|
|
260
|
+
};
|
|
261
|
+
|
|
262
|
+
export function getSearchEf(k: number, m: number) {
|
|
263
|
+
return {
|
|
264
|
+
default: Math.max(k, m * k), // Default calculation
|
|
265
|
+
lower: Math.max(k, (m * k) / 2), // Lower quality, faster
|
|
266
|
+
higher: Math.max(k, m * k * 2), // Higher quality, slower
|
|
267
|
+
};
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
export function getIndexDescription({
|
|
271
|
+
type,
|
|
272
|
+
hnsw,
|
|
273
|
+
}: {
|
|
274
|
+
type: IndexType;
|
|
275
|
+
hnsw: { m: number; efConstruction: number };
|
|
276
|
+
}): string {
|
|
277
|
+
if (type === 'hnsw') {
|
|
278
|
+
return `HNSW(m=${hnsw.m},ef=${hnsw.efConstruction})`;
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
if (type === 'ivfflat') {
|
|
282
|
+
return `IVF`;
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
return 'Flat';
|
|
286
|
+
}
|