@mastra/pg 0.11.0 → 0.11.1-alpha.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.turbo/turbo-build.log +7 -7
- package/CHANGELOG.md +10 -0
- package/dist/_tsup-dts-rollup.d.cts +10 -0
- package/dist/_tsup-dts-rollup.d.ts +10 -0
- package/dist/index.cjs +95 -0
- package/dist/index.js +95 -0
- package/package.json +3 -3
- package/src/storage/index.test.ts +202 -11
- package/src/storage/index.ts +131 -1
package/.turbo/turbo-build.log
CHANGED
|
@@ -1,23 +1,23 @@
|
|
|
1
1
|
|
|
2
|
-
> @mastra/pg@0.11.
|
|
2
|
+
> @mastra/pg@0.11.1-alpha.0 build /home/runner/work/mastra/mastra/stores/pg
|
|
3
3
|
> tsup src/index.ts --format esm,cjs --experimental-dts --clean --treeshake=smallest --splitting
|
|
4
4
|
|
|
5
5
|
[34mCLI[39m Building entry: src/index.ts
|
|
6
6
|
[34mCLI[39m Using tsconfig: tsconfig.json
|
|
7
7
|
[34mCLI[39m tsup v8.5.0
|
|
8
8
|
[34mTSC[39m Build start
|
|
9
|
-
[32mTSC[39m ⚡️ Build success in
|
|
9
|
+
[32mTSC[39m ⚡️ Build success in 11031ms
|
|
10
10
|
[34mDTS[39m Build start
|
|
11
11
|
[34mCLI[39m Target: es2022
|
|
12
12
|
Analysis will use the bundled TypeScript version 5.8.3
|
|
13
13
|
[36mWriting package typings: /home/runner/work/mastra/mastra/stores/pg/dist/_tsup-dts-rollup.d.ts[39m
|
|
14
14
|
Analysis will use the bundled TypeScript version 5.8.3
|
|
15
15
|
[36mWriting package typings: /home/runner/work/mastra/mastra/stores/pg/dist/_tsup-dts-rollup.d.cts[39m
|
|
16
|
-
[32mDTS[39m ⚡️ Build success in
|
|
16
|
+
[32mDTS[39m ⚡️ Build success in 11090ms
|
|
17
17
|
[34mCLI[39m Cleaning output folder
|
|
18
18
|
[34mESM[39m Build start
|
|
19
19
|
[34mCJS[39m Build start
|
|
20
|
-
[32mESM[39m [1mdist/index.js [22m[
|
|
21
|
-
[32mESM[39m ⚡️ Build success in
|
|
22
|
-
[32mCJS[39m [1mdist/index.cjs [22m[
|
|
23
|
-
[32mCJS[39m ⚡️ Build success in
|
|
20
|
+
[32mESM[39m [1mdist/index.js [22m[32m71.25 KB[39m
|
|
21
|
+
[32mESM[39m ⚡️ Build success in 1507ms
|
|
22
|
+
[32mCJS[39m [1mdist/index.cjs [22m[32m71.84 KB[39m
|
|
23
|
+
[32mCJS[39m ⚡️ Build success in 1506ms
|
package/CHANGELOG.md
CHANGED
|
@@ -1,5 +1,15 @@
|
|
|
1
1
|
# @mastra/pg
|
|
2
2
|
|
|
3
|
+
## 0.11.1-alpha.0
|
|
4
|
+
|
|
5
|
+
### Patch Changes
|
|
6
|
+
|
|
7
|
+
- d8f2d19: Add updateMessages API to storage classes (only support for PG and LibSQL for now) and to memory class. Additionally allow for metadata to be saved in the content field of a message.
|
|
8
|
+
- Updated dependencies [d8f2d19]
|
|
9
|
+
- Updated dependencies [9d52b17]
|
|
10
|
+
- Updated dependencies [8ba1b51]
|
|
11
|
+
- @mastra/core@0.10.7-alpha.0
|
|
12
|
+
|
|
3
13
|
## 0.11.0
|
|
4
14
|
|
|
5
15
|
### Minor Changes
|
|
@@ -6,6 +6,7 @@ import type { DescribeIndexParams } from '@mastra/core/vector';
|
|
|
6
6
|
import type { EvalRow } from '@mastra/core/storage';
|
|
7
7
|
import type { IndexStats } from '@mastra/core/vector';
|
|
8
8
|
import type { ISSLConfig } from 'pg-promise/typescript/pg-subset';
|
|
9
|
+
import type { MastraMessageContentV2 } from '@mastra/core/agent';
|
|
9
10
|
import type { MastraMessageV1 } from '@mastra/core/memory';
|
|
10
11
|
import type { MastraMessageV2 } from '@mastra/core/agent';
|
|
11
12
|
import { MastraStorage } from '@mastra/core/storage';
|
|
@@ -411,6 +412,15 @@ declare class PostgresStore extends MastraStorage {
|
|
|
411
412
|
} & PaginationArgs): Promise<PaginationInfo & {
|
|
412
413
|
evals: EvalRow[];
|
|
413
414
|
}>;
|
|
415
|
+
updateMessages({ messages, }: {
|
|
416
|
+
messages: (Partial<Omit<MastraMessageV2, 'createdAt'>> & {
|
|
417
|
+
id: string;
|
|
418
|
+
content?: {
|
|
419
|
+
metadata?: MastraMessageContentV2['metadata'];
|
|
420
|
+
content?: MastraMessageContentV2['content'];
|
|
421
|
+
};
|
|
422
|
+
})[];
|
|
423
|
+
}): Promise<MastraMessageV2[]>;
|
|
414
424
|
}
|
|
415
425
|
export { PostgresStore }
|
|
416
426
|
export { PostgresStore as PostgresStore_alias_1 }
|
|
@@ -6,6 +6,7 @@ import type { DescribeIndexParams } from '@mastra/core/vector';
|
|
|
6
6
|
import type { EvalRow } from '@mastra/core/storage';
|
|
7
7
|
import type { IndexStats } from '@mastra/core/vector';
|
|
8
8
|
import type { ISSLConfig } from 'pg-promise/typescript/pg-subset';
|
|
9
|
+
import type { MastraMessageContentV2 } from '@mastra/core/agent';
|
|
9
10
|
import type { MastraMessageV1 } from '@mastra/core/memory';
|
|
10
11
|
import type { MastraMessageV2 } from '@mastra/core/agent';
|
|
11
12
|
import { MastraStorage } from '@mastra/core/storage';
|
|
@@ -411,6 +412,15 @@ declare class PostgresStore extends MastraStorage {
|
|
|
411
412
|
} & PaginationArgs): Promise<PaginationInfo & {
|
|
412
413
|
evals: EvalRow[];
|
|
413
414
|
}>;
|
|
415
|
+
updateMessages({ messages, }: {
|
|
416
|
+
messages: (Partial<Omit<MastraMessageV2, 'createdAt'>> & {
|
|
417
|
+
id: string;
|
|
418
|
+
content?: {
|
|
419
|
+
metadata?: MastraMessageContentV2['metadata'];
|
|
420
|
+
content?: MastraMessageContentV2['content'];
|
|
421
|
+
};
|
|
422
|
+
})[];
|
|
423
|
+
}): Promise<MastraMessageV2[]>;
|
|
414
424
|
}
|
|
415
425
|
export { PostgresStore }
|
|
416
426
|
export { PostgresStore as PostgresStore_alias_1 }
|
package/dist/index.cjs
CHANGED
|
@@ -1869,6 +1869,101 @@ var PostgresStore = class extends storage.MastraStorage {
|
|
|
1869
1869
|
hasMore: currentOffset + (rows?.length ?? 0) < total
|
|
1870
1870
|
};
|
|
1871
1871
|
}
|
|
1872
|
+
async updateMessages({
|
|
1873
|
+
messages
|
|
1874
|
+
}) {
|
|
1875
|
+
if (messages.length === 0) {
|
|
1876
|
+
return [];
|
|
1877
|
+
}
|
|
1878
|
+
const messageIds = messages.map((m) => m.id);
|
|
1879
|
+
const selectQuery = `SELECT id, content, role, type, "createdAt", thread_id AS "threadId", "resourceId" FROM ${this.getTableName(
|
|
1880
|
+
storage.TABLE_MESSAGES
|
|
1881
|
+
)} WHERE id IN ($1:list)`;
|
|
1882
|
+
const existingMessagesDb = await this.db.manyOrNone(selectQuery, [messageIds]);
|
|
1883
|
+
if (existingMessagesDb.length === 0) {
|
|
1884
|
+
return [];
|
|
1885
|
+
}
|
|
1886
|
+
const existingMessages = existingMessagesDb.map((msg) => {
|
|
1887
|
+
if (typeof msg.content === "string") {
|
|
1888
|
+
try {
|
|
1889
|
+
msg.content = JSON.parse(msg.content);
|
|
1890
|
+
} catch {
|
|
1891
|
+
}
|
|
1892
|
+
}
|
|
1893
|
+
return msg;
|
|
1894
|
+
});
|
|
1895
|
+
const threadIdsToUpdate = /* @__PURE__ */ new Set();
|
|
1896
|
+
await this.db.tx(async (t) => {
|
|
1897
|
+
const queries = [];
|
|
1898
|
+
const columnMapping = {
|
|
1899
|
+
threadId: "thread_id"
|
|
1900
|
+
};
|
|
1901
|
+
for (const existingMessage of existingMessages) {
|
|
1902
|
+
const updatePayload = messages.find((m) => m.id === existingMessage.id);
|
|
1903
|
+
if (!updatePayload) continue;
|
|
1904
|
+
const { id, ...fieldsToUpdate } = updatePayload;
|
|
1905
|
+
if (Object.keys(fieldsToUpdate).length === 0) continue;
|
|
1906
|
+
threadIdsToUpdate.add(existingMessage.threadId);
|
|
1907
|
+
if (updatePayload.threadId && updatePayload.threadId !== existingMessage.threadId) {
|
|
1908
|
+
threadIdsToUpdate.add(updatePayload.threadId);
|
|
1909
|
+
}
|
|
1910
|
+
const setClauses = [];
|
|
1911
|
+
const values = [];
|
|
1912
|
+
let paramIndex = 1;
|
|
1913
|
+
const updatableFields = { ...fieldsToUpdate };
|
|
1914
|
+
if (updatableFields.content) {
|
|
1915
|
+
const newContent = {
|
|
1916
|
+
...existingMessage.content,
|
|
1917
|
+
...updatableFields.content,
|
|
1918
|
+
// Deep merge metadata if it exists on both
|
|
1919
|
+
...existingMessage.content?.metadata && updatableFields.content.metadata ? {
|
|
1920
|
+
metadata: {
|
|
1921
|
+
...existingMessage.content.metadata,
|
|
1922
|
+
...updatableFields.content.metadata
|
|
1923
|
+
}
|
|
1924
|
+
} : {}
|
|
1925
|
+
};
|
|
1926
|
+
setClauses.push(`content = $${paramIndex++}`);
|
|
1927
|
+
values.push(newContent);
|
|
1928
|
+
delete updatableFields.content;
|
|
1929
|
+
}
|
|
1930
|
+
for (const key in updatableFields) {
|
|
1931
|
+
if (Object.prototype.hasOwnProperty.call(updatableFields, key)) {
|
|
1932
|
+
const dbColumn = columnMapping[key] || key;
|
|
1933
|
+
setClauses.push(`"${dbColumn}" = $${paramIndex++}`);
|
|
1934
|
+
values.push(updatableFields[key]);
|
|
1935
|
+
}
|
|
1936
|
+
}
|
|
1937
|
+
if (setClauses.length > 0) {
|
|
1938
|
+
values.push(id);
|
|
1939
|
+
const sql = `UPDATE ${this.getTableName(
|
|
1940
|
+
storage.TABLE_MESSAGES
|
|
1941
|
+
)} SET ${setClauses.join(", ")} WHERE id = $${paramIndex}`;
|
|
1942
|
+
queries.push(t.none(sql, values));
|
|
1943
|
+
}
|
|
1944
|
+
}
|
|
1945
|
+
if (threadIdsToUpdate.size > 0) {
|
|
1946
|
+
queries.push(
|
|
1947
|
+
t.none(`UPDATE ${this.getTableName(storage.TABLE_THREADS)} SET "updatedAt" = NOW() WHERE id IN ($1:list)`, [
|
|
1948
|
+
Array.from(threadIdsToUpdate)
|
|
1949
|
+
])
|
|
1950
|
+
);
|
|
1951
|
+
}
|
|
1952
|
+
if (queries.length > 0) {
|
|
1953
|
+
await t.batch(queries);
|
|
1954
|
+
}
|
|
1955
|
+
});
|
|
1956
|
+
const updatedMessages = await this.db.manyOrNone(selectQuery, [messageIds]);
|
|
1957
|
+
return (updatedMessages || []).map((message) => {
|
|
1958
|
+
if (typeof message.content === "string") {
|
|
1959
|
+
try {
|
|
1960
|
+
message.content = JSON.parse(message.content);
|
|
1961
|
+
} catch {
|
|
1962
|
+
}
|
|
1963
|
+
}
|
|
1964
|
+
return message;
|
|
1965
|
+
});
|
|
1966
|
+
}
|
|
1872
1967
|
};
|
|
1873
1968
|
|
|
1874
1969
|
// src/vector/prompt.ts
|
package/dist/index.js
CHANGED
|
@@ -1861,6 +1861,101 @@ var PostgresStore = class extends MastraStorage {
|
|
|
1861
1861
|
hasMore: currentOffset + (rows?.length ?? 0) < total
|
|
1862
1862
|
};
|
|
1863
1863
|
}
|
|
1864
|
+
async updateMessages({
|
|
1865
|
+
messages
|
|
1866
|
+
}) {
|
|
1867
|
+
if (messages.length === 0) {
|
|
1868
|
+
return [];
|
|
1869
|
+
}
|
|
1870
|
+
const messageIds = messages.map((m) => m.id);
|
|
1871
|
+
const selectQuery = `SELECT id, content, role, type, "createdAt", thread_id AS "threadId", "resourceId" FROM ${this.getTableName(
|
|
1872
|
+
TABLE_MESSAGES
|
|
1873
|
+
)} WHERE id IN ($1:list)`;
|
|
1874
|
+
const existingMessagesDb = await this.db.manyOrNone(selectQuery, [messageIds]);
|
|
1875
|
+
if (existingMessagesDb.length === 0) {
|
|
1876
|
+
return [];
|
|
1877
|
+
}
|
|
1878
|
+
const existingMessages = existingMessagesDb.map((msg) => {
|
|
1879
|
+
if (typeof msg.content === "string") {
|
|
1880
|
+
try {
|
|
1881
|
+
msg.content = JSON.parse(msg.content);
|
|
1882
|
+
} catch {
|
|
1883
|
+
}
|
|
1884
|
+
}
|
|
1885
|
+
return msg;
|
|
1886
|
+
});
|
|
1887
|
+
const threadIdsToUpdate = /* @__PURE__ */ new Set();
|
|
1888
|
+
await this.db.tx(async (t) => {
|
|
1889
|
+
const queries = [];
|
|
1890
|
+
const columnMapping = {
|
|
1891
|
+
threadId: "thread_id"
|
|
1892
|
+
};
|
|
1893
|
+
for (const existingMessage of existingMessages) {
|
|
1894
|
+
const updatePayload = messages.find((m) => m.id === existingMessage.id);
|
|
1895
|
+
if (!updatePayload) continue;
|
|
1896
|
+
const { id, ...fieldsToUpdate } = updatePayload;
|
|
1897
|
+
if (Object.keys(fieldsToUpdate).length === 0) continue;
|
|
1898
|
+
threadIdsToUpdate.add(existingMessage.threadId);
|
|
1899
|
+
if (updatePayload.threadId && updatePayload.threadId !== existingMessage.threadId) {
|
|
1900
|
+
threadIdsToUpdate.add(updatePayload.threadId);
|
|
1901
|
+
}
|
|
1902
|
+
const setClauses = [];
|
|
1903
|
+
const values = [];
|
|
1904
|
+
let paramIndex = 1;
|
|
1905
|
+
const updatableFields = { ...fieldsToUpdate };
|
|
1906
|
+
if (updatableFields.content) {
|
|
1907
|
+
const newContent = {
|
|
1908
|
+
...existingMessage.content,
|
|
1909
|
+
...updatableFields.content,
|
|
1910
|
+
// Deep merge metadata if it exists on both
|
|
1911
|
+
...existingMessage.content?.metadata && updatableFields.content.metadata ? {
|
|
1912
|
+
metadata: {
|
|
1913
|
+
...existingMessage.content.metadata,
|
|
1914
|
+
...updatableFields.content.metadata
|
|
1915
|
+
}
|
|
1916
|
+
} : {}
|
|
1917
|
+
};
|
|
1918
|
+
setClauses.push(`content = $${paramIndex++}`);
|
|
1919
|
+
values.push(newContent);
|
|
1920
|
+
delete updatableFields.content;
|
|
1921
|
+
}
|
|
1922
|
+
for (const key in updatableFields) {
|
|
1923
|
+
if (Object.prototype.hasOwnProperty.call(updatableFields, key)) {
|
|
1924
|
+
const dbColumn = columnMapping[key] || key;
|
|
1925
|
+
setClauses.push(`"${dbColumn}" = $${paramIndex++}`);
|
|
1926
|
+
values.push(updatableFields[key]);
|
|
1927
|
+
}
|
|
1928
|
+
}
|
|
1929
|
+
if (setClauses.length > 0) {
|
|
1930
|
+
values.push(id);
|
|
1931
|
+
const sql = `UPDATE ${this.getTableName(
|
|
1932
|
+
TABLE_MESSAGES
|
|
1933
|
+
)} SET ${setClauses.join(", ")} WHERE id = $${paramIndex}`;
|
|
1934
|
+
queries.push(t.none(sql, values));
|
|
1935
|
+
}
|
|
1936
|
+
}
|
|
1937
|
+
if (threadIdsToUpdate.size > 0) {
|
|
1938
|
+
queries.push(
|
|
1939
|
+
t.none(`UPDATE ${this.getTableName(TABLE_THREADS)} SET "updatedAt" = NOW() WHERE id IN ($1:list)`, [
|
|
1940
|
+
Array.from(threadIdsToUpdate)
|
|
1941
|
+
])
|
|
1942
|
+
);
|
|
1943
|
+
}
|
|
1944
|
+
if (queries.length > 0) {
|
|
1945
|
+
await t.batch(queries);
|
|
1946
|
+
}
|
|
1947
|
+
});
|
|
1948
|
+
const updatedMessages = await this.db.manyOrNone(selectQuery, [messageIds]);
|
|
1949
|
+
return (updatedMessages || []).map((message) => {
|
|
1950
|
+
if (typeof message.content === "string") {
|
|
1951
|
+
try {
|
|
1952
|
+
message.content = JSON.parse(message.content);
|
|
1953
|
+
} catch {
|
|
1954
|
+
}
|
|
1955
|
+
}
|
|
1956
|
+
return message;
|
|
1957
|
+
});
|
|
1958
|
+
}
|
|
1864
1959
|
};
|
|
1865
1960
|
|
|
1866
1961
|
// src/vector/prompt.ts
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@mastra/pg",
|
|
3
|
-
"version": "0.11.0",
|
|
3
|
+
"version": "0.11.1-alpha.0",
|
|
4
4
|
"description": "Postgres provider for Mastra - includes both vector and db storage capabilities",
|
|
5
5
|
"type": "module",
|
|
6
6
|
"main": "dist/index.js",
|
|
@@ -34,8 +34,8 @@
|
|
|
34
34
|
"typescript": "^5.8.3",
|
|
35
35
|
"vitest": "^3.2.3",
|
|
36
36
|
"@internal/lint": "0.0.13",
|
|
37
|
-
"@
|
|
38
|
-
"@
|
|
37
|
+
"@internal/storage-test-utils": "0.0.9",
|
|
38
|
+
"@mastra/core": "0.10.7-alpha.0"
|
|
39
39
|
},
|
|
40
40
|
"peerDependencies": {
|
|
41
41
|
"@mastra/core": ">=0.10.4-0 <0.11.0"
|
|
@@ -4,12 +4,12 @@ import {
|
|
|
4
4
|
createSampleTraceForDB,
|
|
5
5
|
createSampleThread,
|
|
6
6
|
createSampleMessageV1,
|
|
7
|
-
createSampleMessageV2,
|
|
8
7
|
createSampleWorkflowSnapshot,
|
|
9
8
|
resetRole,
|
|
10
9
|
checkWorkflowSnapshot,
|
|
11
10
|
} from '@internal/storage-test-utils';
|
|
12
|
-
import type {
|
|
11
|
+
import type { MastraMessageContentV2, MastraMessageV2 } from '@mastra/core/agent';
|
|
12
|
+
import type { MastraMessageV1, StorageThreadType } from '@mastra/core/memory';
|
|
13
13
|
import type { StorageColumn, TABLE_NAMES } from '@mastra/core/storage';
|
|
14
14
|
import {
|
|
15
15
|
TABLE_WORKFLOW_SNAPSHOT,
|
|
@@ -37,6 +37,37 @@ const connectionString = `postgresql://${TEST_CONFIG.user}:${TEST_CONFIG.passwor
|
|
|
37
37
|
|
|
38
38
|
vi.setConfig({ testTimeout: 60_000, hookTimeout: 60_000 });
|
|
39
39
|
|
|
40
|
+
const createSampleMessageV2 = ({
|
|
41
|
+
threadId,
|
|
42
|
+
resourceId,
|
|
43
|
+
role = 'user',
|
|
44
|
+
content,
|
|
45
|
+
createdAt,
|
|
46
|
+
thread,
|
|
47
|
+
}: {
|
|
48
|
+
threadId: string;
|
|
49
|
+
resourceId?: string;
|
|
50
|
+
role?: 'user' | 'assistant';
|
|
51
|
+
content?: Partial<MastraMessageContentV2>;
|
|
52
|
+
createdAt?: Date;
|
|
53
|
+
thread?: StorageThreadType;
|
|
54
|
+
}): MastraMessageV2 => {
|
|
55
|
+
return {
|
|
56
|
+
id: randomUUID(),
|
|
57
|
+
threadId,
|
|
58
|
+
resourceId: resourceId || thread?.resourceId || 'test-resource',
|
|
59
|
+
role,
|
|
60
|
+
createdAt: createdAt || new Date(),
|
|
61
|
+
content: {
|
|
62
|
+
format: 2,
|
|
63
|
+
parts: content?.parts || [{ type: 'text', text: content?.content ?? '' }],
|
|
64
|
+
content: content?.content || `Sample content ${randomUUID()}`,
|
|
65
|
+
...content,
|
|
66
|
+
},
|
|
67
|
+
type: 'v2',
|
|
68
|
+
};
|
|
69
|
+
};
|
|
70
|
+
|
|
40
71
|
describe('PostgresStore', () => {
|
|
41
72
|
let store: PostgresStore;
|
|
42
73
|
|
|
@@ -229,7 +260,9 @@ describe('PostgresStore', () => {
|
|
|
229
260
|
|
|
230
261
|
const messageContent = ['First', 'Second', 'Third'];
|
|
231
262
|
|
|
232
|
-
const messages = messageContent.map(content =>
|
|
263
|
+
const messages = messageContent.map(content =>
|
|
264
|
+
createSampleMessageV2({ threadId: thread.id, content: { content, parts: [{ type: 'text', text: content }] } }),
|
|
265
|
+
);
|
|
233
266
|
|
|
234
267
|
await store.saveMessages({ messages, format: 'v2' });
|
|
235
268
|
|
|
@@ -269,16 +302,48 @@ describe('PostgresStore', () => {
|
|
|
269
302
|
await store.saveThread({ thread: thread3 });
|
|
270
303
|
|
|
271
304
|
const messages: MastraMessageV2[] = [
|
|
272
|
-
createSampleMessageV2({
|
|
273
|
-
|
|
274
|
-
|
|
305
|
+
createSampleMessageV2({
|
|
306
|
+
threadId: 'thread-one',
|
|
307
|
+
content: { content: 'First' },
|
|
308
|
+
resourceId: 'cross-thread-resource',
|
|
309
|
+
}),
|
|
310
|
+
createSampleMessageV2({
|
|
311
|
+
threadId: 'thread-one',
|
|
312
|
+
content: { content: 'Second' },
|
|
313
|
+
resourceId: 'cross-thread-resource',
|
|
314
|
+
}),
|
|
315
|
+
createSampleMessageV2({
|
|
316
|
+
threadId: 'thread-one',
|
|
317
|
+
content: { content: 'Third' },
|
|
318
|
+
resourceId: 'cross-thread-resource',
|
|
319
|
+
}),
|
|
275
320
|
|
|
276
|
-
createSampleMessageV2({
|
|
277
|
-
|
|
278
|
-
|
|
321
|
+
createSampleMessageV2({
|
|
322
|
+
threadId: 'thread-two',
|
|
323
|
+
content: { content: 'Fourth' },
|
|
324
|
+
resourceId: 'cross-thread-resource',
|
|
325
|
+
}),
|
|
326
|
+
createSampleMessageV2({
|
|
327
|
+
threadId: 'thread-two',
|
|
328
|
+
content: { content: 'Fifth' },
|
|
329
|
+
resourceId: 'cross-thread-resource',
|
|
330
|
+
}),
|
|
331
|
+
createSampleMessageV2({
|
|
332
|
+
threadId: 'thread-two',
|
|
333
|
+
content: { content: 'Sixth' },
|
|
334
|
+
resourceId: 'cross-thread-resource',
|
|
335
|
+
}),
|
|
279
336
|
|
|
280
|
-
createSampleMessageV2({
|
|
281
|
-
|
|
337
|
+
createSampleMessageV2({
|
|
338
|
+
threadId: 'thread-three',
|
|
339
|
+
content: { content: 'Seventh' },
|
|
340
|
+
resourceId: 'other-resource',
|
|
341
|
+
}),
|
|
342
|
+
createSampleMessageV2({
|
|
343
|
+
threadId: 'thread-three',
|
|
344
|
+
content: { content: 'Eighth' },
|
|
345
|
+
resourceId: 'other-resource',
|
|
346
|
+
}),
|
|
282
347
|
];
|
|
283
348
|
|
|
284
349
|
await store.saveMessages({ messages: messages, format: 'v2' });
|
|
@@ -363,6 +428,132 @@ describe('PostgresStore', () => {
|
|
|
363
428
|
});
|
|
364
429
|
});
|
|
365
430
|
|
|
431
|
+
describe('updateMessages', () => {
|
|
432
|
+
let thread: StorageThreadType;
|
|
433
|
+
|
|
434
|
+
beforeEach(async () => {
|
|
435
|
+
const threadData = createSampleThread();
|
|
436
|
+
thread = await store.saveThread({ thread: threadData as StorageThreadType });
|
|
437
|
+
});
|
|
438
|
+
|
|
439
|
+
it('should update a single field of a message (e.g., role)', async () => {
|
|
440
|
+
const originalMessage = createSampleMessageV2({ threadId: thread.id, role: 'user', thread });
|
|
441
|
+
await store.saveMessages({ messages: [originalMessage], format: 'v2' });
|
|
442
|
+
|
|
443
|
+
const updatedMessages = await store.updateMessages({
|
|
444
|
+
messages: [{ id: originalMessage.id, role: 'assistant' }],
|
|
445
|
+
});
|
|
446
|
+
|
|
447
|
+
expect(updatedMessages).toHaveLength(1);
|
|
448
|
+
expect(updatedMessages[0].role).toBe('assistant');
|
|
449
|
+
expect(updatedMessages[0].content).toEqual(originalMessage.content); // Ensure content is unchanged
|
|
450
|
+
});
|
|
451
|
+
|
|
452
|
+
it('should update only the metadata within the content field, preserving other content', async () => {
|
|
453
|
+
const originalMessage = createSampleMessageV2({
|
|
454
|
+
threadId: thread.id,
|
|
455
|
+
content: { content: 'hello world', parts: [{ type: 'text', text: 'hello world' }] },
|
|
456
|
+
thread,
|
|
457
|
+
});
|
|
458
|
+
await store.saveMessages({ messages: [originalMessage], format: 'v2' });
|
|
459
|
+
|
|
460
|
+
const newMetadata = { someKey: 'someValue' };
|
|
461
|
+
await store.updateMessages({
|
|
462
|
+
messages: [{ id: originalMessage.id, content: { metadata: newMetadata } as any }],
|
|
463
|
+
});
|
|
464
|
+
|
|
465
|
+
const fromDb = await store.getMessages({ threadId: thread.id, format: 'v2' });
|
|
466
|
+
expect(fromDb[0].content.metadata).toEqual(newMetadata);
|
|
467
|
+
expect(fromDb[0].content.content).toBe('hello world');
|
|
468
|
+
expect(fromDb[0].content.parts).toEqual([{ type: 'text', text: 'hello world' }]);
|
|
469
|
+
});
|
|
470
|
+
|
|
471
|
+
it('should deep merge metadata, not overwrite it', async () => {
|
|
472
|
+
const originalMessage = createSampleMessageV2({
|
|
473
|
+
threadId: thread.id,
|
|
474
|
+
content: { metadata: { initial: true }, content: 'old content' },
|
|
475
|
+
thread,
|
|
476
|
+
});
|
|
477
|
+
await store.saveMessages({ messages: [originalMessage], format: 'v2' });
|
|
478
|
+
|
|
479
|
+
const newMetadata = { updated: true };
|
|
480
|
+
await store.updateMessages({
|
|
481
|
+
messages: [{ id: originalMessage.id, content: { metadata: newMetadata } as any }],
|
|
482
|
+
});
|
|
483
|
+
|
|
484
|
+
const fromDb = await store.getMessages({ threadId: thread.id, format: 'v2' });
|
|
485
|
+
expect(fromDb[0].content.metadata).toEqual({ initial: true, updated: true });
|
|
486
|
+
});
|
|
487
|
+
|
|
488
|
+
it('should update multiple messages at once', async () => {
|
|
489
|
+
const msg1 = createSampleMessageV2({ threadId: thread.id, role: 'user', thread });
|
|
490
|
+
const msg2 = createSampleMessageV2({ threadId: thread.id, content: { content: 'original' }, thread });
|
|
491
|
+
await store.saveMessages({ messages: [msg1, msg2], format: 'v2' });
|
|
492
|
+
|
|
493
|
+
await store.updateMessages({
|
|
494
|
+
messages: [
|
|
495
|
+
{ id: msg1.id, role: 'assistant' },
|
|
496
|
+
{ id: msg2.id, content: { content: 'updated' } as any },
|
|
497
|
+
],
|
|
498
|
+
});
|
|
499
|
+
|
|
500
|
+
const fromDb = await store.getMessages({ threadId: thread.id, format: 'v2' });
|
|
501
|
+
const updatedMsg1 = fromDb.find(m => m.id === msg1.id)!;
|
|
502
|
+
const updatedMsg2 = fromDb.find(m => m.id === msg2.id)!;
|
|
503
|
+
|
|
504
|
+
expect(updatedMsg1.role).toBe('assistant');
|
|
505
|
+
expect(updatedMsg2.content.content).toBe('updated');
|
|
506
|
+
});
|
|
507
|
+
|
|
508
|
+
it('should update the parent thread updatedAt timestamp', async () => {
|
|
509
|
+
const originalMessage = createSampleMessageV2({ threadId: thread.id, thread });
|
|
510
|
+
await store.saveMessages({ messages: [originalMessage], format: 'v2' });
|
|
511
|
+
const initialThread = await store.getThreadById({ threadId: thread.id });
|
|
512
|
+
|
|
513
|
+
await new Promise(r => setTimeout(r, 10));
|
|
514
|
+
|
|
515
|
+
await store.updateMessages({ messages: [{ id: originalMessage.id, role: 'assistant' }] });
|
|
516
|
+
|
|
517
|
+
const updatedThread = await store.getThreadById({ threadId: thread.id });
|
|
518
|
+
|
|
519
|
+
expect(new Date(updatedThread!.updatedAt).getTime()).toBeGreaterThan(
|
|
520
|
+
new Date(initialThread!.updatedAt).getTime(),
|
|
521
|
+
);
|
|
522
|
+
});
|
|
523
|
+
|
|
524
|
+
it('should update timestamps on both threads when moving a message', async () => {
|
|
525
|
+
const thread2 = await store.saveThread({ thread: createSampleThread() });
|
|
526
|
+
const message = createSampleMessageV2({ threadId: thread.id, thread });
|
|
527
|
+
await store.saveMessages({ messages: [message], format: 'v2' });
|
|
528
|
+
|
|
529
|
+
const initialThread1 = await store.getThreadById({ threadId: thread.id });
|
|
530
|
+
const initialThread2 = await store.getThreadById({ threadId: thread2.id });
|
|
531
|
+
|
|
532
|
+
await new Promise(r => setTimeout(r, 10));
|
|
533
|
+
|
|
534
|
+
await store.updateMessages({
|
|
535
|
+
messages: [{ id: message.id, threadId: thread2.id }],
|
|
536
|
+
});
|
|
537
|
+
|
|
538
|
+
const updatedThread1 = await store.getThreadById({ threadId: thread.id });
|
|
539
|
+
const updatedThread2 = await store.getThreadById({ threadId: thread2.id });
|
|
540
|
+
|
|
541
|
+
expect(new Date(updatedThread1!.updatedAt).getTime()).toBeGreaterThan(
|
|
542
|
+
new Date(initialThread1!.updatedAt).getTime(),
|
|
543
|
+
);
|
|
544
|
+
expect(new Date(updatedThread2!.updatedAt).getTime()).toBeGreaterThan(
|
|
545
|
+
new Date(initialThread2!.updatedAt).getTime(),
|
|
546
|
+
);
|
|
547
|
+
|
|
548
|
+
// Verify the message was moved
|
|
549
|
+
const thread1Messages = await store.getMessages({ threadId: thread.id, format: 'v2' });
|
|
550
|
+
const thread2Messages = await store.getMessages({ threadId: thread2.id, format: 'v2' });
|
|
551
|
+
expect(thread1Messages).toHaveLength(0);
|
|
552
|
+
expect(thread2Messages).toHaveLength(1);
|
|
553
|
+
expect(thread2Messages[0].id).toBe(message.id);
|
|
554
|
+
});
|
|
555
|
+
});
|
|
556
|
+
|
|
366
557
|
describe('Edge Cases and Error Handling', () => {
|
|
367
558
|
it('should handle large metadata objects', async () => {
|
|
368
559
|
const thread = createSampleThread();
|
package/src/storage/index.ts
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import { MessageList } from '@mastra/core/agent';
|
|
2
|
-
import type { MastraMessageV2 } from '@mastra/core/agent';
|
|
2
|
+
import type { MastraMessageContentV2, MastraMessageV2 } from '@mastra/core/agent';
|
|
3
3
|
import type { MetricResult } from '@mastra/core/eval';
|
|
4
4
|
import type { MastraMessageV1, StorageThreadType } from '@mastra/core/memory';
|
|
5
5
|
import {
|
|
@@ -1243,4 +1243,134 @@ export class PostgresStore extends MastraStorage {
|
|
|
1243
1243
|
hasMore: currentOffset + (rows?.length ?? 0) < total,
|
|
1244
1244
|
};
|
|
1245
1245
|
}
|
|
1246
|
+
|
|
1247
|
+
async updateMessages({
|
|
1248
|
+
messages,
|
|
1249
|
+
}: {
|
|
1250
|
+
messages: (Partial<Omit<MastraMessageV2, 'createdAt'>> & {
|
|
1251
|
+
id: string;
|
|
1252
|
+
content?: {
|
|
1253
|
+
metadata?: MastraMessageContentV2['metadata'];
|
|
1254
|
+
content?: MastraMessageContentV2['content'];
|
|
1255
|
+
};
|
|
1256
|
+
})[];
|
|
1257
|
+
}): Promise<MastraMessageV2[]> {
|
|
1258
|
+
if (messages.length === 0) {
|
|
1259
|
+
return [];
|
|
1260
|
+
}
|
|
1261
|
+
|
|
1262
|
+
const messageIds = messages.map(m => m.id);
|
|
1263
|
+
|
|
1264
|
+
const selectQuery = `SELECT id, content, role, type, "createdAt", thread_id AS "threadId", "resourceId" FROM ${this.getTableName(
|
|
1265
|
+
TABLE_MESSAGES,
|
|
1266
|
+
)} WHERE id IN ($1:list)`;
|
|
1267
|
+
|
|
1268
|
+
const existingMessagesDb = await this.db.manyOrNone(selectQuery, [messageIds]);
|
|
1269
|
+
|
|
1270
|
+
if (existingMessagesDb.length === 0) {
|
|
1271
|
+
return [];
|
|
1272
|
+
}
|
|
1273
|
+
|
|
1274
|
+
// Parse content from string to object for merging
|
|
1275
|
+
const existingMessages: MastraMessageV2[] = existingMessagesDb.map(msg => {
|
|
1276
|
+
if (typeof msg.content === 'string') {
|
|
1277
|
+
try {
|
|
1278
|
+
msg.content = JSON.parse(msg.content);
|
|
1279
|
+
} catch {
|
|
1280
|
+
// ignore if not valid json
|
|
1281
|
+
}
|
|
1282
|
+
}
|
|
1283
|
+
return msg as MastraMessageV2;
|
|
1284
|
+
});
|
|
1285
|
+
|
|
1286
|
+
const threadIdsToUpdate = new Set<string>();
|
|
1287
|
+
|
|
1288
|
+
await this.db.tx(async t => {
|
|
1289
|
+
const queries = [];
|
|
1290
|
+
const columnMapping: Record<string, string> = {
|
|
1291
|
+
threadId: 'thread_id',
|
|
1292
|
+
};
|
|
1293
|
+
|
|
1294
|
+
for (const existingMessage of existingMessages) {
|
|
1295
|
+
const updatePayload = messages.find(m => m.id === existingMessage.id);
|
|
1296
|
+
if (!updatePayload) continue;
|
|
1297
|
+
|
|
1298
|
+
const { id, ...fieldsToUpdate } = updatePayload;
|
|
1299
|
+
if (Object.keys(fieldsToUpdate).length === 0) continue;
|
|
1300
|
+
|
|
1301
|
+
threadIdsToUpdate.add(existingMessage.threadId!);
|
|
1302
|
+
if (updatePayload.threadId && updatePayload.threadId !== existingMessage.threadId) {
|
|
1303
|
+
threadIdsToUpdate.add(updatePayload.threadId);
|
|
1304
|
+
}
|
|
1305
|
+
|
|
1306
|
+
const setClauses: string[] = [];
|
|
1307
|
+
const values: any[] = [];
|
|
1308
|
+
let paramIndex = 1;
|
|
1309
|
+
|
|
1310
|
+
const updatableFields = { ...fieldsToUpdate };
|
|
1311
|
+
|
|
1312
|
+
// Special handling for content: merge in code, then update the whole field
|
|
1313
|
+
if (updatableFields.content) {
|
|
1314
|
+
const newContent = {
|
|
1315
|
+
...existingMessage.content,
|
|
1316
|
+
...updatableFields.content,
|
|
1317
|
+
// Deep merge metadata if it exists on both
|
|
1318
|
+
...(existingMessage.content?.metadata && updatableFields.content.metadata
|
|
1319
|
+
? {
|
|
1320
|
+
metadata: {
|
|
1321
|
+
...existingMessage.content.metadata,
|
|
1322
|
+
...updatableFields.content.metadata,
|
|
1323
|
+
},
|
|
1324
|
+
}
|
|
1325
|
+
: {}),
|
|
1326
|
+
};
|
|
1327
|
+
setClauses.push(`content = $${paramIndex++}`);
|
|
1328
|
+
values.push(newContent);
|
|
1329
|
+
delete updatableFields.content;
|
|
1330
|
+
}
|
|
1331
|
+
|
|
1332
|
+
for (const key in updatableFields) {
|
|
1333
|
+
if (Object.prototype.hasOwnProperty.call(updatableFields, key)) {
|
|
1334
|
+
const dbColumn = columnMapping[key] || key;
|
|
1335
|
+
setClauses.push(`"${dbColumn}" = $${paramIndex++}`);
|
|
1336
|
+
values.push(updatableFields[key as keyof typeof updatableFields]);
|
|
1337
|
+
}
|
|
1338
|
+
}
|
|
1339
|
+
|
|
1340
|
+
if (setClauses.length > 0) {
|
|
1341
|
+
values.push(id);
|
|
1342
|
+
const sql = `UPDATE ${this.getTableName(
|
|
1343
|
+
TABLE_MESSAGES,
|
|
1344
|
+
)} SET ${setClauses.join(', ')} WHERE id = $${paramIndex}`;
|
|
1345
|
+
queries.push(t.none(sql, values));
|
|
1346
|
+
}
|
|
1347
|
+
}
|
|
1348
|
+
|
|
1349
|
+
if (threadIdsToUpdate.size > 0) {
|
|
1350
|
+
queries.push(
|
|
1351
|
+
t.none(`UPDATE ${this.getTableName(TABLE_THREADS)} SET "updatedAt" = NOW() WHERE id IN ($1:list)`, [
|
|
1352
|
+
Array.from(threadIdsToUpdate),
|
|
1353
|
+
]),
|
|
1354
|
+
);
|
|
1355
|
+
}
|
|
1356
|
+
|
|
1357
|
+
if (queries.length > 0) {
|
|
1358
|
+
await t.batch(queries);
|
|
1359
|
+
}
|
|
1360
|
+
});
|
|
1361
|
+
|
|
1362
|
+
// Re-fetch to return the fully updated messages
|
|
1363
|
+
const updatedMessages = await this.db.manyOrNone<MastraMessageV2>(selectQuery, [messageIds]);
|
|
1364
|
+
|
|
1365
|
+
return (updatedMessages || []).map(message => {
|
|
1366
|
+
if (typeof message.content === 'string') {
|
|
1367
|
+
try {
|
|
1368
|
+
message.content = JSON.parse(message.content);
|
|
1369
|
+
} catch {
|
|
1370
|
+
/* ignore */
|
|
1371
|
+
}
|
|
1372
|
+
}
|
|
1373
|
+
return message;
|
|
1374
|
+
});
|
|
1375
|
+
}
|
|
1246
1376
|
}
|