vector_mcp 0.3.2 → 0.3.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +75 -0
- data/lib/vector_mcp/definitions.rb +25 -9
- data/lib/vector_mcp/errors.rb +2 -6
- data/lib/vector_mcp/handlers/core.rb +12 -10
- data/lib/vector_mcp/image_util.rb +27 -2
- data/lib/vector_mcp/log_filter.rb +48 -0
- data/lib/vector_mcp/middleware/base.rb +1 -7
- data/lib/vector_mcp/middleware/manager.rb +3 -15
- data/lib/vector_mcp/request_context.rb +182 -0
- data/lib/vector_mcp/sampling/result.rb +11 -1
- data/lib/vector_mcp/security/middleware.rb +2 -28
- data/lib/vector_mcp/security/strategies/api_key.rb +29 -28
- data/lib/vector_mcp/security/strategies/jwt_token.rb +10 -5
- data/lib/vector_mcp/server/capabilities.rb +5 -7
- data/lib/vector_mcp/server/message_handling.rb +11 -5
- data/lib/vector_mcp/server.rb +21 -10
- data/lib/vector_mcp/session.rb +96 -6
- data/lib/vector_mcp/transport/base_session_manager.rb +320 -0
- data/lib/vector_mcp/transport/http_stream/event_store.rb +157 -0
- data/lib/vector_mcp/transport/http_stream/session_manager.rb +191 -0
- data/lib/vector_mcp/transport/http_stream/stream_handler.rb +270 -0
- data/lib/vector_mcp/transport/http_stream.rb +961 -0
- data/lib/vector_mcp/transport/sse/client_connection.rb +1 -1
- data/lib/vector_mcp/transport/sse/stream_manager.rb +1 -1
- data/lib/vector_mcp/transport/sse.rb +74 -19
- data/lib/vector_mcp/transport/sse_session_manager.rb +188 -0
- data/lib/vector_mcp/transport/stdio.rb +70 -13
- data/lib/vector_mcp/transport/stdio_session_manager.rb +181 -0
- data/lib/vector_mcp/util.rb +39 -1
- data/lib/vector_mcp/version.rb +1 -1
- data/lib/vector_mcp.rb +1 -0
- metadata +10 -1
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
# frozen_string_literal: true
|
|
2
2
|
|
|
3
|
+
require "openssl"
|
|
4
|
+
|
|
3
5
|
module VectorMCP
|
|
4
6
|
module Security
|
|
5
7
|
module Strategies
|
|
@@ -10,8 +12,10 @@ module VectorMCP
|
|
|
10
12
|
|
|
11
13
|
# Initialize with a list of valid API keys
|
|
12
14
|
# @param keys [Array<String>] array of valid API keys
|
|
13
|
-
|
|
15
|
+
# @param allow_query_params [Boolean] whether to accept API keys from query parameters (default: false)
|
|
16
|
+
def initialize(keys: [], allow_query_params: false)
|
|
14
17
|
@valid_keys = Set.new(keys.map(&:to_s))
|
|
18
|
+
@allow_query_params = allow_query_params
|
|
15
19
|
end
|
|
16
20
|
|
|
17
21
|
# Add a valid API key
|
|
@@ -33,7 +37,7 @@ module VectorMCP
|
|
|
33
37
|
api_key = extract_api_key(request)
|
|
34
38
|
return false unless api_key&.length&.positive?
|
|
35
39
|
|
|
36
|
-
if
|
|
40
|
+
if secure_key_match?(api_key)
|
|
37
41
|
{
|
|
38
42
|
api_key: api_key,
|
|
39
43
|
strategy: "api_key",
|
|
@@ -58,14 +62,33 @@ module VectorMCP
|
|
|
58
62
|
|
|
59
63
|
private
|
|
60
64
|
|
|
65
|
+
# Constant-time comparison of API key against all valid keys.
|
|
66
|
+
# Iterates all keys to prevent timing side-channels.
|
|
67
|
+
# @param candidate [String] the API key to check
|
|
68
|
+
# @return [Boolean] true if the candidate matches a valid key
|
|
69
|
+
def secure_key_match?(candidate)
|
|
70
|
+
matched = false
|
|
71
|
+
@valid_keys.each do |valid_key|
|
|
72
|
+
next unless candidate.bytesize == valid_key.bytesize
|
|
73
|
+
|
|
74
|
+
matched = true if OpenSSL.fixed_length_secure_compare(candidate, valid_key)
|
|
75
|
+
end
|
|
76
|
+
matched
|
|
77
|
+
end
|
|
78
|
+
|
|
61
79
|
# Extract API key from various request formats
|
|
62
80
|
# @param request [Hash] the request object
|
|
63
81
|
# @return [String, nil] the extracted API key
|
|
64
82
|
def extract_api_key(request)
|
|
65
83
|
headers = normalize_headers(request)
|
|
66
|
-
params = normalize_params(request)
|
|
67
84
|
|
|
68
|
-
extract_from_headers(headers)
|
|
85
|
+
from_headers = extract_from_headers(headers)
|
|
86
|
+
return from_headers if from_headers
|
|
87
|
+
|
|
88
|
+
return nil unless @allow_query_params
|
|
89
|
+
|
|
90
|
+
params = normalize_params(request)
|
|
91
|
+
extract_from_params(params)
|
|
69
92
|
end
|
|
70
93
|
|
|
71
94
|
# Normalize headers to handle different formats
|
|
@@ -96,36 +119,14 @@ module VectorMCP
|
|
|
96
119
|
# @param env [Hash] the Rack environment
|
|
97
120
|
# @return [Hash] normalized headers
|
|
98
121
|
def extract_headers_from_rack_env(env)
|
|
99
|
-
|
|
100
|
-
env.each do |key, value|
|
|
101
|
-
next unless key.start_with?("HTTP_")
|
|
102
|
-
|
|
103
|
-
# Convert HTTP_X_API_KEY to X-API-Key format
|
|
104
|
-
header_name = key[5..].split("_").map do |part|
|
|
105
|
-
case part.upcase
|
|
106
|
-
when "API" then "API" # Keep API in all caps
|
|
107
|
-
else part.capitalize
|
|
108
|
-
end
|
|
109
|
-
end.join("-")
|
|
110
|
-
headers[header_name] = value
|
|
111
|
-
end
|
|
112
|
-
|
|
113
|
-
# Add special headers
|
|
114
|
-
headers["Authorization"] = env["HTTP_AUTHORIZATION"] if env["HTTP_AUTHORIZATION"]
|
|
115
|
-
headers["Content-Type"] = env["CONTENT_TYPE"] if env["CONTENT_TYPE"]
|
|
116
|
-
headers
|
|
122
|
+
VectorMCP::Util.extract_headers_from_rack_env(env)
|
|
117
123
|
end
|
|
118
124
|
|
|
119
125
|
# Extract params from Rack environment
|
|
120
126
|
# @param env [Hash] the Rack environment
|
|
121
127
|
# @return [Hash] normalized params
|
|
122
128
|
def extract_params_from_rack_env(env)
|
|
123
|
-
|
|
124
|
-
if env["QUERY_STRING"]
|
|
125
|
-
require "uri"
|
|
126
|
-
params = URI.decode_www_form(env["QUERY_STRING"]).to_h
|
|
127
|
-
end
|
|
128
|
-
params
|
|
129
|
+
VectorMCP::Util.extract_params_from_rack_env(env)
|
|
129
130
|
end
|
|
130
131
|
|
|
131
132
|
# Extract API key from headers
|
|
@@ -17,12 +17,14 @@ module VectorMCP
|
|
|
17
17
|
# Initialize JWT strategy
|
|
18
18
|
# @param secret [String] the secret key for JWT verification
|
|
19
19
|
# @param algorithm [String] the JWT algorithm (default: HS256)
|
|
20
|
+
# @param allow_query_params [Boolean] whether to accept JWT tokens from query parameters (default: false)
|
|
20
21
|
# @param options [Hash] additional JWT verification options
|
|
21
|
-
def initialize(secret:, algorithm: "HS256", **options)
|
|
22
|
+
def initialize(secret:, algorithm: "HS256", allow_query_params: false, **options)
|
|
22
23
|
raise LoadError, "JWT gem is required for JWT authentication strategy" unless defined?(JWT)
|
|
23
24
|
|
|
24
25
|
@secret = secret
|
|
25
26
|
@algorithm = algorithm
|
|
27
|
+
@allow_query_params = allow_query_params
|
|
26
28
|
@options = {
|
|
27
29
|
algorithm: @algorithm,
|
|
28
30
|
verify_expiration: true,
|
|
@@ -82,11 +84,14 @@ module VectorMCP
|
|
|
82
84
|
# @return [String, nil] the extracted token
|
|
83
85
|
def extract_token(request)
|
|
84
86
|
headers = request[:headers] || request["headers"] || {}
|
|
85
|
-
params = request[:params] || request["params"] || {}
|
|
86
87
|
|
|
87
|
-
extract_from_auth_header(headers) ||
|
|
88
|
-
|
|
89
|
-
|
|
88
|
+
from_headers = extract_from_auth_header(headers) || extract_from_jwt_header(headers)
|
|
89
|
+
return from_headers if from_headers
|
|
90
|
+
|
|
91
|
+
return nil unless @allow_query_params
|
|
92
|
+
|
|
93
|
+
params = request[:params] || request["params"] || {}
|
|
94
|
+
extract_from_params(params)
|
|
90
95
|
end
|
|
91
96
|
|
|
92
97
|
# Extract token from Authorization header
|
|
@@ -34,7 +34,6 @@ module VectorMCP
|
|
|
34
34
|
# @return [void]
|
|
35
35
|
def clear_prompts_list_changed
|
|
36
36
|
@prompts_list_changed = false
|
|
37
|
-
logger.debug("Prompts listChanged flag cleared.")
|
|
38
37
|
end
|
|
39
38
|
|
|
40
39
|
# Notifies connected clients that the list of available prompts has changed.
|
|
@@ -45,10 +44,10 @@ module VectorMCP
|
|
|
45
44
|
notification_method = "notifications/prompts/list_changed"
|
|
46
45
|
begin
|
|
47
46
|
if transport.respond_to?(:broadcast_notification)
|
|
48
|
-
logger.
|
|
47
|
+
logger.debug("Broadcasting prompts list changed notification.")
|
|
49
48
|
transport.broadcast_notification(notification_method)
|
|
50
49
|
elsif transport.respond_to?(:send_notification)
|
|
51
|
-
logger.
|
|
50
|
+
logger.debug("Sending prompts list changed notification (transport may broadcast or send to first client).")
|
|
52
51
|
transport.send_notification(notification_method)
|
|
53
52
|
else
|
|
54
53
|
logger.warn("Transport does not support sending notifications/prompts/list_changed.")
|
|
@@ -62,7 +61,6 @@ module VectorMCP
|
|
|
62
61
|
# @return [void]
|
|
63
62
|
def clear_roots_list_changed
|
|
64
63
|
@roots_list_changed = false
|
|
65
|
-
logger.debug("Roots listChanged flag cleared.")
|
|
66
64
|
end
|
|
67
65
|
|
|
68
66
|
# Notifies connected clients that the list of available roots has changed.
|
|
@@ -73,10 +71,10 @@ module VectorMCP
|
|
|
73
71
|
notification_method = "notifications/roots/list_changed"
|
|
74
72
|
begin
|
|
75
73
|
if transport.respond_to?(:broadcast_notification)
|
|
76
|
-
logger.
|
|
74
|
+
logger.debug("Broadcasting roots list changed notification.")
|
|
77
75
|
transport.broadcast_notification(notification_method)
|
|
78
76
|
elsif transport.respond_to?(:send_notification)
|
|
79
|
-
logger.
|
|
77
|
+
logger.debug("Sending roots list changed notification (transport may broadcast or send to first client).")
|
|
80
78
|
transport.send_notification(notification_method)
|
|
81
79
|
else
|
|
82
80
|
logger.warn("Transport does not support sending notifications/roots/list_changed.")
|
|
@@ -90,7 +88,7 @@ module VectorMCP
|
|
|
90
88
|
# @api private
|
|
91
89
|
def subscribe_prompts(session)
|
|
92
90
|
@prompt_subscribers << session unless @prompt_subscribers.include?(session)
|
|
93
|
-
|
|
91
|
+
# Session subscribed to prompt list changes
|
|
94
92
|
end
|
|
95
93
|
|
|
96
94
|
private
|
|
@@ -21,10 +21,10 @@ module VectorMCP
|
|
|
21
21
|
params = message["params"] || {}
|
|
22
22
|
|
|
23
23
|
if id && method # Request
|
|
24
|
-
logger.
|
|
24
|
+
logger.debug("[#{session_id}] Request [#{id}]: #{method} with params: #{VectorMCP::LogFilter.filter_hash(params).inspect}")
|
|
25
25
|
handle_request(id, method, params, session)
|
|
26
26
|
elsif method # Notification
|
|
27
|
-
logger.
|
|
27
|
+
logger.debug("[#{session_id}] Notification: #{method} with params: #{VectorMCP::LogFilter.filter_hash(params).inspect}")
|
|
28
28
|
handle_notification(method, params, session)
|
|
29
29
|
nil # Notifications do not have a return value to send back to client
|
|
30
30
|
elsif id # Invalid: Has ID but no method
|
|
@@ -74,7 +74,9 @@ module VectorMCP
|
|
|
74
74
|
# Validates that the session is properly initialized for the given request.
|
|
75
75
|
# @api private
|
|
76
76
|
def validate_session_initialization(id, method, _params, session)
|
|
77
|
-
|
|
77
|
+
# Handle both direct VectorMCP::Session and BaseSessionManager::Session wrapper
|
|
78
|
+
actual_session = session.respond_to?(:context) ? session.context : session
|
|
79
|
+
return if actual_session.initialized?
|
|
78
80
|
|
|
79
81
|
# Allow "initialize" even if not marked initialized yet by server
|
|
80
82
|
return if method == "initialize"
|
|
@@ -113,7 +115,9 @@ module VectorMCP
|
|
|
113
115
|
# Internal handler for JSON-RPC notifications.
|
|
114
116
|
# @api private
|
|
115
117
|
def handle_notification(method, params, session)
|
|
116
|
-
|
|
118
|
+
# Handle both direct VectorMCP::Session and BaseSessionManager::Session wrapper
|
|
119
|
+
actual_session = session.respond_to?(:context) ? session.context : session
|
|
120
|
+
unless actual_session.initialized? || method == "initialized"
|
|
117
121
|
logger.warn("Ignoring notification '#{method}' before session is initialized. Params: #{params.inspect}")
|
|
118
122
|
return
|
|
119
123
|
end
|
|
@@ -158,7 +162,9 @@ module VectorMCP
|
|
|
158
162
|
# @api private
|
|
159
163
|
def session_method(method_name)
|
|
160
164
|
lambda do |params, session, _server|
|
|
161
|
-
|
|
165
|
+
# Handle both direct VectorMCP::Session and BaseSessionManager::Session wrapper
|
|
166
|
+
actual_session = session.respond_to?(:context) ? session.context : session
|
|
167
|
+
actual_session.public_send(method_name, params)
|
|
162
168
|
end
|
|
163
169
|
end
|
|
164
170
|
end
|
data/lib/vector_mcp/server.rb
CHANGED
|
@@ -134,11 +134,12 @@ module VectorMCP
|
|
|
134
134
|
|
|
135
135
|
# Runs the server using the specified transport mechanism.
|
|
136
136
|
#
|
|
137
|
-
# @param transport [:stdio, :sse, VectorMCP::Transport::Base] The transport to use.
|
|
138
|
-
# Can be a symbol (`:stdio`, `:sse`) or an initialized transport instance.
|
|
137
|
+
# @param transport [:stdio, :sse, :http_stream, VectorMCP::Transport::Base] The transport to use.
|
|
138
|
+
# Can be a symbol (`:stdio`, `:sse`, `:http_stream`) or an initialized transport instance.
|
|
139
139
|
# If a symbol is provided, the method will instantiate the corresponding transport class.
|
|
140
|
-
# If `:sse` is chosen, it uses Puma as the HTTP server.
|
|
141
|
-
#
|
|
140
|
+
# If `:sse` is chosen, it uses Puma as the HTTP server (deprecated).
|
|
141
|
+
# If `:http_stream` is chosen, it uses the MCP-compliant streamable HTTP transport.
|
|
142
|
+
# @param options [Hash] Transport-specific options (e.g., `:host`, `:port` for HTTP transports).
|
|
142
143
|
# These are passed to the transport's constructor if a symbol is provided for `transport`.
|
|
143
144
|
# @return [void]
|
|
144
145
|
# @raise [ArgumentError] if an unsupported transport symbol is given.
|
|
@@ -150,11 +151,20 @@ module VectorMCP
|
|
|
150
151
|
when :sse
|
|
151
152
|
begin
|
|
152
153
|
require_relative "transport/sse"
|
|
154
|
+
logger.warn("SSE transport is deprecated. Please use :http_stream instead.")
|
|
153
155
|
VectorMCP::Transport::SSE.new(self, **options)
|
|
154
156
|
rescue LoadError => e
|
|
155
157
|
logger.fatal("SSE transport requires additional dependencies.")
|
|
156
158
|
raise NotImplementedError, "SSE transport dependencies not available: #{e.message}"
|
|
157
159
|
end
|
|
160
|
+
when :http_stream
|
|
161
|
+
begin
|
|
162
|
+
require_relative "transport/http_stream"
|
|
163
|
+
VectorMCP::Transport::HttpStream.new(self, **options)
|
|
164
|
+
rescue LoadError => e
|
|
165
|
+
logger.fatal("HttpStream transport requires additional dependencies.")
|
|
166
|
+
raise NotImplementedError, "HttpStream transport dependencies not available: #{e.message}"
|
|
167
|
+
end
|
|
158
168
|
when VectorMCP::Transport::Base # Allow passing an initialized transport instance
|
|
159
169
|
transport.server = self if transport.respond_to?(:server=) && transport.server.nil? # Ensure server is set
|
|
160
170
|
transport
|
|
@@ -180,7 +190,7 @@ module VectorMCP
|
|
|
180
190
|
|
|
181
191
|
case strategy
|
|
182
192
|
when :api_key
|
|
183
|
-
add_api_key_auth(options[:keys] || [])
|
|
193
|
+
add_api_key_auth(options[:keys] || [], allow_query_params: options[:allow_query_params] || false)
|
|
184
194
|
when :jwt
|
|
185
195
|
add_jwt_auth(options)
|
|
186
196
|
when :custom
|
|
@@ -277,14 +287,14 @@ module VectorMCP
|
|
|
277
287
|
# server.use_middleware(LoggingMiddleware, :after_tool_call, conditions: { only_operations: ['important_tool'] })
|
|
278
288
|
def use_middleware(middleware_class, hooks, priority: Middleware::Hook::DEFAULT_PRIORITY, conditions: {})
|
|
279
289
|
@middleware_manager.register(middleware_class, hooks, priority: priority, conditions: conditions)
|
|
280
|
-
@logger.
|
|
290
|
+
@logger.debug("Registered middleware: #{middleware_class.name}")
|
|
281
291
|
end
|
|
282
292
|
|
|
283
293
|
# Remove all middleware hooks for a specific class
|
|
284
294
|
# @param middleware_class [Class] Middleware class to remove
|
|
285
295
|
def remove_middleware(middleware_class)
|
|
286
296
|
@middleware_manager.unregister(middleware_class)
|
|
287
|
-
@logger.
|
|
297
|
+
@logger.debug("Removed middleware: #{middleware_class.name}")
|
|
288
298
|
end
|
|
289
299
|
|
|
290
300
|
# Get middleware statistics
|
|
@@ -296,16 +306,17 @@ module VectorMCP
|
|
|
296
306
|
# Clear all middleware (useful for testing)
|
|
297
307
|
def clear_middleware!
|
|
298
308
|
@middleware_manager.clear!
|
|
299
|
-
@logger.
|
|
309
|
+
@logger.debug("Cleared all middleware")
|
|
300
310
|
end
|
|
301
311
|
|
|
302
312
|
private
|
|
303
313
|
|
|
304
314
|
# Add API key authentication strategy
|
|
305
315
|
# @param keys [Array<String>] array of valid API keys
|
|
316
|
+
# @param allow_query_params [Boolean] whether to accept API keys from query parameters
|
|
306
317
|
# @return [void]
|
|
307
|
-
def add_api_key_auth(keys)
|
|
308
|
-
strategy = Security::Strategies::ApiKey.new(keys: keys)
|
|
318
|
+
def add_api_key_auth(keys, allow_query_params: false)
|
|
319
|
+
strategy = Security::Strategies::ApiKey.new(keys: keys, allow_query_params: allow_query_params)
|
|
309
320
|
@auth_manager.add_strategy(:api_key, strategy)
|
|
310
321
|
end
|
|
311
322
|
|
data/lib/vector_mcp/session.rb
CHANGED
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
require_relative "sampling/request"
|
|
4
4
|
require_relative "sampling/result"
|
|
5
5
|
require_relative "errors"
|
|
6
|
+
require_relative "request_context"
|
|
6
7
|
|
|
7
8
|
module VectorMCP
|
|
8
9
|
# Represents the state of a single client-server connection session in MCP.
|
|
@@ -13,8 +14,9 @@ module VectorMCP
|
|
|
13
14
|
# @attr_reader protocol_version [String] The MCP protocol version used by the server.
|
|
14
15
|
# @attr_reader client_info [Hash, nil] Information about the client, received during initialization.
|
|
15
16
|
# @attr_reader client_capabilities [Hash, nil] Capabilities supported by the client, received during initialization.
|
|
17
|
+
# @attr_reader request_context [RequestContext] The request context for this session.
|
|
16
18
|
class Session
|
|
17
|
-
attr_reader :server_info, :server_capabilities, :protocol_version, :client_info, :client_capabilities, :server, :transport, :id
|
|
19
|
+
attr_reader :server_info, :server_capabilities, :protocol_version, :client_info, :client_capabilities, :server, :transport, :id, :request_context
|
|
18
20
|
attr_accessor :data # For user-defined session-specific storage
|
|
19
21
|
|
|
20
22
|
# Initializes a new session.
|
|
@@ -22,7 +24,8 @@ module VectorMCP
|
|
|
22
24
|
# @param server [VectorMCP::Server] The server instance managing this session.
|
|
23
25
|
# @param transport [VectorMCP::Transport::Base, nil] The transport handling this session. Required for sampling.
|
|
24
26
|
# @param id [String] A unique identifier for this session (e.g., from transport layer).
|
|
25
|
-
|
|
27
|
+
# @param request_context [RequestContext, Hash, nil] The request context for this session.
|
|
28
|
+
def initialize(server, transport = nil, id: SecureRandom.uuid, request_context: nil)
|
|
26
29
|
@server = server
|
|
27
30
|
@transport = transport # Store the transport for sending requests
|
|
28
31
|
@id = id
|
|
@@ -31,6 +34,16 @@ module VectorMCP
|
|
|
31
34
|
@client_capabilities = nil
|
|
32
35
|
@data = {} # Initialize user data hash
|
|
33
36
|
@logger = server.logger
|
|
37
|
+
|
|
38
|
+
# Initialize request context
|
|
39
|
+
@request_context = case request_context
|
|
40
|
+
when RequestContext
|
|
41
|
+
request_context
|
|
42
|
+
when Hash
|
|
43
|
+
RequestContext.new(**request_context)
|
|
44
|
+
else
|
|
45
|
+
RequestContext.new
|
|
46
|
+
end
|
|
34
47
|
end
|
|
35
48
|
|
|
36
49
|
# Marks the session as initialized using parameters from the client's `initialize` request.
|
|
@@ -75,6 +88,75 @@ module VectorMCP
|
|
|
75
88
|
@initialized_state == :succeeded
|
|
76
89
|
end
|
|
77
90
|
|
|
91
|
+
# Sets the request context for this session.
|
|
92
|
+
# This method should be called by transport layers to populate request-specific data.
|
|
93
|
+
#
|
|
94
|
+
# @param context [RequestContext, Hash] The request context to set.
|
|
95
|
+
# Can be a RequestContext object or a hash of attributes.
|
|
96
|
+
# @return [RequestContext] The newly set request context.
|
|
97
|
+
# @raise [ArgumentError] If the context is not a RequestContext or Hash.
|
|
98
|
+
def request_context=(context)
|
|
99
|
+
@request_context = case context
|
|
100
|
+
when RequestContext
|
|
101
|
+
context
|
|
102
|
+
when Hash
|
|
103
|
+
RequestContext.new(**context)
|
|
104
|
+
else
|
|
105
|
+
raise ArgumentError, "Request context must be a RequestContext or Hash, got #{context.class}"
|
|
106
|
+
end
|
|
107
|
+
end
|
|
108
|
+
|
|
109
|
+
# Updates the request context with new data.
|
|
110
|
+
# This merges the provided attributes with the existing context.
|
|
111
|
+
#
|
|
112
|
+
# @param attributes [Hash] The attributes to merge into the request context.
|
|
113
|
+
# @return [RequestContext] The updated request context.
|
|
114
|
+
def update_request_context(**attributes)
|
|
115
|
+
current_attrs = @request_context.to_h
|
|
116
|
+
|
|
117
|
+
# Deep merge nested hashes like headers and params
|
|
118
|
+
merged_attrs = current_attrs.dup
|
|
119
|
+
attributes.each do |key, value|
|
|
120
|
+
merged_attrs[key] = if value.is_a?(Hash) && current_attrs[key].is_a?(Hash)
|
|
121
|
+
current_attrs[key].merge(value)
|
|
122
|
+
else
|
|
123
|
+
value
|
|
124
|
+
end
|
|
125
|
+
end
|
|
126
|
+
|
|
127
|
+
@request_context = RequestContext.new(**merged_attrs)
|
|
128
|
+
end
|
|
129
|
+
|
|
130
|
+
# Convenience method to check if the session has request headers.
|
|
131
|
+
#
|
|
132
|
+
# @return [Boolean] True if the request context has headers, false otherwise.
|
|
133
|
+
def request_headers?
|
|
134
|
+
@request_context.headers?
|
|
135
|
+
end
|
|
136
|
+
|
|
137
|
+
# Convenience method to check if the session has request parameters.
|
|
138
|
+
#
|
|
139
|
+
# @return [Boolean] True if the request context has parameters, false otherwise.
|
|
140
|
+
def request_params?
|
|
141
|
+
@request_context.params?
|
|
142
|
+
end
|
|
143
|
+
|
|
144
|
+
# Convenience method to get a request header value.
|
|
145
|
+
#
|
|
146
|
+
# @param name [String] The header name.
|
|
147
|
+
# @return [String, nil] The header value or nil if not found.
|
|
148
|
+
def request_header(name)
|
|
149
|
+
@request_context.header(name)
|
|
150
|
+
end
|
|
151
|
+
|
|
152
|
+
# Convenience method to get a request parameter value.
|
|
153
|
+
#
|
|
154
|
+
# @param name [String] The parameter name.
|
|
155
|
+
# @return [String, nil] The parameter value or nil if not found.
|
|
156
|
+
def request_param(name)
|
|
157
|
+
@request_context.param(name)
|
|
158
|
+
end
|
|
159
|
+
|
|
78
160
|
# Helper to check client capabilities later if needed
|
|
79
161
|
# def supports?(capability_key)
|
|
80
162
|
# @client_capabilities.key?(capability_key.to_s)
|
|
@@ -111,7 +193,7 @@ module VectorMCP
|
|
|
111
193
|
|
|
112
194
|
begin
|
|
113
195
|
sampling_req_obj = VectorMCP::Sampling::Request.new(request_params)
|
|
114
|
-
@logger.
|
|
196
|
+
@logger.debug("[Session #{@id}] Sending sampling/createMessage request to client.")
|
|
115
197
|
|
|
116
198
|
result = send_sampling_request(sampling_req_obj, timeout)
|
|
117
199
|
|
|
@@ -161,17 +243,25 @@ module VectorMCP
|
|
|
161
243
|
send_request_kwargs = {}
|
|
162
244
|
send_request_kwargs[:timeout] = timeout if timeout
|
|
163
245
|
|
|
164
|
-
|
|
246
|
+
# For HTTP transport, we need to use send_request_to_session to target this specific session
|
|
247
|
+
raw_result = if @transport.respond_to?(:send_request_to_session)
|
|
248
|
+
@transport.send_request_to_session(@id, *send_request_args, **send_request_kwargs)
|
|
249
|
+
else
|
|
250
|
+
# Fallback to generic send_request for other transports
|
|
251
|
+
@transport.send_request(*send_request_args, **send_request_kwargs)
|
|
252
|
+
end
|
|
253
|
+
|
|
165
254
|
VectorMCP::Sampling::Result.new(raw_result)
|
|
166
255
|
rescue ArgumentError => e
|
|
167
256
|
@logger.error("[Session #{@id}] Invalid parameters for sampling request or result: #{e.message}")
|
|
168
|
-
raise VectorMCP::SamplingError
|
|
257
|
+
raise VectorMCP::SamplingError.new("Invalid sampling parameters or malformed client response: #{e.message}",
|
|
258
|
+
details: { original_error: e.to_s })
|
|
169
259
|
rescue VectorMCP::SamplingError => e
|
|
170
260
|
@logger.warn("[Session #{@id}] Sampling request failed: #{e.message}")
|
|
171
261
|
raise e
|
|
172
262
|
rescue StandardError => e
|
|
173
263
|
@logger.error("[Session #{@id}] Unexpected error during sampling: #{e.class.name}: #{e.message}")
|
|
174
|
-
raise VectorMCP::SamplingError
|
|
264
|
+
raise VectorMCP::SamplingError.new("An unexpected error occurred during sampling: #{e.message}", details: { original_error: e.to_s })
|
|
175
265
|
end
|
|
176
266
|
end
|
|
177
267
|
end
|