@mastra/memory 0.0.2-alpha.1 → 0.0.2-alpha.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,5 @@
1
- import { MastraMemory, MessageType, ThreadType } from '@mastra/core';
2
- import { ToolResultPart, CoreToolMessage, ToolInvocation, Message as AiMessage, TextPart } from 'ai';
1
+ import { MastraMemory, MessageType, ThreadType, MessageResponse } from '@mastra/core';
2
+ import { ToolResultPart, Message as AiMessage, TextPart } from 'ai';
3
3
  import crypto from 'crypto';
4
4
  import pg from 'pg';
5
5
 
@@ -7,236 +7,18 @@ const { Pool } = pg;
7
7
 
8
8
  export class PgMemory extends MastraMemory {
9
9
  private pool: pg.Pool;
10
- private MAX_CONTEXT_TOKENS?: number;
11
-
10
+ hasTables: boolean = false;
12
11
  constructor(config: { connectionString: string; maxTokens?: number }) {
13
12
  super();
14
13
  this.pool = new Pool({ connectionString: config.connectionString });
15
14
  this.MAX_CONTEXT_TOKENS = config.maxTokens;
16
15
  }
17
16
 
18
- async drop() {
19
- const client = await this.pool.connect();
20
- await client.query('DELETE FROM mastra_messages');
21
- await client.query('DELETE FROM mastra_threads');
22
- client.release();
23
- await this.pool.end();
24
- }
25
-
26
- // Simplified token estimation
27
- private estimateTokens(text: string): number {
28
- return Math.ceil(text.split(' ').length * 1.3);
29
- }
30
-
31
- private processMessages(messages: MessageType[]): MessageType[] {
32
- return messages.map(mssg => ({
33
- ...mssg,
34
- content: typeof mssg.content === 'string' ? JSON.parse((mssg as MessageType).content as string) : mssg.content,
35
- }));
36
- }
37
-
38
- private convertToUIMessages(messages: MessageType[]): AiMessage[] {
39
- function addToolMessageToChat({
40
- toolMessage,
41
- messages,
42
- toolResultContents,
43
- }: {
44
- toolMessage: CoreToolMessage;
45
- messages: Array<AiMessage>;
46
- toolResultContents: Array<ToolResultPart>;
47
- }): { chatMessages: Array<AiMessage>; toolResultContents: Array<ToolResultPart> } {
48
- const chatMessages = messages.map(message => {
49
- if (message.toolInvocations) {
50
- return {
51
- ...message,
52
- toolInvocations: message.toolInvocations.map(toolInvocation => {
53
- const toolResult = toolMessage.content.find(tool => tool.toolCallId === toolInvocation.toolCallId);
54
-
55
- if (toolResult) {
56
- return {
57
- ...toolInvocation,
58
- state: 'result',
59
- result: toolResult.result,
60
- };
61
- }
62
-
63
- return toolInvocation;
64
- }),
65
- };
66
- }
67
-
68
- return message;
69
- }) as Array<AiMessage>;
70
-
71
- const resultContents = [...toolResultContents, ...toolMessage.content];
72
-
73
- return { chatMessages, toolResultContents: resultContents };
74
- }
75
-
76
- const { chatMessages } = messages.reduce(
77
- (obj: { chatMessages: Array<AiMessage>; toolResultContents: Array<ToolResultPart> }, message) => {
78
- if (message.role === 'tool') {
79
- return addToolMessageToChat({
80
- toolMessage: message as CoreToolMessage,
81
- messages: obj.chatMessages,
82
- toolResultContents: obj.toolResultContents,
83
- });
84
- }
85
-
86
- let textContent = '';
87
- let toolInvocations: Array<ToolInvocation> = [];
88
-
89
- if (typeof message.content === 'string') {
90
- textContent = message.content;
91
- } else if (Array.isArray(message.content)) {
92
- for (const content of message.content) {
93
- if (content.type === 'text') {
94
- textContent += content.text;
95
- } else if (content.type === 'tool-call') {
96
- const toolResult = obj.toolResultContents.find(tool => tool.toolCallId === content.toolCallId);
97
- toolInvocations.push({
98
- state: toolResult ? 'result' : 'call',
99
- toolCallId: content.toolCallId,
100
- toolName: content.toolName,
101
- args: content.args,
102
- result: toolResult?.result,
103
- });
104
- }
105
- }
106
- }
107
-
108
- obj.chatMessages.push({
109
- id: message.id,
110
- role: message.role as AiMessage['role'],
111
- content: textContent,
112
- toolInvocations,
113
- });
114
-
115
- return obj;
116
- },
117
- { chatMessages: [], toolResultContents: [] } as {
118
- chatMessages: Array<AiMessage>;
119
- toolResultContents: Array<ToolResultPart>;
120
- },
121
- );
122
-
123
- return chatMessages;
124
- }
125
-
126
- async ensureTablesExist(): Promise<void> {
127
- const client = await this.pool.connect();
128
- try {
129
- // Check if the threads table exists
130
- const threadsResult = await client.query<{ exists: boolean }>(`
131
- SELECT EXISTS (
132
- SELECT 1
133
- FROM information_schema.tables
134
- WHERE table_name = 'mastra_threads'
135
- );
136
- `);
137
-
138
- if (!threadsResult?.rows?.[0]?.exists) {
139
- await client.query(`
140
- CREATE TABLE IF NOT EXISTS mastra_threads (
141
- id UUID PRIMARY KEY,
142
- resourceid TEXT,
143
- title TEXT,
144
- created_at TIMESTAMP WITH TIME ZONE NOT NULL,
145
- updated_at TIMESTAMP WITH TIME ZONE NOT NULL,
146
- metadata JSONB
147
- );
148
- `);
149
- }
150
-
151
- // Check if the messages table exists
152
- const messagesResult = await client.query<{ exists: boolean }>(`
153
- SELECT EXISTS (
154
- SELECT 1
155
- FROM information_schema.tables
156
- WHERE table_name = 'mastra_messages'
157
- );
158
- `);
159
-
160
- if (!messagesResult?.rows?.[0]?.exists) {
161
- await client.query(`
162
- CREATE TABLE IF NOT EXISTS mastra_messages (
163
- id UUID PRIMARY KEY,
164
- content TEXT NOT NULL,
165
- role VARCHAR(20) NOT NULL,
166
- created_at TIMESTAMP WITH TIME ZONE NOT NULL,
167
- tool_call_ids TEXT DEFAULT NULL,
168
- tool_call_args TEXT DEFAULT NULL,
169
- tool_call_args_expire_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
170
- type VARCHAR(20) NOT NULL,
171
- tokens INTEGER DEFAULT NULL,
172
- thread_id UUID NOT NULL,
173
- FOREIGN KEY (thread_id) REFERENCES mastra_threads(id)
174
- );
175
- `);
176
- }
177
- } finally {
178
- client.release();
179
- }
180
- }
181
-
182
- async updateThread(id: string, title: string, metadata: Record<string, unknown>): Promise<ThreadType> {
183
- const client = await this.pool.connect();
184
- try {
185
- const result = await client.query<ThreadType>(
186
- `
187
- UPDATE mastra_threads
188
- SET title = $1, metadata = $2, updated_at = NOW()
189
- WHERE id = $3
190
- RETURNING *
191
- `,
192
- [title, JSON.stringify(metadata), id],
193
- );
194
- return result?.rows?.[0]!;
195
- } finally {
196
- client.release();
197
- }
198
- }
199
-
200
- async deleteThread(id: string): Promise<void> {
201
- const client = await this.pool.connect();
202
- try {
203
- await client.query(
204
- `
205
- DELETE FROM mastra_messages
206
- WHERE thread_id = $1
207
- `,
208
- [id],
209
- );
210
-
211
- await client.query(
212
- `
213
- DELETE FROM mastra_threads
214
- WHERE id = $1
215
- `,
216
- [id],
217
- );
218
- } finally {
219
- client.release();
220
- }
221
- }
222
-
223
- async deleteMessage(id: string): Promise<void> {
224
- const client = await this.pool.connect();
225
- try {
226
- await client.query(
227
- `
228
- DELETE FROM mastra_messages
229
- WHERE id = $1
230
- `,
231
- [id],
232
- );
233
- } finally {
234
- client.release();
235
- }
236
- }
17
+ /**
18
+ * Threads
19
+ */
237
20
 
238
21
  async getThreadById({ threadId }: { threadId: string }): Promise<ThreadType | null> {
239
- console.log('getThreadById', threadId);
240
22
  await this.ensureTablesExist();
241
23
 
242
24
  const client = await this.pool.connect();
@@ -296,7 +78,52 @@ export class PgMemory extends MastraMemory {
296
78
  }
297
79
  }
298
80
 
299
- async checkIfValidArgExists({ hashedToolCallArgs }: { hashedToolCallArgs: string }): Promise<boolean> {
81
+ async updateThread(id: string, title: string, metadata: Record<string, unknown>): Promise<ThreadType> {
82
+ const client = await this.pool.connect();
83
+ try {
84
+ const result = await client.query<ThreadType>(
85
+ `
86
+ UPDATE mastra_threads
87
+ SET title = $1, metadata = $2, updated_at = NOW()
88
+ WHERE id = $3
89
+ RETURNING *
90
+ `,
91
+ [title, JSON.stringify(metadata), id],
92
+ );
93
+ return result?.rows?.[0]!;
94
+ } finally {
95
+ client.release();
96
+ }
97
+ }
98
+
99
+ async deleteThread(id: string): Promise<void> {
100
+ const client = await this.pool.connect();
101
+ try {
102
+ await client.query(
103
+ `
104
+ DELETE FROM mastra_messages
105
+ WHERE thread_id = $1
106
+ `,
107
+ [id],
108
+ );
109
+
110
+ await client.query(
111
+ `
112
+ DELETE FROM mastra_threads
113
+ WHERE id = $1
114
+ `,
115
+ [id],
116
+ );
117
+ } finally {
118
+ client.release();
119
+ }
120
+ }
121
+
122
+ /**
123
+ * Tool Cache
124
+ */
125
+
126
+ async validateToolCallArgs({ hashedArgs }: { hashedArgs: string }): Promise<boolean> {
300
127
  await this.ensureTablesExist();
301
128
 
302
129
  const client = await this.pool.connect();
@@ -311,7 +138,7 @@ export class PgMemory extends MastraMemory {
311
138
  AND tool_call_args_expire_at > $2
312
139
  ORDER BY created_at ASC
313
140
  LIMIT 1`,
314
- [JSON.stringify([hashedToolCallArgs]), new Date().toISOString()],
141
+ [JSON.stringify([hashedArgs]), new Date().toISOString()],
315
142
  );
316
143
 
317
144
  return toolArgsResult.rows.length > 0;
@@ -323,7 +150,7 @@ export class PgMemory extends MastraMemory {
323
150
  }
324
151
  }
325
152
 
326
- async getCachedToolResult({
153
+ async getToolResult({
327
154
  threadId,
328
155
  toolArgs,
329
156
  toolName,
@@ -409,17 +236,18 @@ export class PgMemory extends MastraMemory {
409
236
  }
410
237
  }
411
238
 
412
- async getContextWindow({
239
+ async getContextWindow<T extends 'raw' | 'core_message'>({
413
240
  threadId,
414
241
  startDate,
415
242
  endDate,
243
+ format = 'raw' as T,
416
244
  }: {
245
+ format?: T;
417
246
  threadId: string;
418
247
  startDate?: Date;
419
248
  endDate?: Date;
420
- }): Promise<MessageType[]> {
249
+ }) {
421
250
  await this.ensureTablesExist();
422
- console.log('table exists');
423
251
  const client = await this.pool.connect();
424
252
 
425
253
  try {
@@ -448,9 +276,8 @@ export class PgMemory extends MastraMemory {
448
276
  [threadId, this.MAX_CONTEXT_TOKENS],
449
277
  );
450
278
 
451
- console.log('result===', JSON.stringify(result.rows, null, 2));
452
-
453
- return this.processMessages(result.rows);
279
+ console.log('Format', format);
280
+ return this.parseMessages(result.rows) as MessageResponse<T>;
454
281
  }
455
282
 
456
283
  //get all messages
@@ -470,9 +297,8 @@ export class PgMemory extends MastraMemory {
470
297
  [threadId],
471
298
  );
472
299
 
473
- console.log('result===', JSON.stringify(result.rows, null, 2));
474
-
475
- return this.processMessages(result.rows);
300
+ console.log('Format', format);
301
+ return this.parseMessages(result.rows) as MessageResponse<T>;
476
302
  } catch (error) {
477
303
  console.log('error getting context window====', error);
478
304
  return [];
@@ -481,6 +307,40 @@ export class PgMemory extends MastraMemory {
481
307
  }
482
308
  }
483
309
 
310
+ /**
311
+ * Messages
312
+ */
313
+
314
+ async getMessages({ threadId }: { threadId: string }): Promise<{ messages: MessageType[]; uiMessages: AiMessage[] }> {
315
+ await this.ensureTablesExist();
316
+
317
+ const client = await this.pool.connect();
318
+ try {
319
+ const result = await client.query<MessageType>(
320
+ `
321
+ SELECT
322
+ id,
323
+ content,
324
+ role,
325
+ type,
326
+ created_at AS createdAt,
327
+ thread_id AS threadId
328
+ FROM mastra_messages
329
+ WHERE thread_id = $1
330
+ ORDER BY created_at ASC
331
+ `,
332
+ [threadId],
333
+ );
334
+
335
+ const messages = this.parseMessages(result.rows);
336
+ const uiMessages = this.convertToUIMessages(messages);
337
+
338
+ return { messages, uiMessages };
339
+ } finally {
340
+ client.release();
341
+ }
342
+ }
343
+
484
344
  async saveMessages({ messages }: { messages: MessageType[] }): Promise<MessageType[]> {
485
345
  await this.ensureTablesExist();
486
346
 
@@ -511,8 +371,8 @@ export class PgMemory extends MastraMemory {
511
371
  // Check all args sequentially
512
372
  validArgExists = true; // Start true and set to false if any check fails
513
373
  for (let i = 0; i < hashedToolCallArgs.length; i++) {
514
- const isValid = await this.checkIfValidArgExists({
515
- hashedToolCallArgs: hashedToolCallArgs[i]!,
374
+ const isValid = await this.validateToolCallArgs({
375
+ hashedArgs: hashedToolCallArgs[i]!,
516
376
  });
517
377
  if (!isValid) {
518
378
  validArgExists = false;
@@ -559,33 +419,91 @@ export class PgMemory extends MastraMemory {
559
419
  }
560
420
  }
561
421
 
562
- async getMessages({ threadId }: { threadId: string }): Promise<{ messages: MessageType[]; uiMessages: AiMessage[] }> {
563
- await this.ensureTablesExist();
564
-
422
+ async deleteMessage(id: string): Promise<void> {
565
423
  const client = await this.pool.connect();
566
424
  try {
567
- const result = await client.query<MessageType>(
425
+ await client.query(
568
426
  `
569
- SELECT
570
- id,
571
- content,
572
- role,
573
- type,
574
- created_at AS createdAt,
575
- thread_id AS threadId
576
- FROM mastra_messages
577
- WHERE thread_id = $1
578
- ORDER BY created_at ASC
579
- `,
580
- [threadId],
427
+ DELETE FROM mastra_messages
428
+ WHERE id = $1
429
+ `,
430
+ [id],
581
431
  );
432
+ } finally {
433
+ client.release();
434
+ }
435
+ }
582
436
 
583
- const messages = this.processMessages(result.rows);
584
- const uiMessages = this.convertToUIMessages(messages);
437
+ /**
438
+ * Table Management
439
+ */
585
440
 
586
- return { messages, uiMessages };
441
+ async drop() {
442
+ const client = await this.pool.connect();
443
+ await client.query('DELETE FROM mastra_messages');
444
+ await client.query('DELETE FROM mastra_threads');
445
+ client.release();
446
+ await this.pool.end();
447
+ }
448
+
449
+ async ensureTablesExist(): Promise<void> {
450
+ if (this.hasTables) {
451
+ return;
452
+ }
453
+
454
+ const client = await this.pool.connect();
455
+ try {
456
+ // Check if the threads table exists
457
+ const threadsResult = await client.query<{ exists: boolean }>(`
458
+ SELECT EXISTS (
459
+ SELECT 1
460
+ FROM information_schema.tables
461
+ WHERE table_name = 'mastra_threads'
462
+ );
463
+ `);
464
+
465
+ if (!threadsResult?.rows?.[0]?.exists) {
466
+ await client.query(`
467
+ CREATE TABLE IF NOT EXISTS mastra_threads (
468
+ id UUID PRIMARY KEY,
469
+ resourceid TEXT,
470
+ title TEXT,
471
+ created_at TIMESTAMP WITH TIME ZONE NOT NULL,
472
+ updated_at TIMESTAMP WITH TIME ZONE NOT NULL,
473
+ metadata JSONB
474
+ );
475
+ `);
476
+ }
477
+
478
+ // Check if the messages table exists
479
+ const messagesResult = await client.query<{ exists: boolean }>(`
480
+ SELECT EXISTS (
481
+ SELECT 1
482
+ FROM information_schema.tables
483
+ WHERE table_name = 'mastra_messages'
484
+ );
485
+ `);
486
+
487
+ if (!messagesResult?.rows?.[0]?.exists) {
488
+ await client.query(`
489
+ CREATE TABLE IF NOT EXISTS mastra_messages (
490
+ id UUID PRIMARY KEY,
491
+ content TEXT NOT NULL,
492
+ role VARCHAR(20) NOT NULL,
493
+ created_at TIMESTAMP WITH TIME ZONE NOT NULL,
494
+ tool_call_ids TEXT DEFAULT NULL,
495
+ tool_call_args TEXT DEFAULT NULL,
496
+ tool_call_args_expire_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
497
+ type VARCHAR(20) NOT NULL,
498
+ tokens INTEGER DEFAULT NULL,
499
+ thread_id UUID NOT NULL,
500
+ FOREIGN KEY (thread_id) REFERENCES mastra_threads(id)
501
+ );
502
+ `);
503
+ }
587
504
  } finally {
588
505
  client.release();
506
+ this.hasTables = true;
589
507
  }
590
508
  }
591
509
  }