langchainrb 0.12.0 → 0.13.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -0
- data/README.md +3 -2
- data/lib/langchain/assistants/assistant.rb +75 -20
- data/lib/langchain/assistants/messages/base.rb +16 -0
- data/lib/langchain/assistants/messages/google_gemini_message.rb +90 -0
- data/lib/langchain/assistants/messages/openai_message.rb +74 -0
- data/lib/langchain/assistants/thread.rb +5 -5
- data/lib/langchain/llm/anthropic.rb +27 -49
- data/lib/langchain/llm/aws_bedrock.rb +30 -34
- data/lib/langchain/llm/azure.rb +6 -0
- data/lib/langchain/llm/base.rb +20 -1
- data/lib/langchain/llm/cohere.rb +38 -6
- data/lib/langchain/llm/google_gemini.rb +67 -0
- data/lib/langchain/llm/google_vertex_ai.rb +68 -112
- data/lib/langchain/llm/mistral_ai.rb +10 -19
- data/lib/langchain/llm/ollama.rb +23 -27
- data/lib/langchain/llm/openai.rb +20 -48
- data/lib/langchain/llm/parameters/chat.rb +51 -0
- data/lib/langchain/llm/response/base_response.rb +2 -2
- data/lib/langchain/llm/response/cohere_response.rb +16 -0
- data/lib/langchain/llm/response/google_gemini_response.rb +45 -0
- data/lib/langchain/llm/response/openai_response.rb +5 -1
- data/lib/langchain/llm/unified_parameters.rb +98 -0
- data/lib/langchain/loader.rb +6 -0
- data/lib/langchain/tool/base.rb +16 -6
- data/lib/langchain/tool/calculator/calculator.json +1 -1
- data/lib/langchain/tool/database/database.json +3 -3
- data/lib/langchain/tool/file_system/file_system.json +3 -3
- data/lib/langchain/tool/news_retriever/news_retriever.json +121 -0
- data/lib/langchain/tool/news_retriever/news_retriever.rb +132 -0
- data/lib/langchain/tool/ruby_code_interpreter/ruby_code_interpreter.json +1 -1
- data/lib/langchain/tool/vectorsearch/vectorsearch.json +1 -1
- data/lib/langchain/tool/weather/weather.json +1 -1
- data/lib/langchain/tool/wikipedia/wikipedia.json +1 -1
- data/lib/langchain/tool/wikipedia/wikipedia.rb +2 -2
- data/lib/langchain/utils/token_length/openai_validator.rb +6 -1
- data/lib/langchain/version.rb +1 -1
- data/lib/langchain.rb +3 -0
- metadata +22 -15
- data/lib/langchain/assistants/message.rb +0 -58
- data/lib/langchain/llm/response/google_vertex_ai_response.rb +0 -33
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: b146eb8568d30ae12aca93a25818fcff7421b7ee2e968330f3a68c5e523da148
|
4
|
+
data.tar.gz: 33f88d7ba03501606706314dce58f626fa0df5aab50639b5f5db3df527ee6520
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 6518e30de12653426280f6f8cf05f37a6d4b311ad4219af52276bace8a75ec6440b8f42c208d9d5c07bb4218f3259cc95ea9edd77f8ba037e7a1600a7dfa3170
|
7
|
+
data.tar.gz: 7f881a4347866c8b52161adaf6b98b669e38b4e2fd1ac513f02efd1dcfe73b2552ad8e65acdd02e9e002769ad3b394850f720e0971774632c2f489c25d9ce076
|
data/CHANGELOG.md
CHANGED
@@ -1,5 +1,17 @@
|
|
1
1
|
## [Unreleased]
|
2
2
|
|
3
|
+
## [0.13.0] - 2024-05-14
|
4
|
+
- New 🛠️ `Langchain::Tool::NewsRetriever` tool to fetch news via newsapi.org
|
5
|
+
- Langchain::Assistant works with `Langchain::LLM::GoogleVertexAI` and `Langchain::LLM::GoogleGemini` llms
|
6
|
+
- [BREAKING] Introduce new `Langchain::Messages::Base` abstraction
|
7
|
+
|
8
|
+
## [0.12.1] - 2024-05-13
|
9
|
+
- Langchain::LLM::Ollama now uses `llama3` by default
|
10
|
+
- Langchain::LLM::Anthropic#complete() now uses `claude-2.1` by default
|
11
|
+
- Updated with new OpenAI models, including `gpt-4o`
|
12
|
+
- New `Langchain::LLM::Cohere#chat()` method.
|
13
|
+
- Introducing `UnifiedParameters` to unify parameters across LLM classes
|
14
|
+
|
3
15
|
## [0.12.0] - 2024-04-22
|
4
16
|
- [BREAKING] Rename `dimension` parameter to `dimensions` everywhere
|
5
17
|
|
data/README.md
CHANGED
@@ -412,6 +412,7 @@ Assistants are Agent-like objects that leverage helpful instructions, LLMs, tool
|
|
412
412
|
| "file_system" | Interacts with the file system | | |
|
413
413
|
| "ruby_code_interpreter" | Interprets Ruby expressions | | `gem "safe_ruby", "~> 1.0.4"` |
|
414
414
|
| "google_search" | A wrapper around Google Search | `ENV["SERPAPI_API_KEY"]` (https://serpapi.com/manage-api-key) | `gem "google_search_results", "~> 2.0.0"` |
|
415
|
+
| "news_retriever" | A wrapper around NewsApi.org | `ENV["NEWS_API_KEY"]` (https://newsapi.org/) | |
|
415
416
|
| "weather" | Calls Open Weather API to retrieve the current weather | `ENV["OPEN_WEATHER_API_KEY"]` (https://home.openweathermap.org/api_keys) | `gem "open-weather-ruby-client", "~> 0.3.0"` |
|
416
417
|
| "wikipedia" | Calls Wikipedia API to retrieve the summary | | `gem "wikipedia-client", "~> 1.17.0"` |
|
417
418
|
|
@@ -445,14 +446,14 @@ assistant = Langchain::Assistant.new(
|
|
445
446
|
thread: thread,
|
446
447
|
instructions: "You are a Meteorologist Assistant that is able to pull the weather for any location",
|
447
448
|
tools: [
|
448
|
-
Langchain::Tool::
|
449
|
+
Langchain::Tool::Weather.new(api_key: ENV["OPEN_WEATHER_API_KEY"])
|
449
450
|
]
|
450
451
|
)
|
451
452
|
```
|
452
453
|
### Using an Assistant
|
453
454
|
You can now add your message to an Assistant.
|
454
455
|
```ruby
|
455
|
-
assistant.add_message content: "What's the weather in New York
|
456
|
+
assistant.add_message content: "What's the weather in New York, New York?"
|
456
457
|
```
|
457
458
|
|
458
459
|
Run the Assistant to generate a response.
|
@@ -7,6 +7,12 @@ module Langchain
|
|
7
7
|
attr_reader :llm, :thread, :instructions
|
8
8
|
attr_accessor :tools
|
9
9
|
|
10
|
+
SUPPORTED_LLMS = [
|
11
|
+
Langchain::LLM::OpenAI,
|
12
|
+
Langchain::LLM::GoogleGemini,
|
13
|
+
Langchain::LLM::GoogleVertexAI
|
14
|
+
]
|
15
|
+
|
10
16
|
# Create a new assistant
|
11
17
|
#
|
12
18
|
# @param llm [Langchain::LLM::Base] LLM instance that the assistant will use
|
@@ -19,7 +25,9 @@ module Langchain
|
|
19
25
|
tools: [],
|
20
26
|
instructions: nil
|
21
27
|
)
|
22
|
-
|
28
|
+
unless SUPPORTED_LLMS.include?(llm.class)
|
29
|
+
raise ArgumentError, "Invalid LLM; currently only #{SUPPORTED_LLMS.join(", ")} are supported"
|
30
|
+
end
|
23
31
|
raise ArgumentError, "Thread must be an instance of Langchain::Thread" unless thread.is_a?(Langchain::Thread)
|
24
32
|
raise ArgumentError, "Tools must be an array of Langchain::Tool::Base instance(s)" unless tools.is_a?(Array) && tools.all? { |tool| tool.is_a?(Langchain::Tool::Base) }
|
25
33
|
|
@@ -30,7 +38,10 @@ module Langchain
|
|
30
38
|
|
31
39
|
# The first message in the thread should be the system instructions
|
32
40
|
# TODO: What if the user added old messages and the system instructions are already in there? Should this overwrite the existing instructions?
|
33
|
-
|
41
|
+
if llm.is_a?(Langchain::LLM::OpenAI)
|
42
|
+
add_message(role: "system", content: instructions) if instructions
|
43
|
+
end
|
44
|
+
# For Google Gemini, system instructions are added to the `system:` param in the `chat` method
|
34
45
|
end
|
35
46
|
|
36
47
|
# Add a user message to the thread
|
@@ -59,11 +70,12 @@ module Langchain
|
|
59
70
|
|
60
71
|
while running
|
61
72
|
# TODO: I think we need to look at all messages and not just the last one.
|
62
|
-
|
63
|
-
|
73
|
+
last_message = thread.messages.last
|
74
|
+
|
75
|
+
if last_message.system?
|
64
76
|
# Do nothing
|
65
77
|
running = false
|
66
|
-
|
78
|
+
elsif last_message.llm?
|
67
79
|
if last_message.tool_calls.any?
|
68
80
|
if auto_tool_execution
|
69
81
|
run_tools(last_message.tool_calls)
|
@@ -76,11 +88,11 @@ module Langchain
|
|
76
88
|
# Do nothing
|
77
89
|
running = false
|
78
90
|
end
|
79
|
-
|
91
|
+
elsif last_message.user?
|
80
92
|
# Run it!
|
81
93
|
response = chat_with_llm
|
82
94
|
|
83
|
-
if response.tool_calls
|
95
|
+
if response.tool_calls.any?
|
84
96
|
# Re-run the while(running) loop to process the tool calls
|
85
97
|
running = true
|
86
98
|
add_message(role: response.role, tool_calls: response.tool_calls)
|
@@ -89,12 +101,12 @@ module Langchain
|
|
89
101
|
running = false
|
90
102
|
add_message(role: response.role, content: response.chat_completion)
|
91
103
|
end
|
92
|
-
|
104
|
+
elsif last_message.tool?
|
93
105
|
# Run it!
|
94
106
|
response = chat_with_llm
|
95
107
|
running = true
|
96
108
|
|
97
|
-
if response.tool_calls
|
109
|
+
if response.tool_calls.any?
|
98
110
|
add_message(role: response.role, tool_calls: response.tool_calls)
|
99
111
|
elsif response.chat_completion
|
100
112
|
add_message(role: response.role, content: response.chat_completion)
|
@@ -121,8 +133,14 @@ module Langchain
|
|
121
133
|
# @param output [String] The output of the tool
|
122
134
|
# @return [Array<Langchain::Message>] The messages in the thread
|
123
135
|
def submit_tool_output(tool_call_id:, output:)
|
124
|
-
|
125
|
-
|
136
|
+
tool_role = if llm.is_a?(Langchain::LLM::OpenAI)
|
137
|
+
Langchain::Messages::OpenAIMessage::TOOL_ROLE
|
138
|
+
elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
|
139
|
+
Langchain::Messages::GoogleGeminiMessage::TOOL_ROLE
|
140
|
+
end
|
141
|
+
|
142
|
+
# TODO: Validate that `tool_call_id` is valid by scanning messages and checking if this tool call ID was invoked
|
143
|
+
add_message(role: tool_role, content: output, tool_call_id: tool_call_id)
|
126
144
|
end
|
127
145
|
|
128
146
|
# Delete all messages in the thread
|
@@ -156,10 +174,15 @@ module Langchain
|
|
156
174
|
def chat_with_llm
|
157
175
|
Langchain.logger.info("Sending a call to #{llm.class}", for: self.class)
|
158
176
|
|
159
|
-
params = {messages: thread.
|
177
|
+
params = {messages: thread.array_of_message_hashes}
|
160
178
|
|
161
179
|
if tools.any?
|
162
|
-
|
180
|
+
if llm.is_a?(Langchain::LLM::OpenAI)
|
181
|
+
params[:tools] = tools.map(&:to_openai_tools).flatten
|
182
|
+
elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
|
183
|
+
params[:tools] = tools.map(&:to_google_gemini_tools).flatten
|
184
|
+
params[:system] = instructions if instructions
|
185
|
+
end
|
163
186
|
# TODO: Not sure that tool_choice should always be "auto"; Maybe we can let the user toggle it.
|
164
187
|
params[:tool_choice] = "auto"
|
165
188
|
end
|
@@ -173,11 +196,11 @@ module Langchain
|
|
173
196
|
def run_tools(tool_calls)
|
174
197
|
# Iterate over each function invocation and submit tool output
|
175
198
|
tool_calls.each do |tool_call|
|
176
|
-
tool_call_id =
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
199
|
+
tool_call_id, tool_name, method_name, tool_arguments = if llm.is_a?(Langchain::LLM::OpenAI)
|
200
|
+
extract_openai_tool_call(tool_call: tool_call)
|
201
|
+
elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
|
202
|
+
extract_google_gemini_tool_call(tool_call: tool_call)
|
203
|
+
end
|
181
204
|
|
182
205
|
tool_instance = tools.find do |t|
|
183
206
|
t.name == tool_name
|
@@ -190,13 +213,41 @@ module Langchain
|
|
190
213
|
|
191
214
|
response = chat_with_llm
|
192
215
|
|
193
|
-
if response.tool_calls
|
216
|
+
if response.tool_calls.any?
|
194
217
|
add_message(role: response.role, tool_calls: response.tool_calls)
|
195
218
|
elsif response.chat_completion
|
196
219
|
add_message(role: response.role, content: response.chat_completion)
|
197
220
|
end
|
198
221
|
end
|
199
222
|
|
223
|
+
# Extract the tool call information from the OpenAI tool call hash
|
224
|
+
#
|
225
|
+
# @param tool_call [Hash] The tool call hash
|
226
|
+
# @return [Array] The tool call information
|
227
|
+
def extract_openai_tool_call(tool_call:)
|
228
|
+
tool_call_id = tool_call.dig("id")
|
229
|
+
|
230
|
+
function_name = tool_call.dig("function", "name")
|
231
|
+
tool_name, method_name = function_name.split("__")
|
232
|
+
tool_arguments = JSON.parse(tool_call.dig("function", "arguments"), symbolize_names: true)
|
233
|
+
|
234
|
+
[tool_call_id, tool_name, method_name, tool_arguments]
|
235
|
+
end
|
236
|
+
|
237
|
+
# Extract the tool call information from the Google Gemini tool call hash
|
238
|
+
#
|
239
|
+
# @param tool_call [Hash] The tool call hash, format: {"functionCall"=>{"name"=>"weather__execute", "args"=>{"input"=>"NYC"}}}
|
240
|
+
# @return [Array] The tool call information
|
241
|
+
def extract_google_gemini_tool_call(tool_call:)
|
242
|
+
tool_call_id = tool_call.dig("functionCall", "name")
|
243
|
+
|
244
|
+
function_name = tool_call.dig("functionCall", "name")
|
245
|
+
tool_name, method_name = function_name.split("__")
|
246
|
+
tool_arguments = tool_call.dig("functionCall", "args").transform_keys(&:to_sym)
|
247
|
+
|
248
|
+
[tool_call_id, tool_name, method_name, tool_arguments]
|
249
|
+
end
|
250
|
+
|
200
251
|
# Build a message
|
201
252
|
#
|
202
253
|
# @param role [String] The role of the message
|
@@ -205,7 +256,11 @@ module Langchain
|
|
205
256
|
# @param tool_call_id [String] The ID of the tool call to include in the message
|
206
257
|
# @return [Langchain::Message] The Message object
|
207
258
|
def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
|
208
|
-
|
259
|
+
if llm.is_a?(Langchain::LLM::OpenAI)
|
260
|
+
Langchain::Messages::OpenAIMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
|
261
|
+
elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
|
262
|
+
Langchain::Messages::GoogleGeminiMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
|
263
|
+
end
|
209
264
|
end
|
210
265
|
|
211
266
|
# TODO: Fix the message truncation when context window is exceeded
|
@@ -0,0 +1,16 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Langchain
|
4
|
+
module Messages
|
5
|
+
class Base
|
6
|
+
attr_reader :role, :content, :tool_calls, :tool_call_id
|
7
|
+
|
8
|
+
# Check if the message came from a user
|
9
|
+
#
|
10
|
+
# @param [Boolean] true/false whether the message came from a user
|
11
|
+
def user?
|
12
|
+
role == "user"
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
@@ -0,0 +1,90 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Langchain
|
4
|
+
module Messages
|
5
|
+
class GoogleGeminiMessage < Base
|
6
|
+
# Google Gemini uses the following roles:
|
7
|
+
ROLES = [
|
8
|
+
"user",
|
9
|
+
"model",
|
10
|
+
"function"
|
11
|
+
].freeze
|
12
|
+
|
13
|
+
TOOL_ROLE = "function"
|
14
|
+
|
15
|
+
# Initialize a new Google Gemini message
|
16
|
+
#
|
17
|
+
# @param [String] The role of the message
|
18
|
+
# @param [String] The content of the message
|
19
|
+
# @param [Array<Hash>] The tool calls made in the message
|
20
|
+
# @param [String] The ID of the tool call
|
21
|
+
def initialize(role:, content: nil, tool_calls: [], tool_call_id: nil)
|
22
|
+
raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role)
|
23
|
+
raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) }
|
24
|
+
|
25
|
+
@role = role
|
26
|
+
# Some Tools return content as a JSON hence `.to_s`
|
27
|
+
@content = content.to_s
|
28
|
+
@tool_calls = tool_calls
|
29
|
+
@tool_call_id = tool_call_id
|
30
|
+
end
|
31
|
+
|
32
|
+
# Check if the message came from an LLM
|
33
|
+
#
|
34
|
+
# @return [Boolean] true/false whether this message was produced by an LLM
|
35
|
+
def llm?
|
36
|
+
model?
|
37
|
+
end
|
38
|
+
|
39
|
+
# Convert the message to a Google Gemini API-compatible hash
|
40
|
+
#
|
41
|
+
# @return [Hash] The message as a Google Gemini API-compatible hash
|
42
|
+
def to_hash
|
43
|
+
{}.tap do |h|
|
44
|
+
h[:role] = role
|
45
|
+
h[:parts] = if function?
|
46
|
+
[{
|
47
|
+
functionResponse: {
|
48
|
+
name: tool_call_id,
|
49
|
+
response: {
|
50
|
+
name: tool_call_id,
|
51
|
+
content: content
|
52
|
+
}
|
53
|
+
}
|
54
|
+
}]
|
55
|
+
elsif tool_calls.any?
|
56
|
+
tool_calls
|
57
|
+
else
|
58
|
+
[{text: content}]
|
59
|
+
end
|
60
|
+
end
|
61
|
+
end
|
62
|
+
|
63
|
+
# Google Gemini does not implement system prompts
|
64
|
+
def system?
|
65
|
+
false
|
66
|
+
end
|
67
|
+
|
68
|
+
# Check if the message is a tool call
|
69
|
+
#
|
70
|
+
# @return [Boolean] true/false whether this message is a tool call
|
71
|
+
def tool?
|
72
|
+
function?
|
73
|
+
end
|
74
|
+
|
75
|
+
# Check if the message is a tool call
|
76
|
+
#
|
77
|
+
# @return [Boolean] true/false whether this message is a tool call
|
78
|
+
def function?
|
79
|
+
role == "function"
|
80
|
+
end
|
81
|
+
|
82
|
+
# Check if the message came from an LLM
|
83
|
+
#
|
84
|
+
# @return [Boolean] true/false whether this message was produced by an LLM
|
85
|
+
def model?
|
86
|
+
role == "model"
|
87
|
+
end
|
88
|
+
end
|
89
|
+
end
|
90
|
+
end
|
@@ -0,0 +1,74 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Langchain
|
4
|
+
module Messages
|
5
|
+
class OpenAIMessage < Base
|
6
|
+
# OpenAI uses the following roles:
|
7
|
+
ROLES = [
|
8
|
+
"system",
|
9
|
+
"assistant",
|
10
|
+
"user",
|
11
|
+
"tool"
|
12
|
+
].freeze
|
13
|
+
|
14
|
+
TOOL_ROLE = "tool"
|
15
|
+
|
16
|
+
# Initialize a new OpenAI message
|
17
|
+
#
|
18
|
+
# @param [String] The role of the message
|
19
|
+
# @param [String] The content of the message
|
20
|
+
# @param [Array<Hash>] The tool calls made in the message
|
21
|
+
# @param [String] The ID of the tool call
|
22
|
+
def initialize(role:, content: nil, tool_calls: [], tool_call_id: nil) # TODO: Implement image_file: reference (https://platform.openai.com/docs/api-reference/messages/object#messages/object-content)
|
23
|
+
raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role)
|
24
|
+
raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) }
|
25
|
+
|
26
|
+
@role = role
|
27
|
+
# Some Tools return content as a JSON hence `.to_s`
|
28
|
+
@content = content.to_s
|
29
|
+
@tool_calls = tool_calls
|
30
|
+
@tool_call_id = tool_call_id
|
31
|
+
end
|
32
|
+
|
33
|
+
# Check if the message came from an LLM
|
34
|
+
#
|
35
|
+
# @return [Boolean] true/false whether this message was produced by an LLM
|
36
|
+
def llm?
|
37
|
+
assistant?
|
38
|
+
end
|
39
|
+
|
40
|
+
# Convert the message to an OpenAI API-compatible hash
|
41
|
+
#
|
42
|
+
# @return [Hash] The message as an OpenAI API-compatible hash
|
43
|
+
def to_hash
|
44
|
+
{}.tap do |h|
|
45
|
+
h[:role] = role
|
46
|
+
h[:content] = content if content # Content is nil for tool calls
|
47
|
+
h[:tool_calls] = tool_calls if tool_calls.any?
|
48
|
+
h[:tool_call_id] = tool_call_id if tool_call_id
|
49
|
+
end
|
50
|
+
end
|
51
|
+
|
52
|
+
# Check if the message came from an LLM
|
53
|
+
#
|
54
|
+
# @return [Boolean] true/false whether this message was produced by an LLM
|
55
|
+
def assistant?
|
56
|
+
role == "assistant"
|
57
|
+
end
|
58
|
+
|
59
|
+
# Check if the message are system instructions
|
60
|
+
#
|
61
|
+
# @return [Boolean] true/false whether this message are system instructions
|
62
|
+
def system?
|
63
|
+
role == "system"
|
64
|
+
end
|
65
|
+
|
66
|
+
# Check if the message is a tool call
|
67
|
+
#
|
68
|
+
# @return [Boolean] true/false whether this message is a tool call
|
69
|
+
def tool?
|
70
|
+
role == "tool"
|
71
|
+
end
|
72
|
+
end
|
73
|
+
end
|
74
|
+
end
|
@@ -8,16 +8,16 @@ module Langchain
|
|
8
8
|
|
9
9
|
# @param messages [Array<Langchain::Message>]
|
10
10
|
def initialize(messages: [])
|
11
|
-
raise ArgumentError, "messages array must only contain Langchain::Message instance(s)" unless messages.is_a?(Array) && messages.all? { |m| m.is_a?(Langchain::
|
11
|
+
raise ArgumentError, "messages array must only contain Langchain::Message instance(s)" unless messages.is_a?(Array) && messages.all? { |m| m.is_a?(Langchain::Messages::Base) }
|
12
12
|
|
13
13
|
@messages = messages
|
14
14
|
end
|
15
15
|
|
16
|
-
# Convert the thread to an
|
16
|
+
# Convert the thread to an LLM APIs-compatible array of hashes
|
17
17
|
#
|
18
18
|
# @return [Array<Hash>] The thread as an OpenAI API-compatible array of hashes
|
19
|
-
def
|
20
|
-
messages.map(&:
|
19
|
+
def array_of_message_hashes
|
20
|
+
messages.map(&:to_hash)
|
21
21
|
end
|
22
22
|
|
23
23
|
# Add a message to the thread
|
@@ -25,7 +25,7 @@ module Langchain
|
|
25
25
|
# @param message [Langchain::Message] The message to add
|
26
26
|
# @return [Array<Langchain::Message>] The updated messages array
|
27
27
|
def add_message(message)
|
28
|
-
raise ArgumentError, "message must be a Langchain::Message instance" unless message.is_a?(Langchain::
|
28
|
+
raise ArgumentError, "message must be a Langchain::Message instance" unless message.is_a?(Langchain::Messages::Base)
|
29
29
|
|
30
30
|
# Prepend the message to the thread
|
31
31
|
messages << message
|
@@ -13,7 +13,7 @@ module Langchain::LLM
|
|
13
13
|
class Anthropic < Base
|
14
14
|
DEFAULTS = {
|
15
15
|
temperature: 0.0,
|
16
|
-
completion_model_name: "claude-2",
|
16
|
+
completion_model_name: "claude-2.1",
|
17
17
|
chat_completion_model_name: "claude-3-sonnet-20240229",
|
18
18
|
max_tokens_to_sample: 256
|
19
19
|
}.freeze
|
@@ -32,6 +32,15 @@ module Langchain::LLM
|
|
32
32
|
|
33
33
|
@client = ::Anthropic::Client.new(access_token: api_key, **llm_options)
|
34
34
|
@defaults = DEFAULTS.merge(default_options)
|
35
|
+
chat_parameters.update(
|
36
|
+
model: {default: @defaults[:chat_completion_model_name]},
|
37
|
+
temperature: {default: @defaults[:temperature]},
|
38
|
+
max_tokens: {default: @defaults[:max_tokens_to_sample]},
|
39
|
+
metadata: {},
|
40
|
+
system: {}
|
41
|
+
)
|
42
|
+
chat_parameters.ignore(:n, :user)
|
43
|
+
chat_parameters.remap(stop: :stop_sequences)
|
35
44
|
end
|
36
45
|
|
37
46
|
# Generate a completion for a given prompt
|
@@ -72,66 +81,35 @@ module Langchain::LLM
|
|
72
81
|
parameters[:metadata] = metadata if metadata
|
73
82
|
parameters[:stream] = stream if stream
|
74
83
|
|
75
|
-
# TODO: Implement token length validator for Anthropic
|
76
|
-
# parameters[:max_tokens_to_sample] = validate_max_tokens(prompt, parameters[:completion_model_name])
|
77
|
-
|
78
84
|
response = client.complete(parameters: parameters)
|
79
85
|
Langchain::LLM::AnthropicResponse.new(response)
|
80
86
|
end
|
81
87
|
|
82
88
|
# Generate a chat completion for given messages
|
83
89
|
#
|
84
|
-
# @param
|
85
|
-
# @
|
86
|
-
# @
|
87
|
-
# @
|
88
|
-
# @
|
89
|
-
# @
|
90
|
-
# @
|
91
|
-
# @
|
92
|
-
# @
|
93
|
-
# @
|
94
|
-
# @
|
90
|
+
# @param [Hash] params unified chat parmeters from [Langchain::LLM::Parameters::Chat::SCHEMA]
|
91
|
+
# @option params [Array<String>] :messages Input messages
|
92
|
+
# @option params [String] :model The model that will complete your prompt
|
93
|
+
# @option params [Integer] :max_tokens Maximum number of tokens to generate before stopping
|
94
|
+
# @option params [Hash] :metadata Object describing metadata about the request
|
95
|
+
# @option params [Array<String>] :stop_sequences Custom text sequences that will cause the model to stop generating
|
96
|
+
# @option params [Boolean] :stream Whether to incrementally stream the response using server-sent events
|
97
|
+
# @option params [String] :system System prompt
|
98
|
+
# @option params [Float] :temperature Amount of randomness injected into the response
|
99
|
+
# @option params [Array<String>] :tools Definitions of tools that the model may use
|
100
|
+
# @option params [Integer] :top_k Only sample from the top K options for each subsequent token
|
101
|
+
# @option params [Float] :top_p Use nucleus sampling.
|
95
102
|
# @return [Langchain::LLM::AnthropicResponse] The chat completion
|
96
|
-
def chat(
|
97
|
-
|
98
|
-
model: @defaults[:chat_completion_model_name],
|
99
|
-
max_tokens: @defaults[:max_tokens_to_sample],
|
100
|
-
metadata: nil,
|
101
|
-
stop_sequences: nil,
|
102
|
-
stream: nil,
|
103
|
-
system: nil,
|
104
|
-
temperature: @defaults[:temperature],
|
105
|
-
tools: [],
|
106
|
-
top_k: nil,
|
107
|
-
top_p: nil
|
108
|
-
)
|
109
|
-
raise ArgumentError.new("messages argument is required") if messages.empty?
|
110
|
-
raise ArgumentError.new("model argument is required") if model.empty?
|
111
|
-
raise ArgumentError.new("max_tokens argument is required") if max_tokens.nil?
|
103
|
+
def chat(params = {})
|
104
|
+
parameters = chat_parameters.to_params(params)
|
112
105
|
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
max_tokens: max_tokens,
|
117
|
-
temperature: temperature
|
118
|
-
}
|
119
|
-
parameters[:metadata] = metadata if metadata
|
120
|
-
parameters[:stop_sequences] = stop_sequences if stop_sequences
|
121
|
-
parameters[:stream] = stream if stream
|
122
|
-
parameters[:system] = system if system
|
123
|
-
parameters[:tools] = tools if tools.any?
|
124
|
-
parameters[:top_k] = top_k if top_k
|
125
|
-
parameters[:top_p] = top_p if top_p
|
106
|
+
raise ArgumentError.new("messages argument is required") if Array(parameters[:messages]).empty?
|
107
|
+
raise ArgumentError.new("model argument is required") if parameters[:model].empty?
|
108
|
+
raise ArgumentError.new("max_tokens argument is required") if parameters[:max_tokens].nil?
|
126
109
|
|
127
110
|
response = client.messages(parameters: parameters)
|
128
111
|
|
129
112
|
Langchain::LLM::AnthropicResponse.new(response)
|
130
113
|
end
|
131
|
-
|
132
|
-
# TODO: Implement token length validator for Anthropic
|
133
|
-
# def validate_max_tokens(messages, model)
|
134
|
-
# LENGTH_VALIDATOR.validate_max_tokens!(messages, model)
|
135
|
-
# end
|
136
114
|
end
|
137
115
|
end
|
@@ -59,6 +59,17 @@ module Langchain::LLM
|
|
59
59
|
@defaults = DEFAULTS.merge(default_options)
|
60
60
|
.merge(completion_model_name: completion_model)
|
61
61
|
.merge(embedding_model_name: embedding_model)
|
62
|
+
|
63
|
+
chat_parameters.update(
|
64
|
+
model: {default: @defaults[:chat_completion_model_name]},
|
65
|
+
temperature: {},
|
66
|
+
max_tokens: {default: @defaults[:max_tokens_to_sample]},
|
67
|
+
metadata: {},
|
68
|
+
system: {},
|
69
|
+
anthropic_version: {default: "bedrock-2023-05-31"}
|
70
|
+
)
|
71
|
+
chat_parameters.ignore(:n, :user)
|
72
|
+
chat_parameters.remap(stop: :stop_sequences)
|
62
73
|
end
|
63
74
|
|
64
75
|
#
|
@@ -113,43 +124,28 @@ module Langchain::LLM
|
|
113
124
|
# Generate a chat completion for a given prompt
|
114
125
|
# Currently only configured to work with the Anthropic provider and
|
115
126
|
# the claude-3 model family
|
116
|
-
#
|
117
|
-
# @param
|
118
|
-
# @
|
119
|
-
# @
|
120
|
-
# @
|
121
|
-
# @
|
122
|
-
# @
|
123
|
-
# @
|
127
|
+
#
|
128
|
+
# @param [Hash] params unified chat parmeters from [Langchain::LLM::Parameters::Chat::SCHEMA]
|
129
|
+
# @option params [Array<String>] :messages The messages to generate a completion for
|
130
|
+
# @option params [String] :system The system prompt to provide instructions
|
131
|
+
# @option params [String] :model The model to use for completion defaults to @defaults[:chat_completion_model_name]
|
132
|
+
# @option params [Integer] :max_tokens The maximum number of tokens to generate defaults to @defaults[:max_tokens_to_sample]
|
133
|
+
# @option params [Array<String>] :stop The stop sequences to use for completion
|
134
|
+
# @option params [Array<String>] :stop_sequences The stop sequences to use for completion
|
135
|
+
# @option params [Float] :temperature The temperature to use for completion
|
136
|
+
# @option params [Float] :top_p Use nucleus sampling.
|
137
|
+
# @option params [Integer] :top_k Only sample from the top K options for each subsequent token
|
124
138
|
# @return [Langchain::LLM::AnthropicMessagesResponse] Response object
|
125
|
-
def chat(
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
temperature: nil,
|
132
|
-
top_p: nil,
|
133
|
-
top_k: nil
|
134
|
-
)
|
135
|
-
raise ArgumentError.new("messages argument is required") if messages.empty?
|
136
|
-
|
137
|
-
raise "Model #{model} does not support chat completions." unless Langchain::LLM::AwsBedrock::SUPPORTED_CHAT_COMPLETION_PROVIDERS.include?(completion_provider)
|
138
|
-
|
139
|
-
inference_parameters = {
|
140
|
-
messages: messages,
|
141
|
-
max_tokens: max_tokens,
|
142
|
-
anthropic_version: @defaults[:anthropic_version]
|
143
|
-
}
|
144
|
-
inference_parameters[:system] = system if system
|
145
|
-
inference_parameters[:stop_sequences] = stop_sequences if stop_sequences
|
146
|
-
inference_parameters[:temperature] = temperature if temperature
|
147
|
-
inference_parameters[:top_p] = top_p if top_p
|
148
|
-
inference_parameters[:top_k] = top_k if top_k
|
139
|
+
def chat(params = {})
|
140
|
+
parameters = chat_parameters.to_params(params)
|
141
|
+
|
142
|
+
raise ArgumentError.new("messages argument is required") if Array(parameters[:messages]).empty?
|
143
|
+
|
144
|
+
raise "Model #{parameters[:model]} does not support chat completions." unless Langchain::LLM::AwsBedrock::SUPPORTED_CHAT_COMPLETION_PROVIDERS.include?(completion_provider)
|
149
145
|
|
150
146
|
response = client.invoke_model({
|
151
|
-
model_id: model,
|
152
|
-
body:
|
147
|
+
model_id: parameters[:model],
|
148
|
+
body: parameters.except(:model).to_json,
|
153
149
|
content_type: "application/json",
|
154
150
|
accept: "application/json"
|
155
151
|
})
|
data/lib/langchain/llm/azure.rb
CHANGED