model-context-protocol-rb 0.3.2 → 0.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,10 +1,12 @@
1
1
  module ModelContextProtocol
2
2
  class Server::Prompt
3
- attr_reader :params
3
+ attr_reader :arguments, :context, :logger
4
4
 
5
- def initialize(params)
6
- validate!(params)
7
- @params = params
5
+ def initialize(arguments, logger, context = {})
6
+ validate!(arguments)
7
+ @arguments = arguments
8
+ @context = context
9
+ @logger = logger
8
10
  end
9
11
 
10
12
  def call
@@ -22,18 +24,18 @@ module ModelContextProtocol
22
24
  Response[messages:, description: self.class.description]
23
25
  end
24
26
 
25
- private def validate!(params = {})
26
- arguments = self.class.arguments || []
27
- required_args = arguments.select { |arg| arg[:required] }.map { |arg| arg[:name] }
28
- valid_arg_names = arguments.map { |arg| arg[:name] }
27
+ private def validate!(arguments = {})
28
+ defined_arguments = self.class.defined_arguments || []
29
+ required_args = defined_arguments.select { |arg| arg[:required] }.map { |arg| arg[:name].to_sym }
30
+ valid_arg_names = defined_arguments.map { |arg| arg[:name].to_sym }
29
31
 
30
- missing_args = required_args - params.keys
32
+ missing_args = required_args - arguments.keys
31
33
  unless missing_args.empty?
32
34
  missing_args_list = missing_args.join(", ")
33
35
  raise ArgumentError, "Missing required arguments: #{missing_args_list}"
34
36
  end
35
37
 
36
- extra_args = params.keys - valid_arg_names
38
+ extra_args = arguments.keys - valid_arg_names
37
39
  unless extra_args.empty?
38
40
  extra_args_list = extra_args.join(", ")
39
41
  raise ArgumentError, "Unexpected arguments: #{extra_args_list}"
@@ -41,10 +43,10 @@ module ModelContextProtocol
41
43
  end
42
44
 
43
45
  class << self
44
- attr_reader :name, :description, :arguments
46
+ attr_reader :name, :description, :defined_arguments
45
47
 
46
48
  def with_metadata(&block)
47
- @arguments ||= []
49
+ @defined_arguments ||= []
48
50
 
49
51
  metadata_dsl = MetadataDSL.new
50
52
  metadata_dsl.instance_eval(&block)
@@ -54,12 +56,12 @@ module ModelContextProtocol
54
56
  end
55
57
 
56
58
  def with_argument(&block)
57
- @arguments ||= []
59
+ @defined_arguments ||= []
58
60
 
59
61
  argument_dsl = ArgumentDSL.new
60
62
  argument_dsl.instance_eval(&block)
61
63
 
62
- @arguments << {
64
+ @defined_arguments << {
63
65
  name: argument_dsl.name,
64
66
  description: argument_dsl.description,
65
67
  required: argument_dsl.required,
@@ -70,21 +72,21 @@ module ModelContextProtocol
70
72
  def inherited(subclass)
71
73
  subclass.instance_variable_set(:@name, @name)
72
74
  subclass.instance_variable_set(:@description, @description)
73
- subclass.instance_variable_set(:@arguments, @arguments&.dup)
75
+ subclass.instance_variable_set(:@defined_arguments, @defined_arguments&.dup)
74
76
  end
75
77
 
76
- def call(params)
77
- new(params).call
78
+ def call(arguments, logger, context = {})
79
+ new(arguments, logger, context).call
78
80
  rescue ArgumentError => error
79
81
  raise ModelContextProtocol::Server::ParameterValidationError, error.message
80
82
  end
81
83
 
82
84
  def metadata
83
- {name: @name, description: @description, arguments: @arguments}
85
+ {name: @name, description: @description, arguments: @defined_arguments}
84
86
  end
85
87
 
86
88
  def complete_for(arg_name, value)
87
- arg = @arguments&.find { |a| a[:name] == arg_name.to_s }
89
+ arg = @defined_arguments&.find { |a| a[:name] == arg_name.to_s }
88
90
  completion = (arg && arg[:completion]) ? arg[:completion] : ModelContextProtocol::Server::NullCompletion
89
91
  completion.call(arg_name.to_s, value)
90
92
  end
@@ -1,10 +1,12 @@
1
1
  module ModelContextProtocol
2
2
  class Server::Resource
3
- attr_reader :mime_type, :uri
3
+ attr_reader :mime_type, :uri, :context, :logger
4
4
 
5
- def initialize
5
+ def initialize(logger, context = {})
6
6
  @mime_type = self.class.mime_type
7
7
  @uri = self.class.uri
8
+ @context = context
9
+ @logger = logger
8
10
  end
9
11
 
10
12
  def call
@@ -56,8 +58,8 @@ module ModelContextProtocol
56
58
  subclass.instance_variable_set(:@uri, @uri)
57
59
  end
58
60
 
59
- def call
60
- new.call
61
+ def call(logger, context = {})
62
+ new(logger, context).call
61
63
  end
62
64
 
63
65
  def metadata
@@ -0,0 +1,108 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "json"
4
+ require "securerandom"
5
+
6
+ module ModelContextProtocol
7
+ class Server
8
+ class SessionStore
9
+ def initialize(redis_client, ttl: 3600)
10
+ @redis = redis_client
11
+ @ttl = ttl
12
+ end
13
+
14
+ def create_session(session_id, data)
15
+ session_data = {
16
+ id: session_id,
17
+ server_instance: data[:server_instance],
18
+ context: data[:context] || {},
19
+ created_at: data[:created_at] || Time.now.to_f,
20
+ last_activity: Time.now.to_f,
21
+ active_stream: false
22
+ }
23
+
24
+ @redis.hset("session:#{session_id}", session_data.transform_values(&:to_json))
25
+ @redis.expire("session:#{session_id}", @ttl)
26
+ session_id
27
+ end
28
+
29
+ def mark_stream_active(session_id, server_instance)
30
+ @redis.multi do |multi|
31
+ multi.hset("session:#{session_id}",
32
+ "active_stream", true.to_json,
33
+ "stream_server", server_instance.to_json,
34
+ "last_activity", Time.now.to_f.to_json)
35
+ multi.expire("session:#{session_id}", @ttl)
36
+ end
37
+ end
38
+
39
+ def mark_stream_inactive(session_id)
40
+ @redis.multi do |multi|
41
+ multi.hset("session:#{session_id}",
42
+ "active_stream", false.to_json,
43
+ "stream_server", nil.to_json,
44
+ "last_activity", Time.now.to_f.to_json)
45
+ multi.expire("session:#{session_id}", @ttl)
46
+ end
47
+ end
48
+
49
+ def get_session_server(session_id)
50
+ server_data = @redis.hget("session:#{session_id}", "stream_server")
51
+ server_data ? JSON.parse(server_data) : nil
52
+ end
53
+
54
+ def session_exists?(session_id)
55
+ @redis.exists("session:#{session_id}") == 1
56
+ end
57
+
58
+ def session_has_active_stream?(session_id)
59
+ stream_data = @redis.hget("session:#{session_id}", "active_stream")
60
+ stream_data ? JSON.parse(stream_data) : false
61
+ end
62
+
63
+ def get_session_context(session_id)
64
+ context_data = @redis.hget("session:#{session_id}", "context")
65
+ context_data ? JSON.parse(context_data) : {}
66
+ end
67
+
68
+ def cleanup_session(session_id)
69
+ @redis.del("session:#{session_id}")
70
+ end
71
+
72
+ def route_message_to_session(session_id, message)
73
+ server_instance = get_session_server(session_id)
74
+ return false unless server_instance
75
+
76
+ # Publish to server-specific channel
77
+ @redis.publish("server:#{server_instance}:messages", {
78
+ session_id: session_id,
79
+ message: message
80
+ }.to_json)
81
+ true
82
+ end
83
+
84
+ def subscribe_to_server(server_instance, &block)
85
+ @redis.subscribe("server:#{server_instance}:messages") do |on|
86
+ on.message do |channel, message|
87
+ data = JSON.parse(message)
88
+ yield(data)
89
+ end
90
+ end
91
+ end
92
+
93
+ def get_all_active_sessions
94
+ keys = @redis.keys("session:*")
95
+ active_sessions = []
96
+
97
+ keys.each do |key|
98
+ session_id = key.sub("session:", "")
99
+ if session_has_active_stream?(session_id)
100
+ active_sessions << session_id
101
+ end
102
+ end
103
+
104
+ active_sessions
105
+ end
106
+ end
107
+ end
108
+ end
@@ -12,16 +12,19 @@ module ModelContextProtocol
12
12
  end
13
13
  end
14
14
 
15
- attr_reader :logger, :router
15
+ attr_reader :router, :configuration
16
16
 
17
- def initialize(logger:, router:)
18
- @logger = logger
17
+ def initialize(router:, configuration:)
19
18
  @router = router
19
+ @configuration = configuration
20
20
  end
21
21
 
22
- def begin
22
+ def handle
23
+ # Connect logger to transport
24
+ @configuration.logger.connect_transport(self)
25
+
23
26
  loop do
24
- line = $stdin.gets
27
+ line = receive_message
25
28
  break unless line
26
29
 
27
30
  begin
@@ -31,18 +34,17 @@ module ModelContextProtocol
31
34
  result = router.route(message)
32
35
  send_message(Response[id: message["id"], result: result.serialized])
33
36
  rescue ModelContextProtocol::Server::ParameterValidationError => validation_error
34
- log("Validation error: #{validation_error.message}")
37
+ @configuration.logger.error("Validation error", error: validation_error.message)
35
38
  send_message(
36
39
  ErrorResponse[id: message["id"], error: {code: -32602, message: validation_error.message}]
37
40
  )
38
41
  rescue JSON::ParserError => parser_error
39
- log("Parser error: #{parser_error.message}")
42
+ @configuration.logger.error("Parser error", error: parser_error.message)
40
43
  send_message(
41
44
  ErrorResponse[id: "", error: {code: -32700, message: parser_error.message}]
42
45
  )
43
46
  rescue => error
44
- log("Internal error: #{error.message}")
45
- log(error.backtrace)
47
+ @configuration.logger.error("Internal error", error: error.message, backtrace: error.backtrace.first(5))
46
48
  send_message(
47
49
  ErrorResponse[id: message["id"], error: {code: -32603, message: error.message}]
48
50
  )
@@ -50,10 +52,23 @@ module ModelContextProtocol
50
52
  end
51
53
  end
52
54
 
55
+ def send_notification(method, params)
56
+ notification = {
57
+ jsonrpc: "2.0",
58
+ method: method,
59
+ params: params
60
+ }
61
+ $stdout.puts(JSON.generate(notification))
62
+ $stdout.flush
63
+ rescue IOError => e
64
+ # Handle broken pipe gracefully
65
+ @configuration.logger.debug("Failed to send notification", error: e.message) if @configuration.logging_enabled?
66
+ end
67
+
53
68
  private
54
69
 
55
- def log(output, level = :error)
56
- logger.send(level.to_sym, output)
70
+ def receive_message
71
+ $stdin.gets
57
72
  end
58
73
 
59
74
  def send_message(message)
@@ -0,0 +1,291 @@
1
+ require "json"
2
+ require "securerandom"
3
+
4
+ module ModelContextProtocol
5
+ class Server::StreamableHttpTransport
6
+ Response = Data.define(:id, :result) do
7
+ def serialized
8
+ {jsonrpc: "2.0", id:, result:}
9
+ end
10
+ end
11
+
12
+ ErrorResponse = Data.define(:id, :error) do
13
+ def serialized
14
+ {jsonrpc: "2.0", id:, error:}
15
+ end
16
+ end
17
+ def initialize(router:, configuration:)
18
+ @router = router
19
+ @configuration = configuration
20
+
21
+ transport_options = @configuration.transport_options
22
+ @redis = transport_options[:redis_client]
23
+
24
+ @session_store = ModelContextProtocol::Server::SessionStore.new(
25
+ @redis,
26
+ ttl: transport_options[:session_ttl] || 3600
27
+ )
28
+
29
+ @server_instance = "#{Socket.gethostname}-#{Process.pid}-#{SecureRandom.hex(4)}"
30
+ @local_streams = {}
31
+ @notification_queue = []
32
+
33
+ setup_redis_subscriber
34
+ end
35
+
36
+ def handle
37
+ @configuration.logger.connect_transport(self)
38
+
39
+ request = @configuration.transport_options[:request]
40
+ response = @configuration.transport_options[:response]
41
+
42
+ unless request && response
43
+ raise ArgumentError, "StreamableHTTP transport requires request and response objects in transport_options"
44
+ end
45
+
46
+ case request.method
47
+ when "POST"
48
+ handle_post_request(request)
49
+ when "GET"
50
+ handle_sse_request(request, response)
51
+ when "DELETE"
52
+ handle_delete_request(request)
53
+ else
54
+ error_response = ErrorResponse[id: nil, error: {code: -32601, message: "Method not allowed"}]
55
+ {json: error_response.serialized, status: 405}
56
+ end
57
+ end
58
+
59
+ def send_notification(method, params)
60
+ notification = {
61
+ jsonrpc: "2.0",
62
+ method: method,
63
+ params: params
64
+ }
65
+
66
+ if has_active_streams?
67
+ deliver_to_active_streams(notification)
68
+ else
69
+ @notification_queue << notification
70
+ end
71
+ end
72
+
73
+ private
74
+
75
+ def handle_post_request(request)
76
+ body_string = request.body.read
77
+ body = JSON.parse(body_string)
78
+ session_id = request.headers["Mcp-Session-Id"]
79
+
80
+ case body["method"]
81
+ when "initialize"
82
+ handle_initialization(body)
83
+ else
84
+ handle_regular_request(body, session_id)
85
+ end
86
+ rescue JSON::ParserError
87
+ error_response = ErrorResponse[id: "", error: {code: -32700, message: "Parse error"}]
88
+ {json: error_response.serialized, status: 400}
89
+ rescue ModelContextProtocol::Server::ParameterValidationError => validation_error
90
+ @configuration.logger.error("Validation error", error: validation_error.message)
91
+ error_response = ErrorResponse[id: body&.dig("id"), error: {code: -32602, message: validation_error.message}]
92
+ {json: error_response.serialized, status: 400}
93
+ rescue => e
94
+ @configuration.logger.error("Error handling POST request", error: e.message, backtrace: e.backtrace.first(5))
95
+ error_response = ErrorResponse[id: body&.dig("id"), error: {code: -32603, message: "Internal error"}]
96
+ {json: error_response.serialized, status: 500}
97
+ end
98
+
99
+ def handle_initialization(body)
100
+ session_id = SecureRandom.uuid
101
+
102
+ @session_store.create_session(session_id, {
103
+ server_instance: @server_instance,
104
+ context: @configuration.context || {},
105
+ created_at: Time.now.to_f
106
+ })
107
+
108
+ result = @router.route(body)
109
+ response = Response[id: body["id"], result: result.serialized]
110
+
111
+ {
112
+ json: response.serialized,
113
+ status: 200,
114
+ headers: {"Mcp-Session-Id" => session_id}
115
+ }
116
+ end
117
+
118
+ def handle_regular_request(body, session_id)
119
+ unless session_id && @session_store.session_exists?(session_id)
120
+ error_response = ErrorResponse[id: body["id"], error: {code: -32600, message: "Invalid or missing session ID"}]
121
+ return {json: error_response.serialized, status: 400}
122
+ end
123
+
124
+ result = @router.route(body)
125
+ response = Response[id: body["id"], result: result.serialized]
126
+
127
+ if @session_store.session_has_active_stream?(session_id)
128
+ deliver_to_session_stream(session_id, response.serialized)
129
+ {json: {accepted: true}, status: 200}
130
+ else
131
+ {json: response.serialized, status: 200}
132
+ end
133
+ end
134
+
135
+ def handle_sse_request(request, response)
136
+ session_id = request.headers["Mcp-Session-Id"]
137
+
138
+ unless session_id && @session_store.session_exists?(session_id)
139
+ error_response = ErrorResponse[id: nil, error: {code: -32600, message: "Invalid or missing session ID"}]
140
+ return {json: error_response.serialized, status: 400}
141
+ end
142
+
143
+ @session_store.mark_stream_active(session_id, @server_instance)
144
+
145
+ {
146
+ stream: true,
147
+ headers: {
148
+ "Content-Type" => "text/event-stream",
149
+ "Cache-Control" => "no-cache",
150
+ "Connection" => "keep-alive"
151
+ },
152
+ stream_proc: create_sse_stream_proc(session_id)
153
+ }
154
+ end
155
+
156
+ def handle_delete_request(request)
157
+ session_id = request.headers["Mcp-Session-Id"]
158
+
159
+ if session_id
160
+ cleanup_session(session_id)
161
+ end
162
+
163
+ {json: {success: true}, status: 200}
164
+ end
165
+
166
+ def create_sse_stream_proc(session_id)
167
+ proc do |stream|
168
+ register_local_stream(session_id, stream)
169
+
170
+ flush_notifications_to_stream(stream)
171
+
172
+ start_keepalive_thread(session_id, stream)
173
+
174
+ loop do
175
+ break unless stream_connected?(stream)
176
+ sleep 0.1
177
+ end
178
+ ensure
179
+ cleanup_local_stream(session_id)
180
+ end
181
+ end
182
+
183
+ def register_local_stream(session_id, stream)
184
+ @local_streams[session_id] = stream
185
+ end
186
+
187
+ def cleanup_local_stream(session_id)
188
+ @local_streams.delete(session_id)
189
+ @session_store.mark_stream_inactive(session_id)
190
+ end
191
+
192
+ def stream_connected?(stream)
193
+ return false unless stream
194
+
195
+ begin
196
+ stream.write(": ping\n\n")
197
+ stream.flush if stream.respond_to?(:flush)
198
+ true
199
+ rescue IOError, Errno::EPIPE, Errno::ECONNRESET
200
+ false
201
+ end
202
+ end
203
+
204
+ def start_keepalive_thread(session_id, stream)
205
+ Thread.new do
206
+ loop do
207
+ sleep 30
208
+ break unless stream_connected?(stream)
209
+
210
+ begin
211
+ send_ping_to_stream(stream)
212
+ rescue IOError, Errno::EPIPE, Errno::ECONNRESET
213
+ break
214
+ end
215
+ end
216
+ rescue => e
217
+ @configuration.logger.error("Keepalive thread error", error: e.message)
218
+ ensure
219
+ cleanup_local_stream(session_id)
220
+ end
221
+ end
222
+
223
+ def send_ping_to_stream(stream)
224
+ stream.write(": ping #{Time.now.iso8601}\n\n")
225
+ stream.flush if stream.respond_to?(:flush)
226
+ end
227
+
228
+ def send_to_stream(stream, data)
229
+ message = data.is_a?(String) ? data : data.to_json
230
+ stream.write("data: #{message}\n\n")
231
+ stream.flush if stream.respond_to?(:flush)
232
+ end
233
+
234
+ def deliver_to_session_stream(session_id, data)
235
+ if @local_streams[session_id]
236
+ begin
237
+ send_to_stream(@local_streams[session_id], data)
238
+ return true
239
+ rescue IOError, Errno::EPIPE, Errno::ECONNRESET
240
+ cleanup_local_stream(session_id)
241
+ end
242
+ end
243
+
244
+ @session_store.route_message_to_session(session_id, data)
245
+ end
246
+
247
+ def cleanup_session(session_id)
248
+ cleanup_local_stream(session_id)
249
+ @session_store.cleanup_session(session_id)
250
+ end
251
+
252
+ def setup_redis_subscriber
253
+ Thread.new do
254
+ @session_store.subscribe_to_server(@server_instance) do |data|
255
+ session_id = data["session_id"]
256
+ message = data["message"]
257
+
258
+ if @local_streams[session_id]
259
+ begin
260
+ send_to_stream(@local_streams[session_id], message)
261
+ rescue IOError, Errno::EPIPE, Errno::ECONNRESET
262
+ cleanup_local_stream(session_id)
263
+ end
264
+ end
265
+ end
266
+ rescue => e
267
+ @configuration.logger.error("Redis subscriber error", error: e.message, backtrace: e.backtrace.first(5))
268
+ sleep 5
269
+ retry
270
+ end
271
+ end
272
+
273
+ def has_active_streams?
274
+ @local_streams.any?
275
+ end
276
+
277
+ def deliver_to_active_streams(notification)
278
+ @local_streams.each do |session_id, stream|
279
+ send_to_stream(stream, notification)
280
+ rescue IOError, Errno::EPIPE, Errno::ECONNRESET
281
+ cleanup_local_stream(session_id)
282
+ end
283
+ end
284
+
285
+ def flush_notifications_to_stream(stream)
286
+ while (notification = @notification_queue.shift)
287
+ send_to_stream(stream, notification)
288
+ end
289
+ end
290
+ end
291
+ end
@@ -2,11 +2,13 @@ require "json-schema"
2
2
 
3
3
  module ModelContextProtocol
4
4
  class Server::Tool
5
- attr_reader :params
5
+ attr_reader :arguments, :context, :logger
6
6
 
7
- def initialize(params)
8
- validate!(params)
9
- @params = params
7
+ def initialize(arguments, logger, context = {})
8
+ validate!(arguments)
9
+ @arguments = arguments
10
+ @context = context
11
+ @logger = logger
10
12
  end
11
13
 
12
14
  def call
@@ -68,8 +70,8 @@ module ModelContextProtocol
68
70
  end
69
71
  end
70
72
 
71
- private def validate!(params)
72
- JSON::Validator.validate!(self.class.input_schema, params)
73
+ private def validate!(arguments)
74
+ JSON::Validator.validate!(self.class.input_schema, arguments)
73
75
  end
74
76
 
75
77
  class << self
@@ -90,8 +92,8 @@ module ModelContextProtocol
90
92
  subclass.instance_variable_set(:@input_schema, @input_schema)
91
93
  end
92
94
 
93
- def call(params)
94
- new(params).call
95
+ def call(arguments, logger, context = {})
96
+ new(arguments, logger, context).call
95
97
  rescue JSON::Schema::ValidationError => validation_error
96
98
  raise ModelContextProtocol::Server::ParameterValidationError, validation_error.message
97
99
  rescue ModelContextProtocol::Server::ResponseArgumentsError => response_arguments_error