model-context-protocol-rb 0.3.1 → 0.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,5 @@
1
+ require_relative "mcp_logger"
2
+
1
3
  module ModelContextProtocol
2
4
  class Server::Configuration
3
5
  # Raised when configured with invalid name.
@@ -12,16 +14,74 @@ module ModelContextProtocol
12
14
  # Raised when a required environment variable is not set
13
15
  class MissingRequiredEnvironmentVariable < StandardError; end
14
16
 
15
- attr_accessor :enable_log, :name, :registry, :version
17
+ # Raised when transport configuration is invalid
18
+ class InvalidTransportError < StandardError; end
19
+
20
+ # Raised when an invalid log level is provided
21
+ class InvalidLogLevelError < StandardError; end
22
+
23
+ # Valid MCP log levels per the specification
24
+ VALID_LOG_LEVELS = %w[debug info notice warning error critical alert emergency].freeze
25
+
26
+ attr_accessor :name, :registry, :version, :transport
27
+ attr_reader :logger
28
+
29
+ def initialize
30
+ # Always create a logger - enabled by default
31
+ @logging_enabled = true
32
+ @default_log_level = "info"
33
+ @logger = ModelContextProtocol::Server::MCPLogger.new(
34
+ logger_name: "server",
35
+ level: @default_log_level,
36
+ enabled: @logging_enabled
37
+ )
38
+ end
16
39
 
17
40
  def logging_enabled?
18
- enable_log || false
41
+ @logging_enabled
42
+ end
43
+
44
+ def logging_enabled=(value)
45
+ @logging_enabled = value
46
+ @logger = ModelContextProtocol::Server::MCPLogger.new(
47
+ logger_name: "server",
48
+ level: @default_log_level,
49
+ enabled: value
50
+ )
51
+ end
52
+
53
+ def default_log_level=(level)
54
+ unless VALID_LOG_LEVELS.include?(level.to_s)
55
+ raise InvalidLogLevelError, "Invalid log level: #{level}. Valid levels are: #{VALID_LOG_LEVELS.join(", ")}"
56
+ end
57
+
58
+ @default_log_level = level.to_s
59
+ @logger.set_mcp_level(@default_log_level)
60
+ end
61
+
62
+ def transport_type
63
+ case transport
64
+ when Hash
65
+ transport[:type] || transport["type"]
66
+ when Symbol, String
67
+ transport.to_sym
68
+ end
69
+ end
70
+
71
+ def transport_options
72
+ case transport
73
+ when Hash
74
+ transport.except(:type, "type").transform_keys(&:to_sym)
75
+ else
76
+ {}
77
+ end
19
78
  end
20
79
 
21
80
  def validate!
22
81
  raise InvalidServerNameError unless valid_name?
23
82
  raise InvalidRegistryError unless valid_registry?
24
83
  raise InvalidServerVersionError unless valid_version?
84
+ validate_transport!
25
85
 
26
86
  validate_environment_variables!
27
87
  end
@@ -51,6 +111,14 @@ module ModelContextProtocol
51
111
  environment_variables[key.to_s.upcase] = value
52
112
  end
53
113
 
114
+ def context
115
+ @context ||= {}
116
+ end
117
+
118
+ def context=(context_hash = {})
119
+ @context = context_hash
120
+ end
121
+
54
122
  private
55
123
 
56
124
  def required_environment_variables
@@ -74,5 +142,29 @@ module ModelContextProtocol
74
142
  def valid_version?
75
143
  version&.is_a?(String)
76
144
  end
145
+
146
+ def validate_transport!
147
+ case transport_type
148
+ when :streamable_http
149
+ validate_streamable_http_transport!
150
+ when :stdio, nil
151
+ # stdio transport has no required options
152
+ else
153
+ raise InvalidTransportError, "Unknown transport type: #{transport_type}" if transport_type
154
+ end
155
+ end
156
+
157
+ def validate_streamable_http_transport!
158
+ options = transport_options
159
+
160
+ unless options[:redis_client]
161
+ raise InvalidTransportError, "streamable_http transport requires redis_client option"
162
+ end
163
+
164
+ redis_client = options[:redis_client]
165
+ unless redis_client.respond_to?(:hset) && redis_client.respond_to?(:expire)
166
+ raise InvalidTransportError, "redis_client must be a Redis-compatible client"
167
+ end
168
+ end
77
169
  end
78
170
  end
@@ -0,0 +1,109 @@
1
+ require "logger"
2
+ require "forwardable"
3
+ require "json"
4
+
5
+ module ModelContextProtocol
6
+ class Server::MCPLogger
7
+ extend Forwardable
8
+
9
+ def_delegators :@internal_logger, :datetime_format=, :formatter=, :progname, :progname=
10
+
11
+ LEVEL_MAP = {
12
+ "debug" => Logger::DEBUG,
13
+ "info" => Logger::INFO,
14
+ "notice" => Logger::INFO,
15
+ "warning" => Logger::WARN,
16
+ "error" => Logger::ERROR,
17
+ "critical" => Logger::FATAL,
18
+ "alert" => Logger::FATAL,
19
+ "emergency" => Logger::UNKNOWN
20
+ }.freeze
21
+
22
+ REVERSE_LEVEL_MAP = {
23
+ Logger::DEBUG => "debug",
24
+ Logger::INFO => "info",
25
+ Logger::WARN => "warning",
26
+ Logger::ERROR => "error",
27
+ Logger::FATAL => "critical",
28
+ Logger::UNKNOWN => "emergency"
29
+ }.freeze
30
+
31
+ attr_accessor :transport
32
+ attr_reader :logger_name, :enabled
33
+
34
+ def initialize(logger_name: "server", level: "info", enabled: true)
35
+ @logger_name = logger_name
36
+ @enabled = enabled
37
+ @internal_logger = Logger.new(nil)
38
+ @internal_logger.level = LEVEL_MAP[level] || Logger::INFO
39
+ @transport = nil
40
+ @queued_messages = []
41
+ end
42
+
43
+ %i[debug info warn error fatal unknown].each do |severity|
44
+ define_method(severity) do |message = nil, **data, &block|
45
+ return true unless @enabled
46
+ add(Logger.const_get(severity.to_s.upcase), message, data, &block)
47
+ end
48
+ end
49
+
50
+ def add(severity, message = nil, data = {}, &block)
51
+ return true unless @enabled
52
+ return true if severity < @internal_logger.level
53
+
54
+ message = block.call if message.nil? && block_given?
55
+ send_notification(severity, message, data)
56
+ true
57
+ end
58
+
59
+ def level=(value)
60
+ @internal_logger.level = value
61
+ end
62
+
63
+ def level
64
+ @internal_logger.level
65
+ end
66
+
67
+ def set_mcp_level(mcp_level)
68
+ self.level = LEVEL_MAP[mcp_level] || Logger::INFO
69
+ end
70
+
71
+ def connect_transport(transport)
72
+ @transport = transport
73
+ flush_queued_messages if @enabled
74
+ end
75
+
76
+ private
77
+
78
+ def send_notification(severity, message, data)
79
+ return unless @enabled
80
+
81
+ notification_params = {
82
+ level: REVERSE_LEVEL_MAP[severity] || "info",
83
+ logger: @logger_name,
84
+ data: format_data(message, data)
85
+ }
86
+
87
+ if @transport
88
+ @transport.send_notification("notifications/message", notification_params)
89
+ else
90
+ @queued_messages << notification_params
91
+ end
92
+ end
93
+
94
+ def format_data(message, additional_data)
95
+ data = {}
96
+ data[:message] = message.to_s if message
97
+ data.merge!(additional_data) unless additional_data.empty?
98
+ data
99
+ end
100
+
101
+ def flush_queued_messages
102
+ return unless @transport && @enabled
103
+ @queued_messages.each do |params|
104
+ @transport.send_notification("notifications/message", params)
105
+ end
106
+ @queued_messages.clear
107
+ end
108
+ end
109
+ end
@@ -1,32 +1,33 @@
1
1
  module ModelContextProtocol
2
2
  class Server::Prompt
3
- attr_reader :params, :description
3
+ attr_reader :params, :context, :logger
4
4
 
5
- def initialize(params)
5
+ def initialize(params, logger, context = {})
6
6
  validate!(params)
7
- @description = self.class.description
8
7
  @params = params
8
+ @context = context
9
+ @logger = logger
9
10
  end
10
11
 
11
12
  def call
12
13
  raise NotImplementedError, "Subclasses must implement the call method"
13
14
  end
14
15
 
15
- Response = Data.define(:messages, :prompt) do
16
+ Response = Data.define(:messages, :description) do
16
17
  def serialized
17
- {description: prompt.description, messages:}
18
+ {description:, messages:}
18
19
  end
19
20
  end
20
21
  private_constant :Response
21
22
 
22
23
  private def respond_with(messages:)
23
- Response[messages:, prompt: self]
24
+ Response[messages:, description: self.class.description]
24
25
  end
25
26
 
26
27
  private def validate!(params = {})
27
28
  arguments = self.class.arguments || []
28
- required_args = arguments.select { |arg| arg[:required] }.map { |arg| arg[:name] }
29
- valid_arg_names = arguments.map { |arg| arg[:name] }
29
+ required_args = arguments.select { |arg| arg[:required] }.map { |arg| arg[:name].to_sym }
30
+ valid_arg_names = arguments.map { |arg| arg[:name].to_sym }
30
31
 
31
32
  missing_args = required_args - params.keys
32
33
  unless missing_args.empty?
@@ -45,21 +46,37 @@ module ModelContextProtocol
45
46
  attr_reader :name, :description, :arguments
46
47
 
47
48
  def with_metadata(&block)
48
- metadata = instance_eval(&block)
49
+ @arguments ||= []
49
50
 
50
- @name = metadata[:name]
51
- @description = metadata[:description]
52
- @arguments = metadata[:arguments]
51
+ metadata_dsl = MetadataDSL.new
52
+ metadata_dsl.instance_eval(&block)
53
+
54
+ @name = metadata_dsl.name
55
+ @description = metadata_dsl.description
56
+ end
57
+
58
+ def with_argument(&block)
59
+ @arguments ||= []
60
+
61
+ argument_dsl = ArgumentDSL.new
62
+ argument_dsl.instance_eval(&block)
63
+
64
+ @arguments << {
65
+ name: argument_dsl.name,
66
+ description: argument_dsl.description,
67
+ required: argument_dsl.required,
68
+ completion: argument_dsl.completion
69
+ }
53
70
  end
54
71
 
55
72
  def inherited(subclass)
56
73
  subclass.instance_variable_set(:@name, @name)
57
74
  subclass.instance_variable_set(:@description, @description)
58
- subclass.instance_variable_set(:@arguments, @arguments)
75
+ subclass.instance_variable_set(:@arguments, @arguments&.dup)
59
76
  end
60
77
 
61
- def call(params)
62
- new(params).call
78
+ def call(params, logger, context = {})
79
+ new(params, logger, context).call
63
80
  rescue ArgumentError => error
64
81
  raise ModelContextProtocol::Server::ParameterValidationError, error.message
65
82
  end
@@ -67,6 +84,46 @@ module ModelContextProtocol
67
84
  def metadata
68
85
  {name: @name, description: @description, arguments: @arguments}
69
86
  end
87
+
88
+ def complete_for(arg_name, value)
89
+ arg = @arguments&.find { |a| a[:name] == arg_name.to_s }
90
+ completion = (arg && arg[:completion]) ? arg[:completion] : ModelContextProtocol::Server::NullCompletion
91
+ completion.call(arg_name.to_s, value)
92
+ end
93
+ end
94
+
95
+ class MetadataDSL
96
+ def name(value = nil)
97
+ @name = value if value
98
+ @name
99
+ end
100
+
101
+ def description(value = nil)
102
+ @description = value if value
103
+ @description
104
+ end
105
+ end
106
+
107
+ class ArgumentDSL
108
+ def name(value = nil)
109
+ @name = value if value
110
+ @name
111
+ end
112
+
113
+ def description(value = nil)
114
+ @description = value if value
115
+ @description
116
+ end
117
+
118
+ def required(value = nil)
119
+ @required = value unless value.nil?
120
+ @required
121
+ end
122
+
123
+ def completion(klass = nil)
124
+ @completion = klass unless klass.nil?
125
+ @completion
126
+ end
70
127
  end
71
128
  end
72
129
  end
@@ -12,6 +12,7 @@ module ModelContextProtocol
12
12
  def initialize
13
13
  @prompts = []
14
14
  @resources = []
15
+ @resource_templates = []
15
16
  @tools = []
16
17
  @prompts_options = {}
17
18
  @resources_options = {}
@@ -28,6 +29,10 @@ module ModelContextProtocol
28
29
  instance_eval(&block) if block
29
30
  end
30
31
 
32
+ def resource_templates(&block)
33
+ instance_eval(&block) if block
34
+ end
35
+
31
36
  def tools(options = {}, &block)
32
37
  @tools_options = options
33
38
  instance_eval(&block) if block
@@ -42,6 +47,8 @@ module ModelContextProtocol
42
47
  @prompts << entry
43
48
  when ->(ancestors) { ancestors.include?(ModelContextProtocol::Server::Resource) }
44
49
  @resources << entry
50
+ when ->(ancestors) { ancestors.include?(ModelContextProtocol::Server::ResourceTemplate) }
51
+ @resource_templates << entry
45
52
  when ->(ancestors) { ancestors.include?(ModelContextProtocol::Server::Tool) }
46
53
  @tools << entry
47
54
  else
@@ -58,6 +65,11 @@ module ModelContextProtocol
58
65
  entry ? entry[:klass] : nil
59
66
  end
60
67
 
68
+ def find_resource_template(uri)
69
+ entry = @resource_templates.find { |r| uri == r[:uriTemplate] }
70
+ entry ? entry[:klass] : nil
71
+ end
72
+
61
73
  def find_tool(name)
62
74
  find_by_name(@tools, name)
63
75
  end
@@ -70,6 +82,10 @@ module ModelContextProtocol
70
82
  ResourcesData[resources: @resources.map { |entry| entry.except(:klass) }]
71
83
  end
72
84
 
85
+ def resource_templates_data
86
+ ResourceTemplatesData[resource_templates: @resource_templates.map { |entry| entry.except(:klass, :completions) }]
87
+ end
88
+
73
89
  def tools_data
74
90
  ToolsData[tools: @tools.map { |entry| entry.except(:klass) }]
75
91
  end
@@ -88,6 +104,12 @@ module ModelContextProtocol
88
104
  end
89
105
  end
90
106
 
107
+ ResourceTemplatesData = Data.define(:resource_templates) do
108
+ def serialized
109
+ {resourceTemplates: resource_templates}
110
+ end
111
+ end
112
+
91
113
  ToolsData = Data.define(:tools) do
92
114
  def serialized
93
115
  {tools:}
@@ -1,10 +1,12 @@
1
1
  module ModelContextProtocol
2
2
  class Server::Resource
3
- attr_reader :mime_type, :uri
3
+ attr_reader :mime_type, :uri, :context, :logger
4
4
 
5
- def initialize
5
+ def initialize(logger, context = {})
6
6
  @mime_type = self.class.mime_type
7
7
  @uri = self.class.uri
8
+ @context = context
9
+ @logger = logger
8
10
  end
9
11
 
10
12
  def call
@@ -40,12 +42,13 @@ module ModelContextProtocol
40
42
  attr_reader :name, :description, :mime_type, :uri
41
43
 
42
44
  def with_metadata(&block)
43
- metadata = instance_eval(&block)
45
+ metadata_dsl = MetadataDSL.new
46
+ metadata_dsl.instance_eval(&block)
44
47
 
45
- @name = metadata[:name]
46
- @description = metadata[:description]
47
- @mime_type = metadata[:mime_type]
48
- @uri = metadata[:uri]
48
+ @name = metadata_dsl.name
49
+ @description = metadata_dsl.description
50
+ @mime_type = metadata_dsl.mime_type
51
+ @uri = metadata_dsl.uri
49
52
  end
50
53
 
51
54
  def inherited(subclass)
@@ -55,12 +58,34 @@ module ModelContextProtocol
55
58
  subclass.instance_variable_set(:@uri, @uri)
56
59
  end
57
60
 
58
- def call
59
- new.call
61
+ def call(logger, context = {})
62
+ new(logger, context).call
60
63
  end
61
64
 
62
65
  def metadata
63
- {name: @name, description: @description, mime_type: @mime_type, uri: @uri}
66
+ {name: @name, description: @description, mimeType: @mime_type, uri: @uri}
67
+ end
68
+ end
69
+
70
+ class MetadataDSL
71
+ def name(value = nil)
72
+ @name = value if value
73
+ @name
74
+ end
75
+
76
+ def description(value = nil)
77
+ @description = value if value
78
+ @description
79
+ end
80
+
81
+ def mime_type(value = nil)
82
+ @mime_type = value if value
83
+ @mime_type
84
+ end
85
+
86
+ def uri(value = nil)
87
+ @uri = value if value
88
+ @uri
64
89
  end
65
90
  end
66
91
  end
@@ -0,0 +1,93 @@
1
+ module ModelContextProtocol
2
+ class Server::ResourceTemplate
3
+ class << self
4
+ attr_reader :name, :description, :mime_type, :uri_template, :completions
5
+
6
+ def with_metadata(&block)
7
+ metadata_dsl = MetadataDSL.new
8
+ metadata_dsl.instance_eval(&block)
9
+
10
+ @name = metadata_dsl.name
11
+ @description = metadata_dsl.description
12
+ @mime_type = metadata_dsl.mime_type
13
+ @uri_template = metadata_dsl.uri_template
14
+ @completions = metadata_dsl.completions
15
+ end
16
+
17
+ def inherited(subclass)
18
+ subclass.instance_variable_set(:@name, @name)
19
+ subclass.instance_variable_set(:@description, @description)
20
+ subclass.instance_variable_set(:@mime_type, @mime_type)
21
+ subclass.instance_variable_set(:@uri_template, @uri_template)
22
+ subclass.instance_variable_set(:@completions, @completions&.dup)
23
+ end
24
+
25
+ def complete_for(param_name, value)
26
+ completion = if @completions && @completions[param_name.to_s]
27
+ @completions[param_name.to_s]
28
+ else
29
+ ModelContextProtocol::Server::NullCompletion
30
+ end
31
+
32
+ completion.call(param_name.to_s, value)
33
+ end
34
+
35
+ def metadata
36
+ {
37
+ name: @name,
38
+ description: @description,
39
+ mimeType: @mime_type,
40
+ uriTemplate: @uri_template,
41
+ completions: @completions&.transform_keys(&:to_s)
42
+ }
43
+ end
44
+ end
45
+
46
+ class MetadataDSL
47
+ attr_reader :completions
48
+
49
+ def initialize
50
+ @completions = {}
51
+ end
52
+
53
+ def name(value = nil)
54
+ @name = value if value
55
+ @name
56
+ end
57
+
58
+ def description(value = nil)
59
+ @description = value if value
60
+ @description
61
+ end
62
+
63
+ def mime_type(value = nil)
64
+ @mime_type = value if value
65
+ @mime_type
66
+ end
67
+
68
+ def uri_template(value = nil, &block)
69
+ @uri_template = value if value
70
+
71
+ if block_given?
72
+ completion_dsl = CompletionDSL.new
73
+ completion_dsl.instance_eval(&block)
74
+ @completions = completion_dsl.completions
75
+ end
76
+
77
+ @uri_template
78
+ end
79
+ end
80
+
81
+ class CompletionDSL
82
+ attr_reader :completions
83
+
84
+ def initialize
85
+ @completions = {}
86
+ end
87
+
88
+ def completion(param_name, completion_class)
89
+ @completions[param_name.to_s] = completion_class
90
+ end
91
+ end
92
+ end
93
+ end
@@ -0,0 +1,108 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "json"
4
+ require "securerandom"
5
+
6
+ module ModelContextProtocol
7
+ class Server
8
+ class SessionStore
9
+ def initialize(redis_client, ttl: 3600)
10
+ @redis = redis_client
11
+ @ttl = ttl
12
+ end
13
+
14
+ def create_session(session_id, data)
15
+ session_data = {
16
+ id: session_id,
17
+ server_instance: data[:server_instance],
18
+ context: data[:context] || {},
19
+ created_at: data[:created_at] || Time.now.to_f,
20
+ last_activity: Time.now.to_f,
21
+ active_stream: false
22
+ }
23
+
24
+ @redis.hset("session:#{session_id}", session_data.transform_values(&:to_json))
25
+ @redis.expire("session:#{session_id}", @ttl)
26
+ session_id
27
+ end
28
+
29
+ def mark_stream_active(session_id, server_instance)
30
+ @redis.multi do |multi|
31
+ multi.hset("session:#{session_id}",
32
+ "active_stream", true.to_json,
33
+ "stream_server", server_instance.to_json,
34
+ "last_activity", Time.now.to_f.to_json)
35
+ multi.expire("session:#{session_id}", @ttl)
36
+ end
37
+ end
38
+
39
+ def mark_stream_inactive(session_id)
40
+ @redis.multi do |multi|
41
+ multi.hset("session:#{session_id}",
42
+ "active_stream", false.to_json,
43
+ "stream_server", nil.to_json,
44
+ "last_activity", Time.now.to_f.to_json)
45
+ multi.expire("session:#{session_id}", @ttl)
46
+ end
47
+ end
48
+
49
+ def get_session_server(session_id)
50
+ server_data = @redis.hget("session:#{session_id}", "stream_server")
51
+ server_data ? JSON.parse(server_data) : nil
52
+ end
53
+
54
+ def session_exists?(session_id)
55
+ @redis.exists("session:#{session_id}") == 1
56
+ end
57
+
58
+ def session_has_active_stream?(session_id)
59
+ stream_data = @redis.hget("session:#{session_id}", "active_stream")
60
+ stream_data ? JSON.parse(stream_data) : false
61
+ end
62
+
63
+ def get_session_context(session_id)
64
+ context_data = @redis.hget("session:#{session_id}", "context")
65
+ context_data ? JSON.parse(context_data) : {}
66
+ end
67
+
68
+ def cleanup_session(session_id)
69
+ @redis.del("session:#{session_id}")
70
+ end
71
+
72
+ def route_message_to_session(session_id, message)
73
+ server_instance = get_session_server(session_id)
74
+ return false unless server_instance
75
+
76
+ # Publish to server-specific channel
77
+ @redis.publish("server:#{server_instance}:messages", {
78
+ session_id: session_id,
79
+ message: message
80
+ }.to_json)
81
+ true
82
+ end
83
+
84
+ def subscribe_to_server(server_instance, &block)
85
+ @redis.subscribe("server:#{server_instance}:messages") do |on|
86
+ on.message do |channel, message|
87
+ data = JSON.parse(message)
88
+ yield(data)
89
+ end
90
+ end
91
+ end
92
+
93
+ def get_all_active_sessions
94
+ keys = @redis.keys("session:*")
95
+ active_sessions = []
96
+
97
+ keys.each do |key|
98
+ session_id = key.sub("session:", "")
99
+ if session_has_active_stream?(session_id)
100
+ active_sessions << session_id
101
+ end
102
+ end
103
+
104
+ active_sessions
105
+ end
106
+ end
107
+ end
108
+ end