@mastra/pg 0.12.3 → 0.12.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,160 @@
1
+ import { ErrorCategory, ErrorDomain, MastraError } from '@mastra/core/error';
2
+ import type { PaginationInfo, StorageGetTracesArg, StorageGetTracesPaginatedArg } from '@mastra/core/storage';
3
+ import { TABLE_TRACES, TracesStorage, safelyParseJSON } from '@mastra/core/storage';
4
+ import type { Trace } from '@mastra/core/telemetry';
5
+ import { parseFieldKey } from '@mastra/core/utils';
6
+ import type { IDatabase } from 'pg-promise';
7
+ import type { StoreOperationsPG } from '../operations';
8
+ import { getSchemaName, getTableName } from '../utils';
9
+
10
+ export class TracesPG extends TracesStorage {
11
+ public client: IDatabase<{}>;
12
+ private operations: StoreOperationsPG;
13
+ private schema?: string;
14
+
15
+ constructor({
16
+ client,
17
+ operations,
18
+ schema,
19
+ }: {
20
+ client: IDatabase<{}>;
21
+ operations: StoreOperationsPG;
22
+ schema?: string;
23
+ }) {
24
+ super();
25
+ this.client = client;
26
+ this.operations = operations;
27
+ this.schema = schema;
28
+ }
29
+
30
+ async getTraces(args: StorageGetTracesArg): Promise<Trace[]> {
31
+ if (args.fromDate || args.toDate) {
32
+ (args as any).dateRange = {
33
+ start: args.fromDate,
34
+ end: args.toDate,
35
+ };
36
+ }
37
+ try {
38
+ const result = await this.getTracesPaginated(args);
39
+ return result.traces;
40
+ } catch (error) {
41
+ throw new MastraError(
42
+ {
43
+ id: 'MASTRA_STORAGE_PG_STORE_GET_TRACES_FAILED',
44
+ domain: ErrorDomain.STORAGE,
45
+ category: ErrorCategory.THIRD_PARTY,
46
+ },
47
+ error,
48
+ );
49
+ }
50
+ }
51
+
52
+ async getTracesPaginated(args: StorageGetTracesPaginatedArg): Promise<PaginationInfo & { traces: Trace[] }> {
53
+ const { name, scope, page = 0, perPage = 100, attributes, filters, dateRange } = args;
54
+ const fromDate = dateRange?.start;
55
+ const toDate = dateRange?.end;
56
+ const currentOffset = page * perPage;
57
+
58
+ const queryParams: any[] = [];
59
+ const conditions: string[] = [];
60
+ let paramIndex = 1;
61
+
62
+ if (name) {
63
+ conditions.push(`name LIKE $${paramIndex++}`);
64
+ queryParams.push(`${name}%`);
65
+ }
66
+ if (scope) {
67
+ conditions.push(`scope = $${paramIndex++}`);
68
+ queryParams.push(scope);
69
+ }
70
+ if (attributes) {
71
+ Object.entries(attributes).forEach(([key, value]) => {
72
+ const parsedKey = parseFieldKey(key);
73
+ conditions.push(`attributes->>'${parsedKey}' = $${paramIndex++}`);
74
+ queryParams.push(value);
75
+ });
76
+ }
77
+ if (filters) {
78
+ Object.entries(filters).forEach(([key, value]) => {
79
+ const parsedKey = parseFieldKey(key);
80
+ conditions.push(`"${parsedKey}" = $${paramIndex++}`);
81
+ queryParams.push(value);
82
+ });
83
+ }
84
+ if (fromDate) {
85
+ conditions.push(`"createdAt" >= $${paramIndex++}`);
86
+ queryParams.push(fromDate);
87
+ }
88
+ if (toDate) {
89
+ conditions.push(`"createdAt" <= $${paramIndex++}`);
90
+ queryParams.push(toDate);
91
+ }
92
+
93
+ const whereClause = conditions.length > 0 ? `WHERE ${conditions.join(' AND ')}` : '';
94
+
95
+ try {
96
+ const countResult = await this.client.oneOrNone<{ count: string }>(
97
+ `SELECT COUNT(*) FROM ${getTableName({ indexName: TABLE_TRACES, schemaName: getSchemaName(this.schema) })} ${whereClause}`,
98
+ queryParams,
99
+ );
100
+ const total = Number(countResult?.count ?? 0);
101
+
102
+ if (total === 0) {
103
+ return {
104
+ traces: [],
105
+ total: 0,
106
+ page,
107
+ perPage,
108
+ hasMore: false,
109
+ };
110
+ }
111
+
112
+ const dataResult = await this.client.manyOrNone<Record<string, any>>(
113
+ `SELECT * FROM ${getTableName({ indexName: TABLE_TRACES, schemaName: getSchemaName(this.schema) })} ${whereClause} ORDER BY "startTime" DESC LIMIT $${paramIndex++} OFFSET $${paramIndex++}`,
114
+ [...queryParams, perPage, currentOffset],
115
+ );
116
+
117
+ const traces = dataResult.map(row => ({
118
+ id: row.id,
119
+ parentSpanId: row.parentSpanId,
120
+ traceId: row.traceId,
121
+ name: row.name,
122
+ scope: row.scope,
123
+ kind: row.kind,
124
+ status: safelyParseJSON(row.status as string),
125
+ events: safelyParseJSON(row.events as string),
126
+ links: safelyParseJSON(row.links as string),
127
+ attributes: safelyParseJSON(row.attributes as string),
128
+ startTime: row.startTime,
129
+ endTime: row.endTime,
130
+ other: safelyParseJSON(row.other as string),
131
+ createdAt: row.createdAtZ || row.createdAt,
132
+ })) as Trace[];
133
+
134
+ return {
135
+ traces,
136
+ total,
137
+ page,
138
+ perPage,
139
+ hasMore: currentOffset + traces.length < total,
140
+ };
141
+ } catch (error) {
142
+ throw new MastraError(
143
+ {
144
+ id: 'MASTRA_STORAGE_PG_STORE_GET_TRACES_PAGINATED_FAILED',
145
+ domain: ErrorDomain.STORAGE,
146
+ category: ErrorCategory.THIRD_PARTY,
147
+ },
148
+ error,
149
+ );
150
+ }
151
+ }
152
+
153
+ async batchTraceInsert({ records }: { records: Record<string, any>[] }): Promise<void> {
154
+ this.logger.debug('Batch inserting traces', { count: records.length });
155
+ await this.operations.batchInsert({
156
+ tableName: TABLE_TRACES,
157
+ records,
158
+ });
159
+ }
160
+ }
@@ -0,0 +1,12 @@
1
+ import { parseSqlIdentifier } from '@mastra/core/utils';
2
+
3
+ export function getSchemaName(schema?: string) {
4
+ return schema ? `"${parseSqlIdentifier(schema, 'schema name')}"` : undefined;
5
+ }
6
+
7
+ export function getTableName({ indexName, schemaName }: { indexName: string; schemaName?: string }) {
8
+ const parsedIndexName = parseSqlIdentifier(indexName, 'index name');
9
+ const quotedIndexName = `"${parsedIndexName}"`;
10
+ const quotedSchemaName = schemaName;
11
+ return quotedSchemaName ? `${quotedSchemaName}.${quotedIndexName}` : quotedIndexName;
12
+ }
@@ -0,0 +1,253 @@
1
+ import type { WorkflowRun, WorkflowRuns, WorkflowRunState } from '@mastra/core';
2
+ import { ErrorCategory, ErrorDomain, MastraError } from '@mastra/core/error';
3
+ import { WorkflowsStorage, TABLE_WORKFLOW_SNAPSHOT } from '@mastra/core/storage';
4
+ import type { IDatabase } from 'pg-promise';
5
+ import type { StoreOperationsPG } from '../operations';
6
+ import { getTableName } from '../utils';
7
+
8
+ function parseWorkflowRun(row: Record<string, any>): WorkflowRun {
9
+ let parsedSnapshot: WorkflowRunState | string = row.snapshot as string;
10
+ if (typeof parsedSnapshot === 'string') {
11
+ try {
12
+ parsedSnapshot = JSON.parse(row.snapshot as string) as WorkflowRunState;
13
+ } catch (e) {
14
+ // If parsing fails, return the raw snapshot string
15
+ console.warn(`Failed to parse snapshot for workflow ${row.workflow_name}: ${e}`);
16
+ }
17
+ }
18
+ return {
19
+ workflowName: row.workflow_name as string,
20
+ runId: row.run_id as string,
21
+ snapshot: parsedSnapshot,
22
+ resourceId: row.resourceId as string,
23
+ createdAt: new Date(row.createdAtZ || (row.createdAt as string)),
24
+ updatedAt: new Date(row.updatedAtZ || (row.updatedAt as string)),
25
+ };
26
+ }
27
+
28
+ export class WorkflowsPG extends WorkflowsStorage {
29
+ public client: IDatabase<{}>;
30
+ private operations: StoreOperationsPG;
31
+ private schema: string;
32
+
33
+ constructor({
34
+ client,
35
+ operations,
36
+ schema,
37
+ }: {
38
+ client: IDatabase<{}>;
39
+ operations: StoreOperationsPG;
40
+ schema: string;
41
+ }) {
42
+ super();
43
+ this.client = client;
44
+ this.operations = operations;
45
+ this.schema = schema;
46
+ }
47
+
48
+ async persistWorkflowSnapshot({
49
+ workflowName,
50
+ runId,
51
+ snapshot,
52
+ }: {
53
+ workflowName: string;
54
+ runId: string;
55
+ snapshot: WorkflowRunState;
56
+ }): Promise<void> {
57
+ try {
58
+ const now = new Date().toISOString();
59
+ await this.client.none(
60
+ `INSERT INTO ${TABLE_WORKFLOW_SNAPSHOT} (workflow_name, run_id, snapshot, "createdAt", "updatedAt")
61
+ VALUES ($1, $2, $3, $4, $5)
62
+ ON CONFLICT (workflow_name, run_id) DO UPDATE
63
+ SET snapshot = $3, "updatedAt" = $5`,
64
+ [workflowName, runId, JSON.stringify(snapshot), now, now],
65
+ );
66
+ } catch (error) {
67
+ throw new MastraError(
68
+ {
69
+ id: 'MASTRA_STORAGE_PG_STORE_PERSIST_WORKFLOW_SNAPSHOT_FAILED',
70
+ domain: ErrorDomain.STORAGE,
71
+ category: ErrorCategory.THIRD_PARTY,
72
+ },
73
+ error,
74
+ );
75
+ }
76
+ }
77
+
78
+ async loadWorkflowSnapshot({
79
+ workflowName,
80
+ runId,
81
+ }: {
82
+ workflowName: string;
83
+ runId: string;
84
+ }): Promise<WorkflowRunState | null> {
85
+ try {
86
+ const result = await this.operations.load<{ snapshot: WorkflowRunState }>({
87
+ tableName: TABLE_WORKFLOW_SNAPSHOT,
88
+ keys: { workflow_name: workflowName, run_id: runId },
89
+ });
90
+
91
+ return result ? result.snapshot : null;
92
+ } catch (error) {
93
+ throw new MastraError(
94
+ {
95
+ id: 'MASTRA_STORAGE_PG_STORE_LOAD_WORKFLOW_SNAPSHOT_FAILED',
96
+ domain: ErrorDomain.STORAGE,
97
+ category: ErrorCategory.THIRD_PARTY,
98
+ },
99
+ error,
100
+ );
101
+ }
102
+ }
103
+
104
+ async getWorkflowRunById({
105
+ runId,
106
+ workflowName,
107
+ }: {
108
+ runId: string;
109
+ workflowName?: string;
110
+ }): Promise<WorkflowRun | null> {
111
+ try {
112
+ const conditions: string[] = [];
113
+ const values: any[] = [];
114
+ let paramIndex = 1;
115
+
116
+ if (runId) {
117
+ conditions.push(`run_id = $${paramIndex}`);
118
+ values.push(runId);
119
+ paramIndex++;
120
+ }
121
+
122
+ if (workflowName) {
123
+ conditions.push(`workflow_name = $${paramIndex}`);
124
+ values.push(workflowName);
125
+ paramIndex++;
126
+ }
127
+
128
+ const whereClause = conditions.length > 0 ? `WHERE ${conditions.join(' AND ')}` : '';
129
+
130
+ // Get results
131
+ const query = `
132
+ SELECT * FROM ${getTableName({ indexName: TABLE_WORKFLOW_SNAPSHOT, schemaName: this.schema })}
133
+ ${whereClause}
134
+ `;
135
+
136
+ const queryValues = values;
137
+
138
+ const result = await this.client.oneOrNone(query, queryValues);
139
+
140
+ if (!result) {
141
+ return null;
142
+ }
143
+
144
+ return parseWorkflowRun(result);
145
+ } catch (error) {
146
+ throw new MastraError(
147
+ {
148
+ id: 'MASTRA_STORAGE_PG_STORE_GET_WORKFLOW_RUN_BY_ID_FAILED',
149
+ domain: ErrorDomain.STORAGE,
150
+ category: ErrorCategory.THIRD_PARTY,
151
+ details: {
152
+ runId,
153
+ workflowName: workflowName || '',
154
+ },
155
+ },
156
+ error,
157
+ );
158
+ }
159
+ }
160
+
161
+ async getWorkflowRuns({
162
+ workflowName,
163
+ fromDate,
164
+ toDate,
165
+ limit,
166
+ offset,
167
+ resourceId,
168
+ }: {
169
+ workflowName?: string;
170
+ fromDate?: Date;
171
+ toDate?: Date;
172
+ limit?: number;
173
+ offset?: number;
174
+ resourceId?: string;
175
+ } = {}): Promise<WorkflowRuns> {
176
+ try {
177
+ const conditions: string[] = [];
178
+ const values: any[] = [];
179
+ let paramIndex = 1;
180
+
181
+ if (workflowName) {
182
+ conditions.push(`workflow_name = $${paramIndex}`);
183
+ values.push(workflowName);
184
+ paramIndex++;
185
+ }
186
+
187
+ if (resourceId) {
188
+ const hasResourceId = await this.operations.hasColumn(TABLE_WORKFLOW_SNAPSHOT, 'resourceId');
189
+ if (hasResourceId) {
190
+ conditions.push(`"resourceId" = $${paramIndex}`);
191
+ values.push(resourceId);
192
+ paramIndex++;
193
+ } else {
194
+ console.warn(`[${TABLE_WORKFLOW_SNAPSHOT}] resourceId column not found. Skipping resourceId filter.`);
195
+ }
196
+ }
197
+
198
+ if (fromDate) {
199
+ conditions.push(`"createdAt" >= $${paramIndex}`);
200
+ values.push(fromDate);
201
+ paramIndex++;
202
+ }
203
+
204
+ if (toDate) {
205
+ conditions.push(`"createdAt" <= $${paramIndex}`);
206
+ values.push(toDate);
207
+ paramIndex++;
208
+ }
209
+ const whereClause = conditions.length > 0 ? `WHERE ${conditions.join(' AND ')}` : '';
210
+
211
+ let total = 0;
212
+ // Only get total count when using pagination
213
+ if (limit !== undefined && offset !== undefined) {
214
+ const countResult = await this.client.one(
215
+ `SELECT COUNT(*) as count FROM ${getTableName({ indexName: TABLE_WORKFLOW_SNAPSHOT, schemaName: this.schema })} ${whereClause}`,
216
+ values,
217
+ );
218
+ total = Number(countResult.count);
219
+ }
220
+
221
+ // Get results
222
+ const query = `
223
+ SELECT * FROM ${getTableName({ indexName: TABLE_WORKFLOW_SNAPSHOT, schemaName: this.schema })}
224
+ ${whereClause}
225
+ ORDER BY "createdAt" DESC
226
+ ${limit !== undefined && offset !== undefined ? ` LIMIT $${paramIndex} OFFSET $${paramIndex + 1}` : ''}
227
+ `;
228
+
229
+ const queryValues = limit !== undefined && offset !== undefined ? [...values, limit, offset] : values;
230
+
231
+ const result = await this.client.manyOrNone(query, queryValues);
232
+
233
+ const runs = (result || []).map(row => {
234
+ return parseWorkflowRun(row);
235
+ });
236
+
237
+ // Use runs.length as total when not paginating
238
+ return { runs, total: total || runs.length };
239
+ } catch (error) {
240
+ throw new MastraError(
241
+ {
242
+ id: 'MASTRA_STORAGE_PG_STORE_GET_WORKFLOW_RUNS_FAILED',
243
+ domain: ErrorDomain.STORAGE,
244
+ category: ErrorCategory.THIRD_PARTY,
245
+ details: {
246
+ workflowName: workflowName || 'all',
247
+ },
248
+ },
249
+ error,
250
+ );
251
+ }
252
+ }
253
+ }