deepagents 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.rspec +3 -0
- data/LICENSE +21 -0
- data/README.md +237 -0
- data/Rakefile +17 -0
- data/examples/langchain_integration.rb +58 -0
- data/examples/research_agent.rb +180 -0
- data/lib/deepagents/deepagentsrb/errors.rb +80 -0
- data/lib/deepagents/deepagentsrb/graph.rb +213 -0
- data/lib/deepagents/deepagentsrb/models.rb +167 -0
- data/lib/deepagents/deepagentsrb/state.rb +139 -0
- data/lib/deepagents/deepagentsrb/sub_agent.rb +125 -0
- data/lib/deepagents/deepagentsrb/tools.rb +205 -0
- data/lib/deepagents/deepagentsrb/version.rb +3 -0
- data/lib/deepagents/errors.rb +80 -0
- data/lib/deepagents/graph.rb +207 -0
- data/lib/deepagents/models.rb +217 -0
- data/lib/deepagents/state.rb +139 -0
- data/lib/deepagents/sub_agent.rb +130 -0
- data/lib/deepagents/tools.rb +152 -0
- data/lib/deepagents/version.rb +3 -0
- data/lib/deepagents.rb +61 -0
- metadata +150 -0
@@ -0,0 +1,205 @@
|
|
1
|
+
module DeepAgentsRb
|
2
|
+
module Tools
|
3
|
+
# Tool descriptions
|
4
|
+
WRITE_TODOS_DESCRIPTION = "Use this tool to manage and track tasks. It helps you plan and break down complex tasks."
|
5
|
+
EDIT_DESCRIPTION = "Edit a file by replacing specific content."
|
6
|
+
TOOL_DESCRIPTION = "Read a file from the virtual file system."
|
7
|
+
|
8
|
+
# Tool class to define tools
|
9
|
+
class Tool
|
10
|
+
attr_reader :name, :description, :function
|
11
|
+
|
12
|
+
def initialize(name, description, &block)
|
13
|
+
@name = name
|
14
|
+
@description = description
|
15
|
+
@function = block
|
16
|
+
end
|
17
|
+
|
18
|
+
def call(*args, **kwargs)
|
19
|
+
begin
|
20
|
+
@function.call(*args, **kwargs)
|
21
|
+
rescue ArgumentError => e
|
22
|
+
raise ToolError.new(@name, "Invalid arguments: #{e.message}")
|
23
|
+
rescue => e
|
24
|
+
raise ToolError.new(@name, e.message)
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
28
|
+
|
29
|
+
# Command class to represent tool results
|
30
|
+
class Command
|
31
|
+
attr_reader :update
|
32
|
+
|
33
|
+
def initialize(update)
|
34
|
+
raise ArgumentError, "Update must be a hash" unless update.is_a?(Hash)
|
35
|
+
@update = update
|
36
|
+
end
|
37
|
+
end
|
38
|
+
|
39
|
+
# Tool message class
|
40
|
+
class ToolMessage
|
41
|
+
attr_reader :content, :tool_call_id
|
42
|
+
|
43
|
+
def initialize(content, tool_call_id:)
|
44
|
+
@content = content
|
45
|
+
@tool_call_id = tool_call_id
|
46
|
+
end
|
47
|
+
|
48
|
+
def to_h
|
49
|
+
{
|
50
|
+
role: "tool",
|
51
|
+
content: @content,
|
52
|
+
tool_call_id: @tool_call_id
|
53
|
+
}
|
54
|
+
end
|
55
|
+
end
|
56
|
+
|
57
|
+
# Define the write_todos tool
|
58
|
+
def self.write_todos
|
59
|
+
Tool.new("write_todos", WRITE_TODOS_DESCRIPTION) do |todos, tool_call_id:|
|
60
|
+
Command.new(
|
61
|
+
update: {
|
62
|
+
todos: todos,
|
63
|
+
messages: [
|
64
|
+
ToolMessage.new("Updated todo list to #{todos}", tool_call_id: tool_call_id)
|
65
|
+
]
|
66
|
+
}
|
67
|
+
)
|
68
|
+
end
|
69
|
+
end
|
70
|
+
|
71
|
+
# Define the ls tool
|
72
|
+
def self.ls
|
73
|
+
Tool.new("ls", "List all files") do |state|
|
74
|
+
state.get("files", {}).keys
|
75
|
+
end
|
76
|
+
end
|
77
|
+
|
78
|
+
# Define the read_file tool
|
79
|
+
def self.read_file
|
80
|
+
Tool.new("read_file", TOOL_DESCRIPTION) do |file_path, state, offset: 0, limit: 2000|
|
81
|
+
mock_filesystem = state.get("files", {})
|
82
|
+
|
83
|
+
if !mock_filesystem.key?(file_path)
|
84
|
+
raise FileNotFoundError.new(file_path)
|
85
|
+
end
|
86
|
+
|
87
|
+
# Get file content
|
88
|
+
content = mock_filesystem[file_path]
|
89
|
+
|
90
|
+
# Handle empty file
|
91
|
+
if content.nil? || content.strip.empty?
|
92
|
+
return "System reminder: File exists but has empty contents"
|
93
|
+
end
|
94
|
+
|
95
|
+
# Split content into lines
|
96
|
+
lines = content.split("\n")
|
97
|
+
|
98
|
+
# Apply line offset and limit
|
99
|
+
start_idx = offset
|
100
|
+
end_idx = [start_idx + limit, lines.length].min
|
101
|
+
|
102
|
+
# Handle case where offset is beyond file length
|
103
|
+
if start_idx >= lines.length
|
104
|
+
return "Error: Line offset #{offset} exceeds file length (#{lines.length} lines)"
|
105
|
+
end
|
106
|
+
|
107
|
+
# Format output with line numbers (cat -n format)
|
108
|
+
result_lines = []
|
109
|
+
(start_idx...end_idx).each do |i|
|
110
|
+
line_content = lines[i]
|
111
|
+
|
112
|
+
# Truncate lines longer than 2000 characters
|
113
|
+
if line_content.length > 2000
|
114
|
+
line_content = line_content[0...2000]
|
115
|
+
end
|
116
|
+
|
117
|
+
# Line numbers start at 1, so add 1 to the index
|
118
|
+
line_number = i + 1
|
119
|
+
result_lines << "#{line_number.to_s.rjust(6)}\t#{line_content}"
|
120
|
+
end
|
121
|
+
|
122
|
+
result_lines.join("\n")
|
123
|
+
end
|
124
|
+
end
|
125
|
+
|
126
|
+
# Define the write_file tool
|
127
|
+
def self.write_file
|
128
|
+
Tool.new("write_file", "Write to a file") do |file_path, content, state, tool_call_id:|
|
129
|
+
files = state.get("files", {})
|
130
|
+
files[file_path] = content
|
131
|
+
|
132
|
+
Command.new(
|
133
|
+
update: {
|
134
|
+
files: files,
|
135
|
+
messages: [
|
136
|
+
ToolMessage.new("Updated file #{file_path}", tool_call_id: tool_call_id)
|
137
|
+
]
|
138
|
+
}
|
139
|
+
)
|
140
|
+
end
|
141
|
+
end
|
142
|
+
|
143
|
+
# Define the edit_file tool
|
144
|
+
def self.edit_file
|
145
|
+
Tool.new("edit_file", EDIT_DESCRIPTION) do |file_path, old_string, new_string, state, tool_call_id:, replace_all: false|
|
146
|
+
mock_filesystem = state.get("files", {})
|
147
|
+
|
148
|
+
# Check if file exists in mock filesystem
|
149
|
+
if !mock_filesystem.key?(file_path)
|
150
|
+
raise FileNotFoundError.new(file_path)
|
151
|
+
return "Error: File '#{file_path}' not found"
|
152
|
+
end
|
153
|
+
|
154
|
+
# Get current file content
|
155
|
+
content = mock_filesystem[file_path]
|
156
|
+
|
157
|
+
# Check if old_string exists in the file
|
158
|
+
if !content.include?(old_string)
|
159
|
+
return "Error: String not found in file: '#{old_string}'"
|
160
|
+
end
|
161
|
+
|
162
|
+
# If not replace_all, check for uniqueness
|
163
|
+
if !replace_all
|
164
|
+
occurrences = content.scan(old_string).length
|
165
|
+
if occurrences > 1
|
166
|
+
return "Error: String '#{old_string}' appears #{occurrences} times in file. Use replace_all=true to replace all instances, or provide a more specific string with surrounding context."
|
167
|
+
elsif occurrences == 0
|
168
|
+
return "Error: String not found in file: '#{old_string}'"
|
169
|
+
end
|
170
|
+
end
|
171
|
+
|
172
|
+
# Perform the replacement
|
173
|
+
if replace_all
|
174
|
+
new_content = content.gsub(old_string, new_string)
|
175
|
+
replacement_count = content.scan(old_string).length
|
176
|
+
result_msg = "Successfully replaced #{replacement_count} instance(s) of the string in '#{file_path}'"
|
177
|
+
else
|
178
|
+
new_content = content.sub(old_string, new_string) # Replace only first occurrence
|
179
|
+
result_msg = "Successfully replaced string in '#{file_path}'"
|
180
|
+
end
|
181
|
+
|
182
|
+
# Update the mock filesystem
|
183
|
+
mock_filesystem[file_path] = new_content
|
184
|
+
|
185
|
+
Command.new(
|
186
|
+
update: {
|
187
|
+
files: mock_filesystem,
|
188
|
+
messages: [ToolMessage.new(result_msg, tool_call_id: tool_call_id)]
|
189
|
+
}
|
190
|
+
)
|
191
|
+
end
|
192
|
+
end
|
193
|
+
|
194
|
+
# Get all built-in tools
|
195
|
+
def self.built_in_tools
|
196
|
+
[
|
197
|
+
write_todos,
|
198
|
+
ls,
|
199
|
+
read_file,
|
200
|
+
write_file,
|
201
|
+
edit_file
|
202
|
+
]
|
203
|
+
end
|
204
|
+
end
|
205
|
+
end
|
@@ -0,0 +1,80 @@
|
|
1
|
+
module DeepAgents
|
2
|
+
# Base error class for all DeepAgents errors
|
3
|
+
class Error < StandardError; end
|
4
|
+
|
5
|
+
# Error raised when there's an issue with the state
|
6
|
+
class StateError < Error; end
|
7
|
+
|
8
|
+
# Error raised when there's an issue with a tool
|
9
|
+
class ToolError < Error
|
10
|
+
attr_reader :tool_name
|
11
|
+
|
12
|
+
def initialize(tool_name, message)
|
13
|
+
@tool_name = tool_name
|
14
|
+
super("Error in tool '#{tool_name}': #{message}")
|
15
|
+
end
|
16
|
+
end
|
17
|
+
|
18
|
+
# Error raised when a tool is not found
|
19
|
+
class ToolNotFoundError < ToolError
|
20
|
+
def initialize(tool_name)
|
21
|
+
super(tool_name, "Tool not found")
|
22
|
+
end
|
23
|
+
end
|
24
|
+
|
25
|
+
# Error raised when there's an issue with a file operation
|
26
|
+
class FileError < Error; end
|
27
|
+
|
28
|
+
# Error raised when a file is not found
|
29
|
+
class FileNotFoundError < FileError
|
30
|
+
attr_reader :path
|
31
|
+
|
32
|
+
def initialize(path)
|
33
|
+
@path = path
|
34
|
+
super("File not found: #{path}")
|
35
|
+
end
|
36
|
+
end
|
37
|
+
|
38
|
+
# Error raised when there's an issue with a model
|
39
|
+
class ModelError < Error; end
|
40
|
+
|
41
|
+
# Error raised when there's an issue with API credentials
|
42
|
+
class CredentialsError < ModelError
|
43
|
+
def initialize(model_name)
|
44
|
+
super("Missing API credentials for #{model_name}")
|
45
|
+
end
|
46
|
+
end
|
47
|
+
|
48
|
+
# Error raised when there's an issue with a sub-agent
|
49
|
+
class SubAgentError < Error; end
|
50
|
+
|
51
|
+
# Error raised when a sub-agent is not found
|
52
|
+
class SubAgentNotFoundError < SubAgentError
|
53
|
+
attr_reader :agent_name
|
54
|
+
|
55
|
+
def initialize(agent_name)
|
56
|
+
@agent_name = agent_name
|
57
|
+
super("Sub-agent not found: #{agent_name}")
|
58
|
+
end
|
59
|
+
end
|
60
|
+
|
61
|
+
# Error raised when there's an issue with parsing a tool call
|
62
|
+
class ToolCallParseError < Error
|
63
|
+
attr_reader :raw_content
|
64
|
+
|
65
|
+
def initialize(message, raw_content = nil)
|
66
|
+
@raw_content = raw_content
|
67
|
+
super(message)
|
68
|
+
end
|
69
|
+
end
|
70
|
+
|
71
|
+
# Error raised when there's an issue with the LLM response
|
72
|
+
class LLMResponseError < Error; end
|
73
|
+
|
74
|
+
# Error raised when the maximum number of iterations is reached
|
75
|
+
class MaxIterationsError < Error
|
76
|
+
def initialize(iterations)
|
77
|
+
super("Maximum number of iterations (#{iterations}) reached")
|
78
|
+
end
|
79
|
+
end
|
80
|
+
end
|
@@ -0,0 +1,207 @@
|
|
1
|
+
require_relative 'errors'
|
2
|
+
require_relative 'state'
|
3
|
+
require_relative 'tools'
|
4
|
+
require_relative 'sub_agent'
|
5
|
+
require 'langgraph_rb'
|
6
|
+
|
7
|
+
module DeepAgents
|
8
|
+
# ReactAgent class for implementing the React pattern
|
9
|
+
class ReactAgent
|
10
|
+
attr_reader :tools, :instructions, :model, :state
|
11
|
+
|
12
|
+
def initialize(tools, instructions, model, state_schema = nil)
|
13
|
+
@tools = tools
|
14
|
+
@instructions = instructions
|
15
|
+
@model = model
|
16
|
+
@state = DeepAgentState.new
|
17
|
+
@tool_registry = ToolRegistry.new
|
18
|
+
|
19
|
+
# Register tools
|
20
|
+
tools.each do |tool|
|
21
|
+
@tool_registry.register(tool)
|
22
|
+
end
|
23
|
+
end
|
24
|
+
|
25
|
+
def run(max_iterations = 10)
|
26
|
+
iterations = 0
|
27
|
+
messages = []
|
28
|
+
|
29
|
+
while iterations < max_iterations
|
30
|
+
iterations += 1
|
31
|
+
|
32
|
+
prompt = build_prompt(messages)
|
33
|
+
response = @model.generate(prompt, @tools)
|
34
|
+
|
35
|
+
if response[:tool_calls] && !response[:tool_calls].empty?
|
36
|
+
tool_call = response[:tool_calls].first
|
37
|
+
tool_name = tool_call[:name]
|
38
|
+
tool_args = tool_call[:arguments]
|
39
|
+
|
40
|
+
begin
|
41
|
+
tool = @tool_registry.get(tool_name)
|
42
|
+
result = execute_tool(tool, tool_args)
|
43
|
+
|
44
|
+
messages << { role: "assistant", content: "I'll use the #{tool_name} tool." }
|
45
|
+
messages << { role: "system", content: "Tool result: #{result}" }
|
46
|
+
rescue => e
|
47
|
+
messages << { role: "system", content: "Error: #{e.message}" }
|
48
|
+
end
|
49
|
+
else
|
50
|
+
messages << { role: "assistant", content: response[:content] }
|
51
|
+
break
|
52
|
+
end
|
53
|
+
end
|
54
|
+
|
55
|
+
if iterations >= max_iterations
|
56
|
+
raise MaxIterationsError.new(max_iterations)
|
57
|
+
end
|
58
|
+
|
59
|
+
# Return the final response
|
60
|
+
messages.last[:content]
|
61
|
+
end
|
62
|
+
|
63
|
+
private
|
64
|
+
|
65
|
+
def build_prompt(messages)
|
66
|
+
prompt = "You are a deep agent with the following instructions:\n\n"
|
67
|
+
prompt += "#{@instructions}\n\n"
|
68
|
+
|
69
|
+
if @tools && !@tools.empty?
|
70
|
+
prompt += "You have access to the following tools:\n\n"
|
71
|
+
|
72
|
+
@tools.each do |tool|
|
73
|
+
prompt += "#{tool.name}: #{tool.description}\n"
|
74
|
+
prompt += "Parameters: #{tool.parameters.join(', ')}\n\n"
|
75
|
+
end
|
76
|
+
end
|
77
|
+
|
78
|
+
if @state && !@state.to_h.empty?
|
79
|
+
prompt += "Current state:\n\n"
|
80
|
+
prompt += "Todos:\n"
|
81
|
+
|
82
|
+
@state.todos.each_with_index do |todo, index|
|
83
|
+
prompt += "#{index + 1}. [#{todo.status}] #{todo.content}\n"
|
84
|
+
end
|
85
|
+
|
86
|
+
prompt += "\n"
|
87
|
+
|
88
|
+
if !@state.files.empty?
|
89
|
+
prompt += "Files:\n"
|
90
|
+
@state.files.each do |path, content|
|
91
|
+
prompt += "- #{path}\n"
|
92
|
+
end
|
93
|
+
prompt += "\n"
|
94
|
+
end
|
95
|
+
end
|
96
|
+
|
97
|
+
if messages && !messages.empty?
|
98
|
+
prompt += "Previous messages:\n\n"
|
99
|
+
|
100
|
+
messages.each do |message|
|
101
|
+
prompt += "#{message[:role]}: #{message[:content]}\n\n"
|
102
|
+
end
|
103
|
+
end
|
104
|
+
|
105
|
+
prompt
|
106
|
+
end
|
107
|
+
|
108
|
+
def execute_tool(tool, args)
|
109
|
+
begin
|
110
|
+
if args.is_a?(Hash)
|
111
|
+
tool.call(**args)
|
112
|
+
else
|
113
|
+
tool.call(args)
|
114
|
+
end
|
115
|
+
rescue => e
|
116
|
+
"Error executing tool: #{e.message}"
|
117
|
+
end
|
118
|
+
end
|
119
|
+
end
|
120
|
+
|
121
|
+
# Graph module for creating deep agents
|
122
|
+
module Graph
|
123
|
+
def self.create_deep_agent(tools, instructions, model: nil, subagents: nil, state_schema: nil)
|
124
|
+
# Create a default model if none is provided
|
125
|
+
if model.nil?
|
126
|
+
begin
|
127
|
+
model = Models::Claude.new
|
128
|
+
rescue => e
|
129
|
+
begin
|
130
|
+
model = Models::OpenAI.new
|
131
|
+
rescue => e2
|
132
|
+
# Fall back to mock model if no API keys are available
|
133
|
+
model = Models::MockModel.new
|
134
|
+
end
|
135
|
+
end
|
136
|
+
end
|
137
|
+
|
138
|
+
# Add standard tools
|
139
|
+
file_system = FileSystem.new
|
140
|
+
state = state_schema ? state_schema.new : DeepAgentState.new
|
141
|
+
|
142
|
+
# Create standard tools
|
143
|
+
standard_tools = StandardTools.file_system_tools(file_system) +
|
144
|
+
StandardTools.todo_tools(state) + [StandardTools.planning_tool]
|
145
|
+
|
146
|
+
# Combine with user-provided tools
|
147
|
+
all_tools = tools + standard_tools
|
148
|
+
|
149
|
+
# Create LangGraph tools
|
150
|
+
langgraph_tools = all_tools.map do |tool|
|
151
|
+
LangGraphRb::Tool.new(
|
152
|
+
name: tool.name,
|
153
|
+
description: tool.description,
|
154
|
+
parameters: tool.parameters,
|
155
|
+
function: ->(args) { tool.call(**args) }
|
156
|
+
)
|
157
|
+
end
|
158
|
+
|
159
|
+
# Create sub-agent tools if provided
|
160
|
+
if subagents && !subagents.empty?
|
161
|
+
subagent_registry = SubAgentRegistry.new
|
162
|
+
|
163
|
+
subagents.each do |subagent|
|
164
|
+
subagent_registry.register(subagent)
|
165
|
+
|
166
|
+
# Create a tool for each subagent
|
167
|
+
subagent_tool = LangGraphRb::Tool.new(
|
168
|
+
name: "task_#{subagent.name}",
|
169
|
+
description: "Use the #{subagent.name} subagent for tasks related to: #{subagent.description}",
|
170
|
+
parameters: { task: "The task to perform" },
|
171
|
+
function: ->(args) { subagent.invoke(args[:task]) }
|
172
|
+
)
|
173
|
+
|
174
|
+
langgraph_tools << subagent_tool
|
175
|
+
end
|
176
|
+
end
|
177
|
+
|
178
|
+
# Create the LangGraph agent
|
179
|
+
agent = LangGraphRb::Agent.new(
|
180
|
+
model: model.to_langchain_model,
|
181
|
+
tools: langgraph_tools,
|
182
|
+
system_prompt: instructions,
|
183
|
+
state: state
|
184
|
+
)
|
185
|
+
|
186
|
+
# Create a wrapper that maintains compatibility with our ReactAgent interface
|
187
|
+
ReactAgentWrapper.new(agent, all_tools, instructions, model, state)
|
188
|
+
end
|
189
|
+
|
190
|
+
# Wrapper class to maintain compatibility with ReactAgent interface
|
191
|
+
class ReactAgentWrapper
|
192
|
+
attr_reader :tools, :instructions, :model, :state
|
193
|
+
|
194
|
+
def initialize(langgraph_agent, tools, instructions, model, state)
|
195
|
+
@langgraph_agent = langgraph_agent
|
196
|
+
@tools = tools
|
197
|
+
@instructions = instructions
|
198
|
+
@model = model
|
199
|
+
@state = state
|
200
|
+
end
|
201
|
+
|
202
|
+
def run(max_iterations = 10)
|
203
|
+
@langgraph_agent.run(max_iterations: max_iterations)
|
204
|
+
end
|
205
|
+
end
|
206
|
+
end
|
207
|
+
end
|
@@ -0,0 +1,217 @@
|
|
1
|
+
require_relative 'errors'
|
2
|
+
|
3
|
+
module DeepAgents
|
4
|
+
module Models
|
5
|
+
# Base model class for LLM models
|
6
|
+
class BaseModel
|
7
|
+
attr_reader :model
|
8
|
+
|
9
|
+
def initialize(model:)
|
10
|
+
@model = model
|
11
|
+
end
|
12
|
+
|
13
|
+
def generate(prompt, tools = nil)
|
14
|
+
raise NotImplementedError, "Subclasses must implement the generate method"
|
15
|
+
end
|
16
|
+
|
17
|
+
def to_langchain_model
|
18
|
+
raise NotImplementedError, "Subclasses must implement the to_langchain_model method"
|
19
|
+
end
|
20
|
+
end
|
21
|
+
|
22
|
+
# Claude model adapter
|
23
|
+
class Claude < BaseModel
|
24
|
+
def initialize(api_key: nil, model: "claude-3-sonnet-20240229")
|
25
|
+
super(model: model)
|
26
|
+
@api_key = api_key || ENV["ANTHROPIC_API_KEY"]
|
27
|
+
|
28
|
+
begin
|
29
|
+
require 'anthropic'
|
30
|
+
rescue LoadError
|
31
|
+
raise LoadError, "The 'anthropic' gem is required for Claude models"
|
32
|
+
end
|
33
|
+
|
34
|
+
if @api_key.nil? || @api_key.empty?
|
35
|
+
raise CredentialsError.new("Claude")
|
36
|
+
end
|
37
|
+
|
38
|
+
@client = Anthropic::Client.new(api_key: @api_key)
|
39
|
+
end
|
40
|
+
|
41
|
+
def generate(prompt, tools = nil)
|
42
|
+
begin
|
43
|
+
messages = [{ role: "user", content: prompt }]
|
44
|
+
|
45
|
+
params = {
|
46
|
+
model: @model,
|
47
|
+
messages: messages,
|
48
|
+
max_tokens: 4096,
|
49
|
+
temperature: 0.7
|
50
|
+
}
|
51
|
+
|
52
|
+
if tools && !tools.empty?
|
53
|
+
tool_definitions = tools.map do |tool|
|
54
|
+
{
|
55
|
+
name: tool.name,
|
56
|
+
description: tool.description,
|
57
|
+
parameters: {
|
58
|
+
type: "object",
|
59
|
+
properties: Hash[tool.parameters.map { |param| [param, { type: "string" }] }],
|
60
|
+
required: tool.parameters
|
61
|
+
}
|
62
|
+
}
|
63
|
+
end
|
64
|
+
|
65
|
+
params[:tools] = tool_definitions
|
66
|
+
end
|
67
|
+
|
68
|
+
response = @client.messages.create(**params)
|
69
|
+
|
70
|
+
if tools && !tools.empty? && response.content.first.type == "tool_use"
|
71
|
+
tool_call = response.content.first.tool_use
|
72
|
+
return {
|
73
|
+
content: nil,
|
74
|
+
tool_calls: [{
|
75
|
+
name: tool_call.name,
|
76
|
+
arguments: tool_call.parameters
|
77
|
+
}]
|
78
|
+
}
|
79
|
+
else
|
80
|
+
return {
|
81
|
+
content: response.content.first.text,
|
82
|
+
tool_calls: []
|
83
|
+
}
|
84
|
+
end
|
85
|
+
rescue => e
|
86
|
+
raise ModelError, "Claude API error: #{e.message}"
|
87
|
+
end
|
88
|
+
end
|
89
|
+
|
90
|
+
def to_langchain_model
|
91
|
+
# Create and return a langchainrb Claude model
|
92
|
+
require 'langchain'
|
93
|
+
Langchain::LLM::Anthropic.new(
|
94
|
+
api_key: @api_key,
|
95
|
+
model_name: @model
|
96
|
+
)
|
97
|
+
end
|
98
|
+
end
|
99
|
+
|
100
|
+
# OpenAI model adapter
|
101
|
+
class OpenAI < BaseModel
|
102
|
+
def initialize(api_key: nil, model: "gpt-4o")
|
103
|
+
super(model: model)
|
104
|
+
@api_key = api_key || ENV["OPENAI_API_KEY"]
|
105
|
+
|
106
|
+
begin
|
107
|
+
require 'openai'
|
108
|
+
rescue LoadError
|
109
|
+
raise LoadError, "The 'openai' gem is required for OpenAI models"
|
110
|
+
end
|
111
|
+
|
112
|
+
if @api_key.nil? || @api_key.empty?
|
113
|
+
raise CredentialsError.new("OpenAI")
|
114
|
+
end
|
115
|
+
|
116
|
+
@client = ::OpenAI::Client.new(access_token: @api_key)
|
117
|
+
end
|
118
|
+
|
119
|
+
def generate(prompt, tools = nil)
|
120
|
+
begin
|
121
|
+
messages = [{ role: "user", content: prompt }]
|
122
|
+
|
123
|
+
params = {
|
124
|
+
model: @model,
|
125
|
+
messages: messages,
|
126
|
+
temperature: 0.7,
|
127
|
+
max_tokens: 4096
|
128
|
+
}
|
129
|
+
|
130
|
+
if tools && !tools.empty?
|
131
|
+
tool_definitions = tools.map do |tool|
|
132
|
+
{
|
133
|
+
type: "function",
|
134
|
+
function: {
|
135
|
+
name: tool.name,
|
136
|
+
description: tool.description,
|
137
|
+
parameters: {
|
138
|
+
type: "object",
|
139
|
+
properties: Hash[tool.parameters.map { |param| [param, { type: "string" }] }],
|
140
|
+
required: tool.parameters
|
141
|
+
}
|
142
|
+
}
|
143
|
+
}
|
144
|
+
end
|
145
|
+
|
146
|
+
params[:tools] = tool_definitions
|
147
|
+
end
|
148
|
+
|
149
|
+
response = @client.chat(parameters: params)
|
150
|
+
|
151
|
+
if tools && !tools.empty? && response.dig("choices", 0, "message", "tool_calls")
|
152
|
+
tool_call = response.dig("choices", 0, "message", "tool_calls").first
|
153
|
+
return {
|
154
|
+
content: nil,
|
155
|
+
tool_calls: [{
|
156
|
+
name: tool_call["function"]["name"],
|
157
|
+
arguments: JSON.parse(tool_call["function"]["arguments"])
|
158
|
+
}]
|
159
|
+
}
|
160
|
+
else
|
161
|
+
return {
|
162
|
+
content: response.dig("choices", 0, "message", "content"),
|
163
|
+
tool_calls: []
|
164
|
+
}
|
165
|
+
end
|
166
|
+
rescue => e
|
167
|
+
raise ModelError, "OpenAI API error: #{e.message}"
|
168
|
+
end
|
169
|
+
end
|
170
|
+
|
171
|
+
def to_langchain_model
|
172
|
+
# Create and return a langchainrb OpenAI model
|
173
|
+
require 'langchain'
|
174
|
+
Langchain::LLM::OpenAI.new(
|
175
|
+
api_key: @api_key,
|
176
|
+
model_name: @model
|
177
|
+
)
|
178
|
+
end
|
179
|
+
end
|
180
|
+
|
181
|
+
# Mock model for testing
|
182
|
+
class MockModel < BaseModel
|
183
|
+
def initialize(model: "mock-model")
|
184
|
+
super(model: model)
|
185
|
+
end
|
186
|
+
|
187
|
+
def generate(prompt, tools = nil)
|
188
|
+
if tools && !tools.empty? && prompt.include?("use tool")
|
189
|
+
tool = tools.first
|
190
|
+
return {
|
191
|
+
content: nil,
|
192
|
+
tool_calls: [{
|
193
|
+
name: tool.name,
|
194
|
+
arguments: Hash[tool.parameters.map { |param| [param, "test_value"] }]
|
195
|
+
}]
|
196
|
+
}
|
197
|
+
else
|
198
|
+
return {
|
199
|
+
content: "This is a mock response for: #{prompt}",
|
200
|
+
tool_calls: []
|
201
|
+
}
|
202
|
+
end
|
203
|
+
end
|
204
|
+
|
205
|
+
def to_langchain_model
|
206
|
+
# Create and return a langchainrb mock model
|
207
|
+
require 'langchain'
|
208
|
+
Langchain::LLM::Base.new.tap do |model|
|
209
|
+
# Override the complete method to return mock responses
|
210
|
+
def model.complete(prompt:, **_kwargs)
|
211
|
+
"This is a mock response from langchain for: #{prompt}"
|
212
|
+
end
|
213
|
+
end
|
214
|
+
end
|
215
|
+
end
|
216
|
+
end
|
217
|
+
end
|