@mastra/pg 0.1.6-alpha.1 → 0.1.6-alpha.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.cjs ADDED
@@ -0,0 +1,1050 @@
1
+ 'use strict';
2
+
3
+ var vector = require('@mastra/core/vector');
4
+ var pg = require('pg');
5
+ var filter = require('@mastra/core/vector/filter');
6
+ var storage = require('@mastra/core/storage');
7
+ var pgPromise = require('pg-promise');
8
+
9
+ function _interopDefault (e) { return e && e.__esModule ? e : { default: e }; }
10
+
11
+ var pg__default = /*#__PURE__*/_interopDefault(pg);
12
+ var pgPromise__default = /*#__PURE__*/_interopDefault(pgPromise);
13
+
14
+ // src/vector/index.ts
15
+ var PGFilterTranslator = class extends filter.BaseFilterTranslator {
16
+ getSupportedOperators() {
17
+ return {
18
+ ...filter.BaseFilterTranslator.DEFAULT_OPERATORS,
19
+ custom: ["$contains", "$size"]
20
+ };
21
+ }
22
+ translate(filter) {
23
+ if (this.isEmpty(filter)) {
24
+ return filter;
25
+ }
26
+ this.validateFilter(filter);
27
+ return this.translateNode(filter);
28
+ }
29
+ translateNode(node, currentPath = "") {
30
+ const withPath = (result2) => currentPath ? { [currentPath]: result2 } : result2;
31
+ if (this.isPrimitive(node)) {
32
+ return withPath({ $eq: this.normalizeComparisonValue(node) });
33
+ }
34
+ if (Array.isArray(node)) {
35
+ return withPath({ $in: this.normalizeArrayValues(node) });
36
+ }
37
+ if (node instanceof RegExp) {
38
+ return withPath(this.translateRegexPattern(node.source, node.flags));
39
+ }
40
+ const entries = Object.entries(node);
41
+ const result = {};
42
+ if ("$options" in node && !("$regex" in node)) {
43
+ throw new Error("$options is not valid without $regex");
44
+ }
45
+ if ("$regex" in node) {
46
+ const options = node.$options || "";
47
+ return withPath(this.translateRegexPattern(node.$regex, options));
48
+ }
49
+ for (const [key, value] of entries) {
50
+ if (key === "$options") continue;
51
+ const newPath = currentPath ? `${currentPath}.${key}` : key;
52
+ if (this.isLogicalOperator(key)) {
53
+ result[key] = Array.isArray(value) ? value.map((filter) => this.translateNode(filter)) : this.translateNode(value);
54
+ } else if (this.isOperator(key)) {
55
+ if (this.isArrayOperator(key) && !Array.isArray(value) && key !== "$elemMatch") {
56
+ result[key] = [value];
57
+ } else if (this.isBasicOperator(key) && Array.isArray(value)) {
58
+ result[key] = JSON.stringify(value);
59
+ } else {
60
+ result[key] = value;
61
+ }
62
+ } else if (typeof value === "object" && value !== null) {
63
+ const hasOperators = Object.keys(value).some((k) => this.isOperator(k));
64
+ if (hasOperators) {
65
+ result[newPath] = this.translateNode(value);
66
+ } else {
67
+ Object.assign(result, this.translateNode(value, newPath));
68
+ }
69
+ } else {
70
+ result[newPath] = this.translateNode(value);
71
+ }
72
+ }
73
+ return result;
74
+ }
75
+ translateRegexPattern(pattern, options = "") {
76
+ if (!options) return { $regex: pattern };
77
+ const flags = options.split("").filter((f) => "imsux".includes(f)).join("");
78
+ return { $regex: flags ? `(?${flags})${pattern}` : pattern };
79
+ }
80
+ };
81
+
82
+ // src/vector/sql-builder.ts
83
+ var createBasicOperator = (symbol) => {
84
+ return (key, paramIndex) => ({
85
+ sql: `CASE
86
+ WHEN $${paramIndex}::text IS NULL THEN metadata#>>'{${handleKey(key)}}' IS ${symbol === "=" ? "" : "NOT"} NULL
87
+ ELSE metadata#>>'{${handleKey(key)}}' ${symbol} $${paramIndex}::text
88
+ END`,
89
+ needsValue: true
90
+ });
91
+ };
92
+ var createNumericOperator = (symbol) => {
93
+ return (key, paramIndex) => ({
94
+ sql: `(metadata#>>'{${handleKey(key)}}')::numeric ${symbol} $${paramIndex}`,
95
+ needsValue: true
96
+ });
97
+ };
98
+ function buildElemMatchConditions(value, paramIndex) {
99
+ if (typeof value !== "object" || Array.isArray(value)) {
100
+ throw new Error("$elemMatch requires an object with conditions");
101
+ }
102
+ const conditions = [];
103
+ const values = [];
104
+ Object.entries(value).forEach(([field, val]) => {
105
+ const nextParamIndex = paramIndex + values.length;
106
+ let paramOperator;
107
+ let paramKey;
108
+ let paramValue;
109
+ if (field.startsWith("$")) {
110
+ paramOperator = field;
111
+ paramKey = "";
112
+ paramValue = val;
113
+ } else if (typeof val === "object" && !Array.isArray(val)) {
114
+ const [op, opValue] = Object.entries(val || {})[0] || [];
115
+ paramOperator = op;
116
+ paramKey = field;
117
+ paramValue = opValue;
118
+ } else {
119
+ paramOperator = "$eq";
120
+ paramKey = field;
121
+ paramValue = val;
122
+ }
123
+ const operatorFn = FILTER_OPERATORS[paramOperator];
124
+ if (!operatorFn) {
125
+ throw new Error(`Invalid operator: ${paramOperator}`);
126
+ }
127
+ const result = operatorFn(paramKey, nextParamIndex, paramValue);
128
+ const sql = result.sql.replaceAll("metadata#>>", "elem#>>");
129
+ conditions.push(sql);
130
+ if (result.needsValue) {
131
+ values.push(paramValue);
132
+ }
133
+ });
134
+ return {
135
+ sql: conditions.join(" AND "),
136
+ values
137
+ };
138
+ }
139
+ var FILTER_OPERATORS = {
140
+ $eq: createBasicOperator("="),
141
+ $ne: createBasicOperator("!="),
142
+ $gt: createNumericOperator(">"),
143
+ $gte: createNumericOperator(">="),
144
+ $lt: createNumericOperator("<"),
145
+ $lte: createNumericOperator("<="),
146
+ // Array Operators
147
+ $in: (key, paramIndex) => ({
148
+ sql: `metadata#>>'{${handleKey(key)}}' = ANY($${paramIndex}::text[])`,
149
+ needsValue: true
150
+ }),
151
+ $nin: (key, paramIndex) => ({
152
+ sql: `metadata#>>'{${handleKey(key)}}' != ALL($${paramIndex}::text[])`,
153
+ needsValue: true
154
+ }),
155
+ $all: (key, paramIndex) => ({
156
+ sql: `CASE WHEN array_length($${paramIndex}::text[], 1) IS NULL THEN false
157
+ ELSE (metadata#>'{${handleKey(key)}}')::jsonb ?& $${paramIndex}::text[] END`,
158
+ needsValue: true
159
+ }),
160
+ $elemMatch: (key, paramIndex, value) => {
161
+ const { sql, values } = buildElemMatchConditions(value, paramIndex);
162
+ return {
163
+ sql: `(
164
+ CASE
165
+ WHEN jsonb_typeof(metadata->'${handleKey(key)}') = 'array' THEN
166
+ EXISTS (
167
+ SELECT 1
168
+ FROM jsonb_array_elements(metadata->'${handleKey(key)}') as elem
169
+ WHERE ${sql}
170
+ )
171
+ ELSE FALSE
172
+ END
173
+ )`,
174
+ needsValue: true,
175
+ transformValue: () => values
176
+ };
177
+ },
178
+ // Element Operators
179
+ $exists: (key) => ({
180
+ sql: `metadata ? '${key}'`,
181
+ needsValue: false
182
+ }),
183
+ // Logical Operators
184
+ $and: (key) => ({ sql: `(${key})`, needsValue: false }),
185
+ $or: (key) => ({ sql: `(${key})`, needsValue: false }),
186
+ $not: (key) => ({ sql: `NOT (${key})`, needsValue: false }),
187
+ $nor: (key) => ({ sql: `NOT (${key})`, needsValue: false }),
188
+ // Regex Operators
189
+ $regex: (key, paramIndex) => ({
190
+ sql: `metadata#>>'{${handleKey(key)}}' ~ $${paramIndex}`,
191
+ needsValue: true
192
+ }),
193
+ $contains: (key, paramIndex) => ({
194
+ sql: `metadata @> $${paramIndex}::jsonb`,
195
+ needsValue: true,
196
+ transformValue: (value) => {
197
+ const parts = key.split(".");
198
+ return JSON.stringify(parts.reduceRight((value2, key2) => ({ [key2]: value2 }), value));
199
+ }
200
+ }),
201
+ $size: (key, paramIndex) => ({
202
+ sql: `(
203
+ CASE
204
+ WHEN jsonb_typeof(metadata#>'{${handleKey(key)}}') = 'array' THEN
205
+ jsonb_array_length(metadata#>'{${handleKey(key)}}') = $${paramIndex}
206
+ ELSE FALSE
207
+ END
208
+ )`,
209
+ needsValue: true
210
+ })
211
+ };
212
+ var handleKey = (key) => {
213
+ return key.replace(/\./g, ",");
214
+ };
215
+ function buildFilterQuery(filter, minScore) {
216
+ const values = [minScore];
217
+ function buildCondition(key, value, parentPath) {
218
+ if (["$and", "$or", "$not", "$nor"].includes(key)) {
219
+ return handleLogicalOperator(key, value);
220
+ }
221
+ if (!value || typeof value !== "object") {
222
+ values.push(value);
223
+ return `metadata#>>'{${handleKey(key)}}' = $${values.length}`;
224
+ }
225
+ const [[operator, operatorValue] = []] = Object.entries(value);
226
+ if (operator === "$not") {
227
+ const entries = Object.entries(operatorValue);
228
+ const conditions2 = entries.map(([nestedOp, nestedValue]) => {
229
+ if (!FILTER_OPERATORS[nestedOp]) {
230
+ throw new Error(`Invalid operator in $not condition: ${nestedOp}`);
231
+ }
232
+ const operatorFn2 = FILTER_OPERATORS[nestedOp];
233
+ const operatorResult2 = operatorFn2(key, values.length + 1);
234
+ if (operatorResult2.needsValue) {
235
+ values.push(nestedValue);
236
+ }
237
+ return operatorResult2.sql;
238
+ }).join(" AND ");
239
+ return `NOT (${conditions2})`;
240
+ }
241
+ const operatorFn = FILTER_OPERATORS[operator];
242
+ const operatorResult = operatorFn(key, values.length + 1, operatorValue);
243
+ if (operatorResult.needsValue) {
244
+ const transformedValue = operatorResult.transformValue ? operatorResult.transformValue(operatorValue) : operatorValue;
245
+ if (Array.isArray(transformedValue) && operator === "$elemMatch") {
246
+ values.push(...transformedValue);
247
+ } else {
248
+ values.push(transformedValue);
249
+ }
250
+ }
251
+ return operatorResult.sql;
252
+ }
253
+ function handleLogicalOperator(key, value, parentPath) {
254
+ if (key === "$not") {
255
+ const entries = Object.entries(value);
256
+ const conditions3 = entries.map(([fieldKey, fieldValue]) => buildCondition(fieldKey, fieldValue)).join(" AND ");
257
+ return `NOT (${conditions3})`;
258
+ }
259
+ if (!value || value.length === 0) {
260
+ switch (key) {
261
+ case "$and":
262
+ case "$nor":
263
+ return "true";
264
+ // Empty $and/$nor match everything
265
+ case "$or":
266
+ return "false";
267
+ // Empty $or matches nothing
268
+ default:
269
+ return "true";
270
+ }
271
+ }
272
+ const joinOperator = key === "$or" || key === "$nor" ? "OR" : "AND";
273
+ const conditions2 = value.map((f) => {
274
+ const entries = Object.entries(f || {});
275
+ if (entries.length === 0) return "";
276
+ const [firstKey, firstValue] = entries[0] || [];
277
+ if (["$and", "$or", "$not", "$nor"].includes(firstKey)) {
278
+ return buildCondition(firstKey, firstValue);
279
+ }
280
+ return entries.map(([k, v]) => buildCondition(k, v)).join(` ${joinOperator} `);
281
+ });
282
+ const joined = conditions2.join(` ${joinOperator} `);
283
+ const operatorFn = FILTER_OPERATORS[key];
284
+ return operatorFn(joined, 0, value).sql;
285
+ }
286
+ if (!filter) {
287
+ return { sql: "", values };
288
+ }
289
+ const conditions = Object.entries(filter).map(([key, value]) => buildCondition(key, value)).filter(Boolean).join(" AND ");
290
+ return { sql: conditions ? `WHERE ${conditions}` : "", values };
291
+ }
292
+
293
+ // src/vector/index.ts
294
+ var PgVector = class extends vector.MastraVector {
295
+ pool;
296
+ indexCache = /* @__PURE__ */ new Map();
297
+ constructor(connectionString) {
298
+ super();
299
+ const basePool = new pg__default.default.Pool({
300
+ connectionString,
301
+ max: 20,
302
+ // Maximum number of clients in the pool
303
+ idleTimeoutMillis: 3e4,
304
+ // Close idle connections after 30 seconds
305
+ connectionTimeoutMillis: 2e3
306
+ // Fail fast if can't connect
307
+ });
308
+ const telemetry = this.__getTelemetry();
309
+ this.pool = telemetry?.traceClass(basePool, {
310
+ spanNamePrefix: "pg-vector",
311
+ attributes: {
312
+ "vector.type": "postgres"
313
+ }
314
+ }) ?? basePool;
315
+ }
316
+ transformFilter(filter) {
317
+ const translator = new PGFilterTranslator();
318
+ return translator.translate(filter);
319
+ }
320
+ async getIndexInfo(indexName) {
321
+ if (!this.indexCache.has(indexName)) {
322
+ this.indexCache.set(indexName, await this.describeIndex(indexName));
323
+ }
324
+ return this.indexCache.get(indexName);
325
+ }
326
+ async query(...args) {
327
+ const params = this.normalizeArgs("query", args, ["minScore", "ef", "probes"]);
328
+ const { indexName, queryVector, topK = 10, filter, includeVector = false, minScore = 0, ef, probes } = params;
329
+ const client = await this.pool.connect();
330
+ try {
331
+ const vectorStr = `[${queryVector.join(",")}]`;
332
+ const translatedFilter = this.transformFilter(filter);
333
+ const { sql: filterQuery, values: filterValues } = buildFilterQuery(translatedFilter, minScore);
334
+ const indexInfo = await this.getIndexInfo(indexName);
335
+ if (indexInfo.type === "hnsw") {
336
+ const calculatedEf = ef ?? Math.max(topK, (indexInfo?.config?.m ?? 16) * topK);
337
+ const searchEf = Math.min(1e3, Math.max(1, calculatedEf));
338
+ await client.query(`SET LOCAL hnsw.ef_search = ${searchEf}`);
339
+ }
340
+ if (indexInfo.type === "ivfflat" && probes) {
341
+ await client.query(`SET LOCAL ivfflat.probes = ${probes}`);
342
+ }
343
+ const query = `
344
+ WITH vector_scores AS (
345
+ SELECT
346
+ vector_id as id,
347
+ 1 - (embedding <=> '${vectorStr}'::vector) as score,
348
+ metadata
349
+ ${includeVector ? ", embedding" : ""}
350
+ FROM ${indexName}
351
+ ${filterQuery}
352
+ )
353
+ SELECT *
354
+ FROM vector_scores
355
+ WHERE score > $1
356
+ ORDER BY score DESC
357
+ LIMIT ${topK}`;
358
+ const result = await client.query(query, filterValues);
359
+ return result.rows.map(({ id, score, metadata, embedding }) => ({
360
+ id,
361
+ score,
362
+ metadata,
363
+ ...includeVector && embedding && { vector: JSON.parse(embedding) }
364
+ }));
365
+ } finally {
366
+ client.release();
367
+ }
368
+ }
369
+ async upsert(...args) {
370
+ const params = this.normalizeArgs("upsert", args);
371
+ const { indexName, vectors, metadata, ids } = params;
372
+ const client = await this.pool.connect();
373
+ try {
374
+ await client.query("BEGIN");
375
+ const vectorIds = ids || vectors.map(() => crypto.randomUUID());
376
+ for (let i = 0; i < vectors.length; i++) {
377
+ const query = `
378
+ INSERT INTO ${indexName} (vector_id, embedding, metadata)
379
+ VALUES ($1, $2::vector, $3::jsonb)
380
+ ON CONFLICT (vector_id)
381
+ DO UPDATE SET
382
+ embedding = $2::vector,
383
+ metadata = $3::jsonb
384
+ RETURNING embedding::text
385
+ `;
386
+ await client.query(query, [vectorIds[i], `[${vectors[i]?.join(",")}]`, JSON.stringify(metadata?.[i] || {})]);
387
+ }
388
+ await client.query("COMMIT");
389
+ return vectorIds;
390
+ } catch (error) {
391
+ await client.query("ROLLBACK");
392
+ throw error;
393
+ } finally {
394
+ client.release();
395
+ }
396
+ }
397
+ async createIndex(...args) {
398
+ const params = this.normalizeArgs("createIndex", args, ["indexConfig", "buildIndex"]);
399
+ const { indexName, dimension, metric = "cosine", indexConfig = {}, buildIndex = true } = params;
400
+ const client = await this.pool.connect();
401
+ try {
402
+ if (!indexName.match(/^[a-zA-Z_][a-zA-Z0-9_]*$/)) {
403
+ throw new Error("Invalid index name format");
404
+ }
405
+ if (!Number.isInteger(dimension) || dimension <= 0) {
406
+ throw new Error("Dimension must be a positive integer");
407
+ }
408
+ const extensionCheck = await client.query(`
409
+ SELECT EXISTS (
410
+ SELECT 1 FROM pg_available_extensions WHERE name = 'vector'
411
+ );
412
+ `);
413
+ if (!extensionCheck.rows[0].exists) {
414
+ throw new Error("PostgreSQL vector extension is not available. Please install it first.");
415
+ }
416
+ await client.query("CREATE EXTENSION IF NOT EXISTS vector");
417
+ await client.query(`
418
+ CREATE TABLE IF NOT EXISTS ${indexName} (
419
+ id SERIAL PRIMARY KEY,
420
+ vector_id TEXT UNIQUE NOT NULL,
421
+ embedding vector(${dimension}),
422
+ metadata JSONB DEFAULT '{}'::jsonb
423
+ );
424
+ `);
425
+ if (buildIndex) {
426
+ await this.buildIndex({ indexName, metric, indexConfig });
427
+ }
428
+ } catch (error) {
429
+ console.error("Failed to create vector table:", error);
430
+ throw error;
431
+ } finally {
432
+ client.release();
433
+ }
434
+ }
435
+ /**
436
+ * @deprecated This function is deprecated. Use buildIndex instead
437
+ */
438
+ async defineIndex(indexName, metric = "cosine", indexConfig) {
439
+ return this.buildIndex({ indexName, metric, indexConfig });
440
+ }
441
+ async buildIndex(...args) {
442
+ const params = this.normalizeArgs("buildIndex", args, ["metric", "indexConfig"]);
443
+ const { indexName, metric = "cosine", indexConfig } = params;
444
+ const client = await this.pool.connect();
445
+ try {
446
+ await client.query(`DROP INDEX IF EXISTS ${indexName}_vector_idx`);
447
+ if (indexConfig.type === "flat") return;
448
+ const metricOp = metric === "cosine" ? "vector_cosine_ops" : metric === "euclidean" ? "vector_l2_ops" : "vector_ip_ops";
449
+ let indexSQL;
450
+ if (indexConfig.type === "hnsw") {
451
+ const m = indexConfig.hnsw?.m ?? 8;
452
+ const efConstruction = indexConfig.hnsw?.efConstruction ?? 32;
453
+ indexSQL = `
454
+ CREATE INDEX ${indexName}_vector_idx
455
+ ON ${indexName}
456
+ USING hnsw (embedding ${metricOp})
457
+ WITH (
458
+ m = ${m},
459
+ ef_construction = ${efConstruction}
460
+ )
461
+ `;
462
+ } else {
463
+ let lists;
464
+ if (indexConfig.ivf?.lists) {
465
+ lists = indexConfig.ivf.lists;
466
+ } else {
467
+ const size = (await client.query(`SELECT COUNT(*) FROM ${indexName}`)).rows[0].count;
468
+ lists = Math.max(100, Math.min(4e3, Math.floor(Math.sqrt(size) * 2)));
469
+ }
470
+ indexSQL = `
471
+ CREATE INDEX ${indexName}_vector_idx
472
+ ON ${indexName}
473
+ USING ivfflat (embedding ${metricOp})
474
+ WITH (lists = ${lists});
475
+ `;
476
+ }
477
+ await client.query(indexSQL);
478
+ this.indexCache.delete(indexName);
479
+ } finally {
480
+ client.release();
481
+ }
482
+ }
483
+ async listIndexes() {
484
+ const client = await this.pool.connect();
485
+ try {
486
+ const vectorTablesQuery = `
487
+ SELECT DISTINCT table_name
488
+ FROM information_schema.columns
489
+ WHERE table_schema = 'public'
490
+ AND udt_name = 'vector';
491
+ `;
492
+ const vectorTables = await client.query(vectorTablesQuery);
493
+ return vectorTables.rows.map((row) => row.table_name);
494
+ } finally {
495
+ client.release();
496
+ }
497
+ }
498
+ async describeIndex(indexName) {
499
+ const client = await this.pool.connect();
500
+ try {
501
+ const dimensionQuery = `
502
+ SELECT atttypmod as dimension
503
+ FROM pg_attribute
504
+ WHERE attrelid = $1::regclass
505
+ AND attname = 'embedding';
506
+ `;
507
+ const countQuery = `
508
+ SELECT COUNT(*) as count
509
+ FROM ${indexName};
510
+ `;
511
+ const indexQuery = `
512
+ SELECT
513
+ am.amname as index_method,
514
+ pg_get_indexdef(i.indexrelid) as index_def,
515
+ opclass.opcname as operator_class
516
+ FROM pg_index i
517
+ JOIN pg_class c ON i.indexrelid = c.oid
518
+ JOIN pg_am am ON c.relam = am.oid
519
+ JOIN pg_opclass opclass ON i.indclass[0] = opclass.oid
520
+ WHERE c.relname = '${indexName}_vector_idx';
521
+ `;
522
+ const [dimResult, countResult, indexResult] = await Promise.all([
523
+ client.query(dimensionQuery, [indexName]),
524
+ client.query(countQuery),
525
+ client.query(indexQuery)
526
+ ]);
527
+ const { index_method, index_def, operator_class } = indexResult.rows[0] || {
528
+ index_method: "flat",
529
+ index_def: "",
530
+ operator_class: "cosine"
531
+ };
532
+ const metric = operator_class.includes("l2") ? "euclidean" : operator_class.includes("ip") ? "dotproduct" : "cosine";
533
+ const config = {};
534
+ if (index_method === "hnsw") {
535
+ const m = index_def.match(/m\s*=\s*'?(\d+)'?/)?.[1];
536
+ const efConstruction = index_def.match(/ef_construction\s*=\s*'?(\d+)'?/)?.[1];
537
+ if (m) config.m = parseInt(m);
538
+ if (efConstruction) config.efConstruction = parseInt(efConstruction);
539
+ } else if (index_method === "ivfflat") {
540
+ const lists = index_def.match(/lists\s*=\s*'?(\d+)'?/)?.[1];
541
+ if (lists) config.lists = parseInt(lists);
542
+ }
543
+ return {
544
+ dimension: dimResult.rows[0].dimension,
545
+ count: parseInt(countResult.rows[0].count),
546
+ metric,
547
+ type: index_method,
548
+ config
549
+ };
550
+ } catch (e) {
551
+ await client.query("ROLLBACK");
552
+ throw new Error(`Failed to describe vector table: ${e.message}`);
553
+ } finally {
554
+ client.release();
555
+ }
556
+ }
557
+ async deleteIndex(indexName) {
558
+ const client = await this.pool.connect();
559
+ try {
560
+ await client.query(`DROP TABLE IF EXISTS ${indexName} CASCADE`);
561
+ } catch (error) {
562
+ await client.query("ROLLBACK");
563
+ throw new Error(`Failed to delete vector table: ${error.message}`);
564
+ } finally {
565
+ client.release();
566
+ }
567
+ }
568
+ async truncateIndex(indexName) {
569
+ const client = await this.pool.connect();
570
+ try {
571
+ await client.query(`TRUNCATE ${indexName}`);
572
+ } catch (e) {
573
+ await client.query("ROLLBACK");
574
+ throw new Error(`Failed to truncate vector table: ${e.message}`);
575
+ } finally {
576
+ client.release();
577
+ }
578
+ }
579
+ async disconnect() {
580
+ await this.pool.end();
581
+ }
582
+ };
583
+ var PostgresStore = class extends storage.MastraStorage {
584
+ db;
585
+ pgp;
586
+ constructor(config) {
587
+ super({ name: "PostgresStore" });
588
+ this.pgp = pgPromise__default.default();
589
+ this.db = this.pgp(
590
+ `connectionString` in config ? { connectionString: config.connectionString } : {
591
+ host: config.host,
592
+ port: config.port,
593
+ database: config.database,
594
+ user: config.user,
595
+ password: config.password
596
+ }
597
+ );
598
+ }
599
+ getEvalsByAgentName(_agentName, _type) {
600
+ throw new Error("Method not implemented.");
601
+ }
602
+ async batchInsert({ tableName, records }) {
603
+ try {
604
+ await this.db.query("BEGIN");
605
+ for (const record of records) {
606
+ await this.insert({ tableName, record });
607
+ }
608
+ await this.db.query("COMMIT");
609
+ } catch (error) {
610
+ console.error(`Error inserting into ${tableName}:`, error);
611
+ await this.db.query("ROLLBACK");
612
+ throw error;
613
+ }
614
+ }
615
+ async getTraces({
616
+ name,
617
+ scope,
618
+ page,
619
+ perPage,
620
+ attributes
621
+ }) {
622
+ let idx = 1;
623
+ const limit = perPage;
624
+ const offset = page * perPage;
625
+ const args = [];
626
+ const conditions = [];
627
+ if (name) {
628
+ conditions.push(`name LIKE CONCAT($${idx++}, '%')`);
629
+ }
630
+ if (scope) {
631
+ conditions.push(`scope = $${idx++}`);
632
+ }
633
+ if (attributes) {
634
+ Object.keys(attributes).forEach((key) => {
635
+ conditions.push(`attributes->>'${key}' = $${idx++}`);
636
+ });
637
+ }
638
+ const whereClause = conditions.length > 0 ? `WHERE ${conditions.join(" AND ")}` : "";
639
+ if (name) {
640
+ args.push(name);
641
+ }
642
+ if (scope) {
643
+ args.push(scope);
644
+ }
645
+ if (attributes) {
646
+ for (const [_key, value] of Object.entries(attributes)) {
647
+ args.push(value);
648
+ }
649
+ }
650
+ console.log(
651
+ "QUERY",
652
+ `SELECT * FROM ${storage.MastraStorage.TABLE_TRACES} ${whereClause} ORDER BY "createdAt" DESC LIMIT ${limit} OFFSET ${offset}`,
653
+ args
654
+ );
655
+ const result = await this.db.manyOrNone(
656
+ `SELECT * FROM ${storage.MastraStorage.TABLE_TRACES} ${whereClause} ORDER BY "createdAt" DESC LIMIT ${limit} OFFSET ${offset}`,
657
+ args
658
+ );
659
+ if (!result) {
660
+ return [];
661
+ }
662
+ return result.map((row) => ({
663
+ id: row.id,
664
+ parentSpanId: row.parentSpanId,
665
+ traceId: row.traceId,
666
+ name: row.name,
667
+ scope: row.scope,
668
+ kind: row.kind,
669
+ status: row.status,
670
+ events: row.events,
671
+ links: row.links,
672
+ attributes: row.attributes,
673
+ startTime: row.startTime,
674
+ endTime: row.endTime,
675
+ other: row.other,
676
+ createdAt: row.createdAt
677
+ }));
678
+ }
679
+ async createTable({
680
+ tableName,
681
+ schema
682
+ }) {
683
+ try {
684
+ const columns = Object.entries(schema).map(([name, def]) => {
685
+ const constraints = [];
686
+ if (def.primaryKey) constraints.push("PRIMARY KEY");
687
+ if (!def.nullable) constraints.push("NOT NULL");
688
+ return `"${name}" ${def.type.toUpperCase()} ${constraints.join(" ")}`;
689
+ }).join(",\n");
690
+ const sql = `
691
+ CREATE TABLE IF NOT EXISTS ${tableName} (
692
+ ${columns}
693
+ );
694
+ ${tableName === storage.MastraStorage.TABLE_WORKFLOW_SNAPSHOT ? `
695
+ DO $$ BEGIN
696
+ IF NOT EXISTS (
697
+ SELECT 1 FROM pg_constraint WHERE conname = 'mastra_workflow_snapshot_workflow_name_run_id_key'
698
+ ) THEN
699
+ ALTER TABLE ${tableName}
700
+ ADD CONSTRAINT mastra_workflow_snapshot_workflow_name_run_id_key
701
+ UNIQUE (workflow_name, run_id);
702
+ END IF;
703
+ END $$;
704
+ ` : ""}
705
+ `;
706
+ await this.db.none(sql);
707
+ } catch (error) {
708
+ console.error(`Error creating table ${tableName}:`, error);
709
+ throw error;
710
+ }
711
+ }
712
+ async clearTable({ tableName }) {
713
+ try {
714
+ await this.db.none(`TRUNCATE TABLE ${tableName} CASCADE`);
715
+ } catch (error) {
716
+ console.error(`Error clearing table ${tableName}:`, error);
717
+ throw error;
718
+ }
719
+ }
720
+ async insert({ tableName, record }) {
721
+ try {
722
+ const columns = Object.keys(record);
723
+ const values = Object.values(record);
724
+ const placeholders = values.map((_, i) => `$${i + 1}`).join(", ");
725
+ await this.db.none(
726
+ `INSERT INTO ${tableName} (${columns.map((c) => `"${c}"`).join(", ")}) VALUES (${placeholders})`,
727
+ values
728
+ );
729
+ } catch (error) {
730
+ console.error(`Error inserting into ${tableName}:`, error);
731
+ throw error;
732
+ }
733
+ }
734
+ async load({ tableName, keys }) {
735
+ try {
736
+ const keyEntries = Object.entries(keys);
737
+ const conditions = keyEntries.map(([key], index) => `"${key}" = $${index + 1}`).join(" AND ");
738
+ const values = keyEntries.map(([_, value]) => value);
739
+ const result = await this.db.oneOrNone(`SELECT * FROM ${tableName} WHERE ${conditions}`, values);
740
+ if (!result) {
741
+ return null;
742
+ }
743
+ if (tableName === storage.MastraStorage.TABLE_WORKFLOW_SNAPSHOT) {
744
+ const snapshot = result;
745
+ if (typeof snapshot.snapshot === "string") {
746
+ snapshot.snapshot = JSON.parse(snapshot.snapshot);
747
+ }
748
+ return snapshot;
749
+ }
750
+ return result;
751
+ } catch (error) {
752
+ console.error(`Error loading from ${tableName}:`, error);
753
+ throw error;
754
+ }
755
+ }
756
+ async getThreadById({ threadId }) {
757
+ try {
758
+ const thread = await this.db.oneOrNone(
759
+ `SELECT
760
+ id,
761
+ "resourceId",
762
+ title,
763
+ metadata,
764
+ "createdAt",
765
+ "updatedAt"
766
+ FROM "${storage.MastraStorage.TABLE_THREADS}"
767
+ WHERE id = $1`,
768
+ [threadId]
769
+ );
770
+ if (!thread) {
771
+ return null;
772
+ }
773
+ return {
774
+ ...thread,
775
+ metadata: typeof thread.metadata === "string" ? JSON.parse(thread.metadata) : thread.metadata,
776
+ createdAt: thread.createdAt,
777
+ updatedAt: thread.updatedAt
778
+ };
779
+ } catch (error) {
780
+ console.error(`Error getting thread ${threadId}:`, error);
781
+ throw error;
782
+ }
783
+ }
784
+ async getThreadsByResourceId({ resourceId }) {
785
+ try {
786
+ const threads = await this.db.manyOrNone(
787
+ `SELECT
788
+ id,
789
+ "resourceId",
790
+ title,
791
+ metadata,
792
+ "createdAt",
793
+ "updatedAt"
794
+ FROM "${storage.MastraStorage.TABLE_THREADS}"
795
+ WHERE "resourceId" = $1`,
796
+ [resourceId]
797
+ );
798
+ return threads.map((thread) => ({
799
+ ...thread,
800
+ metadata: typeof thread.metadata === "string" ? JSON.parse(thread.metadata) : thread.metadata,
801
+ createdAt: thread.createdAt,
802
+ updatedAt: thread.updatedAt
803
+ }));
804
+ } catch (error) {
805
+ console.error(`Error getting threads for resource ${resourceId}:`, error);
806
+ throw error;
807
+ }
808
+ }
809
+ async saveThread({ thread }) {
810
+ try {
811
+ await this.db.none(
812
+ `INSERT INTO "${storage.MastraStorage.TABLE_THREADS}" (
813
+ id,
814
+ "resourceId",
815
+ title,
816
+ metadata,
817
+ "createdAt",
818
+ "updatedAt"
819
+ ) VALUES ($1, $2, $3, $4, $5, $6)
820
+ ON CONFLICT (id) DO UPDATE SET
821
+ "resourceId" = EXCLUDED."resourceId",
822
+ title = EXCLUDED.title,
823
+ metadata = EXCLUDED.metadata,
824
+ "createdAt" = EXCLUDED."createdAt",
825
+ "updatedAt" = EXCLUDED."updatedAt"`,
826
+ [
827
+ thread.id,
828
+ thread.resourceId,
829
+ thread.title,
830
+ thread.metadata ? JSON.stringify(thread.metadata) : null,
831
+ thread.createdAt,
832
+ thread.updatedAt
833
+ ]
834
+ );
835
+ return thread;
836
+ } catch (error) {
837
+ console.error("Error saving thread:", error);
838
+ throw error;
839
+ }
840
+ }
841
+ async updateThread({
842
+ id,
843
+ title,
844
+ metadata
845
+ }) {
846
+ try {
847
+ const existingThread = await this.getThreadById({ threadId: id });
848
+ if (!existingThread) {
849
+ throw new Error(`Thread ${id} not found`);
850
+ }
851
+ const mergedMetadata = {
852
+ ...existingThread.metadata,
853
+ ...metadata
854
+ };
855
+ const thread = await this.db.one(
856
+ `UPDATE "${storage.MastraStorage.TABLE_THREADS}"
857
+ SET title = $1,
858
+ metadata = $2,
859
+ "updatedAt" = $3
860
+ WHERE id = $4
861
+ RETURNING *`,
862
+ [title, mergedMetadata, (/* @__PURE__ */ new Date()).toISOString(), id]
863
+ );
864
+ return {
865
+ ...thread,
866
+ metadata: typeof thread.metadata === "string" ? JSON.parse(thread.metadata) : thread.metadata,
867
+ createdAt: thread.createdAt,
868
+ updatedAt: thread.updatedAt
869
+ };
870
+ } catch (error) {
871
+ console.error("Error updating thread:", error);
872
+ throw error;
873
+ }
874
+ }
875
+ async deleteThread({ threadId }) {
876
+ try {
877
+ await this.db.tx(async (t) => {
878
+ await t.none(`DELETE FROM "${storage.MastraStorage.TABLE_MESSAGES}" WHERE thread_id = $1`, [threadId]);
879
+ await t.none(`DELETE FROM "${storage.MastraStorage.TABLE_THREADS}" WHERE id = $1`, [threadId]);
880
+ });
881
+ } catch (error) {
882
+ console.error("Error deleting thread:", error);
883
+ throw error;
884
+ }
885
+ }
886
+ async getMessages({ threadId, selectBy }) {
887
+ try {
888
+ const messages = [];
889
+ const limit = typeof selectBy?.last === `number` ? selectBy.last : 40;
890
+ const include = selectBy?.include || [];
891
+ if (include.length) {
892
+ const includeResult = await this.db.manyOrNone(
893
+ `
894
+ WITH ordered_messages AS (
895
+ SELECT
896
+ *,
897
+ ROW_NUMBER() OVER (ORDER BY "createdAt" DESC) as row_num
898
+ FROM "${storage.MastraStorage.TABLE_MESSAGES}"
899
+ WHERE thread_id = $1
900
+ )
901
+ SELECT
902
+ m.id,
903
+ m.content,
904
+ m.role,
905
+ m.type,
906
+ m."createdAt",
907
+ m.thread_id AS "threadId"
908
+ FROM ordered_messages m
909
+ WHERE m.id = ANY($2)
910
+ OR EXISTS (
911
+ SELECT 1 FROM ordered_messages target
912
+ WHERE target.id = ANY($2)
913
+ AND (
914
+ -- Get previous messages based on the max withPreviousMessages
915
+ (m.row_num <= target.row_num + $3 AND m.row_num > target.row_num)
916
+ OR
917
+ -- Get next messages based on the max withNextMessages
918
+ (m.row_num >= target.row_num - $4 AND m.row_num < target.row_num)
919
+ )
920
+ )
921
+ ORDER BY m."createdAt" DESC
922
+ `,
923
+ [
924
+ threadId,
925
+ include.map((i) => i.id),
926
+ Math.max(...include.map((i) => i.withPreviousMessages || 0)),
927
+ Math.max(...include.map((i) => i.withNextMessages || 0))
928
+ ]
929
+ );
930
+ messages.push(...includeResult);
931
+ }
932
+ const result = await this.db.manyOrNone(
933
+ `
934
+ SELECT
935
+ id,
936
+ content,
937
+ role,
938
+ type,
939
+ "createdAt",
940
+ thread_id AS "threadId"
941
+ FROM "${storage.MastraStorage.TABLE_MESSAGES}"
942
+ WHERE thread_id = $1
943
+ AND id != ALL($2)
944
+ ORDER BY "createdAt" DESC
945
+ LIMIT $3
946
+ `,
947
+ [threadId, messages.map((m) => m.id), limit]
948
+ );
949
+ messages.push(...result);
950
+ messages.sort((a, b) => new Date(a.createdAt).getTime() - new Date(b.createdAt).getTime());
951
+ messages.forEach((message) => {
952
+ if (typeof message.content === "string") {
953
+ try {
954
+ message.content = JSON.parse(message.content);
955
+ } catch {
956
+ }
957
+ }
958
+ });
959
+ return messages;
960
+ } catch (error) {
961
+ console.error("Error getting messages:", error);
962
+ throw error;
963
+ }
964
+ }
965
+ async saveMessages({ messages }) {
966
+ if (messages.length === 0) return messages;
967
+ try {
968
+ const threadId = messages[0]?.threadId;
969
+ if (!threadId) {
970
+ throw new Error("Thread ID is required");
971
+ }
972
+ const thread = await this.getThreadById({ threadId });
973
+ if (!thread) {
974
+ throw new Error(`Thread ${threadId} not found`);
975
+ }
976
+ await this.db.tx(async (t) => {
977
+ for (const message of messages) {
978
+ await t.none(
979
+ `INSERT INTO "${storage.MastraStorage.TABLE_MESSAGES}" (id, thread_id, content, "createdAt", role, type)
980
+ VALUES ($1, $2, $3, $4, $5, $6)`,
981
+ [
982
+ message.id,
983
+ threadId,
984
+ typeof message.content === "string" ? message.content : JSON.stringify(message.content),
985
+ message.createdAt || (/* @__PURE__ */ new Date()).toISOString(),
986
+ message.role,
987
+ message.type
988
+ ]
989
+ );
990
+ }
991
+ });
992
+ return messages;
993
+ } catch (error) {
994
+ console.error("Error saving messages:", error);
995
+ throw error;
996
+ }
997
+ }
998
+ async persistWorkflowSnapshot({
999
+ workflowName,
1000
+ runId,
1001
+ snapshot
1002
+ }) {
1003
+ try {
1004
+ const now = (/* @__PURE__ */ new Date()).toISOString();
1005
+ await this.db.none(
1006
+ `INSERT INTO "${storage.MastraStorage.TABLE_WORKFLOW_SNAPSHOT}" (
1007
+ workflow_name,
1008
+ run_id,
1009
+ snapshot,
1010
+ "createdAt",
1011
+ "updatedAt"
1012
+ ) VALUES ($1, $2, $3, $4, $5)
1013
+ ON CONFLICT (workflow_name, run_id) DO UPDATE
1014
+ SET snapshot = EXCLUDED.snapshot,
1015
+ "updatedAt" = EXCLUDED."updatedAt"`,
1016
+ [workflowName, runId, JSON.stringify(snapshot), now, now]
1017
+ );
1018
+ } catch (error) {
1019
+ console.error("Error persisting workflow snapshot:", error);
1020
+ throw error;
1021
+ }
1022
+ }
1023
+ async loadWorkflowSnapshot({
1024
+ workflowName,
1025
+ runId
1026
+ }) {
1027
+ try {
1028
+ const result = await this.load({
1029
+ tableName: storage.MastraStorage.TABLE_WORKFLOW_SNAPSHOT,
1030
+ keys: {
1031
+ workflow_name: workflowName,
1032
+ run_id: runId
1033
+ }
1034
+ });
1035
+ if (!result) {
1036
+ return null;
1037
+ }
1038
+ return result.snapshot;
1039
+ } catch (error) {
1040
+ console.error("Error loading workflow snapshot:", error);
1041
+ throw error;
1042
+ }
1043
+ }
1044
+ async close() {
1045
+ this.pgp.end();
1046
+ }
1047
+ };
1048
+
1049
+ exports.PgVector = PgVector;
1050
+ exports.PostgresStore = PostgresStore;