model-context-protocol-rb 0.4.0 → 0.5.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +14 -1
- data/README.md +337 -158
- data/lib/model_context_protocol/server/cancellable.rb +54 -0
- data/lib/model_context_protocol/server/configuration.rb +4 -9
- data/lib/model_context_protocol/server/progressable.rb +72 -0
- data/lib/model_context_protocol/server/prompt.rb +3 -1
- data/lib/model_context_protocol/server/redis_client_proxy.rb +134 -0
- data/lib/model_context_protocol/server/redis_config.rb +108 -0
- data/lib/model_context_protocol/server/redis_pool_manager.rb +110 -0
- data/lib/model_context_protocol/server/resource.rb +3 -0
- data/lib/model_context_protocol/server/router.rb +36 -3
- data/lib/model_context_protocol/server/stdio_transport/request_store.rb +102 -0
- data/lib/model_context_protocol/server/stdio_transport.rb +31 -6
- data/lib/model_context_protocol/server/streamable_http_transport/event_counter.rb +35 -0
- data/lib/model_context_protocol/server/streamable_http_transport/message_poller.rb +101 -0
- data/lib/model_context_protocol/server/streamable_http_transport/notification_queue.rb +80 -0
- data/lib/model_context_protocol/server/streamable_http_transport/request_store.rb +224 -0
- data/lib/model_context_protocol/server/streamable_http_transport/session_message_queue.rb +120 -0
- data/lib/model_context_protocol/server/{session_store.rb → streamable_http_transport/session_store.rb} +30 -16
- data/lib/model_context_protocol/server/streamable_http_transport/stream_registry.rb +119 -0
- data/lib/model_context_protocol/server/streamable_http_transport.rb +181 -80
- data/lib/model_context_protocol/server/tool.rb +4 -0
- data/lib/model_context_protocol/server.rb +9 -3
- data/lib/model_context_protocol/version.rb +1 -1
- data/tasks/templates/dev-http.erb +58 -14
- metadata +57 -3
@@ -0,0 +1,54 @@
|
|
1
|
+
require "concurrent-ruby"
|
2
|
+
|
3
|
+
module ModelContextProtocol
|
4
|
+
module Server::Cancellable
|
5
|
+
# Raised when a request has been cancelled by the client
|
6
|
+
class CancellationError < StandardError; end
|
7
|
+
|
8
|
+
# Execute a block with automatic cancellation support for blocking I/O operations.
|
9
|
+
# This method uses Concurrent::TimerTask to poll for cancellation every 100ms
|
10
|
+
# and can interrupt even blocking operations like HTTP requests or database queries.
|
11
|
+
#
|
12
|
+
# @param interval [Float] polling interval in seconds (default: 0.1)
|
13
|
+
# @yield block to execute with cancellation support
|
14
|
+
# @return [Object] the result of the block
|
15
|
+
# @raise [CancellationError] if the request is cancelled during execution
|
16
|
+
#
|
17
|
+
# @example
|
18
|
+
# cancellable do
|
19
|
+
# response = Net::HTTP.get(URI('https://slow-api.example.com'))
|
20
|
+
# process_response(response)
|
21
|
+
# end
|
22
|
+
def cancellable(interval: 0.1, &block)
|
23
|
+
context = Thread.current[:mcp_context]
|
24
|
+
executing_thread = Concurrent::AtomicReference.new(nil)
|
25
|
+
|
26
|
+
timer_task = Concurrent::TimerTask.new(execution_interval: interval) do
|
27
|
+
if context && context[:request_store] && context[:request_id]
|
28
|
+
if context[:request_store].cancelled?(context[:request_id])
|
29
|
+
thread = executing_thread.get
|
30
|
+
thread&.raise(CancellationError, "Request was cancelled") if thread&.alive?
|
31
|
+
end
|
32
|
+
end
|
33
|
+
end
|
34
|
+
|
35
|
+
begin
|
36
|
+
executing_thread.set(Thread.current)
|
37
|
+
|
38
|
+
if context && context[:request_store] && context[:request_id]
|
39
|
+
if context[:request_store].cancelled?(context[:request_id])
|
40
|
+
raise CancellationError, "Request #{context[:request_id]} was cancelled"
|
41
|
+
end
|
42
|
+
end
|
43
|
+
|
44
|
+
timer_task.execute
|
45
|
+
|
46
|
+
result = block.call
|
47
|
+
result
|
48
|
+
ensure
|
49
|
+
executing_thread.set(nil)
|
50
|
+
timer_task&.shutdown if timer_task&.running?
|
51
|
+
end
|
52
|
+
end
|
53
|
+
end
|
54
|
+
end
|
@@ -200,15 +200,10 @@ module ModelContextProtocol
|
|
200
200
|
end
|
201
201
|
|
202
202
|
def validate_streamable_http_transport!
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
end
|
208
|
-
|
209
|
-
redis_client = options[:redis_client]
|
210
|
-
unless redis_client.respond_to?(:hset) && redis_client.respond_to?(:expire)
|
211
|
-
raise InvalidTransportError, "redis_client must be a Redis-compatible client"
|
203
|
+
unless ModelContextProtocol::Server::RedisConfig.configured?
|
204
|
+
raise InvalidTransportError,
|
205
|
+
"streamable_http transport requires Redis configuration. " \
|
206
|
+
"Call ModelContextProtocol::Server.configure_redis in an initializer."
|
212
207
|
end
|
213
208
|
end
|
214
209
|
|
@@ -0,0 +1,72 @@
|
|
1
|
+
require "concurrent-ruby"
|
2
|
+
|
3
|
+
module ModelContextProtocol
|
4
|
+
module Server::Progressable
|
5
|
+
# Execute a block with automatic time-based progress reporting.
|
6
|
+
# Uses Concurrent::TimerTask to send progress notifications at regular intervals.
|
7
|
+
#
|
8
|
+
# @param max_duration [Numeric] Expected duration in seconds
|
9
|
+
# @param message [String, nil] Optional custom progress message
|
10
|
+
# @yield block to execute with progress tracking
|
11
|
+
# @return [Object] the result of the block
|
12
|
+
#
|
13
|
+
# @example
|
14
|
+
# progressable(max_duration: 30) do # 30 seconds
|
15
|
+
# perform_long_operation
|
16
|
+
# end
|
17
|
+
def progressable(max_duration:, message: nil, &block)
|
18
|
+
context = Thread.current[:mcp_context]
|
19
|
+
|
20
|
+
return yield unless context && context[:progress_token] && context[:transport]
|
21
|
+
|
22
|
+
progress_token = context[:progress_token]
|
23
|
+
transport = context[:transport]
|
24
|
+
start_time = Time.now
|
25
|
+
update_interval = [1.0, max_duration * 0.05].max
|
26
|
+
|
27
|
+
timer_task = Concurrent::TimerTask.new(execution_interval: update_interval) do
|
28
|
+
elapsed_seconds = Time.now - start_time
|
29
|
+
progress_pct = [(elapsed_seconds / max_duration) * 100, 99].min
|
30
|
+
|
31
|
+
progress_message = if message
|
32
|
+
"#{message} (#{elapsed_seconds.round(1)}s / ~#{max_duration}s)"
|
33
|
+
else
|
34
|
+
"Processing... (#{elapsed_seconds.round(1)}s / ~#{max_duration}s)"
|
35
|
+
end
|
36
|
+
|
37
|
+
begin
|
38
|
+
transport.send_notification("notifications/progress", {
|
39
|
+
progressToken: progress_token,
|
40
|
+
progress: progress_pct.round(1),
|
41
|
+
total: 100,
|
42
|
+
message: progress_message
|
43
|
+
})
|
44
|
+
rescue
|
45
|
+
nil
|
46
|
+
end
|
47
|
+
|
48
|
+
timer_task.shutdown if elapsed_seconds >= max_duration
|
49
|
+
end
|
50
|
+
|
51
|
+
begin
|
52
|
+
timer_task.execute
|
53
|
+
result = yield
|
54
|
+
|
55
|
+
begin
|
56
|
+
transport.send_notification("notifications/progress", {
|
57
|
+
progressToken: progress_token,
|
58
|
+
progress: 100,
|
59
|
+
total: 100,
|
60
|
+
message: "Completed"
|
61
|
+
})
|
62
|
+
rescue
|
63
|
+
nil
|
64
|
+
end
|
65
|
+
|
66
|
+
result
|
67
|
+
ensure
|
68
|
+
timer_task&.shutdown if timer_task&.running?
|
69
|
+
end
|
70
|
+
end
|
71
|
+
end
|
72
|
+
end
|
@@ -1,6 +1,8 @@
|
|
1
1
|
module ModelContextProtocol
|
2
2
|
class Server::Prompt
|
3
|
-
include Server::
|
3
|
+
include ModelContextProtocol::Server::Cancellable
|
4
|
+
include ModelContextProtocol::Server::ContentHelpers
|
5
|
+
include ModelContextProtocol::Server::Progressable
|
4
6
|
|
5
7
|
attr_reader :arguments, :context, :logger
|
6
8
|
|
@@ -0,0 +1,134 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module ModelContextProtocol
|
4
|
+
class Server
|
5
|
+
class RedisClientProxy
|
6
|
+
def initialize(pool)
|
7
|
+
@pool = pool
|
8
|
+
end
|
9
|
+
|
10
|
+
def get(key)
|
11
|
+
with_connection { |redis| redis.get(key) }
|
12
|
+
end
|
13
|
+
|
14
|
+
def set(key, value, **options)
|
15
|
+
with_connection { |redis| redis.set(key, value, **options) }
|
16
|
+
end
|
17
|
+
|
18
|
+
def del(*keys)
|
19
|
+
with_connection { |redis| redis.del(*keys) }
|
20
|
+
end
|
21
|
+
|
22
|
+
def exists(*keys)
|
23
|
+
with_connection { |redis| redis.exists(*keys) }
|
24
|
+
end
|
25
|
+
|
26
|
+
def expire(key, seconds)
|
27
|
+
with_connection { |redis| redis.expire(key, seconds) }
|
28
|
+
end
|
29
|
+
|
30
|
+
def ttl(key)
|
31
|
+
with_connection { |redis| redis.ttl(key) }
|
32
|
+
end
|
33
|
+
|
34
|
+
def hget(key, field)
|
35
|
+
with_connection { |redis| redis.hget(key, field) }
|
36
|
+
end
|
37
|
+
|
38
|
+
def hset(key, *args)
|
39
|
+
with_connection { |redis| redis.hset(key, *args) }
|
40
|
+
end
|
41
|
+
|
42
|
+
def hgetall(key)
|
43
|
+
with_connection { |redis| redis.hgetall(key) }
|
44
|
+
end
|
45
|
+
|
46
|
+
def lpush(key, *values)
|
47
|
+
with_connection { |redis| redis.lpush(key, *values) }
|
48
|
+
end
|
49
|
+
|
50
|
+
def rpop(key)
|
51
|
+
with_connection { |redis| redis.rpop(key) }
|
52
|
+
end
|
53
|
+
|
54
|
+
def lrange(key, start, stop)
|
55
|
+
with_connection { |redis| redis.lrange(key, start, stop) }
|
56
|
+
end
|
57
|
+
|
58
|
+
def llen(key)
|
59
|
+
with_connection { |redis| redis.llen(key) }
|
60
|
+
end
|
61
|
+
|
62
|
+
def ltrim(key, start, stop)
|
63
|
+
with_connection { |redis| redis.ltrim(key, start, stop) }
|
64
|
+
end
|
65
|
+
|
66
|
+
def incr(key)
|
67
|
+
with_connection { |redis| redis.incr(key) }
|
68
|
+
end
|
69
|
+
|
70
|
+
def decr(key)
|
71
|
+
with_connection { |redis| redis.decr(key) }
|
72
|
+
end
|
73
|
+
|
74
|
+
def keys(pattern)
|
75
|
+
with_connection { |redis| redis.keys(pattern) }
|
76
|
+
end
|
77
|
+
|
78
|
+
def multi(&block)
|
79
|
+
with_connection do |redis|
|
80
|
+
redis.multi do |multi|
|
81
|
+
multi_wrapper = RedisMultiWrapper.new(multi)
|
82
|
+
block.call(multi_wrapper)
|
83
|
+
end
|
84
|
+
end
|
85
|
+
end
|
86
|
+
|
87
|
+
def pipelined(&block)
|
88
|
+
with_connection do |redis|
|
89
|
+
redis.pipelined do |pipeline|
|
90
|
+
pipeline_wrapper = RedisMultiWrapper.new(pipeline)
|
91
|
+
block.call(pipeline_wrapper)
|
92
|
+
end
|
93
|
+
end
|
94
|
+
end
|
95
|
+
|
96
|
+
def mget(*keys)
|
97
|
+
with_connection { |redis| redis.mget(*keys) }
|
98
|
+
end
|
99
|
+
|
100
|
+
def eval(script, keys: [], argv: [])
|
101
|
+
with_connection { |redis| redis.eval(script, keys: keys, argv: argv) }
|
102
|
+
end
|
103
|
+
|
104
|
+
def ping
|
105
|
+
with_connection { |redis| redis.ping }
|
106
|
+
end
|
107
|
+
|
108
|
+
def flushdb
|
109
|
+
with_connection { |redis| redis.flushdb }
|
110
|
+
end
|
111
|
+
|
112
|
+
private
|
113
|
+
|
114
|
+
def with_connection(&block)
|
115
|
+
@pool.with(&block)
|
116
|
+
end
|
117
|
+
|
118
|
+
# Wrapper for Redis multi/pipeline operations
|
119
|
+
class RedisMultiWrapper
|
120
|
+
def initialize(multi)
|
121
|
+
@multi = multi
|
122
|
+
end
|
123
|
+
|
124
|
+
def method_missing(method, *args, **kwargs, &block)
|
125
|
+
@multi.send(method, *args, **kwargs, &block)
|
126
|
+
end
|
127
|
+
|
128
|
+
def respond_to_missing?(method, include_private = false)
|
129
|
+
@multi.respond_to?(method, include_private)
|
130
|
+
end
|
131
|
+
end
|
132
|
+
end
|
133
|
+
end
|
134
|
+
end
|
@@ -0,0 +1,108 @@
|
|
1
|
+
require "singleton"
|
2
|
+
|
3
|
+
module ModelContextProtocol
|
4
|
+
class Server::RedisConfig
|
5
|
+
include Singleton
|
6
|
+
|
7
|
+
class NotConfiguredError < StandardError
|
8
|
+
def initialize
|
9
|
+
super("Redis not configured. Call ModelContextProtocol::Server.configure_redis first")
|
10
|
+
end
|
11
|
+
end
|
12
|
+
|
13
|
+
attr_reader :manager
|
14
|
+
|
15
|
+
def self.configure(&block)
|
16
|
+
instance.configure(&block)
|
17
|
+
end
|
18
|
+
|
19
|
+
def self.configured?
|
20
|
+
instance.configured?
|
21
|
+
end
|
22
|
+
|
23
|
+
def self.pool
|
24
|
+
instance.pool
|
25
|
+
end
|
26
|
+
|
27
|
+
def self.shutdown!
|
28
|
+
instance.shutdown!
|
29
|
+
end
|
30
|
+
|
31
|
+
def self.reset!
|
32
|
+
instance.reset!
|
33
|
+
end
|
34
|
+
|
35
|
+
def self.stats
|
36
|
+
instance.stats
|
37
|
+
end
|
38
|
+
|
39
|
+
def self.pool_manager
|
40
|
+
instance.manager
|
41
|
+
end
|
42
|
+
|
43
|
+
def initialize
|
44
|
+
reset!
|
45
|
+
end
|
46
|
+
|
47
|
+
def configure(&block)
|
48
|
+
shutdown! if configured?
|
49
|
+
|
50
|
+
config = Configuration.new
|
51
|
+
yield(config) if block_given?
|
52
|
+
|
53
|
+
@manager = Server::RedisPoolManager.new(
|
54
|
+
redis_url: config.redis_url,
|
55
|
+
pool_size: config.pool_size,
|
56
|
+
pool_timeout: config.pool_timeout
|
57
|
+
)
|
58
|
+
|
59
|
+
if config.enable_reaper
|
60
|
+
@manager.configure_reaper(
|
61
|
+
enabled: true,
|
62
|
+
interval: config.reaper_interval,
|
63
|
+
idle_timeout: config.idle_timeout
|
64
|
+
)
|
65
|
+
end
|
66
|
+
|
67
|
+
@manager.start
|
68
|
+
end
|
69
|
+
|
70
|
+
def configured?
|
71
|
+
!@manager.nil? && !@manager.pool.nil?
|
72
|
+
end
|
73
|
+
|
74
|
+
def pool
|
75
|
+
raise NotConfiguredError unless configured?
|
76
|
+
@manager.pool
|
77
|
+
end
|
78
|
+
|
79
|
+
def shutdown!
|
80
|
+
@manager&.shutdown
|
81
|
+
@manager = nil
|
82
|
+
end
|
83
|
+
|
84
|
+
def reset!
|
85
|
+
shutdown!
|
86
|
+
@manager = nil
|
87
|
+
end
|
88
|
+
|
89
|
+
def stats
|
90
|
+
return {} unless configured?
|
91
|
+
@manager.stats
|
92
|
+
end
|
93
|
+
|
94
|
+
class Configuration
|
95
|
+
attr_accessor :redis_url, :pool_size, :pool_timeout,
|
96
|
+
:enable_reaper, :reaper_interval, :idle_timeout
|
97
|
+
|
98
|
+
def initialize
|
99
|
+
@redis_url = nil
|
100
|
+
@pool_size = 20
|
101
|
+
@pool_timeout = 5
|
102
|
+
@enable_reaper = true
|
103
|
+
@reaper_interval = 60
|
104
|
+
@idle_timeout = 300
|
105
|
+
end
|
106
|
+
end
|
107
|
+
end
|
108
|
+
end
|
@@ -0,0 +1,110 @@
|
|
1
|
+
module ModelContextProtocol
|
2
|
+
class Server::RedisPoolManager
|
3
|
+
attr_reader :pool, :reaper_thread
|
4
|
+
|
5
|
+
def initialize(redis_url:, pool_size: 20, pool_timeout: 5)
|
6
|
+
@redis_url = redis_url
|
7
|
+
@pool_size = pool_size
|
8
|
+
@pool_timeout = pool_timeout
|
9
|
+
@pool = nil
|
10
|
+
@reaper_thread = nil
|
11
|
+
@reaper_config = {
|
12
|
+
enabled: false,
|
13
|
+
interval: 60,
|
14
|
+
idle_timeout: 300
|
15
|
+
}
|
16
|
+
end
|
17
|
+
|
18
|
+
def configure_reaper(enabled:, interval: 60, idle_timeout: 300)
|
19
|
+
@reaper_config = {
|
20
|
+
enabled: enabled,
|
21
|
+
interval: interval,
|
22
|
+
idle_timeout: idle_timeout
|
23
|
+
}
|
24
|
+
end
|
25
|
+
|
26
|
+
def start
|
27
|
+
validate!
|
28
|
+
create_pool
|
29
|
+
start_reaper if @reaper_config[:enabled]
|
30
|
+
true
|
31
|
+
end
|
32
|
+
|
33
|
+
def shutdown
|
34
|
+
stop_reaper
|
35
|
+
close_pool
|
36
|
+
end
|
37
|
+
|
38
|
+
def healthy?
|
39
|
+
return false unless @pool
|
40
|
+
|
41
|
+
@pool.with do |conn|
|
42
|
+
conn.ping == "PONG"
|
43
|
+
end
|
44
|
+
rescue
|
45
|
+
false
|
46
|
+
end
|
47
|
+
|
48
|
+
def reap_now
|
49
|
+
return unless @pool
|
50
|
+
|
51
|
+
@pool.reap(@reaper_config[:idle_timeout]) do |conn|
|
52
|
+
conn.close
|
53
|
+
end
|
54
|
+
end
|
55
|
+
|
56
|
+
def stats
|
57
|
+
return {} unless @pool
|
58
|
+
|
59
|
+
{
|
60
|
+
size: @pool.size,
|
61
|
+
available: @pool.available,
|
62
|
+
idle: @pool.instance_variable_get(:@idle_since)&.size || 0
|
63
|
+
}
|
64
|
+
end
|
65
|
+
|
66
|
+
private
|
67
|
+
|
68
|
+
def validate!
|
69
|
+
raise ArgumentError, "redis_url is required" if @redis_url.nil? || @redis_url.empty?
|
70
|
+
raise ArgumentError, "pool_size must be positive" if @pool_size <= 0
|
71
|
+
raise ArgumentError, "pool_timeout must be positive" if @pool_timeout <= 0
|
72
|
+
end
|
73
|
+
|
74
|
+
def create_pool
|
75
|
+
@pool = ConnectionPool.new(size: @pool_size, timeout: @pool_timeout) do
|
76
|
+
Redis.new(url: @redis_url)
|
77
|
+
end
|
78
|
+
end
|
79
|
+
|
80
|
+
def close_pool
|
81
|
+
@pool&.shutdown { |conn| conn.close }
|
82
|
+
@pool = nil
|
83
|
+
end
|
84
|
+
|
85
|
+
def start_reaper
|
86
|
+
return if @reaper_thread&.alive?
|
87
|
+
|
88
|
+
@reaper_thread = Thread.new do
|
89
|
+
loop do
|
90
|
+
sleep @reaper_config[:interval]
|
91
|
+
begin
|
92
|
+
reap_now
|
93
|
+
rescue => e
|
94
|
+
warn "Redis reaper error: #{e.message}"
|
95
|
+
end
|
96
|
+
end
|
97
|
+
end
|
98
|
+
|
99
|
+
@reaper_thread.name = "MCP-Redis-Reaper"
|
100
|
+
end
|
101
|
+
|
102
|
+
def stop_reaper
|
103
|
+
return unless @reaper_thread&.alive?
|
104
|
+
|
105
|
+
@reaper_thread.kill
|
106
|
+
@reaper_thread.join(5)
|
107
|
+
@reaper_thread = nil
|
108
|
+
end
|
109
|
+
end
|
110
|
+
end
|
@@ -1,3 +1,5 @@
|
|
1
|
+
require_relative "cancellable"
|
2
|
+
|
1
3
|
module ModelContextProtocol
|
2
4
|
class Server::Router
|
3
5
|
# Raised when an invalid method is provided.
|
@@ -12,14 +14,45 @@ module ModelContextProtocol
|
|
12
14
|
@handlers[method] = handler
|
13
15
|
end
|
14
16
|
|
15
|
-
|
17
|
+
# Route a message to its handler with request tracking support
|
18
|
+
#
|
19
|
+
# @param message [Hash] the JSON-RPC message
|
20
|
+
# @param request_store [Object] the request store for tracking cancellation
|
21
|
+
# @param session_id [String, nil] the session ID for HTTP transport
|
22
|
+
# @param transport [Object, nil] the transport for sending notifications
|
23
|
+
# @return [Object] the handler result, or nil if cancelled
|
24
|
+
def route(message, request_store: nil, session_id: nil, transport: nil)
|
16
25
|
method = message["method"]
|
17
26
|
handler = @handlers[method]
|
18
27
|
raise MethodNotFoundError, "Method not found: #{method}" unless handler
|
19
28
|
|
20
|
-
|
21
|
-
|
29
|
+
request_id = message["id"]
|
30
|
+
progress_token = message.dig("params", "_meta", "progressToken")
|
31
|
+
|
32
|
+
if request_id && request_store
|
33
|
+
request_store.register_request(request_id, session_id)
|
34
|
+
end
|
35
|
+
|
36
|
+
result = nil
|
37
|
+
begin
|
38
|
+
with_environment(@configuration&.environment_variables) do
|
39
|
+
context = {request_id:, request_store:, session_id:, progress_token:, transport:}
|
40
|
+
|
41
|
+
Thread.current[:mcp_context] = context
|
42
|
+
|
43
|
+
result = handler.call(message)
|
44
|
+
end
|
45
|
+
rescue Server::Cancellable::CancellationError
|
46
|
+
return nil
|
47
|
+
ensure
|
48
|
+
if request_id && request_store
|
49
|
+
request_store.unregister_request(request_id)
|
50
|
+
end
|
51
|
+
|
52
|
+
Thread.current[:mcp_context] = nil
|
22
53
|
end
|
54
|
+
|
55
|
+
result
|
23
56
|
end
|
24
57
|
|
25
58
|
private
|
@@ -0,0 +1,102 @@
|
|
1
|
+
module ModelContextProtocol
|
2
|
+
class Server::StdioTransport
|
3
|
+
# Thread-safe in-memory storage for tracking active requests and their cancellation status.
|
4
|
+
# This store is used by StdioTransport to manage request lifecycle and handle cancellation.
|
5
|
+
class RequestStore
|
6
|
+
def initialize
|
7
|
+
@mutex = Mutex.new
|
8
|
+
@requests = {}
|
9
|
+
end
|
10
|
+
|
11
|
+
# Register a new request with its associated thread
|
12
|
+
#
|
13
|
+
# @param request_id [String] the unique request identifier
|
14
|
+
# @param thread [Thread] the thread processing this request (defaults to current thread)
|
15
|
+
# @return [void]
|
16
|
+
def register_request(request_id, thread = Thread.current)
|
17
|
+
@mutex.synchronize do
|
18
|
+
@requests[request_id] = {
|
19
|
+
thread:,
|
20
|
+
cancelled: false,
|
21
|
+
started_at: Time.now
|
22
|
+
}
|
23
|
+
end
|
24
|
+
end
|
25
|
+
|
26
|
+
# Mark a request as cancelled
|
27
|
+
#
|
28
|
+
# @param request_id [String] the unique request identifier
|
29
|
+
# @return [Boolean] true if request was found and marked cancelled, false otherwise
|
30
|
+
def mark_cancelled(request_id)
|
31
|
+
@mutex.synchronize do
|
32
|
+
if (request = @requests[request_id])
|
33
|
+
request[:cancelled] = true
|
34
|
+
return true
|
35
|
+
end
|
36
|
+
false
|
37
|
+
end
|
38
|
+
end
|
39
|
+
|
40
|
+
# Check if a request has been cancelled
|
41
|
+
#
|
42
|
+
# @param request_id [String] the unique request identifier
|
43
|
+
# @return [Boolean] true if the request is cancelled, false otherwise
|
44
|
+
def cancelled?(request_id)
|
45
|
+
@mutex.synchronize do
|
46
|
+
@requests[request_id]&.fetch(:cancelled, false) || false
|
47
|
+
end
|
48
|
+
end
|
49
|
+
|
50
|
+
# Unregister a request (typically called when request completes)
|
51
|
+
#
|
52
|
+
# @param request_id [String] the unique request identifier
|
53
|
+
# @return [Hash, nil] the removed request data, or nil if not found
|
54
|
+
def unregister_request(request_id)
|
55
|
+
@mutex.synchronize do
|
56
|
+
@requests.delete(request_id)
|
57
|
+
end
|
58
|
+
end
|
59
|
+
|
60
|
+
# Get information about a specific request
|
61
|
+
#
|
62
|
+
# @param request_id [String] the unique request identifier
|
63
|
+
# @return [Hash, nil] request information or nil if not found
|
64
|
+
def get_request(request_id)
|
65
|
+
@mutex.synchronize do
|
66
|
+
@requests[request_id]&.dup
|
67
|
+
end
|
68
|
+
end
|
69
|
+
|
70
|
+
# Get all active request IDs
|
71
|
+
#
|
72
|
+
# @return [Array<String>] list of active request IDs
|
73
|
+
def active_requests
|
74
|
+
@mutex.synchronize do
|
75
|
+
@requests.keys.dup
|
76
|
+
end
|
77
|
+
end
|
78
|
+
|
79
|
+
# Clean up old requests (useful for preventing memory leaks)
|
80
|
+
#
|
81
|
+
# @param max_age_seconds [Integer] maximum age of requests to keep
|
82
|
+
# @return [Array<String>] list of cleaned up request IDs
|
83
|
+
def cleanup_old_requests(max_age_seconds = 300)
|
84
|
+
cutoff_time = Time.now - max_age_seconds
|
85
|
+
removed_ids = []
|
86
|
+
|
87
|
+
@mutex.synchronize do
|
88
|
+
@requests.delete_if do |request_id, data|
|
89
|
+
if data[:started_at] < cutoff_time
|
90
|
+
removed_ids << request_id
|
91
|
+
true
|
92
|
+
else
|
93
|
+
false
|
94
|
+
end
|
95
|
+
end
|
96
|
+
end
|
97
|
+
|
98
|
+
removed_ids
|
99
|
+
end
|
100
|
+
end
|
101
|
+
end
|
102
|
+
end
|