model-context-protocol-rb 0.3.3 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -20,6 +20,11 @@ module ModelContextProtocol
20
20
 
21
21
  transport_options = @configuration.transport_options
22
22
  @redis = transport_options[:redis_client]
23
+ @require_sessions = transport_options.fetch(:require_sessions, false)
24
+ @default_protocol_version = transport_options.fetch(:default_protocol_version, "2025-03-26")
25
+ @session_protocol_versions = {} # Track protocol versions per session
26
+ @validate_origin = transport_options.fetch(:validate_origin, true)
27
+ @allowed_origins = transport_options.fetch(:allowed_origins, ["http://localhost", "https://localhost", "http://127.0.0.1", "https://127.0.0.1"])
23
28
 
24
29
  @session_store = ModelContextProtocol::Server::SessionStore.new(
25
30
  @redis,
@@ -29,6 +34,7 @@ module ModelContextProtocol
29
34
  @server_instance = "#{Socket.gethostname}-#{Process.pid}-#{SecureRandom.hex(4)}"
30
35
  @local_streams = {}
31
36
  @notification_queue = []
37
+ @sse_event_counter = 0
32
38
 
33
39
  setup_redis_subscriber
34
40
  end
@@ -36,20 +42,19 @@ module ModelContextProtocol
36
42
  def handle
37
43
  @configuration.logger.connect_transport(self)
38
44
 
39
- request = @configuration.transport_options[:request]
40
- response = @configuration.transport_options[:response]
45
+ env = @configuration.transport_options[:env]
41
46
 
42
- unless request && response
43
- raise ArgumentError, "StreamableHTTP transport requires request and response objects in transport_options"
47
+ unless env
48
+ raise ArgumentError, "StreamableHTTP transport requires Rack env hash in transport_options"
44
49
  end
45
50
 
46
- case request.method
51
+ case env["REQUEST_METHOD"]
47
52
  when "POST"
48
- handle_post_request(request)
53
+ handle_post_request(env)
49
54
  when "GET"
50
- handle_sse_request(request, response)
55
+ handle_sse_request(env)
51
56
  when "DELETE"
52
- handle_delete_request(request)
57
+ handle_delete_request(env)
53
58
  else
54
59
  error_response = ErrorResponse[id: nil, error: {code: -32601, message: "Method not allowed"}]
55
60
  {json: error_response.serialized, status: 405}
@@ -72,16 +77,90 @@ module ModelContextProtocol
72
77
 
73
78
  private
74
79
 
75
- def handle_post_request(request)
76
- body_string = request.body.read
80
+ def validate_headers(env)
81
+ if @validate_origin
82
+ origin = env["HTTP_ORIGIN"]
83
+ if origin && !@allowed_origins.any? { |allowed| origin.start_with?(allowed) }
84
+ error_response = ErrorResponse[id: nil, error: {code: -32600, message: "Origin not allowed"}]
85
+ return {json: error_response.serialized, status: 403}
86
+ end
87
+ end
88
+
89
+ accept_header = env["HTTP_ACCEPT"]
90
+ if accept_header
91
+ unless accept_header.include?("application/json") || accept_header.include?("text/event-stream")
92
+ error_response = ErrorResponse[id: nil, error: {code: -32600, message: "Invalid Accept header. Must include application/json or text/event-stream"}]
93
+ return {json: error_response.serialized, status: 400}
94
+ end
95
+ end
96
+
97
+ protocol_version = env["HTTP_MCP_PROTOCOL_VERSION"]
98
+ if protocol_version
99
+ # Check if this matches a known negotiated version
100
+ valid_versions = @session_protocol_versions.values.compact.uniq
101
+ unless valid_versions.empty? || valid_versions.include?(protocol_version)
102
+ error_response = ErrorResponse[id: nil, error: {code: -32600, message: "Invalid MCP protocol version: #{protocol_version}. Expected one of: #{valid_versions.join(", ")}"}]
103
+ return {json: error_response.serialized, status: 400}
104
+ end
105
+ end
106
+
107
+ nil
108
+ end
109
+
110
+ def determine_message_type(body)
111
+ if body.key?("method") && body.key?("id")
112
+ :request
113
+ elsif body.key?("method") && !body.key?("id")
114
+ :notification
115
+ elsif body.key?("id") && body.key?("result") || body.key?("error")
116
+ :response
117
+ else
118
+ :unknown
119
+ end
120
+ end
121
+
122
+ def create_initialization_sse_stream_proc(response_data)
123
+ proc do |stream|
124
+ event_id = next_event_id
125
+ send_sse_event(stream, response_data, event_id)
126
+ end
127
+ end
128
+
129
+ def create_request_sse_stream_proc(response_data)
130
+ proc do |stream|
131
+ event_id = next_event_id
132
+ send_sse_event(stream, response_data, event_id)
133
+ end
134
+ end
135
+
136
+ def next_event_id
137
+ @sse_event_counter += 1
138
+ "#{@server_instance}-#{@sse_event_counter}"
139
+ end
140
+
141
+ def send_sse_event(stream, data, event_id = nil)
142
+ if event_id
143
+ stream.write("id: #{event_id}\n")
144
+ end
145
+ message = data.is_a?(String) ? data : data.to_json
146
+ stream.write("data: #{message}\n\n")
147
+ stream.flush if stream.respond_to?(:flush)
148
+ end
149
+
150
+ def handle_post_request(env)
151
+ validation_error = validate_headers(env)
152
+ return validation_error if validation_error
153
+
154
+ body_string = env["rack.input"].read
77
155
  body = JSON.parse(body_string)
78
- session_id = request.headers["Mcp-Session-Id"]
156
+ session_id = env["HTTP_MCP_SESSION_ID"]
157
+ accept_header = env["HTTP_ACCEPT"] || ""
79
158
 
80
159
  case body["method"]
81
160
  when "initialize"
82
- handle_initialization(body)
161
+ handle_initialization(body, accept_header)
83
162
  else
84
- handle_regular_request(body, session_id)
163
+ handle_regular_request(body, session_id, accept_header)
85
164
  end
86
165
  rescue JSON::ParserError
87
166
  error_response = ErrorResponse[id: "", error: {code: -32700, message: "Parse error"}]
@@ -96,51 +175,122 @@ module ModelContextProtocol
96
175
  {json: error_response.serialized, status: 500}
97
176
  end
98
177
 
99
- def handle_initialization(body)
100
- session_id = SecureRandom.uuid
101
-
102
- @session_store.create_session(session_id, {
103
- server_instance: @server_instance,
104
- context: @configuration.context || {},
105
- created_at: Time.now.to_f
106
- })
107
-
178
+ def handle_initialization(body, accept_header)
108
179
  result = @router.route(body)
109
180
  response = Response[id: body["id"], result: result.serialized]
181
+ response_headers = {}
182
+
183
+ negotiated_protocol_version = result.serialized[:protocolVersion] || result.serialized["protocolVersion"]
184
+
185
+ if @require_sessions
186
+ session_id = SecureRandom.uuid
187
+ @session_store.create_session(session_id, {
188
+ server_instance: @server_instance,
189
+ context: @configuration.context || {},
190
+ created_at: Time.now.to_f,
191
+ negotiated_protocol_version: negotiated_protocol_version
192
+ })
193
+ response_headers["Mcp-Session-Id"] = session_id
194
+ @session_protocol_versions[session_id] = negotiated_protocol_version
195
+ else
196
+ @session_protocol_versions[:default] = negotiated_protocol_version
197
+ end
110
198
 
111
- {
112
- json: response.serialized,
113
- status: 200,
114
- headers: {"Mcp-Session-Id" => session_id}
115
- }
199
+ if accept_header.include?("text/event-stream") && !accept_header.include?("application/json")
200
+ response_headers.merge!({
201
+ "Content-Type" => "text/event-stream",
202
+ "Cache-Control" => "no-cache",
203
+ "Connection" => "keep-alive"
204
+ })
205
+
206
+ {
207
+ stream: true,
208
+ headers: response_headers,
209
+ stream_proc: create_initialization_sse_stream_proc(response.serialized)
210
+ }
211
+ else
212
+ response_headers["Content-Type"] = "application/json"
213
+ {
214
+ json: response.serialized,
215
+ status: 200,
216
+ headers: response_headers
217
+ }
218
+ end
116
219
  end
117
220
 
118
- def handle_regular_request(body, session_id)
119
- unless session_id && @session_store.session_exists?(session_id)
120
- error_response = ErrorResponse[id: body["id"], error: {code: -32600, message: "Invalid or missing session ID"}]
121
- return {json: error_response.serialized, status: 400}
221
+ def handle_regular_request(body, session_id, accept_header)
222
+ if @require_sessions
223
+ unless session_id && @session_store.session_exists?(session_id)
224
+ if session_id && !@session_store.session_exists?(session_id)
225
+ error_response = ErrorResponse[id: body["id"], error: {code: -32600, message: "Session terminated"}]
226
+ return {json: error_response.serialized, status: 404}
227
+ else
228
+ error_response = ErrorResponse[id: body["id"], error: {code: -32600, message: "Invalid or missing session ID"}]
229
+ return {json: error_response.serialized, status: 400}
230
+ end
231
+ end
122
232
  end
123
233
 
124
- result = @router.route(body)
125
- response = Response[id: body["id"], result: result.serialized]
234
+ message_type = determine_message_type(body)
126
235
 
127
- if @session_store.session_has_active_stream?(session_id)
128
- deliver_to_session_stream(session_id, response.serialized)
129
- {json: {accepted: true}, status: 200}
130
- else
131
- {json: response.serialized, status: 200}
236
+ case message_type
237
+ when :notification, :response
238
+ if session_id && @session_store.session_has_active_stream?(session_id)
239
+ deliver_to_session_stream(session_id, body)
240
+ end
241
+ {json: {}, status: 202}
242
+
243
+ when :request
244
+ result = @router.route(body)
245
+ response = Response[id: body["id"], result: result.serialized]
246
+
247
+ if session_id && @session_store.session_has_active_stream?(session_id)
248
+ deliver_to_session_stream(session_id, response.serialized)
249
+ return {json: {accepted: true}, status: 200}
250
+ end
251
+
252
+ if accept_header.include?("text/event-stream") && !accept_header.include?("application/json")
253
+ {
254
+ stream: true,
255
+ headers: {
256
+ "Content-Type" => "text/event-stream",
257
+ "Cache-Control" => "no-cache",
258
+ "Connection" => "keep-alive"
259
+ },
260
+ stream_proc: create_request_sse_stream_proc(response.serialized)
261
+ }
262
+ else
263
+ {
264
+ json: response.serialized,
265
+ status: 200,
266
+ headers: {"Content-Type" => "application/json"}
267
+ }
268
+ end
132
269
  end
133
270
  end
134
271
 
135
- def handle_sse_request(request, response)
136
- session_id = request.headers["Mcp-Session-Id"]
137
-
138
- unless session_id && @session_store.session_exists?(session_id)
139
- error_response = ErrorResponse[id: nil, error: {code: -32600, message: "Invalid or missing session ID"}]
272
+ def handle_sse_request(env)
273
+ accept_header = env["HTTP_ACCEPT"] || ""
274
+ unless accept_header.include?("text/event-stream")
275
+ error_response = ErrorResponse[id: nil, error: {code: -32600, message: "Accept header must include text/event-stream"}]
140
276
  return {json: error_response.serialized, status: 400}
141
277
  end
142
278
 
143
- @session_store.mark_stream_active(session_id, @server_instance)
279
+ session_id = env["HTTP_MCP_SESSION_ID"]
280
+ last_event_id = env["HTTP_LAST_EVENT_ID"]
281
+
282
+ if @require_sessions
283
+ unless session_id && @session_store.session_exists?(session_id)
284
+ if session_id && !@session_store.session_exists?(session_id)
285
+ error_response = ErrorResponse[id: nil, error: {code: -32600, message: "Session terminated"}]
286
+ return {json: error_response.serialized, status: 404}
287
+ else
288
+ error_response = ErrorResponse[id: nil, error: {code: -32600, message: "Invalid or missing session ID"}]
289
+ return {json: error_response.serialized, status: 400}
290
+ end
291
+ end
292
+ @session_store.mark_stream_active(session_id, @server_instance)
293
+ end
144
294
 
145
295
  {
146
296
  stream: true,
@@ -149,12 +299,12 @@ module ModelContextProtocol
149
299
  "Cache-Control" => "no-cache",
150
300
  "Connection" => "keep-alive"
151
301
  },
152
- stream_proc: create_sse_stream_proc(session_id)
302
+ stream_proc: create_sse_stream_proc(session_id, last_event_id)
153
303
  }
154
304
  end
155
305
 
156
- def handle_delete_request(request)
157
- session_id = request.headers["Mcp-Session-Id"]
306
+ def handle_delete_request(env)
307
+ session_id = env["HTTP_MCP_SESSION_ID"]
158
308
 
159
309
  if session_id
160
310
  cleanup_session(session_id)
@@ -163,11 +313,15 @@ module ModelContextProtocol
163
313
  {json: {success: true}, status: 200}
164
314
  end
165
315
 
166
- def create_sse_stream_proc(session_id)
316
+ def create_sse_stream_proc(session_id, last_event_id = nil)
167
317
  proc do |stream|
168
- register_local_stream(session_id, stream)
318
+ register_local_stream(session_id, stream) if session_id
169
319
 
170
- flush_notifications_to_stream(stream)
320
+ if last_event_id
321
+ replay_messages_after_event_id(stream, session_id, last_event_id)
322
+ else
323
+ flush_notifications_to_stream(stream)
324
+ end
171
325
 
172
326
  start_keepalive_thread(session_id, stream)
173
327
 
@@ -176,7 +330,7 @@ module ModelContextProtocol
176
330
  sleep 0.1
177
331
  end
178
332
  ensure
179
- cleanup_local_stream(session_id)
333
+ cleanup_local_stream(session_id) if session_id
180
334
  end
181
335
  end
182
336
 
@@ -226,9 +380,12 @@ module ModelContextProtocol
226
380
  end
227
381
 
228
382
  def send_to_stream(stream, data)
229
- message = data.is_a?(String) ? data : data.to_json
230
- stream.write("data: #{message}\n\n")
231
- stream.flush if stream.respond_to?(:flush)
383
+ event_id = next_event_id
384
+ send_sse_event(stream, data, event_id)
385
+ end
386
+
387
+ def replay_messages_after_event_id(stream, session_id, last_event_id)
388
+ flush_notifications_to_stream(stream)
232
389
  end
233
390
 
234
391
  def deliver_to_session_stream(session_id, data)
@@ -2,11 +2,16 @@ require "json-schema"
2
2
 
3
3
  module ModelContextProtocol
4
4
  class Server::Tool
5
- attr_reader :params, :context, :logger
5
+ # Raised when output schema validation fails.
6
+ class OutputSchemaValidationError < StandardError; end
6
7
 
7
- def initialize(params, logger, context = {})
8
- validate!(params)
9
- @params = params
8
+ include ModelContextProtocol::Server::ContentHelpers
9
+
10
+ attr_reader :arguments, :context, :logger
11
+
12
+ def initialize(arguments, logger, context = {})
13
+ validate!(arguments)
14
+ @arguments = arguments
10
15
  @context = context
11
16
  @logger = logger
12
17
  end
@@ -15,99 +20,106 @@ module ModelContextProtocol
15
20
  raise NotImplementedError, "Subclasses must implement the call method"
16
21
  end
17
22
 
18
- TextResponse = Data.define(:text) do
19
- def serialized
20
- {content: [{type: "text", text:}], isError: false}
21
- end
22
- end
23
- private_constant :TextResponse
24
-
25
- ImageResponse = Data.define(:data, :mime_type) do
26
- def initialize(data:, mime_type: "image/png")
27
- super
28
- end
29
-
23
+ Response = Data.define(:content) do
30
24
  def serialized
31
- {content: [{type: "image", data:, mimeType: mime_type}], isError: false}
25
+ serialized_contents = content.map(&:serialized)
26
+ {content: serialized_contents, isError: false}
32
27
  end
33
28
  end
34
- private_constant :ImageResponse
35
-
36
- ResourceResponse = Data.define(:uri, :text, :mime_type) do
37
- def initialize(uri:, text:, mime_type: "text/plain")
38
- super
39
- end
29
+ private_constant :Response
40
30
 
31
+ StructuredContentResponse = Data.define(:structured_content, :tool) do
41
32
  def serialized
42
- {content: [{type: "resource", resource: {uri:, mimeType: mime_type, text:}}], isError: false}
33
+ json_text = JSON.generate(structured_content)
34
+ text_content = ModelContextProtocol::Server::Content::Text[
35
+ meta: nil,
36
+ annotations: nil,
37
+ text: json_text
38
+ ]
39
+
40
+ validation_errors = JSON::Validator.fully_validate(
41
+ tool.class.definition[:outputSchema], structured_content
42
+ )
43
+
44
+ if validation_errors.empty?
45
+ {
46
+ structuredContent: structured_content,
47
+ content: [text_content.serialized],
48
+ isError: false
49
+ }
50
+ else
51
+ raise OutputSchemaValidationError, validation_errors.join(", ")
52
+ end
43
53
  end
44
54
  end
45
- private_constant :ResourceResponse
55
+ private_constant :StructuredContentResponse
46
56
 
47
- ToolErrorResponse = Data.define(:text) do
57
+ ErrorResponse = Data.define(:error) do
48
58
  def serialized
49
- {content: [{type: "text", text:}], isError: true}
59
+ {content: [{type: "text", text: error}], isError: true}
50
60
  end
51
61
  end
52
- private_constant :ToolErrorResponse
53
-
54
- private def respond_with(type, **options)
55
- case [type, options]
56
- in [:text, {text:}]
57
- TextResponse[text:]
58
- in [:image, {data:, mime_type:}]
59
- ImageResponse[data:, mime_type:]
60
- in [:image, {data:}]
61
- ImageResponse[data:]
62
- in [:resource, {mime_type:, text:, uri:}]
63
- ResourceResponse[mime_type:, text:, uri:]
64
- in [:resource, {text:, uri:}]
65
- ResourceResponse[text:, uri:]
66
- in [:error, {text:}]
67
- ToolErrorResponse[text:]
62
+ private_constant :ErrorResponse
63
+
64
+ private def respond_with(**kwargs)
65
+ case [kwargs]
66
+ in [{content:}]
67
+ content_array = content.is_a?(Array) ? content : [content]
68
+ Response[content: content_array]
69
+ in [{structured_content:}]
70
+ StructuredContentResponse[structured_content:, tool: self]
71
+ in [{error:}]
72
+ ErrorResponse[error:]
68
73
  else
69
- raise ModelContextProtocol::Server::ResponseArgumentsError, "Invalid arguments: #{type}, #{options}"
74
+ raise ModelContextProtocol::Server::ResponseArgumentsError, "Invalid arguments: #{kwargs.inspect}"
70
75
  end
71
76
  end
72
77
 
73
- private def validate!(params)
74
- JSON::Validator.validate!(self.class.input_schema, params)
78
+ private def validate!(arguments)
79
+ JSON::Validator.validate!(self.class.input_schema, arguments)
75
80
  end
76
81
 
77
82
  class << self
78
- attr_reader :name, :description, :input_schema
83
+ attr_reader :name, :description, :title, :input_schema, :output_schema
79
84
 
80
- def with_metadata(&block)
81
- metadata_dsl = MetadataDSL.new
82
- metadata_dsl.instance_eval(&block)
85
+ def define(&block)
86
+ definition_dsl = DefinitionDSL.new
87
+ definition_dsl.instance_eval(&block)
83
88
 
84
- @name = metadata_dsl.name
85
- @description = metadata_dsl.description
86
- @input_schema = metadata_dsl.input_schema
89
+ @name = definition_dsl.name
90
+ @description = definition_dsl.description
91
+ @title = definition_dsl.title
92
+ @input_schema = definition_dsl.input_schema
93
+ @output_schema = definition_dsl.output_schema
87
94
  end
88
95
 
89
96
  def inherited(subclass)
90
97
  subclass.instance_variable_set(:@name, @name)
91
98
  subclass.instance_variable_set(:@description, @description)
99
+ subclass.instance_variable_set(:@title, @title)
92
100
  subclass.instance_variable_set(:@input_schema, @input_schema)
101
+ subclass.instance_variable_set(:@output_schema, @output_schema)
93
102
  end
94
103
 
95
- def call(params, logger, context = {})
96
- new(params, logger, context).call
104
+ def call(arguments, logger, context = {})
105
+ new(arguments, logger, context).call
97
106
  rescue JSON::Schema::ValidationError => validation_error
98
107
  raise ModelContextProtocol::Server::ParameterValidationError, validation_error.message
99
- rescue ModelContextProtocol::Server::ResponseArgumentsError => response_arguments_error
100
- raise response_arguments_error
108
+ rescue OutputSchemaValidationError, ModelContextProtocol::Server::ResponseArgumentsError => tool_error
109
+ raise tool_error, tool_error.message
101
110
  rescue => error
102
- ToolErrorResponse[text: error.message]
111
+ ErrorResponse[error: error.message]
103
112
  end
104
113
 
105
- def metadata
106
- {name: @name, description: @description, inputSchema: @input_schema}
114
+ def definition
115
+ result = {name: @name, description: @description, inputSchema: @input_schema}
116
+ result[:title] = @title if @title
117
+ result[:outputSchema] = @output_schema if @output_schema
118
+ result
107
119
  end
108
120
  end
109
121
 
110
- class MetadataDSL
122
+ class DefinitionDSL
111
123
  def name(value = nil)
112
124
  @name = value if value
113
125
  @name
@@ -118,10 +130,20 @@ module ModelContextProtocol
118
130
  @description
119
131
  end
120
132
 
133
+ def title(value = nil)
134
+ @title = value if value
135
+ @title
136
+ end
137
+
121
138
  def input_schema(&block)
122
139
  @input_schema = instance_eval(&block) if block_given?
123
140
  @input_schema
124
141
  end
142
+
143
+ def output_schema(&block)
144
+ @output_schema = instance_eval(&block) if block_given?
145
+ @output_schema
146
+ end
125
147
  end
126
148
  end
127
149
  end