@mastra/memory 0.0.2-alpha.9 → 0.1.0-alpha.65

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,492 +0,0 @@
1
- import { MastraMemory, MessageType, ThreadType, MessageResponse } from '@mastra/core';
2
- import { ToolResultPart, Message as AiMessage, TextPart } from 'ai';
3
- import crypto from 'crypto';
4
- import pg from 'pg';
5
-
6
- const { Pool } = pg;
7
-
8
- export class PgMemory extends MastraMemory {
9
- private pool: pg.Pool;
10
- hasTables: boolean = false;
11
- constructor(config: { connectionString: string; maxTokens?: number }) {
12
- super();
13
- this.pool = new Pool({ connectionString: config.connectionString });
14
- this.MAX_CONTEXT_TOKENS = config.maxTokens;
15
- }
16
-
17
- /**
18
- * Threads
19
- */
20
-
21
- async getThreadById({ threadId }: { threadId: string }): Promise<ThreadType | null> {
22
- await this.ensureTablesExist();
23
-
24
- const client = await this.pool.connect();
25
- try {
26
- const result = await client.query<ThreadType>(
27
- `
28
- SELECT id, title, created_at AS createdAt, updated_at AS updatedAt, resourceid, metadata
29
- FROM mastra_threads
30
- WHERE id = $1
31
- `,
32
- [threadId],
33
- );
34
-
35
- return result.rows[0] || null;
36
- } finally {
37
- client.release();
38
- }
39
- }
40
-
41
- async getThreadsByResourceId({ resourceid }: { resourceid: string }): Promise<ThreadType[]> {
42
- await this.ensureTablesExist();
43
-
44
- const client = await this.pool.connect();
45
- try {
46
- const result = await client.query<ThreadType>(
47
- `
48
- SELECT id, title, resourceid, created_at AS createdAt, updated_at AS updatedAt, metadata
49
- FROM mastra_threads
50
- WHERE resourceid = $1
51
- `,
52
- [resourceid],
53
- );
54
- return result.rows;
55
- } finally {
56
- client.release();
57
- }
58
- }
59
-
60
- async saveThread({ thread }: { thread: ThreadType }): Promise<ThreadType> {
61
- await this.ensureTablesExist();
62
-
63
- const client = await this.pool.connect();
64
- try {
65
- const { id, title, createdAt, updatedAt, resourceid, metadata } = thread;
66
- const result = await client.query<ThreadType>(
67
- `
68
- INSERT INTO mastra_threads (id, title, created_at, updated_at, resourceid, metadata)
69
- VALUES ($1, $2, $3, $4, $5, $6)
70
- ON CONFLICT (id) DO UPDATE SET title = $2, updated_at = $4, resourceid = $5, metadata = $6
71
- RETURNING id, title, created_at AS createdAt, updated_at AS updatedAt, resourceid, metadata
72
- `,
73
- [id, title, createdAt, updatedAt, resourceid, JSON.stringify(metadata)],
74
- );
75
- return result?.rows?.[0]!;
76
- } finally {
77
- client.release();
78
- }
79
- }
80
-
81
- async updateThread(id: string, title: string, metadata: Record<string, unknown>): Promise<ThreadType> {
82
- const client = await this.pool.connect();
83
- try {
84
- const result = await client.query<ThreadType>(
85
- `
86
- UPDATE mastra_threads
87
- SET title = $1, metadata = $2, updated_at = NOW()
88
- WHERE id = $3
89
- RETURNING *
90
- `,
91
- [title, JSON.stringify(metadata), id],
92
- );
93
- return result?.rows?.[0]!;
94
- } finally {
95
- client.release();
96
- }
97
- }
98
-
99
- async deleteThread(id: string): Promise<void> {
100
- const client = await this.pool.connect();
101
- try {
102
- await client.query(
103
- `
104
- DELETE FROM mastra_messages
105
- WHERE thread_id = $1
106
- `,
107
- [id],
108
- );
109
-
110
- await client.query(
111
- `
112
- DELETE FROM mastra_threads
113
- WHERE id = $1
114
- `,
115
- [id],
116
- );
117
- } finally {
118
- client.release();
119
- }
120
- }
121
-
122
- /**
123
- * Tool Cache
124
- */
125
-
126
- async validateToolCallArgs({ hashedArgs }: { hashedArgs: string }): Promise<boolean> {
127
- await this.ensureTablesExist();
128
-
129
- const client = await this.pool.connect();
130
-
131
- try {
132
- const toolArgsResult = await client.query<{ toolCallIds: string; toolCallArgs: string; createdAt: string }>(
133
- ` SELECT tool_call_ids as toolCallIds,
134
- tool_call_args as toolCallArgs,
135
- created_at AS createdAt
136
- FROM mastra_messages
137
- WHERE tool_call_args::jsonb @> $1
138
- AND tool_call_args_expire_at > $2
139
- ORDER BY created_at ASC
140
- LIMIT 1`,
141
- [JSON.stringify([hashedArgs]), new Date().toISOString()],
142
- );
143
-
144
- return toolArgsResult.rows.length > 0;
145
- } catch (error) {
146
- console.log('error checking if valid arg exists====', error);
147
- return false;
148
- } finally {
149
- client.release();
150
- }
151
- }
152
-
153
- async getToolResult({
154
- threadId,
155
- toolArgs,
156
- toolName,
157
- }: {
158
- threadId: string;
159
- toolArgs: Record<string, unknown>;
160
- toolName: string;
161
- }): Promise<ToolResultPart['result'] | null> {
162
- await this.ensureTablesExist();
163
- console.log('checking for cached tool result====', JSON.stringify(toolArgs, null, 2));
164
-
165
- const client = await this.pool.connect();
166
-
167
- try {
168
- const hashedToolArgs = crypto
169
- .createHash('sha256')
170
- .update(JSON.stringify({ args: toolArgs, threadId, toolName }))
171
- .digest('hex');
172
-
173
- const toolArgsResult = await client.query<{ tool_call_ids: string; tool_call_args: string; created_at: string }>(
174
- `SELECT tool_call_ids,
175
- tool_call_args,
176
- created_at
177
- FROM mastra_messages
178
- WHERE tool_call_args::jsonb @> $1
179
- AND tool_call_args_expire_at > $2
180
- ORDER BY created_at ASC
181
- LIMIT 1`,
182
- [JSON.stringify([hashedToolArgs]), new Date().toISOString()],
183
- );
184
-
185
- if (toolArgsResult.rows.length > 0) {
186
- const toolCallArgs = JSON.parse(toolArgsResult.rows[0]?.tool_call_args!) as string[];
187
- const toolCallIds = JSON.parse(toolArgsResult.rows[0]?.tool_call_ids!) as string[];
188
- const createdAt = toolArgsResult.rows[0]?.created_at!;
189
-
190
- const toolCallArgsIndex = toolCallArgs.findIndex(arg => arg === hashedToolArgs);
191
- const correspondingToolCallId = toolCallIds[toolCallArgsIndex];
192
-
193
- const toolResult = await client.query<{ content: string }>(
194
- `SELECT content
195
- FROM mastra_messages
196
- WHERE thread_id = $1
197
- AND tool_call_ids ILIKE $2
198
- AND type = 'tool-result'
199
- AND created_at = $3
200
- LIMIT 1`,
201
- [threadId, `%${correspondingToolCallId}%`, new Date(createdAt).toISOString()],
202
- );
203
-
204
- if (toolResult.rows.length === 0) {
205
- console.log('no tool result found');
206
- return null;
207
- }
208
-
209
- const toolResultContent = JSON.parse(toolResult.rows[0]?.content!) as Array<ToolResultPart>;
210
- const requiredToolResult = toolResultContent.find(part => part.toolCallId === correspondingToolCallId);
211
-
212
- if (requiredToolResult) {
213
- return requiredToolResult.result;
214
- }
215
- }
216
-
217
- return null;
218
- } catch (error) {
219
- console.log('error getting cached tool result====', error);
220
- return null;
221
- } finally {
222
- client.release();
223
- }
224
- }
225
-
226
- async getContextWindow<T extends 'raw' | 'core_message'>({
227
- threadId,
228
- startDate,
229
- endDate,
230
- format = 'raw' as T,
231
- }: {
232
- format?: T;
233
- threadId: string;
234
- startDate?: Date;
235
- endDate?: Date;
236
- }) {
237
- await this.ensureTablesExist();
238
- const client = await this.pool.connect();
239
-
240
- try {
241
- if (this.MAX_CONTEXT_TOKENS) {
242
- // Get messages with token limit and time filter
243
- const result = await client.query<MessageType>(
244
- `WITH RankedMessages AS (
245
- SELECT *,
246
- SUM(tokens) OVER (ORDER BY created_at DESC) as running_total
247
- FROM mastra_messages
248
- WHERE thread_id = $1
249
- AND type = 'text'
250
- ${startDate ? `AND created_at >= '${startDate.toISOString()}'` : ''}
251
- ${endDate ? `AND created_at <= '${endDate.toISOString()}'` : ''}
252
- ORDER BY created_at DESC
253
- )
254
- SELECT id,
255
- content,
256
- role,
257
- type,
258
- created_at AS createdAt,
259
- thread_id AS threadId
260
- FROM RankedMessages
261
- WHERE running_total <= $2
262
- ORDER BY created_at ASC`,
263
- [threadId, this.MAX_CONTEXT_TOKENS],
264
- );
265
-
266
- console.log('Format', format);
267
- return this.parseMessages(result.rows) as MessageResponse<T>;
268
- }
269
-
270
- //get all messages
271
- const result = await client.query<MessageType>(
272
- `SELECT id,
273
- content,
274
- role,
275
- type,
276
- created_at AS createdAt,
277
- thread_id AS threadId
278
- FROM mastra_messages
279
- WHERE thread_id = $1
280
- AND type = 'text'
281
- ${startDate ? `AND created_at >= '${startDate.toISOString()}'` : ''}
282
- ${endDate ? `AND created_at <= '${endDate.toISOString()}'` : ''}
283
- ORDER BY created_at ASC`,
284
- [threadId],
285
- );
286
-
287
- console.log('Format', format);
288
- return this.parseMessages(result.rows) as MessageResponse<T>;
289
- } catch (error) {
290
- console.log('error getting context window====', error);
291
- return [];
292
- } finally {
293
- client.release();
294
- }
295
- }
296
-
297
- /**
298
- * Messages
299
- */
300
-
301
- async getMessages({ threadId }: { threadId: string }): Promise<{ messages: MessageType[]; uiMessages: AiMessage[] }> {
302
- await this.ensureTablesExist();
303
-
304
- const client = await this.pool.connect();
305
- try {
306
- const result = await client.query<MessageType>(
307
- `
308
- SELECT
309
- id,
310
- content,
311
- role,
312
- type,
313
- created_at AS createdAt,
314
- thread_id AS threadId
315
- FROM mastra_messages
316
- WHERE thread_id = $1
317
- ORDER BY created_at ASC
318
- `,
319
- [threadId],
320
- );
321
-
322
- const messages = this.parseMessages(result.rows);
323
- const uiMessages = this.convertToUIMessages(messages);
324
-
325
- return { messages, uiMessages };
326
- } finally {
327
- client.release();
328
- }
329
- }
330
-
331
- async saveMessages({ messages }: { messages: MessageType[] }): Promise<MessageType[]> {
332
- await this.ensureTablesExist();
333
-
334
- const client = await this.pool.connect();
335
- try {
336
- await client.query('BEGIN');
337
- for (const message of messages) {
338
- const { id, content, role, createdAt, threadId, toolCallIds, toolCallArgs, type, toolNames } = message;
339
- let tokens = null;
340
- if (type === 'text') {
341
- const contentMssg = role === 'assistant' ? (content as Array<TextPart>)[0]?.text || '' : (content as string);
342
- tokens = this.estimateTokens(contentMssg);
343
- }
344
-
345
- // Hash the toolCallArgs if they exist
346
- const hashedToolCallArgs = toolCallArgs
347
- ? toolCallArgs.map((args, index) =>
348
- crypto
349
- .createHash('sha256')
350
- .update(JSON.stringify({ args, threadId, toolName: toolNames?.[index] }))
351
- .digest('hex'),
352
- )
353
- : null;
354
-
355
- let validArgExists = false;
356
- if (hashedToolCallArgs?.length) {
357
- // Check all args sequentially
358
- validArgExists = true; // Start true and set to false if any check fails
359
- for (let i = 0; i < hashedToolCallArgs.length; i++) {
360
- const isValid = await this.validateToolCallArgs({
361
- hashedArgs: hashedToolCallArgs[i]!,
362
- });
363
- if (!isValid) {
364
- validArgExists = false;
365
- break;
366
- }
367
- }
368
- }
369
-
370
- const toolCallArgsExpireAt = !toolCallArgs
371
- ? null
372
- : validArgExists
373
- ? createdAt
374
- : new Date(createdAt.getTime() + 5 * 60 * 1000); // 5 minutes
375
-
376
- await client.query(
377
- `
378
- INSERT INTO mastra_messages (id, content, role, created_at, thread_id, tool_call_ids, tool_call_args, type, tokens, tool_call_args_expire_at)
379
- VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
380
- `,
381
- [
382
- id,
383
- JSON.stringify(content),
384
- role,
385
- createdAt.toISOString(),
386
- threadId,
387
- JSON.stringify(toolCallIds),
388
- JSON.stringify(hashedToolCallArgs),
389
- type,
390
- tokens,
391
- toolCallArgsExpireAt?.toISOString(),
392
- ],
393
- );
394
- }
395
- await client.query('COMMIT');
396
- return messages;
397
- } catch (error) {
398
- await client.query('ROLLBACK');
399
- throw error;
400
- } finally {
401
- client.release();
402
- }
403
- }
404
-
405
- async deleteMessage(id: string): Promise<void> {
406
- const client = await this.pool.connect();
407
- try {
408
- await client.query(
409
- `
410
- DELETE FROM mastra_messages
411
- WHERE id = $1
412
- `,
413
- [id],
414
- );
415
- } finally {
416
- client.release();
417
- }
418
- }
419
-
420
- /**
421
- * Table Management
422
- */
423
-
424
- async drop() {
425
- const client = await this.pool.connect();
426
- await client.query('DELETE FROM mastra_messages');
427
- await client.query('DELETE FROM mastra_threads');
428
- client.release();
429
- await this.pool.end();
430
- }
431
-
432
- async ensureTablesExist(): Promise<void> {
433
- if (this.hasTables) {
434
- return;
435
- }
436
-
437
- const client = await this.pool.connect();
438
- try {
439
- // Check if the threads table exists
440
- const threadsResult = await client.query<{ exists: boolean }>(`
441
- SELECT EXISTS (
442
- SELECT 1
443
- FROM information_schema.tables
444
- WHERE table_name = 'mastra_threads'
445
- );
446
- `);
447
-
448
- if (!threadsResult?.rows?.[0]?.exists) {
449
- await client.query(`
450
- CREATE TABLE IF NOT EXISTS mastra_threads (
451
- id UUID PRIMARY KEY,
452
- resourceid TEXT,
453
- title TEXT,
454
- created_at TIMESTAMP WITH TIME ZONE NOT NULL,
455
- updated_at TIMESTAMP WITH TIME ZONE NOT NULL,
456
- metadata JSONB
457
- );
458
- `);
459
- }
460
-
461
- // Check if the messages table exists
462
- const messagesResult = await client.query<{ exists: boolean }>(`
463
- SELECT EXISTS (
464
- SELECT 1
465
- FROM information_schema.tables
466
- WHERE table_name = 'mastra_messages'
467
- );
468
- `);
469
-
470
- if (!messagesResult?.rows?.[0]?.exists) {
471
- await client.query(`
472
- CREATE TABLE IF NOT EXISTS mastra_messages (
473
- id UUID PRIMARY KEY,
474
- content TEXT NOT NULL,
475
- role VARCHAR(20) NOT NULL,
476
- created_at TIMESTAMP WITH TIME ZONE NOT NULL,
477
- tool_call_ids TEXT DEFAULT NULL,
478
- tool_call_args TEXT DEFAULT NULL,
479
- tool_call_args_expire_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
480
- type VARCHAR(20) NOT NULL,
481
- tokens INTEGER DEFAULT NULL,
482
- thread_id UUID NOT NULL,
483
- FOREIGN KEY (thread_id) REFERENCES mastra_threads(id)
484
- );
485
- `);
486
- }
487
- } finally {
488
- client.release();
489
- this.hasTables = true;
490
- }
491
- }
492
- }