@mastra/pg 0.1.0-alpha.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +46 -0
- package/LICENSE +44 -0
- package/README.md +155 -0
- package/dist/index.d.ts +82 -0
- package/dist/index.js +886 -0
- package/docker-compose.yaml +14 -0
- package/package.json +39 -0
- package/src/index.ts +2 -0
- package/src/storage/index.test.ts +379 -0
- package/src/storage/index.ts +477 -0
- package/src/vector/filter.test.ts +967 -0
- package/src/vector/filter.ts +106 -0
- package/src/vector/index.test.ts +1205 -0
- package/src/vector/index.ts +282 -0
- package/src/vector/sql-builder.ts +285 -0
- package/tsconfig.json +15 -0
- package/vitest.config.ts +11 -0
|
@@ -0,0 +1,477 @@
|
|
|
1
|
+
import { type MessageType, type StorageThreadType } from '@mastra/core/memory';
|
|
2
|
+
import { MastraStorage, type StorageColumn, type StorageGetMessagesArg, type TABLE_NAMES } from '@mastra/core/storage';
|
|
3
|
+
import { type WorkflowRunState } from '@mastra/core/workflows';
|
|
4
|
+
import pgPromise from 'pg-promise';
|
|
5
|
+
|
|
6
|
+
export type PostgresConfig =
|
|
7
|
+
| {
|
|
8
|
+
host: string;
|
|
9
|
+
port: number;
|
|
10
|
+
database: string;
|
|
11
|
+
user: string;
|
|
12
|
+
password: string;
|
|
13
|
+
}
|
|
14
|
+
| {
|
|
15
|
+
connectionString: string;
|
|
16
|
+
};
|
|
17
|
+
|
|
18
|
+
export class PostgresStore extends MastraStorage {
|
|
19
|
+
private db: pgPromise.IDatabase<{}>;
|
|
20
|
+
private pgp: pgPromise.IMain;
|
|
21
|
+
|
|
22
|
+
constructor(config: PostgresConfig) {
|
|
23
|
+
super({ name: 'PostgresStore' });
|
|
24
|
+
this.pgp = pgPromise();
|
|
25
|
+
this.db = this.pgp(
|
|
26
|
+
`connectionString` in config
|
|
27
|
+
? { connectionString: config.connectionString }
|
|
28
|
+
: {
|
|
29
|
+
host: config.host,
|
|
30
|
+
port: config.port,
|
|
31
|
+
database: config.database,
|
|
32
|
+
user: config.user,
|
|
33
|
+
password: config.password,
|
|
34
|
+
},
|
|
35
|
+
);
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
async createTable({
|
|
39
|
+
tableName,
|
|
40
|
+
schema,
|
|
41
|
+
}: {
|
|
42
|
+
tableName: TABLE_NAMES;
|
|
43
|
+
schema: Record<string, StorageColumn>;
|
|
44
|
+
}): Promise<void> {
|
|
45
|
+
try {
|
|
46
|
+
const columns = Object.entries(schema)
|
|
47
|
+
.map(([name, def]) => {
|
|
48
|
+
const constraints = [];
|
|
49
|
+
if (def.primaryKey) constraints.push('PRIMARY KEY');
|
|
50
|
+
if (!def.nullable) constraints.push('NOT NULL');
|
|
51
|
+
return `"${name}" ${def.type.toUpperCase()} ${constraints.join(' ')}`;
|
|
52
|
+
})
|
|
53
|
+
.join(',\n');
|
|
54
|
+
|
|
55
|
+
const sql = `
|
|
56
|
+
CREATE TABLE IF NOT EXISTS ${tableName} (
|
|
57
|
+
${columns}
|
|
58
|
+
);
|
|
59
|
+
${
|
|
60
|
+
tableName === MastraStorage.TABLE_WORKFLOW_SNAPSHOT
|
|
61
|
+
? `
|
|
62
|
+
DO $$ BEGIN
|
|
63
|
+
IF NOT EXISTS (
|
|
64
|
+
SELECT 1 FROM pg_constraint WHERE conname = 'mastra_workflow_snapshot_workflow_name_run_id_key'
|
|
65
|
+
) THEN
|
|
66
|
+
ALTER TABLE ${tableName}
|
|
67
|
+
ADD CONSTRAINT mastra_workflow_snapshot_workflow_name_run_id_key
|
|
68
|
+
UNIQUE (workflow_name, run_id);
|
|
69
|
+
END IF;
|
|
70
|
+
END $$;
|
|
71
|
+
`
|
|
72
|
+
: ''
|
|
73
|
+
}
|
|
74
|
+
`;
|
|
75
|
+
|
|
76
|
+
await this.db.none(sql);
|
|
77
|
+
} catch (error) {
|
|
78
|
+
console.error(`Error creating table ${tableName}:`, error);
|
|
79
|
+
throw error;
|
|
80
|
+
}
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
async clearTable({ tableName }: { tableName: TABLE_NAMES }): Promise<void> {
|
|
84
|
+
try {
|
|
85
|
+
await this.db.none(`TRUNCATE TABLE ${tableName} CASCADE`);
|
|
86
|
+
} catch (error) {
|
|
87
|
+
console.error(`Error clearing table ${tableName}:`, error);
|
|
88
|
+
throw error;
|
|
89
|
+
}
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
async insert({ tableName, record }: { tableName: TABLE_NAMES; record: Record<string, any> }): Promise<void> {
|
|
93
|
+
try {
|
|
94
|
+
const columns = Object.keys(record);
|
|
95
|
+
const values = Object.values(record);
|
|
96
|
+
const placeholders = values.map((_, i) => `$${i + 1}`).join(', ');
|
|
97
|
+
|
|
98
|
+
await this.db.none(
|
|
99
|
+
`INSERT INTO ${tableName} (${columns.map(c => `"${c}"`).join(', ')}) VALUES (${placeholders})`,
|
|
100
|
+
values,
|
|
101
|
+
);
|
|
102
|
+
} catch (error) {
|
|
103
|
+
console.error(`Error inserting into ${tableName}:`, error);
|
|
104
|
+
throw error;
|
|
105
|
+
}
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
async load<R>({ tableName, keys }: { tableName: TABLE_NAMES; keys: Record<string, string> }): Promise<R | null> {
|
|
109
|
+
try {
|
|
110
|
+
const keyEntries = Object.entries(keys);
|
|
111
|
+
const conditions = keyEntries.map(([key], index) => `"${key}" = $${index + 1}`).join(' AND ');
|
|
112
|
+
const values = keyEntries.map(([_, value]) => value);
|
|
113
|
+
|
|
114
|
+
const result = await this.db.oneOrNone<R>(`SELECT * FROM ${tableName} WHERE ${conditions}`, values);
|
|
115
|
+
|
|
116
|
+
if (!result) {
|
|
117
|
+
return null;
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
// If this is a workflow snapshot, parse the snapshot field
|
|
121
|
+
if (tableName === MastraStorage.TABLE_WORKFLOW_SNAPSHOT) {
|
|
122
|
+
const snapshot = result as any;
|
|
123
|
+
if (typeof snapshot.snapshot === 'string') {
|
|
124
|
+
snapshot.snapshot = JSON.parse(snapshot.snapshot);
|
|
125
|
+
}
|
|
126
|
+
return snapshot;
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
return result;
|
|
130
|
+
} catch (error) {
|
|
131
|
+
console.error(`Error loading from ${tableName}:`, error);
|
|
132
|
+
throw error;
|
|
133
|
+
}
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
async getThreadById({ threadId }: { threadId: string }): Promise<StorageThreadType | null> {
|
|
137
|
+
try {
|
|
138
|
+
const thread = await this.db.oneOrNone<StorageThreadType>(
|
|
139
|
+
`SELECT
|
|
140
|
+
id,
|
|
141
|
+
"resourceId",
|
|
142
|
+
title,
|
|
143
|
+
metadata,
|
|
144
|
+
"createdAt",
|
|
145
|
+
"updatedAt"
|
|
146
|
+
FROM "${MastraStorage.TABLE_THREADS}"
|
|
147
|
+
WHERE id = $1`,
|
|
148
|
+
[threadId],
|
|
149
|
+
);
|
|
150
|
+
|
|
151
|
+
if (!thread) {
|
|
152
|
+
return null;
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
return {
|
|
156
|
+
...thread,
|
|
157
|
+
metadata: typeof thread.metadata === 'string' ? JSON.parse(thread.metadata) : thread.metadata,
|
|
158
|
+
createdAt: thread.createdAt,
|
|
159
|
+
updatedAt: thread.updatedAt,
|
|
160
|
+
};
|
|
161
|
+
} catch (error) {
|
|
162
|
+
console.error(`Error getting thread ${threadId}:`, error);
|
|
163
|
+
throw error;
|
|
164
|
+
}
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
async getThreadsByResourceId({ resourceId }: { resourceId: string }): Promise<StorageThreadType[]> {
|
|
168
|
+
try {
|
|
169
|
+
const threads = await this.db.manyOrNone<StorageThreadType>(
|
|
170
|
+
`SELECT
|
|
171
|
+
id,
|
|
172
|
+
"resourceId",
|
|
173
|
+
title,
|
|
174
|
+
metadata,
|
|
175
|
+
"createdAt",
|
|
176
|
+
"updatedAt"
|
|
177
|
+
FROM "${MastraStorage.TABLE_THREADS}"
|
|
178
|
+
WHERE "resourceId" = $1`,
|
|
179
|
+
[resourceId],
|
|
180
|
+
);
|
|
181
|
+
|
|
182
|
+
return threads.map(thread => ({
|
|
183
|
+
...thread,
|
|
184
|
+
metadata: typeof thread.metadata === 'string' ? JSON.parse(thread.metadata) : thread.metadata,
|
|
185
|
+
createdAt: thread.createdAt,
|
|
186
|
+
updatedAt: thread.updatedAt,
|
|
187
|
+
}));
|
|
188
|
+
} catch (error) {
|
|
189
|
+
console.error(`Error getting threads for resource ${resourceId}:`, error);
|
|
190
|
+
throw error;
|
|
191
|
+
}
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
async saveThread({ thread }: { thread: StorageThreadType }): Promise<StorageThreadType> {
|
|
195
|
+
try {
|
|
196
|
+
await this.db.none(
|
|
197
|
+
`INSERT INTO "${MastraStorage.TABLE_THREADS}" (
|
|
198
|
+
id,
|
|
199
|
+
"resourceId",
|
|
200
|
+
title,
|
|
201
|
+
metadata,
|
|
202
|
+
"createdAt",
|
|
203
|
+
"updatedAt"
|
|
204
|
+
) VALUES ($1, $2, $3, $4, $5, $6)
|
|
205
|
+
ON CONFLICT (id) DO UPDATE SET
|
|
206
|
+
"resourceId" = EXCLUDED."resourceId",
|
|
207
|
+
title = EXCLUDED.title,
|
|
208
|
+
metadata = EXCLUDED.metadata,
|
|
209
|
+
"createdAt" = EXCLUDED."createdAt",
|
|
210
|
+
"updatedAt" = EXCLUDED."updatedAt"`,
|
|
211
|
+
[
|
|
212
|
+
thread.id,
|
|
213
|
+
thread.resourceId,
|
|
214
|
+
thread.title,
|
|
215
|
+
thread.metadata ? JSON.stringify(thread.metadata) : null,
|
|
216
|
+
thread.createdAt,
|
|
217
|
+
thread.updatedAt,
|
|
218
|
+
],
|
|
219
|
+
);
|
|
220
|
+
|
|
221
|
+
return thread;
|
|
222
|
+
} catch (error) {
|
|
223
|
+
console.error('Error saving thread:', error);
|
|
224
|
+
throw error;
|
|
225
|
+
}
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
async updateThread({
|
|
229
|
+
id,
|
|
230
|
+
title,
|
|
231
|
+
metadata,
|
|
232
|
+
}: {
|
|
233
|
+
id: string;
|
|
234
|
+
title: string;
|
|
235
|
+
metadata: Record<string, unknown>;
|
|
236
|
+
}): Promise<StorageThreadType> {
|
|
237
|
+
try {
|
|
238
|
+
// First get the existing thread to merge metadata
|
|
239
|
+
const existingThread = await this.getThreadById({ threadId: id });
|
|
240
|
+
if (!existingThread) {
|
|
241
|
+
throw new Error(`Thread ${id} not found`);
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
// Merge the existing metadata with the new metadata
|
|
245
|
+
const mergedMetadata = {
|
|
246
|
+
...existingThread.metadata,
|
|
247
|
+
...metadata,
|
|
248
|
+
};
|
|
249
|
+
|
|
250
|
+
const thread = await this.db.one<StorageThreadType>(
|
|
251
|
+
`UPDATE "${MastraStorage.TABLE_THREADS}"
|
|
252
|
+
SET title = $1,
|
|
253
|
+
metadata = $2,
|
|
254
|
+
"updatedAt" = $3
|
|
255
|
+
WHERE id = $4
|
|
256
|
+
RETURNING *`,
|
|
257
|
+
[title, mergedMetadata, new Date().toISOString(), id],
|
|
258
|
+
);
|
|
259
|
+
|
|
260
|
+
return {
|
|
261
|
+
...thread,
|
|
262
|
+
metadata: typeof thread.metadata === 'string' ? JSON.parse(thread.metadata) : thread.metadata,
|
|
263
|
+
createdAt: thread.createdAt,
|
|
264
|
+
updatedAt: thread.updatedAt,
|
|
265
|
+
};
|
|
266
|
+
} catch (error) {
|
|
267
|
+
console.error('Error updating thread:', error);
|
|
268
|
+
throw error;
|
|
269
|
+
}
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
async deleteThread({ threadId }: { threadId: string }): Promise<void> {
|
|
273
|
+
try {
|
|
274
|
+
await this.db.tx(async t => {
|
|
275
|
+
// First delete all messages associated with this thread
|
|
276
|
+
await t.none(`DELETE FROM "${MastraStorage.TABLE_MESSAGES}" WHERE thread_id = $1`, [threadId]);
|
|
277
|
+
|
|
278
|
+
// Then delete the thread
|
|
279
|
+
await t.none(`DELETE FROM "${MastraStorage.TABLE_THREADS}" WHERE id = $1`, [threadId]);
|
|
280
|
+
});
|
|
281
|
+
} catch (error) {
|
|
282
|
+
console.error('Error deleting thread:', error);
|
|
283
|
+
throw error;
|
|
284
|
+
}
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
async getMessages<T = unknown>({ threadId, selectBy }: StorageGetMessagesArg): Promise<T> {
|
|
288
|
+
try {
|
|
289
|
+
const messages: any[] = [];
|
|
290
|
+
const limit = typeof selectBy?.last === `number` ? selectBy.last : 40;
|
|
291
|
+
const include = selectBy?.include || [];
|
|
292
|
+
|
|
293
|
+
if (include.length) {
|
|
294
|
+
const includeResult = await this.db.manyOrNone(
|
|
295
|
+
`
|
|
296
|
+
WITH ordered_messages AS (
|
|
297
|
+
SELECT
|
|
298
|
+
*,
|
|
299
|
+
ROW_NUMBER() OVER (ORDER BY "createdAt") as row_num
|
|
300
|
+
FROM "${MastraStorage.TABLE_MESSAGES}"
|
|
301
|
+
WHERE thread_id = $1
|
|
302
|
+
)
|
|
303
|
+
SELECT DISTINCT ON (m.id)
|
|
304
|
+
m.id,
|
|
305
|
+
m.content,
|
|
306
|
+
m.role,
|
|
307
|
+
m.type,
|
|
308
|
+
m."createdAt",
|
|
309
|
+
m.thread_id AS "threadId"
|
|
310
|
+
FROM ordered_messages m
|
|
311
|
+
WHERE m.id = ANY($2)
|
|
312
|
+
OR EXISTS (
|
|
313
|
+
SELECT 1 FROM ordered_messages target
|
|
314
|
+
WHERE target.id = ANY($2)
|
|
315
|
+
AND (
|
|
316
|
+
-- Get previous messages based on the max withPreviousMessages
|
|
317
|
+
(m.row_num >= target.row_num - $3 AND m.row_num < target.row_num)
|
|
318
|
+
OR
|
|
319
|
+
-- Get next messages based on the max withNextMessages
|
|
320
|
+
(m.row_num <= target.row_num + $4 AND m.row_num > target.row_num)
|
|
321
|
+
)
|
|
322
|
+
)
|
|
323
|
+
ORDER BY m.id, m."createdAt"
|
|
324
|
+
`,
|
|
325
|
+
[
|
|
326
|
+
threadId,
|
|
327
|
+
include.map(i => i.id),
|
|
328
|
+
Math.max(...include.map(i => i.withPreviousMessages || 0)),
|
|
329
|
+
Math.max(...include.map(i => i.withNextMessages || 0)),
|
|
330
|
+
],
|
|
331
|
+
);
|
|
332
|
+
|
|
333
|
+
messages.push(...includeResult);
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
// Then get the remaining messages, excluding the ids we just fetched
|
|
337
|
+
const result = await this.db.manyOrNone(
|
|
338
|
+
`
|
|
339
|
+
SELECT
|
|
340
|
+
id,
|
|
341
|
+
content,
|
|
342
|
+
role,
|
|
343
|
+
type,
|
|
344
|
+
"createdAt",
|
|
345
|
+
thread_id AS "threadId"
|
|
346
|
+
FROM "${MastraStorage.TABLE_MESSAGES}"
|
|
347
|
+
WHERE thread_id = $1
|
|
348
|
+
AND id != ALL($2)
|
|
349
|
+
ORDER BY "createdAt" DESC
|
|
350
|
+
LIMIT $3
|
|
351
|
+
`,
|
|
352
|
+
[threadId, messages.map(m => m.id), limit],
|
|
353
|
+
);
|
|
354
|
+
|
|
355
|
+
messages.push(...result);
|
|
356
|
+
|
|
357
|
+
// Sort all messages by creation date
|
|
358
|
+
messages.sort((a, b) => new Date(a.createdAt).getTime() - new Date(b.createdAt).getTime());
|
|
359
|
+
|
|
360
|
+
// Parse message content
|
|
361
|
+
messages.forEach(message => {
|
|
362
|
+
if (typeof message.content === 'string') {
|
|
363
|
+
try {
|
|
364
|
+
message.content = JSON.parse(message.content);
|
|
365
|
+
} catch (e) {
|
|
366
|
+
// If parsing fails, leave as string
|
|
367
|
+
}
|
|
368
|
+
}
|
|
369
|
+
});
|
|
370
|
+
|
|
371
|
+
return messages as T;
|
|
372
|
+
} catch (error) {
|
|
373
|
+
console.error('Error getting messages:', error);
|
|
374
|
+
throw error;
|
|
375
|
+
}
|
|
376
|
+
}
|
|
377
|
+
|
|
378
|
+
async saveMessages({ messages }: { messages: MessageType[] }): Promise<MessageType[]> {
|
|
379
|
+
if (messages.length === 0) return messages;
|
|
380
|
+
|
|
381
|
+
try {
|
|
382
|
+
const threadId = messages[0]?.threadId;
|
|
383
|
+
if (!threadId) {
|
|
384
|
+
throw new Error('Thread ID is required');
|
|
385
|
+
}
|
|
386
|
+
|
|
387
|
+
// Check if thread exists
|
|
388
|
+
const thread = await this.getThreadById({ threadId });
|
|
389
|
+
if (!thread) {
|
|
390
|
+
throw new Error(`Thread ${threadId} not found`);
|
|
391
|
+
}
|
|
392
|
+
|
|
393
|
+
await this.db.tx(async t => {
|
|
394
|
+
for (const message of messages) {
|
|
395
|
+
await t.none(
|
|
396
|
+
`INSERT INTO "${MastraStorage.TABLE_MESSAGES}" (id, thread_id, content, "createdAt", role, type)
|
|
397
|
+
VALUES ($1, $2, $3, $4, $5, $6)`,
|
|
398
|
+
[
|
|
399
|
+
message.id,
|
|
400
|
+
threadId,
|
|
401
|
+
typeof message.content === 'string' ? message.content : JSON.stringify(message.content),
|
|
402
|
+
message.createdAt || new Date().toISOString(),
|
|
403
|
+
message.role,
|
|
404
|
+
message.type,
|
|
405
|
+
],
|
|
406
|
+
);
|
|
407
|
+
}
|
|
408
|
+
});
|
|
409
|
+
|
|
410
|
+
return messages;
|
|
411
|
+
} catch (error) {
|
|
412
|
+
console.error('Error saving messages:', error);
|
|
413
|
+
throw error;
|
|
414
|
+
}
|
|
415
|
+
}
|
|
416
|
+
|
|
417
|
+
async persistWorkflowSnapshot({
|
|
418
|
+
workflowName,
|
|
419
|
+
runId,
|
|
420
|
+
snapshot,
|
|
421
|
+
}: {
|
|
422
|
+
workflowName: string;
|
|
423
|
+
runId: string;
|
|
424
|
+
snapshot: WorkflowRunState;
|
|
425
|
+
}): Promise<void> {
|
|
426
|
+
try {
|
|
427
|
+
const now = new Date().toISOString();
|
|
428
|
+
await this.db.none(
|
|
429
|
+
`INSERT INTO "${MastraStorage.TABLE_WORKFLOW_SNAPSHOT}" (
|
|
430
|
+
workflow_name,
|
|
431
|
+
run_id,
|
|
432
|
+
snapshot,
|
|
433
|
+
"createdAt",
|
|
434
|
+
"updatedAt"
|
|
435
|
+
) VALUES ($1, $2, $3, $4, $5)
|
|
436
|
+
ON CONFLICT (workflow_name, run_id) DO UPDATE
|
|
437
|
+
SET snapshot = EXCLUDED.snapshot,
|
|
438
|
+
"updatedAt" = EXCLUDED."updatedAt"`,
|
|
439
|
+
[workflowName, runId, JSON.stringify(snapshot), now, now],
|
|
440
|
+
);
|
|
441
|
+
} catch (error) {
|
|
442
|
+
console.error('Error persisting workflow snapshot:', error);
|
|
443
|
+
throw error;
|
|
444
|
+
}
|
|
445
|
+
}
|
|
446
|
+
|
|
447
|
+
async loadWorkflowSnapshot({
|
|
448
|
+
workflowName,
|
|
449
|
+
runId,
|
|
450
|
+
}: {
|
|
451
|
+
workflowName: string;
|
|
452
|
+
runId: string;
|
|
453
|
+
}): Promise<WorkflowRunState | null> {
|
|
454
|
+
try {
|
|
455
|
+
const result = await this.load({
|
|
456
|
+
tableName: MastraStorage.TABLE_WORKFLOW_SNAPSHOT,
|
|
457
|
+
keys: {
|
|
458
|
+
workflow_name: workflowName,
|
|
459
|
+
run_id: runId,
|
|
460
|
+
},
|
|
461
|
+
});
|
|
462
|
+
|
|
463
|
+
if (!result) {
|
|
464
|
+
return null;
|
|
465
|
+
}
|
|
466
|
+
|
|
467
|
+
return (result as any).snapshot;
|
|
468
|
+
} catch (error) {
|
|
469
|
+
console.error('Error loading workflow snapshot:', error);
|
|
470
|
+
throw error;
|
|
471
|
+
}
|
|
472
|
+
}
|
|
473
|
+
|
|
474
|
+
async close(): Promise<void> {
|
|
475
|
+
this.pgp.end();
|
|
476
|
+
}
|
|
477
|
+
}
|