@mastra/pg 0.0.0-commonjs-20250227130920

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,391 @@
1
+ import type { Filter } from '@mastra/core/filter';
2
+ import { MastraVector } from '@mastra/core/vector';
3
+ import type { IndexStats, QueryResult } from '@mastra/core/vector';
4
+ import pg from 'pg';
5
+
6
+ import { PGFilterTranslator } from './filter';
7
+ import { buildFilterQuery } from './sql-builder';
8
+ import type { IndexConfig, IndexType } from './types';
9
+
10
+ export interface PGIndexStats extends IndexStats {
11
+ type: IndexType;
12
+ config: {
13
+ m?: number;
14
+ efConstruction?: number;
15
+ lists?: number;
16
+ probes?: number;
17
+ };
18
+ }
19
+
20
+ export class PgVector extends MastraVector {
21
+ private pool: pg.Pool;
22
+ private indexCache: Map<string, PGIndexStats> = new Map();
23
+
24
+ constructor(connectionString: string) {
25
+ super();
26
+
27
+ const basePool = new pg.Pool({
28
+ connectionString,
29
+ max: 20, // Maximum number of clients in the pool
30
+ idleTimeoutMillis: 30000, // Close idle connections after 30 seconds
31
+ connectionTimeoutMillis: 2000, // Fail fast if can't connect
32
+ });
33
+
34
+ const telemetry = this.__getTelemetry();
35
+
36
+ this.pool =
37
+ telemetry?.traceClass(basePool, {
38
+ spanNamePrefix: 'pg-vector',
39
+ attributes: {
40
+ 'vector.type': 'postgres',
41
+ },
42
+ }) ?? basePool;
43
+ }
44
+
45
+ transformFilter(filter?: Filter) {
46
+ const pgFilter = new PGFilterTranslator();
47
+ const translatedFilter = pgFilter.translate(filter ?? {});
48
+ return translatedFilter;
49
+ }
50
+
51
+ async getIndexInfo(indexName: string): Promise<PGIndexStats> {
52
+ if (!this.indexCache.has(indexName)) {
53
+ this.indexCache.set(indexName, await this.describeIndex(indexName));
54
+ }
55
+ return this.indexCache.get(indexName)!;
56
+ }
57
+
58
+ async query(
59
+ indexName: string,
60
+ queryVector: number[],
61
+ topK: number = 10,
62
+ filter?: Filter,
63
+ includeVector: boolean = false,
64
+ minScore: number = 0, // Optional minimum score threshold
65
+ options?: {
66
+ ef?: number; // For HNSW
67
+ probes?: number; // For IVF
68
+ },
69
+ ): Promise<QueryResult[]> {
70
+ const client = await this.pool.connect();
71
+ try {
72
+ const vectorStr = `[${queryVector.join(',')}]`;
73
+ const translatedFilter = this.transformFilter(filter);
74
+ const { sql: filterQuery, values: filterValues } = buildFilterQuery(translatedFilter, minScore);
75
+
76
+ // Get index type and configuration
77
+ const indexInfo = await this.getIndexInfo(indexName);
78
+
79
+ // Set HNSW search parameter if applicable
80
+ if (indexInfo.type === 'hnsw') {
81
+ // Calculate ef and clamp between 1 and 1000
82
+ const calculatedEf = options?.ef ?? Math.max(topK, (indexInfo?.config?.m ?? 16) * topK);
83
+ const searchEf = Math.min(1000, Math.max(1, calculatedEf));
84
+ await client.query(`SET LOCAL hnsw.ef_search = ${searchEf}`);
85
+ }
86
+
87
+ if (indexInfo.type === 'ivfflat' && options?.probes) {
88
+ await client.query(`SET LOCAL ivfflat.probes = ${options.probes}`);
89
+ }
90
+
91
+ const query = `
92
+ WITH vector_scores AS (
93
+ SELECT
94
+ vector_id as id,
95
+ 1 - (embedding <=> '${vectorStr}'::vector) as score,
96
+ metadata
97
+ ${includeVector ? ', embedding' : ''}
98
+ FROM ${indexName}
99
+ ${filterQuery}
100
+ )
101
+ SELECT *
102
+ FROM vector_scores
103
+ WHERE score > $1
104
+ ORDER BY score DESC
105
+ LIMIT ${topK}`;
106
+ const result = await client.query(query, filterValues);
107
+
108
+ return result.rows.map(({ id, score, metadata, embedding }) => ({
109
+ id,
110
+ score,
111
+ metadata,
112
+ ...(includeVector && embedding && { vector: JSON.parse(embedding) }),
113
+ }));
114
+ } finally {
115
+ client.release();
116
+ }
117
+ }
118
+
119
+ async upsert(
120
+ indexName: string,
121
+ vectors: number[][],
122
+ metadata?: Record<string, any>[],
123
+ ids?: string[],
124
+ ): Promise<string[]> {
125
+ // Start a transaction
126
+ const client = await this.pool.connect();
127
+ try {
128
+ await client.query('BEGIN');
129
+ const vectorIds = ids || vectors.map(() => crypto.randomUUID());
130
+
131
+ for (let i = 0; i < vectors.length; i++) {
132
+ const query = `
133
+ INSERT INTO ${indexName} (vector_id, embedding, metadata)
134
+ VALUES ($1, $2::vector, $3::jsonb)
135
+ ON CONFLICT (vector_id)
136
+ DO UPDATE SET
137
+ embedding = $2::vector,
138
+ metadata = $3::jsonb
139
+ RETURNING embedding::text
140
+ `;
141
+
142
+ await client.query(query, [vectorIds[i], `[${vectors[i]?.join(',')}]`, JSON.stringify(metadata?.[i] || {})]);
143
+ }
144
+
145
+ await client.query('COMMIT');
146
+ return vectorIds;
147
+ } catch (error) {
148
+ await client.query('ROLLBACK');
149
+ throw error;
150
+ } finally {
151
+ client.release();
152
+ }
153
+ }
154
+
155
+ async createIndex(
156
+ indexName: string,
157
+ dimension: number,
158
+ metric: 'cosine' | 'euclidean' | 'dotproduct' = 'cosine',
159
+ indexConfig: IndexConfig = {},
160
+ defineIndex: boolean = true,
161
+ ): Promise<void> {
162
+ const client = await this.pool.connect();
163
+ try {
164
+ // Validate inputs
165
+ if (!indexName.match(/^[a-zA-Z_][a-zA-Z0-9_]*$/)) {
166
+ throw new Error('Invalid index name format');
167
+ }
168
+ if (!Number.isInteger(dimension) || dimension <= 0) {
169
+ throw new Error('Dimension must be a positive integer');
170
+ }
171
+
172
+ // First check if vector extension is available
173
+ const extensionCheck = await client.query(`
174
+ SELECT EXISTS (
175
+ SELECT 1 FROM pg_available_extensions WHERE name = 'vector'
176
+ );
177
+ `);
178
+
179
+ if (!extensionCheck.rows[0].exists) {
180
+ throw new Error('PostgreSQL vector extension is not available. Please install it first.');
181
+ }
182
+
183
+ // Try to create extension
184
+ await client.query('CREATE EXTENSION IF NOT EXISTS vector');
185
+
186
+ // Create the table with explicit schema
187
+ await client.query(`
188
+ CREATE TABLE IF NOT EXISTS ${indexName} (
189
+ id SERIAL PRIMARY KEY,
190
+ vector_id TEXT UNIQUE NOT NULL,
191
+ embedding vector(${dimension}),
192
+ metadata JSONB DEFAULT '{}'::jsonb
193
+ );
194
+ `);
195
+
196
+ if (defineIndex) {
197
+ await this.defineIndex(indexName, metric, indexConfig);
198
+ }
199
+ } catch (error: any) {
200
+ console.error('Failed to create vector table:', error);
201
+ throw error;
202
+ } finally {
203
+ client.release();
204
+ }
205
+ }
206
+
207
+ /**
208
+ * @deprecated This function is deprecated. Use buildIndex instead
209
+ */
210
+ async defineIndex(
211
+ indexName: string,
212
+ metric: 'cosine' | 'euclidean' | 'dotproduct' = 'cosine',
213
+ indexConfig: IndexConfig,
214
+ ): Promise<void> {
215
+ return this.buildIndex(indexName, metric, indexConfig);
216
+ }
217
+
218
+ async buildIndex(
219
+ indexName: string,
220
+ metric: 'cosine' | 'euclidean' | 'dotproduct' = 'cosine',
221
+ indexConfig: IndexConfig,
222
+ ): Promise<void> {
223
+ const client = await this.pool.connect();
224
+ try {
225
+ await client.query(`DROP INDEX IF EXISTS ${indexName}_vector_idx`);
226
+
227
+ if (indexConfig.type === 'flat') return;
228
+
229
+ const metricOp =
230
+ metric === 'cosine' ? 'vector_cosine_ops' : metric === 'euclidean' ? 'vector_l2_ops' : 'vector_ip_ops';
231
+
232
+ let indexSQL: string;
233
+ if (indexConfig.type === 'hnsw') {
234
+ const m = indexConfig.hnsw?.m ?? 8;
235
+ const efConstruction = indexConfig.hnsw?.efConstruction ?? 32;
236
+
237
+ indexSQL = `
238
+ CREATE INDEX ${indexName}_vector_idx
239
+ ON ${indexName}
240
+ USING hnsw (embedding ${metricOp})
241
+ WITH (
242
+ m = ${m},
243
+ ef_construction = ${efConstruction}
244
+ )
245
+ `;
246
+ } else {
247
+ let lists: number;
248
+ if (indexConfig.ivf?.lists) {
249
+ lists = indexConfig.ivf.lists;
250
+ } else {
251
+ const size = (await client.query(`SELECT COUNT(*) FROM ${indexName}`)).rows[0].count;
252
+ lists = Math.max(100, Math.min(4000, Math.floor(Math.sqrt(size) * 2)));
253
+ }
254
+ indexSQL = `
255
+ CREATE INDEX ${indexName}_vector_idx
256
+ ON ${indexName}
257
+ USING ivfflat (embedding ${metricOp})
258
+ WITH (lists = ${lists});
259
+ `;
260
+ }
261
+
262
+ await client.query(indexSQL);
263
+ this.indexCache.delete(indexName);
264
+ } finally {
265
+ client.release();
266
+ }
267
+ }
268
+
269
+ async listIndexes(): Promise<string[]> {
270
+ const client = await this.pool.connect();
271
+ try {
272
+ // Then let's see which ones have vector columns
273
+ const vectorTablesQuery = `
274
+ SELECT DISTINCT table_name
275
+ FROM information_schema.columns
276
+ WHERE table_schema = 'public'
277
+ AND udt_name = 'vector';
278
+ `;
279
+ const vectorTables = await client.query(vectorTablesQuery);
280
+ return vectorTables.rows.map(row => row.table_name);
281
+ } finally {
282
+ client.release();
283
+ }
284
+ }
285
+
286
+ async describeIndex(indexName: string): Promise<PGIndexStats> {
287
+ const client = await this.pool.connect();
288
+ try {
289
+ // Get vector dimension
290
+ const dimensionQuery = `
291
+ SELECT atttypmod as dimension
292
+ FROM pg_attribute
293
+ WHERE attrelid = $1::regclass
294
+ AND attname = 'embedding';
295
+ `;
296
+
297
+ // Get row count
298
+ const countQuery = `
299
+ SELECT COUNT(*) as count
300
+ FROM ${indexName};
301
+ `;
302
+
303
+ // Get index metric type
304
+ const indexQuery = `
305
+ SELECT
306
+ am.amname as index_method,
307
+ pg_get_indexdef(i.indexrelid) as index_def,
308
+ opclass.opcname as operator_class
309
+ FROM pg_index i
310
+ JOIN pg_class c ON i.indexrelid = c.oid
311
+ JOIN pg_am am ON c.relam = am.oid
312
+ JOIN pg_opclass opclass ON i.indclass[0] = opclass.oid
313
+ WHERE c.relname = '${indexName}_vector_idx';
314
+ `;
315
+
316
+ const [dimResult, countResult, indexResult] = await Promise.all([
317
+ client.query(dimensionQuery, [indexName]),
318
+ client.query(countQuery),
319
+ client.query(indexQuery),
320
+ ]);
321
+
322
+ const { index_method, index_def, operator_class } = indexResult.rows[0] || {
323
+ index_method: 'flat',
324
+ index_def: '',
325
+ operator_class: 'cosine',
326
+ };
327
+
328
+ // Convert pg_vector index method to our metric type
329
+ const metric = operator_class.includes('l2')
330
+ ? 'euclidean'
331
+ : operator_class.includes('ip')
332
+ ? 'dotproduct'
333
+ : 'cosine';
334
+
335
+ // Parse index configuration
336
+ const config: { m?: number; efConstruction?: number; lists?: number } = {};
337
+
338
+ if (index_method === 'hnsw') {
339
+ const m = index_def.match(/m\s*=\s*'?(\d+)'?/)?.[1];
340
+ const efConstruction = index_def.match(/ef_construction\s*=\s*'?(\d+)'?/)?.[1];
341
+ if (m) config.m = parseInt(m);
342
+ if (efConstruction) config.efConstruction = parseInt(efConstruction);
343
+ } else if (index_method === 'ivfflat') {
344
+ const lists = index_def.match(/lists\s*=\s*'?(\d+)'?/)?.[1];
345
+ if (lists) config.lists = parseInt(lists);
346
+ }
347
+
348
+ return {
349
+ dimension: dimResult.rows[0].dimension,
350
+ count: parseInt(countResult.rows[0].count),
351
+ metric,
352
+ type: index_method as 'flat' | 'hnsw' | 'ivfflat',
353
+ config,
354
+ };
355
+ } catch (e: any) {
356
+ await client.query('ROLLBACK');
357
+ throw new Error(`Failed to describe vector table: ${e.message}`);
358
+ } finally {
359
+ client.release();
360
+ }
361
+ }
362
+
363
+ async deleteIndex(indexName: string): Promise<void> {
364
+ const client = await this.pool.connect();
365
+ try {
366
+ // Drop the table
367
+ await client.query(`DROP TABLE IF EXISTS ${indexName} CASCADE`);
368
+ } catch (error: any) {
369
+ await client.query('ROLLBACK');
370
+ throw new Error(`Failed to delete vector table: ${error.message}`);
371
+ } finally {
372
+ client.release();
373
+ }
374
+ }
375
+
376
+ async truncateIndex(indexName: string) {
377
+ const client = await this.pool.connect();
378
+ try {
379
+ await client.query(`TRUNCATE ${indexName}`);
380
+ } catch (e: any) {
381
+ await client.query('ROLLBACK');
382
+ throw new Error(`Failed to truncate vector table: ${e.message}`);
383
+ } finally {
384
+ client.release();
385
+ }
386
+ }
387
+
388
+ async disconnect() {
389
+ await this.pool.end();
390
+ }
391
+ }
@@ -0,0 +1,286 @@
1
+ import type { IndexConfig, IndexType } from './types';
2
+
3
+ import type { PgVector } from '.';
4
+
5
+ export interface TestResult {
6
+ distribution: string;
7
+ dimension: number;
8
+ type: IndexType;
9
+ size: number;
10
+ k?: number;
11
+ metrics: {
12
+ recall?: number;
13
+ minRecall?: number;
14
+ maxRecall?: number;
15
+ latency?: {
16
+ p50: number;
17
+ p95: number;
18
+ lists?: number;
19
+ vectorsPerList?: number;
20
+ m?: number;
21
+ ef?: number;
22
+ };
23
+ clustering?: {
24
+ numLists?: number;
25
+ avgVectorsPerList?: number;
26
+ recommendedLists?: number;
27
+ distribution?: string;
28
+ };
29
+ };
30
+ }
31
+
32
+ export const generateRandomVectors = (count: number, dim: number) => {
33
+ return Array.from({ length: count }, () => {
34
+ return Array.from({ length: dim }, () => Math.random() * 2 - 1);
35
+ });
36
+ };
37
+
38
+ export const generateClusteredVectors = (count: number, dim: number, numClusters: number = 10) => {
39
+ // Generate cluster centers
40
+ const centers = Array.from({ length: numClusters }, () => Array.from({ length: dim }, () => Math.random() * 2 - 1));
41
+
42
+ // Generate vectors around centers with varying spread
43
+ return Array.from({ length: count }, () => {
44
+ // Pick a random cluster, with some clusters being more popular
45
+ const centerIdx = Math.floor(Math.pow(Math.random(), 2) * numClusters);
46
+ const center = centers[centerIdx] as number[];
47
+
48
+ // Add noise, with some vectors being further from centers
49
+ const spread = Math.random() < 0.8 ? 0.1 : 0.5; // 80% close, 20% far
50
+ return center.map(c => c + (Math.random() * spread - spread / 2));
51
+ });
52
+ };
53
+
54
+ // Or even more extreme:
55
+ export const generateSkewedVectors = (count: number, dim: number) => {
56
+ // Create dense clusters with sparse regions
57
+ const vectors: number[][] = [];
58
+
59
+ const denseCount = Math.floor(count * 0.6);
60
+ const sparseCount = count - denseCount;
61
+
62
+ // Dense cluster (60% of vectors)
63
+ const denseCenter = Array.from({ length: dim }, () => Math.random() * 0.2);
64
+ for (let i = 0; i < denseCount; i++) {
65
+ vectors.push(denseCenter.map(c => c + (Math.random() * 0.1 - 0.05)));
66
+ }
67
+
68
+ // Scattered vectors (40%)
69
+ for (let i = 0; i < sparseCount; i++) {
70
+ vectors.push(Array.from({ length: dim }, () => Math.random() * 2 - 1));
71
+ }
72
+
73
+ return vectors.sort(() => Math.random() - 0.5); // Shuffle
74
+ };
75
+
76
+ export const findNearestBruteForce = (query: number[], vectors: number[][], k: number) => {
77
+ const similarities = vectors.map((vector, idx) => {
78
+ const similarity = cosineSimilarity(query, vector);
79
+ return { idx, dist: similarity };
80
+ });
81
+
82
+ const sorted = similarities.sort((a, b) => b.dist - a.dist);
83
+ return sorted.slice(0, k).map(x => x.idx);
84
+ };
85
+
86
+ export const calculateRecall = (actual: number[], expected: number[], k: number): number => {
87
+ let score = 0;
88
+ for (let i = 0; i < k; i++) {
89
+ if (actual[i] === expected[i]) {
90
+ score += 1;
91
+ } else if (expected.includes(actual[i] ?? 0)) {
92
+ score += 0.5;
93
+ }
94
+ }
95
+ return score / k;
96
+ };
97
+
98
+ export function cosineSimilarity(a: number[], b: number[]): number {
99
+ const dotProduct = a.reduce((sum, val, i) => sum + (val ?? 0) * (b[i] ?? 0), 0);
100
+ const normA = Math.sqrt(a.reduce((sum, val) => sum + val * val, 0));
101
+ const normB = Math.sqrt(b.reduce((sum, val) => sum + val * val, 0));
102
+ return dotProduct / (normA * normB);
103
+ }
104
+
105
+ export const formatTable = (data: any[], columns: string[]) => {
106
+ const colWidths = columns.map(col =>
107
+ Math.max(
108
+ col.length,
109
+ ...data.map(row => {
110
+ const value = row[col];
111
+ return value === undefined || value === null ? '-'.length : value.toString().length;
112
+ }),
113
+ ),
114
+ );
115
+
116
+ const topBorder = '┌' + colWidths.map(w => '─'.repeat(w)).join('┬') + '┐';
117
+ const headerSeparator = '├' + colWidths.map(w => '─'.repeat(w)).join('┼') + '┤';
118
+ const bottomBorder = '└' + colWidths.map(w => '─'.repeat(w)).join('┴') + '┘';
119
+
120
+ const header = '│' + columns.map((col, i) => col.padEnd(colWidths[i] ?? 0)).join('│') + '│';
121
+ const rows = data.map(
122
+ row =>
123
+ '│' +
124
+ columns
125
+ .map((col, i) => {
126
+ const value = row[col];
127
+ const displayValue = value === undefined || value === null ? '-' : value.toString();
128
+ return displayValue.padEnd(colWidths[i]);
129
+ })
130
+ .join('│') +
131
+ '│',
132
+ );
133
+
134
+ return [topBorder, header, headerSeparator, ...rows, bottomBorder].join('\n');
135
+ };
136
+
137
+ export const groupBy = <T, K extends keyof T>(
138
+ array: T[],
139
+ key: K | ((item: T) => string),
140
+ reducer?: (group: T[]) => any,
141
+ ): Record<string, any> => {
142
+ const grouped = array.reduce(
143
+ (acc, item) => {
144
+ const value = typeof key === 'function' ? key(item) : item[key];
145
+ if (!acc[value as any]) acc[value as any] = [];
146
+ acc[value as any]?.push(item);
147
+ return acc;
148
+ },
149
+ {} as Record<string, T[]>,
150
+ );
151
+
152
+ if (reducer) {
153
+ return Object.fromEntries(Object.entries(grouped).map(([key, group]) => [key, reducer(group)]));
154
+ }
155
+
156
+ return grouped;
157
+ };
158
+
159
+ export const calculateTimeout = (dimension: number, size: number, k: number) => {
160
+ let timeout = 600000;
161
+ if (dimension >= 1024) timeout *= 3;
162
+ else if (dimension >= 384) timeout *= 1.5;
163
+ if (size >= 10000) timeout *= 2;
164
+ if (k >= 75) timeout *= 1.5;
165
+ return timeout * 5;
166
+ };
167
+
168
+ export const baseTestConfigs = {
169
+ smokeTests: [{ dimension: 384, size: 1_000, k: 10, queryCount: 10 }],
170
+ '64': [
171
+ { dimension: 64, size: 100, k: 10, queryCount: 30 },
172
+ { dimension: 64, size: 100, k: 25, queryCount: 30 },
173
+ { dimension: 64, size: 100, k: 50, queryCount: 30 },
174
+ { dimension: 64, size: 100, k: 100, queryCount: 30 },
175
+ { dimension: 64, size: 1_000, k: 10, queryCount: 30 },
176
+ { dimension: 64, size: 1_000, k: 25, queryCount: 30 },
177
+ { dimension: 64, size: 1_000, k: 50, queryCount: 30 },
178
+ { dimension: 64, size: 1_000, k: 100, queryCount: 30 },
179
+ { dimension: 64, size: 10_000, k: 10, queryCount: 30 },
180
+ { dimension: 64, size: 100_000, k: 10, queryCount: 30 },
181
+ { dimension: 64, size: 100_000, k: 25, queryCount: 30 },
182
+ { dimension: 64, size: 100_000, k: 50, queryCount: 30 },
183
+ { dimension: 64, size: 100_000, k: 100, queryCount: 30 },
184
+ { dimension: 64, size: 500_000, k: 10, queryCount: 30 },
185
+ { dimension: 64, size: 1_000_000, k: 10, queryCount: 30 },
186
+ ],
187
+ '384': [
188
+ { dimension: 384, size: 100, k: 10, queryCount: 30 },
189
+ { dimension: 384, size: 100, k: 25, queryCount: 30 },
190
+ { dimension: 384, size: 100, k: 50, queryCount: 30 },
191
+ { dimension: 384, size: 100, k: 100, queryCount: 30 },
192
+ { dimension: 384, size: 1_000, k: 10, queryCount: 30 },
193
+ { dimension: 384, size: 1_000, k: 25, queryCount: 30 },
194
+ { dimension: 384, size: 1_000, k: 50, queryCount: 30 },
195
+ { dimension: 384, size: 1_000, k: 100, queryCount: 30 },
196
+ { dimension: 384, size: 10_000, k: 10, queryCount: 30 },
197
+ { dimension: 384, size: 100_000, k: 10, queryCount: 30 },
198
+ { dimension: 384, size: 100_000, k: 25, queryCount: 30 },
199
+ { dimension: 384, size: 100_000, k: 50, queryCount: 30 },
200
+ { dimension: 384, size: 100_000, k: 100, queryCount: 30 },
201
+ { dimension: 384, size: 500_000, k: 10, queryCount: 30 },
202
+ ],
203
+ '1024': [
204
+ { dimension: 1024, size: 100, k: 10, queryCount: 30 },
205
+ { dimension: 1024, size: 100, k: 25, queryCount: 30 },
206
+ { dimension: 1024, size: 100, k: 50, queryCount: 30 },
207
+ { dimension: 1024, size: 100, k: 100, queryCount: 30 },
208
+ { dimension: 1024, size: 1_000, k: 10, queryCount: 30 },
209
+ { dimension: 1024, size: 1_000, k: 25, queryCount: 30 },
210
+ { dimension: 1024, size: 1_000, k: 50, queryCount: 30 },
211
+ { dimension: 1024, size: 1_000, k: 100, queryCount: 30 },
212
+ { dimension: 1024, size: 10_000, k: 10, queryCount: 30 },
213
+ { dimension: 1024, size: 10_000, k: 25, queryCount: 30 },
214
+ { dimension: 1024, size: 10_000, k: 50, queryCount: 30 },
215
+ { dimension: 1024, size: 10_000, k: 100, queryCount: 30 },
216
+ { dimension: 1024, size: 50_000, k: 10, queryCount: 30 },
217
+ { dimension: 1024, size: 50_000, k: 25, queryCount: 30 },
218
+ ],
219
+ stressTests: [
220
+ // Maximum load
221
+ { dimension: 512, size: 1_000_000, k: 50, queryCount: 5 },
222
+
223
+ // Dense search
224
+ { dimension: 256, size: 1_000_000, k: 100, queryCount: 5 },
225
+
226
+ { dimension: 1024, size: 500_000, k: 50, queryCount: 5 },
227
+ ],
228
+ };
229
+
230
+ export interface TestConfig {
231
+ dimension: number;
232
+ size: number;
233
+ k: number;
234
+ queryCount: number;
235
+ }
236
+
237
+ export async function warmupQuery(vectorDB: PgVector, indexName: string, dimension: number, k: number) {
238
+ const warmupVector = generateRandomVectors(1, dimension)[0] as number[];
239
+ await vectorDB.query(indexName, warmupVector, k);
240
+ }
241
+
242
+ export async function measureLatency<T>(fn: () => Promise<T>): Promise<[number, T]> {
243
+ const start = performance.now();
244
+ const result = await fn();
245
+ const end = performance.now();
246
+ return [end - start, result];
247
+ }
248
+
249
+ export const getListCount = (indexConfig: IndexConfig, size: number): number | undefined => {
250
+ if (indexConfig.type !== 'ivfflat') return undefined;
251
+ if (indexConfig.ivf?.lists) return indexConfig.ivf.lists;
252
+ return Math.max(100, Math.min(4000, Math.floor(Math.sqrt(size) * 2)));
253
+ };
254
+
255
+ export const getHNSWConfig = (indexConfig: IndexConfig): { m: number; efConstruction: number } => {
256
+ return {
257
+ m: indexConfig.hnsw?.m ?? 8,
258
+ efConstruction: indexConfig.hnsw?.efConstruction ?? 32,
259
+ };
260
+ };
261
+
262
+ export function getSearchEf(k: number, m: number) {
263
+ return {
264
+ default: Math.max(k, m * k), // Default calculation
265
+ lower: Math.max(k, (m * k) / 2), // Lower quality, faster
266
+ higher: Math.max(k, m * k * 2), // Higher quality, slower
267
+ };
268
+ }
269
+
270
+ export function getIndexDescription({
271
+ type,
272
+ hnsw,
273
+ }: {
274
+ type: IndexType;
275
+ hnsw: { m: number; efConstruction: number };
276
+ }): string {
277
+ if (type === 'hnsw') {
278
+ return `HNSW(m=${hnsw.m},ef=${hnsw.efConstruction})`;
279
+ }
280
+
281
+ if (type === 'ivfflat') {
282
+ return `IVF`;
283
+ }
284
+
285
+ return 'Flat';
286
+ }