langchainrb 0.12.1 → 0.13.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +3 -2
- data/lib/langchain/assistants/assistant.rb +75 -20
- data/lib/langchain/assistants/messages/base.rb +16 -0
- data/lib/langchain/assistants/messages/google_gemini_message.rb +90 -0
- data/lib/langchain/assistants/messages/openai_message.rb +74 -0
- data/lib/langchain/assistants/thread.rb +5 -5
- data/lib/langchain/llm/base.rb +2 -1
- data/lib/langchain/llm/google_gemini.rb +67 -0
- data/lib/langchain/llm/google_vertex_ai.rb +68 -112
- data/lib/langchain/llm/response/google_gemini_response.rb +45 -0
- data/lib/langchain/llm/response/openai_response.rb +5 -1
- data/lib/langchain/tool/base.rb +11 -1
- data/lib/langchain/tool/calculator/calculator.json +1 -1
- data/lib/langchain/tool/database/database.json +3 -3
- data/lib/langchain/tool/file_system/file_system.json +3 -3
- data/lib/langchain/tool/news_retriever/news_retriever.json +121 -0
- data/lib/langchain/tool/news_retriever/news_retriever.rb +132 -0
- data/lib/langchain/tool/ruby_code_interpreter/ruby_code_interpreter.json +1 -1
- data/lib/langchain/tool/vectorsearch/vectorsearch.json +1 -1
- data/lib/langchain/tool/weather/weather.json +1 -1
- data/lib/langchain/tool/wikipedia/wikipedia.json +1 -1
- data/lib/langchain/tool/wikipedia/wikipedia.rb +2 -2
- data/lib/langchain/version.rb +1 -1
- data/lib/langchain.rb +3 -0
- metadata +14 -9
- data/lib/langchain/assistants/message.rb +0 -58
- data/lib/langchain/llm/response/google_vertex_ai_response.rb +0 -33
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: b146eb8568d30ae12aca93a25818fcff7421b7ee2e968330f3a68c5e523da148
|
4
|
+
data.tar.gz: 33f88d7ba03501606706314dce58f626fa0df5aab50639b5f5db3df527ee6520
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 6518e30de12653426280f6f8cf05f37a6d4b311ad4219af52276bace8a75ec6440b8f42c208d9d5c07bb4218f3259cc95ea9edd77f8ba037e7a1600a7dfa3170
|
7
|
+
data.tar.gz: 7f881a4347866c8b52161adaf6b98b669e38b4e2fd1ac513f02efd1dcfe73b2552ad8e65acdd02e9e002769ad3b394850f720e0971774632c2f489c25d9ce076
|
data/CHANGELOG.md
CHANGED
@@ -1,5 +1,10 @@
|
|
1
1
|
## [Unreleased]
|
2
2
|
|
3
|
+
## [0.13.0] - 2024-05-14
|
4
|
+
- New 🛠️ `Langchain::Tool::NewsRetriever` tool to fetch news via newsapi.org
|
5
|
+
- Langchain::Assistant works with `Langchain::LLM::GoogleVertexAI` and `Langchain::LLM::GoogleGemini` llms
|
6
|
+
- [BREAKING] Introduce new `Langchain::Messages::Base` abstraction
|
7
|
+
|
3
8
|
## [0.12.1] - 2024-05-13
|
4
9
|
- Langchain::LLM::Ollama now uses `llama3` by default
|
5
10
|
- Langchain::LLM::Anthropic#complete() now uses `claude-2.1` by default
|
data/README.md
CHANGED
@@ -412,6 +412,7 @@ Assistants are Agent-like objects that leverage helpful instructions, LLMs, tool
|
|
412
412
|
| "file_system" | Interacts with the file system | | |
|
413
413
|
| "ruby_code_interpreter" | Interprets Ruby expressions | | `gem "safe_ruby", "~> 1.0.4"` |
|
414
414
|
| "google_search" | A wrapper around Google Search | `ENV["SERPAPI_API_KEY"]` (https://serpapi.com/manage-api-key) | `gem "google_search_results", "~> 2.0.0"` |
|
415
|
+
| "news_retriever" | A wrapper around NewsApi.org | `ENV["NEWS_API_KEY"]` (https://newsapi.org/) | |
|
415
416
|
| "weather" | Calls Open Weather API to retrieve the current weather | `ENV["OPEN_WEATHER_API_KEY"]` (https://home.openweathermap.org/api_keys) | `gem "open-weather-ruby-client", "~> 0.3.0"` |
|
416
417
|
| "wikipedia" | Calls Wikipedia API to retrieve the summary | | `gem "wikipedia-client", "~> 1.17.0"` |
|
417
418
|
|
@@ -445,14 +446,14 @@ assistant = Langchain::Assistant.new(
|
|
445
446
|
thread: thread,
|
446
447
|
instructions: "You are a Meteorologist Assistant that is able to pull the weather for any location",
|
447
448
|
tools: [
|
448
|
-
Langchain::Tool::
|
449
|
+
Langchain::Tool::Weather.new(api_key: ENV["OPEN_WEATHER_API_KEY"])
|
449
450
|
]
|
450
451
|
)
|
451
452
|
```
|
452
453
|
### Using an Assistant
|
453
454
|
You can now add your message to an Assistant.
|
454
455
|
```ruby
|
455
|
-
assistant.add_message content: "What's the weather in New York
|
456
|
+
assistant.add_message content: "What's the weather in New York, New York?"
|
456
457
|
```
|
457
458
|
|
458
459
|
Run the Assistant to generate a response.
|
@@ -7,6 +7,12 @@ module Langchain
|
|
7
7
|
attr_reader :llm, :thread, :instructions
|
8
8
|
attr_accessor :tools
|
9
9
|
|
10
|
+
SUPPORTED_LLMS = [
|
11
|
+
Langchain::LLM::OpenAI,
|
12
|
+
Langchain::LLM::GoogleGemini,
|
13
|
+
Langchain::LLM::GoogleVertexAI
|
14
|
+
]
|
15
|
+
|
10
16
|
# Create a new assistant
|
11
17
|
#
|
12
18
|
# @param llm [Langchain::LLM::Base] LLM instance that the assistant will use
|
@@ -19,7 +25,9 @@ module Langchain
|
|
19
25
|
tools: [],
|
20
26
|
instructions: nil
|
21
27
|
)
|
22
|
-
|
28
|
+
unless SUPPORTED_LLMS.include?(llm.class)
|
29
|
+
raise ArgumentError, "Invalid LLM; currently only #{SUPPORTED_LLMS.join(", ")} are supported"
|
30
|
+
end
|
23
31
|
raise ArgumentError, "Thread must be an instance of Langchain::Thread" unless thread.is_a?(Langchain::Thread)
|
24
32
|
raise ArgumentError, "Tools must be an array of Langchain::Tool::Base instance(s)" unless tools.is_a?(Array) && tools.all? { |tool| tool.is_a?(Langchain::Tool::Base) }
|
25
33
|
|
@@ -30,7 +38,10 @@ module Langchain
|
|
30
38
|
|
31
39
|
# The first message in the thread should be the system instructions
|
32
40
|
# TODO: What if the user added old messages and the system instructions are already in there? Should this overwrite the existing instructions?
|
33
|
-
|
41
|
+
if llm.is_a?(Langchain::LLM::OpenAI)
|
42
|
+
add_message(role: "system", content: instructions) if instructions
|
43
|
+
end
|
44
|
+
# For Google Gemini, system instructions are added to the `system:` param in the `chat` method
|
34
45
|
end
|
35
46
|
|
36
47
|
# Add a user message to the thread
|
@@ -59,11 +70,12 @@ module Langchain
|
|
59
70
|
|
60
71
|
while running
|
61
72
|
# TODO: I think we need to look at all messages and not just the last one.
|
62
|
-
|
63
|
-
|
73
|
+
last_message = thread.messages.last
|
74
|
+
|
75
|
+
if last_message.system?
|
64
76
|
# Do nothing
|
65
77
|
running = false
|
66
|
-
|
78
|
+
elsif last_message.llm?
|
67
79
|
if last_message.tool_calls.any?
|
68
80
|
if auto_tool_execution
|
69
81
|
run_tools(last_message.tool_calls)
|
@@ -76,11 +88,11 @@ module Langchain
|
|
76
88
|
# Do nothing
|
77
89
|
running = false
|
78
90
|
end
|
79
|
-
|
91
|
+
elsif last_message.user?
|
80
92
|
# Run it!
|
81
93
|
response = chat_with_llm
|
82
94
|
|
83
|
-
if response.tool_calls
|
95
|
+
if response.tool_calls.any?
|
84
96
|
# Re-run the while(running) loop to process the tool calls
|
85
97
|
running = true
|
86
98
|
add_message(role: response.role, tool_calls: response.tool_calls)
|
@@ -89,12 +101,12 @@ module Langchain
|
|
89
101
|
running = false
|
90
102
|
add_message(role: response.role, content: response.chat_completion)
|
91
103
|
end
|
92
|
-
|
104
|
+
elsif last_message.tool?
|
93
105
|
# Run it!
|
94
106
|
response = chat_with_llm
|
95
107
|
running = true
|
96
108
|
|
97
|
-
if response.tool_calls
|
109
|
+
if response.tool_calls.any?
|
98
110
|
add_message(role: response.role, tool_calls: response.tool_calls)
|
99
111
|
elsif response.chat_completion
|
100
112
|
add_message(role: response.role, content: response.chat_completion)
|
@@ -121,8 +133,14 @@ module Langchain
|
|
121
133
|
# @param output [String] The output of the tool
|
122
134
|
# @return [Array<Langchain::Message>] The messages in the thread
|
123
135
|
def submit_tool_output(tool_call_id:, output:)
|
124
|
-
|
125
|
-
|
136
|
+
tool_role = if llm.is_a?(Langchain::LLM::OpenAI)
|
137
|
+
Langchain::Messages::OpenAIMessage::TOOL_ROLE
|
138
|
+
elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
|
139
|
+
Langchain::Messages::GoogleGeminiMessage::TOOL_ROLE
|
140
|
+
end
|
141
|
+
|
142
|
+
# TODO: Validate that `tool_call_id` is valid by scanning messages and checking if this tool call ID was invoked
|
143
|
+
add_message(role: tool_role, content: output, tool_call_id: tool_call_id)
|
126
144
|
end
|
127
145
|
|
128
146
|
# Delete all messages in the thread
|
@@ -156,10 +174,15 @@ module Langchain
|
|
156
174
|
def chat_with_llm
|
157
175
|
Langchain.logger.info("Sending a call to #{llm.class}", for: self.class)
|
158
176
|
|
159
|
-
params = {messages: thread.
|
177
|
+
params = {messages: thread.array_of_message_hashes}
|
160
178
|
|
161
179
|
if tools.any?
|
162
|
-
|
180
|
+
if llm.is_a?(Langchain::LLM::OpenAI)
|
181
|
+
params[:tools] = tools.map(&:to_openai_tools).flatten
|
182
|
+
elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
|
183
|
+
params[:tools] = tools.map(&:to_google_gemini_tools).flatten
|
184
|
+
params[:system] = instructions if instructions
|
185
|
+
end
|
163
186
|
# TODO: Not sure that tool_choice should always be "auto"; Maybe we can let the user toggle it.
|
164
187
|
params[:tool_choice] = "auto"
|
165
188
|
end
|
@@ -173,11 +196,11 @@ module Langchain
|
|
173
196
|
def run_tools(tool_calls)
|
174
197
|
# Iterate over each function invocation and submit tool output
|
175
198
|
tool_calls.each do |tool_call|
|
176
|
-
tool_call_id =
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
199
|
+
tool_call_id, tool_name, method_name, tool_arguments = if llm.is_a?(Langchain::LLM::OpenAI)
|
200
|
+
extract_openai_tool_call(tool_call: tool_call)
|
201
|
+
elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
|
202
|
+
extract_google_gemini_tool_call(tool_call: tool_call)
|
203
|
+
end
|
181
204
|
|
182
205
|
tool_instance = tools.find do |t|
|
183
206
|
t.name == tool_name
|
@@ -190,13 +213,41 @@ module Langchain
|
|
190
213
|
|
191
214
|
response = chat_with_llm
|
192
215
|
|
193
|
-
if response.tool_calls
|
216
|
+
if response.tool_calls.any?
|
194
217
|
add_message(role: response.role, tool_calls: response.tool_calls)
|
195
218
|
elsif response.chat_completion
|
196
219
|
add_message(role: response.role, content: response.chat_completion)
|
197
220
|
end
|
198
221
|
end
|
199
222
|
|
223
|
+
# Extract the tool call information from the OpenAI tool call hash
|
224
|
+
#
|
225
|
+
# @param tool_call [Hash] The tool call hash
|
226
|
+
# @return [Array] The tool call information
|
227
|
+
def extract_openai_tool_call(tool_call:)
|
228
|
+
tool_call_id = tool_call.dig("id")
|
229
|
+
|
230
|
+
function_name = tool_call.dig("function", "name")
|
231
|
+
tool_name, method_name = function_name.split("__")
|
232
|
+
tool_arguments = JSON.parse(tool_call.dig("function", "arguments"), symbolize_names: true)
|
233
|
+
|
234
|
+
[tool_call_id, tool_name, method_name, tool_arguments]
|
235
|
+
end
|
236
|
+
|
237
|
+
# Extract the tool call information from the Google Gemini tool call hash
|
238
|
+
#
|
239
|
+
# @param tool_call [Hash] The tool call hash, format: {"functionCall"=>{"name"=>"weather__execute", "args"=>{"input"=>"NYC"}}}
|
240
|
+
# @return [Array] The tool call information
|
241
|
+
def extract_google_gemini_tool_call(tool_call:)
|
242
|
+
tool_call_id = tool_call.dig("functionCall", "name")
|
243
|
+
|
244
|
+
function_name = tool_call.dig("functionCall", "name")
|
245
|
+
tool_name, method_name = function_name.split("__")
|
246
|
+
tool_arguments = tool_call.dig("functionCall", "args").transform_keys(&:to_sym)
|
247
|
+
|
248
|
+
[tool_call_id, tool_name, method_name, tool_arguments]
|
249
|
+
end
|
250
|
+
|
200
251
|
# Build a message
|
201
252
|
#
|
202
253
|
# @param role [String] The role of the message
|
@@ -205,7 +256,11 @@ module Langchain
|
|
205
256
|
# @param tool_call_id [String] The ID of the tool call to include in the message
|
206
257
|
# @return [Langchain::Message] The Message object
|
207
258
|
def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
|
208
|
-
|
259
|
+
if llm.is_a?(Langchain::LLM::OpenAI)
|
260
|
+
Langchain::Messages::OpenAIMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
|
261
|
+
elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
|
262
|
+
Langchain::Messages::GoogleGeminiMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
|
263
|
+
end
|
209
264
|
end
|
210
265
|
|
211
266
|
# TODO: Fix the message truncation when context window is exceeded
|
@@ -0,0 +1,16 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Langchain
|
4
|
+
module Messages
|
5
|
+
class Base
|
6
|
+
attr_reader :role, :content, :tool_calls, :tool_call_id
|
7
|
+
|
8
|
+
# Check if the message came from a user
|
9
|
+
#
|
10
|
+
# @param [Boolean] true/false whether the message came from a user
|
11
|
+
def user?
|
12
|
+
role == "user"
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
@@ -0,0 +1,90 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Langchain
|
4
|
+
module Messages
|
5
|
+
class GoogleGeminiMessage < Base
|
6
|
+
# Google Gemini uses the following roles:
|
7
|
+
ROLES = [
|
8
|
+
"user",
|
9
|
+
"model",
|
10
|
+
"function"
|
11
|
+
].freeze
|
12
|
+
|
13
|
+
TOOL_ROLE = "function"
|
14
|
+
|
15
|
+
# Initialize a new Google Gemini message
|
16
|
+
#
|
17
|
+
# @param [String] The role of the message
|
18
|
+
# @param [String] The content of the message
|
19
|
+
# @param [Array<Hash>] The tool calls made in the message
|
20
|
+
# @param [String] The ID of the tool call
|
21
|
+
def initialize(role:, content: nil, tool_calls: [], tool_call_id: nil)
|
22
|
+
raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role)
|
23
|
+
raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) }
|
24
|
+
|
25
|
+
@role = role
|
26
|
+
# Some Tools return content as a JSON hence `.to_s`
|
27
|
+
@content = content.to_s
|
28
|
+
@tool_calls = tool_calls
|
29
|
+
@tool_call_id = tool_call_id
|
30
|
+
end
|
31
|
+
|
32
|
+
# Check if the message came from an LLM
|
33
|
+
#
|
34
|
+
# @return [Boolean] true/false whether this message was produced by an LLM
|
35
|
+
def llm?
|
36
|
+
model?
|
37
|
+
end
|
38
|
+
|
39
|
+
# Convert the message to a Google Gemini API-compatible hash
|
40
|
+
#
|
41
|
+
# @return [Hash] The message as a Google Gemini API-compatible hash
|
42
|
+
def to_hash
|
43
|
+
{}.tap do |h|
|
44
|
+
h[:role] = role
|
45
|
+
h[:parts] = if function?
|
46
|
+
[{
|
47
|
+
functionResponse: {
|
48
|
+
name: tool_call_id,
|
49
|
+
response: {
|
50
|
+
name: tool_call_id,
|
51
|
+
content: content
|
52
|
+
}
|
53
|
+
}
|
54
|
+
}]
|
55
|
+
elsif tool_calls.any?
|
56
|
+
tool_calls
|
57
|
+
else
|
58
|
+
[{text: content}]
|
59
|
+
end
|
60
|
+
end
|
61
|
+
end
|
62
|
+
|
63
|
+
# Google Gemini does not implement system prompts
|
64
|
+
def system?
|
65
|
+
false
|
66
|
+
end
|
67
|
+
|
68
|
+
# Check if the message is a tool call
|
69
|
+
#
|
70
|
+
# @return [Boolean] true/false whether this message is a tool call
|
71
|
+
def tool?
|
72
|
+
function?
|
73
|
+
end
|
74
|
+
|
75
|
+
# Check if the message is a tool call
|
76
|
+
#
|
77
|
+
# @return [Boolean] true/false whether this message is a tool call
|
78
|
+
def function?
|
79
|
+
role == "function"
|
80
|
+
end
|
81
|
+
|
82
|
+
# Check if the message came from an LLM
|
83
|
+
#
|
84
|
+
# @return [Boolean] true/false whether this message was produced by an LLM
|
85
|
+
def model?
|
86
|
+
role == "model"
|
87
|
+
end
|
88
|
+
end
|
89
|
+
end
|
90
|
+
end
|
@@ -0,0 +1,74 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Langchain
|
4
|
+
module Messages
|
5
|
+
class OpenAIMessage < Base
|
6
|
+
# OpenAI uses the following roles:
|
7
|
+
ROLES = [
|
8
|
+
"system",
|
9
|
+
"assistant",
|
10
|
+
"user",
|
11
|
+
"tool"
|
12
|
+
].freeze
|
13
|
+
|
14
|
+
TOOL_ROLE = "tool"
|
15
|
+
|
16
|
+
# Initialize a new OpenAI message
|
17
|
+
#
|
18
|
+
# @param [String] The role of the message
|
19
|
+
# @param [String] The content of the message
|
20
|
+
# @param [Array<Hash>] The tool calls made in the message
|
21
|
+
# @param [String] The ID of the tool call
|
22
|
+
def initialize(role:, content: nil, tool_calls: [], tool_call_id: nil) # TODO: Implement image_file: reference (https://platform.openai.com/docs/api-reference/messages/object#messages/object-content)
|
23
|
+
raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role)
|
24
|
+
raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) }
|
25
|
+
|
26
|
+
@role = role
|
27
|
+
# Some Tools return content as a JSON hence `.to_s`
|
28
|
+
@content = content.to_s
|
29
|
+
@tool_calls = tool_calls
|
30
|
+
@tool_call_id = tool_call_id
|
31
|
+
end
|
32
|
+
|
33
|
+
# Check if the message came from an LLM
|
34
|
+
#
|
35
|
+
# @return [Boolean] true/false whether this message was produced by an LLM
|
36
|
+
def llm?
|
37
|
+
assistant?
|
38
|
+
end
|
39
|
+
|
40
|
+
# Convert the message to an OpenAI API-compatible hash
|
41
|
+
#
|
42
|
+
# @return [Hash] The message as an OpenAI API-compatible hash
|
43
|
+
def to_hash
|
44
|
+
{}.tap do |h|
|
45
|
+
h[:role] = role
|
46
|
+
h[:content] = content if content # Content is nil for tool calls
|
47
|
+
h[:tool_calls] = tool_calls if tool_calls.any?
|
48
|
+
h[:tool_call_id] = tool_call_id if tool_call_id
|
49
|
+
end
|
50
|
+
end
|
51
|
+
|
52
|
+
# Check if the message came from an LLM
|
53
|
+
#
|
54
|
+
# @return [Boolean] true/false whether this message was produced by an LLM
|
55
|
+
def assistant?
|
56
|
+
role == "assistant"
|
57
|
+
end
|
58
|
+
|
59
|
+
# Check if the message are system instructions
|
60
|
+
#
|
61
|
+
# @return [Boolean] true/false whether this message are system instructions
|
62
|
+
def system?
|
63
|
+
role == "system"
|
64
|
+
end
|
65
|
+
|
66
|
+
# Check if the message is a tool call
|
67
|
+
#
|
68
|
+
# @return [Boolean] true/false whether this message is a tool call
|
69
|
+
def tool?
|
70
|
+
role == "tool"
|
71
|
+
end
|
72
|
+
end
|
73
|
+
end
|
74
|
+
end
|
@@ -8,16 +8,16 @@ module Langchain
|
|
8
8
|
|
9
9
|
# @param messages [Array<Langchain::Message>]
|
10
10
|
def initialize(messages: [])
|
11
|
-
raise ArgumentError, "messages array must only contain Langchain::Message instance(s)" unless messages.is_a?(Array) && messages.all? { |m| m.is_a?(Langchain::
|
11
|
+
raise ArgumentError, "messages array must only contain Langchain::Message instance(s)" unless messages.is_a?(Array) && messages.all? { |m| m.is_a?(Langchain::Messages::Base) }
|
12
12
|
|
13
13
|
@messages = messages
|
14
14
|
end
|
15
15
|
|
16
|
-
# Convert the thread to an
|
16
|
+
# Convert the thread to an LLM APIs-compatible array of hashes
|
17
17
|
#
|
18
18
|
# @return [Array<Hash>] The thread as an OpenAI API-compatible array of hashes
|
19
|
-
def
|
20
|
-
messages.map(&:
|
19
|
+
def array_of_message_hashes
|
20
|
+
messages.map(&:to_hash)
|
21
21
|
end
|
22
22
|
|
23
23
|
# Add a message to the thread
|
@@ -25,7 +25,7 @@ module Langchain
|
|
25
25
|
# @param message [Langchain::Message] The message to add
|
26
26
|
# @return [Array<Langchain::Message>] The updated messages array
|
27
27
|
def add_message(message)
|
28
|
-
raise ArgumentError, "message must be a Langchain::Message instance" unless message.is_a?(Langchain::
|
28
|
+
raise ArgumentError, "message must be a Langchain::Message instance" unless message.is_a?(Langchain::Messages::Base)
|
29
29
|
|
30
30
|
# Prepend the message to the thread
|
31
31
|
messages << message
|
data/lib/langchain/llm/base.rb
CHANGED
@@ -11,7 +11,8 @@ module Langchain::LLM
|
|
11
11
|
# - {Langchain::LLM::Azure}
|
12
12
|
# - {Langchain::LLM::Cohere}
|
13
13
|
# - {Langchain::LLM::GooglePalm}
|
14
|
-
# - {Langchain::LLM::
|
14
|
+
# - {Langchain::LLM::GoogleVertexAI}
|
15
|
+
# - {Langchain::LLM::GoogleGemini}
|
15
16
|
# - {Langchain::LLM::HuggingFace}
|
16
17
|
# - {Langchain::LLM::LlamaCpp}
|
17
18
|
# - {Langchain::LLM::OpenAI}
|
@@ -0,0 +1,67 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Langchain::LLM
|
4
|
+
# Usage:
|
5
|
+
# llm = Langchain::LLM::GoogleGemini.new(api_key: ENV['GOOGLE_GEMINI_API_KEY'])
|
6
|
+
class GoogleGemini < Base
|
7
|
+
DEFAULTS = {
|
8
|
+
chat_completion_model_name: "gemini-1.5-pro-latest",
|
9
|
+
temperature: 0.0
|
10
|
+
}
|
11
|
+
|
12
|
+
attr_reader :defaults, :api_key
|
13
|
+
|
14
|
+
def initialize(api_key:, default_options: {})
|
15
|
+
@api_key = api_key
|
16
|
+
@defaults = DEFAULTS.merge(default_options)
|
17
|
+
|
18
|
+
chat_parameters.update(
|
19
|
+
model: {default: @defaults[:chat_completion_model_name]},
|
20
|
+
temperature: {default: @defaults[:temperature]}
|
21
|
+
)
|
22
|
+
chat_parameters.remap(
|
23
|
+
messages: :contents,
|
24
|
+
system: :system_instruction,
|
25
|
+
tool_choice: :tool_config
|
26
|
+
)
|
27
|
+
end
|
28
|
+
|
29
|
+
# Generate a chat completion for a given prompt
|
30
|
+
#
|
31
|
+
# @param messages [Array<Hash>] List of messages comprising the conversation so far
|
32
|
+
# @param model [String] The model to use
|
33
|
+
# @param tools [Array<Hash>] A list of Tools the model may use to generate the next response
|
34
|
+
# @param tool_choice [String] Specifies the mode in which function calling should execute. If unspecified, the default value will be set to AUTO. Possible values: AUTO, ANY, NONE
|
35
|
+
# @param system [String] Developer set system instruction
|
36
|
+
def chat(params = {})
|
37
|
+
params[:system] = {parts: [{text: params[:system]}]} if params[:system]
|
38
|
+
params[:tools] = {function_declarations: params[:tools]} if params[:tools]
|
39
|
+
params[:tool_choice] = {function_calling_config: {mode: params[:tool_choice].upcase}} if params[:tool_choice]
|
40
|
+
|
41
|
+
raise ArgumentError.new("messages argument is required") if Array(params[:messages]).empty?
|
42
|
+
|
43
|
+
parameters = chat_parameters.to_params(params)
|
44
|
+
parameters[:generation_config] = {temperature: parameters.delete(:temperature)} if parameters[:temperature]
|
45
|
+
|
46
|
+
uri = URI("https://generativelanguage.googleapis.com/v1beta/models/#{parameters[:model]}:generateContent?key=#{api_key}")
|
47
|
+
|
48
|
+
request = Net::HTTP::Post.new(uri)
|
49
|
+
request.content_type = "application/json"
|
50
|
+
request.body = parameters.to_json
|
51
|
+
|
52
|
+
response = Net::HTTP.start(uri.hostname, uri.port, use_ssl: uri.scheme == "https") do |http|
|
53
|
+
http.request(request)
|
54
|
+
end
|
55
|
+
|
56
|
+
parsed_response = JSON.parse(response.body)
|
57
|
+
|
58
|
+
wrapped_response = Langchain::LLM::GoogleGeminiResponse.new(parsed_response, model: parameters[:model])
|
59
|
+
|
60
|
+
if wrapped_response.chat_completion || Array(wrapped_response.tool_calls).any?
|
61
|
+
wrapped_response
|
62
|
+
else
|
63
|
+
raise StandardError.new(response)
|
64
|
+
end
|
65
|
+
end
|
66
|
+
end
|
67
|
+
end
|