langchainrb 0.12.1 → 0.13.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +3 -2
- data/lib/langchain/assistants/assistant.rb +75 -20
- data/lib/langchain/assistants/messages/base.rb +16 -0
- data/lib/langchain/assistants/messages/google_gemini_message.rb +90 -0
- data/lib/langchain/assistants/messages/openai_message.rb +74 -0
- data/lib/langchain/assistants/thread.rb +5 -5
- data/lib/langchain/llm/base.rb +2 -1
- data/lib/langchain/llm/google_gemini.rb +67 -0
- data/lib/langchain/llm/google_vertex_ai.rb +68 -112
- data/lib/langchain/llm/response/google_gemini_response.rb +45 -0
- data/lib/langchain/llm/response/openai_response.rb +5 -1
- data/lib/langchain/tool/base.rb +11 -1
- data/lib/langchain/tool/calculator/calculator.json +1 -1
- data/lib/langchain/tool/database/database.json +3 -3
- data/lib/langchain/tool/file_system/file_system.json +3 -3
- data/lib/langchain/tool/news_retriever/news_retriever.json +121 -0
- data/lib/langchain/tool/news_retriever/news_retriever.rb +132 -0
- data/lib/langchain/tool/ruby_code_interpreter/ruby_code_interpreter.json +1 -1
- data/lib/langchain/tool/vectorsearch/vectorsearch.json +1 -1
- data/lib/langchain/tool/weather/weather.json +1 -1
- data/lib/langchain/tool/wikipedia/wikipedia.json +1 -1
- data/lib/langchain/tool/wikipedia/wikipedia.rb +2 -2
- data/lib/langchain/version.rb +1 -1
- data/lib/langchain.rb +3 -0
- metadata +14 -9
- data/lib/langchain/assistants/message.rb +0 -58
- data/lib/langchain/llm/response/google_vertex_ai_response.rb +0 -33
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: b146eb8568d30ae12aca93a25818fcff7421b7ee2e968330f3a68c5e523da148
|
4
|
+
data.tar.gz: 33f88d7ba03501606706314dce58f626fa0df5aab50639b5f5db3df527ee6520
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 6518e30de12653426280f6f8cf05f37a6d4b311ad4219af52276bace8a75ec6440b8f42c208d9d5c07bb4218f3259cc95ea9edd77f8ba037e7a1600a7dfa3170
|
7
|
+
data.tar.gz: 7f881a4347866c8b52161adaf6b98b669e38b4e2fd1ac513f02efd1dcfe73b2552ad8e65acdd02e9e002769ad3b394850f720e0971774632c2f489c25d9ce076
|
data/CHANGELOG.md
CHANGED
@@ -1,5 +1,10 @@
|
|
1
1
|
## [Unreleased]
|
2
2
|
|
3
|
+
## [0.13.0] - 2024-05-14
|
4
|
+
- New 🛠️ `Langchain::Tool::NewsRetriever` tool to fetch news via newsapi.org
|
5
|
+
- Langchain::Assistant works with `Langchain::LLM::GoogleVertexAI` and `Langchain::LLM::GoogleGemini` llms
|
6
|
+
- [BREAKING] Introduce new `Langchain::Messages::Base` abstraction
|
7
|
+
|
3
8
|
## [0.12.1] - 2024-05-13
|
4
9
|
- Langchain::LLM::Ollama now uses `llama3` by default
|
5
10
|
- Langchain::LLM::Anthropic#complete() now uses `claude-2.1` by default
|
data/README.md
CHANGED
@@ -412,6 +412,7 @@ Assistants are Agent-like objects that leverage helpful instructions, LLMs, tool
|
|
412
412
|
| "file_system" | Interacts with the file system | | |
|
413
413
|
| "ruby_code_interpreter" | Interprets Ruby expressions | | `gem "safe_ruby", "~> 1.0.4"` |
|
414
414
|
| "google_search" | A wrapper around Google Search | `ENV["SERPAPI_API_KEY"]` (https://serpapi.com/manage-api-key) | `gem "google_search_results", "~> 2.0.0"` |
|
415
|
+
| "news_retriever" | A wrapper around NewsApi.org | `ENV["NEWS_API_KEY"]` (https://newsapi.org/) | |
|
415
416
|
| "weather" | Calls Open Weather API to retrieve the current weather | `ENV["OPEN_WEATHER_API_KEY"]` (https://home.openweathermap.org/api_keys) | `gem "open-weather-ruby-client", "~> 0.3.0"` |
|
416
417
|
| "wikipedia" | Calls Wikipedia API to retrieve the summary | | `gem "wikipedia-client", "~> 1.17.0"` |
|
417
418
|
|
@@ -445,14 +446,14 @@ assistant = Langchain::Assistant.new(
|
|
445
446
|
thread: thread,
|
446
447
|
instructions: "You are a Meteorologist Assistant that is able to pull the weather for any location",
|
447
448
|
tools: [
|
448
|
-
Langchain::Tool::
|
449
|
+
Langchain::Tool::Weather.new(api_key: ENV["OPEN_WEATHER_API_KEY"])
|
449
450
|
]
|
450
451
|
)
|
451
452
|
```
|
452
453
|
### Using an Assistant
|
453
454
|
You can now add your message to an Assistant.
|
454
455
|
```ruby
|
455
|
-
assistant.add_message content: "What's the weather in New York
|
456
|
+
assistant.add_message content: "What's the weather in New York, New York?"
|
456
457
|
```
|
457
458
|
|
458
459
|
Run the Assistant to generate a response.
|
@@ -7,6 +7,12 @@ module Langchain
|
|
7
7
|
attr_reader :llm, :thread, :instructions
|
8
8
|
attr_accessor :tools
|
9
9
|
|
10
|
+
SUPPORTED_LLMS = [
|
11
|
+
Langchain::LLM::OpenAI,
|
12
|
+
Langchain::LLM::GoogleGemini,
|
13
|
+
Langchain::LLM::GoogleVertexAI
|
14
|
+
]
|
15
|
+
|
10
16
|
# Create a new assistant
|
11
17
|
#
|
12
18
|
# @param llm [Langchain::LLM::Base] LLM instance that the assistant will use
|
@@ -19,7 +25,9 @@ module Langchain
|
|
19
25
|
tools: [],
|
20
26
|
instructions: nil
|
21
27
|
)
|
22
|
-
|
28
|
+
unless SUPPORTED_LLMS.include?(llm.class)
|
29
|
+
raise ArgumentError, "Invalid LLM; currently only #{SUPPORTED_LLMS.join(", ")} are supported"
|
30
|
+
end
|
23
31
|
raise ArgumentError, "Thread must be an instance of Langchain::Thread" unless thread.is_a?(Langchain::Thread)
|
24
32
|
raise ArgumentError, "Tools must be an array of Langchain::Tool::Base instance(s)" unless tools.is_a?(Array) && tools.all? { |tool| tool.is_a?(Langchain::Tool::Base) }
|
25
33
|
|
@@ -30,7 +38,10 @@ module Langchain
|
|
30
38
|
|
31
39
|
# The first message in the thread should be the system instructions
|
32
40
|
# TODO: What if the user added old messages and the system instructions are already in there? Should this overwrite the existing instructions?
|
33
|
-
|
41
|
+
if llm.is_a?(Langchain::LLM::OpenAI)
|
42
|
+
add_message(role: "system", content: instructions) if instructions
|
43
|
+
end
|
44
|
+
# For Google Gemini, system instructions are added to the `system:` param in the `chat` method
|
34
45
|
end
|
35
46
|
|
36
47
|
# Add a user message to the thread
|
@@ -59,11 +70,12 @@ module Langchain
|
|
59
70
|
|
60
71
|
while running
|
61
72
|
# TODO: I think we need to look at all messages and not just the last one.
|
62
|
-
|
63
|
-
|
73
|
+
last_message = thread.messages.last
|
74
|
+
|
75
|
+
if last_message.system?
|
64
76
|
# Do nothing
|
65
77
|
running = false
|
66
|
-
|
78
|
+
elsif last_message.llm?
|
67
79
|
if last_message.tool_calls.any?
|
68
80
|
if auto_tool_execution
|
69
81
|
run_tools(last_message.tool_calls)
|
@@ -76,11 +88,11 @@ module Langchain
|
|
76
88
|
# Do nothing
|
77
89
|
running = false
|
78
90
|
end
|
79
|
-
|
91
|
+
elsif last_message.user?
|
80
92
|
# Run it!
|
81
93
|
response = chat_with_llm
|
82
94
|
|
83
|
-
if response.tool_calls
|
95
|
+
if response.tool_calls.any?
|
84
96
|
# Re-run the while(running) loop to process the tool calls
|
85
97
|
running = true
|
86
98
|
add_message(role: response.role, tool_calls: response.tool_calls)
|
@@ -89,12 +101,12 @@ module Langchain
|
|
89
101
|
running = false
|
90
102
|
add_message(role: response.role, content: response.chat_completion)
|
91
103
|
end
|
92
|
-
|
104
|
+
elsif last_message.tool?
|
93
105
|
# Run it!
|
94
106
|
response = chat_with_llm
|
95
107
|
running = true
|
96
108
|
|
97
|
-
if response.tool_calls
|
109
|
+
if response.tool_calls.any?
|
98
110
|
add_message(role: response.role, tool_calls: response.tool_calls)
|
99
111
|
elsif response.chat_completion
|
100
112
|
add_message(role: response.role, content: response.chat_completion)
|
@@ -121,8 +133,14 @@ module Langchain
|
|
121
133
|
# @param output [String] The output of the tool
|
122
134
|
# @return [Array<Langchain::Message>] The messages in the thread
|
123
135
|
def submit_tool_output(tool_call_id:, output:)
|
124
|
-
|
125
|
-
|
136
|
+
tool_role = if llm.is_a?(Langchain::LLM::OpenAI)
|
137
|
+
Langchain::Messages::OpenAIMessage::TOOL_ROLE
|
138
|
+
elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
|
139
|
+
Langchain::Messages::GoogleGeminiMessage::TOOL_ROLE
|
140
|
+
end
|
141
|
+
|
142
|
+
# TODO: Validate that `tool_call_id` is valid by scanning messages and checking if this tool call ID was invoked
|
143
|
+
add_message(role: tool_role, content: output, tool_call_id: tool_call_id)
|
126
144
|
end
|
127
145
|
|
128
146
|
# Delete all messages in the thread
|
@@ -156,10 +174,15 @@ module Langchain
|
|
156
174
|
def chat_with_llm
|
157
175
|
Langchain.logger.info("Sending a call to #{llm.class}", for: self.class)
|
158
176
|
|
159
|
-
params = {messages: thread.
|
177
|
+
params = {messages: thread.array_of_message_hashes}
|
160
178
|
|
161
179
|
if tools.any?
|
162
|
-
|
180
|
+
if llm.is_a?(Langchain::LLM::OpenAI)
|
181
|
+
params[:tools] = tools.map(&:to_openai_tools).flatten
|
182
|
+
elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
|
183
|
+
params[:tools] = tools.map(&:to_google_gemini_tools).flatten
|
184
|
+
params[:system] = instructions if instructions
|
185
|
+
end
|
163
186
|
# TODO: Not sure that tool_choice should always be "auto"; Maybe we can let the user toggle it.
|
164
187
|
params[:tool_choice] = "auto"
|
165
188
|
end
|
@@ -173,11 +196,11 @@ module Langchain
|
|
173
196
|
def run_tools(tool_calls)
|
174
197
|
# Iterate over each function invocation and submit tool output
|
175
198
|
tool_calls.each do |tool_call|
|
176
|
-
tool_call_id =
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
199
|
+
tool_call_id, tool_name, method_name, tool_arguments = if llm.is_a?(Langchain::LLM::OpenAI)
|
200
|
+
extract_openai_tool_call(tool_call: tool_call)
|
201
|
+
elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
|
202
|
+
extract_google_gemini_tool_call(tool_call: tool_call)
|
203
|
+
end
|
181
204
|
|
182
205
|
tool_instance = tools.find do |t|
|
183
206
|
t.name == tool_name
|
@@ -190,13 +213,41 @@ module Langchain
|
|
190
213
|
|
191
214
|
response = chat_with_llm
|
192
215
|
|
193
|
-
if response.tool_calls
|
216
|
+
if response.tool_calls.any?
|
194
217
|
add_message(role: response.role, tool_calls: response.tool_calls)
|
195
218
|
elsif response.chat_completion
|
196
219
|
add_message(role: response.role, content: response.chat_completion)
|
197
220
|
end
|
198
221
|
end
|
199
222
|
|
223
|
+
# Extract the tool call information from the OpenAI tool call hash
|
224
|
+
#
|
225
|
+
# @param tool_call [Hash] The tool call hash
|
226
|
+
# @return [Array] The tool call information
|
227
|
+
def extract_openai_tool_call(tool_call:)
|
228
|
+
tool_call_id = tool_call.dig("id")
|
229
|
+
|
230
|
+
function_name = tool_call.dig("function", "name")
|
231
|
+
tool_name, method_name = function_name.split("__")
|
232
|
+
tool_arguments = JSON.parse(tool_call.dig("function", "arguments"), symbolize_names: true)
|
233
|
+
|
234
|
+
[tool_call_id, tool_name, method_name, tool_arguments]
|
235
|
+
end
|
236
|
+
|
237
|
+
# Extract the tool call information from the Google Gemini tool call hash
|
238
|
+
#
|
239
|
+
# @param tool_call [Hash] The tool call hash, format: {"functionCall"=>{"name"=>"weather__execute", "args"=>{"input"=>"NYC"}}}
|
240
|
+
# @return [Array] The tool call information
|
241
|
+
def extract_google_gemini_tool_call(tool_call:)
|
242
|
+
tool_call_id = tool_call.dig("functionCall", "name")
|
243
|
+
|
244
|
+
function_name = tool_call.dig("functionCall", "name")
|
245
|
+
tool_name, method_name = function_name.split("__")
|
246
|
+
tool_arguments = tool_call.dig("functionCall", "args").transform_keys(&:to_sym)
|
247
|
+
|
248
|
+
[tool_call_id, tool_name, method_name, tool_arguments]
|
249
|
+
end
|
250
|
+
|
200
251
|
# Build a message
|
201
252
|
#
|
202
253
|
# @param role [String] The role of the message
|
@@ -205,7 +256,11 @@ module Langchain
|
|
205
256
|
# @param tool_call_id [String] The ID of the tool call to include in the message
|
206
257
|
# @return [Langchain::Message] The Message object
|
207
258
|
def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
|
208
|
-
|
259
|
+
if llm.is_a?(Langchain::LLM::OpenAI)
|
260
|
+
Langchain::Messages::OpenAIMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
|
261
|
+
elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
|
262
|
+
Langchain::Messages::GoogleGeminiMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
|
263
|
+
end
|
209
264
|
end
|
210
265
|
|
211
266
|
# TODO: Fix the message truncation when context window is exceeded
|
@@ -0,0 +1,16 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Langchain
|
4
|
+
module Messages
|
5
|
+
class Base
|
6
|
+
attr_reader :role, :content, :tool_calls, :tool_call_id
|
7
|
+
|
8
|
+
# Check if the message came from a user
|
9
|
+
#
|
10
|
+
# @param [Boolean] true/false whether the message came from a user
|
11
|
+
def user?
|
12
|
+
role == "user"
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
@@ -0,0 +1,90 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Langchain
|
4
|
+
module Messages
|
5
|
+
class GoogleGeminiMessage < Base
|
6
|
+
# Google Gemini uses the following roles:
|
7
|
+
ROLES = [
|
8
|
+
"user",
|
9
|
+
"model",
|
10
|
+
"function"
|
11
|
+
].freeze
|
12
|
+
|
13
|
+
TOOL_ROLE = "function"
|
14
|
+
|
15
|
+
# Initialize a new Google Gemini message
|
16
|
+
#
|
17
|
+
# @param [String] The role of the message
|
18
|
+
# @param [String] The content of the message
|
19
|
+
# @param [Array<Hash>] The tool calls made in the message
|
20
|
+
# @param [String] The ID of the tool call
|
21
|
+
def initialize(role:, content: nil, tool_calls: [], tool_call_id: nil)
|
22
|
+
raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role)
|
23
|
+
raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) }
|
24
|
+
|
25
|
+
@role = role
|
26
|
+
# Some Tools return content as a JSON hence `.to_s`
|
27
|
+
@content = content.to_s
|
28
|
+
@tool_calls = tool_calls
|
29
|
+
@tool_call_id = tool_call_id
|
30
|
+
end
|
31
|
+
|
32
|
+
# Check if the message came from an LLM
|
33
|
+
#
|
34
|
+
# @return [Boolean] true/false whether this message was produced by an LLM
|
35
|
+
def llm?
|
36
|
+
model?
|
37
|
+
end
|
38
|
+
|
39
|
+
# Convert the message to a Google Gemini API-compatible hash
|
40
|
+
#
|
41
|
+
# @return [Hash] The message as a Google Gemini API-compatible hash
|
42
|
+
def to_hash
|
43
|
+
{}.tap do |h|
|
44
|
+
h[:role] = role
|
45
|
+
h[:parts] = if function?
|
46
|
+
[{
|
47
|
+
functionResponse: {
|
48
|
+
name: tool_call_id,
|
49
|
+
response: {
|
50
|
+
name: tool_call_id,
|
51
|
+
content: content
|
52
|
+
}
|
53
|
+
}
|
54
|
+
}]
|
55
|
+
elsif tool_calls.any?
|
56
|
+
tool_calls
|
57
|
+
else
|
58
|
+
[{text: content}]
|
59
|
+
end
|
60
|
+
end
|
61
|
+
end
|
62
|
+
|
63
|
+
# Google Gemini does not implement system prompts
|
64
|
+
def system?
|
65
|
+
false
|
66
|
+
end
|
67
|
+
|
68
|
+
# Check if the message is a tool call
|
69
|
+
#
|
70
|
+
# @return [Boolean] true/false whether this message is a tool call
|
71
|
+
def tool?
|
72
|
+
function?
|
73
|
+
end
|
74
|
+
|
75
|
+
# Check if the message is a tool call
|
76
|
+
#
|
77
|
+
# @return [Boolean] true/false whether this message is a tool call
|
78
|
+
def function?
|
79
|
+
role == "function"
|
80
|
+
end
|
81
|
+
|
82
|
+
# Check if the message came from an LLM
|
83
|
+
#
|
84
|
+
# @return [Boolean] true/false whether this message was produced by an LLM
|
85
|
+
def model?
|
86
|
+
role == "model"
|
87
|
+
end
|
88
|
+
end
|
89
|
+
end
|
90
|
+
end
|
@@ -0,0 +1,74 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Langchain
|
4
|
+
module Messages
|
5
|
+
class OpenAIMessage < Base
|
6
|
+
# OpenAI uses the following roles:
|
7
|
+
ROLES = [
|
8
|
+
"system",
|
9
|
+
"assistant",
|
10
|
+
"user",
|
11
|
+
"tool"
|
12
|
+
].freeze
|
13
|
+
|
14
|
+
TOOL_ROLE = "tool"
|
15
|
+
|
16
|
+
# Initialize a new OpenAI message
|
17
|
+
#
|
18
|
+
# @param [String] The role of the message
|
19
|
+
# @param [String] The content of the message
|
20
|
+
# @param [Array<Hash>] The tool calls made in the message
|
21
|
+
# @param [String] The ID of the tool call
|
22
|
+
def initialize(role:, content: nil, tool_calls: [], tool_call_id: nil) # TODO: Implement image_file: reference (https://platform.openai.com/docs/api-reference/messages/object#messages/object-content)
|
23
|
+
raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role)
|
24
|
+
raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) }
|
25
|
+
|
26
|
+
@role = role
|
27
|
+
# Some Tools return content as a JSON hence `.to_s`
|
28
|
+
@content = content.to_s
|
29
|
+
@tool_calls = tool_calls
|
30
|
+
@tool_call_id = tool_call_id
|
31
|
+
end
|
32
|
+
|
33
|
+
# Check if the message came from an LLM
|
34
|
+
#
|
35
|
+
# @return [Boolean] true/false whether this message was produced by an LLM
|
36
|
+
def llm?
|
37
|
+
assistant?
|
38
|
+
end
|
39
|
+
|
40
|
+
# Convert the message to an OpenAI API-compatible hash
|
41
|
+
#
|
42
|
+
# @return [Hash] The message as an OpenAI API-compatible hash
|
43
|
+
def to_hash
|
44
|
+
{}.tap do |h|
|
45
|
+
h[:role] = role
|
46
|
+
h[:content] = content if content # Content is nil for tool calls
|
47
|
+
h[:tool_calls] = tool_calls if tool_calls.any?
|
48
|
+
h[:tool_call_id] = tool_call_id if tool_call_id
|
49
|
+
end
|
50
|
+
end
|
51
|
+
|
52
|
+
# Check if the message came from an LLM
|
53
|
+
#
|
54
|
+
# @return [Boolean] true/false whether this message was produced by an LLM
|
55
|
+
def assistant?
|
56
|
+
role == "assistant"
|
57
|
+
end
|
58
|
+
|
59
|
+
# Check if the message are system instructions
|
60
|
+
#
|
61
|
+
# @return [Boolean] true/false whether this message are system instructions
|
62
|
+
def system?
|
63
|
+
role == "system"
|
64
|
+
end
|
65
|
+
|
66
|
+
# Check if the message is a tool call
|
67
|
+
#
|
68
|
+
# @return [Boolean] true/false whether this message is a tool call
|
69
|
+
def tool?
|
70
|
+
role == "tool"
|
71
|
+
end
|
72
|
+
end
|
73
|
+
end
|
74
|
+
end
|
@@ -8,16 +8,16 @@ module Langchain
|
|
8
8
|
|
9
9
|
# @param messages [Array<Langchain::Message>]
|
10
10
|
def initialize(messages: [])
|
11
|
-
raise ArgumentError, "messages array must only contain Langchain::Message instance(s)" unless messages.is_a?(Array) && messages.all? { |m| m.is_a?(Langchain::
|
11
|
+
raise ArgumentError, "messages array must only contain Langchain::Message instance(s)" unless messages.is_a?(Array) && messages.all? { |m| m.is_a?(Langchain::Messages::Base) }
|
12
12
|
|
13
13
|
@messages = messages
|
14
14
|
end
|
15
15
|
|
16
|
-
# Convert the thread to an
|
16
|
+
# Convert the thread to an LLM APIs-compatible array of hashes
|
17
17
|
#
|
18
18
|
# @return [Array<Hash>] The thread as an OpenAI API-compatible array of hashes
|
19
|
-
def
|
20
|
-
messages.map(&:
|
19
|
+
def array_of_message_hashes
|
20
|
+
messages.map(&:to_hash)
|
21
21
|
end
|
22
22
|
|
23
23
|
# Add a message to the thread
|
@@ -25,7 +25,7 @@ module Langchain
|
|
25
25
|
# @param message [Langchain::Message] The message to add
|
26
26
|
# @return [Array<Langchain::Message>] The updated messages array
|
27
27
|
def add_message(message)
|
28
|
-
raise ArgumentError, "message must be a Langchain::Message instance" unless message.is_a?(Langchain::
|
28
|
+
raise ArgumentError, "message must be a Langchain::Message instance" unless message.is_a?(Langchain::Messages::Base)
|
29
29
|
|
30
30
|
# Prepend the message to the thread
|
31
31
|
messages << message
|
data/lib/langchain/llm/base.rb
CHANGED
@@ -11,7 +11,8 @@ module Langchain::LLM
|
|
11
11
|
# - {Langchain::LLM::Azure}
|
12
12
|
# - {Langchain::LLM::Cohere}
|
13
13
|
# - {Langchain::LLM::GooglePalm}
|
14
|
-
# - {Langchain::LLM::
|
14
|
+
# - {Langchain::LLM::GoogleVertexAI}
|
15
|
+
# - {Langchain::LLM::GoogleGemini}
|
15
16
|
# - {Langchain::LLM::HuggingFace}
|
16
17
|
# - {Langchain::LLM::LlamaCpp}
|
17
18
|
# - {Langchain::LLM::OpenAI}
|
@@ -0,0 +1,67 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Langchain::LLM
|
4
|
+
# Usage:
|
5
|
+
# llm = Langchain::LLM::GoogleGemini.new(api_key: ENV['GOOGLE_GEMINI_API_KEY'])
|
6
|
+
class GoogleGemini < Base
|
7
|
+
DEFAULTS = {
|
8
|
+
chat_completion_model_name: "gemini-1.5-pro-latest",
|
9
|
+
temperature: 0.0
|
10
|
+
}
|
11
|
+
|
12
|
+
attr_reader :defaults, :api_key
|
13
|
+
|
14
|
+
def initialize(api_key:, default_options: {})
|
15
|
+
@api_key = api_key
|
16
|
+
@defaults = DEFAULTS.merge(default_options)
|
17
|
+
|
18
|
+
chat_parameters.update(
|
19
|
+
model: {default: @defaults[:chat_completion_model_name]},
|
20
|
+
temperature: {default: @defaults[:temperature]}
|
21
|
+
)
|
22
|
+
chat_parameters.remap(
|
23
|
+
messages: :contents,
|
24
|
+
system: :system_instruction,
|
25
|
+
tool_choice: :tool_config
|
26
|
+
)
|
27
|
+
end
|
28
|
+
|
29
|
+
# Generate a chat completion for a given prompt
|
30
|
+
#
|
31
|
+
# @param messages [Array<Hash>] List of messages comprising the conversation so far
|
32
|
+
# @param model [String] The model to use
|
33
|
+
# @param tools [Array<Hash>] A list of Tools the model may use to generate the next response
|
34
|
+
# @param tool_choice [String] Specifies the mode in which function calling should execute. If unspecified, the default value will be set to AUTO. Possible values: AUTO, ANY, NONE
|
35
|
+
# @param system [String] Developer set system instruction
|
36
|
+
def chat(params = {})
|
37
|
+
params[:system] = {parts: [{text: params[:system]}]} if params[:system]
|
38
|
+
params[:tools] = {function_declarations: params[:tools]} if params[:tools]
|
39
|
+
params[:tool_choice] = {function_calling_config: {mode: params[:tool_choice].upcase}} if params[:tool_choice]
|
40
|
+
|
41
|
+
raise ArgumentError.new("messages argument is required") if Array(params[:messages]).empty?
|
42
|
+
|
43
|
+
parameters = chat_parameters.to_params(params)
|
44
|
+
parameters[:generation_config] = {temperature: parameters.delete(:temperature)} if parameters[:temperature]
|
45
|
+
|
46
|
+
uri = URI("https://generativelanguage.googleapis.com/v1beta/models/#{parameters[:model]}:generateContent?key=#{api_key}")
|
47
|
+
|
48
|
+
request = Net::HTTP::Post.new(uri)
|
49
|
+
request.content_type = "application/json"
|
50
|
+
request.body = parameters.to_json
|
51
|
+
|
52
|
+
response = Net::HTTP.start(uri.hostname, uri.port, use_ssl: uri.scheme == "https") do |http|
|
53
|
+
http.request(request)
|
54
|
+
end
|
55
|
+
|
56
|
+
parsed_response = JSON.parse(response.body)
|
57
|
+
|
58
|
+
wrapped_response = Langchain::LLM::GoogleGeminiResponse.new(parsed_response, model: parameters[:model])
|
59
|
+
|
60
|
+
if wrapped_response.chat_completion || Array(wrapped_response.tool_calls).any?
|
61
|
+
wrapped_response
|
62
|
+
else
|
63
|
+
raise StandardError.new(response)
|
64
|
+
end
|
65
|
+
end
|
66
|
+
end
|
67
|
+
end
|