model-context-protocol-rb 0.4.0 → 0.5.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +14 -1
- data/README.md +337 -158
- data/lib/model_context_protocol/server/cancellable.rb +54 -0
- data/lib/model_context_protocol/server/configuration.rb +4 -9
- data/lib/model_context_protocol/server/progressable.rb +72 -0
- data/lib/model_context_protocol/server/prompt.rb +3 -1
- data/lib/model_context_protocol/server/redis_client_proxy.rb +134 -0
- data/lib/model_context_protocol/server/redis_config.rb +108 -0
- data/lib/model_context_protocol/server/redis_pool_manager.rb +110 -0
- data/lib/model_context_protocol/server/resource.rb +3 -0
- data/lib/model_context_protocol/server/router.rb +36 -3
- data/lib/model_context_protocol/server/stdio_transport/request_store.rb +102 -0
- data/lib/model_context_protocol/server/stdio_transport.rb +31 -6
- data/lib/model_context_protocol/server/streamable_http_transport/event_counter.rb +35 -0
- data/lib/model_context_protocol/server/streamable_http_transport/message_poller.rb +101 -0
- data/lib/model_context_protocol/server/streamable_http_transport/notification_queue.rb +80 -0
- data/lib/model_context_protocol/server/streamable_http_transport/request_store.rb +224 -0
- data/lib/model_context_protocol/server/streamable_http_transport/session_message_queue.rb +120 -0
- data/lib/model_context_protocol/server/{session_store.rb → streamable_http_transport/session_store.rb} +30 -16
- data/lib/model_context_protocol/server/streamable_http_transport/stream_registry.rb +119 -0
- data/lib/model_context_protocol/server/streamable_http_transport.rb +181 -80
- data/lib/model_context_protocol/server/tool.rb +4 -0
- data/lib/model_context_protocol/server.rb +9 -3
- data/lib/model_context_protocol/version.rb +1 -1
- data/tasks/templates/dev-http.erb +58 -14
- metadata +57 -3
@@ -0,0 +1,119 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require "json"
|
4
|
+
|
5
|
+
module ModelContextProtocol
|
6
|
+
class Server::StreamableHttpTransport
|
7
|
+
class StreamRegistry
|
8
|
+
STREAM_KEY_PREFIX = "stream:active:"
|
9
|
+
HEARTBEAT_KEY_PREFIX = "stream:heartbeat:"
|
10
|
+
DEFAULT_TTL = 60 # 1 minute TTL for stream entries
|
11
|
+
|
12
|
+
def initialize(redis_client, server_instance, ttl: DEFAULT_TTL)
|
13
|
+
@redis = redis_client
|
14
|
+
@server_instance = server_instance
|
15
|
+
@ttl = ttl
|
16
|
+
@local_streams = {} # Keep local reference for direct stream access
|
17
|
+
end
|
18
|
+
|
19
|
+
def register_stream(session_id, stream)
|
20
|
+
@local_streams[session_id] = stream
|
21
|
+
|
22
|
+
# Store stream registration in Redis with TTL
|
23
|
+
@redis.multi do |multi|
|
24
|
+
multi.set("#{STREAM_KEY_PREFIX}#{session_id}", @server_instance, ex: @ttl)
|
25
|
+
multi.set("#{HEARTBEAT_KEY_PREFIX}#{session_id}", Time.now.to_f, ex: @ttl)
|
26
|
+
end
|
27
|
+
end
|
28
|
+
|
29
|
+
def unregister_stream(session_id)
|
30
|
+
@local_streams.delete(session_id)
|
31
|
+
|
32
|
+
@redis.multi do |multi|
|
33
|
+
multi.del("#{STREAM_KEY_PREFIX}#{session_id}")
|
34
|
+
multi.del("#{HEARTBEAT_KEY_PREFIX}#{session_id}")
|
35
|
+
end
|
36
|
+
end
|
37
|
+
|
38
|
+
def get_local_stream(session_id)
|
39
|
+
@local_streams[session_id]
|
40
|
+
end
|
41
|
+
|
42
|
+
def has_local_stream?(session_id)
|
43
|
+
@local_streams.key?(session_id)
|
44
|
+
end
|
45
|
+
|
46
|
+
def get_stream_server(session_id)
|
47
|
+
@redis.get("#{STREAM_KEY_PREFIX}#{session_id}")
|
48
|
+
end
|
49
|
+
|
50
|
+
def stream_active?(session_id)
|
51
|
+
@redis.exists("#{STREAM_KEY_PREFIX}#{session_id}") == 1
|
52
|
+
end
|
53
|
+
|
54
|
+
def refresh_heartbeat(session_id)
|
55
|
+
@redis.multi do |multi|
|
56
|
+
multi.set("#{HEARTBEAT_KEY_PREFIX}#{session_id}", Time.now.to_f, ex: @ttl)
|
57
|
+
multi.expire("#{STREAM_KEY_PREFIX}#{session_id}", @ttl)
|
58
|
+
end
|
59
|
+
end
|
60
|
+
|
61
|
+
def get_all_local_streams
|
62
|
+
@local_streams.dup
|
63
|
+
end
|
64
|
+
|
65
|
+
def has_any_local_streams?
|
66
|
+
!@local_streams.empty?
|
67
|
+
end
|
68
|
+
|
69
|
+
def cleanup_expired_streams
|
70
|
+
# Get all local stream session IDs
|
71
|
+
local_session_ids = @local_streams.keys
|
72
|
+
|
73
|
+
# Check which ones are still active in Redis
|
74
|
+
pipeline_results = @redis.pipelined do |pipeline|
|
75
|
+
local_session_ids.each do |session_id|
|
76
|
+
pipeline.exists("#{STREAM_KEY_PREFIX}#{session_id}")
|
77
|
+
end
|
78
|
+
end
|
79
|
+
|
80
|
+
# Remove expired streams from local storage
|
81
|
+
expired_sessions = []
|
82
|
+
local_session_ids.each_with_index do |session_id, index|
|
83
|
+
if pipeline_results[index] == 0 # Stream expired in Redis
|
84
|
+
@local_streams.delete(session_id)
|
85
|
+
expired_sessions << session_id
|
86
|
+
end
|
87
|
+
end
|
88
|
+
|
89
|
+
expired_sessions
|
90
|
+
end
|
91
|
+
|
92
|
+
def get_stale_streams(max_age_seconds = 90)
|
93
|
+
current_time = Time.now.to_f
|
94
|
+
stale_streams = []
|
95
|
+
|
96
|
+
# Get all heartbeat keys
|
97
|
+
heartbeat_keys = @redis.keys("#{HEARTBEAT_KEY_PREFIX}*")
|
98
|
+
|
99
|
+
return stale_streams if heartbeat_keys.empty?
|
100
|
+
|
101
|
+
# Get all heartbeat timestamps
|
102
|
+
heartbeat_values = @redis.mget(heartbeat_keys)
|
103
|
+
|
104
|
+
heartbeat_keys.each_with_index do |key, index|
|
105
|
+
next unless heartbeat_values[index]
|
106
|
+
|
107
|
+
session_id = key.sub(HEARTBEAT_KEY_PREFIX, "")
|
108
|
+
last_heartbeat = heartbeat_values[index].to_f
|
109
|
+
|
110
|
+
if current_time - last_heartbeat > max_age_seconds
|
111
|
+
stale_streams << session_id
|
112
|
+
end
|
113
|
+
end
|
114
|
+
|
115
|
+
stale_streams
|
116
|
+
end
|
117
|
+
end
|
118
|
+
end
|
119
|
+
end
|
@@ -14,29 +14,62 @@ module ModelContextProtocol
|
|
14
14
|
{jsonrpc: "2.0", id:, error:}
|
15
15
|
end
|
16
16
|
end
|
17
|
+
|
17
18
|
def initialize(router:, configuration:)
|
18
19
|
@router = router
|
19
20
|
@configuration = configuration
|
20
21
|
|
21
22
|
transport_options = @configuration.transport_options
|
22
|
-
@
|
23
|
+
@redis_pool = ModelContextProtocol::Server::RedisConfig.pool
|
23
24
|
@require_sessions = transport_options.fetch(:require_sessions, false)
|
24
25
|
@default_protocol_version = transport_options.fetch(:default_protocol_version, "2025-03-26")
|
25
|
-
@session_protocol_versions = {}
|
26
|
+
@session_protocol_versions = {}
|
26
27
|
@validate_origin = transport_options.fetch(:validate_origin, true)
|
27
28
|
@allowed_origins = transport_options.fetch(:allowed_origins, ["http://localhost", "https://localhost", "http://127.0.0.1", "https://127.0.0.1"])
|
29
|
+
@redis = ModelContextProtocol::Server::RedisClientProxy.new(@redis_pool)
|
28
30
|
|
29
|
-
@session_store =
|
31
|
+
@session_store = SessionStore.new(
|
30
32
|
@redis,
|
31
33
|
ttl: transport_options[:session_ttl] || 3600
|
32
34
|
)
|
33
35
|
|
34
36
|
@server_instance = "#{Socket.gethostname}-#{Process.pid}-#{SecureRandom.hex(4)}"
|
35
|
-
@
|
36
|
-
@notification_queue =
|
37
|
-
@
|
37
|
+
@stream_registry = StreamRegistry.new(@redis, @server_instance)
|
38
|
+
@notification_queue = NotificationQueue.new(@redis, @server_instance)
|
39
|
+
@event_counter = EventCounter.new(@redis, @server_instance)
|
40
|
+
@request_store = RequestStore.new(@redis, @server_instance)
|
41
|
+
@stream_monitor_thread = nil
|
42
|
+
@message_poller = MessagePoller.new(@redis, @stream_registry, @configuration.logger) do |stream, message|
|
43
|
+
send_to_stream(stream, message)
|
44
|
+
end
|
38
45
|
|
39
|
-
|
46
|
+
start_message_poller
|
47
|
+
start_stream_monitor
|
48
|
+
end
|
49
|
+
|
50
|
+
def shutdown
|
51
|
+
@configuration.logger.info("Shutting down StreamableHttpTransport")
|
52
|
+
|
53
|
+
# Stop the message poller
|
54
|
+
@message_poller&.stop
|
55
|
+
|
56
|
+
# Stop the stream monitor thread
|
57
|
+
if @stream_monitor_thread&.alive?
|
58
|
+
@stream_monitor_thread.kill
|
59
|
+
@stream_monitor_thread.join(timeout: 5)
|
60
|
+
end
|
61
|
+
|
62
|
+
# Unregister all local streams
|
63
|
+
@stream_registry.get_all_local_streams.each do |session_id, stream|
|
64
|
+
@stream_registry.unregister_stream(session_id)
|
65
|
+
@session_store.mark_stream_inactive(session_id)
|
66
|
+
rescue => e
|
67
|
+
@configuration.logger.error("Error during stream cleanup", session_id: session_id, error: e.message)
|
68
|
+
end
|
69
|
+
|
70
|
+
@redis_pool.checkin(@redis) if @redis_pool && @redis
|
71
|
+
|
72
|
+
@configuration.logger.info("StreamableHttpTransport shutdown complete")
|
40
73
|
end
|
41
74
|
|
42
75
|
def handle
|
@@ -68,10 +101,10 @@ module ModelContextProtocol
|
|
68
101
|
params: params
|
69
102
|
}
|
70
103
|
|
71
|
-
if
|
104
|
+
if @stream_registry.has_any_local_streams?
|
72
105
|
deliver_to_active_streams(notification)
|
73
106
|
else
|
74
|
-
@notification_queue
|
107
|
+
@notification_queue.push(notification)
|
75
108
|
end
|
76
109
|
end
|
77
110
|
|
@@ -96,7 +129,6 @@ module ModelContextProtocol
|
|
96
129
|
|
97
130
|
protocol_version = env["HTTP_MCP_PROTOCOL_VERSION"]
|
98
131
|
if protocol_version
|
99
|
-
# Check if this matches a known negotiated version
|
100
132
|
valid_versions = @session_protocol_versions.values.compact.uniq
|
101
133
|
unless valid_versions.empty? || valid_versions.include?(protocol_version)
|
102
134
|
error_response = ErrorResponse[id: nil, error: {code: -32600, message: "Invalid MCP protocol version: #{protocol_version}. Expected one of: #{valid_versions.join(", ")}"}]
|
@@ -133,9 +165,37 @@ module ModelContextProtocol
|
|
133
165
|
end
|
134
166
|
end
|
135
167
|
|
168
|
+
def create_progressive_request_sse_stream_proc(request_body, session_id)
|
169
|
+
proc do |stream|
|
170
|
+
temp_stream_id = session_id || "temp-#{SecureRandom.hex(8)}"
|
171
|
+
@stream_registry.register_stream(temp_stream_id, stream)
|
172
|
+
|
173
|
+
begin
|
174
|
+
result = @router.route(request_body, request_store: @request_store, session_id: session_id, transport: self)
|
175
|
+
|
176
|
+
if result
|
177
|
+
response = Response[id: request_body["id"], result: result.serialized]
|
178
|
+
|
179
|
+
event_id = next_event_id
|
180
|
+
send_sse_event(stream, response.serialized, event_id)
|
181
|
+
else
|
182
|
+
event_id = next_event_id
|
183
|
+
send_sse_event(stream, {}, event_id)
|
184
|
+
end
|
185
|
+
|
186
|
+
# Close stream immediately when work is complete
|
187
|
+
close_stream(temp_stream_id, reason: "request_completed")
|
188
|
+
rescue IOError, Errno::EPIPE, Errno::ECONNRESET
|
189
|
+
# Client disconnected during processing
|
190
|
+
ensure
|
191
|
+
# Fallback cleanup
|
192
|
+
@stream_registry.unregister_stream(temp_stream_id)
|
193
|
+
end
|
194
|
+
end
|
195
|
+
end
|
196
|
+
|
136
197
|
def next_event_id
|
137
|
-
@
|
138
|
-
"#{@server_instance}-#{@sse_event_counter}"
|
198
|
+
@event_counter.next_event_id
|
139
199
|
end
|
140
200
|
|
141
201
|
def send_sse_event(stream, data, event_id = nil)
|
@@ -147,6 +207,20 @@ module ModelContextProtocol
|
|
147
207
|
stream.flush if stream.respond_to?(:flush)
|
148
208
|
end
|
149
209
|
|
210
|
+
def close_stream(session_id, reason: "completed")
|
211
|
+
if (stream = @stream_registry.get_local_stream(session_id))
|
212
|
+
begin
|
213
|
+
send_sse_event(stream, {type: "stream_complete", reason: reason})
|
214
|
+
stream.close
|
215
|
+
rescue IOError, Errno::EPIPE, Errno::ECONNRESET, Errno::ENOTCONN, Errno::EBADF
|
216
|
+
nil
|
217
|
+
end
|
218
|
+
|
219
|
+
@stream_registry.unregister_stream(session_id)
|
220
|
+
@session_store.mark_stream_inactive(session_id) if @require_sessions
|
221
|
+
end
|
222
|
+
end
|
223
|
+
|
150
224
|
def handle_post_request(env)
|
151
225
|
validation_error = validate_headers(env)
|
152
226
|
return validation_error if validation_error
|
@@ -176,7 +250,7 @@ module ModelContextProtocol
|
|
176
250
|
end
|
177
251
|
|
178
252
|
def handle_initialization(body, accept_header)
|
179
|
-
result = @router.route(body)
|
253
|
+
result = @router.route(body, transport: self)
|
180
254
|
response = Response[id: body["id"], result: result.serialized]
|
181
255
|
response_headers = {}
|
182
256
|
|
@@ -235,21 +309,19 @@ module ModelContextProtocol
|
|
235
309
|
|
236
310
|
case message_type
|
237
311
|
when :notification, :response
|
238
|
-
if
|
312
|
+
if body["method"] == "notifications/cancelled"
|
313
|
+
handle_cancellation(body, session_id)
|
314
|
+
elsif session_id && @session_store.session_has_active_stream?(session_id)
|
239
315
|
deliver_to_session_stream(session_id, body)
|
240
316
|
end
|
241
317
|
{json: {}, status: 202}
|
242
318
|
|
243
319
|
when :request
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
if session_id && @session_store.session_has_active_stream?(session_id)
|
248
|
-
deliver_to_session_stream(session_id, response.serialized)
|
249
|
-
return {json: {accepted: true}, status: 200}
|
250
|
-
end
|
320
|
+
has_progress_token = body.dig("params", "_meta", "progressToken")
|
321
|
+
should_stream = (accept_header.include?("text/event-stream") && !accept_header.include?("application/json")) ||
|
322
|
+
has_progress_token
|
251
323
|
|
252
|
-
if
|
324
|
+
if should_stream
|
253
325
|
{
|
254
326
|
stream: true,
|
255
327
|
headers: {
|
@@ -257,14 +329,27 @@ module ModelContextProtocol
|
|
257
329
|
"Cache-Control" => "no-cache",
|
258
330
|
"Connection" => "keep-alive"
|
259
331
|
},
|
260
|
-
stream_proc:
|
332
|
+
stream_proc: create_progressive_request_sse_stream_proc(body, session_id)
|
261
333
|
}
|
262
334
|
else
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
335
|
+
result = @router.route(body, request_store: @request_store, session_id: session_id, transport: self)
|
336
|
+
|
337
|
+
if result
|
338
|
+
response = Response[id: body["id"], result: result.serialized]
|
339
|
+
|
340
|
+
if session_id && @session_store.session_has_active_stream?(session_id)
|
341
|
+
deliver_to_session_stream(session_id, response.serialized)
|
342
|
+
return {json: {accepted: true}, status: 200}
|
343
|
+
end
|
344
|
+
|
345
|
+
{
|
346
|
+
json: response.serialized,
|
347
|
+
status: 200,
|
348
|
+
headers: {"Content-Type" => "application/json"}
|
349
|
+
}
|
350
|
+
else
|
351
|
+
{json: {}, status: 204}
|
352
|
+
end
|
268
353
|
end
|
269
354
|
end
|
270
355
|
end
|
@@ -315,7 +400,7 @@ module ModelContextProtocol
|
|
315
400
|
|
316
401
|
def create_sse_stream_proc(session_id, last_event_id = nil)
|
317
402
|
proc do |stream|
|
318
|
-
|
403
|
+
@stream_registry.register_stream(session_id, stream) if session_id
|
319
404
|
|
320
405
|
if last_event_id
|
321
406
|
replay_messages_after_event_id(stream, session_id, last_event_id)
|
@@ -323,26 +408,15 @@ module ModelContextProtocol
|
|
323
408
|
flush_notifications_to_stream(stream)
|
324
409
|
end
|
325
410
|
|
326
|
-
start_keepalive_thread(session_id, stream)
|
327
|
-
|
328
411
|
loop do
|
329
412
|
break unless stream_connected?(stream)
|
330
413
|
sleep 0.1
|
331
414
|
end
|
332
415
|
ensure
|
333
|
-
|
416
|
+
@stream_registry.unregister_stream(session_id) if session_id
|
334
417
|
end
|
335
418
|
end
|
336
419
|
|
337
|
-
def register_local_stream(session_id, stream)
|
338
|
-
@local_streams[session_id] = stream
|
339
|
-
end
|
340
|
-
|
341
|
-
def cleanup_local_stream(session_id)
|
342
|
-
@local_streams.delete(session_id)
|
343
|
-
@session_store.mark_stream_inactive(session_id)
|
344
|
-
end
|
345
|
-
|
346
420
|
def stream_connected?(stream)
|
347
421
|
return false unless stream
|
348
422
|
|
@@ -350,27 +424,44 @@ module ModelContextProtocol
|
|
350
424
|
stream.write(": ping\n\n")
|
351
425
|
stream.flush if stream.respond_to?(:flush)
|
352
426
|
true
|
353
|
-
rescue IOError, Errno::EPIPE, Errno::ECONNRESET
|
427
|
+
rescue IOError, Errno::EPIPE, Errno::ECONNRESET, Errno::ENOTCONN, Errno::EBADF
|
354
428
|
false
|
355
429
|
end
|
356
430
|
end
|
357
431
|
|
358
|
-
def
|
359
|
-
Thread.new do
|
432
|
+
def start_stream_monitor
|
433
|
+
@stream_monitor_thread = Thread.new do
|
360
434
|
loop do
|
361
|
-
sleep 30
|
362
|
-
break unless stream_connected?(stream)
|
435
|
+
sleep 30 # Check every 30 seconds
|
363
436
|
|
364
437
|
begin
|
365
|
-
|
366
|
-
rescue
|
367
|
-
|
438
|
+
monitor_streams
|
439
|
+
rescue => e
|
440
|
+
@configuration.logger.error("Stream monitor error", error: e.message)
|
368
441
|
end
|
369
442
|
end
|
370
443
|
rescue => e
|
371
|
-
@configuration.logger.error("
|
372
|
-
|
373
|
-
|
444
|
+
@configuration.logger.error("Stream monitor thread error", error: e.message)
|
445
|
+
sleep 5
|
446
|
+
retry
|
447
|
+
end
|
448
|
+
end
|
449
|
+
|
450
|
+
def monitor_streams
|
451
|
+
expired_sessions = @stream_registry.cleanup_expired_streams
|
452
|
+
expired_sessions.each do |session_id|
|
453
|
+
@session_store.mark_stream_inactive(session_id)
|
454
|
+
end
|
455
|
+
|
456
|
+
@stream_registry.get_all_local_streams.each do |session_id, stream|
|
457
|
+
if stream_connected?(stream)
|
458
|
+
send_ping_to_stream(stream)
|
459
|
+
@stream_registry.refresh_heartbeat(session_id)
|
460
|
+
else
|
461
|
+
close_stream(session_id, reason: "client_disconnected")
|
462
|
+
end
|
463
|
+
rescue IOError, Errno::EPIPE, Errno::ECONNRESET, Errno::ENOTCONN, Errno::EBADF
|
464
|
+
close_stream(session_id, reason: "network_error")
|
374
465
|
end
|
375
466
|
end
|
376
467
|
|
@@ -389,60 +480,70 @@ module ModelContextProtocol
|
|
389
480
|
end
|
390
481
|
|
391
482
|
def deliver_to_session_stream(session_id, data)
|
392
|
-
if @
|
483
|
+
if @stream_registry.has_local_stream?(session_id)
|
484
|
+
stream = @stream_registry.get_local_stream(session_id)
|
393
485
|
begin
|
394
|
-
send_to_stream(
|
486
|
+
send_to_stream(stream, data)
|
395
487
|
return true
|
396
488
|
rescue IOError, Errno::EPIPE, Errno::ECONNRESET
|
397
|
-
|
489
|
+
close_stream(session_id, reason: "client_disconnected")
|
398
490
|
end
|
399
491
|
end
|
400
492
|
|
401
|
-
@session_store.
|
493
|
+
@session_store.queue_message_for_session(session_id, data)
|
402
494
|
end
|
403
495
|
|
404
496
|
def cleanup_session(session_id)
|
405
|
-
|
497
|
+
@stream_registry.unregister_stream(session_id)
|
406
498
|
@session_store.cleanup_session(session_id)
|
499
|
+
@request_store.cleanup_session_requests(session_id)
|
407
500
|
end
|
408
501
|
|
409
|
-
def
|
410
|
-
|
411
|
-
@session_store.subscribe_to_server(@server_instance) do |data|
|
412
|
-
session_id = data["session_id"]
|
413
|
-
message = data["message"]
|
414
|
-
|
415
|
-
if @local_streams[session_id]
|
416
|
-
begin
|
417
|
-
send_to_stream(@local_streams[session_id], message)
|
418
|
-
rescue IOError, Errno::EPIPE, Errno::ECONNRESET
|
419
|
-
cleanup_local_stream(session_id)
|
420
|
-
end
|
421
|
-
end
|
422
|
-
end
|
423
|
-
rescue => e
|
424
|
-
@configuration.logger.error("Redis subscriber error", error: e.message, backtrace: e.backtrace.first(5))
|
425
|
-
sleep 5
|
426
|
-
retry
|
427
|
-
end
|
502
|
+
def start_message_poller
|
503
|
+
@message_poller.start
|
428
504
|
end
|
429
505
|
|
430
506
|
def has_active_streams?
|
431
|
-
@
|
507
|
+
@stream_registry.has_any_local_streams?
|
432
508
|
end
|
433
509
|
|
434
510
|
def deliver_to_active_streams(notification)
|
435
|
-
@
|
511
|
+
@stream_registry.get_all_local_streams.each do |session_id, stream|
|
436
512
|
send_to_stream(stream, notification)
|
437
513
|
rescue IOError, Errno::EPIPE, Errno::ECONNRESET
|
438
|
-
|
514
|
+
close_stream(session_id, reason: "client_disconnected")
|
439
515
|
end
|
440
516
|
end
|
441
517
|
|
442
518
|
def flush_notifications_to_stream(stream)
|
443
|
-
|
519
|
+
notifications = @notification_queue.pop_all
|
520
|
+
notifications.each do |notification|
|
444
521
|
send_to_stream(stream, notification)
|
445
522
|
end
|
446
523
|
end
|
524
|
+
|
525
|
+
# Handle a cancellation notification from the client
|
526
|
+
#
|
527
|
+
# @param message [Hash] the cancellation notification message
|
528
|
+
# @param session_id [String, nil] the session ID if available
|
529
|
+
def handle_cancellation(message, session_id = nil)
|
530
|
+
params = message["params"]
|
531
|
+
return unless params
|
532
|
+
|
533
|
+
request_id = params["requestId"]
|
534
|
+
reason = params["reason"]
|
535
|
+
|
536
|
+
return unless request_id
|
537
|
+
|
538
|
+
@request_store.mark_cancelled(request_id, reason)
|
539
|
+
rescue
|
540
|
+
nil
|
541
|
+
end
|
542
|
+
|
543
|
+
def cleanup
|
544
|
+
@message_poller&.stop
|
545
|
+
@stream_monitor_thread&.kill
|
546
|
+
@redis = nil
|
547
|
+
end
|
447
548
|
end
|
448
549
|
end
|
@@ -5,7 +5,9 @@ module ModelContextProtocol
|
|
5
5
|
# Raised when output schema validation fails.
|
6
6
|
class OutputSchemaValidationError < StandardError; end
|
7
7
|
|
8
|
+
include ModelContextProtocol::Server::Cancellable
|
8
9
|
include ModelContextProtocol::Server::ContentHelpers
|
10
|
+
include ModelContextProtocol::Server::Progressable
|
9
11
|
|
10
12
|
attr_reader :arguments, :context, :logger
|
11
13
|
|
@@ -107,6 +109,8 @@ module ModelContextProtocol
|
|
107
109
|
raise ModelContextProtocol::Server::ParameterValidationError, validation_error.message
|
108
110
|
rescue OutputSchemaValidationError, ModelContextProtocol::Server::ResponseArgumentsError => tool_error
|
109
111
|
raise tool_error, tool_error.message
|
112
|
+
rescue Server::Cancellable::CancellationError
|
113
|
+
raise
|
110
114
|
rescue => error
|
111
115
|
ErrorResponse[error: error.message]
|
112
116
|
end
|
@@ -8,7 +8,7 @@ module ModelContextProtocol
|
|
8
8
|
# Raised when invalid parameters are provided.
|
9
9
|
class ParameterValidationError < StandardError; end
|
10
10
|
|
11
|
-
attr_reader :configuration, :router
|
11
|
+
attr_reader :configuration, :router, :transport
|
12
12
|
|
13
13
|
def initialize
|
14
14
|
@configuration = Configuration.new
|
@@ -20,7 +20,7 @@ module ModelContextProtocol
|
|
20
20
|
def start
|
21
21
|
configuration.validate!
|
22
22
|
|
23
|
-
transport = case configuration.transport_type
|
23
|
+
@transport = case configuration.transport_type
|
24
24
|
when :stdio, nil
|
25
25
|
StdioTransport.new(router: @router, configuration: @configuration)
|
26
26
|
when :streamable_http
|
@@ -32,7 +32,7 @@ module ModelContextProtocol
|
|
32
32
|
raise ArgumentError, "Unknown transport: #{configuration.transport_type}"
|
33
33
|
end
|
34
34
|
|
35
|
-
transport.handle
|
35
|
+
@transport.handle
|
36
36
|
end
|
37
37
|
|
38
38
|
private
|
@@ -281,5 +281,11 @@ module ModelContextProtocol
|
|
281
281
|
end
|
282
282
|
end
|
283
283
|
end
|
284
|
+
|
285
|
+
class << self
|
286
|
+
def configure_redis(&block)
|
287
|
+
RedisConfig.configure(&block)
|
288
|
+
end
|
289
|
+
end
|
284
290
|
end
|
285
291
|
end
|