@mastra/memory 0.0.2-alpha.2 → 0.0.2-alpha.21

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,5 @@
1
- import { MastraMemory, MessageType, ThreadType } from '@mastra/core';
2
- import { ToolResultPart, CoreToolMessage, ToolInvocation, Message as AiMessage, TextPart } from 'ai';
1
+ import { MastraMemory, MessageType, ThreadType, MessageResponse } from '@mastra/core';
2
+ import { ToolResultPart, Message as AiMessage, TextPart } from 'ai';
3
3
  import crypto from 'crypto';
4
4
  import pg from 'pg';
5
5
 
@@ -7,236 +7,18 @@ const { Pool } = pg;
7
7
 
8
8
  export class PgMemory extends MastraMemory {
9
9
  private pool: pg.Pool;
10
- private MAX_CONTEXT_TOKENS?: number;
11
-
10
+ hasTables: boolean = false;
12
11
  constructor(config: { connectionString: string; maxTokens?: number }) {
13
12
  super();
14
13
  this.pool = new Pool({ connectionString: config.connectionString });
15
14
  this.MAX_CONTEXT_TOKENS = config.maxTokens;
16
15
  }
17
16
 
18
- async drop() {
19
- const client = await this.pool.connect();
20
- await client.query('DELETE FROM mastra_messages');
21
- await client.query('DELETE FROM mastra_threads');
22
- client.release();
23
- await this.pool.end();
24
- }
25
-
26
- // Simplified token estimation
27
- private estimateTokens(text: string): number {
28
- return Math.ceil(text.split(' ').length * 1.3);
29
- }
30
-
31
- private processMessages(messages: MessageType[]): MessageType[] {
32
- return messages.map(mssg => ({
33
- ...mssg,
34
- content: typeof mssg.content === 'string' ? JSON.parse((mssg as MessageType).content as string) : mssg.content,
35
- }));
36
- }
37
-
38
- private convertToUIMessages(messages: MessageType[]): AiMessage[] {
39
- function addToolMessageToChat({
40
- toolMessage,
41
- messages,
42
- toolResultContents,
43
- }: {
44
- toolMessage: CoreToolMessage;
45
- messages: Array<AiMessage>;
46
- toolResultContents: Array<ToolResultPart>;
47
- }): { chatMessages: Array<AiMessage>; toolResultContents: Array<ToolResultPart> } {
48
- const chatMessages = messages.map(message => {
49
- if (message.toolInvocations) {
50
- return {
51
- ...message,
52
- toolInvocations: message.toolInvocations.map(toolInvocation => {
53
- const toolResult = toolMessage.content.find(tool => tool.toolCallId === toolInvocation.toolCallId);
54
-
55
- if (toolResult) {
56
- return {
57
- ...toolInvocation,
58
- state: 'result',
59
- result: toolResult.result,
60
- };
61
- }
62
-
63
- return toolInvocation;
64
- }),
65
- };
66
- }
67
-
68
- return message;
69
- }) as Array<AiMessage>;
70
-
71
- const resultContents = [...toolResultContents, ...toolMessage.content];
72
-
73
- return { chatMessages, toolResultContents: resultContents };
74
- }
75
-
76
- const { chatMessages } = messages.reduce(
77
- (obj: { chatMessages: Array<AiMessage>; toolResultContents: Array<ToolResultPart> }, message) => {
78
- if (message.role === 'tool') {
79
- return addToolMessageToChat({
80
- toolMessage: message as CoreToolMessage,
81
- messages: obj.chatMessages,
82
- toolResultContents: obj.toolResultContents,
83
- });
84
- }
85
-
86
- let textContent = '';
87
- let toolInvocations: Array<ToolInvocation> = [];
88
-
89
- if (typeof message.content === 'string') {
90
- textContent = message.content;
91
- } else if (Array.isArray(message.content)) {
92
- for (const content of message.content) {
93
- if (content.type === 'text') {
94
- textContent += content.text;
95
- } else if (content.type === 'tool-call') {
96
- const toolResult = obj.toolResultContents.find(tool => tool.toolCallId === content.toolCallId);
97
- toolInvocations.push({
98
- state: toolResult ? 'result' : 'call',
99
- toolCallId: content.toolCallId,
100
- toolName: content.toolName,
101
- args: content.args,
102
- result: toolResult?.result,
103
- });
104
- }
105
- }
106
- }
107
-
108
- obj.chatMessages.push({
109
- id: message.id,
110
- role: message.role as AiMessage['role'],
111
- content: textContent,
112
- toolInvocations,
113
- });
114
-
115
- return obj;
116
- },
117
- { chatMessages: [], toolResultContents: [] } as {
118
- chatMessages: Array<AiMessage>;
119
- toolResultContents: Array<ToolResultPart>;
120
- },
121
- );
122
-
123
- return chatMessages;
124
- }
125
-
126
- async ensureTablesExist(): Promise<void> {
127
- const client = await this.pool.connect();
128
- try {
129
- // Check if the threads table exists
130
- const threadsResult = await client.query<{ exists: boolean }>(`
131
- SELECT EXISTS (
132
- SELECT 1
133
- FROM information_schema.tables
134
- WHERE table_name = 'mastra_threads'
135
- );
136
- `);
137
-
138
- if (!threadsResult?.rows?.[0]?.exists) {
139
- await client.query(`
140
- CREATE TABLE IF NOT EXISTS mastra_threads (
141
- id UUID PRIMARY KEY,
142
- resourceid TEXT,
143
- title TEXT,
144
- created_at TIMESTAMP WITH TIME ZONE NOT NULL,
145
- updated_at TIMESTAMP WITH TIME ZONE NOT NULL,
146
- metadata JSONB
147
- );
148
- `);
149
- }
150
-
151
- // Check if the messages table exists
152
- const messagesResult = await client.query<{ exists: boolean }>(`
153
- SELECT EXISTS (
154
- SELECT 1
155
- FROM information_schema.tables
156
- WHERE table_name = 'mastra_messages'
157
- );
158
- `);
159
-
160
- if (!messagesResult?.rows?.[0]?.exists) {
161
- await client.query(`
162
- CREATE TABLE IF NOT EXISTS mastra_messages (
163
- id UUID PRIMARY KEY,
164
- content TEXT NOT NULL,
165
- role VARCHAR(20) NOT NULL,
166
- created_at TIMESTAMP WITH TIME ZONE NOT NULL,
167
- tool_call_ids TEXT DEFAULT NULL,
168
- tool_call_args TEXT DEFAULT NULL,
169
- tool_call_args_expire_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
170
- type VARCHAR(20) NOT NULL,
171
- tokens INTEGER DEFAULT NULL,
172
- thread_id UUID NOT NULL,
173
- FOREIGN KEY (thread_id) REFERENCES mastra_threads(id)
174
- );
175
- `);
176
- }
177
- } finally {
178
- client.release();
179
- }
180
- }
181
-
182
- async updateThread(id: string, title: string, metadata: Record<string, unknown>): Promise<ThreadType> {
183
- const client = await this.pool.connect();
184
- try {
185
- const result = await client.query<ThreadType>(
186
- `
187
- UPDATE mastra_threads
188
- SET title = $1, metadata = $2, updated_at = NOW()
189
- WHERE id = $3
190
- RETURNING *
191
- `,
192
- [title, JSON.stringify(metadata), id],
193
- );
194
- return result?.rows?.[0]!;
195
- } finally {
196
- client.release();
197
- }
198
- }
199
-
200
- async deleteThread(id: string): Promise<void> {
201
- const client = await this.pool.connect();
202
- try {
203
- await client.query(
204
- `
205
- DELETE FROM mastra_messages
206
- WHERE thread_id = $1
207
- `,
208
- [id],
209
- );
210
-
211
- await client.query(
212
- `
213
- DELETE FROM mastra_threads
214
- WHERE id = $1
215
- `,
216
- [id],
217
- );
218
- } finally {
219
- client.release();
220
- }
221
- }
222
-
223
- async deleteMessage(id: string): Promise<void> {
224
- const client = await this.pool.connect();
225
- try {
226
- await client.query(
227
- `
228
- DELETE FROM mastra_messages
229
- WHERE id = $1
230
- `,
231
- [id],
232
- );
233
- } finally {
234
- client.release();
235
- }
236
- }
17
+ /**
18
+ * Threads
19
+ */
237
20
 
238
21
  async getThreadById({ threadId }: { threadId: string }): Promise<ThreadType | null> {
239
- console.log('getThreadById', threadId);
240
22
  await this.ensureTablesExist();
241
23
 
242
24
  const client = await this.pool.connect();
@@ -296,7 +78,52 @@ export class PgMemory extends MastraMemory {
296
78
  }
297
79
  }
298
80
 
299
- async checkIfValidArgExists({ hashedToolCallArgs }: { hashedToolCallArgs: string }): Promise<boolean> {
81
+ async updateThread(id: string, title: string, metadata: Record<string, unknown>): Promise<ThreadType> {
82
+ const client = await this.pool.connect();
83
+ try {
84
+ const result = await client.query<ThreadType>(
85
+ `
86
+ UPDATE mastra_threads
87
+ SET title = $1, metadata = $2, updated_at = NOW()
88
+ WHERE id = $3
89
+ RETURNING *
90
+ `,
91
+ [title, JSON.stringify(metadata), id],
92
+ );
93
+ return result?.rows?.[0]!;
94
+ } finally {
95
+ client.release();
96
+ }
97
+ }
98
+
99
+ async deleteThread(id: string): Promise<void> {
100
+ const client = await this.pool.connect();
101
+ try {
102
+ await client.query(
103
+ `
104
+ DELETE FROM mastra_messages
105
+ WHERE thread_id = $1
106
+ `,
107
+ [id],
108
+ );
109
+
110
+ await client.query(
111
+ `
112
+ DELETE FROM mastra_threads
113
+ WHERE id = $1
114
+ `,
115
+ [id],
116
+ );
117
+ } finally {
118
+ client.release();
119
+ }
120
+ }
121
+
122
+ /**
123
+ * Tool Cache
124
+ */
125
+
126
+ async validateToolCallArgs({ hashedArgs }: { hashedArgs: string }): Promise<boolean> {
300
127
  await this.ensureTablesExist();
301
128
 
302
129
  const client = await this.pool.connect();
@@ -311,7 +138,7 @@ export class PgMemory extends MastraMemory {
311
138
  AND tool_call_args_expire_at > $2
312
139
  ORDER BY created_at ASC
313
140
  LIMIT 1`,
314
- [JSON.stringify([hashedToolCallArgs]), new Date().toISOString()],
141
+ [JSON.stringify([hashedArgs]), new Date().toISOString()],
315
142
  );
316
143
 
317
144
  return toolArgsResult.rows.length > 0;
@@ -323,7 +150,7 @@ export class PgMemory extends MastraMemory {
323
150
  }
324
151
  }
325
152
 
326
- async getCachedToolResult({
153
+ async getToolResult({
327
154
  threadId,
328
155
  toolArgs,
329
156
  toolName,
@@ -343,8 +170,6 @@ export class PgMemory extends MastraMemory {
343
170
  .update(JSON.stringify({ args: toolArgs, threadId, toolName }))
344
171
  .digest('hex');
345
172
 
346
- console.log('hashedToolArgs====', hashedToolArgs);
347
-
348
173
  const toolArgsResult = await client.query<{ tool_call_ids: string; tool_call_args: string; created_at: string }>(
349
174
  `SELECT tool_call_ids,
350
175
  tool_call_args,
@@ -358,20 +183,13 @@ export class PgMemory extends MastraMemory {
358
183
  );
359
184
 
360
185
  if (toolArgsResult.rows.length > 0) {
361
- console.log('toolArgsResult====', JSON.stringify(toolArgsResult.rows[0], null, 2));
362
186
  const toolCallArgs = JSON.parse(toolArgsResult.rows[0]?.tool_call_args!) as string[];
363
187
  const toolCallIds = JSON.parse(toolArgsResult.rows[0]?.tool_call_ids!) as string[];
364
188
  const createdAt = toolArgsResult.rows[0]?.created_at!;
365
189
 
366
- console.log('toolCallArgs====', JSON.stringify(toolCallArgs, null, 2));
367
- console.log('toolCallIds====', JSON.stringify(toolCallIds, null, 2));
368
- console.log('createdAt====', createdAt);
369
-
370
190
  const toolCallArgsIndex = toolCallArgs.findIndex(arg => arg === hashedToolArgs);
371
191
  const correspondingToolCallId = toolCallIds[toolCallArgsIndex];
372
192
 
373
- console.log('correspondingToolCallId====', { correspondingToolCallId, toolCallArgsIndex });
374
-
375
193
  const toolResult = await client.query<{ content: string }>(
376
194
  `SELECT content
377
195
  FROM mastra_messages
@@ -383,8 +201,6 @@ export class PgMemory extends MastraMemory {
383
201
  [threadId, `%${correspondingToolCallId}%`, new Date(createdAt).toISOString()],
384
202
  );
385
203
 
386
- console.log('called toolResult');
387
-
388
204
  if (toolResult.rows.length === 0) {
389
205
  console.log('no tool result found');
390
206
  return null;
@@ -393,8 +209,6 @@ export class PgMemory extends MastraMemory {
393
209
  const toolResultContent = JSON.parse(toolResult.rows[0]?.content!) as Array<ToolResultPart>;
394
210
  const requiredToolResult = toolResultContent.find(part => part.toolCallId === correspondingToolCallId);
395
211
 
396
- console.log('requiredToolResult====', JSON.stringify(requiredToolResult, null, 2));
397
-
398
212
  if (requiredToolResult) {
399
213
  return requiredToolResult.result;
400
214
  }
@@ -409,17 +223,18 @@ export class PgMemory extends MastraMemory {
409
223
  }
410
224
  }
411
225
 
412
- async getContextWindow({
226
+ async getContextWindow<T extends 'raw' | 'core_message'>({
413
227
  threadId,
414
228
  startDate,
415
229
  endDate,
230
+ format = 'raw' as T,
416
231
  }: {
232
+ format?: T;
417
233
  threadId: string;
418
234
  startDate?: Date;
419
235
  endDate?: Date;
420
- }): Promise<MessageType[]> {
236
+ }) {
421
237
  await this.ensureTablesExist();
422
- console.log('table exists');
423
238
  const client = await this.pool.connect();
424
239
 
425
240
  try {
@@ -431,7 +246,7 @@ export class PgMemory extends MastraMemory {
431
246
  SUM(tokens) OVER (ORDER BY created_at DESC) as running_total
432
247
  FROM mastra_messages
433
248
  WHERE thread_id = $1
434
- AND type = 'text'
249
+ AND type IN ('text', 'tool-result')
435
250
  ${startDate ? `AND created_at >= '${startDate.toISOString()}'` : ''}
436
251
  ${endDate ? `AND created_at <= '${endDate.toISOString()}'` : ''}
437
252
  ORDER BY created_at DESC
@@ -448,9 +263,8 @@ export class PgMemory extends MastraMemory {
448
263
  [threadId, this.MAX_CONTEXT_TOKENS],
449
264
  );
450
265
 
451
- console.log('result===', JSON.stringify(result.rows, null, 2));
452
-
453
- return this.processMessages(result.rows);
266
+ console.log('Format', format);
267
+ return this.parseMessages(result.rows) as MessageResponse<T>;
454
268
  }
455
269
 
456
270
  //get all messages
@@ -463,16 +277,15 @@ export class PgMemory extends MastraMemory {
463
277
  thread_id AS threadId
464
278
  FROM mastra_messages
465
279
  WHERE thread_id = $1
466
- AND type = 'text'
280
+ AND type IN ('text', 'tool-result')
467
281
  ${startDate ? `AND created_at >= '${startDate.toISOString()}'` : ''}
468
282
  ${endDate ? `AND created_at <= '${endDate.toISOString()}'` : ''}
469
283
  ORDER BY created_at ASC`,
470
284
  [threadId],
471
285
  );
472
286
 
473
- console.log('result===', JSON.stringify(result.rows, null, 2));
474
-
475
- return this.processMessages(result.rows);
287
+ console.log('Format', format);
288
+ return this.parseMessages(result.rows) as MessageResponse<T>;
476
289
  } catch (error) {
477
290
  console.log('error getting context window====', error);
478
291
  return [];
@@ -481,6 +294,40 @@ export class PgMemory extends MastraMemory {
481
294
  }
482
295
  }
483
296
 
297
+ /**
298
+ * Messages
299
+ */
300
+
301
+ async getMessages({ threadId }: { threadId: string }): Promise<{ messages: MessageType[]; uiMessages: AiMessage[] }> {
302
+ await this.ensureTablesExist();
303
+
304
+ const client = await this.pool.connect();
305
+ try {
306
+ const result = await client.query<MessageType>(
307
+ `
308
+ SELECT
309
+ id,
310
+ content,
311
+ role,
312
+ type,
313
+ created_at AS createdAt,
314
+ thread_id AS threadId
315
+ FROM mastra_messages
316
+ WHERE thread_id = $1
317
+ ORDER BY created_at ASC
318
+ `,
319
+ [threadId],
320
+ );
321
+
322
+ const messages = this.parseMessages(result.rows);
323
+ const uiMessages = this.convertToUIMessages(messages);
324
+
325
+ return { messages, uiMessages };
326
+ } finally {
327
+ client.release();
328
+ }
329
+ }
330
+
484
331
  async saveMessages({ messages }: { messages: MessageType[] }): Promise<MessageType[]> {
485
332
  await this.ensureTablesExist();
486
333
 
@@ -488,7 +335,6 @@ export class PgMemory extends MastraMemory {
488
335
  try {
489
336
  await client.query('BEGIN');
490
337
  for (const message of messages) {
491
- console.log('saving message====', JSON.stringify(message, null, 2));
492
338
  const { id, content, role, createdAt, threadId, toolCallIds, toolCallArgs, type, toolNames } = message;
493
339
  let tokens = null;
494
340
  if (type === 'text') {
@@ -511,8 +357,8 @@ export class PgMemory extends MastraMemory {
511
357
  // Check all args sequentially
512
358
  validArgExists = true; // Start true and set to false if any check fails
513
359
  for (let i = 0; i < hashedToolCallArgs.length; i++) {
514
- const isValid = await this.checkIfValidArgExists({
515
- hashedToolCallArgs: hashedToolCallArgs[i]!,
360
+ const isValid = await this.validateToolCallArgs({
361
+ hashedArgs: hashedToolCallArgs[i]!,
516
362
  });
517
363
  if (!isValid) {
518
364
  validArgExists = false;
@@ -527,8 +373,6 @@ export class PgMemory extends MastraMemory {
527
373
  ? createdAt
528
374
  : new Date(createdAt.getTime() + 5 * 60 * 1000); // 5 minutes
529
375
 
530
- console.log('just before query');
531
-
532
376
  await client.query(
533
377
  `
534
378
  INSERT INTO mastra_messages (id, content, role, created_at, thread_id, tool_call_ids, tool_call_args, type, tokens, tool_call_args_expire_at)
@@ -548,7 +392,6 @@ export class PgMemory extends MastraMemory {
548
392
  ],
549
393
  );
550
394
  }
551
- console.log('just after query');
552
395
  await client.query('COMMIT');
553
396
  return messages;
554
397
  } catch (error) {
@@ -559,33 +402,91 @@ export class PgMemory extends MastraMemory {
559
402
  }
560
403
  }
561
404
 
562
- async getMessages({ threadId }: { threadId: string }): Promise<{ messages: MessageType[]; uiMessages: AiMessage[] }> {
563
- await this.ensureTablesExist();
564
-
405
+ async deleteMessage(id: string): Promise<void> {
565
406
  const client = await this.pool.connect();
566
407
  try {
567
- const result = await client.query<MessageType>(
408
+ await client.query(
568
409
  `
569
- SELECT
570
- id,
571
- content,
572
- role,
573
- type,
574
- created_at AS createdAt,
575
- thread_id AS threadId
576
- FROM mastra_messages
577
- WHERE thread_id = $1
578
- ORDER BY created_at ASC
579
- `,
580
- [threadId],
410
+ DELETE FROM mastra_messages
411
+ WHERE id = $1
412
+ `,
413
+ [id],
581
414
  );
415
+ } finally {
416
+ client.release();
417
+ }
418
+ }
582
419
 
583
- const messages = this.processMessages(result.rows);
584
- const uiMessages = this.convertToUIMessages(messages);
420
+ /**
421
+ * Table Management
422
+ */
585
423
 
586
- return { messages, uiMessages };
424
+ async drop() {
425
+ const client = await this.pool.connect();
426
+ await client.query('DELETE FROM mastra_messages');
427
+ await client.query('DELETE FROM mastra_threads');
428
+ client.release();
429
+ await this.pool.end();
430
+ }
431
+
432
+ async ensureTablesExist(): Promise<void> {
433
+ if (this.hasTables) {
434
+ return;
435
+ }
436
+
437
+ const client = await this.pool.connect();
438
+ try {
439
+ // Check if the threads table exists
440
+ const threadsResult = await client.query<{ exists: boolean }>(`
441
+ SELECT EXISTS (
442
+ SELECT 1
443
+ FROM information_schema.tables
444
+ WHERE table_name = 'mastra_threads'
445
+ );
446
+ `);
447
+
448
+ if (!threadsResult?.rows?.[0]?.exists) {
449
+ await client.query(`
450
+ CREATE TABLE IF NOT EXISTS mastra_threads (
451
+ id UUID PRIMARY KEY,
452
+ resourceid TEXT,
453
+ title TEXT,
454
+ created_at TIMESTAMP WITH TIME ZONE NOT NULL,
455
+ updated_at TIMESTAMP WITH TIME ZONE NOT NULL,
456
+ metadata JSONB
457
+ );
458
+ `);
459
+ }
460
+
461
+ // Check if the messages table exists
462
+ const messagesResult = await client.query<{ exists: boolean }>(`
463
+ SELECT EXISTS (
464
+ SELECT 1
465
+ FROM information_schema.tables
466
+ WHERE table_name = 'mastra_messages'
467
+ );
468
+ `);
469
+
470
+ if (!messagesResult?.rows?.[0]?.exists) {
471
+ await client.query(`
472
+ CREATE TABLE IF NOT EXISTS mastra_messages (
473
+ id UUID PRIMARY KEY,
474
+ content TEXT NOT NULL,
475
+ role VARCHAR(20) NOT NULL,
476
+ created_at TIMESTAMP WITH TIME ZONE NOT NULL,
477
+ tool_call_ids TEXT DEFAULT NULL,
478
+ tool_call_args TEXT DEFAULT NULL,
479
+ tool_call_args_expire_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
480
+ type VARCHAR(20) NOT NULL,
481
+ tokens INTEGER DEFAULT NULL,
482
+ thread_id UUID NOT NULL,
483
+ FOREIGN KEY (thread_id) REFERENCES mastra_threads(id)
484
+ );
485
+ `);
486
+ }
587
487
  } finally {
588
488
  client.release();
489
+ this.hasTables = true;
589
490
  }
590
491
  }
591
492
  }