@mastra/mcp 0.4.3 → 0.5.0-alpha.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
- import http from 'http';
1
+ import http from 'node:http';
2
2
  import path from 'path';
3
3
  import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js';
4
4
  import { describe, it, expect, beforeAll, afterAll, afterEach, vi } from 'vitest';
@@ -31,11 +31,11 @@ describe('MCPServer', () => {
31
31
  });
32
32
  });
33
33
 
34
- await new Promise(resolve => httpServer.listen(PORT, resolve));
34
+ await new Promise<void>(resolve => httpServer.listen(PORT, () => resolve()));
35
35
  });
36
36
 
37
37
  afterAll(async () => {
38
- await new Promise(resolve => httpServer.close(resolve));
38
+ await new Promise<void>(resolve => httpServer.close(() => resolve()));
39
39
  });
40
40
 
41
41
  describe('MCPServer SSE transport', () => {
@@ -77,7 +77,6 @@ describe('MCPServer', () => {
77
77
  });
78
78
 
79
79
  it('should return 503 if message sent before SSE connection', async () => {
80
- // Manually clear the SSE transport
81
80
  (server as any).sseTransport = undefined;
82
81
  const res = await fetch(`http://localhost:${PORT}/message`, {
83
82
  method: 'POST',
package/src/server.ts CHANGED
@@ -1,29 +1,24 @@
1
- import { isVercelTool, isZodType, resolveSerializedZodOutput } from '@mastra/core';
1
+ import { randomUUID } from 'crypto';
2
+ import type * as http from 'node:http';
3
+ import type { InternalCoreTool } from '@mastra/core';
4
+ import { makeCoreTool } from '@mastra/core';
2
5
  import type { ToolsInput } from '@mastra/core/agent';
6
+ import { MCPServerBase } from '@mastra/core/mcp';
7
+ import type { MCPServerSSEOptions, ConvertedTool } from '@mastra/core/mcp';
8
+ import { RuntimeContext } from '@mastra/core/runtime-context';
3
9
  import { Server } from '@modelcontextprotocol/sdk/server/index.js';
4
10
  import { SSEServerTransport } from '@modelcontextprotocol/sdk/server/sse.js';
5
11
  import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js';
12
+ import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
13
+ import type { StreamableHTTPServerTransportOptions } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
6
14
  import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js';
7
- import jsonSchemaToZod from 'json-schema-to-zod';
8
15
  import { z } from 'zod';
9
- import { zodToJsonSchema } from 'zod-to-json-schema';
10
- import { createLogger } from './logger';
11
16
 
12
- const logger = createLogger();
13
-
14
- type ConvertedTool = {
15
- name: string;
16
- description?: string;
17
- inputSchema: any;
18
- zodSchema: z.ZodTypeAny;
19
- execute: any;
20
- };
21
-
22
- export class MCPServer {
17
+ export class MCPServer extends MCPServerBase {
23
18
  private server: Server;
24
- private convertedTools: Record<string, ConvertedTool>;
25
19
  private stdioTransport?: StdioServerTransport;
26
20
  private sseTransport?: SSEServerTransport;
21
+ private streamableHTTPTransport?: StreamableHTTPServerTransport;
27
22
 
28
23
  /**
29
24
  * Get the current stdio transport.
@@ -40,10 +35,10 @@ export class MCPServer {
40
35
  }
41
36
 
42
37
  /**
43
- * Get a read-only view of the registered tools (for testing/introspection).
38
+ * Get the current streamable HTTP transport.
44
39
  */
45
- tools(): Readonly<Record<string, ConvertedTool>> {
46
- return this.convertedTools;
40
+ public getStreamableHTTPTransport(): StreamableHTTPServerTransport | undefined {
41
+ return this.streamableHTTPTransport;
47
42
  }
48
43
 
49
44
  /**
@@ -53,9 +48,11 @@ export class MCPServer {
53
48
  * @param opts.tools - Tool definitions to register
54
49
  */
55
50
  constructor({ name, version, tools }: { name: string; version: string; tools: ToolsInput }) {
51
+ super({ name, version, tools });
52
+
56
53
  this.server = new Server({ name, version }, { capabilities: { tools: {}, logging: { enabled: true } } });
57
- this.convertedTools = this.convertTools(tools);
58
- void logger.info(
54
+
55
+ this.logger.info(
59
56
  `Initialized MCPServer '${name}' v${version} with tools: ${Object.keys(this.convertedTools).join(', ')}`,
60
57
  );
61
58
 
@@ -68,55 +65,39 @@ export class MCPServer {
68
65
  * @param tools Tool definitions
69
66
  * @returns Converted tools registry
70
67
  */
71
- private convertTools(tools: ToolsInput): Record<string, ConvertedTool> {
68
+ convertTools(tools: ToolsInput): Record<string, ConvertedTool> {
72
69
  const convertedTools: Record<string, ConvertedTool> = {};
73
70
  for (const toolName of Object.keys(tools)) {
74
- let inputSchema: any;
75
- let zodSchema: z.ZodTypeAny;
76
71
  const toolInstance = tools[toolName];
77
72
  if (!toolInstance) {
78
- void logger.warning(`Tool instance for '${toolName}' is undefined. Skipping.`);
73
+ this.logger.warn(`Tool instance for '${toolName}' is undefined. Skipping.`);
79
74
  continue;
80
75
  }
76
+
81
77
  if (typeof toolInstance.execute !== 'function') {
82
- void logger.warning(`Tool '${toolName}' does not have a valid execute function. Skipping.`);
78
+ this.logger.warn(`Tool '${toolName}' does not have a valid execute function. Skipping.`);
83
79
  continue;
84
80
  }
85
- // Vercel tools: .parameters is either Zod or JSON schema
86
- if (isVercelTool(toolInstance)) {
87
- if (isZodType(toolInstance.parameters)) {
88
- zodSchema = toolInstance.parameters;
89
- inputSchema = zodToJsonSchema(zodSchema);
90
- } else if (typeof toolInstance.parameters === 'object') {
91
- zodSchema = resolveSerializedZodOutput(jsonSchemaToZod(toolInstance.parameters));
92
- inputSchema = toolInstance.parameters;
93
- } else {
94
- zodSchema = z.object({});
95
- inputSchema = zodToJsonSchema(zodSchema);
96
- }
97
- } else {
98
- // Mastra tools: .inputSchema is always Zod
99
- zodSchema = toolInstance?.inputSchema ?? z.object({});
100
- inputSchema = zodToJsonSchema(zodSchema);
101
- }
102
81
 
103
- // Wrap execute to support both signatures (typed, returns Promise<any>)
104
- const execute: (args: any, execOptions?: any) => Promise<any> = async (args, execOptions) => {
105
- if (isVercelTool(toolInstance)) {
106
- return (await toolInstance.execute?.(args, execOptions)) ?? undefined;
107
- }
108
- return (await toolInstance.execute?.({ context: args }, execOptions)) ?? undefined;
82
+ const options = {
83
+ name: toolName,
84
+ runtimeContext: new RuntimeContext(),
85
+ mastra: this.mastra,
86
+ logger: this.logger,
87
+ description: toolInstance?.description,
109
88
  };
89
+
90
+ const coreTool = makeCoreTool(toolInstance, options) as InternalCoreTool;
91
+
110
92
  convertedTools[toolName] = {
111
93
  name: toolName,
112
- description: toolInstance?.description,
113
- inputSchema,
114
- zodSchema,
115
- execute,
94
+ description: coreTool.description,
95
+ parameters: coreTool.parameters,
96
+ execute: coreTool.execute,
116
97
  };
117
- void logger.info(`Registered tool: '${toolName}' [${toolInstance?.description || 'No description'}]`);
98
+ this.logger.info(`Registered tool: '${toolName}' [${toolInstance?.description || 'No description'}]`);
118
99
  }
119
- void logger.info(`Total tools registered: ${Object.keys(convertedTools).length}`);
100
+ this.logger.info(`Total tools registered: ${Object.keys(convertedTools).length}`);
120
101
  return convertedTools;
121
102
  }
122
103
 
@@ -125,12 +106,12 @@ export class MCPServer {
125
106
  */
126
107
  private registerListToolsHandler() {
127
108
  this.server.setRequestHandler(ListToolsRequestSchema, async () => {
128
- await logger.debug('Handling ListTools request');
109
+ this.logger.debug('Handling ListTools request');
129
110
  return {
130
111
  tools: Object.values(this.convertedTools).map(tool => ({
131
112
  name: tool.name,
132
113
  description: tool.description,
133
- inputSchema: tool.inputSchema,
114
+ inputSchema: tool.parameters.jsonSchema,
134
115
  })),
135
116
  };
136
117
  });
@@ -145,17 +126,36 @@ export class MCPServer {
145
126
  try {
146
127
  const tool = this.convertedTools[request.params.name];
147
128
  if (!tool) {
148
- await logger.warning(`CallTool: Unknown tool '${request.params.name}' requested.`);
129
+ this.logger.warn(`CallTool: Unknown tool '${request.params.name}' requested.`);
149
130
  return {
150
131
  content: [{ type: 'text', text: `Unknown tool: ${request.params.name}` }],
151
132
  isError: true,
152
133
  };
153
134
  }
154
- await logger.debug(`CallTool: Invoking '${request.params.name}' with arguments:`, request.params.arguments);
155
- const args = tool.zodSchema.parse(request.params.arguments ?? {});
156
- const result = await tool.execute(args, request.params);
135
+
136
+ this.logger.debug(`CallTool: Invoking '${request.params.name}' with arguments:`, request.params.arguments);
137
+
138
+ const validation = tool.parameters.validate?.(request.params.arguments ?? {});
139
+ if (validation && !validation.success) {
140
+ this.logger.warn(`CallTool: Invalid tool arguments for '${request.params.name}'`, {
141
+ errors: validation.error,
142
+ });
143
+ return {
144
+ content: [{ type: 'text', text: `Invalid tool arguments: ${JSON.stringify(validation.error)}` }],
145
+ isError: true,
146
+ };
147
+ }
148
+ if (!tool.execute) {
149
+ this.logger.warn(`CallTool: Tool '${request.params.name}' does not have an execute function.`);
150
+ return {
151
+ content: [{ type: 'text', text: `Tool '${request.params.name}' does not have an execute function.` }],
152
+ isError: true,
153
+ };
154
+ }
155
+
156
+ const result = await tool.execute(validation?.value, { messages: [], toolCallId: '' });
157
157
  const duration = Date.now() - startTime;
158
- await logger.info(`Tool '${request.params.name}' executed successfully in ${duration}ms.`);
158
+ this.logger.info(`Tool '${request.params.name}' executed successfully in ${duration}ms.`);
159
159
  return {
160
160
  content: [
161
161
  {
@@ -168,7 +168,7 @@ export class MCPServer {
168
168
  } catch (error) {
169
169
  const duration = Date.now() - startTime;
170
170
  if (error instanceof z.ZodError) {
171
- await logger.warning('Invalid tool arguments', {
171
+ this.logger.warn('Invalid tool arguments', {
172
172
  tool: request.params.name,
173
173
  errors: error.errors,
174
174
  duration: `${duration}ms`,
@@ -183,7 +183,7 @@ export class MCPServer {
183
183
  isError: true,
184
184
  };
185
185
  }
186
- await logger.error(`Tool execution failed: ${request.params.name}`, error);
186
+ this.logger.error(`Tool execution failed: ${request.params.name}`, { error });
187
187
  return {
188
188
  content: [{ type: 'text', text: `Error: ${error instanceof Error ? error.message : String(error)}` }],
189
189
  isError: true,
@@ -195,10 +195,10 @@ export class MCPServer {
195
195
  /**
196
196
  * Start the MCP server using stdio transport (for Windsurf integration).
197
197
  */
198
- async startStdio() {
198
+ public async startStdio(): Promise<void> {
199
199
  this.stdioTransport = new StdioServerTransport();
200
200
  await this.server.connect(this.stdioTransport);
201
- await logger.info('Started MCP Server (stdio)');
201
+ this.logger.info('Started MCP Server (stdio)');
202
202
  }
203
203
 
204
204
  /**
@@ -211,43 +211,135 @@ export class MCPServer {
211
211
  * @param req Incoming HTTP request
212
212
  * @param res HTTP response (must support .write/.end)
213
213
  */
214
- async startSSE({
214
+ public async startSSE({ url, ssePath, messagePath, req, res }: MCPServerSSEOptions): Promise<void> {
215
+ if (url.pathname === ssePath) {
216
+ await this.connectSSE({
217
+ messagePath,
218
+ res,
219
+ });
220
+ } else if (url.pathname === messagePath) {
221
+ this.logger.debug('Received message');
222
+ if (!this.sseTransport) {
223
+ res.writeHead(503);
224
+ res.end('SSE connection not established');
225
+ return;
226
+ }
227
+ await this.sseTransport.handlePostMessage(req, res);
228
+ } else {
229
+ this.logger.debug('Unknown path:', { path: url.pathname });
230
+ res.writeHead(404);
231
+ res.end();
232
+ }
233
+ }
234
+
235
+ /**
236
+ * Handles MCP-over-StreamableHTTP protocol for user-provided HTTP servers.
237
+ * Call this from your HTTP server for the streamable HTTP endpoint.
238
+ *
239
+ * @param url Parsed URL of the incoming request
240
+ * @param httpPath Path for establishing the streamable HTTP connection (e.g. '/mcp')
241
+ * @param req Incoming HTTP request
242
+ * @param res HTTP response (must support .write/.end)
243
+ * @param options Optional options to pass to the transport (e.g. sessionIdGenerator)
244
+ */
245
+ public async startHTTP({
215
246
  url,
216
- ssePath,
217
- messagePath,
247
+ httpPath,
218
248
  req,
219
249
  res,
250
+ options = { sessionIdGenerator: () => randomUUID() },
220
251
  }: {
221
252
  url: URL;
222
- ssePath: string;
223
- messagePath: string;
224
- req: any;
225
- res: any;
253
+ httpPath: string;
254
+ req: http.IncomingMessage;
255
+ res: http.ServerResponse<http.IncomingMessage>;
256
+ options?: StreamableHTTPServerTransportOptions;
226
257
  }) {
227
- if (url.pathname === ssePath) {
228
- await logger.debug('Received SSE connection');
229
- this.sseTransport = new SSEServerTransport(messagePath, res);
230
- await this.server.connect(this.sseTransport);
258
+ if (url.pathname === httpPath) {
259
+ this.streamableHTTPTransport = new StreamableHTTPServerTransport(options);
260
+ try {
261
+ await this.server.connect(this.streamableHTTPTransport);
262
+ } catch (error) {
263
+ this.logger.error('Error connecting to MCP server', { error });
264
+ res.writeHead(500);
265
+ res.end('Error connecting to MCP server');
266
+ return;
267
+ }
268
+
269
+ try {
270
+ await this.streamableHTTPTransport.handleRequest(req, res);
271
+ } catch (error) {
272
+ this.logger.error('Error handling MCP connection', { error });
273
+ res.writeHead(500);
274
+ res.end('Error handling MCP connection');
275
+ return;
276
+ }
231
277
 
232
278
  this.server.onclose = async () => {
279
+ this.streamableHTTPTransport = undefined;
233
280
  await this.server.close();
234
- this.sseTransport = undefined;
235
281
  };
282
+
236
283
  res.on('close', () => {
237
- this.sseTransport = undefined;
284
+ this.streamableHTTPTransport = undefined;
238
285
  });
239
- } else if (url.pathname === messagePath) {
240
- await logger.debug('Received message');
241
- if (!this.sseTransport) {
242
- res.writeHead(503);
243
- res.end('SSE connection not established');
244
- return;
245
- }
246
- await this.sseTransport.handlePostMessage(req, res);
247
286
  } else {
248
- await logger.debug('Unknown path:', url.pathname);
249
287
  res.writeHead(404);
250
288
  res.end();
251
289
  }
252
290
  }
291
+
292
+ public async handlePostMessage(req: http.IncomingMessage, res: http.ServerResponse<http.IncomingMessage>) {
293
+ if (!this.sseTransport) {
294
+ res.writeHead(503);
295
+ res.end('SSE connection not established');
296
+ return;
297
+ }
298
+ await this.sseTransport.handlePostMessage(req, res);
299
+ }
300
+
301
+ public async connectSSE({
302
+ messagePath,
303
+ res,
304
+ }: {
305
+ messagePath: string;
306
+ res: http.ServerResponse<http.IncomingMessage>;
307
+ }) {
308
+ this.logger.debug('Received SSE connection');
309
+ this.sseTransport = new SSEServerTransport(messagePath, res);
310
+ await this.server.connect(this.sseTransport);
311
+
312
+ this.server.onclose = async () => {
313
+ this.sseTransport = undefined;
314
+ await this.server.close();
315
+ };
316
+
317
+ res.on('close', () => {
318
+ this.sseTransport = undefined;
319
+ });
320
+ }
321
+
322
+ /**
323
+ * Close the MCP server and all its connections
324
+ */
325
+ async close() {
326
+ try {
327
+ if (this.stdioTransport) {
328
+ await this.stdioTransport.close?.();
329
+ this.stdioTransport = undefined;
330
+ }
331
+ if (this.sseTransport) {
332
+ await this.sseTransport.close?.();
333
+ this.sseTransport = undefined;
334
+ }
335
+ if (this.streamableHTTPTransport) {
336
+ await this.streamableHTTPTransport.close?.();
337
+ this.streamableHTTPTransport = undefined;
338
+ }
339
+ await this.server.close();
340
+ this.logger.info('MCP server closed.');
341
+ } catch (error) {
342
+ this.logger.error('Error closing MCP server:', { error });
343
+ }
344
+ }
253
345
  }