model-context-protocol-rb 0.6.0 → 0.7.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +26 -2
- data/README.md +174 -978
- data/lib/model_context_protocol/rspec/helpers.rb +54 -0
- data/lib/model_context_protocol/rspec/matchers/be_mcp_error_response.rb +123 -0
- data/lib/model_context_protocol/rspec/matchers/be_valid_mcp_class.rb +103 -0
- data/lib/model_context_protocol/rspec/matchers/be_valid_mcp_prompt_response.rb +126 -0
- data/lib/model_context_protocol/rspec/matchers/be_valid_mcp_resource_response.rb +121 -0
- data/lib/model_context_protocol/rspec/matchers/be_valid_mcp_tool_response.rb +135 -0
- data/lib/model_context_protocol/rspec/matchers/have_audio_content.rb +109 -0
- data/lib/model_context_protocol/rspec/matchers/have_embedded_resource_content.rb +150 -0
- data/lib/model_context_protocol/rspec/matchers/have_image_content.rb +109 -0
- data/lib/model_context_protocol/rspec/matchers/have_message_count.rb +87 -0
- data/lib/model_context_protocol/rspec/matchers/have_message_with_role.rb +152 -0
- data/lib/model_context_protocol/rspec/matchers/have_resource_annotations.rb +135 -0
- data/lib/model_context_protocol/rspec/matchers/have_resource_blob.rb +108 -0
- data/lib/model_context_protocol/rspec/matchers/have_resource_link_content.rb +138 -0
- data/lib/model_context_protocol/rspec/matchers/have_resource_mime_type.rb +103 -0
- data/lib/model_context_protocol/rspec/matchers/have_resource_text.rb +112 -0
- data/lib/model_context_protocol/rspec/matchers/have_structured_content.rb +88 -0
- data/lib/model_context_protocol/rspec/matchers/have_text_content.rb +113 -0
- data/lib/model_context_protocol/rspec/matchers.rb +31 -0
- data/lib/model_context_protocol/rspec.rb +23 -0
- data/lib/model_context_protocol/server/client_logger.rb +1 -1
- data/lib/model_context_protocol/server/configuration.rb +195 -91
- data/lib/model_context_protocol/server/content_helpers.rb +1 -1
- data/lib/model_context_protocol/server/prompt.rb +0 -14
- data/lib/model_context_protocol/server/redis_client_proxy.rb +2 -14
- data/lib/model_context_protocol/server/redis_config.rb +5 -7
- data/lib/model_context_protocol/server/redis_pool_manager.rb +10 -13
- data/lib/model_context_protocol/server/registry.rb +8 -0
- data/lib/model_context_protocol/server/router.rb +279 -4
- data/lib/model_context_protocol/server/server_logger.rb +5 -2
- data/lib/model_context_protocol/server/stdio_configuration.rb +114 -0
- data/lib/model_context_protocol/server/stdio_transport/request_store.rb +0 -41
- data/lib/model_context_protocol/server/streamable_http_configuration.rb +218 -0
- data/lib/model_context_protocol/server/streamable_http_transport/event_counter.rb +0 -13
- data/lib/model_context_protocol/server/streamable_http_transport/notification_queue.rb +0 -41
- data/lib/model_context_protocol/server/streamable_http_transport/request_store.rb +0 -103
- data/lib/model_context_protocol/server/streamable_http_transport/server_request_store.rb +0 -64
- data/lib/model_context_protocol/server/streamable_http_transport/session_message_queue.rb +0 -58
- data/lib/model_context_protocol/server/streamable_http_transport/session_store.rb +17 -31
- data/lib/model_context_protocol/server/streamable_http_transport/stream_registry.rb +0 -34
- data/lib/model_context_protocol/server/streamable_http_transport.rb +192 -56
- data/lib/model_context_protocol/server/tool.rb +67 -1
- data/lib/model_context_protocol/server.rb +203 -262
- data/lib/model_context_protocol/version.rb +1 -1
- data/lib/model_context_protocol.rb +4 -1
- data/lib/puma/plugin/mcp.rb +39 -0
- data/tasks/mcp.rake +26 -0
- data/tasks/templates/dev-http-puma.erb +251 -0
- data/tasks/templates/dev-http.erb +166 -184
- data/tasks/templates/dev.erb +29 -7
- metadata +26 -2
|
@@ -43,14 +43,6 @@ module ModelContextProtocol
|
|
|
43
43
|
@local_streams.key?(session_id)
|
|
44
44
|
end
|
|
45
45
|
|
|
46
|
-
def get_stream_server(session_id)
|
|
47
|
-
@redis.get("#{STREAM_KEY_PREFIX}#{session_id}")
|
|
48
|
-
end
|
|
49
|
-
|
|
50
|
-
def stream_active?(session_id)
|
|
51
|
-
@redis.exists("#{STREAM_KEY_PREFIX}#{session_id}") == 1
|
|
52
|
-
end
|
|
53
|
-
|
|
54
46
|
def refresh_heartbeat(session_id)
|
|
55
47
|
@redis.multi do |multi|
|
|
56
48
|
multi.set("#{HEARTBEAT_KEY_PREFIX}#{session_id}", Time.now.to_f, ex: @ttl)
|
|
@@ -88,32 +80,6 @@ module ModelContextProtocol
|
|
|
88
80
|
|
|
89
81
|
expired_sessions
|
|
90
82
|
end
|
|
91
|
-
|
|
92
|
-
def get_stale_streams(max_age_seconds = 90)
|
|
93
|
-
current_time = Time.now.to_f
|
|
94
|
-
stale_streams = []
|
|
95
|
-
|
|
96
|
-
# Get all heartbeat keys
|
|
97
|
-
heartbeat_keys = @redis.keys("#{HEARTBEAT_KEY_PREFIX}*")
|
|
98
|
-
|
|
99
|
-
return stale_streams if heartbeat_keys.empty?
|
|
100
|
-
|
|
101
|
-
# Get all heartbeat timestamps
|
|
102
|
-
heartbeat_values = @redis.mget(heartbeat_keys)
|
|
103
|
-
|
|
104
|
-
heartbeat_keys.each_with_index do |key, index|
|
|
105
|
-
next unless heartbeat_values[index]
|
|
106
|
-
|
|
107
|
-
session_id = key.sub(HEARTBEAT_KEY_PREFIX, "")
|
|
108
|
-
last_heartbeat = heartbeat_values[index].to_f
|
|
109
|
-
|
|
110
|
-
if current_time - last_heartbeat > max_age_seconds
|
|
111
|
-
stale_streams << session_id
|
|
112
|
-
end
|
|
113
|
-
end
|
|
114
|
-
|
|
115
|
-
stale_streams
|
|
116
|
-
end
|
|
117
83
|
end
|
|
118
84
|
end
|
|
119
85
|
end
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
require "json"
|
|
2
2
|
require "securerandom"
|
|
3
|
+
require "concurrent"
|
|
3
4
|
|
|
4
5
|
module ModelContextProtocol
|
|
5
6
|
class Server::StreamableHttpTransport
|
|
@@ -28,21 +29,20 @@ module ModelContextProtocol
|
|
|
28
29
|
@redis_pool = ModelContextProtocol::Server::RedisConfig.pool
|
|
29
30
|
@redis = ModelContextProtocol::Server::RedisClientProxy.new(@redis_pool)
|
|
30
31
|
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
@
|
|
34
|
-
@
|
|
35
|
-
@
|
|
36
|
-
@allowed_origins = transport_options.fetch(:allowed_origins, ["http://localhost", "https://localhost", "http://127.0.0.1", "https://127.0.0.1"])
|
|
32
|
+
@require_sessions = @configuration.require_sessions
|
|
33
|
+
# Use Concurrent::Map for thread-safe access from multiple request threads
|
|
34
|
+
@session_protocol_versions = Concurrent::Map.new
|
|
35
|
+
@validate_origin = @configuration.validate_origin
|
|
36
|
+
@allowed_origins = @configuration.allowed_origins
|
|
37
37
|
|
|
38
|
-
@session_store = SessionStore.new(@redis, ttl:
|
|
38
|
+
@session_store = SessionStore.new(@redis, ttl: @configuration.session_ttl)
|
|
39
39
|
@server_instance = "#{Socket.gethostname}-#{Process.pid}-#{SecureRandom.hex(4)}"
|
|
40
40
|
@stream_registry = StreamRegistry.new(@redis, @server_instance)
|
|
41
41
|
@notification_queue = NotificationQueue.new(@redis, @server_instance)
|
|
42
42
|
@event_counter = EventCounter.new(@redis, @server_instance)
|
|
43
43
|
@request_store = RequestStore.new(@redis, @server_instance)
|
|
44
44
|
@server_request_store = ServerRequestStore.new(@redis, @server_instance)
|
|
45
|
-
@ping_timeout =
|
|
45
|
+
@ping_timeout = @configuration.ping_timeout
|
|
46
46
|
|
|
47
47
|
@message_poller = MessagePoller.new(@redis, @stream_registry, @client_logger) do |stream, message|
|
|
48
48
|
send_to_stream(stream, message)
|
|
@@ -55,7 +55,8 @@ module ModelContextProtocol
|
|
|
55
55
|
end
|
|
56
56
|
|
|
57
57
|
# Gracefully shut down the transport by stopping background threads and cleaning up resources
|
|
58
|
-
# Closes all active streams
|
|
58
|
+
# Closes all active streams. Redis entries are left to expire naturally (they have TTLs).
|
|
59
|
+
# This method is signal-safe and avoids mutex operations.
|
|
59
60
|
def shutdown
|
|
60
61
|
@server_logger.info("Shutting down StreamableHttpTransport")
|
|
61
62
|
|
|
@@ -67,32 +68,31 @@ module ModelContextProtocol
|
|
|
67
68
|
@stream_monitor_thread.join(5)
|
|
68
69
|
end
|
|
69
70
|
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
71
|
+
# Close streams directly without Redis cleanup (signal-safe).
|
|
72
|
+
# Redis entries will expire naturally via TTL.
|
|
73
|
+
@stream_registry.get_all_local_streams.each do |session_id, stream|
|
|
74
|
+
begin
|
|
75
|
+
stream.close
|
|
76
|
+
rescue IOError, Errno::EPIPE, Errno::ECONNRESET, Errno::ENOTCONN, Errno::EBADF
|
|
77
|
+
# Stream already closed, ignore
|
|
78
|
+
end
|
|
79
|
+
@server_logger.info("← SSE stream [closed] (#{session_id}) [shutdown]")
|
|
74
80
|
end
|
|
75
81
|
|
|
76
|
-
@redis_pool.checkin(@redis) if @redis_pool && @redis
|
|
77
|
-
|
|
78
82
|
@server_logger.info("StreamableHttpTransport shutdown complete")
|
|
79
83
|
end
|
|
80
84
|
|
|
81
85
|
# Main entry point for handling HTTP requests (POST, GET, DELETE)
|
|
82
86
|
# Routes requests to appropriate handlers and manages the request/response lifecycle
|
|
83
|
-
|
|
87
|
+
# @param env [Hash] Rack environment hash (required)
|
|
88
|
+
# @param session_context [Hash] Per-request context that will be merged with server context
|
|
89
|
+
def handle(env:, session_context: {})
|
|
84
90
|
@server_logger.debug("Handling streamable HTTP transport request")
|
|
85
91
|
|
|
86
|
-
env = @configuration.transport_options[:env]
|
|
87
|
-
|
|
88
|
-
unless env
|
|
89
|
-
raise ArgumentError, "StreamableHTTP transport requires Rack env hash in transport_options"
|
|
90
|
-
end
|
|
91
|
-
|
|
92
92
|
case env["REQUEST_METHOD"]
|
|
93
93
|
when "POST"
|
|
94
94
|
@server_logger.debug("Handling POST request")
|
|
95
|
-
handle_post_request(env)
|
|
95
|
+
handle_post_request(env, session_context: session_context)
|
|
96
96
|
when "GET"
|
|
97
97
|
@server_logger.debug("Handling GET request")
|
|
98
98
|
handle_get_request(env)
|
|
@@ -165,16 +165,11 @@ module ModelContextProtocol
|
|
|
165
165
|
end
|
|
166
166
|
end
|
|
167
167
|
|
|
168
|
-
# Validate HTTP headers for
|
|
169
|
-
# Returns error response if headers are invalid, nil if valid
|
|
170
|
-
def validate_headers(env)
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
if origin && !@allowed_origins.any? { |allowed| origin.start_with?(allowed) }
|
|
174
|
-
error_response = ErrorResponse[id: nil, error: {code: -32600, message: "Origin not allowed"}]
|
|
175
|
-
return {json: error_response.serialized, status: 403}
|
|
176
|
-
end
|
|
177
|
-
end
|
|
168
|
+
# Validate HTTP headers for POST requests: CORS origin, content type, and protocol version.
|
|
169
|
+
# Returns error response hash if headers are invalid, nil if valid.
|
|
170
|
+
def validate_headers(env, session_id: nil)
|
|
171
|
+
origin_error = validate_origin!(env)
|
|
172
|
+
return origin_error if origin_error
|
|
178
173
|
|
|
179
174
|
accept_header = env["HTTP_ACCEPT"]
|
|
180
175
|
if accept_header
|
|
@@ -184,15 +179,53 @@ module ModelContextProtocol
|
|
|
184
179
|
end
|
|
185
180
|
end
|
|
186
181
|
|
|
182
|
+
validate_protocol_version!(env, session_id: session_id)
|
|
183
|
+
end
|
|
184
|
+
|
|
185
|
+
# Validate CORS Origin header against allowed origins.
|
|
186
|
+
# The MCP spec requires servers to validate Origin on all incoming connections
|
|
187
|
+
# to prevent DNS rebinding attacks.
|
|
188
|
+
def validate_origin!(env)
|
|
189
|
+
return nil unless @validate_origin
|
|
190
|
+
|
|
191
|
+
origin = env["HTTP_ORIGIN"]
|
|
192
|
+
if origin && !@allowed_origins.any? { |allowed| origin.start_with?(allowed) }
|
|
193
|
+
error_response = ErrorResponse[id: nil, error: {code: -32600, message: "Origin not allowed"}]
|
|
194
|
+
return {json: error_response.serialized, status: 403}
|
|
195
|
+
end
|
|
196
|
+
|
|
197
|
+
nil
|
|
198
|
+
end
|
|
199
|
+
|
|
200
|
+
# Validate MCP-Protocol-Version header against negotiated version.
|
|
201
|
+
# Per the MCP spec, the server MUST respond with 400 Bad Request for invalid
|
|
202
|
+
# or unsupported protocol versions. When a session_id is provided, validation
|
|
203
|
+
# is scoped to that session's negotiated version.
|
|
204
|
+
def validate_protocol_version!(env, session_id: nil)
|
|
187
205
|
protocol_version = env["HTTP_MCP_PROTOCOL_VERSION"]
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
206
|
+
return nil unless protocol_version
|
|
207
|
+
|
|
208
|
+
# When a session_id is provided, try session-specific validation first.
|
|
209
|
+
# If the session has a known negotiated version, validate strictly against it.
|
|
210
|
+
if session_id
|
|
211
|
+
expected_version = @session_protocol_versions[session_id]
|
|
212
|
+
if expected_version
|
|
213
|
+
if protocol_version != expected_version
|
|
214
|
+
error_response = ErrorResponse[id: nil, error: {code: -32600, message: "Invalid MCP protocol version: #{protocol_version}. Expected: #{expected_version}"}]
|
|
215
|
+
return {json: error_response.serialized, status: 400}
|
|
216
|
+
end
|
|
217
|
+
return nil
|
|
193
218
|
end
|
|
194
219
|
end
|
|
195
220
|
|
|
221
|
+
# Fallback: validate against all known negotiated versions (covers cases
|
|
222
|
+
# where session_id is nil or has no entry, e.g. sessions not required).
|
|
223
|
+
valid_versions = @session_protocol_versions.values.compact.uniq
|
|
224
|
+
unless valid_versions.empty? || valid_versions.include?(protocol_version)
|
|
225
|
+
error_response = ErrorResponse[id: nil, error: {code: -32600, message: "Invalid MCP protocol version: #{protocol_version}. Expected one of: #{valid_versions.join(", ")}"}]
|
|
226
|
+
return {json: error_response.serialized, status: 400}
|
|
227
|
+
end
|
|
228
|
+
|
|
196
229
|
nil
|
|
197
230
|
end
|
|
198
231
|
|
|
@@ -212,13 +245,15 @@ module ModelContextProtocol
|
|
|
212
245
|
|
|
213
246
|
# Handle HTTP POST requests containing JSON-RPC messages
|
|
214
247
|
# Parses request body and routes to initialization or regular request handlers
|
|
215
|
-
|
|
216
|
-
|
|
248
|
+
# @param env [Hash] Rack environment hash
|
|
249
|
+
# @param session_context [Hash] Per-request context for initialization
|
|
250
|
+
def handle_post_request(env, session_context: {})
|
|
251
|
+
session_id = env["HTTP_MCP_SESSION_ID"]
|
|
252
|
+
validation_error = validate_headers(env, session_id: session_id)
|
|
217
253
|
return validation_error if validation_error
|
|
218
254
|
|
|
219
255
|
body_string = env["rack.input"].read
|
|
220
256
|
body = JSON.parse(body_string)
|
|
221
|
-
session_id = env["HTTP_MCP_SESSION_ID"]
|
|
222
257
|
accept_header = env["HTTP_ACCEPT"] || ""
|
|
223
258
|
|
|
224
259
|
log_to_server_with_context(request_id: body["id"]) do |logger|
|
|
@@ -237,7 +272,7 @@ module ModelContextProtocol
|
|
|
237
272
|
end
|
|
238
273
|
|
|
239
274
|
if body["method"] == "initialize"
|
|
240
|
-
handle_initialization(body, accept_header)
|
|
275
|
+
handle_initialization(body, accept_header, session_context: session_context)
|
|
241
276
|
else
|
|
242
277
|
handle_regular_request(body, session_id, accept_header)
|
|
243
278
|
end
|
|
@@ -276,7 +311,10 @@ module ModelContextProtocol
|
|
|
276
311
|
|
|
277
312
|
# Handle MCP initialization requests to establish protocol version and optional sessions
|
|
278
313
|
# Always returns JSON response regardless of Accept header to keep initialization simple
|
|
279
|
-
|
|
314
|
+
# @param body [Hash] Parsed JSON-RPC request body
|
|
315
|
+
# @param accept_header [String] HTTP Accept header value
|
|
316
|
+
# @param session_context [Hash] Per-request context to merge with server context
|
|
317
|
+
def handle_initialization(body, accept_header, session_context: {})
|
|
280
318
|
result = @router.route(body, transport: self)
|
|
281
319
|
response = Response[id: body["id"], result: result.serialized]
|
|
282
320
|
response_headers = {}
|
|
@@ -284,12 +322,17 @@ module ModelContextProtocol
|
|
|
284
322
|
|
|
285
323
|
if @require_sessions
|
|
286
324
|
session_id = SecureRandom.uuid
|
|
325
|
+
# Merge server-level defaults with request-level context
|
|
326
|
+
merged_context = (@configuration.context || {}).merge(session_context)
|
|
287
327
|
@session_store.create_session(session_id, {
|
|
288
328
|
server_instance: @server_instance,
|
|
289
|
-
context:
|
|
329
|
+
context: merged_context,
|
|
290
330
|
created_at: Time.now.to_f,
|
|
291
331
|
negotiated_protocol_version: negotiated_protocol_version
|
|
292
332
|
})
|
|
333
|
+
# Store initial handler names for list_changed detection
|
|
334
|
+
current_handlers = @configuration.registry.handler_names
|
|
335
|
+
@session_store.store_registered_handlers(session_id, **current_handlers)
|
|
293
336
|
response_headers["Mcp-Session-Id"] = session_id
|
|
294
337
|
@session_protocol_versions[session_id] = negotiated_protocol_version
|
|
295
338
|
log_to_server_with_context { |logger| logger.info("Session created: #{session_id} (protocol: #{negotiated_protocol_version})") }
|
|
@@ -314,9 +357,15 @@ module ModelContextProtocol
|
|
|
314
357
|
# Handle regular MCP requests (tools, resources, prompts) with streaming/JSON decision logic
|
|
315
358
|
# Defaults to SSE streaming but returns JSON when client explicitly requests JSON only
|
|
316
359
|
def handle_regular_request(body, session_id, accept_header)
|
|
360
|
+
session_context = {}
|
|
361
|
+
|
|
317
362
|
if @require_sessions
|
|
363
|
+
# Per the MCP spec, servers SHOULD respond to requests without a valid
|
|
364
|
+
# Mcp-Session-Id header (other than initialization) with HTTP 400.
|
|
365
|
+
# The session ID MUST be present on all subsequent requests after initialization,
|
|
366
|
+
# including notifications like notifications/initialized.
|
|
318
367
|
unless session_id && @session_store.session_exists?(session_id)
|
|
319
|
-
if session_id
|
|
368
|
+
if session_id
|
|
320
369
|
error_response = ErrorResponse[id: body["id"], error: {code: -32600, message: "Session terminated"}]
|
|
321
370
|
return {json: error_response.serialized, status: 404}
|
|
322
371
|
else
|
|
@@ -324,6 +373,9 @@ module ModelContextProtocol
|
|
|
324
373
|
return {json: error_response.serialized, status: 400}
|
|
325
374
|
end
|
|
326
375
|
end
|
|
376
|
+
|
|
377
|
+
session_context = @session_store.get_session_context(session_id)
|
|
378
|
+
check_and_notify_handler_changes(session_id)
|
|
327
379
|
end
|
|
328
380
|
|
|
329
381
|
message_type = determine_message_type(body)
|
|
@@ -348,7 +400,7 @@ module ModelContextProtocol
|
|
|
348
400
|
log_to_server_with_context do |logger|
|
|
349
401
|
logger.info("← Notification [accepted]")
|
|
350
402
|
end
|
|
351
|
-
{
|
|
403
|
+
{status: 202}
|
|
352
404
|
|
|
353
405
|
when :request
|
|
354
406
|
if accept_header.include?("text/event-stream")
|
|
@@ -359,9 +411,9 @@ module ModelContextProtocol
|
|
|
359
411
|
"Cache-Control" => "no-cache",
|
|
360
412
|
"Connection" => "keep-alive"
|
|
361
413
|
},
|
|
362
|
-
stream_proc: create_request_response_sse_stream_proc(body, session_id)
|
|
414
|
+
stream_proc: create_request_response_sse_stream_proc(body, session_id, session_context: session_context)
|
|
363
415
|
}
|
|
364
|
-
elsif (result = @router.route(body, request_store: @request_store, session_id: session_id, transport: self))
|
|
416
|
+
elsif (result = @router.route(body, request_store: @request_store, session_id: session_id, transport: self, session_context: session_context))
|
|
365
417
|
response = Response[id: body["id"], result: result.serialized]
|
|
366
418
|
|
|
367
419
|
log_to_server_with_context(request_id: response.id) do |logger|
|
|
@@ -386,6 +438,9 @@ module ModelContextProtocol
|
|
|
386
438
|
# Handle HTTP GET requests to establish persistent SSE connections for notifications
|
|
387
439
|
# Validates session requirements and Accept headers before opening long-lived streams
|
|
388
440
|
def handle_get_request(env)
|
|
441
|
+
origin_error = validate_origin!(env)
|
|
442
|
+
return origin_error if origin_error
|
|
443
|
+
|
|
389
444
|
accept_header = env["HTTP_ACCEPT"] || ""
|
|
390
445
|
unless accept_header.include?("text/event-stream")
|
|
391
446
|
error_response = ErrorResponse[id: nil, error: {code: -32600, message: "Accept header must include text/event-stream"}]
|
|
@@ -393,6 +448,10 @@ module ModelContextProtocol
|
|
|
393
448
|
end
|
|
394
449
|
|
|
395
450
|
session_id = env["HTTP_MCP_SESSION_ID"]
|
|
451
|
+
|
|
452
|
+
protocol_error = validate_protocol_version!(env, session_id: session_id)
|
|
453
|
+
return protocol_error if protocol_error
|
|
454
|
+
|
|
396
455
|
last_event_id = env["HTTP_LAST_EVENT_ID"]
|
|
397
456
|
|
|
398
457
|
if @require_sessions
|
|
@@ -422,10 +481,28 @@ module ModelContextProtocol
|
|
|
422
481
|
# Handle HTTP DELETE requests to clean up sessions and associated resources
|
|
423
482
|
# Removes session data, closes streams, and cleans up request store entries
|
|
424
483
|
def handle_delete_request(env)
|
|
484
|
+
origin_error = validate_origin!(env)
|
|
485
|
+
return origin_error if origin_error
|
|
486
|
+
|
|
425
487
|
session_id = env["HTTP_MCP_SESSION_ID"]
|
|
426
488
|
|
|
489
|
+
protocol_error = validate_protocol_version!(env, session_id: session_id)
|
|
490
|
+
return protocol_error if protocol_error
|
|
491
|
+
|
|
427
492
|
@server_logger.info("→ DELETE /mcp [Session cleanup: #{session_id || "unknown"}]")
|
|
428
493
|
|
|
494
|
+
if @require_sessions
|
|
495
|
+
unless session_id
|
|
496
|
+
error_response = ErrorResponse[id: nil, error: {code: -32600, message: "Invalid or missing session ID"}]
|
|
497
|
+
return {json: error_response.serialized, status: 400}
|
|
498
|
+
end
|
|
499
|
+
|
|
500
|
+
unless @session_store.session_exists?(session_id)
|
|
501
|
+
error_response = ErrorResponse[id: nil, error: {code: -32600, message: "Session terminated"}]
|
|
502
|
+
return {json: error_response.serialized, status: 404}
|
|
503
|
+
end
|
|
504
|
+
end
|
|
505
|
+
|
|
429
506
|
if session_id
|
|
430
507
|
cleanup_session(session_id)
|
|
431
508
|
log_to_server_with_context { |logger| logger.info("Session cleanup: #{session_id}") }
|
|
@@ -441,7 +518,10 @@ module ModelContextProtocol
|
|
|
441
518
|
# Create SSE stream processor for request-response pattern with real-time progress support
|
|
442
519
|
# Opens stream → Executes request → Sends response → Closes stream
|
|
443
520
|
# Enables progress notifications during long-running operations like tool calls
|
|
444
|
-
|
|
521
|
+
# @param request_body [Hash] Parsed JSON-RPC request
|
|
522
|
+
# @param session_id [String, nil] Session ID for this request
|
|
523
|
+
# @param session_context [Hash] Context to pass to handlers
|
|
524
|
+
def create_request_response_sse_stream_proc(request_body, session_id, session_context: {})
|
|
445
525
|
proc do |stream|
|
|
446
526
|
temp_stream_id = "temp-#{SecureRandom.hex(8)}"
|
|
447
527
|
@stream_registry.register_stream(temp_stream_id, stream)
|
|
@@ -452,7 +532,7 @@ module ModelContextProtocol
|
|
|
452
532
|
end
|
|
453
533
|
|
|
454
534
|
begin
|
|
455
|
-
if (result = @router.route(request_body, request_store: @request_store, session_id: session_id, transport: self, stream_id: temp_stream_id))
|
|
535
|
+
if (result = @router.route(request_body, request_store: @request_store, session_id: session_id, transport: self, stream_id: temp_stream_id, session_context: session_context))
|
|
456
536
|
response = Response[id: request_body["id"], result: result.serialized]
|
|
457
537
|
event_id = next_event_id
|
|
458
538
|
send_sse_event(stream, response.serialized, event_id)
|
|
@@ -467,6 +547,16 @@ module ModelContextProtocol
|
|
|
467
547
|
rescue IOError, Errno::EPIPE, Errno::ECONNRESET => e
|
|
468
548
|
@server_logger.debug("Client disconnected during progressive request processing: #{e.class.name}")
|
|
469
549
|
log_to_server_with_context { |logger| logger.info("← SSE stream [closed] (#{temp_stream_id}) [client_disconnected]") }
|
|
550
|
+
rescue ModelContextProtocol::Server::ParameterValidationError => e
|
|
551
|
+
@client_logger.error("Validation error", error: e.message)
|
|
552
|
+
error_response = ErrorResponse[id: request_body["id"], error: {code: -32602, message: e.message}]
|
|
553
|
+
send_sse_event(stream, error_response.serialized, next_event_id)
|
|
554
|
+
close_stream(temp_stream_id, reason: "validation_error")
|
|
555
|
+
rescue => e
|
|
556
|
+
@client_logger.error("Internal error", error: e.message, backtrace: e.backtrace)
|
|
557
|
+
error_response = ErrorResponse[id: request_body["id"], error: {code: -32603, message: e.message}]
|
|
558
|
+
send_sse_event(stream, error_response.serialized, next_event_id)
|
|
559
|
+
close_stream(temp_stream_id, reason: "internal_error")
|
|
470
560
|
ensure
|
|
471
561
|
@stream_registry.unregister_stream(temp_stream_id)
|
|
472
562
|
end
|
|
@@ -525,8 +615,15 @@ module ModelContextProtocol
|
|
|
525
615
|
flush_notifications_to_stream(stream)
|
|
526
616
|
end
|
|
527
617
|
|
|
618
|
+
# Also flush any messages queued in Redis from other server instances
|
|
619
|
+
poll_and_deliver_redis_messages(stream, session_id) if session_id
|
|
620
|
+
|
|
528
621
|
loop do
|
|
529
622
|
break unless stream_connected?(stream)
|
|
623
|
+
|
|
624
|
+
# Poll for queued messages from Redis (cross-server delivery)
|
|
625
|
+
poll_and_deliver_redis_messages(stream, session_id) if session_id
|
|
626
|
+
|
|
530
627
|
sleep 0.1
|
|
531
628
|
end
|
|
532
629
|
ensure
|
|
@@ -688,12 +785,6 @@ module ModelContextProtocol
|
|
|
688
785
|
@server_request_store.cleanup_session_requests(session_id)
|
|
689
786
|
end
|
|
690
787
|
|
|
691
|
-
# Check if this transport instance has any active local streams
|
|
692
|
-
# Used to determine if notifications should be queued or delivered immediately
|
|
693
|
-
def has_active_streams?
|
|
694
|
-
@stream_registry.has_any_local_streams?
|
|
695
|
-
end
|
|
696
|
-
|
|
697
788
|
# Broadcast notification to all active streams on this transport instance
|
|
698
789
|
# Handles connection errors gracefully and removes disconnected streams
|
|
699
790
|
def deliver_to_active_streams(notification)
|
|
@@ -723,6 +814,22 @@ module ModelContextProtocol
|
|
|
723
814
|
@server_logger.debug("Delivered notifications to #{delivered_count} streams, cleaned up #{disconnected_streams.size} disconnected streams")
|
|
724
815
|
end
|
|
725
816
|
|
|
817
|
+
# Poll for messages queued in Redis and deliver to the stream
|
|
818
|
+
# Handles cross-server message delivery when notifications are queued by other server instances
|
|
819
|
+
def poll_and_deliver_redis_messages(stream, session_id)
|
|
820
|
+
return unless session_id
|
|
821
|
+
|
|
822
|
+
messages = @session_store.poll_messages_for_session(session_id)
|
|
823
|
+
return if messages.empty?
|
|
824
|
+
|
|
825
|
+
@server_logger.debug("Delivering #{messages.size} queued messages from Redis to stream #{session_id}")
|
|
826
|
+
messages.each do |message|
|
|
827
|
+
send_to_stream(stream, message)
|
|
828
|
+
end
|
|
829
|
+
rescue => e
|
|
830
|
+
@server_logger.error("Error polling Redis messages: #{e.message}")
|
|
831
|
+
end
|
|
832
|
+
|
|
726
833
|
# Flush any queued notifications to a newly connected stream
|
|
727
834
|
# Ensures clients receive notifications that were queued while disconnected
|
|
728
835
|
def flush_notifications_to_stream(stream)
|
|
@@ -783,5 +890,34 @@ module ModelContextProtocol
|
|
|
783
890
|
end
|
|
784
891
|
nil
|
|
785
892
|
end
|
|
893
|
+
|
|
894
|
+
# Check if registered handlers have changed for a session and send notifications
|
|
895
|
+
# Compares current handlers against previously stored handlers in Redis
|
|
896
|
+
def check_and_notify_handler_changes(session_id)
|
|
897
|
+
return unless session_id
|
|
898
|
+
return unless @session_store.session_exists?(session_id)
|
|
899
|
+
|
|
900
|
+
current = @configuration.registry.handler_names
|
|
901
|
+
previous = @session_store.get_registered_handlers(session_id)
|
|
902
|
+
|
|
903
|
+
return if previous.nil? # First request after init
|
|
904
|
+
|
|
905
|
+
changed_types = []
|
|
906
|
+
changed_types << :prompts if current[:prompts].sort != previous[:prompts]&.sort
|
|
907
|
+
changed_types << :resources if current[:resources].sort != previous[:resources]&.sort
|
|
908
|
+
changed_types << :tools if current[:tools].sort != previous[:tools]&.sort
|
|
909
|
+
|
|
910
|
+
return if changed_types.empty?
|
|
911
|
+
|
|
912
|
+
changed_types.each do |type|
|
|
913
|
+
send_notification("notifications/#{type}/list_changed", {}, session_id: session_id)
|
|
914
|
+
end
|
|
915
|
+
|
|
916
|
+
@session_store.store_registered_handlers(session_id, **current)
|
|
917
|
+
rescue => e
|
|
918
|
+
@server_logger.error("Error checking handler changes: #{e.class.name}: #{e.message}")
|
|
919
|
+
@server_logger.debug("Backtrace: #{e.backtrace.first(5).join("\n")}")
|
|
920
|
+
# Don't re-raise - handler change detection is optional, allow request to proceed
|
|
921
|
+
end
|
|
786
922
|
end
|
|
787
923
|
end
|
|
@@ -83,7 +83,7 @@ module ModelContextProtocol
|
|
|
83
83
|
end
|
|
84
84
|
|
|
85
85
|
class << self
|
|
86
|
-
attr_reader :name, :description, :title, :input_schema, :output_schema
|
|
86
|
+
attr_reader :name, :description, :title, :input_schema, :output_schema, :annotations, :security_schemes
|
|
87
87
|
|
|
88
88
|
def define(&block)
|
|
89
89
|
definition_dsl = DefinitionDSL.new
|
|
@@ -94,6 +94,8 @@ module ModelContextProtocol
|
|
|
94
94
|
@title = definition_dsl.title
|
|
95
95
|
@input_schema = definition_dsl.input_schema
|
|
96
96
|
@output_schema = definition_dsl.output_schema
|
|
97
|
+
@annotations = definition_dsl.defined_annotations
|
|
98
|
+
@security_schemes = definition_dsl.security_schemes
|
|
97
99
|
end
|
|
98
100
|
|
|
99
101
|
def inherited(subclass)
|
|
@@ -102,6 +104,8 @@ module ModelContextProtocol
|
|
|
102
104
|
subclass.instance_variable_set(:@title, @title)
|
|
103
105
|
subclass.instance_variable_set(:@input_schema, @input_schema)
|
|
104
106
|
subclass.instance_variable_set(:@output_schema, @output_schema)
|
|
107
|
+
subclass.instance_variable_set(:@annotations, @annotations&.dup)
|
|
108
|
+
subclass.instance_variable_set(:@security_schemes, @security_schemes)
|
|
105
109
|
end
|
|
106
110
|
|
|
107
111
|
def call(arguments, client_logger, server_logger, context = {})
|
|
@@ -120,6 +124,9 @@ module ModelContextProtocol
|
|
|
120
124
|
result = {name: @name, description: @description, inputSchema: @input_schema}
|
|
121
125
|
result[:title] = @title if @title
|
|
122
126
|
result[:outputSchema] = @output_schema if @output_schema
|
|
127
|
+
annotations_hash = @annotations&.serialized
|
|
128
|
+
result[:annotations] = annotations_hash if annotations_hash
|
|
129
|
+
result[:securitySchemes] = @security_schemes if @security_schemes
|
|
123
130
|
result
|
|
124
131
|
end
|
|
125
132
|
end
|
|
@@ -149,6 +156,65 @@ module ModelContextProtocol
|
|
|
149
156
|
@output_schema = instance_eval(&block) if block_given?
|
|
150
157
|
@output_schema
|
|
151
158
|
end
|
|
159
|
+
|
|
160
|
+
attr_reader :defined_annotations
|
|
161
|
+
|
|
162
|
+
def annotations(&block)
|
|
163
|
+
@defined_annotations = AnnotationsDSL.new
|
|
164
|
+
@defined_annotations.instance_eval(&block)
|
|
165
|
+
@defined_annotations
|
|
166
|
+
end
|
|
167
|
+
|
|
168
|
+
def security_schemes(&block)
|
|
169
|
+
@security_schemes = instance_eval(&block) if block_given?
|
|
170
|
+
@security_schemes
|
|
171
|
+
end
|
|
172
|
+
end
|
|
173
|
+
|
|
174
|
+
class AnnotationsDSL
|
|
175
|
+
def initialize
|
|
176
|
+
@read_only_hint = nil
|
|
177
|
+
@destructive_hint = nil
|
|
178
|
+
@idempotent_hint = nil
|
|
179
|
+
@open_world_hint = nil
|
|
180
|
+
end
|
|
181
|
+
|
|
182
|
+
def read_only_hint(value)
|
|
183
|
+
validate_boolean!(:read_only_hint, value)
|
|
184
|
+
@read_only_hint = value
|
|
185
|
+
end
|
|
186
|
+
|
|
187
|
+
def destructive_hint(value)
|
|
188
|
+
validate_boolean!(:destructive_hint, value)
|
|
189
|
+
@destructive_hint = value
|
|
190
|
+
end
|
|
191
|
+
|
|
192
|
+
def idempotent_hint(value)
|
|
193
|
+
validate_boolean!(:idempotent_hint, value)
|
|
194
|
+
@idempotent_hint = value
|
|
195
|
+
end
|
|
196
|
+
|
|
197
|
+
def open_world_hint(value)
|
|
198
|
+
validate_boolean!(:open_world_hint, value)
|
|
199
|
+
@open_world_hint = value
|
|
200
|
+
end
|
|
201
|
+
|
|
202
|
+
def serialized
|
|
203
|
+
result = {}
|
|
204
|
+
result[:readOnlyHint] = @read_only_hint unless @read_only_hint.nil?
|
|
205
|
+
result[:destructiveHint] = @destructive_hint unless @destructive_hint.nil?
|
|
206
|
+
result[:idempotentHint] = @idempotent_hint unless @idempotent_hint.nil?
|
|
207
|
+
result[:openWorldHint] = @open_world_hint unless @open_world_hint.nil?
|
|
208
|
+
result.empty? ? nil : result
|
|
209
|
+
end
|
|
210
|
+
|
|
211
|
+
private
|
|
212
|
+
|
|
213
|
+
def validate_boolean!(field, value)
|
|
214
|
+
unless value.is_a?(TrueClass) || value.is_a?(FalseClass)
|
|
215
|
+
raise ArgumentError, "#{field} must be a boolean, got: #{value.inspect}"
|
|
216
|
+
end
|
|
217
|
+
end
|
|
152
218
|
end
|
|
153
219
|
end
|
|
154
220
|
end
|