model-context-protocol-rb 0.5.1 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,231 @@
1
+ require "json"
2
+
3
+ module ModelContextProtocol
4
+ class Server::StreamableHttpTransport
5
+ # Redis-based distributed storage for tracking server-initiated requests and their response status.
6
+ # This store is used by StreamableHttpTransport to manage outgoing request lifecycle (like pings)
7
+ # across multiple server instances and handle timeouts in a distributed environment.
8
+ class ServerRequestStore
9
+ REQUEST_KEY_PREFIX = "server_request:pending:"
10
+ SESSION_KEY_PREFIX = "server_request:session:"
11
+ DEFAULT_TTL = 60 # 1 minute TTL for request entries
12
+
13
+ def initialize(redis_client, server_instance, ttl: DEFAULT_TTL)
14
+ @redis = redis_client
15
+ @server_instance = server_instance
16
+ @ttl = ttl
17
+ end
18
+
19
+ # Register a new server-initiated request with its associated session
20
+ #
21
+ # @param request_id [String] the unique JSON-RPC request identifier
22
+ # @param session_id [String] the session identifier (can be nil for sessionless requests)
23
+ # @param type [Symbol] the type of request (e.g., :ping)
24
+ # @return [void]
25
+ def register_request(request_id, session_id = nil, type: :ping)
26
+ request_data = {
27
+ session_id: session_id,
28
+ server_instance: @server_instance,
29
+ type: type.to_s,
30
+ created_at: Time.now.to_f
31
+ }
32
+
33
+ @redis.multi do |multi|
34
+ multi.set("#{REQUEST_KEY_PREFIX}#{request_id}",
35
+ request_data.to_json, ex: @ttl)
36
+
37
+ if session_id
38
+ multi.set("#{SESSION_KEY_PREFIX}#{session_id}:#{request_id}",
39
+ true, ex: @ttl)
40
+ end
41
+ end
42
+ end
43
+
44
+ # Mark a server-initiated request as completed (response received)
45
+ #
46
+ # @param request_id [String] the unique JSON-RPC request identifier
47
+ # @return [Boolean] true if request was pending, false if not found
48
+ def mark_completed(request_id)
49
+ request_data = @redis.get("#{REQUEST_KEY_PREFIX}#{request_id}")
50
+ return false unless request_data
51
+
52
+ unregister_request(request_id)
53
+ true
54
+ end
55
+
56
+ # Check if a server-initiated request is still pending
57
+ #
58
+ # @param request_id [String] the unique JSON-RPC request identifier
59
+ # @return [Boolean] true if the request is pending, false otherwise
60
+ def pending?(request_id)
61
+ @redis.exists("#{REQUEST_KEY_PREFIX}#{request_id}") == 1
62
+ end
63
+
64
+ # Get information about a specific pending request
65
+ #
66
+ # @param request_id [String] the unique JSON-RPC request identifier
67
+ # @return [Hash, nil] request information or nil if not found
68
+ def get_request(request_id)
69
+ data = @redis.get("#{REQUEST_KEY_PREFIX}#{request_id}")
70
+ data ? JSON.parse(data) : nil
71
+ rescue JSON::ParserError
72
+ nil
73
+ end
74
+
75
+ # Find requests that have exceeded the specified timeout
76
+ #
77
+ # @param timeout_seconds [Integer] timeout in seconds
78
+ # @return [Array<Hash>] array of expired request info with request_id and session_id
79
+ def get_expired_requests(timeout_seconds)
80
+ current_time = Time.now.to_f
81
+ expired_requests = []
82
+
83
+ # Get all pending request keys
84
+ request_keys = @redis.keys("#{REQUEST_KEY_PREFIX}*")
85
+ return expired_requests if request_keys.empty?
86
+
87
+ # Get all request data in batch
88
+ request_values = @redis.mget(request_keys)
89
+
90
+ request_keys.each_with_index do |key, index|
91
+ next unless request_values[index]
92
+
93
+ begin
94
+ request_data = JSON.parse(request_values[index])
95
+ created_at = request_data["created_at"]
96
+
97
+ if created_at && (current_time - created_at) > timeout_seconds
98
+ request_id = key.sub(REQUEST_KEY_PREFIX, "")
99
+ expired_requests << {
100
+ request_id: request_id,
101
+ session_id: request_data["session_id"],
102
+ type: request_data["type"],
103
+ age: current_time - created_at
104
+ }
105
+ end
106
+ rescue JSON::ParserError
107
+ # Skip malformed entries
108
+ next
109
+ end
110
+ end
111
+
112
+ expired_requests
113
+ end
114
+
115
+ # Clean up expired requests based on timeout
116
+ #
117
+ # @param timeout_seconds [Integer] timeout in seconds
118
+ # @return [Array<String>] list of cleaned up request IDs
119
+ def cleanup_expired_requests(timeout_seconds)
120
+ expired_requests = get_expired_requests(timeout_seconds)
121
+
122
+ expired_requests.each do |request_info|
123
+ unregister_request(request_info[:request_id])
124
+ end
125
+
126
+ expired_requests.map { |r| r[:request_id] }
127
+ end
128
+
129
+ # Unregister a request (typically called when request completes or times out)
130
+ #
131
+ # @param request_id [String] the unique JSON-RPC request identifier
132
+ # @return [void]
133
+ def unregister_request(request_id)
134
+ request_data = @redis.get("#{REQUEST_KEY_PREFIX}#{request_id}")
135
+
136
+ keys_to_delete = ["#{REQUEST_KEY_PREFIX}#{request_id}"]
137
+
138
+ if request_data
139
+ begin
140
+ data = JSON.parse(request_data)
141
+ session_id = data["session_id"]
142
+
143
+ if session_id
144
+ keys_to_delete << "#{SESSION_KEY_PREFIX}#{session_id}:#{request_id}"
145
+ end
146
+ rescue JSON::ParserError
147
+ nil
148
+ end
149
+ end
150
+
151
+ @redis.del(*keys_to_delete) unless keys_to_delete.empty?
152
+ end
153
+
154
+ # Clean up all server requests associated with a session
155
+ # This is typically called when a session is terminated
156
+ #
157
+ # @param session_id [String] the session identifier
158
+ # @return [Array<String>] list of cleaned up request IDs
159
+ def cleanup_session_requests(session_id)
160
+ pattern = "#{SESSION_KEY_PREFIX}#{session_id}:*"
161
+ request_keys = @redis.keys(pattern)
162
+ return [] if request_keys.empty?
163
+
164
+ # Extract request IDs from the keys
165
+ request_ids = request_keys.map do |key|
166
+ key.sub("#{SESSION_KEY_PREFIX}#{session_id}:", "")
167
+ end
168
+
169
+ # Delete all related keys
170
+ all_keys = []
171
+ request_ids.each do |request_id|
172
+ all_keys << "#{REQUEST_KEY_PREFIX}#{request_id}"
173
+ end
174
+ all_keys.concat(request_keys)
175
+
176
+ @redis.del(*all_keys) unless all_keys.empty?
177
+ request_ids
178
+ end
179
+
180
+ # Get all pending request IDs for a specific session
181
+ #
182
+ # @param session_id [String] the session identifier
183
+ # @return [Array<String>] list of pending request IDs for the session
184
+ def get_session_requests(session_id)
185
+ pattern = "#{SESSION_KEY_PREFIX}#{session_id}:*"
186
+ request_keys = @redis.keys(pattern)
187
+
188
+ request_keys.map do |key|
189
+ key.sub("#{SESSION_KEY_PREFIX}#{session_id}:", "")
190
+ end
191
+ end
192
+
193
+ # Get all pending request IDs across all sessions
194
+ #
195
+ # @return [Array<String>] list of all pending request IDs
196
+ def get_all_pending_requests
197
+ pattern = "#{REQUEST_KEY_PREFIX}*"
198
+ request_keys = @redis.keys(pattern)
199
+
200
+ request_keys.map do |key|
201
+ key.sub(REQUEST_KEY_PREFIX, "")
202
+ end
203
+ end
204
+
205
+ # Refresh the TTL for a pending request
206
+ #
207
+ # @param request_id [String] the unique JSON-RPC request identifier
208
+ # @return [Boolean] true if TTL was refreshed, false if request doesn't exist
209
+ def refresh_request_ttl(request_id)
210
+ request_data = @redis.get("#{REQUEST_KEY_PREFIX}#{request_id}")
211
+ return false unless request_data
212
+
213
+ @redis.multi do |multi|
214
+ multi.expire("#{REQUEST_KEY_PREFIX}#{request_id}", @ttl)
215
+
216
+ begin
217
+ data = JSON.parse(request_data)
218
+ session_id = data["session_id"]
219
+ if session_id
220
+ multi.expire("#{SESSION_KEY_PREFIX}#{session_id}:#{request_id}", @ttl)
221
+ end
222
+ rescue JSON::ParserError
223
+ nil
224
+ end
225
+ end
226
+
227
+ true
228
+ end
229
+ end
230
+ end
231
+ end