rubyrana 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +9 -0
- data/LICENSE +190 -0
- data/NOTICE +4 -0
- data/README.md +179 -0
- data/REPORT.md +156 -0
- data/docs/CHECKLIST.md +32 -0
- data/docs/RELEASE.md +16 -0
- data/docs/USAGE.md +54 -0
- data/examples/mcp.rb +22 -0
- data/examples/quick_start.rb +23 -0
- data/examples/streaming.rb +19 -0
- data/examples/tools_loader.rb +18 -0
- data/lib/rubyrana/agent.rb +153 -0
- data/lib/rubyrana/config.rb +15 -0
- data/lib/rubyrana/errors.rb +11 -0
- data/lib/rubyrana/mcp/client.rb +124 -0
- data/lib/rubyrana/multi_agent.rb +21 -0
- data/lib/rubyrana/persistence/base.rb +15 -0
- data/lib/rubyrana/persistence/file_store.rb +38 -0
- data/lib/rubyrana/persistence/redis_store.rb +36 -0
- data/lib/rubyrana/providers/anthropic.rb +176 -0
- data/lib/rubyrana/providers/base.rb +25 -0
- data/lib/rubyrana/providers/bedrock.rb +11 -0
- data/lib/rubyrana/providers/openai.rb +11 -0
- data/lib/rubyrana/routing/keyword_router.rb +20 -0
- data/lib/rubyrana/safety/filter.rb +28 -0
- data/lib/rubyrana/tool.rb +30 -0
- data/lib/rubyrana/tool_registry.rb +26 -0
- data/lib/rubyrana/tooling.rb +33 -0
- data/lib/rubyrana/tools/code_interpreter.rb +68 -0
- data/lib/rubyrana/tools/loader.rb +20 -0
- data/lib/rubyrana/tools/mcp_web_search.rb +36 -0
- data/lib/rubyrana/tools/web_search.rb +79 -0
- data/lib/rubyrana/tools.rb +22 -0
- data/lib/rubyrana/version.rb +5 -0
- data/lib/rubyrana.rb +33 -0
- metadata +149 -0
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "dotenv/load"
|
|
4
|
+
require "rubyrana"
|
|
5
|
+
|
|
6
|
+
Rubyrana.configure do |config|
|
|
7
|
+
config.debug = true
|
|
8
|
+
config.default_provider = Rubyrana::Providers::Anthropic.new(
|
|
9
|
+
api_key: ENV.fetch("ANTHROPIC_API_KEY"),
|
|
10
|
+
model: ENV.fetch("ANTHROPIC_MODEL", "claude-3-haiku-20240307")
|
|
11
|
+
)
|
|
12
|
+
end
|
|
13
|
+
|
|
14
|
+
word_count = Rubyrana::Tool.new("word_count") do |text:|
|
|
15
|
+
text.split.size
|
|
16
|
+
end
|
|
17
|
+
|
|
18
|
+
built_in_tools = [
|
|
19
|
+
Rubyrana::Tools.code_interpreter
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
agent = Rubyrana::Agent.new(tools: [word_count] + built_in_tools)
|
|
23
|
+
puts agent.call("How many words are in this sentence?")
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "dotenv/load"
|
|
4
|
+
require "rubyrana"
|
|
5
|
+
|
|
6
|
+
Rubyrana.configure do |config|
|
|
7
|
+
config.default_provider = Rubyrana::Providers::Anthropic.new(
|
|
8
|
+
api_key: ENV.fetch("ANTHROPIC_API_KEY"),
|
|
9
|
+
model: ENV.fetch("ANTHROPIC_MODEL", "claude-3-haiku-20240307")
|
|
10
|
+
)
|
|
11
|
+
end
|
|
12
|
+
|
|
13
|
+
agent = Rubyrana::Agent.new
|
|
14
|
+
|
|
15
|
+
agent.stream("Give me a one-line summary of Ruby.") do |chunk|
|
|
16
|
+
print chunk
|
|
17
|
+
end
|
|
18
|
+
|
|
19
|
+
puts
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "dotenv/load"
|
|
4
|
+
require "rubyrana"
|
|
5
|
+
|
|
6
|
+
# Define tools in a separate folder using Rubyrana.tool
|
|
7
|
+
# Example file: ./tools/hello_tool.rb
|
|
8
|
+
# Rubyrana.tool("hello") { |name:| "Hello, #{name}!" }
|
|
9
|
+
|
|
10
|
+
Rubyrana.configure do |config|
|
|
11
|
+
config.default_provider = Rubyrana::Providers::Anthropic.new(
|
|
12
|
+
api_key: ENV.fetch("ANTHROPIC_API_KEY"),
|
|
13
|
+
model: ENV.fetch("ANTHROPIC_MODEL", "claude-3-haiku-20240307")
|
|
14
|
+
)
|
|
15
|
+
end
|
|
16
|
+
|
|
17
|
+
agent = Rubyrana::Agent.new(load_tools_from: "./tools")
|
|
18
|
+
puts agent.call("Use the hello tool to greet Ajay")
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Rubyrana
|
|
4
|
+
class Agent
|
|
5
|
+
def initialize(model: nil, tools: [], load_tools_from: nil, memory: true, history: [], safety_filters: [], store: nil, session_id: nil)
|
|
6
|
+
@model = model || Rubyrana.config.default_provider
|
|
7
|
+
raise ConfigurationError, "No provider configured" unless @model
|
|
8
|
+
|
|
9
|
+
@tool_registry = ToolRegistry.new(tools)
|
|
10
|
+
@load_tools_from = load_tools_from
|
|
11
|
+
@memory = memory
|
|
12
|
+
@safety_filters = safety_filters
|
|
13
|
+
@store = store
|
|
14
|
+
@session_id = session_id
|
|
15
|
+
@messages = history.dup
|
|
16
|
+
load_tools_from_directory
|
|
17
|
+
load_persisted_messages
|
|
18
|
+
end
|
|
19
|
+
|
|
20
|
+
def call(prompt, **opts)
|
|
21
|
+
apply_safety_filters(prompt)
|
|
22
|
+
messages = @memory ? @messages.dup : []
|
|
23
|
+
messages << { role: "user", content: prompt }
|
|
24
|
+
max_iterations = opts.delete(:max_iterations) || 5
|
|
25
|
+
|
|
26
|
+
max_iterations.times do
|
|
27
|
+
response = @model.complete(messages: messages, tools: @tool_registry.definitions, **opts)
|
|
28
|
+
text = response[:text] || response["text"] || ""
|
|
29
|
+
tool_calls = response[:tool_calls] || response["tool_calls"] || []
|
|
30
|
+
assistant_content = response[:assistant_content] || response["assistant_content"]
|
|
31
|
+
@last_usage = response[:usage] || response["usage"]
|
|
32
|
+
|
|
33
|
+
if tool_calls.any?
|
|
34
|
+
log_debug("Tool calls requested", tool_calls: tool_calls)
|
|
35
|
+
if assistant_content
|
|
36
|
+
messages << { role: "assistant", content: assistant_content }
|
|
37
|
+
elsif !text.to_s.empty?
|
|
38
|
+
messages << { role: "assistant", content: text }
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
tool_calls.each do |call|
|
|
42
|
+
tool_name = call[:name] || call["name"]
|
|
43
|
+
tool_call_id = call[:id] || call["id"]
|
|
44
|
+
arguments = call[:arguments] || call["arguments"] || {}
|
|
45
|
+
tool = @tool_registry.fetch(tool_name)
|
|
46
|
+
raise ToolError, "Unknown tool: #{tool_name}" unless tool
|
|
47
|
+
|
|
48
|
+
log_debug("Running tool", tool: tool_name, arguments: arguments)
|
|
49
|
+
result = tool.call(**symbolize_keys(arguments))
|
|
50
|
+
tool_message = { role: "tool", name: tool.name, content: result.to_s }
|
|
51
|
+
tool_message[:tool_call_id] = tool_call_id if tool_call_id
|
|
52
|
+
messages << tool_message
|
|
53
|
+
end
|
|
54
|
+
|
|
55
|
+
next
|
|
56
|
+
end
|
|
57
|
+
|
|
58
|
+
apply_safety_filters(text.to_s)
|
|
59
|
+
messages << { role: "assistant", content: text.to_s }
|
|
60
|
+
@messages = messages if @memory
|
|
61
|
+
persist_messages
|
|
62
|
+
return text.to_s
|
|
63
|
+
end
|
|
64
|
+
|
|
65
|
+
raise ToolError, "Tool loop exceeded max iterations"
|
|
66
|
+
end
|
|
67
|
+
|
|
68
|
+
def stream(prompt, **opts, &block)
|
|
69
|
+
apply_safety_filters(prompt)
|
|
70
|
+
messages = @memory ? @messages.dup : []
|
|
71
|
+
messages << { role: "user", content: prompt }
|
|
72
|
+
|
|
73
|
+
output = String.new
|
|
74
|
+
stream_proc = lambda do |chunk|
|
|
75
|
+
output << chunk.to_s
|
|
76
|
+
block.call(chunk) if block
|
|
77
|
+
end
|
|
78
|
+
|
|
79
|
+
if block
|
|
80
|
+
@model.stream(messages: messages, tools: @tool_registry.definitions, **opts, &stream_proc)
|
|
81
|
+
apply_safety_filters(output)
|
|
82
|
+
messages << { role: "assistant", content: output }
|
|
83
|
+
@messages = messages if @memory
|
|
84
|
+
persist_messages
|
|
85
|
+
return
|
|
86
|
+
end
|
|
87
|
+
|
|
88
|
+
Enumerator.new do |yielder|
|
|
89
|
+
@model.stream(messages: messages, tools: @tool_registry.definitions, **opts) do |chunk|
|
|
90
|
+
output << chunk.to_s
|
|
91
|
+
yielder << chunk.to_s
|
|
92
|
+
end
|
|
93
|
+
|
|
94
|
+
apply_safety_filters(output)
|
|
95
|
+
messages << { role: "assistant", content: output }
|
|
96
|
+
@messages = messages if @memory
|
|
97
|
+
persist_messages
|
|
98
|
+
end
|
|
99
|
+
end
|
|
100
|
+
|
|
101
|
+
def messages
|
|
102
|
+
@messages.dup
|
|
103
|
+
end
|
|
104
|
+
|
|
105
|
+
def last_usage
|
|
106
|
+
@last_usage
|
|
107
|
+
end
|
|
108
|
+
|
|
109
|
+
def reset!
|
|
110
|
+
@messages.clear
|
|
111
|
+
persist_messages
|
|
112
|
+
end
|
|
113
|
+
|
|
114
|
+
private
|
|
115
|
+
|
|
116
|
+
def symbolize_keys(hash)
|
|
117
|
+
hash.each_with_object({}) do |(key, value), acc|
|
|
118
|
+
acc[key.to_sym] = value
|
|
119
|
+
end
|
|
120
|
+
end
|
|
121
|
+
|
|
122
|
+
def log_debug(message, **data)
|
|
123
|
+
return unless Rubyrana.config.debug
|
|
124
|
+
|
|
125
|
+
Rubyrana.config.logger.debug({ message: message, **data })
|
|
126
|
+
end
|
|
127
|
+
|
|
128
|
+
def apply_safety_filters(text)
|
|
129
|
+
value = text.to_s
|
|
130
|
+
@safety_filters.each { |filter| filter.enforce!(value) }
|
|
131
|
+
end
|
|
132
|
+
|
|
133
|
+
def load_persisted_messages
|
|
134
|
+
return unless @store && @session_id
|
|
135
|
+
|
|
136
|
+
@messages = @store.load(@session_id)
|
|
137
|
+
end
|
|
138
|
+
|
|
139
|
+
def load_tools_from_directory
|
|
140
|
+
return unless @load_tools_from
|
|
141
|
+
|
|
142
|
+
loader = Rubyrana::Tools::Loader.new(@load_tools_from)
|
|
143
|
+
loaded = loader.load
|
|
144
|
+
loaded.each { |tool| @tool_registry.register(tool) }
|
|
145
|
+
end
|
|
146
|
+
|
|
147
|
+
def persist_messages
|
|
148
|
+
return unless @store && @session_id
|
|
149
|
+
|
|
150
|
+
@store.save(@session_id, @messages)
|
|
151
|
+
end
|
|
152
|
+
end
|
|
153
|
+
end
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "logger"
|
|
4
|
+
|
|
5
|
+
module Rubyrana
|
|
6
|
+
class Config
|
|
7
|
+
attr_accessor :logger, :default_provider, :debug
|
|
8
|
+
|
|
9
|
+
def initialize
|
|
10
|
+
@logger = Logger.new($stdout)
|
|
11
|
+
@default_provider = nil
|
|
12
|
+
@debug = false
|
|
13
|
+
end
|
|
14
|
+
end
|
|
15
|
+
end
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Rubyrana
|
|
4
|
+
class Error < StandardError; end
|
|
5
|
+
class ConfigurationError < Error; end
|
|
6
|
+
class ProviderError < Error; end
|
|
7
|
+
class ToolError < Error; end
|
|
8
|
+
class SafetyError < Error; end
|
|
9
|
+
class PersistenceError < Error; end
|
|
10
|
+
class RoutingError < Error; end
|
|
11
|
+
end
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "json"
|
|
4
|
+
require "open3"
|
|
5
|
+
|
|
6
|
+
module Rubyrana
|
|
7
|
+
module MCP
|
|
8
|
+
class Client
|
|
9
|
+
def initialize(command: nil, args: [], io: nil)
|
|
10
|
+
@command = command
|
|
11
|
+
@args = args
|
|
12
|
+
@io = io
|
|
13
|
+
@id = 0
|
|
14
|
+
end
|
|
15
|
+
|
|
16
|
+
def with_session
|
|
17
|
+
if @io
|
|
18
|
+
use_io(@io)
|
|
19
|
+
initialize_session
|
|
20
|
+
tools = list_tools.map { |tool_def| tool_from_mcp(tool_def) }
|
|
21
|
+
return yield tools
|
|
22
|
+
end
|
|
23
|
+
|
|
24
|
+
raise Rubyrana::ConfigurationError, "MCP command is required" unless @command
|
|
25
|
+
|
|
26
|
+
Open3.popen3(@command, *@args) do |stdin, stdout, stderr, wait_thr|
|
|
27
|
+
use_io({ stdin: stdin, stdout: stdout, stderr: stderr, wait_thr: wait_thr })
|
|
28
|
+
|
|
29
|
+
initialize_session
|
|
30
|
+
tools = list_tools.map { |tool_def| tool_from_mcp(tool_def) }
|
|
31
|
+
yield tools
|
|
32
|
+
ensure
|
|
33
|
+
cleanup
|
|
34
|
+
end
|
|
35
|
+
end
|
|
36
|
+
|
|
37
|
+
private
|
|
38
|
+
|
|
39
|
+
def initialize_session
|
|
40
|
+
send_request("initialize", {
|
|
41
|
+
protocolVersion: "2024-11-05",
|
|
42
|
+
capabilities: {},
|
|
43
|
+
clientInfo: { name: "rubyrana", version: Rubyrana::VERSION }
|
|
44
|
+
})
|
|
45
|
+
end
|
|
46
|
+
|
|
47
|
+
def list_tools
|
|
48
|
+
response = send_request("tools/list", {})
|
|
49
|
+
response.fetch("tools", [])
|
|
50
|
+
end
|
|
51
|
+
|
|
52
|
+
def call_tool(name, arguments)
|
|
53
|
+
response = send_request("tools/call", {
|
|
54
|
+
name: name,
|
|
55
|
+
arguments: arguments
|
|
56
|
+
})
|
|
57
|
+
|
|
58
|
+
response
|
|
59
|
+
end
|
|
60
|
+
|
|
61
|
+
def tool_from_mcp(defn)
|
|
62
|
+
Rubyrana::Tool.new(
|
|
63
|
+
defn.fetch("name"),
|
|
64
|
+
description: defn["description"],
|
|
65
|
+
schema: defn["inputSchema"]
|
|
66
|
+
) do |**args|
|
|
67
|
+
call_tool(defn.fetch("name"), args)
|
|
68
|
+
end
|
|
69
|
+
end
|
|
70
|
+
|
|
71
|
+
def send_request(method, params)
|
|
72
|
+
@id += 1
|
|
73
|
+
request = {
|
|
74
|
+
jsonrpc: "2.0",
|
|
75
|
+
id: @id,
|
|
76
|
+
method: method,
|
|
77
|
+
params: params
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
@stdin.write(JSON.dump(request))
|
|
81
|
+
@stdin.write("\n")
|
|
82
|
+
@stdin.flush
|
|
83
|
+
|
|
84
|
+
read_response(@id)
|
|
85
|
+
end
|
|
86
|
+
|
|
87
|
+
def read_response(request_id)
|
|
88
|
+
loop do
|
|
89
|
+
line = @stdout.gets
|
|
90
|
+
raise Rubyrana::ProviderError, "MCP server closed" unless line
|
|
91
|
+
|
|
92
|
+
response = JSON.parse(line)
|
|
93
|
+
next unless response["id"] == request_id
|
|
94
|
+
|
|
95
|
+
if response["error"]
|
|
96
|
+
raise Rubyrana::ProviderError, response["error"].to_s
|
|
97
|
+
end
|
|
98
|
+
|
|
99
|
+
return response.fetch("result", {})
|
|
100
|
+
end
|
|
101
|
+
rescue JSON::ParserError
|
|
102
|
+
raise Rubyrana::ProviderError, "Invalid MCP response"
|
|
103
|
+
end
|
|
104
|
+
|
|
105
|
+
def cleanup
|
|
106
|
+
return unless @stdin
|
|
107
|
+
|
|
108
|
+
@stdin.close unless @stdin.closed?
|
|
109
|
+
@stdout.close unless @stdout.closed?
|
|
110
|
+
@stderr.close unless @stderr.closed?
|
|
111
|
+
@wait_thr.value if @wait_thr
|
|
112
|
+
rescue StandardError
|
|
113
|
+
# Ignore cleanup errors
|
|
114
|
+
end
|
|
115
|
+
|
|
116
|
+
def use_io(io)
|
|
117
|
+
@stdin = io.fetch(:stdin)
|
|
118
|
+
@stdout = io.fetch(:stdout)
|
|
119
|
+
@stderr = io.fetch(:stderr)
|
|
120
|
+
@wait_thr = io.fetch(:wait_thr, nil)
|
|
121
|
+
end
|
|
122
|
+
end
|
|
123
|
+
end
|
|
124
|
+
end
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Rubyrana
|
|
4
|
+
class MultiAgent
|
|
5
|
+
def initialize(agents:, router: nil)
|
|
6
|
+
@agents = agents
|
|
7
|
+
@router = router
|
|
8
|
+
end
|
|
9
|
+
|
|
10
|
+
def call(prompt, **opts)
|
|
11
|
+
agent = @router ? @router.route(prompt, agents: @agents) : @agents.first
|
|
12
|
+
raise Rubyrana::RoutingError, "No agents configured" unless agent
|
|
13
|
+
|
|
14
|
+
agent.call(prompt, **opts)
|
|
15
|
+
end
|
|
16
|
+
|
|
17
|
+
def broadcast(prompt, **opts)
|
|
18
|
+
@agents.map { |agent| agent.call(prompt, **opts) }
|
|
19
|
+
end
|
|
20
|
+
end
|
|
21
|
+
end
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "json"
|
|
4
|
+
require "fileutils"
|
|
5
|
+
|
|
6
|
+
module Rubyrana
|
|
7
|
+
module Persistence
|
|
8
|
+
class FileStore < Base
|
|
9
|
+
def initialize(directory: ".rubyrana")
|
|
10
|
+
@directory = directory
|
|
11
|
+
FileUtils.mkdir_p(@directory)
|
|
12
|
+
end
|
|
13
|
+
|
|
14
|
+
def load(session_id)
|
|
15
|
+
path = file_path(session_id)
|
|
16
|
+
return [] unless File.exist?(path)
|
|
17
|
+
|
|
18
|
+
JSON.parse(File.read(path))
|
|
19
|
+
rescue StandardError => e
|
|
20
|
+
raise Rubyrana::PersistenceError, e.message
|
|
21
|
+
end
|
|
22
|
+
|
|
23
|
+
def save(session_id, messages)
|
|
24
|
+
path = file_path(session_id)
|
|
25
|
+
File.write(path, JSON.pretty_generate(messages))
|
|
26
|
+
true
|
|
27
|
+
rescue StandardError => e
|
|
28
|
+
raise Rubyrana::PersistenceError, e.message
|
|
29
|
+
end
|
|
30
|
+
|
|
31
|
+
private
|
|
32
|
+
|
|
33
|
+
def file_path(session_id)
|
|
34
|
+
File.join(@directory, "#{session_id}.json")
|
|
35
|
+
end
|
|
36
|
+
end
|
|
37
|
+
end
|
|
38
|
+
end
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "json"
|
|
4
|
+
|
|
5
|
+
module Rubyrana
|
|
6
|
+
module Persistence
|
|
7
|
+
class RedisStore < Base
|
|
8
|
+
def initialize(redis:, namespace: "rubyrana")
|
|
9
|
+
@redis = redis
|
|
10
|
+
@namespace = namespace
|
|
11
|
+
end
|
|
12
|
+
|
|
13
|
+
def load(session_id)
|
|
14
|
+
raw = @redis.get(key(session_id))
|
|
15
|
+
return [] unless raw
|
|
16
|
+
|
|
17
|
+
JSON.parse(raw)
|
|
18
|
+
rescue StandardError => e
|
|
19
|
+
raise Rubyrana::PersistenceError, e.message
|
|
20
|
+
end
|
|
21
|
+
|
|
22
|
+
def save(session_id, messages)
|
|
23
|
+
@redis.set(key(session_id), JSON.dump(messages))
|
|
24
|
+
true
|
|
25
|
+
rescue StandardError => e
|
|
26
|
+
raise Rubyrana::PersistenceError, e.message
|
|
27
|
+
end
|
|
28
|
+
|
|
29
|
+
private
|
|
30
|
+
|
|
31
|
+
def key(session_id)
|
|
32
|
+
"#{@namespace}:#{session_id}"
|
|
33
|
+
end
|
|
34
|
+
end
|
|
35
|
+
end
|
|
36
|
+
end
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "faraday"
|
|
4
|
+
require "json"
|
|
5
|
+
|
|
6
|
+
module Rubyrana
|
|
7
|
+
module Providers
|
|
8
|
+
class Anthropic < Base
|
|
9
|
+
def initialize(api_key:, model:, client: nil)
|
|
10
|
+
@api_key = api_key
|
|
11
|
+
@model = model
|
|
12
|
+
@client = client
|
|
13
|
+
end
|
|
14
|
+
|
|
15
|
+
def complete(prompt: nil, messages: nil, tools: [], **_opts)
|
|
16
|
+
resolved_messages = format_messages(messages || [{ role: "user", content: prompt }])
|
|
17
|
+
payload = {
|
|
18
|
+
model: @model,
|
|
19
|
+
max_tokens: 1024,
|
|
20
|
+
messages: resolved_messages
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
if tools.any?
|
|
24
|
+
payload[:tools] = tools.map do |tool|
|
|
25
|
+
entry = {
|
|
26
|
+
name: tool[:name],
|
|
27
|
+
input_schema: tool[:input_schema] || { type: "object", properties: {}, required: [] }
|
|
28
|
+
}
|
|
29
|
+
entry[:description] = tool[:description] if tool[:description]
|
|
30
|
+
entry
|
|
31
|
+
end
|
|
32
|
+
end
|
|
33
|
+
|
|
34
|
+
response = request_with_retries do
|
|
35
|
+
client.post("/v1/messages") do |req|
|
|
36
|
+
req.headers["x-api-key"] = @api_key
|
|
37
|
+
req.headers["anthropic-version"] = "2023-06-01"
|
|
38
|
+
req.headers["Content-Type"] = "application/json"
|
|
39
|
+
req.body = JSON.dump(payload)
|
|
40
|
+
end
|
|
41
|
+
end
|
|
42
|
+
|
|
43
|
+
parse_response(response)
|
|
44
|
+
end
|
|
45
|
+
|
|
46
|
+
def stream(prompt: nil, messages: nil, tools: [], **_opts, &block)
|
|
47
|
+
return super unless block_given?
|
|
48
|
+
|
|
49
|
+
resolved_messages = format_messages(messages || [{ role: "user", content: prompt }])
|
|
50
|
+
payload = {
|
|
51
|
+
model: @model,
|
|
52
|
+
max_tokens: 1024,
|
|
53
|
+
messages: resolved_messages,
|
|
54
|
+
stream: true
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
if tools.any?
|
|
58
|
+
payload[:tools] = tools.map do |tool|
|
|
59
|
+
entry = {
|
|
60
|
+
name: tool[:name],
|
|
61
|
+
input_schema: tool[:input_schema] || { type: "object", properties: {}, required: [] }
|
|
62
|
+
}
|
|
63
|
+
entry[:description] = tool[:description] if tool[:description]
|
|
64
|
+
entry
|
|
65
|
+
end
|
|
66
|
+
end
|
|
67
|
+
|
|
68
|
+
stream_request("/v1/messages", payload, &block)
|
|
69
|
+
end
|
|
70
|
+
|
|
71
|
+
private
|
|
72
|
+
|
|
73
|
+
def client
|
|
74
|
+
@client ||= Faraday.new(url: "https://api.anthropic.com")
|
|
75
|
+
end
|
|
76
|
+
|
|
77
|
+
def parse_response(response)
|
|
78
|
+
unless response.success?
|
|
79
|
+
raise ProviderError, "Anthropic request failed (status #{response.status}): #{response.body}"
|
|
80
|
+
end
|
|
81
|
+
|
|
82
|
+
body = JSON.parse(response.body)
|
|
83
|
+
assistant_content = body["content"] || []
|
|
84
|
+
text = assistant_content.select { |item| item["type"] == "text" }.map { |item| item["text"] }.join
|
|
85
|
+
tool_calls = assistant_content.select { |item| item["type"] == "tool_use" }.map do |item|
|
|
86
|
+
{
|
|
87
|
+
id: item["id"],
|
|
88
|
+
name: item["name"],
|
|
89
|
+
arguments: item["input"] || {}
|
|
90
|
+
}
|
|
91
|
+
end
|
|
92
|
+
|
|
93
|
+
{
|
|
94
|
+
text: text,
|
|
95
|
+
tool_calls: tool_calls,
|
|
96
|
+
usage: body["usage"],
|
|
97
|
+
assistant_content: assistant_content
|
|
98
|
+
}
|
|
99
|
+
rescue JSON::ParserError
|
|
100
|
+
raise ProviderError, "Invalid response from Anthropic"
|
|
101
|
+
end
|
|
102
|
+
|
|
103
|
+
def stream_request(path, payload, &block)
|
|
104
|
+
buffer = String.new
|
|
105
|
+
|
|
106
|
+
request_with_retries do
|
|
107
|
+
client.post(path) do |req|
|
|
108
|
+
req.headers["x-api-key"] = @api_key
|
|
109
|
+
req.headers["anthropic-version"] = "2023-06-01"
|
|
110
|
+
req.headers["Content-Type"] = "application/json"
|
|
111
|
+
req.options.on_data = proc do |chunk, _|
|
|
112
|
+
buffer << chunk
|
|
113
|
+
while (line = buffer.slice!(/.*\n/))
|
|
114
|
+
line = line.strip
|
|
115
|
+
next unless line.start_with?("data:")
|
|
116
|
+
|
|
117
|
+
data = line.delete_prefix("data:").strip
|
|
118
|
+
next if data == "[DONE]"
|
|
119
|
+
|
|
120
|
+
begin
|
|
121
|
+
event = JSON.parse(data)
|
|
122
|
+
delta = extract_stream_delta(event)
|
|
123
|
+
block.call(delta) if delta && !delta.empty?
|
|
124
|
+
rescue JSON::ParserError
|
|
125
|
+
next
|
|
126
|
+
end
|
|
127
|
+
end
|
|
128
|
+
end
|
|
129
|
+
req.body = JSON.dump(payload)
|
|
130
|
+
end
|
|
131
|
+
end
|
|
132
|
+
end
|
|
133
|
+
|
|
134
|
+
def extract_stream_delta(event)
|
|
135
|
+
return event.dig("delta", "text") if event["type"] == "content_block_delta"
|
|
136
|
+
return event.dig("content_block", "text") if event["type"] == "content_block_start"
|
|
137
|
+
|
|
138
|
+
nil
|
|
139
|
+
end
|
|
140
|
+
|
|
141
|
+
def format_messages(messages)
|
|
142
|
+
messages.map do |message|
|
|
143
|
+
role = message[:role] || message["role"]
|
|
144
|
+
content = message[:content] || message["content"]
|
|
145
|
+
|
|
146
|
+
if role == "tool"
|
|
147
|
+
tool_use_id = message[:tool_call_id] || message["tool_call_id"]
|
|
148
|
+
{
|
|
149
|
+
role: "user",
|
|
150
|
+
content: [
|
|
151
|
+
{
|
|
152
|
+
type: "tool_result",
|
|
153
|
+
tool_use_id: tool_use_id,
|
|
154
|
+
content: content.to_s
|
|
155
|
+
}
|
|
156
|
+
]
|
|
157
|
+
}
|
|
158
|
+
else
|
|
159
|
+
{ role: role, content: content }
|
|
160
|
+
end
|
|
161
|
+
end
|
|
162
|
+
end
|
|
163
|
+
|
|
164
|
+
def request_with_retries(max_retries: 2)
|
|
165
|
+
attempts = 0
|
|
166
|
+
begin
|
|
167
|
+
attempts += 1
|
|
168
|
+
yield
|
|
169
|
+
rescue Faraday::Error => e
|
|
170
|
+
retry if attempts <= max_retries
|
|
171
|
+
raise ProviderError, e.message
|
|
172
|
+
end
|
|
173
|
+
end
|
|
174
|
+
end
|
|
175
|
+
end
|
|
176
|
+
end
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Rubyrana
|
|
4
|
+
module Providers
|
|
5
|
+
class Base
|
|
6
|
+
def complete(prompt: nil, messages: nil, tools: [], **_opts)
|
|
7
|
+
raise NotImplementedError, "Provider must implement #complete"
|
|
8
|
+
end
|
|
9
|
+
|
|
10
|
+
def stream(prompt: nil, messages: nil, tools: [], **_opts, &block)
|
|
11
|
+
result = complete(prompt: prompt, messages: messages, tools: tools, **_opts)
|
|
12
|
+
text = result.is_a?(Hash) ? (result[:text] || result["text"]) : result
|
|
13
|
+
|
|
14
|
+
if block_given?
|
|
15
|
+
yield text.to_s
|
|
16
|
+
return
|
|
17
|
+
end
|
|
18
|
+
|
|
19
|
+
Enumerator.new do |yielder|
|
|
20
|
+
yielder << text.to_s
|
|
21
|
+
end
|
|
22
|
+
end
|
|
23
|
+
end
|
|
24
|
+
end
|
|
25
|
+
end
|