@mastra/mcp 0.4.1-alpha.2 → 0.4.1-alpha.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/client.ts CHANGED
@@ -5,7 +5,8 @@ import { Client } from '@modelcontextprotocol/sdk/client/index.js';
5
5
  import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
6
6
  import type { SSEClientTransportOptions } from '@modelcontextprotocol/sdk/client/sse.js';
7
7
  import { getDefaultEnvironment, StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
8
- import type { StdioServerParameters } from '@modelcontextprotocol/sdk/client/stdio.js';
8
+ import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
9
+ import type { StreamableHTTPClientTransportOptions } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
9
10
  import { DEFAULT_REQUEST_TIMEOUT_MSEC } from '@modelcontextprotocol/sdk/shared/protocol.js';
10
11
  import type { Protocol } from '@modelcontextprotocol/sdk/shared/protocol.js';
11
12
  import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
@@ -28,18 +29,43 @@ export interface LogMessage {
28
29
 
29
30
  export type LogHandler = (logMessage: LogMessage) => void;
30
31
 
31
- // Omit the fields we want to control from the SDK options
32
- type SSEClientParameters = {
33
- url: URL;
34
- } & SSEClientTransportOptions;
35
-
36
- export type MastraMCPServerDefinition = (StdioServerParameters | SSEClientParameters) & {
32
+ // Base options common to all server definitions
33
+ type BaseServerOptions = {
37
34
  logger?: LogHandler;
38
35
  timeout?: number;
39
36
  capabilities?: ClientCapabilities;
40
37
  enableServerLogs?: boolean;
41
38
  };
42
39
 
40
+ type StdioServerDefinition = BaseServerOptions & {
41
+ command: string; // 'command' is required for Stdio
42
+ args?: string[];
43
+ env?: Record<string, string>;
44
+
45
+ url?: never; // Exclude 'url' for Stdio
46
+ requestInit?: never; // Exclude HTTP options for Stdio
47
+ eventSourceInit?: never; // Exclude HTTP options for Stdio
48
+ reconnectionOptions?: never; // Exclude Streamable HTTP specific options
49
+ sessionId?: never; // Exclude Streamable HTTP specific options
50
+ };
51
+
52
+ // HTTP Server Definition (Streamable HTTP or SSE fallback)
53
+ type HttpServerDefinition = BaseServerOptions & {
54
+ url: URL; // 'url' is required for HTTP
55
+
56
+ command?: never; // Exclude 'command' for HTTP
57
+ args?: never; // Exclude Stdio options for HTTP
58
+ env?: never; // Exclude Stdio options for HTTP
59
+
60
+ // Include relevant options from SDK HTTP transport types
61
+ requestInit?: StreamableHTTPClientTransportOptions['requestInit'];
62
+ eventSourceInit?: SSEClientTransportOptions['eventSourceInit'];
63
+ reconnectionOptions?: StreamableHTTPClientTransportOptions['reconnectionOptions'];
64
+ sessionId?: StreamableHTTPClientTransportOptions['sessionId'];
65
+ };
66
+
67
+ export type MastraMCPServerDefinition = StdioServerDefinition | HttpServerDefinition;
68
+
43
69
  /**
44
70
  * Convert an MCP LoggingLevel to a logger method name that exists in our logger
45
71
  */
@@ -65,11 +91,12 @@ function convertLogLevelToLoggerMethod(level: LoggingLevel): 'debug' | 'info' |
65
91
 
66
92
  export class MastraMCPClient extends MastraBase {
67
93
  name: string;
68
- private transport: Transport;
69
94
  private client: Client;
70
95
  private readonly timeout: number;
71
96
  private logHandler?: LogHandler;
72
97
  private enableServerLogs?: boolean;
98
+ private serverConfig: MastraMCPServerDefinition;
99
+ private transport?: Transport;
73
100
 
74
101
  constructor({
75
102
  name,
@@ -89,22 +116,7 @@ export class MastraMCPClient extends MastraBase {
89
116
  this.timeout = timeout;
90
117
  this.logHandler = server.logger;
91
118
  this.enableServerLogs = server.enableServerLogs ?? true;
92
-
93
- // Extract log handler from server config to avoid passing it to transport
94
- const { logger, enableServerLogs, ...serverConfig } = server;
95
-
96
- if (`url` in serverConfig) {
97
- this.transport = new SSEClientTransport(serverConfig.url, {
98
- requestInit: serverConfig.requestInit,
99
- eventSourceInit: serverConfig.eventSourceInit,
100
- });
101
- } else {
102
- this.transport = new StdioClientTransport({
103
- ...serverConfig,
104
- // without ...getDefaultEnvironment() commands like npx will fail because there will be no PATH env var
105
- env: { ...getDefaultEnvironment(), ...(serverConfig.env || {}) },
106
- });
107
- }
119
+ this.serverConfig = server;
108
120
 
109
121
  this.client = new Client(
110
122
  {
@@ -130,14 +142,16 @@ export class MastraMCPClient extends MastraBase {
130
142
  // Convert MCP logging level to our logger method
131
143
  const loggerMethod = convertLogLevelToLoggerMethod(level);
132
144
 
145
+ const msg = `[${this.name}] ${message}`;
146
+
133
147
  // Log to internal logger
134
- this.logger[loggerMethod](message, details);
148
+ this.logger[loggerMethod](msg, details);
135
149
 
136
150
  // Send to registered handler if available
137
151
  if (this.logHandler) {
138
152
  this.logHandler({
139
153
  level,
140
- message,
154
+ message: msg,
141
155
  timestamp: new Date(),
142
156
  serverName: this.name,
143
157
  details,
@@ -164,46 +178,135 @@ export class MastraMCPClient extends MastraBase {
164
178
  }
165
179
  }
166
180
 
181
+ private async connectStdio(command: string) {
182
+ this.log('debug', `Using Stdio transport for command: ${command}`);
183
+ try {
184
+ this.transport = new StdioClientTransport({
185
+ command,
186
+ args: this.serverConfig.args,
187
+ env: { ...getDefaultEnvironment(), ...(this.serverConfig.env || {}) },
188
+ });
189
+ await this.client.connect(this.transport, { timeout: this.serverConfig.timeout ?? this.timeout });
190
+ this.log('debug', `Successfully connected to MCP server via Stdio`);
191
+ } catch (e) {
192
+ this.log('error', e instanceof Error ? e.stack || e.message : JSON.stringify(e));
193
+ throw e;
194
+ }
195
+ }
196
+
197
+ private async connectHttp(url: URL) {
198
+ const { requestInit, eventSourceInit } = this.serverConfig;
199
+
200
+ this.log('debug', `Attempting to connect to URL: ${url}`);
201
+
202
+ // Assume /sse means sse.
203
+ let shouldTrySSE = url.pathname.endsWith(`/sse`);
204
+
205
+ if (!shouldTrySSE) {
206
+ try {
207
+ // Try Streamable HTTP transport first
208
+ this.log('debug', 'Trying Streamable HTTP transport...');
209
+ const streamableTransport = new StreamableHTTPClientTransport(url, {
210
+ requestInit,
211
+ reconnectionOptions: this.serverConfig.reconnectionOptions,
212
+ sessionId: this.serverConfig.sessionId,
213
+ });
214
+ await this.client.connect(streamableTransport, {
215
+ timeout:
216
+ // this is hardcoded to 3s because the long default timeout would be extremely slow for sse backwards compat (60s)
217
+ 3000,
218
+ });
219
+ this.transport = streamableTransport;
220
+ this.log('debug', 'Successfully connected using Streamable HTTP transport.');
221
+ } catch (error) {
222
+ this.log('debug', `Streamable HTTP transport failed: ${error}`);
223
+ shouldTrySSE = true;
224
+ }
225
+ }
226
+
227
+ if (shouldTrySSE) {
228
+ this.log('debug', 'Falling back to deprecated HTTP+SSE transport...');
229
+ try {
230
+ // Fallback to SSE transport
231
+ const sseTransport = new SSEClientTransport(url, { requestInit, eventSourceInit });
232
+ await this.client.connect(sseTransport, { timeout: this.serverConfig.timeout ?? this.timeout });
233
+ this.transport = sseTransport;
234
+ this.log('debug', 'Successfully connected using deprecated HTTP+SSE transport.');
235
+ } catch (sseError) {
236
+ this.log(
237
+ 'error',
238
+ `Failed to connect with SSE transport after failing to connect to Streamable HTTP transport first. SSE error: ${sseError}`,
239
+ );
240
+ throw new Error('Could not connect to server with any available HTTP transport');
241
+ }
242
+ }
243
+ }
244
+
167
245
  private isConnected = false;
168
246
 
169
247
  async connect() {
170
248
  if (this.isConnected) return;
171
- try {
172
- this.log('debug', `Connecting to MCP server`);
173
- await this.client.connect(this.transport, {
174
- timeout: this.timeout,
175
- });
176
- this.isConnected = true;
177
- const originalOnClose = this.client.onclose;
178
- this.client.onclose = () => {
179
- this.log('debug', `MCP server connection closed`);
180
- this.isConnected = false;
181
- if (typeof originalOnClose === `function`) {
182
- originalOnClose();
183
- }
184
- };
185
- asyncExitHook(
186
- async () => {
187
- this.log('debug', `Disconnecting MCP server during exit`);
188
- await this.disconnect();
189
- },
190
- { wait: 5000 },
191
- );
192
249
 
193
- process.on('SIGTERM', () => gracefulExit());
194
- this.log('info', `Successfully connected to MCP server`);
195
- } catch (e) {
196
- this.log('error', `Failed connecting to MCP server`, {
197
- error: e instanceof Error ? e.stack : JSON.stringify(e, null, 2),
198
- });
250
+ const { command, url } = this.serverConfig;
251
+
252
+ if (command) {
253
+ await this.connectStdio(command);
254
+ } else if (url) {
255
+ await this.connectHttp(url);
256
+ } else {
257
+ throw new Error('Server configuration must include either a command or a url.');
258
+ }
259
+
260
+ this.isConnected = true;
261
+ const originalOnClose = this.client.onclose;
262
+ this.client.onclose = () => {
263
+ this.log('debug', `MCP server connection closed`);
199
264
  this.isConnected = false;
200
- throw e;
265
+ if (typeof originalOnClose === `function`) {
266
+ originalOnClose();
267
+ }
268
+ };
269
+ asyncExitHook(
270
+ async () => {
271
+ this.log('debug', `Disconnecting MCP server during exit`);
272
+ await this.disconnect();
273
+ },
274
+ { wait: 5000 },
275
+ );
276
+
277
+ process.on('SIGTERM', () => gracefulExit());
278
+ this.log('debug', `Successfully connected to MCP server`);
279
+ }
280
+
281
+ /**
282
+ * Get the current session ID if using the Streamable HTTP transport.
283
+ * Returns undefined if not connected or not using Streamable HTTP.
284
+ */
285
+ get sessionId(): string | undefined {
286
+ if (this.transport instanceof StreamableHTTPClientTransport) {
287
+ return this.transport.sessionId;
201
288
  }
289
+ return undefined;
202
290
  }
203
291
 
204
292
  async disconnect() {
293
+ if (!this.transport) {
294
+ this.log('debug', 'Disconnect called but no transport was connected.');
295
+ return;
296
+ }
205
297
  this.log('debug', `Disconnecting from MCP server`);
206
- return await this.client.close();
298
+ try {
299
+ await this.transport.close();
300
+ this.log('debug', 'Successfully disconnected from MCP server');
301
+ } catch (e) {
302
+ this.log('error', 'Error during MCP server disconnect', {
303
+ error: e instanceof Error ? e.stack : JSON.stringify(e, null, 2),
304
+ });
305
+ throw e;
306
+ } finally {
307
+ this.transport = undefined;
308
+ this.isConnected = false;
309
+ }
207
310
  }
208
311
 
209
312
  // TODO: do the type magic to return the right method type. Right now we get infinitely deep infered type errors from Zod without using "any"
@@ -244,6 +244,7 @@ describe('MCPConfiguration', () => {
244
244
  slowServer: {
245
245
  command: 'node',
246
246
  args: ['-e', 'setTimeout(() => process.exit(0), 65000)'], // Simulate a server that takes 65 seconds to start
247
+ timeout: 1000,
247
248
  },
248
249
  },
249
250
  });
@@ -252,14 +253,14 @@ describe('MCPConfiguration', () => {
252
253
  await slowConfig.disconnect();
253
254
  });
254
255
 
255
- it('timeout should be longer than default timeout', async () => {
256
+ it('timeout should be longer than configured timeout', async () => {
256
257
  const slowConfig = new MCPConfiguration({
257
258
  id: 'test-slow-server',
258
- timeout: 70000,
259
+ timeout: 2000,
259
260
  servers: {
260
261
  slowServer: {
261
262
  command: 'node',
262
- args: ['-e', 'setTimeout(() => process.exit(0), 65000)'], // Simulate a server that takes 65 seconds to start
263
+ args: ['-e', 'setTimeout(() => process.exit(0), 1000)'], // Simulate a server that takes 1 second to start
263
264
  },
264
265
  },
265
266
  });
@@ -270,22 +271,6 @@ describe('MCPConfiguration', () => {
270
271
  await slowConfig.disconnect();
271
272
  });
272
273
 
273
- it('should respect custom timeout configuration', async () => {
274
- const quickConfig = new MCPConfiguration({
275
- id: 'test-quick-timeout',
276
- timeout: 1000, // Very short global timeout
277
- servers: {
278
- slowServer: {
279
- command: 'node',
280
- args: ['-e', 'setTimeout(() => process.exit(0), 30000)'], // Takes 30 seconds to exit
281
- },
282
- },
283
- });
284
-
285
- await expect(quickConfig.getTools()).rejects.toThrow(/Request timed out/);
286
- await quickConfig.disconnect();
287
- });
288
-
289
274
  it('should respect per-server timeout configuration', async () => {
290
275
  const mixedConfig = new MCPConfiguration({
291
276
  id: 'test-mixed-timeout',
@@ -89,6 +89,20 @@ To fix this you have three different options:
89
89
  return connectedToolsets;
90
90
  }
91
91
 
92
+ /**
93
+ * Get the current session IDs for all connected MCP clients using the Streamable HTTP transport.
94
+ * Returns an object mapping server names to their session IDs.
95
+ */
96
+ get sessionIds(): Record<string, string> {
97
+ const sessionIds: Record<string, string> = {};
98
+ for (const [serverName, client] of this.mcpClientsById.entries()) {
99
+ if (client.sessionId) {
100
+ sessionIds[serverName] = client.sessionId;
101
+ }
102
+ }
103
+ return sessionIds;
104
+ }
105
+
92
106
  private mcpClientsById = new Map<string, MastraMCPClient>();
93
107
  private async getConnectedClient(name: string, config: MastraMCPServerDefinition) {
94
108
  const exists = this.mcpClientsById.has(name);
@@ -117,7 +131,9 @@ To fix this you have three different options:
117
131
  this.logger.error(`MCPConfiguration errored connecting to MCP server ${name}`, {
118
132
  error: e instanceof Error ? e.message : String(e),
119
133
  });
120
- throw new Error(`Failed to connect to MCP server ${name}: ${e instanceof Error ? e.message : String(e)}`);
134
+ throw new Error(
135
+ `Failed to connect to MCP server ${name}: ${e instanceof Error ? e.stack || e.message : String(e)}`,
136
+ );
121
137
  }
122
138
 
123
139
  this.logger.debug(`Connected to ${name} MCP server`);