langchainrb 0.12.1 → 0.13.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (30) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +8 -0
  3. data/README.md +3 -2
  4. data/lib/langchain/assistants/assistant.rb +75 -20
  5. data/lib/langchain/assistants/messages/base.rb +16 -0
  6. data/lib/langchain/assistants/messages/google_gemini_message.rb +90 -0
  7. data/lib/langchain/assistants/messages/openai_message.rb +74 -0
  8. data/lib/langchain/assistants/thread.rb +5 -5
  9. data/lib/langchain/evals/ragas/faithfulness.rb +2 -0
  10. data/lib/langchain/llm/base.rb +2 -1
  11. data/lib/langchain/llm/google_gemini.rb +67 -0
  12. data/lib/langchain/llm/google_vertex_ai.rb +75 -108
  13. data/lib/langchain/llm/response/google_gemini_response.rb +45 -0
  14. data/lib/langchain/llm/response/openai_response.rb +5 -1
  15. data/lib/langchain/tool/base.rb +11 -1
  16. data/lib/langchain/tool/calculator/calculator.json +1 -1
  17. data/lib/langchain/tool/database/database.json +3 -3
  18. data/lib/langchain/tool/file_system/file_system.json +3 -3
  19. data/lib/langchain/tool/news_retriever/news_retriever.json +121 -0
  20. data/lib/langchain/tool/news_retriever/news_retriever.rb +132 -0
  21. data/lib/langchain/tool/ruby_code_interpreter/ruby_code_interpreter.json +1 -1
  22. data/lib/langchain/tool/vectorsearch/vectorsearch.json +1 -1
  23. data/lib/langchain/tool/weather/weather.json +1 -1
  24. data/lib/langchain/tool/wikipedia/wikipedia.json +1 -1
  25. data/lib/langchain/tool/wikipedia/wikipedia.rb +2 -2
  26. data/lib/langchain/version.rb +1 -1
  27. data/lib/langchain.rb +3 -0
  28. metadata +14 -9
  29. data/lib/langchain/assistants/message.rb +0 -58
  30. data/lib/langchain/llm/response/google_vertex_ai_response.rb +0 -33
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 6f106f178a5641f17ca723eec7272d120fbdd0e3046e44152eaf575af2985724
4
- data.tar.gz: 417d5b671a6783d0854c05c43871b1e931133eaf65e6d39ac6ac4989a0124e36
3
+ metadata.gz: 31daa3b09f92561f783122c10c1b48482bba75eac67e01550c71f7d76af36551
4
+ data.tar.gz: 355e21f33fbc3d21ac364ce046b0d2908ef111d2aa17996605df953ca25d0640
5
5
  SHA512:
6
- metadata.gz: 184b139d3e9d54fcd42665e25692d9fdb14e92c3d0bb23713647dfa2f1a87854215b13cc9573f10872d3e2e09d042e8c05640356276bc69a32423119e99c129b
7
- data.tar.gz: 4503a8498018b53a9068345186efe03a72da4340d82a0f7806d4007e220c748814ab2da49194e09fd11ef33de3a35f38048e835dd43cee2c0b967bb810bfb512
6
+ metadata.gz: f2bbf794a223f9b0da303f9b65a1a309213db00d45227ce6e9d5a9bc039d1150e06b786ff9730c1e4f2f2fd6d6566687d4a04d3c39f5dcd8d9e66c8e84e097ba
7
+ data.tar.gz: b406738ff1be88c7c545ec284d3050a3b5c0bb34a747f345ff18cbaeb63a3abf9763ec723913bd58ddd62be261c6abd88a87448fd2b9d3bde00eb53d795931e2
data/CHANGELOG.md CHANGED
@@ -1,5 +1,13 @@
1
1
  ## [Unreleased]
2
2
 
3
+ ## [0.13.1] - 2024-05-14
4
+ - Better error handling for `Langchain::LLM::GoogleVertexAI`
5
+
6
+ ## [0.13.0] - 2024-05-14
7
+ - New 🛠️ `Langchain::Tool::NewsRetriever` tool to fetch news via newsapi.org
8
+ - Langchain::Assistant works with `Langchain::LLM::GoogleVertexAI` and `Langchain::LLM::GoogleGemini` llms
9
+ - [BREAKING] Introduce new `Langchain::Messages::Base` abstraction
10
+
3
11
  ## [0.12.1] - 2024-05-13
4
12
  - Langchain::LLM::Ollama now uses `llama3` by default
5
13
  - Langchain::LLM::Anthropic#complete() now uses `claude-2.1` by default
data/README.md CHANGED
@@ -412,6 +412,7 @@ Assistants are Agent-like objects that leverage helpful instructions, LLMs, tool
412
412
  | "file_system" | Interacts with the file system | | |
413
413
  | "ruby_code_interpreter" | Interprets Ruby expressions | | `gem "safe_ruby", "~> 1.0.4"` |
414
414
  | "google_search" | A wrapper around Google Search | `ENV["SERPAPI_API_KEY"]` (https://serpapi.com/manage-api-key) | `gem "google_search_results", "~> 2.0.0"` |
415
+ | "news_retriever" | A wrapper around NewsApi.org | `ENV["NEWS_API_KEY"]` (https://newsapi.org/) | |
415
416
  | "weather" | Calls Open Weather API to retrieve the current weather | `ENV["OPEN_WEATHER_API_KEY"]` (https://home.openweathermap.org/api_keys) | `gem "open-weather-ruby-client", "~> 0.3.0"` |
416
417
  | "wikipedia" | Calls Wikipedia API to retrieve the summary | | `gem "wikipedia-client", "~> 1.17.0"` |
417
418
 
@@ -445,14 +446,14 @@ assistant = Langchain::Assistant.new(
445
446
  thread: thread,
446
447
  instructions: "You are a Meteorologist Assistant that is able to pull the weather for any location",
447
448
  tools: [
448
- Langchain::Tool::GoogleSearch.new(api_key: ENV["SERPAPI_API_KEY"])
449
+ Langchain::Tool::Weather.new(api_key: ENV["OPEN_WEATHER_API_KEY"])
449
450
  ]
450
451
  )
451
452
  ```
452
453
  ### Using an Assistant
453
454
  You can now add your message to an Assistant.
454
455
  ```ruby
455
- assistant.add_message content: "What's the weather in New York City?"
456
+ assistant.add_message content: "What's the weather in New York, New York?"
456
457
  ```
457
458
 
458
459
  Run the Assistant to generate a response.
@@ -7,6 +7,12 @@ module Langchain
7
7
  attr_reader :llm, :thread, :instructions
8
8
  attr_accessor :tools
9
9
 
10
+ SUPPORTED_LLMS = [
11
+ Langchain::LLM::OpenAI,
12
+ Langchain::LLM::GoogleGemini,
13
+ Langchain::LLM::GoogleVertexAI
14
+ ]
15
+
10
16
  # Create a new assistant
11
17
  #
12
18
  # @param llm [Langchain::LLM::Base] LLM instance that the assistant will use
@@ -19,7 +25,9 @@ module Langchain
19
25
  tools: [],
20
26
  instructions: nil
21
27
  )
22
- raise ArgumentError, "Invalid LLM; currently only Langchain::LLM::OpenAI is supported" unless llm.instance_of?(Langchain::LLM::OpenAI)
28
+ unless SUPPORTED_LLMS.include?(llm.class)
29
+ raise ArgumentError, "Invalid LLM; currently only #{SUPPORTED_LLMS.join(", ")} are supported"
30
+ end
23
31
  raise ArgumentError, "Thread must be an instance of Langchain::Thread" unless thread.is_a?(Langchain::Thread)
24
32
  raise ArgumentError, "Tools must be an array of Langchain::Tool::Base instance(s)" unless tools.is_a?(Array) && tools.all? { |tool| tool.is_a?(Langchain::Tool::Base) }
25
33
 
@@ -30,7 +38,10 @@ module Langchain
30
38
 
31
39
  # The first message in the thread should be the system instructions
32
40
  # TODO: What if the user added old messages and the system instructions are already in there? Should this overwrite the existing instructions?
33
- add_message(role: "system", content: instructions) if instructions
41
+ if llm.is_a?(Langchain::LLM::OpenAI)
42
+ add_message(role: "system", content: instructions) if instructions
43
+ end
44
+ # For Google Gemini, system instructions are added to the `system:` param in the `chat` method
34
45
  end
35
46
 
36
47
  # Add a user message to the thread
@@ -59,11 +70,12 @@ module Langchain
59
70
 
60
71
  while running
61
72
  # TODO: I think we need to look at all messages and not just the last one.
62
- case (last_message = thread.messages.last).role
63
- when "system"
73
+ last_message = thread.messages.last
74
+
75
+ if last_message.system?
64
76
  # Do nothing
65
77
  running = false
66
- when "assistant"
78
+ elsif last_message.llm?
67
79
  if last_message.tool_calls.any?
68
80
  if auto_tool_execution
69
81
  run_tools(last_message.tool_calls)
@@ -76,11 +88,11 @@ module Langchain
76
88
  # Do nothing
77
89
  running = false
78
90
  end
79
- when "user"
91
+ elsif last_message.user?
80
92
  # Run it!
81
93
  response = chat_with_llm
82
94
 
83
- if response.tool_calls
95
+ if response.tool_calls.any?
84
96
  # Re-run the while(running) loop to process the tool calls
85
97
  running = true
86
98
  add_message(role: response.role, tool_calls: response.tool_calls)
@@ -89,12 +101,12 @@ module Langchain
89
101
  running = false
90
102
  add_message(role: response.role, content: response.chat_completion)
91
103
  end
92
- when "tool"
104
+ elsif last_message.tool?
93
105
  # Run it!
94
106
  response = chat_with_llm
95
107
  running = true
96
108
 
97
- if response.tool_calls
109
+ if response.tool_calls.any?
98
110
  add_message(role: response.role, tool_calls: response.tool_calls)
99
111
  elsif response.chat_completion
100
112
  add_message(role: response.role, content: response.chat_completion)
@@ -121,8 +133,14 @@ module Langchain
121
133
  # @param output [String] The output of the tool
122
134
  # @return [Array<Langchain::Message>] The messages in the thread
123
135
  def submit_tool_output(tool_call_id:, output:)
124
- # TODO: Validate that `tool_call_id` is valid
125
- add_message(role: "tool", content: output, tool_call_id: tool_call_id)
136
+ tool_role = if llm.is_a?(Langchain::LLM::OpenAI)
137
+ Langchain::Messages::OpenAIMessage::TOOL_ROLE
138
+ elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
139
+ Langchain::Messages::GoogleGeminiMessage::TOOL_ROLE
140
+ end
141
+
142
+ # TODO: Validate that `tool_call_id` is valid by scanning messages and checking if this tool call ID was invoked
143
+ add_message(role: tool_role, content: output, tool_call_id: tool_call_id)
126
144
  end
127
145
 
128
146
  # Delete all messages in the thread
@@ -156,10 +174,15 @@ module Langchain
156
174
  def chat_with_llm
157
175
  Langchain.logger.info("Sending a call to #{llm.class}", for: self.class)
158
176
 
159
- params = {messages: thread.openai_messages}
177
+ params = {messages: thread.array_of_message_hashes}
160
178
 
161
179
  if tools.any?
162
- params[:tools] = tools.map(&:to_openai_tools).flatten
180
+ if llm.is_a?(Langchain::LLM::OpenAI)
181
+ params[:tools] = tools.map(&:to_openai_tools).flatten
182
+ elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
183
+ params[:tools] = tools.map(&:to_google_gemini_tools).flatten
184
+ params[:system] = instructions if instructions
185
+ end
163
186
  # TODO: Not sure that tool_choice should always be "auto"; Maybe we can let the user toggle it.
164
187
  params[:tool_choice] = "auto"
165
188
  end
@@ -173,11 +196,11 @@ module Langchain
173
196
  def run_tools(tool_calls)
174
197
  # Iterate over each function invocation and submit tool output
175
198
  tool_calls.each do |tool_call|
176
- tool_call_id = tool_call.dig("id")
177
-
178
- function_name = tool_call.dig("function", "name")
179
- tool_name, method_name = function_name.split("-")
180
- tool_arguments = JSON.parse(tool_call.dig("function", "arguments"), symbolize_names: true)
199
+ tool_call_id, tool_name, method_name, tool_arguments = if llm.is_a?(Langchain::LLM::OpenAI)
200
+ extract_openai_tool_call(tool_call: tool_call)
201
+ elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
202
+ extract_google_gemini_tool_call(tool_call: tool_call)
203
+ end
181
204
 
182
205
  tool_instance = tools.find do |t|
183
206
  t.name == tool_name
@@ -190,13 +213,41 @@ module Langchain
190
213
 
191
214
  response = chat_with_llm
192
215
 
193
- if response.tool_calls
216
+ if response.tool_calls.any?
194
217
  add_message(role: response.role, tool_calls: response.tool_calls)
195
218
  elsif response.chat_completion
196
219
  add_message(role: response.role, content: response.chat_completion)
197
220
  end
198
221
  end
199
222
 
223
+ # Extract the tool call information from the OpenAI tool call hash
224
+ #
225
+ # @param tool_call [Hash] The tool call hash
226
+ # @return [Array] The tool call information
227
+ def extract_openai_tool_call(tool_call:)
228
+ tool_call_id = tool_call.dig("id")
229
+
230
+ function_name = tool_call.dig("function", "name")
231
+ tool_name, method_name = function_name.split("__")
232
+ tool_arguments = JSON.parse(tool_call.dig("function", "arguments"), symbolize_names: true)
233
+
234
+ [tool_call_id, tool_name, method_name, tool_arguments]
235
+ end
236
+
237
+ # Extract the tool call information from the Google Gemini tool call hash
238
+ #
239
+ # @param tool_call [Hash] The tool call hash, format: {"functionCall"=>{"name"=>"weather__execute", "args"=>{"input"=>"NYC"}}}
240
+ # @return [Array] The tool call information
241
+ def extract_google_gemini_tool_call(tool_call:)
242
+ tool_call_id = tool_call.dig("functionCall", "name")
243
+
244
+ function_name = tool_call.dig("functionCall", "name")
245
+ tool_name, method_name = function_name.split("__")
246
+ tool_arguments = tool_call.dig("functionCall", "args").transform_keys(&:to_sym)
247
+
248
+ [tool_call_id, tool_name, method_name, tool_arguments]
249
+ end
250
+
200
251
  # Build a message
201
252
  #
202
253
  # @param role [String] The role of the message
@@ -205,7 +256,11 @@ module Langchain
205
256
  # @param tool_call_id [String] The ID of the tool call to include in the message
206
257
  # @return [Langchain::Message] The Message object
207
258
  def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
208
- Message.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
259
+ if llm.is_a?(Langchain::LLM::OpenAI)
260
+ Langchain::Messages::OpenAIMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
261
+ elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
262
+ Langchain::Messages::GoogleGeminiMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
263
+ end
209
264
  end
210
265
 
211
266
  # TODO: Fix the message truncation when context window is exceeded
@@ -0,0 +1,16 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain
4
+ module Messages
5
+ class Base
6
+ attr_reader :role, :content, :tool_calls, :tool_call_id
7
+
8
+ # Check if the message came from a user
9
+ #
10
+ # @param [Boolean] true/false whether the message came from a user
11
+ def user?
12
+ role == "user"
13
+ end
14
+ end
15
+ end
16
+ end
@@ -0,0 +1,90 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain
4
+ module Messages
5
+ class GoogleGeminiMessage < Base
6
+ # Google Gemini uses the following roles:
7
+ ROLES = [
8
+ "user",
9
+ "model",
10
+ "function"
11
+ ].freeze
12
+
13
+ TOOL_ROLE = "function"
14
+
15
+ # Initialize a new Google Gemini message
16
+ #
17
+ # @param [String] The role of the message
18
+ # @param [String] The content of the message
19
+ # @param [Array<Hash>] The tool calls made in the message
20
+ # @param [String] The ID of the tool call
21
+ def initialize(role:, content: nil, tool_calls: [], tool_call_id: nil)
22
+ raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role)
23
+ raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) }
24
+
25
+ @role = role
26
+ # Some Tools return content as a JSON hence `.to_s`
27
+ @content = content.to_s
28
+ @tool_calls = tool_calls
29
+ @tool_call_id = tool_call_id
30
+ end
31
+
32
+ # Check if the message came from an LLM
33
+ #
34
+ # @return [Boolean] true/false whether this message was produced by an LLM
35
+ def llm?
36
+ model?
37
+ end
38
+
39
+ # Convert the message to a Google Gemini API-compatible hash
40
+ #
41
+ # @return [Hash] The message as a Google Gemini API-compatible hash
42
+ def to_hash
43
+ {}.tap do |h|
44
+ h[:role] = role
45
+ h[:parts] = if function?
46
+ [{
47
+ functionResponse: {
48
+ name: tool_call_id,
49
+ response: {
50
+ name: tool_call_id,
51
+ content: content
52
+ }
53
+ }
54
+ }]
55
+ elsif tool_calls.any?
56
+ tool_calls
57
+ else
58
+ [{text: content}]
59
+ end
60
+ end
61
+ end
62
+
63
+ # Google Gemini does not implement system prompts
64
+ def system?
65
+ false
66
+ end
67
+
68
+ # Check if the message is a tool call
69
+ #
70
+ # @return [Boolean] true/false whether this message is a tool call
71
+ def tool?
72
+ function?
73
+ end
74
+
75
+ # Check if the message is a tool call
76
+ #
77
+ # @return [Boolean] true/false whether this message is a tool call
78
+ def function?
79
+ role == "function"
80
+ end
81
+
82
+ # Check if the message came from an LLM
83
+ #
84
+ # @return [Boolean] true/false whether this message was produced by an LLM
85
+ def model?
86
+ role == "model"
87
+ end
88
+ end
89
+ end
90
+ end
@@ -0,0 +1,74 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain
4
+ module Messages
5
+ class OpenAIMessage < Base
6
+ # OpenAI uses the following roles:
7
+ ROLES = [
8
+ "system",
9
+ "assistant",
10
+ "user",
11
+ "tool"
12
+ ].freeze
13
+
14
+ TOOL_ROLE = "tool"
15
+
16
+ # Initialize a new OpenAI message
17
+ #
18
+ # @param [String] The role of the message
19
+ # @param [String] The content of the message
20
+ # @param [Array<Hash>] The tool calls made in the message
21
+ # @param [String] The ID of the tool call
22
+ def initialize(role:, content: nil, tool_calls: [], tool_call_id: nil) # TODO: Implement image_file: reference (https://platform.openai.com/docs/api-reference/messages/object#messages/object-content)
23
+ raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role)
24
+ raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) }
25
+
26
+ @role = role
27
+ # Some Tools return content as a JSON hence `.to_s`
28
+ @content = content.to_s
29
+ @tool_calls = tool_calls
30
+ @tool_call_id = tool_call_id
31
+ end
32
+
33
+ # Check if the message came from an LLM
34
+ #
35
+ # @return [Boolean] true/false whether this message was produced by an LLM
36
+ def llm?
37
+ assistant?
38
+ end
39
+
40
+ # Convert the message to an OpenAI API-compatible hash
41
+ #
42
+ # @return [Hash] The message as an OpenAI API-compatible hash
43
+ def to_hash
44
+ {}.tap do |h|
45
+ h[:role] = role
46
+ h[:content] = content if content # Content is nil for tool calls
47
+ h[:tool_calls] = tool_calls if tool_calls.any?
48
+ h[:tool_call_id] = tool_call_id if tool_call_id
49
+ end
50
+ end
51
+
52
+ # Check if the message came from an LLM
53
+ #
54
+ # @return [Boolean] true/false whether this message was produced by an LLM
55
+ def assistant?
56
+ role == "assistant"
57
+ end
58
+
59
+ # Check if the message are system instructions
60
+ #
61
+ # @return [Boolean] true/false whether this message are system instructions
62
+ def system?
63
+ role == "system"
64
+ end
65
+
66
+ # Check if the message is a tool call
67
+ #
68
+ # @return [Boolean] true/false whether this message is a tool call
69
+ def tool?
70
+ role == "tool"
71
+ end
72
+ end
73
+ end
74
+ end
@@ -8,16 +8,16 @@ module Langchain
8
8
 
9
9
  # @param messages [Array<Langchain::Message>]
10
10
  def initialize(messages: [])
11
- raise ArgumentError, "messages array must only contain Langchain::Message instance(s)" unless messages.is_a?(Array) && messages.all? { |m| m.is_a?(Langchain::Message) }
11
+ raise ArgumentError, "messages array must only contain Langchain::Message instance(s)" unless messages.is_a?(Array) && messages.all? { |m| m.is_a?(Langchain::Messages::Base) }
12
12
 
13
13
  @messages = messages
14
14
  end
15
15
 
16
- # Convert the thread to an OpenAI API-compatible array of hashes
16
+ # Convert the thread to an LLM APIs-compatible array of hashes
17
17
  #
18
18
  # @return [Array<Hash>] The thread as an OpenAI API-compatible array of hashes
19
- def openai_messages
20
- messages.map(&:to_openai_format)
19
+ def array_of_message_hashes
20
+ messages.map(&:to_hash)
21
21
  end
22
22
 
23
23
  # Add a message to the thread
@@ -25,7 +25,7 @@ module Langchain
25
25
  # @param message [Langchain::Message] The message to add
26
26
  # @return [Array<Langchain::Message>] The updated messages array
27
27
  def add_message(message)
28
- raise ArgumentError, "message must be a Langchain::Message instance" unless message.is_a?(Langchain::Message)
28
+ raise ArgumentError, "message must be a Langchain::Message instance" unless message.is_a?(Langchain::Messages::Base)
29
29
 
30
30
  # Prepend the message to the thread
31
31
  messages << message
@@ -42,6 +42,8 @@ module Langchain
42
42
 
43
43
  def count_verified_statements(verifications)
44
44
  match = verifications.match(/Final verdict for each statement in order:\s*(.*)/)
45
+ return 0.0 unless match # no verified statements found
46
+
45
47
  verdicts = match.captures.first
46
48
  verdicts
47
49
  .split(".")
@@ -11,7 +11,8 @@ module Langchain::LLM
11
11
  # - {Langchain::LLM::Azure}
12
12
  # - {Langchain::LLM::Cohere}
13
13
  # - {Langchain::LLM::GooglePalm}
14
- # - {Langchain::LLM::GoogleVertexAi}
14
+ # - {Langchain::LLM::GoogleVertexAI}
15
+ # - {Langchain::LLM::GoogleGemini}
15
16
  # - {Langchain::LLM::HuggingFace}
16
17
  # - {Langchain::LLM::LlamaCpp}
17
18
  # - {Langchain::LLM::OpenAI}
@@ -0,0 +1,67 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain::LLM
4
+ # Usage:
5
+ # llm = Langchain::LLM::GoogleGemini.new(api_key: ENV['GOOGLE_GEMINI_API_KEY'])
6
+ class GoogleGemini < Base
7
+ DEFAULTS = {
8
+ chat_completion_model_name: "gemini-1.5-pro-latest",
9
+ temperature: 0.0
10
+ }
11
+
12
+ attr_reader :defaults, :api_key
13
+
14
+ def initialize(api_key:, default_options: {})
15
+ @api_key = api_key
16
+ @defaults = DEFAULTS.merge(default_options)
17
+
18
+ chat_parameters.update(
19
+ model: {default: @defaults[:chat_completion_model_name]},
20
+ temperature: {default: @defaults[:temperature]}
21
+ )
22
+ chat_parameters.remap(
23
+ messages: :contents,
24
+ system: :system_instruction,
25
+ tool_choice: :tool_config
26
+ )
27
+ end
28
+
29
+ # Generate a chat completion for a given prompt
30
+ #
31
+ # @param messages [Array<Hash>] List of messages comprising the conversation so far
32
+ # @param model [String] The model to use
33
+ # @param tools [Array<Hash>] A list of Tools the model may use to generate the next response
34
+ # @param tool_choice [String] Specifies the mode in which function calling should execute. If unspecified, the default value will be set to AUTO. Possible values: AUTO, ANY, NONE
35
+ # @param system [String] Developer set system instruction
36
+ def chat(params = {})
37
+ params[:system] = {parts: [{text: params[:system]}]} if params[:system]
38
+ params[:tools] = {function_declarations: params[:tools]} if params[:tools]
39
+ params[:tool_choice] = {function_calling_config: {mode: params[:tool_choice].upcase}} if params[:tool_choice]
40
+
41
+ raise ArgumentError.new("messages argument is required") if Array(params[:messages]).empty?
42
+
43
+ parameters = chat_parameters.to_params(params)
44
+ parameters[:generation_config] = {temperature: parameters.delete(:temperature)} if parameters[:temperature]
45
+
46
+ uri = URI("https://generativelanguage.googleapis.com/v1beta/models/#{parameters[:model]}:generateContent?key=#{api_key}")
47
+
48
+ request = Net::HTTP::Post.new(uri)
49
+ request.content_type = "application/json"
50
+ request.body = parameters.to_json
51
+
52
+ response = Net::HTTP.start(uri.hostname, uri.port, use_ssl: uri.scheme == "https") do |http|
53
+ http.request(request)
54
+ end
55
+
56
+ parsed_response = JSON.parse(response.body)
57
+
58
+ wrapped_response = Langchain::LLM::GoogleGeminiResponse.new(parsed_response, model: parameters[:model])
59
+
60
+ if wrapped_response.chat_completion || Array(wrapped_response.tool_calls).any?
61
+ wrapped_response
62
+ else
63
+ raise StandardError.new(response)
64
+ end
65
+ end
66
+ end
67
+ end