model-context-protocol-rb 0.6.0 → 0.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +26 -2
  3. data/README.md +174 -978
  4. data/lib/model_context_protocol/rspec/helpers.rb +54 -0
  5. data/lib/model_context_protocol/rspec/matchers/be_mcp_error_response.rb +123 -0
  6. data/lib/model_context_protocol/rspec/matchers/be_valid_mcp_class.rb +103 -0
  7. data/lib/model_context_protocol/rspec/matchers/be_valid_mcp_prompt_response.rb +126 -0
  8. data/lib/model_context_protocol/rspec/matchers/be_valid_mcp_resource_response.rb +121 -0
  9. data/lib/model_context_protocol/rspec/matchers/be_valid_mcp_tool_response.rb +135 -0
  10. data/lib/model_context_protocol/rspec/matchers/have_audio_content.rb +109 -0
  11. data/lib/model_context_protocol/rspec/matchers/have_embedded_resource_content.rb +150 -0
  12. data/lib/model_context_protocol/rspec/matchers/have_image_content.rb +109 -0
  13. data/lib/model_context_protocol/rspec/matchers/have_message_count.rb +87 -0
  14. data/lib/model_context_protocol/rspec/matchers/have_message_with_role.rb +152 -0
  15. data/lib/model_context_protocol/rspec/matchers/have_resource_annotations.rb +135 -0
  16. data/lib/model_context_protocol/rspec/matchers/have_resource_blob.rb +108 -0
  17. data/lib/model_context_protocol/rspec/matchers/have_resource_link_content.rb +138 -0
  18. data/lib/model_context_protocol/rspec/matchers/have_resource_mime_type.rb +103 -0
  19. data/lib/model_context_protocol/rspec/matchers/have_resource_text.rb +112 -0
  20. data/lib/model_context_protocol/rspec/matchers/have_structured_content.rb +88 -0
  21. data/lib/model_context_protocol/rspec/matchers/have_text_content.rb +113 -0
  22. data/lib/model_context_protocol/rspec/matchers.rb +31 -0
  23. data/lib/model_context_protocol/rspec.rb +23 -0
  24. data/lib/model_context_protocol/server/client_logger.rb +1 -1
  25. data/lib/model_context_protocol/server/configuration.rb +195 -91
  26. data/lib/model_context_protocol/server/content_helpers.rb +1 -1
  27. data/lib/model_context_protocol/server/prompt.rb +0 -14
  28. data/lib/model_context_protocol/server/redis_client_proxy.rb +2 -14
  29. data/lib/model_context_protocol/server/redis_config.rb +5 -7
  30. data/lib/model_context_protocol/server/redis_pool_manager.rb +10 -13
  31. data/lib/model_context_protocol/server/registry.rb +8 -0
  32. data/lib/model_context_protocol/server/router.rb +279 -4
  33. data/lib/model_context_protocol/server/server_logger.rb +5 -2
  34. data/lib/model_context_protocol/server/stdio_configuration.rb +114 -0
  35. data/lib/model_context_protocol/server/stdio_transport/request_store.rb +0 -41
  36. data/lib/model_context_protocol/server/streamable_http_configuration.rb +218 -0
  37. data/lib/model_context_protocol/server/streamable_http_transport/event_counter.rb +0 -13
  38. data/lib/model_context_protocol/server/streamable_http_transport/notification_queue.rb +0 -41
  39. data/lib/model_context_protocol/server/streamable_http_transport/request_store.rb +0 -103
  40. data/lib/model_context_protocol/server/streamable_http_transport/server_request_store.rb +0 -64
  41. data/lib/model_context_protocol/server/streamable_http_transport/session_message_queue.rb +0 -58
  42. data/lib/model_context_protocol/server/streamable_http_transport/session_store.rb +17 -31
  43. data/lib/model_context_protocol/server/streamable_http_transport/stream_registry.rb +0 -34
  44. data/lib/model_context_protocol/server/streamable_http_transport.rb +192 -56
  45. data/lib/model_context_protocol/server/tool.rb +67 -1
  46. data/lib/model_context_protocol/server.rb +203 -262
  47. data/lib/model_context_protocol/version.rb +1 -1
  48. data/lib/model_context_protocol.rb +4 -1
  49. data/lib/puma/plugin/mcp.rb +39 -0
  50. data/tasks/mcp.rake +26 -0
  51. data/tasks/templates/dev-http-puma.erb +251 -0
  52. data/tasks/templates/dev-http.erb +166 -184
  53. data/tasks/templates/dev.erb +29 -7
  54. metadata +26 -2
@@ -43,14 +43,6 @@ module ModelContextProtocol
43
43
  @local_streams.key?(session_id)
44
44
  end
45
45
 
46
- def get_stream_server(session_id)
47
- @redis.get("#{STREAM_KEY_PREFIX}#{session_id}")
48
- end
49
-
50
- def stream_active?(session_id)
51
- @redis.exists("#{STREAM_KEY_PREFIX}#{session_id}") == 1
52
- end
53
-
54
46
  def refresh_heartbeat(session_id)
55
47
  @redis.multi do |multi|
56
48
  multi.set("#{HEARTBEAT_KEY_PREFIX}#{session_id}", Time.now.to_f, ex: @ttl)
@@ -88,32 +80,6 @@ module ModelContextProtocol
88
80
 
89
81
  expired_sessions
90
82
  end
91
-
92
- def get_stale_streams(max_age_seconds = 90)
93
- current_time = Time.now.to_f
94
- stale_streams = []
95
-
96
- # Get all heartbeat keys
97
- heartbeat_keys = @redis.keys("#{HEARTBEAT_KEY_PREFIX}*")
98
-
99
- return stale_streams if heartbeat_keys.empty?
100
-
101
- # Get all heartbeat timestamps
102
- heartbeat_values = @redis.mget(heartbeat_keys)
103
-
104
- heartbeat_keys.each_with_index do |key, index|
105
- next unless heartbeat_values[index]
106
-
107
- session_id = key.sub(HEARTBEAT_KEY_PREFIX, "")
108
- last_heartbeat = heartbeat_values[index].to_f
109
-
110
- if current_time - last_heartbeat > max_age_seconds
111
- stale_streams << session_id
112
- end
113
- end
114
-
115
- stale_streams
116
- end
117
83
  end
118
84
  end
119
85
  end
@@ -1,5 +1,6 @@
1
1
  require "json"
2
2
  require "securerandom"
3
+ require "concurrent"
3
4
 
4
5
  module ModelContextProtocol
5
6
  class Server::StreamableHttpTransport
@@ -28,21 +29,20 @@ module ModelContextProtocol
28
29
  @redis_pool = ModelContextProtocol::Server::RedisConfig.pool
29
30
  @redis = ModelContextProtocol::Server::RedisClientProxy.new(@redis_pool)
30
31
 
31
- transport_options = @configuration.transport_options
32
- @require_sessions = transport_options.fetch(:require_sessions, false)
33
- @default_protocol_version = transport_options.fetch(:default_protocol_version, "2025-03-26")
34
- @session_protocol_versions = {}
35
- @validate_origin = transport_options.fetch(:validate_origin, true)
36
- @allowed_origins = transport_options.fetch(:allowed_origins, ["http://localhost", "https://localhost", "http://127.0.0.1", "https://127.0.0.1"])
32
+ @require_sessions = @configuration.require_sessions
33
+ # Use Concurrent::Map for thread-safe access from multiple request threads
34
+ @session_protocol_versions = Concurrent::Map.new
35
+ @validate_origin = @configuration.validate_origin
36
+ @allowed_origins = @configuration.allowed_origins
37
37
 
38
- @session_store = SessionStore.new(@redis, ttl: transport_options[:session_ttl] || 3600)
38
+ @session_store = SessionStore.new(@redis, ttl: @configuration.session_ttl)
39
39
  @server_instance = "#{Socket.gethostname}-#{Process.pid}-#{SecureRandom.hex(4)}"
40
40
  @stream_registry = StreamRegistry.new(@redis, @server_instance)
41
41
  @notification_queue = NotificationQueue.new(@redis, @server_instance)
42
42
  @event_counter = EventCounter.new(@redis, @server_instance)
43
43
  @request_store = RequestStore.new(@redis, @server_instance)
44
44
  @server_request_store = ServerRequestStore.new(@redis, @server_instance)
45
- @ping_timeout = transport_options.fetch(:ping_timeout, 10)
45
+ @ping_timeout = @configuration.ping_timeout
46
46
 
47
47
  @message_poller = MessagePoller.new(@redis, @stream_registry, @client_logger) do |stream, message|
48
48
  send_to_stream(stream, message)
@@ -55,7 +55,8 @@ module ModelContextProtocol
55
55
  end
56
56
 
57
57
  # Gracefully shut down the transport by stopping background threads and cleaning up resources
58
- # Closes all active streams and returns Redis connections to the pool
58
+ # Closes all active streams. Redis entries are left to expire naturally (they have TTLs).
59
+ # This method is signal-safe and avoids mutex operations.
59
60
  def shutdown
60
61
  @server_logger.info("Shutting down StreamableHttpTransport")
61
62
 
@@ -67,32 +68,31 @@ module ModelContextProtocol
67
68
  @stream_monitor_thread.join(5)
68
69
  end
69
70
 
70
- @stream_registry.get_all_local_streams.each do |session_id, _stream|
71
- close_stream(session_id, reason: "shutdown")
72
- rescue => e
73
- @server_logger.error("Error during stream cleanup for session #{session_id}: #{e.message}")
71
+ # Close streams directly without Redis cleanup (signal-safe).
72
+ # Redis entries will expire naturally via TTL.
73
+ @stream_registry.get_all_local_streams.each do |session_id, stream|
74
+ begin
75
+ stream.close
76
+ rescue IOError, Errno::EPIPE, Errno::ECONNRESET, Errno::ENOTCONN, Errno::EBADF
77
+ # Stream already closed, ignore
78
+ end
79
+ @server_logger.info("← SSE stream [closed] (#{session_id}) [shutdown]")
74
80
  end
75
81
 
76
- @redis_pool.checkin(@redis) if @redis_pool && @redis
77
-
78
82
  @server_logger.info("StreamableHttpTransport shutdown complete")
79
83
  end
80
84
 
81
85
  # Main entry point for handling HTTP requests (POST, GET, DELETE)
82
86
  # Routes requests to appropriate handlers and manages the request/response lifecycle
83
- def handle
87
+ # @param env [Hash] Rack environment hash (required)
88
+ # @param session_context [Hash] Per-request context that will be merged with server context
89
+ def handle(env:, session_context: {})
84
90
  @server_logger.debug("Handling streamable HTTP transport request")
85
91
 
86
- env = @configuration.transport_options[:env]
87
-
88
- unless env
89
- raise ArgumentError, "StreamableHTTP transport requires Rack env hash in transport_options"
90
- end
91
-
92
92
  case env["REQUEST_METHOD"]
93
93
  when "POST"
94
94
  @server_logger.debug("Handling POST request")
95
- handle_post_request(env)
95
+ handle_post_request(env, session_context: session_context)
96
96
  when "GET"
97
97
  @server_logger.debug("Handling GET request")
98
98
  handle_get_request(env)
@@ -165,16 +165,11 @@ module ModelContextProtocol
165
165
  end
166
166
  end
167
167
 
168
- # Validate HTTP headers for required content type and CORS origin
169
- # Returns error response if headers are invalid, nil if valid
170
- def validate_headers(env)
171
- if @validate_origin
172
- origin = env["HTTP_ORIGIN"]
173
- if origin && !@allowed_origins.any? { |allowed| origin.start_with?(allowed) }
174
- error_response = ErrorResponse[id: nil, error: {code: -32600, message: "Origin not allowed"}]
175
- return {json: error_response.serialized, status: 403}
176
- end
177
- end
168
+ # Validate HTTP headers for POST requests: CORS origin, content type, and protocol version.
169
+ # Returns error response hash if headers are invalid, nil if valid.
170
+ def validate_headers(env, session_id: nil)
171
+ origin_error = validate_origin!(env)
172
+ return origin_error if origin_error
178
173
 
179
174
  accept_header = env["HTTP_ACCEPT"]
180
175
  if accept_header
@@ -184,15 +179,53 @@ module ModelContextProtocol
184
179
  end
185
180
  end
186
181
 
182
+ validate_protocol_version!(env, session_id: session_id)
183
+ end
184
+
185
+ # Validate CORS Origin header against allowed origins.
186
+ # The MCP spec requires servers to validate Origin on all incoming connections
187
+ # to prevent DNS rebinding attacks.
188
+ def validate_origin!(env)
189
+ return nil unless @validate_origin
190
+
191
+ origin = env["HTTP_ORIGIN"]
192
+ if origin && !@allowed_origins.any? { |allowed| origin.start_with?(allowed) }
193
+ error_response = ErrorResponse[id: nil, error: {code: -32600, message: "Origin not allowed"}]
194
+ return {json: error_response.serialized, status: 403}
195
+ end
196
+
197
+ nil
198
+ end
199
+
200
+ # Validate MCP-Protocol-Version header against negotiated version.
201
+ # Per the MCP spec, the server MUST respond with 400 Bad Request for invalid
202
+ # or unsupported protocol versions. When a session_id is provided, validation
203
+ # is scoped to that session's negotiated version.
204
+ def validate_protocol_version!(env, session_id: nil)
187
205
  protocol_version = env["HTTP_MCP_PROTOCOL_VERSION"]
188
- if protocol_version
189
- valid_versions = @session_protocol_versions.values.compact.uniq
190
- unless valid_versions.empty? || valid_versions.include?(protocol_version)
191
- error_response = ErrorResponse[id: nil, error: {code: -32600, message: "Invalid MCP protocol version: #{protocol_version}. Expected one of: #{valid_versions.join(", ")}"}]
192
- return {json: error_response.serialized, status: 400}
206
+ return nil unless protocol_version
207
+
208
+ # When a session_id is provided, try session-specific validation first.
209
+ # If the session has a known negotiated version, validate strictly against it.
210
+ if session_id
211
+ expected_version = @session_protocol_versions[session_id]
212
+ if expected_version
213
+ if protocol_version != expected_version
214
+ error_response = ErrorResponse[id: nil, error: {code: -32600, message: "Invalid MCP protocol version: #{protocol_version}. Expected: #{expected_version}"}]
215
+ return {json: error_response.serialized, status: 400}
216
+ end
217
+ return nil
193
218
  end
194
219
  end
195
220
 
221
+ # Fallback: validate against all known negotiated versions (covers cases
222
+ # where session_id is nil or has no entry, e.g. sessions not required).
223
+ valid_versions = @session_protocol_versions.values.compact.uniq
224
+ unless valid_versions.empty? || valid_versions.include?(protocol_version)
225
+ error_response = ErrorResponse[id: nil, error: {code: -32600, message: "Invalid MCP protocol version: #{protocol_version}. Expected one of: #{valid_versions.join(", ")}"}]
226
+ return {json: error_response.serialized, status: 400}
227
+ end
228
+
196
229
  nil
197
230
  end
198
231
 
@@ -212,13 +245,15 @@ module ModelContextProtocol
212
245
 
213
246
  # Handle HTTP POST requests containing JSON-RPC messages
214
247
  # Parses request body and routes to initialization or regular request handlers
215
- def handle_post_request(env)
216
- validation_error = validate_headers(env)
248
+ # @param env [Hash] Rack environment hash
249
+ # @param session_context [Hash] Per-request context for initialization
250
+ def handle_post_request(env, session_context: {})
251
+ session_id = env["HTTP_MCP_SESSION_ID"]
252
+ validation_error = validate_headers(env, session_id: session_id)
217
253
  return validation_error if validation_error
218
254
 
219
255
  body_string = env["rack.input"].read
220
256
  body = JSON.parse(body_string)
221
- session_id = env["HTTP_MCP_SESSION_ID"]
222
257
  accept_header = env["HTTP_ACCEPT"] || ""
223
258
 
224
259
  log_to_server_with_context(request_id: body["id"]) do |logger|
@@ -237,7 +272,7 @@ module ModelContextProtocol
237
272
  end
238
273
 
239
274
  if body["method"] == "initialize"
240
- handle_initialization(body, accept_header)
275
+ handle_initialization(body, accept_header, session_context: session_context)
241
276
  else
242
277
  handle_regular_request(body, session_id, accept_header)
243
278
  end
@@ -276,7 +311,10 @@ module ModelContextProtocol
276
311
 
277
312
  # Handle MCP initialization requests to establish protocol version and optional sessions
278
313
  # Always returns JSON response regardless of Accept header to keep initialization simple
279
- def handle_initialization(body, accept_header)
314
+ # @param body [Hash] Parsed JSON-RPC request body
315
+ # @param accept_header [String] HTTP Accept header value
316
+ # @param session_context [Hash] Per-request context to merge with server context
317
+ def handle_initialization(body, accept_header, session_context: {})
280
318
  result = @router.route(body, transport: self)
281
319
  response = Response[id: body["id"], result: result.serialized]
282
320
  response_headers = {}
@@ -284,12 +322,17 @@ module ModelContextProtocol
284
322
 
285
323
  if @require_sessions
286
324
  session_id = SecureRandom.uuid
325
+ # Merge server-level defaults with request-level context
326
+ merged_context = (@configuration.context || {}).merge(session_context)
287
327
  @session_store.create_session(session_id, {
288
328
  server_instance: @server_instance,
289
- context: @configuration.context || {},
329
+ context: merged_context,
290
330
  created_at: Time.now.to_f,
291
331
  negotiated_protocol_version: negotiated_protocol_version
292
332
  })
333
+ # Store initial handler names for list_changed detection
334
+ current_handlers = @configuration.registry.handler_names
335
+ @session_store.store_registered_handlers(session_id, **current_handlers)
293
336
  response_headers["Mcp-Session-Id"] = session_id
294
337
  @session_protocol_versions[session_id] = negotiated_protocol_version
295
338
  log_to_server_with_context { |logger| logger.info("Session created: #{session_id} (protocol: #{negotiated_protocol_version})") }
@@ -314,9 +357,15 @@ module ModelContextProtocol
314
357
  # Handle regular MCP requests (tools, resources, prompts) with streaming/JSON decision logic
315
358
  # Defaults to SSE streaming but returns JSON when client explicitly requests JSON only
316
359
  def handle_regular_request(body, session_id, accept_header)
360
+ session_context = {}
361
+
317
362
  if @require_sessions
363
+ # Per the MCP spec, servers SHOULD respond to requests without a valid
364
+ # Mcp-Session-Id header (other than initialization) with HTTP 400.
365
+ # The session ID MUST be present on all subsequent requests after initialization,
366
+ # including notifications like notifications/initialized.
318
367
  unless session_id && @session_store.session_exists?(session_id)
319
- if session_id && !@session_store.session_exists?(session_id)
368
+ if session_id
320
369
  error_response = ErrorResponse[id: body["id"], error: {code: -32600, message: "Session terminated"}]
321
370
  return {json: error_response.serialized, status: 404}
322
371
  else
@@ -324,6 +373,9 @@ module ModelContextProtocol
324
373
  return {json: error_response.serialized, status: 400}
325
374
  end
326
375
  end
376
+
377
+ session_context = @session_store.get_session_context(session_id)
378
+ check_and_notify_handler_changes(session_id)
327
379
  end
328
380
 
329
381
  message_type = determine_message_type(body)
@@ -348,7 +400,7 @@ module ModelContextProtocol
348
400
  log_to_server_with_context do |logger|
349
401
  logger.info("← Notification [accepted]")
350
402
  end
351
- {json: {}, status: 202}
403
+ {status: 202}
352
404
 
353
405
  when :request
354
406
  if accept_header.include?("text/event-stream")
@@ -359,9 +411,9 @@ module ModelContextProtocol
359
411
  "Cache-Control" => "no-cache",
360
412
  "Connection" => "keep-alive"
361
413
  },
362
- stream_proc: create_request_response_sse_stream_proc(body, session_id)
414
+ stream_proc: create_request_response_sse_stream_proc(body, session_id, session_context: session_context)
363
415
  }
364
- elsif (result = @router.route(body, request_store: @request_store, session_id: session_id, transport: self))
416
+ elsif (result = @router.route(body, request_store: @request_store, session_id: session_id, transport: self, session_context: session_context))
365
417
  response = Response[id: body["id"], result: result.serialized]
366
418
 
367
419
  log_to_server_with_context(request_id: response.id) do |logger|
@@ -386,6 +438,9 @@ module ModelContextProtocol
386
438
  # Handle HTTP GET requests to establish persistent SSE connections for notifications
387
439
  # Validates session requirements and Accept headers before opening long-lived streams
388
440
  def handle_get_request(env)
441
+ origin_error = validate_origin!(env)
442
+ return origin_error if origin_error
443
+
389
444
  accept_header = env["HTTP_ACCEPT"] || ""
390
445
  unless accept_header.include?("text/event-stream")
391
446
  error_response = ErrorResponse[id: nil, error: {code: -32600, message: "Accept header must include text/event-stream"}]
@@ -393,6 +448,10 @@ module ModelContextProtocol
393
448
  end
394
449
 
395
450
  session_id = env["HTTP_MCP_SESSION_ID"]
451
+
452
+ protocol_error = validate_protocol_version!(env, session_id: session_id)
453
+ return protocol_error if protocol_error
454
+
396
455
  last_event_id = env["HTTP_LAST_EVENT_ID"]
397
456
 
398
457
  if @require_sessions
@@ -422,10 +481,28 @@ module ModelContextProtocol
422
481
  # Handle HTTP DELETE requests to clean up sessions and associated resources
423
482
  # Removes session data, closes streams, and cleans up request store entries
424
483
  def handle_delete_request(env)
484
+ origin_error = validate_origin!(env)
485
+ return origin_error if origin_error
486
+
425
487
  session_id = env["HTTP_MCP_SESSION_ID"]
426
488
 
489
+ protocol_error = validate_protocol_version!(env, session_id: session_id)
490
+ return protocol_error if protocol_error
491
+
427
492
  @server_logger.info("→ DELETE /mcp [Session cleanup: #{session_id || "unknown"}]")
428
493
 
494
+ if @require_sessions
495
+ unless session_id
496
+ error_response = ErrorResponse[id: nil, error: {code: -32600, message: "Invalid or missing session ID"}]
497
+ return {json: error_response.serialized, status: 400}
498
+ end
499
+
500
+ unless @session_store.session_exists?(session_id)
501
+ error_response = ErrorResponse[id: nil, error: {code: -32600, message: "Session terminated"}]
502
+ return {json: error_response.serialized, status: 404}
503
+ end
504
+ end
505
+
429
506
  if session_id
430
507
  cleanup_session(session_id)
431
508
  log_to_server_with_context { |logger| logger.info("Session cleanup: #{session_id}") }
@@ -441,7 +518,10 @@ module ModelContextProtocol
441
518
  # Create SSE stream processor for request-response pattern with real-time progress support
442
519
  # Opens stream → Executes request → Sends response → Closes stream
443
520
  # Enables progress notifications during long-running operations like tool calls
444
- def create_request_response_sse_stream_proc(request_body, session_id)
521
+ # @param request_body [Hash] Parsed JSON-RPC request
522
+ # @param session_id [String, nil] Session ID for this request
523
+ # @param session_context [Hash] Context to pass to handlers
524
+ def create_request_response_sse_stream_proc(request_body, session_id, session_context: {})
445
525
  proc do |stream|
446
526
  temp_stream_id = "temp-#{SecureRandom.hex(8)}"
447
527
  @stream_registry.register_stream(temp_stream_id, stream)
@@ -452,7 +532,7 @@ module ModelContextProtocol
452
532
  end
453
533
 
454
534
  begin
455
- if (result = @router.route(request_body, request_store: @request_store, session_id: session_id, transport: self, stream_id: temp_stream_id))
535
+ if (result = @router.route(request_body, request_store: @request_store, session_id: session_id, transport: self, stream_id: temp_stream_id, session_context: session_context))
456
536
  response = Response[id: request_body["id"], result: result.serialized]
457
537
  event_id = next_event_id
458
538
  send_sse_event(stream, response.serialized, event_id)
@@ -467,6 +547,16 @@ module ModelContextProtocol
467
547
  rescue IOError, Errno::EPIPE, Errno::ECONNRESET => e
468
548
  @server_logger.debug("Client disconnected during progressive request processing: #{e.class.name}")
469
549
  log_to_server_with_context { |logger| logger.info("← SSE stream [closed] (#{temp_stream_id}) [client_disconnected]") }
550
+ rescue ModelContextProtocol::Server::ParameterValidationError => e
551
+ @client_logger.error("Validation error", error: e.message)
552
+ error_response = ErrorResponse[id: request_body["id"], error: {code: -32602, message: e.message}]
553
+ send_sse_event(stream, error_response.serialized, next_event_id)
554
+ close_stream(temp_stream_id, reason: "validation_error")
555
+ rescue => e
556
+ @client_logger.error("Internal error", error: e.message, backtrace: e.backtrace)
557
+ error_response = ErrorResponse[id: request_body["id"], error: {code: -32603, message: e.message}]
558
+ send_sse_event(stream, error_response.serialized, next_event_id)
559
+ close_stream(temp_stream_id, reason: "internal_error")
470
560
  ensure
471
561
  @stream_registry.unregister_stream(temp_stream_id)
472
562
  end
@@ -525,8 +615,15 @@ module ModelContextProtocol
525
615
  flush_notifications_to_stream(stream)
526
616
  end
527
617
 
618
+ # Also flush any messages queued in Redis from other server instances
619
+ poll_and_deliver_redis_messages(stream, session_id) if session_id
620
+
528
621
  loop do
529
622
  break unless stream_connected?(stream)
623
+
624
+ # Poll for queued messages from Redis (cross-server delivery)
625
+ poll_and_deliver_redis_messages(stream, session_id) if session_id
626
+
530
627
  sleep 0.1
531
628
  end
532
629
  ensure
@@ -688,12 +785,6 @@ module ModelContextProtocol
688
785
  @server_request_store.cleanup_session_requests(session_id)
689
786
  end
690
787
 
691
- # Check if this transport instance has any active local streams
692
- # Used to determine if notifications should be queued or delivered immediately
693
- def has_active_streams?
694
- @stream_registry.has_any_local_streams?
695
- end
696
-
697
788
  # Broadcast notification to all active streams on this transport instance
698
789
  # Handles connection errors gracefully and removes disconnected streams
699
790
  def deliver_to_active_streams(notification)
@@ -723,6 +814,22 @@ module ModelContextProtocol
723
814
  @server_logger.debug("Delivered notifications to #{delivered_count} streams, cleaned up #{disconnected_streams.size} disconnected streams")
724
815
  end
725
816
 
817
+ # Poll for messages queued in Redis and deliver to the stream
818
+ # Handles cross-server message delivery when notifications are queued by other server instances
819
+ def poll_and_deliver_redis_messages(stream, session_id)
820
+ return unless session_id
821
+
822
+ messages = @session_store.poll_messages_for_session(session_id)
823
+ return if messages.empty?
824
+
825
+ @server_logger.debug("Delivering #{messages.size} queued messages from Redis to stream #{session_id}")
826
+ messages.each do |message|
827
+ send_to_stream(stream, message)
828
+ end
829
+ rescue => e
830
+ @server_logger.error("Error polling Redis messages: #{e.message}")
831
+ end
832
+
726
833
  # Flush any queued notifications to a newly connected stream
727
834
  # Ensures clients receive notifications that were queued while disconnected
728
835
  def flush_notifications_to_stream(stream)
@@ -783,5 +890,34 @@ module ModelContextProtocol
783
890
  end
784
891
  nil
785
892
  end
893
+
894
+ # Check if registered handlers have changed for a session and send notifications
895
+ # Compares current handlers against previously stored handlers in Redis
896
+ def check_and_notify_handler_changes(session_id)
897
+ return unless session_id
898
+ return unless @session_store.session_exists?(session_id)
899
+
900
+ current = @configuration.registry.handler_names
901
+ previous = @session_store.get_registered_handlers(session_id)
902
+
903
+ return if previous.nil? # First request after init
904
+
905
+ changed_types = []
906
+ changed_types << :prompts if current[:prompts].sort != previous[:prompts]&.sort
907
+ changed_types << :resources if current[:resources].sort != previous[:resources]&.sort
908
+ changed_types << :tools if current[:tools].sort != previous[:tools]&.sort
909
+
910
+ return if changed_types.empty?
911
+
912
+ changed_types.each do |type|
913
+ send_notification("notifications/#{type}/list_changed", {}, session_id: session_id)
914
+ end
915
+
916
+ @session_store.store_registered_handlers(session_id, **current)
917
+ rescue => e
918
+ @server_logger.error("Error checking handler changes: #{e.class.name}: #{e.message}")
919
+ @server_logger.debug("Backtrace: #{e.backtrace.first(5).join("\n")}")
920
+ # Don't re-raise - handler change detection is optional, allow request to proceed
921
+ end
786
922
  end
787
923
  end
@@ -83,7 +83,7 @@ module ModelContextProtocol
83
83
  end
84
84
 
85
85
  class << self
86
- attr_reader :name, :description, :title, :input_schema, :output_schema
86
+ attr_reader :name, :description, :title, :input_schema, :output_schema, :annotations, :security_schemes
87
87
 
88
88
  def define(&block)
89
89
  definition_dsl = DefinitionDSL.new
@@ -94,6 +94,8 @@ module ModelContextProtocol
94
94
  @title = definition_dsl.title
95
95
  @input_schema = definition_dsl.input_schema
96
96
  @output_schema = definition_dsl.output_schema
97
+ @annotations = definition_dsl.defined_annotations
98
+ @security_schemes = definition_dsl.security_schemes
97
99
  end
98
100
 
99
101
  def inherited(subclass)
@@ -102,6 +104,8 @@ module ModelContextProtocol
102
104
  subclass.instance_variable_set(:@title, @title)
103
105
  subclass.instance_variable_set(:@input_schema, @input_schema)
104
106
  subclass.instance_variable_set(:@output_schema, @output_schema)
107
+ subclass.instance_variable_set(:@annotations, @annotations&.dup)
108
+ subclass.instance_variable_set(:@security_schemes, @security_schemes)
105
109
  end
106
110
 
107
111
  def call(arguments, client_logger, server_logger, context = {})
@@ -120,6 +124,9 @@ module ModelContextProtocol
120
124
  result = {name: @name, description: @description, inputSchema: @input_schema}
121
125
  result[:title] = @title if @title
122
126
  result[:outputSchema] = @output_schema if @output_schema
127
+ annotations_hash = @annotations&.serialized
128
+ result[:annotations] = annotations_hash if annotations_hash
129
+ result[:securitySchemes] = @security_schemes if @security_schemes
123
130
  result
124
131
  end
125
132
  end
@@ -149,6 +156,65 @@ module ModelContextProtocol
149
156
  @output_schema = instance_eval(&block) if block_given?
150
157
  @output_schema
151
158
  end
159
+
160
+ attr_reader :defined_annotations
161
+
162
+ def annotations(&block)
163
+ @defined_annotations = AnnotationsDSL.new
164
+ @defined_annotations.instance_eval(&block)
165
+ @defined_annotations
166
+ end
167
+
168
+ def security_schemes(&block)
169
+ @security_schemes = instance_eval(&block) if block_given?
170
+ @security_schemes
171
+ end
172
+ end
173
+
174
+ class AnnotationsDSL
175
+ def initialize
176
+ @read_only_hint = nil
177
+ @destructive_hint = nil
178
+ @idempotent_hint = nil
179
+ @open_world_hint = nil
180
+ end
181
+
182
+ def read_only_hint(value)
183
+ validate_boolean!(:read_only_hint, value)
184
+ @read_only_hint = value
185
+ end
186
+
187
+ def destructive_hint(value)
188
+ validate_boolean!(:destructive_hint, value)
189
+ @destructive_hint = value
190
+ end
191
+
192
+ def idempotent_hint(value)
193
+ validate_boolean!(:idempotent_hint, value)
194
+ @idempotent_hint = value
195
+ end
196
+
197
+ def open_world_hint(value)
198
+ validate_boolean!(:open_world_hint, value)
199
+ @open_world_hint = value
200
+ end
201
+
202
+ def serialized
203
+ result = {}
204
+ result[:readOnlyHint] = @read_only_hint unless @read_only_hint.nil?
205
+ result[:destructiveHint] = @destructive_hint unless @destructive_hint.nil?
206
+ result[:idempotentHint] = @idempotent_hint unless @idempotent_hint.nil?
207
+ result[:openWorldHint] = @open_world_hint unless @open_world_hint.nil?
208
+ result.empty? ? nil : result
209
+ end
210
+
211
+ private
212
+
213
+ def validate_boolean!(field, value)
214
+ unless value.is_a?(TrueClass) || value.is_a?(FalseClass)
215
+ raise ArgumentError, "#{field} must be a boolean, got: #{value.inspect}"
216
+ end
217
+ end
152
218
  end
153
219
  end
154
220
  end