@mastra/pg 1.0.0-beta.2 → 1.0.0-beta.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +27 -0
- package/dist/index.cjs +50 -44
- package/dist/index.cjs.map +1 -1
- package/dist/index.js +51 -45
- package/dist/index.js.map +1 -1
- package/dist/shared/config.d.ts +1 -1
- package/dist/shared/config.d.ts.map +1 -1
- package/dist/storage/domains/memory/index.d.ts.map +1 -1
- package/dist/storage/domains/scores/index.d.ts.map +1 -1
- package/dist/vector/index.d.ts.map +1 -1
- package/package.json +4 -4
package/dist/index.js
CHANGED
|
@@ -5,7 +5,7 @@ import { Mutex } from 'async-mutex';
|
|
|
5
5
|
import * as pg from 'pg';
|
|
6
6
|
import xxhash from 'xxhash-wasm';
|
|
7
7
|
import { BaseFilterTranslator } from '@mastra/core/vector/filter';
|
|
8
|
-
import { MastraStorage, StoreOperations, TABLE_SCHEMAS, TABLE_WORKFLOW_SNAPSHOT, TABLE_SPANS, TABLE_THREADS, TABLE_MESSAGES, TABLE_TRACES, TABLE_SCORERS, ScoresStorage, normalizePerPage, calculatePagination, WorkflowsStorage, MemoryStorage, TABLE_RESOURCES, ObservabilityStorage,
|
|
8
|
+
import { MastraStorage, StoreOperations, TABLE_SCHEMAS, TABLE_WORKFLOW_SNAPSHOT, TABLE_SPANS, TABLE_THREADS, TABLE_MESSAGES, TABLE_TRACES, TABLE_SCORERS, ScoresStorage, normalizePerPage, calculatePagination, WorkflowsStorage, MemoryStorage, TABLE_RESOURCES, ObservabilityStorage, transformScoreRow as transformScoreRow$1 } from '@mastra/core/storage';
|
|
9
9
|
import pgPromise from 'pg-promise';
|
|
10
10
|
import { MessageList } from '@mastra/core/agent';
|
|
11
11
|
import { saveScorePayloadSchema } from '@mastra/core/evals';
|
|
@@ -596,9 +596,6 @@ var PgVector = class extends MastraVector {
|
|
|
596
596
|
if (this.vectorExtensionSchema === "pg_catalog") {
|
|
597
597
|
return "vector";
|
|
598
598
|
}
|
|
599
|
-
if (this.vectorExtensionSchema === (this.schema || "public")) {
|
|
600
|
-
return "vector";
|
|
601
|
-
}
|
|
602
599
|
const validatedSchema = parseSqlIdentifier(this.vectorExtensionSchema, "vector extension schema");
|
|
603
600
|
return `${validatedSchema}.vector`;
|
|
604
601
|
}
|
|
@@ -1888,6 +1885,20 @@ var MemoryPG = class extends MemoryStorage {
|
|
|
1888
1885
|
const threadTableName = getTableName({ indexName: TABLE_THREADS, schemaName: getSchemaName(this.schema) });
|
|
1889
1886
|
await this.client.tx(async (t) => {
|
|
1890
1887
|
await t.none(`DELETE FROM ${tableName} WHERE thread_id = $1`, [threadId]);
|
|
1888
|
+
const schemaName = this.schema || "public";
|
|
1889
|
+
const vectorTables = await t.manyOrNone(
|
|
1890
|
+
`
|
|
1891
|
+
SELECT tablename
|
|
1892
|
+
FROM pg_tables
|
|
1893
|
+
WHERE schemaname = $1
|
|
1894
|
+
AND (tablename = 'memory_messages' OR tablename LIKE 'memory_messages_%')
|
|
1895
|
+
`,
|
|
1896
|
+
[schemaName]
|
|
1897
|
+
);
|
|
1898
|
+
for (const { tablename } of vectorTables) {
|
|
1899
|
+
const vectorTableName = getTableName({ indexName: tablename, schemaName: getSchemaName(this.schema) });
|
|
1900
|
+
await t.none(`DELETE FROM ${vectorTableName} WHERE metadata->>'thread_id' = $1`, [threadId]);
|
|
1901
|
+
}
|
|
1891
1902
|
await t.none(`DELETE FROM ${threadTableName} WHERE id = $1`, [threadId]);
|
|
1892
1903
|
});
|
|
1893
1904
|
} catch (error) {
|
|
@@ -1904,28 +1915,26 @@ var MemoryPG = class extends MemoryStorage {
|
|
|
1904
1915
|
);
|
|
1905
1916
|
}
|
|
1906
1917
|
}
|
|
1907
|
-
async _getIncludedMessages({
|
|
1908
|
-
|
|
1909
|
-
include
|
|
1910
|
-
}) {
|
|
1911
|
-
if (!threadId.trim()) throw new Error("threadId must be a non-empty string");
|
|
1912
|
-
if (!include) return null;
|
|
1918
|
+
async _getIncludedMessages({ include }) {
|
|
1919
|
+
if (!include || include.length === 0) return null;
|
|
1913
1920
|
const unionQueries = [];
|
|
1914
1921
|
const params = [];
|
|
1915
1922
|
let paramIdx = 1;
|
|
1916
1923
|
const tableName = getTableName({ indexName: TABLE_MESSAGES, schemaName: getSchemaName(this.schema) });
|
|
1917
1924
|
for (const inc of include) {
|
|
1918
1925
|
const { id, withPreviousMessages = 0, withNextMessages = 0 } = inc;
|
|
1919
|
-
const searchId = inc.threadId || threadId;
|
|
1920
1926
|
unionQueries.push(
|
|
1921
1927
|
`
|
|
1922
1928
|
SELECT * FROM (
|
|
1923
|
-
WITH
|
|
1929
|
+
WITH target_thread AS (
|
|
1930
|
+
SELECT thread_id FROM ${tableName} WHERE id = $${paramIdx}
|
|
1931
|
+
),
|
|
1932
|
+
ordered_messages AS (
|
|
1924
1933
|
SELECT
|
|
1925
1934
|
*,
|
|
1926
1935
|
ROW_NUMBER() OVER (ORDER BY "createdAt" ASC) as row_num
|
|
1927
1936
|
FROM ${tableName}
|
|
1928
|
-
WHERE thread_id =
|
|
1937
|
+
WHERE thread_id = (SELECT thread_id FROM target_thread)
|
|
1929
1938
|
)
|
|
1930
1939
|
SELECT
|
|
1931
1940
|
m.id,
|
|
@@ -1937,24 +1946,24 @@ var MemoryPG = class extends MemoryStorage {
|
|
|
1937
1946
|
m.thread_id AS "threadId",
|
|
1938
1947
|
m."resourceId"
|
|
1939
1948
|
FROM ordered_messages m
|
|
1940
|
-
WHERE m.id = $${paramIdx
|
|
1949
|
+
WHERE m.id = $${paramIdx}
|
|
1941
1950
|
OR EXISTS (
|
|
1942
1951
|
SELECT 1 FROM ordered_messages target
|
|
1943
|
-
WHERE target.id = $${paramIdx
|
|
1952
|
+
WHERE target.id = $${paramIdx}
|
|
1944
1953
|
AND (
|
|
1945
1954
|
-- Get previous messages (messages that come BEFORE the target)
|
|
1946
|
-
(m.row_num < target.row_num AND m.row_num >= target.row_num - $${paramIdx +
|
|
1955
|
+
(m.row_num < target.row_num AND m.row_num >= target.row_num - $${paramIdx + 1})
|
|
1947
1956
|
OR
|
|
1948
1957
|
-- Get next messages (messages that come AFTER the target)
|
|
1949
|
-
(m.row_num > target.row_num AND m.row_num <= target.row_num + $${paramIdx +
|
|
1958
|
+
(m.row_num > target.row_num AND m.row_num <= target.row_num + $${paramIdx + 2})
|
|
1950
1959
|
)
|
|
1951
1960
|
)
|
|
1952
1961
|
) AS query_${paramIdx}
|
|
1953
1962
|
`
|
|
1954
1963
|
// Keep ASC for final sorting after fetching context
|
|
1955
1964
|
);
|
|
1956
|
-
params.push(
|
|
1957
|
-
paramIdx +=
|
|
1965
|
+
params.push(id, withPreviousMessages, withNextMessages);
|
|
1966
|
+
paramIdx += 3;
|
|
1958
1967
|
}
|
|
1959
1968
|
const finalQuery = unionQueries.join(" UNION ALL ") + ' ORDER BY "createdAt" ASC';
|
|
1960
1969
|
const includedRows = await this.client.manyOrNone(finalQuery, params);
|
|
@@ -2018,15 +2027,18 @@ var MemoryPG = class extends MemoryStorage {
|
|
|
2018
2027
|
}
|
|
2019
2028
|
async listMessages(args) {
|
|
2020
2029
|
const { threadId, resourceId, include, filter, perPage: perPageInput, page = 0, orderBy } = args;
|
|
2021
|
-
|
|
2030
|
+
const threadIds = (Array.isArray(threadId) ? threadId : [threadId]).filter(
|
|
2031
|
+
(id) => typeof id === "string"
|
|
2032
|
+
);
|
|
2033
|
+
if (threadIds.length === 0 || threadIds.some((id) => !id.trim())) {
|
|
2022
2034
|
throw new MastraError(
|
|
2023
2035
|
{
|
|
2024
2036
|
id: "STORAGE_PG_LIST_MESSAGES_INVALID_THREAD_ID",
|
|
2025
2037
|
domain: ErrorDomain.STORAGE,
|
|
2026
2038
|
category: ErrorCategory.THIRD_PARTY,
|
|
2027
|
-
details: { threadId }
|
|
2039
|
+
details: { threadId: Array.isArray(threadId) ? String(threadId) : String(threadId) }
|
|
2028
2040
|
},
|
|
2029
|
-
new Error("threadId must be a non-empty string")
|
|
2041
|
+
new Error("threadId must be a non-empty string or array of non-empty strings")
|
|
2030
2042
|
);
|
|
2031
2043
|
}
|
|
2032
2044
|
if (page < 0) {
|
|
@@ -2036,7 +2048,7 @@ var MemoryPG = class extends MemoryStorage {
|
|
|
2036
2048
|
category: ErrorCategory.USER,
|
|
2037
2049
|
text: "Page number must be non-negative",
|
|
2038
2050
|
details: {
|
|
2039
|
-
threadId,
|
|
2051
|
+
threadId: Array.isArray(threadId) ? threadId.join(",") : threadId,
|
|
2040
2052
|
page
|
|
2041
2053
|
}
|
|
2042
2054
|
});
|
|
@@ -2048,9 +2060,10 @@ var MemoryPG = class extends MemoryStorage {
|
|
|
2048
2060
|
const orderByStatement = `ORDER BY "${field}" ${direction}`;
|
|
2049
2061
|
const selectStatement = `SELECT id, content, role, type, "createdAt", "createdAtZ", thread_id AS "threadId", "resourceId"`;
|
|
2050
2062
|
const tableName = getTableName({ indexName: TABLE_MESSAGES, schemaName: getSchemaName(this.schema) });
|
|
2051
|
-
const
|
|
2052
|
-
const
|
|
2053
|
-
|
|
2063
|
+
const threadPlaceholders = threadIds.map((_, i) => `$${i + 1}`).join(", ");
|
|
2064
|
+
const conditions = [`thread_id IN (${threadPlaceholders})`];
|
|
2065
|
+
const queryParams = [...threadIds];
|
|
2066
|
+
let paramIndex = threadIds.length + 1;
|
|
2054
2067
|
if (resourceId) {
|
|
2055
2068
|
conditions.push(`"resourceId" = $${paramIndex++}`);
|
|
2056
2069
|
queryParams.push(resourceId);
|
|
@@ -2082,7 +2095,7 @@ var MemoryPG = class extends MemoryStorage {
|
|
|
2082
2095
|
}
|
|
2083
2096
|
const messageIds = new Set(messages.map((m) => m.id));
|
|
2084
2097
|
if (include && include.length > 0) {
|
|
2085
|
-
const includeMessages = await this._getIncludedMessages({
|
|
2098
|
+
const includeMessages = await this._getIncludedMessages({ include });
|
|
2086
2099
|
if (includeMessages) {
|
|
2087
2100
|
for (const includeMsg of includeMessages) {
|
|
2088
2101
|
if (!messageIds.has(includeMsg.id)) {
|
|
@@ -2109,7 +2122,10 @@ var MemoryPG = class extends MemoryStorage {
|
|
|
2109
2122
|
}
|
|
2110
2123
|
return direction === "ASC" ? String(aValue).localeCompare(String(bValue)) : String(bValue).localeCompare(String(aValue));
|
|
2111
2124
|
});
|
|
2112
|
-
const
|
|
2125
|
+
const threadIdSet = new Set(threadIds);
|
|
2126
|
+
const returnedThreadMessageIds = new Set(
|
|
2127
|
+
finalMessages.filter((m) => m.threadId && threadIdSet.has(m.threadId)).map((m) => m.id)
|
|
2128
|
+
);
|
|
2113
2129
|
const allThreadMessagesReturned = returnedThreadMessageIds.size >= total;
|
|
2114
2130
|
const hasMore = perPageInput !== false && !allThreadMessagesReturned && offset + perPage < total;
|
|
2115
2131
|
return {
|
|
@@ -2126,7 +2142,7 @@ var MemoryPG = class extends MemoryStorage {
|
|
|
2126
2142
|
domain: ErrorDomain.STORAGE,
|
|
2127
2143
|
category: ErrorCategory.THIRD_PARTY,
|
|
2128
2144
|
details: {
|
|
2129
|
-
threadId,
|
|
2145
|
+
threadId: Array.isArray(threadId) ? threadId.join(",") : threadId,
|
|
2130
2146
|
resourceId: resourceId ?? ""
|
|
2131
2147
|
}
|
|
2132
2148
|
},
|
|
@@ -3581,20 +3597,12 @@ var StoreOperationsPG = class extends StoreOperations {
|
|
|
3581
3597
|
}
|
|
3582
3598
|
};
|
|
3583
3599
|
function transformScoreRow(row) {
|
|
3584
|
-
return {
|
|
3585
|
-
|
|
3586
|
-
|
|
3587
|
-
|
|
3588
|
-
|
|
3589
|
-
|
|
3590
|
-
metadata: safelyParseJSON(row.metadata),
|
|
3591
|
-
output: safelyParseJSON(row.output),
|
|
3592
|
-
additionalContext: safelyParseJSON(row.additionalContext),
|
|
3593
|
-
requestContext: safelyParseJSON(row.requestContext),
|
|
3594
|
-
entity: safelyParseJSON(row.entity),
|
|
3595
|
-
createdAt: row.createdAtZ || row.createdAt,
|
|
3596
|
-
updatedAt: row.updatedAtZ || row.updatedAt
|
|
3597
|
-
};
|
|
3600
|
+
return transformScoreRow$1(row, {
|
|
3601
|
+
preferredTimestampFields: {
|
|
3602
|
+
createdAt: "createdAtZ",
|
|
3603
|
+
updatedAt: "updatedAtZ"
|
|
3604
|
+
}
|
|
3605
|
+
});
|
|
3598
3606
|
}
|
|
3599
3607
|
var ScoresPG = class extends ScoresStorage {
|
|
3600
3608
|
client;
|
|
@@ -3741,8 +3749,6 @@ var ScoresPG = class extends ScoresStorage {
|
|
|
3741
3749
|
scorer: scorer ? JSON.stringify(scorer) : null,
|
|
3742
3750
|
preprocessStepResult: preprocessStepResult ? JSON.stringify(preprocessStepResult) : null,
|
|
3743
3751
|
analyzeStepResult: analyzeStepResult ? JSON.stringify(analyzeStepResult) : null,
|
|
3744
|
-
metadata: metadata ? JSON.stringify(metadata) : null,
|
|
3745
|
-
additionalContext: additionalContext ? JSON.stringify(additionalContext) : null,
|
|
3746
3752
|
requestContext: requestContext ? JSON.stringify(requestContext) : null,
|
|
3747
3753
|
entity: entity ? JSON.stringify(entity) : null,
|
|
3748
3754
|
createdAt: (/* @__PURE__ */ new Date()).toISOString(),
|