@mastra/pg 1.0.0-beta.1 → 1.0.0-beta.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +36 -0
- package/README.md +3 -0
- package/dist/index.cjs +336 -58
- package/dist/index.cjs.map +1 -1
- package/dist/index.js +337 -59
- package/dist/index.js.map +1 -1
- package/dist/shared/config.d.ts +1 -1
- package/dist/shared/config.d.ts.map +1 -1
- package/dist/storage/domains/memory/index.d.ts.map +1 -1
- package/dist/storage/domains/scores/index.d.ts.map +1 -1
- package/dist/vector/index.d.ts +11 -3
- package/dist/vector/index.d.ts.map +1 -1
- package/dist/vector/sql-builder.d.ts +4 -0
- package/dist/vector/sql-builder.d.ts.map +1 -1
- package/package.json +7 -7
package/dist/index.js
CHANGED
|
@@ -5,7 +5,7 @@ import { Mutex } from 'async-mutex';
|
|
|
5
5
|
import * as pg from 'pg';
|
|
6
6
|
import xxhash from 'xxhash-wasm';
|
|
7
7
|
import { BaseFilterTranslator } from '@mastra/core/vector/filter';
|
|
8
|
-
import { MastraStorage, StoreOperations, TABLE_SCHEMAS, TABLE_WORKFLOW_SNAPSHOT, TABLE_SPANS, TABLE_THREADS, TABLE_MESSAGES, TABLE_TRACES, TABLE_SCORERS, ScoresStorage, normalizePerPage, calculatePagination, WorkflowsStorage, MemoryStorage, TABLE_RESOURCES, ObservabilityStorage,
|
|
8
|
+
import { MastraStorage, StoreOperations, TABLE_SCHEMAS, TABLE_WORKFLOW_SNAPSHOT, TABLE_SPANS, TABLE_THREADS, TABLE_MESSAGES, TABLE_TRACES, TABLE_SCORERS, ScoresStorage, normalizePerPage, calculatePagination, WorkflowsStorage, MemoryStorage, TABLE_RESOURCES, ObservabilityStorage, transformScoreRow as transformScoreRow$1 } from '@mastra/core/storage';
|
|
9
9
|
import pgPromise from 'pg-promise';
|
|
10
10
|
import { MessageList } from '@mastra/core/agent';
|
|
11
11
|
import { saveScorePayloadSchema } from '@mastra/core/evals';
|
|
@@ -126,12 +126,20 @@ var createBasicOperator = (symbol) => {
|
|
|
126
126
|
};
|
|
127
127
|
};
|
|
128
128
|
var createNumericOperator = (symbol) => {
|
|
129
|
-
return (key, paramIndex) => {
|
|
129
|
+
return (key, paramIndex, value) => {
|
|
130
130
|
const jsonPathKey = parseJsonPathKey(key);
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
131
|
+
const isNumeric = typeof value === "number" || typeof value === "string" && !isNaN(Number(value)) && value.trim() !== "";
|
|
132
|
+
if (isNumeric) {
|
|
133
|
+
return {
|
|
134
|
+
sql: `(metadata#>>'{${jsonPathKey}}')::numeric ${symbol} $${paramIndex}::numeric`,
|
|
135
|
+
needsValue: true
|
|
136
|
+
};
|
|
137
|
+
} else {
|
|
138
|
+
return {
|
|
139
|
+
sql: `metadata#>>'{${jsonPathKey}}' ${symbol} $${paramIndex}::text`,
|
|
140
|
+
needsValue: true
|
|
141
|
+
};
|
|
142
|
+
}
|
|
135
143
|
};
|
|
136
144
|
};
|
|
137
145
|
function buildElemMatchConditions(value, paramIndex) {
|
|
@@ -312,6 +320,83 @@ var parseJsonPathKey = (key) => {
|
|
|
312
320
|
function escapeLikePattern(str) {
|
|
313
321
|
return str.replace(/([%_\\])/g, "\\$1");
|
|
314
322
|
}
|
|
323
|
+
function buildDeleteFilterQuery(filter) {
|
|
324
|
+
const values = [];
|
|
325
|
+
function buildCondition(key, value, parentPath) {
|
|
326
|
+
if (["$and", "$or", "$not", "$nor"].includes(key)) {
|
|
327
|
+
return handleLogicalOperator(key, value);
|
|
328
|
+
}
|
|
329
|
+
if (!value || typeof value !== "object") {
|
|
330
|
+
values.push(value);
|
|
331
|
+
return `metadata#>>'{${parseJsonPathKey(key)}}' = $${values.length}`;
|
|
332
|
+
}
|
|
333
|
+
const [[operator, operatorValue] = []] = Object.entries(value);
|
|
334
|
+
if (operator === "$not") {
|
|
335
|
+
const entries = Object.entries(operatorValue);
|
|
336
|
+
const conditions2 = entries.map(([nestedOp, nestedValue]) => {
|
|
337
|
+
if (!FILTER_OPERATORS[nestedOp]) {
|
|
338
|
+
throw new Error(`Invalid operator in $not condition: ${nestedOp}`);
|
|
339
|
+
}
|
|
340
|
+
const operatorFn2 = FILTER_OPERATORS[nestedOp];
|
|
341
|
+
const operatorResult2 = operatorFn2(key, values.length + 1, nestedValue);
|
|
342
|
+
if (operatorResult2.needsValue) {
|
|
343
|
+
values.push(nestedValue);
|
|
344
|
+
}
|
|
345
|
+
return operatorResult2.sql;
|
|
346
|
+
}).join(" AND ");
|
|
347
|
+
return `NOT (${conditions2})`;
|
|
348
|
+
}
|
|
349
|
+
const operatorFn = FILTER_OPERATORS[operator];
|
|
350
|
+
const operatorResult = operatorFn(key, values.length + 1, operatorValue);
|
|
351
|
+
if (operatorResult.needsValue) {
|
|
352
|
+
const transformedValue = operatorResult.transformValue ? operatorResult.transformValue() : operatorValue;
|
|
353
|
+
if (Array.isArray(transformedValue) && operator === "$elemMatch") {
|
|
354
|
+
values.push(...transformedValue);
|
|
355
|
+
} else {
|
|
356
|
+
values.push(transformedValue);
|
|
357
|
+
}
|
|
358
|
+
}
|
|
359
|
+
return operatorResult.sql;
|
|
360
|
+
}
|
|
361
|
+
function handleLogicalOperator(key, value, parentPath) {
|
|
362
|
+
if (key === "$not") {
|
|
363
|
+
const entries = Object.entries(value);
|
|
364
|
+
const conditions3 = entries.map(([fieldKey, fieldValue]) => buildCondition(fieldKey, fieldValue)).join(" AND ");
|
|
365
|
+
return `NOT (${conditions3})`;
|
|
366
|
+
}
|
|
367
|
+
if (!value || value.length === 0) {
|
|
368
|
+
switch (key) {
|
|
369
|
+
case "$and":
|
|
370
|
+
case "$nor":
|
|
371
|
+
return "true";
|
|
372
|
+
// Empty $and/$nor match everything
|
|
373
|
+
case "$or":
|
|
374
|
+
return "false";
|
|
375
|
+
// Empty $or matches nothing
|
|
376
|
+
default:
|
|
377
|
+
return "true";
|
|
378
|
+
}
|
|
379
|
+
}
|
|
380
|
+
const joinOperator = key === "$or" || key === "$nor" ? "OR" : "AND";
|
|
381
|
+
const conditions2 = value.map((f) => {
|
|
382
|
+
const entries = Object.entries(f || {});
|
|
383
|
+
if (entries.length === 0) return "";
|
|
384
|
+
const [firstKey, firstValue] = entries[0] || [];
|
|
385
|
+
if (["$and", "$or", "$not", "$nor"].includes(firstKey)) {
|
|
386
|
+
return buildCondition(firstKey, firstValue);
|
|
387
|
+
}
|
|
388
|
+
return entries.map(([k, v]) => buildCondition(k, v)).join(` ${joinOperator} `);
|
|
389
|
+
});
|
|
390
|
+
const joined = conditions2.join(` ${joinOperator} `);
|
|
391
|
+
const operatorFn = FILTER_OPERATORS[key];
|
|
392
|
+
return operatorFn(joined, 0, value).sql;
|
|
393
|
+
}
|
|
394
|
+
if (!filter) {
|
|
395
|
+
return { sql: "", values };
|
|
396
|
+
}
|
|
397
|
+
const conditions = Object.entries(filter).map(([key, value]) => buildCondition(key, value)).filter(Boolean).join(" AND ");
|
|
398
|
+
return { sql: conditions ? `WHERE ${conditions}` : "", values };
|
|
399
|
+
}
|
|
315
400
|
function buildFilterQuery(filter, minScore, topK) {
|
|
316
401
|
const values = [minScore, topK];
|
|
317
402
|
function buildCondition(key, value, parentPath) {
|
|
@@ -511,9 +596,6 @@ var PgVector = class extends MastraVector {
|
|
|
511
596
|
if (this.vectorExtensionSchema === "pg_catalog") {
|
|
512
597
|
return "vector";
|
|
513
598
|
}
|
|
514
|
-
if (this.vectorExtensionSchema === (this.schema || "public")) {
|
|
515
|
-
return "vector";
|
|
516
|
-
}
|
|
517
599
|
const validatedSchema = parseSqlIdentifier(this.vectorExtensionSchema, "vector extension schema");
|
|
518
600
|
return `${validatedSchema}.vector`;
|
|
519
601
|
}
|
|
@@ -633,11 +715,31 @@ var PgVector = class extends MastraVector {
|
|
|
633
715
|
client.release();
|
|
634
716
|
}
|
|
635
717
|
}
|
|
636
|
-
async upsert({
|
|
718
|
+
async upsert({
|
|
719
|
+
indexName,
|
|
720
|
+
vectors,
|
|
721
|
+
metadata,
|
|
722
|
+
ids,
|
|
723
|
+
deleteFilter
|
|
724
|
+
}) {
|
|
637
725
|
const { tableName } = this.getTableName(indexName);
|
|
638
726
|
const client = await this.pool.connect();
|
|
639
727
|
try {
|
|
640
728
|
await client.query("BEGIN");
|
|
729
|
+
if (deleteFilter) {
|
|
730
|
+
this.logger?.debug(`Deleting vectors matching filter before upsert`, { indexName, deleteFilter });
|
|
731
|
+
const translatedFilter = this.transformFilter(deleteFilter);
|
|
732
|
+
const { sql: filterQuery, values: filterValues } = buildDeleteFilterQuery(translatedFilter);
|
|
733
|
+
const whereClause = filterQuery.trim().replace(/^WHERE\s+/i, "");
|
|
734
|
+
if (whereClause) {
|
|
735
|
+
const deleteQuery = `DELETE FROM ${tableName} WHERE ${whereClause}`;
|
|
736
|
+
const result = await client.query(deleteQuery, filterValues);
|
|
737
|
+
this.logger?.debug(`Deleted ${result.rowCount || 0} vectors before upsert`, {
|
|
738
|
+
indexName,
|
|
739
|
+
deletedCount: result.rowCount || 0
|
|
740
|
+
});
|
|
741
|
+
}
|
|
742
|
+
}
|
|
641
743
|
const vectorIds = ids || vectors.map(() => crypto.randomUUID());
|
|
642
744
|
const vectorType = this.getVectorTypeName();
|
|
643
745
|
for (let i = 0; i < vectors.length; i++) {
|
|
@@ -653,6 +755,11 @@ var PgVector = class extends MastraVector {
|
|
|
653
755
|
await client.query(query, [vectorIds[i], `[${vectors[i]?.join(",")}]`, JSON.stringify(metadata?.[i] || {})]);
|
|
654
756
|
}
|
|
655
757
|
await client.query("COMMIT");
|
|
758
|
+
this.logger?.debug(`Upserted ${vectors.length} vectors to ${indexName}`, {
|
|
759
|
+
indexName,
|
|
760
|
+
vectorCount: vectors.length,
|
|
761
|
+
hadDeleteFilter: !!deleteFilter
|
|
762
|
+
});
|
|
656
763
|
return vectorIds;
|
|
657
764
|
} catch (error) {
|
|
658
765
|
await client.query("ROLLBACK");
|
|
@@ -1210,17 +1317,36 @@ var PgVector = class extends MastraVector {
|
|
|
1210
1317
|
* @returns A promise that resolves when the update is complete.
|
|
1211
1318
|
* @throws Will throw an error if no updates are provided or if the update operation fails.
|
|
1212
1319
|
*/
|
|
1213
|
-
async updateVector({ indexName, id, update }) {
|
|
1320
|
+
async updateVector({ indexName, id, filter, update }) {
|
|
1214
1321
|
let client;
|
|
1215
1322
|
try {
|
|
1216
1323
|
if (!update.vector && !update.metadata) {
|
|
1217
1324
|
throw new Error("No updates provided");
|
|
1218
1325
|
}
|
|
1326
|
+
if (!id && !filter) {
|
|
1327
|
+
throw new MastraError({
|
|
1328
|
+
id: "MASTRA_STORAGE_PG_VECTOR_UPDATE_MISSING_PARAMS",
|
|
1329
|
+
text: "Either id or filter must be provided",
|
|
1330
|
+
domain: ErrorDomain.MASTRA_VECTOR,
|
|
1331
|
+
category: ErrorCategory.USER,
|
|
1332
|
+
details: { indexName }
|
|
1333
|
+
});
|
|
1334
|
+
}
|
|
1335
|
+
if (id && filter) {
|
|
1336
|
+
throw new MastraError({
|
|
1337
|
+
id: "MASTRA_STORAGE_PG_VECTOR_UPDATE_CONFLICTING_PARAMS",
|
|
1338
|
+
text: "Cannot provide both id and filter - they are mutually exclusive",
|
|
1339
|
+
domain: ErrorDomain.MASTRA_VECTOR,
|
|
1340
|
+
category: ErrorCategory.USER,
|
|
1341
|
+
details: { indexName }
|
|
1342
|
+
});
|
|
1343
|
+
}
|
|
1219
1344
|
client = await this.pool.connect();
|
|
1220
|
-
|
|
1221
|
-
let values = [id];
|
|
1222
|
-
let valueIndex = 2;
|
|
1345
|
+
const { tableName } = this.getTableName(indexName);
|
|
1223
1346
|
const vectorType = this.getVectorTypeName();
|
|
1347
|
+
let updateParts = [];
|
|
1348
|
+
let values = [];
|
|
1349
|
+
let valueIndex = 1;
|
|
1224
1350
|
if (update.vector) {
|
|
1225
1351
|
updateParts.push(`embedding = $${valueIndex}::${vectorType}`);
|
|
1226
1352
|
values.push(`[${update.vector.join(",")}]`);
|
|
@@ -1229,18 +1355,60 @@ var PgVector = class extends MastraVector {
|
|
|
1229
1355
|
if (update.metadata) {
|
|
1230
1356
|
updateParts.push(`metadata = $${valueIndex}::jsonb`);
|
|
1231
1357
|
values.push(JSON.stringify(update.metadata));
|
|
1358
|
+
valueIndex++;
|
|
1232
1359
|
}
|
|
1233
1360
|
if (updateParts.length === 0) {
|
|
1234
1361
|
return;
|
|
1235
1362
|
}
|
|
1236
|
-
|
|
1363
|
+
let whereClause;
|
|
1364
|
+
let whereValues;
|
|
1365
|
+
if (id) {
|
|
1366
|
+
whereClause = `vector_id = $${valueIndex}`;
|
|
1367
|
+
whereValues = [id];
|
|
1368
|
+
} else {
|
|
1369
|
+
if (!filter || Object.keys(filter).length === 0) {
|
|
1370
|
+
throw new MastraError({
|
|
1371
|
+
id: "MASTRA_STORAGE_PG_VECTOR_UPDATE_EMPTY_FILTER",
|
|
1372
|
+
text: "Cannot update with empty filter",
|
|
1373
|
+
domain: ErrorDomain.MASTRA_VECTOR,
|
|
1374
|
+
category: ErrorCategory.USER,
|
|
1375
|
+
details: { indexName }
|
|
1376
|
+
});
|
|
1377
|
+
}
|
|
1378
|
+
const translatedFilter = this.transformFilter(filter);
|
|
1379
|
+
const { sql: filterQuery, values: filterValues } = buildDeleteFilterQuery(translatedFilter);
|
|
1380
|
+
whereClause = filterQuery.trim().replace(/^WHERE\s+/i, "");
|
|
1381
|
+
if (!whereClause) {
|
|
1382
|
+
throw new MastraError({
|
|
1383
|
+
id: "MASTRA_STORAGE_PG_VECTOR_UPDATE_INVALID_FILTER",
|
|
1384
|
+
text: "Filter produced empty WHERE clause",
|
|
1385
|
+
domain: ErrorDomain.MASTRA_VECTOR,
|
|
1386
|
+
category: ErrorCategory.USER,
|
|
1387
|
+
details: { indexName, filter: JSON.stringify(filter) }
|
|
1388
|
+
});
|
|
1389
|
+
}
|
|
1390
|
+
whereClause = whereClause.replace(/\$(\d+)/g, (match, num) => {
|
|
1391
|
+
const newIndex = parseInt(num) + valueIndex - 1;
|
|
1392
|
+
return `$${newIndex}`;
|
|
1393
|
+
});
|
|
1394
|
+
whereValues = filterValues;
|
|
1395
|
+
}
|
|
1237
1396
|
const query = `
|
|
1238
1397
|
UPDATE ${tableName}
|
|
1239
1398
|
SET ${updateParts.join(", ")}
|
|
1240
|
-
WHERE
|
|
1399
|
+
WHERE ${whereClause}
|
|
1241
1400
|
`;
|
|
1242
|
-
await client.query(query, values);
|
|
1401
|
+
const result = await client.query(query, [...values, ...whereValues]);
|
|
1402
|
+
this.logger?.info(`Updated ${result.rowCount || 0} vectors in ${indexName}`, {
|
|
1403
|
+
indexName,
|
|
1404
|
+
id: id ? id : void 0,
|
|
1405
|
+
filter: filter ? filter : void 0,
|
|
1406
|
+
updatedCount: result.rowCount || 0
|
|
1407
|
+
});
|
|
1243
1408
|
} catch (error) {
|
|
1409
|
+
if (error instanceof MastraError) {
|
|
1410
|
+
throw error;
|
|
1411
|
+
}
|
|
1244
1412
|
const mastraError = new MastraError(
|
|
1245
1413
|
{
|
|
1246
1414
|
id: "MASTRA_STORAGE_PG_VECTOR_UPDATE_VECTOR_FAILED",
|
|
@@ -1248,7 +1416,8 @@ var PgVector = class extends MastraVector {
|
|
|
1248
1416
|
category: ErrorCategory.THIRD_PARTY,
|
|
1249
1417
|
details: {
|
|
1250
1418
|
indexName,
|
|
1251
|
-
id
|
|
1419
|
+
...id && { id },
|
|
1420
|
+
...filter && { filter: JSON.stringify(filter) }
|
|
1252
1421
|
}
|
|
1253
1422
|
},
|
|
1254
1423
|
error
|
|
@@ -1295,6 +1464,106 @@ var PgVector = class extends MastraVector {
|
|
|
1295
1464
|
client?.release();
|
|
1296
1465
|
}
|
|
1297
1466
|
}
|
|
1467
|
+
/**
|
|
1468
|
+
* Delete vectors matching a metadata filter.
|
|
1469
|
+
* @param indexName - The name of the index containing the vectors.
|
|
1470
|
+
* @param filter - The filter to match vectors for deletion.
|
|
1471
|
+
* @returns A promise that resolves when the deletion is complete.
|
|
1472
|
+
* @throws Will throw an error if the deletion operation fails.
|
|
1473
|
+
*/
|
|
1474
|
+
async deleteVectors({ indexName, filter, ids }) {
|
|
1475
|
+
let client;
|
|
1476
|
+
try {
|
|
1477
|
+
client = await this.pool.connect();
|
|
1478
|
+
const { tableName } = this.getTableName(indexName);
|
|
1479
|
+
if (!filter && !ids) {
|
|
1480
|
+
throw new MastraError({
|
|
1481
|
+
id: "MASTRA_STORAGE_PG_VECTOR_DELETE_MISSING_PARAMS",
|
|
1482
|
+
text: "Either filter or ids must be provided",
|
|
1483
|
+
domain: ErrorDomain.MASTRA_VECTOR,
|
|
1484
|
+
category: ErrorCategory.USER,
|
|
1485
|
+
details: { indexName }
|
|
1486
|
+
});
|
|
1487
|
+
}
|
|
1488
|
+
if (filter && ids) {
|
|
1489
|
+
throw new MastraError({
|
|
1490
|
+
id: "MASTRA_STORAGE_PG_VECTOR_DELETE_CONFLICTING_PARAMS",
|
|
1491
|
+
text: "Cannot provide both filter and ids - they are mutually exclusive",
|
|
1492
|
+
domain: ErrorDomain.MASTRA_VECTOR,
|
|
1493
|
+
category: ErrorCategory.USER,
|
|
1494
|
+
details: { indexName }
|
|
1495
|
+
});
|
|
1496
|
+
}
|
|
1497
|
+
let query;
|
|
1498
|
+
let values;
|
|
1499
|
+
if (ids) {
|
|
1500
|
+
if (ids.length === 0) {
|
|
1501
|
+
throw new MastraError({
|
|
1502
|
+
id: "MASTRA_STORAGE_PG_VECTOR_DELETE_EMPTY_IDS",
|
|
1503
|
+
text: "Cannot delete with empty ids array",
|
|
1504
|
+
domain: ErrorDomain.MASTRA_VECTOR,
|
|
1505
|
+
category: ErrorCategory.USER,
|
|
1506
|
+
details: { indexName }
|
|
1507
|
+
});
|
|
1508
|
+
}
|
|
1509
|
+
const placeholders = ids.map((_, i) => `$${i + 1}`).join(", ");
|
|
1510
|
+
query = `DELETE FROM ${tableName} WHERE vector_id IN (${placeholders})`;
|
|
1511
|
+
values = ids;
|
|
1512
|
+
} else {
|
|
1513
|
+
if (!filter || Object.keys(filter).length === 0) {
|
|
1514
|
+
throw new MastraError({
|
|
1515
|
+
id: "MASTRA_STORAGE_PG_VECTOR_DELETE_EMPTY_FILTER",
|
|
1516
|
+
text: "Cannot delete with empty filter. Use deleteIndex to delete all vectors.",
|
|
1517
|
+
domain: ErrorDomain.MASTRA_VECTOR,
|
|
1518
|
+
category: ErrorCategory.USER,
|
|
1519
|
+
details: { indexName }
|
|
1520
|
+
});
|
|
1521
|
+
}
|
|
1522
|
+
const translatedFilter = this.transformFilter(filter);
|
|
1523
|
+
const { sql: filterQuery, values: filterValues } = buildDeleteFilterQuery(translatedFilter);
|
|
1524
|
+
const whereClause = filterQuery.trim().replace(/^WHERE\s+/i, "");
|
|
1525
|
+
if (!whereClause) {
|
|
1526
|
+
throw new MastraError({
|
|
1527
|
+
id: "MASTRA_STORAGE_PG_VECTOR_DELETE_INVALID_FILTER",
|
|
1528
|
+
text: "Filter produced empty WHERE clause",
|
|
1529
|
+
domain: ErrorDomain.MASTRA_VECTOR,
|
|
1530
|
+
category: ErrorCategory.USER,
|
|
1531
|
+
details: { indexName, filter: JSON.stringify(filter) }
|
|
1532
|
+
});
|
|
1533
|
+
}
|
|
1534
|
+
query = `DELETE FROM ${tableName} WHERE ${whereClause}`;
|
|
1535
|
+
values = filterValues;
|
|
1536
|
+
}
|
|
1537
|
+
const result = await client.query(query, values);
|
|
1538
|
+
this.logger?.info(`Deleted ${result.rowCount || 0} vectors from ${indexName}`, {
|
|
1539
|
+
indexName,
|
|
1540
|
+
filter: filter ? filter : void 0,
|
|
1541
|
+
ids: ids ? ids : void 0,
|
|
1542
|
+
deletedCount: result.rowCount || 0
|
|
1543
|
+
});
|
|
1544
|
+
} catch (error) {
|
|
1545
|
+
if (error instanceof MastraError) {
|
|
1546
|
+
throw error;
|
|
1547
|
+
}
|
|
1548
|
+
const mastraError = new MastraError(
|
|
1549
|
+
{
|
|
1550
|
+
id: "MASTRA_STORAGE_PG_VECTOR_DELETE_VECTORS_FAILED",
|
|
1551
|
+
domain: ErrorDomain.MASTRA_VECTOR,
|
|
1552
|
+
category: ErrorCategory.THIRD_PARTY,
|
|
1553
|
+
details: {
|
|
1554
|
+
indexName,
|
|
1555
|
+
...filter && { filter: JSON.stringify(filter) },
|
|
1556
|
+
...ids && { idsCount: ids.length }
|
|
1557
|
+
}
|
|
1558
|
+
},
|
|
1559
|
+
error
|
|
1560
|
+
);
|
|
1561
|
+
this.logger?.trackException(mastraError);
|
|
1562
|
+
throw mastraError;
|
|
1563
|
+
} finally {
|
|
1564
|
+
client?.release();
|
|
1565
|
+
}
|
|
1566
|
+
}
|
|
1298
1567
|
};
|
|
1299
1568
|
function getSchemaName(schema) {
|
|
1300
1569
|
return schema ? `"${parseSqlIdentifier(schema, "schema name")}"` : void 0;
|
|
@@ -1616,6 +1885,20 @@ var MemoryPG = class extends MemoryStorage {
|
|
|
1616
1885
|
const threadTableName = getTableName({ indexName: TABLE_THREADS, schemaName: getSchemaName(this.schema) });
|
|
1617
1886
|
await this.client.tx(async (t) => {
|
|
1618
1887
|
await t.none(`DELETE FROM ${tableName} WHERE thread_id = $1`, [threadId]);
|
|
1888
|
+
const schemaName = this.schema || "public";
|
|
1889
|
+
const vectorTables = await t.manyOrNone(
|
|
1890
|
+
`
|
|
1891
|
+
SELECT tablename
|
|
1892
|
+
FROM pg_tables
|
|
1893
|
+
WHERE schemaname = $1
|
|
1894
|
+
AND (tablename = 'memory_messages' OR tablename LIKE 'memory_messages_%')
|
|
1895
|
+
`,
|
|
1896
|
+
[schemaName]
|
|
1897
|
+
);
|
|
1898
|
+
for (const { tablename } of vectorTables) {
|
|
1899
|
+
const vectorTableName = getTableName({ indexName: tablename, schemaName: getSchemaName(this.schema) });
|
|
1900
|
+
await t.none(`DELETE FROM ${vectorTableName} WHERE metadata->>'thread_id' = $1`, [threadId]);
|
|
1901
|
+
}
|
|
1619
1902
|
await t.none(`DELETE FROM ${threadTableName} WHERE id = $1`, [threadId]);
|
|
1620
1903
|
});
|
|
1621
1904
|
} catch (error) {
|
|
@@ -1632,28 +1915,26 @@ var MemoryPG = class extends MemoryStorage {
|
|
|
1632
1915
|
);
|
|
1633
1916
|
}
|
|
1634
1917
|
}
|
|
1635
|
-
async _getIncludedMessages({
|
|
1636
|
-
|
|
1637
|
-
include
|
|
1638
|
-
}) {
|
|
1639
|
-
if (!threadId.trim()) throw new Error("threadId must be a non-empty string");
|
|
1640
|
-
if (!include) return null;
|
|
1918
|
+
async _getIncludedMessages({ include }) {
|
|
1919
|
+
if (!include || include.length === 0) return null;
|
|
1641
1920
|
const unionQueries = [];
|
|
1642
1921
|
const params = [];
|
|
1643
1922
|
let paramIdx = 1;
|
|
1644
1923
|
const tableName = getTableName({ indexName: TABLE_MESSAGES, schemaName: getSchemaName(this.schema) });
|
|
1645
1924
|
for (const inc of include) {
|
|
1646
1925
|
const { id, withPreviousMessages = 0, withNextMessages = 0 } = inc;
|
|
1647
|
-
const searchId = inc.threadId || threadId;
|
|
1648
1926
|
unionQueries.push(
|
|
1649
1927
|
`
|
|
1650
1928
|
SELECT * FROM (
|
|
1651
|
-
WITH
|
|
1929
|
+
WITH target_thread AS (
|
|
1930
|
+
SELECT thread_id FROM ${tableName} WHERE id = $${paramIdx}
|
|
1931
|
+
),
|
|
1932
|
+
ordered_messages AS (
|
|
1652
1933
|
SELECT
|
|
1653
1934
|
*,
|
|
1654
1935
|
ROW_NUMBER() OVER (ORDER BY "createdAt" ASC) as row_num
|
|
1655
1936
|
FROM ${tableName}
|
|
1656
|
-
WHERE thread_id =
|
|
1937
|
+
WHERE thread_id = (SELECT thread_id FROM target_thread)
|
|
1657
1938
|
)
|
|
1658
1939
|
SELECT
|
|
1659
1940
|
m.id,
|
|
@@ -1665,24 +1946,24 @@ var MemoryPG = class extends MemoryStorage {
|
|
|
1665
1946
|
m.thread_id AS "threadId",
|
|
1666
1947
|
m."resourceId"
|
|
1667
1948
|
FROM ordered_messages m
|
|
1668
|
-
WHERE m.id = $${paramIdx
|
|
1949
|
+
WHERE m.id = $${paramIdx}
|
|
1669
1950
|
OR EXISTS (
|
|
1670
1951
|
SELECT 1 FROM ordered_messages target
|
|
1671
|
-
WHERE target.id = $${paramIdx
|
|
1952
|
+
WHERE target.id = $${paramIdx}
|
|
1672
1953
|
AND (
|
|
1673
1954
|
-- Get previous messages (messages that come BEFORE the target)
|
|
1674
|
-
(m.row_num < target.row_num AND m.row_num >= target.row_num - $${paramIdx +
|
|
1955
|
+
(m.row_num < target.row_num AND m.row_num >= target.row_num - $${paramIdx + 1})
|
|
1675
1956
|
OR
|
|
1676
1957
|
-- Get next messages (messages that come AFTER the target)
|
|
1677
|
-
(m.row_num > target.row_num AND m.row_num <= target.row_num + $${paramIdx +
|
|
1958
|
+
(m.row_num > target.row_num AND m.row_num <= target.row_num + $${paramIdx + 2})
|
|
1678
1959
|
)
|
|
1679
1960
|
)
|
|
1680
1961
|
) AS query_${paramIdx}
|
|
1681
1962
|
`
|
|
1682
1963
|
// Keep ASC for final sorting after fetching context
|
|
1683
1964
|
);
|
|
1684
|
-
params.push(
|
|
1685
|
-
paramIdx +=
|
|
1965
|
+
params.push(id, withPreviousMessages, withNextMessages);
|
|
1966
|
+
paramIdx += 3;
|
|
1686
1967
|
}
|
|
1687
1968
|
const finalQuery = unionQueries.join(" UNION ALL ") + ' ORDER BY "createdAt" ASC';
|
|
1688
1969
|
const includedRows = await this.client.manyOrNone(finalQuery, params);
|
|
@@ -1746,15 +2027,18 @@ var MemoryPG = class extends MemoryStorage {
|
|
|
1746
2027
|
}
|
|
1747
2028
|
async listMessages(args) {
|
|
1748
2029
|
const { threadId, resourceId, include, filter, perPage: perPageInput, page = 0, orderBy } = args;
|
|
1749
|
-
|
|
2030
|
+
const threadIds = (Array.isArray(threadId) ? threadId : [threadId]).filter(
|
|
2031
|
+
(id) => typeof id === "string"
|
|
2032
|
+
);
|
|
2033
|
+
if (threadIds.length === 0 || threadIds.some((id) => !id.trim())) {
|
|
1750
2034
|
throw new MastraError(
|
|
1751
2035
|
{
|
|
1752
2036
|
id: "STORAGE_PG_LIST_MESSAGES_INVALID_THREAD_ID",
|
|
1753
2037
|
domain: ErrorDomain.STORAGE,
|
|
1754
2038
|
category: ErrorCategory.THIRD_PARTY,
|
|
1755
|
-
details: { threadId }
|
|
2039
|
+
details: { threadId: Array.isArray(threadId) ? String(threadId) : String(threadId) }
|
|
1756
2040
|
},
|
|
1757
|
-
new Error("threadId must be a non-empty string")
|
|
2041
|
+
new Error("threadId must be a non-empty string or array of non-empty strings")
|
|
1758
2042
|
);
|
|
1759
2043
|
}
|
|
1760
2044
|
if (page < 0) {
|
|
@@ -1764,7 +2048,7 @@ var MemoryPG = class extends MemoryStorage {
|
|
|
1764
2048
|
category: ErrorCategory.USER,
|
|
1765
2049
|
text: "Page number must be non-negative",
|
|
1766
2050
|
details: {
|
|
1767
|
-
threadId,
|
|
2051
|
+
threadId: Array.isArray(threadId) ? threadId.join(",") : threadId,
|
|
1768
2052
|
page
|
|
1769
2053
|
}
|
|
1770
2054
|
});
|
|
@@ -1776,9 +2060,10 @@ var MemoryPG = class extends MemoryStorage {
|
|
|
1776
2060
|
const orderByStatement = `ORDER BY "${field}" ${direction}`;
|
|
1777
2061
|
const selectStatement = `SELECT id, content, role, type, "createdAt", "createdAtZ", thread_id AS "threadId", "resourceId"`;
|
|
1778
2062
|
const tableName = getTableName({ indexName: TABLE_MESSAGES, schemaName: getSchemaName(this.schema) });
|
|
1779
|
-
const
|
|
1780
|
-
const
|
|
1781
|
-
|
|
2063
|
+
const threadPlaceholders = threadIds.map((_, i) => `$${i + 1}`).join(", ");
|
|
2064
|
+
const conditions = [`thread_id IN (${threadPlaceholders})`];
|
|
2065
|
+
const queryParams = [...threadIds];
|
|
2066
|
+
let paramIndex = threadIds.length + 1;
|
|
1782
2067
|
if (resourceId) {
|
|
1783
2068
|
conditions.push(`"resourceId" = $${paramIndex++}`);
|
|
1784
2069
|
queryParams.push(resourceId);
|
|
@@ -1810,7 +2095,7 @@ var MemoryPG = class extends MemoryStorage {
|
|
|
1810
2095
|
}
|
|
1811
2096
|
const messageIds = new Set(messages.map((m) => m.id));
|
|
1812
2097
|
if (include && include.length > 0) {
|
|
1813
|
-
const includeMessages = await this._getIncludedMessages({
|
|
2098
|
+
const includeMessages = await this._getIncludedMessages({ include });
|
|
1814
2099
|
if (includeMessages) {
|
|
1815
2100
|
for (const includeMsg of includeMessages) {
|
|
1816
2101
|
if (!messageIds.has(includeMsg.id)) {
|
|
@@ -1837,7 +2122,10 @@ var MemoryPG = class extends MemoryStorage {
|
|
|
1837
2122
|
}
|
|
1838
2123
|
return direction === "ASC" ? String(aValue).localeCompare(String(bValue)) : String(bValue).localeCompare(String(aValue));
|
|
1839
2124
|
});
|
|
1840
|
-
const
|
|
2125
|
+
const threadIdSet = new Set(threadIds);
|
|
2126
|
+
const returnedThreadMessageIds = new Set(
|
|
2127
|
+
finalMessages.filter((m) => m.threadId && threadIdSet.has(m.threadId)).map((m) => m.id)
|
|
2128
|
+
);
|
|
1841
2129
|
const allThreadMessagesReturned = returnedThreadMessageIds.size >= total;
|
|
1842
2130
|
const hasMore = perPageInput !== false && !allThreadMessagesReturned && offset + perPage < total;
|
|
1843
2131
|
return {
|
|
@@ -1854,7 +2142,7 @@ var MemoryPG = class extends MemoryStorage {
|
|
|
1854
2142
|
domain: ErrorDomain.STORAGE,
|
|
1855
2143
|
category: ErrorCategory.THIRD_PARTY,
|
|
1856
2144
|
details: {
|
|
1857
|
-
threadId,
|
|
2145
|
+
threadId: Array.isArray(threadId) ? threadId.join(",") : threadId,
|
|
1858
2146
|
resourceId: resourceId ?? ""
|
|
1859
2147
|
}
|
|
1860
2148
|
},
|
|
@@ -3309,20 +3597,12 @@ var StoreOperationsPG = class extends StoreOperations {
|
|
|
3309
3597
|
}
|
|
3310
3598
|
};
|
|
3311
3599
|
function transformScoreRow(row) {
|
|
3312
|
-
return {
|
|
3313
|
-
|
|
3314
|
-
|
|
3315
|
-
|
|
3316
|
-
|
|
3317
|
-
|
|
3318
|
-
metadata: safelyParseJSON(row.metadata),
|
|
3319
|
-
output: safelyParseJSON(row.output),
|
|
3320
|
-
additionalContext: safelyParseJSON(row.additionalContext),
|
|
3321
|
-
requestContext: safelyParseJSON(row.requestContext),
|
|
3322
|
-
entity: safelyParseJSON(row.entity),
|
|
3323
|
-
createdAt: row.createdAtZ || row.createdAt,
|
|
3324
|
-
updatedAt: row.updatedAtZ || row.updatedAt
|
|
3325
|
-
};
|
|
3600
|
+
return transformScoreRow$1(row, {
|
|
3601
|
+
preferredTimestampFields: {
|
|
3602
|
+
createdAt: "createdAtZ",
|
|
3603
|
+
updatedAt: "updatedAtZ"
|
|
3604
|
+
}
|
|
3605
|
+
});
|
|
3326
3606
|
}
|
|
3327
3607
|
var ScoresPG = class extends ScoresStorage {
|
|
3328
3608
|
client;
|
|
@@ -3469,8 +3749,6 @@ var ScoresPG = class extends ScoresStorage {
|
|
|
3469
3749
|
scorer: scorer ? JSON.stringify(scorer) : null,
|
|
3470
3750
|
preprocessStepResult: preprocessStepResult ? JSON.stringify(preprocessStepResult) : null,
|
|
3471
3751
|
analyzeStepResult: analyzeStepResult ? JSON.stringify(analyzeStepResult) : null,
|
|
3472
|
-
metadata: metadata ? JSON.stringify(metadata) : null,
|
|
3473
|
-
additionalContext: additionalContext ? JSON.stringify(additionalContext) : null,
|
|
3474
3752
|
requestContext: requestContext ? JSON.stringify(requestContext) : null,
|
|
3475
3753
|
entity: entity ? JSON.stringify(entity) : null,
|
|
3476
3754
|
createdAt: (/* @__PURE__ */ new Date()).toISOString(),
|