model-context-protocol-rb 0.4.0 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -1
  3. data/README.md +155 -12
  4. data/lib/model_context_protocol/server/cancellable.rb +54 -0
  5. data/lib/model_context_protocol/server/configuration.rb +4 -9
  6. data/lib/model_context_protocol/server/progressable.rb +72 -0
  7. data/lib/model_context_protocol/server/prompt.rb +3 -1
  8. data/lib/model_context_protocol/server/redis_client_proxy.rb +134 -0
  9. data/lib/model_context_protocol/server/redis_config.rb +108 -0
  10. data/lib/model_context_protocol/server/redis_pool_manager.rb +110 -0
  11. data/lib/model_context_protocol/server/resource.rb +3 -0
  12. data/lib/model_context_protocol/server/router.rb +36 -3
  13. data/lib/model_context_protocol/server/stdio_transport/request_store.rb +102 -0
  14. data/lib/model_context_protocol/server/stdio_transport.rb +31 -6
  15. data/lib/model_context_protocol/server/streamable_http_transport/event_counter.rb +35 -0
  16. data/lib/model_context_protocol/server/streamable_http_transport/message_poller.rb +101 -0
  17. data/lib/model_context_protocol/server/streamable_http_transport/notification_queue.rb +80 -0
  18. data/lib/model_context_protocol/server/streamable_http_transport/request_store.rb +224 -0
  19. data/lib/model_context_protocol/server/streamable_http_transport/session_message_queue.rb +120 -0
  20. data/lib/model_context_protocol/server/{session_store.rb → streamable_http_transport/session_store.rb} +30 -16
  21. data/lib/model_context_protocol/server/streamable_http_transport/stream_registry.rb +119 -0
  22. data/lib/model_context_protocol/server/streamable_http_transport.rb +162 -79
  23. data/lib/model_context_protocol/server/tool.rb +4 -0
  24. data/lib/model_context_protocol/server.rb +9 -3
  25. data/lib/model_context_protocol/version.rb +1 -1
  26. data/tasks/templates/dev-http.erb +58 -14
  27. metadata +57 -3
@@ -0,0 +1,224 @@
1
+ require "json"
2
+
3
+ module ModelContextProtocol
4
+ class Server::StreamableHttpTransport
5
+ # Redis-based distributed storage for tracking active requests and their cancellation status.
6
+ # This store is used by StreamableHttpTransport to manage request lifecycle across multiple
7
+ # server instances and handle cancellation in a distributed environment.
8
+ class RequestStore
9
+ REQUEST_KEY_PREFIX = "request:active:"
10
+ CANCELLED_KEY_PREFIX = "request:cancelled:"
11
+ SESSION_KEY_PREFIX = "request:session:"
12
+ DEFAULT_TTL = 60 # 1 minute TTL for request entries
13
+
14
+ def initialize(redis_client, server_instance, ttl: DEFAULT_TTL)
15
+ @redis = redis_client
16
+ @server_instance = server_instance
17
+ @ttl = ttl
18
+ end
19
+
20
+ # Register a new request with its associated session
21
+ #
22
+ # @param request_id [String] the unique request identifier
23
+ # @param session_id [String] the session identifier (can be nil for sessionless requests)
24
+ # @return [void]
25
+ def register_request(request_id, session_id = nil)
26
+ request_data = {
27
+ session_id: session_id,
28
+ server_instance: @server_instance,
29
+ started_at: Time.now.to_f
30
+ }
31
+
32
+ @redis.multi do |multi|
33
+ multi.set("#{REQUEST_KEY_PREFIX}#{request_id}",
34
+ request_data.to_json, ex: @ttl)
35
+
36
+ if session_id
37
+ multi.set("#{SESSION_KEY_PREFIX}#{session_id}:#{request_id}",
38
+ true, ex: @ttl)
39
+ end
40
+ end
41
+ end
42
+
43
+ # Mark a request as cancelled
44
+ #
45
+ # @param request_id [String] the unique request identifier
46
+ # @param reason [String] optional reason for cancellation
47
+ # @return [Boolean] true if cancellation was recorded
48
+ def mark_cancelled(request_id, reason = nil)
49
+ cancellation_data = {
50
+ cancelled_at: Time.now.to_f,
51
+ reason: reason
52
+ }
53
+
54
+ result = @redis.set("#{CANCELLED_KEY_PREFIX}#{request_id}",
55
+ cancellation_data.to_json, ex: @ttl)
56
+ result == "OK"
57
+ end
58
+
59
+ # Check if a request has been cancelled
60
+ #
61
+ # @param request_id [String] the unique request identifier
62
+ # @return [Boolean] true if the request is cancelled, false otherwise
63
+ def cancelled?(request_id)
64
+ @redis.exists("#{CANCELLED_KEY_PREFIX}#{request_id}") == 1
65
+ end
66
+
67
+ # Get cancellation information for a request
68
+ #
69
+ # @param request_id [String] the unique request identifier
70
+ # @return [Hash, nil] cancellation data or nil if not cancelled
71
+ def get_cancellation_info(request_id)
72
+ data = @redis.get("#{CANCELLED_KEY_PREFIX}#{request_id}")
73
+ data ? JSON.parse(data) : nil
74
+ rescue JSON::ParserError
75
+ nil
76
+ end
77
+
78
+ # Unregister a request (typically called when request completes)
79
+ #
80
+ # @param request_id [String] the unique request identifier
81
+ # @return [void]
82
+ def unregister_request(request_id)
83
+ request_data = @redis.get("#{REQUEST_KEY_PREFIX}#{request_id}")
84
+
85
+ keys_to_delete = ["#{REQUEST_KEY_PREFIX}#{request_id}",
86
+ "#{CANCELLED_KEY_PREFIX}#{request_id}"]
87
+
88
+ if request_data
89
+ begin
90
+ data = JSON.parse(request_data)
91
+ session_id = data["session_id"]
92
+
93
+ if session_id
94
+ keys_to_delete << "#{SESSION_KEY_PREFIX}#{session_id}:#{request_id}"
95
+ end
96
+ rescue JSON::ParserError
97
+ nil
98
+ end
99
+ end
100
+
101
+ @redis.del(*keys_to_delete) unless keys_to_delete.empty?
102
+ end
103
+
104
+ # Get information about a specific request
105
+ #
106
+ # @param request_id [String] the unique request identifier
107
+ # @return [Hash, nil] request information or nil if not found
108
+ def get_request(request_id)
109
+ data = @redis.get("#{REQUEST_KEY_PREFIX}#{request_id}")
110
+ data ? JSON.parse(data) : nil
111
+ rescue JSON::ParserError
112
+ nil
113
+ end
114
+
115
+ # Check if a request is currently active
116
+ #
117
+ # @param request_id [String] the unique request identifier
118
+ # @return [Boolean] true if the request is active, false otherwise
119
+ def active?(request_id)
120
+ @redis.exists("#{REQUEST_KEY_PREFIX}#{request_id}") == 1
121
+ end
122
+
123
+ # Clean up all requests associated with a session
124
+ # This is typically called when a session is terminated
125
+ #
126
+ # @param session_id [String] the session identifier
127
+ # @return [Array<String>] list of cleaned up request IDs
128
+ def cleanup_session_requests(session_id)
129
+ pattern = "#{SESSION_KEY_PREFIX}#{session_id}:*"
130
+ request_keys = @redis.keys(pattern)
131
+ return [] if request_keys.empty?
132
+
133
+ # Extract request IDs from the keys
134
+ request_ids = request_keys.map do |key|
135
+ key.sub("#{SESSION_KEY_PREFIX}#{session_id}:", "")
136
+ end
137
+
138
+ # Delete all related keys
139
+ all_keys = []
140
+ request_ids.each do |request_id|
141
+ all_keys << "#{REQUEST_KEY_PREFIX}#{request_id}"
142
+ all_keys << "#{CANCELLED_KEY_PREFIX}#{request_id}"
143
+ end
144
+ all_keys.concat(request_keys)
145
+
146
+ @redis.del(*all_keys) unless all_keys.empty?
147
+ request_ids
148
+ end
149
+
150
+ # Get all active request IDs for a specific session
151
+ #
152
+ # @param session_id [String] the session identifier
153
+ # @return [Array<String>] list of active request IDs for the session
154
+ def get_session_requests(session_id)
155
+ pattern = "#{SESSION_KEY_PREFIX}#{session_id}:*"
156
+ request_keys = @redis.keys(pattern)
157
+
158
+ request_keys.map do |key|
159
+ key.sub("#{SESSION_KEY_PREFIX}#{session_id}:", "")
160
+ end
161
+ end
162
+
163
+ # Get all active request IDs across all sessions
164
+ #
165
+ # @return [Array<String>] list of all active request IDs
166
+ def get_all_active_requests
167
+ pattern = "#{REQUEST_KEY_PREFIX}*"
168
+ request_keys = @redis.keys(pattern)
169
+
170
+ request_keys.map do |key|
171
+ key.sub(REQUEST_KEY_PREFIX, "")
172
+ end
173
+ end
174
+
175
+ # Clean up expired requests based on TTL
176
+ # This method can be called periodically to ensure cleanup
177
+ #
178
+ # @return [Integer] number of expired requests cleaned up
179
+ def cleanup_expired_requests
180
+ active_keys = @redis.keys("#{REQUEST_KEY_PREFIX}*")
181
+ expired_count = 0
182
+ key_exists_without_expiration = -1
183
+ key_does_not_exist = -2
184
+
185
+ active_keys.each do |key|
186
+ ttl = @redis.ttl(key)
187
+ if ttl == key_exists_without_expiration
188
+ @redis.expire(key, @ttl)
189
+ elsif ttl == key_does_not_exist
190
+ expired_count += 1
191
+ end
192
+ end
193
+
194
+ expired_count
195
+ end
196
+
197
+ # Refresh the TTL for an active request
198
+ #
199
+ # @param request_id [String] the unique request identifier
200
+ # @return [Boolean] true if TTL was refreshed, false if request doesn't exist
201
+ def refresh_request_ttl(request_id)
202
+ request_data = @redis.get("#{REQUEST_KEY_PREFIX}#{request_id}")
203
+ return false unless request_data
204
+
205
+ @redis.multi do |multi|
206
+ multi.expire("#{REQUEST_KEY_PREFIX}#{request_id}", @ttl)
207
+ multi.expire("#{CANCELLED_KEY_PREFIX}#{request_id}", @ttl)
208
+
209
+ begin
210
+ data = JSON.parse(request_data)
211
+ session_id = data["session_id"]
212
+ if session_id
213
+ multi.expire("#{SESSION_KEY_PREFIX}#{session_id}:#{request_id}", @ttl)
214
+ end
215
+ rescue JSON::ParserError
216
+ nil
217
+ end
218
+ end
219
+
220
+ true
221
+ end
222
+ end
223
+ end
224
+ end
@@ -0,0 +1,120 @@
1
+ require "json"
2
+ require "securerandom"
3
+
4
+ module ModelContextProtocol
5
+ class Server::StreamableHttpTransport
6
+ class SessionMessageQueue
7
+ QUEUE_KEY_PREFIX = "session_messages:"
8
+ LOCK_KEY_PREFIX = "session_lock:"
9
+ DEFAULT_TTL = 3600 # 1 hour
10
+ MAX_MESSAGES = 1000
11
+ LOCK_TIMEOUT = 5 # seconds
12
+
13
+ def initialize(redis_client, session_id, ttl: DEFAULT_TTL)
14
+ @redis = redis_client
15
+ @session_id = session_id
16
+ @queue_key = "#{QUEUE_KEY_PREFIX}#{session_id}"
17
+ @lock_key = "#{LOCK_KEY_PREFIX}#{session_id}"
18
+ @ttl = ttl
19
+ end
20
+
21
+ def push_message(message)
22
+ message_json = serialize_message(message)
23
+
24
+ @redis.multi do |multi|
25
+ multi.lpush(@queue_key, message_json)
26
+ multi.expire(@queue_key, @ttl)
27
+ multi.ltrim(@queue_key, 0, MAX_MESSAGES - 1)
28
+ end
29
+ end
30
+
31
+ def push_messages(messages)
32
+ return if messages.empty?
33
+
34
+ message_jsons = messages.map { |msg| serialize_message(msg) }
35
+
36
+ @redis.multi do |multi|
37
+ message_jsons.each do |json|
38
+ multi.lpush(@queue_key, json)
39
+ end
40
+ multi.expire(@queue_key, @ttl)
41
+ multi.ltrim(@queue_key, 0, MAX_MESSAGES - 1)
42
+ end
43
+ end
44
+
45
+ def poll_messages
46
+ lua_script = <<~LUA
47
+ local messages = redis.call('lrange', KEYS[1], 0, -1)
48
+ if #messages > 0 then
49
+ redis.call('del', KEYS[1])
50
+ end
51
+ return messages
52
+ LUA
53
+
54
+ messages = @redis.eval(lua_script, keys: [@queue_key])
55
+ return [] unless messages && !messages.empty?
56
+ messages.reverse.map { |json| deserialize_message(json) }
57
+ rescue
58
+ []
59
+ end
60
+
61
+ def peek_messages
62
+ messages = @redis.lrange(@queue_key, 0, -1)
63
+ messages.reverse.map { |json| deserialize_message(json) }
64
+ rescue
65
+ []
66
+ end
67
+
68
+ def has_messages?
69
+ @redis.exists(@queue_key) > 0
70
+ rescue
71
+ false
72
+ end
73
+
74
+ def message_count
75
+ @redis.llen(@queue_key)
76
+ rescue
77
+ 0
78
+ end
79
+
80
+ def clear
81
+ @redis.del(@queue_key)
82
+ rescue
83
+ end
84
+
85
+ def with_lock(timeout: LOCK_TIMEOUT, &block)
86
+ lock_id = SecureRandom.hex(16)
87
+
88
+ acquired = @redis.set(@lock_key, lock_id, nx: true, ex: timeout)
89
+ return false unless acquired
90
+
91
+ begin
92
+ yield
93
+ ensure
94
+ lua_script = <<~LUA
95
+ if redis.call("get", KEYS[1]) == ARGV[1] then
96
+ return redis.call("del", KEYS[1])
97
+ else
98
+ return 0
99
+ end
100
+ LUA
101
+ @redis.eval(lua_script, keys: [@lock_key], argv: [lock_id])
102
+ end
103
+
104
+ true
105
+ end
106
+
107
+ private
108
+
109
+ def serialize_message(message)
110
+ message.is_a?(String) ? message : message.to_json
111
+ end
112
+
113
+ def deserialize_message(json)
114
+ JSON.parse(json)
115
+ rescue JSON::ParserError
116
+ json
117
+ end
118
+ end
119
+ end
120
+ end
@@ -1,10 +1,9 @@
1
- # frozen_string_literal: true
2
-
3
1
  require "json"
4
2
  require "securerandom"
3
+ require_relative "session_message_queue"
5
4
 
6
5
  module ModelContextProtocol
7
- class Server
6
+ class Server::StreamableHttpTransport
8
7
  class SessionStore
9
8
  def initialize(redis_client, ttl: 3600)
10
9
  @redis = redis_client
@@ -69,25 +68,40 @@ module ModelContextProtocol
69
68
  @redis.del("session:#{session_id}")
70
69
  end
71
70
 
72
- def route_message_to_session(session_id, message)
73
- server_instance = get_session_server(session_id)
74
- return false unless server_instance
71
+ def queue_message_for_session(session_id, message)
72
+ return false unless session_exists?(session_id)
75
73
 
76
- # Publish to server-specific channel
77
- @redis.publish("server:#{server_instance}:messages", {
78
- session_id: session_id,
79
- message: message
80
- }.to_json)
74
+ queue = SessionMessageQueue.new(@redis, session_id, ttl: @ttl)
75
+ queue.push_message(message)
81
76
  true
77
+ rescue
78
+ false
79
+ end
80
+
81
+ def poll_messages_for_session(session_id)
82
+ return [] unless session_exists?(session_id)
83
+
84
+ queue = SessionMessageQueue.new(@redis, session_id, ttl: @ttl)
85
+ queue.poll_messages
86
+ rescue
87
+ []
82
88
  end
83
89
 
84
- def subscribe_to_server(server_instance, &block)
85
- @redis.subscribe("server:#{server_instance}:messages") do |on|
86
- on.message do |channel, message|
87
- data = JSON.parse(message)
88
- yield(data)
90
+ def get_sessions_with_messages
91
+ session_keys = @redis.keys("session:*")
92
+ sessions_with_messages = []
93
+
94
+ session_keys.each do |key|
95
+ session_id = key.sub("session:", "")
96
+ queue = SessionMessageQueue.new(@redis, session_id, ttl: @ttl)
97
+ if queue.has_messages?
98
+ sessions_with_messages << session_id
89
99
  end
90
100
  end
101
+
102
+ sessions_with_messages
103
+ rescue
104
+ []
91
105
  end
92
106
 
93
107
  def get_all_active_sessions
@@ -0,0 +1,119 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "json"
4
+
5
+ module ModelContextProtocol
6
+ class Server::StreamableHttpTransport
7
+ class StreamRegistry
8
+ STREAM_KEY_PREFIX = "stream:active:"
9
+ HEARTBEAT_KEY_PREFIX = "stream:heartbeat:"
10
+ DEFAULT_TTL = 60 # 1 minute TTL for stream entries
11
+
12
+ def initialize(redis_client, server_instance, ttl: DEFAULT_TTL)
13
+ @redis = redis_client
14
+ @server_instance = server_instance
15
+ @ttl = ttl
16
+ @local_streams = {} # Keep local reference for direct stream access
17
+ end
18
+
19
+ def register_stream(session_id, stream)
20
+ @local_streams[session_id] = stream
21
+
22
+ # Store stream registration in Redis with TTL
23
+ @redis.multi do |multi|
24
+ multi.set("#{STREAM_KEY_PREFIX}#{session_id}", @server_instance, ex: @ttl)
25
+ multi.set("#{HEARTBEAT_KEY_PREFIX}#{session_id}", Time.now.to_f, ex: @ttl)
26
+ end
27
+ end
28
+
29
+ def unregister_stream(session_id)
30
+ @local_streams.delete(session_id)
31
+
32
+ @redis.multi do |multi|
33
+ multi.del("#{STREAM_KEY_PREFIX}#{session_id}")
34
+ multi.del("#{HEARTBEAT_KEY_PREFIX}#{session_id}")
35
+ end
36
+ end
37
+
38
+ def get_local_stream(session_id)
39
+ @local_streams[session_id]
40
+ end
41
+
42
+ def has_local_stream?(session_id)
43
+ @local_streams.key?(session_id)
44
+ end
45
+
46
+ def get_stream_server(session_id)
47
+ @redis.get("#{STREAM_KEY_PREFIX}#{session_id}")
48
+ end
49
+
50
+ def stream_active?(session_id)
51
+ @redis.exists("#{STREAM_KEY_PREFIX}#{session_id}") == 1
52
+ end
53
+
54
+ def refresh_heartbeat(session_id)
55
+ @redis.multi do |multi|
56
+ multi.set("#{HEARTBEAT_KEY_PREFIX}#{session_id}", Time.now.to_f, ex: @ttl)
57
+ multi.expire("#{STREAM_KEY_PREFIX}#{session_id}", @ttl)
58
+ end
59
+ end
60
+
61
+ def get_all_local_streams
62
+ @local_streams.dup
63
+ end
64
+
65
+ def has_any_local_streams?
66
+ !@local_streams.empty?
67
+ end
68
+
69
+ def cleanup_expired_streams
70
+ # Get all local stream session IDs
71
+ local_session_ids = @local_streams.keys
72
+
73
+ # Check which ones are still active in Redis
74
+ pipeline_results = @redis.pipelined do |pipeline|
75
+ local_session_ids.each do |session_id|
76
+ pipeline.exists("#{STREAM_KEY_PREFIX}#{session_id}")
77
+ end
78
+ end
79
+
80
+ # Remove expired streams from local storage
81
+ expired_sessions = []
82
+ local_session_ids.each_with_index do |session_id, index|
83
+ if pipeline_results[index] == 0 # Stream expired in Redis
84
+ @local_streams.delete(session_id)
85
+ expired_sessions << session_id
86
+ end
87
+ end
88
+
89
+ expired_sessions
90
+ end
91
+
92
+ def get_stale_streams(max_age_seconds = 90)
93
+ current_time = Time.now.to_f
94
+ stale_streams = []
95
+
96
+ # Get all heartbeat keys
97
+ heartbeat_keys = @redis.keys("#{HEARTBEAT_KEY_PREFIX}*")
98
+
99
+ return stale_streams if heartbeat_keys.empty?
100
+
101
+ # Get all heartbeat timestamps
102
+ heartbeat_values = @redis.mget(heartbeat_keys)
103
+
104
+ heartbeat_keys.each_with_index do |key, index|
105
+ next unless heartbeat_values[index]
106
+
107
+ session_id = key.sub(HEARTBEAT_KEY_PREFIX, "")
108
+ last_heartbeat = heartbeat_values[index].to_f
109
+
110
+ if current_time - last_heartbeat > max_age_seconds
111
+ stale_streams << session_id
112
+ end
113
+ end
114
+
115
+ stale_streams
116
+ end
117
+ end
118
+ end
119
+ end