langchainrb 0.12.0 → 0.13.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +12 -0
  3. data/README.md +3 -2
  4. data/lib/langchain/assistants/assistant.rb +75 -20
  5. data/lib/langchain/assistants/messages/base.rb +16 -0
  6. data/lib/langchain/assistants/messages/google_gemini_message.rb +90 -0
  7. data/lib/langchain/assistants/messages/openai_message.rb +74 -0
  8. data/lib/langchain/assistants/thread.rb +5 -5
  9. data/lib/langchain/llm/anthropic.rb +27 -49
  10. data/lib/langchain/llm/aws_bedrock.rb +30 -34
  11. data/lib/langchain/llm/azure.rb +6 -0
  12. data/lib/langchain/llm/base.rb +20 -1
  13. data/lib/langchain/llm/cohere.rb +38 -6
  14. data/lib/langchain/llm/google_gemini.rb +67 -0
  15. data/lib/langchain/llm/google_vertex_ai.rb +68 -112
  16. data/lib/langchain/llm/mistral_ai.rb +10 -19
  17. data/lib/langchain/llm/ollama.rb +23 -27
  18. data/lib/langchain/llm/openai.rb +20 -48
  19. data/lib/langchain/llm/parameters/chat.rb +51 -0
  20. data/lib/langchain/llm/response/base_response.rb +2 -2
  21. data/lib/langchain/llm/response/cohere_response.rb +16 -0
  22. data/lib/langchain/llm/response/google_gemini_response.rb +45 -0
  23. data/lib/langchain/llm/response/openai_response.rb +5 -1
  24. data/lib/langchain/llm/unified_parameters.rb +98 -0
  25. data/lib/langchain/loader.rb +6 -0
  26. data/lib/langchain/tool/base.rb +16 -6
  27. data/lib/langchain/tool/calculator/calculator.json +1 -1
  28. data/lib/langchain/tool/database/database.json +3 -3
  29. data/lib/langchain/tool/file_system/file_system.json +3 -3
  30. data/lib/langchain/tool/news_retriever/news_retriever.json +121 -0
  31. data/lib/langchain/tool/news_retriever/news_retriever.rb +132 -0
  32. data/lib/langchain/tool/ruby_code_interpreter/ruby_code_interpreter.json +1 -1
  33. data/lib/langchain/tool/vectorsearch/vectorsearch.json +1 -1
  34. data/lib/langchain/tool/weather/weather.json +1 -1
  35. data/lib/langchain/tool/wikipedia/wikipedia.json +1 -1
  36. data/lib/langchain/tool/wikipedia/wikipedia.rb +2 -2
  37. data/lib/langchain/utils/token_length/openai_validator.rb +6 -1
  38. data/lib/langchain/version.rb +1 -1
  39. data/lib/langchain.rb +3 -0
  40. metadata +22 -15
  41. data/lib/langchain/assistants/message.rb +0 -58
  42. data/lib/langchain/llm/response/google_vertex_ai_response.rb +0 -33
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 7f29aad35bc35dc95eb8673b11578b51c7449a19818989d9da5e640c6fb219c7
4
- data.tar.gz: 4d0c4d3d424a82c7f02fb9e49ca52a5bdca5dfbce19fbfa22f2d74ef46d81eb7
3
+ metadata.gz: b146eb8568d30ae12aca93a25818fcff7421b7ee2e968330f3a68c5e523da148
4
+ data.tar.gz: 33f88d7ba03501606706314dce58f626fa0df5aab50639b5f5db3df527ee6520
5
5
  SHA512:
6
- metadata.gz: 91b6f4fc5056308eab9119dcfda1be16857e6e9e6e531977148b1e8f31b72090794b67e6855afb95633b8f836b8d20921bc5a069afdc745d1114892143a177e1
7
- data.tar.gz: f7a7949ab2efd960eacf3a93f7beaa9104403a93619b8c95ea094901c2d3d19b89980c81d293ae16035c5ff51fe021a09f2e81e2c0ed6854bff87d30e6def925
6
+ metadata.gz: 6518e30de12653426280f6f8cf05f37a6d4b311ad4219af52276bace8a75ec6440b8f42c208d9d5c07bb4218f3259cc95ea9edd77f8ba037e7a1600a7dfa3170
7
+ data.tar.gz: 7f881a4347866c8b52161adaf6b98b669e38b4e2fd1ac513f02efd1dcfe73b2552ad8e65acdd02e9e002769ad3b394850f720e0971774632c2f489c25d9ce076
data/CHANGELOG.md CHANGED
@@ -1,5 +1,17 @@
1
1
  ## [Unreleased]
2
2
 
3
+ ## [0.13.0] - 2024-05-14
4
+ - New 🛠️ `Langchain::Tool::NewsRetriever` tool to fetch news via newsapi.org
5
+ - Langchain::Assistant works with `Langchain::LLM::GoogleVertexAI` and `Langchain::LLM::GoogleGemini` llms
6
+ - [BREAKING] Introduce new `Langchain::Messages::Base` abstraction
7
+
8
+ ## [0.12.1] - 2024-05-13
9
+ - Langchain::LLM::Ollama now uses `llama3` by default
10
+ - Langchain::LLM::Anthropic#complete() now uses `claude-2.1` by default
11
+ - Updated with new OpenAI models, including `gpt-4o`
12
+ - New `Langchain::LLM::Cohere#chat()` method.
13
+ - Introducing `UnifiedParameters` to unify parameters across LLM classes
14
+
3
15
  ## [0.12.0] - 2024-04-22
4
16
  - [BREAKING] Rename `dimension` parameter to `dimensions` everywhere
5
17
 
data/README.md CHANGED
@@ -412,6 +412,7 @@ Assistants are Agent-like objects that leverage helpful instructions, LLMs, tool
412
412
  | "file_system" | Interacts with the file system | | |
413
413
  | "ruby_code_interpreter" | Interprets Ruby expressions | | `gem "safe_ruby", "~> 1.0.4"` |
414
414
  | "google_search" | A wrapper around Google Search | `ENV["SERPAPI_API_KEY"]` (https://serpapi.com/manage-api-key) | `gem "google_search_results", "~> 2.0.0"` |
415
+ | "news_retriever" | A wrapper around NewsApi.org | `ENV["NEWS_API_KEY"]` (https://newsapi.org/) | |
415
416
  | "weather" | Calls Open Weather API to retrieve the current weather | `ENV["OPEN_WEATHER_API_KEY"]` (https://home.openweathermap.org/api_keys) | `gem "open-weather-ruby-client", "~> 0.3.0"` |
416
417
  | "wikipedia" | Calls Wikipedia API to retrieve the summary | | `gem "wikipedia-client", "~> 1.17.0"` |
417
418
 
@@ -445,14 +446,14 @@ assistant = Langchain::Assistant.new(
445
446
  thread: thread,
446
447
  instructions: "You are a Meteorologist Assistant that is able to pull the weather for any location",
447
448
  tools: [
448
- Langchain::Tool::GoogleSearch.new(api_key: ENV["SERPAPI_API_KEY"])
449
+ Langchain::Tool::Weather.new(api_key: ENV["OPEN_WEATHER_API_KEY"])
449
450
  ]
450
451
  )
451
452
  ```
452
453
  ### Using an Assistant
453
454
  You can now add your message to an Assistant.
454
455
  ```ruby
455
- assistant.add_message content: "What's the weather in New York City?"
456
+ assistant.add_message content: "What's the weather in New York, New York?"
456
457
  ```
457
458
 
458
459
  Run the Assistant to generate a response.
@@ -7,6 +7,12 @@ module Langchain
7
7
  attr_reader :llm, :thread, :instructions
8
8
  attr_accessor :tools
9
9
 
10
+ SUPPORTED_LLMS = [
11
+ Langchain::LLM::OpenAI,
12
+ Langchain::LLM::GoogleGemini,
13
+ Langchain::LLM::GoogleVertexAI
14
+ ]
15
+
10
16
  # Create a new assistant
11
17
  #
12
18
  # @param llm [Langchain::LLM::Base] LLM instance that the assistant will use
@@ -19,7 +25,9 @@ module Langchain
19
25
  tools: [],
20
26
  instructions: nil
21
27
  )
22
- raise ArgumentError, "Invalid LLM; currently only Langchain::LLM::OpenAI is supported" unless llm.instance_of?(Langchain::LLM::OpenAI)
28
+ unless SUPPORTED_LLMS.include?(llm.class)
29
+ raise ArgumentError, "Invalid LLM; currently only #{SUPPORTED_LLMS.join(", ")} are supported"
30
+ end
23
31
  raise ArgumentError, "Thread must be an instance of Langchain::Thread" unless thread.is_a?(Langchain::Thread)
24
32
  raise ArgumentError, "Tools must be an array of Langchain::Tool::Base instance(s)" unless tools.is_a?(Array) && tools.all? { |tool| tool.is_a?(Langchain::Tool::Base) }
25
33
 
@@ -30,7 +38,10 @@ module Langchain
30
38
 
31
39
  # The first message in the thread should be the system instructions
32
40
  # TODO: What if the user added old messages and the system instructions are already in there? Should this overwrite the existing instructions?
33
- add_message(role: "system", content: instructions) if instructions
41
+ if llm.is_a?(Langchain::LLM::OpenAI)
42
+ add_message(role: "system", content: instructions) if instructions
43
+ end
44
+ # For Google Gemini, system instructions are added to the `system:` param in the `chat` method
34
45
  end
35
46
 
36
47
  # Add a user message to the thread
@@ -59,11 +70,12 @@ module Langchain
59
70
 
60
71
  while running
61
72
  # TODO: I think we need to look at all messages and not just the last one.
62
- case (last_message = thread.messages.last).role
63
- when "system"
73
+ last_message = thread.messages.last
74
+
75
+ if last_message.system?
64
76
  # Do nothing
65
77
  running = false
66
- when "assistant"
78
+ elsif last_message.llm?
67
79
  if last_message.tool_calls.any?
68
80
  if auto_tool_execution
69
81
  run_tools(last_message.tool_calls)
@@ -76,11 +88,11 @@ module Langchain
76
88
  # Do nothing
77
89
  running = false
78
90
  end
79
- when "user"
91
+ elsif last_message.user?
80
92
  # Run it!
81
93
  response = chat_with_llm
82
94
 
83
- if response.tool_calls
95
+ if response.tool_calls.any?
84
96
  # Re-run the while(running) loop to process the tool calls
85
97
  running = true
86
98
  add_message(role: response.role, tool_calls: response.tool_calls)
@@ -89,12 +101,12 @@ module Langchain
89
101
  running = false
90
102
  add_message(role: response.role, content: response.chat_completion)
91
103
  end
92
- when "tool"
104
+ elsif last_message.tool?
93
105
  # Run it!
94
106
  response = chat_with_llm
95
107
  running = true
96
108
 
97
- if response.tool_calls
109
+ if response.tool_calls.any?
98
110
  add_message(role: response.role, tool_calls: response.tool_calls)
99
111
  elsif response.chat_completion
100
112
  add_message(role: response.role, content: response.chat_completion)
@@ -121,8 +133,14 @@ module Langchain
121
133
  # @param output [String] The output of the tool
122
134
  # @return [Array<Langchain::Message>] The messages in the thread
123
135
  def submit_tool_output(tool_call_id:, output:)
124
- # TODO: Validate that `tool_call_id` is valid
125
- add_message(role: "tool", content: output, tool_call_id: tool_call_id)
136
+ tool_role = if llm.is_a?(Langchain::LLM::OpenAI)
137
+ Langchain::Messages::OpenAIMessage::TOOL_ROLE
138
+ elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
139
+ Langchain::Messages::GoogleGeminiMessage::TOOL_ROLE
140
+ end
141
+
142
+ # TODO: Validate that `tool_call_id` is valid by scanning messages and checking if this tool call ID was invoked
143
+ add_message(role: tool_role, content: output, tool_call_id: tool_call_id)
126
144
  end
127
145
 
128
146
  # Delete all messages in the thread
@@ -156,10 +174,15 @@ module Langchain
156
174
  def chat_with_llm
157
175
  Langchain.logger.info("Sending a call to #{llm.class}", for: self.class)
158
176
 
159
- params = {messages: thread.openai_messages}
177
+ params = {messages: thread.array_of_message_hashes}
160
178
 
161
179
  if tools.any?
162
- params[:tools] = tools.map(&:to_openai_tools).flatten
180
+ if llm.is_a?(Langchain::LLM::OpenAI)
181
+ params[:tools] = tools.map(&:to_openai_tools).flatten
182
+ elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
183
+ params[:tools] = tools.map(&:to_google_gemini_tools).flatten
184
+ params[:system] = instructions if instructions
185
+ end
163
186
  # TODO: Not sure that tool_choice should always be "auto"; Maybe we can let the user toggle it.
164
187
  params[:tool_choice] = "auto"
165
188
  end
@@ -173,11 +196,11 @@ module Langchain
173
196
  def run_tools(tool_calls)
174
197
  # Iterate over each function invocation and submit tool output
175
198
  tool_calls.each do |tool_call|
176
- tool_call_id = tool_call.dig("id")
177
-
178
- function_name = tool_call.dig("function", "name")
179
- tool_name, method_name = function_name.split("-")
180
- tool_arguments = JSON.parse(tool_call.dig("function", "arguments"), symbolize_names: true)
199
+ tool_call_id, tool_name, method_name, tool_arguments = if llm.is_a?(Langchain::LLM::OpenAI)
200
+ extract_openai_tool_call(tool_call: tool_call)
201
+ elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
202
+ extract_google_gemini_tool_call(tool_call: tool_call)
203
+ end
181
204
 
182
205
  tool_instance = tools.find do |t|
183
206
  t.name == tool_name
@@ -190,13 +213,41 @@ module Langchain
190
213
 
191
214
  response = chat_with_llm
192
215
 
193
- if response.tool_calls
216
+ if response.tool_calls.any?
194
217
  add_message(role: response.role, tool_calls: response.tool_calls)
195
218
  elsif response.chat_completion
196
219
  add_message(role: response.role, content: response.chat_completion)
197
220
  end
198
221
  end
199
222
 
223
+ # Extract the tool call information from the OpenAI tool call hash
224
+ #
225
+ # @param tool_call [Hash] The tool call hash
226
+ # @return [Array] The tool call information
227
+ def extract_openai_tool_call(tool_call:)
228
+ tool_call_id = tool_call.dig("id")
229
+
230
+ function_name = tool_call.dig("function", "name")
231
+ tool_name, method_name = function_name.split("__")
232
+ tool_arguments = JSON.parse(tool_call.dig("function", "arguments"), symbolize_names: true)
233
+
234
+ [tool_call_id, tool_name, method_name, tool_arguments]
235
+ end
236
+
237
+ # Extract the tool call information from the Google Gemini tool call hash
238
+ #
239
+ # @param tool_call [Hash] The tool call hash, format: {"functionCall"=>{"name"=>"weather__execute", "args"=>{"input"=>"NYC"}}}
240
+ # @return [Array] The tool call information
241
+ def extract_google_gemini_tool_call(tool_call:)
242
+ tool_call_id = tool_call.dig("functionCall", "name")
243
+
244
+ function_name = tool_call.dig("functionCall", "name")
245
+ tool_name, method_name = function_name.split("__")
246
+ tool_arguments = tool_call.dig("functionCall", "args").transform_keys(&:to_sym)
247
+
248
+ [tool_call_id, tool_name, method_name, tool_arguments]
249
+ end
250
+
200
251
  # Build a message
201
252
  #
202
253
  # @param role [String] The role of the message
@@ -205,7 +256,11 @@ module Langchain
205
256
  # @param tool_call_id [String] The ID of the tool call to include in the message
206
257
  # @return [Langchain::Message] The Message object
207
258
  def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
208
- Message.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
259
+ if llm.is_a?(Langchain::LLM::OpenAI)
260
+ Langchain::Messages::OpenAIMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
261
+ elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
262
+ Langchain::Messages::GoogleGeminiMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
263
+ end
209
264
  end
210
265
 
211
266
  # TODO: Fix the message truncation when context window is exceeded
@@ -0,0 +1,16 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain
4
+ module Messages
5
+ class Base
6
+ attr_reader :role, :content, :tool_calls, :tool_call_id
7
+
8
+ # Check if the message came from a user
9
+ #
10
+ # @param [Boolean] true/false whether the message came from a user
11
+ def user?
12
+ role == "user"
13
+ end
14
+ end
15
+ end
16
+ end
@@ -0,0 +1,90 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain
4
+ module Messages
5
+ class GoogleGeminiMessage < Base
6
+ # Google Gemini uses the following roles:
7
+ ROLES = [
8
+ "user",
9
+ "model",
10
+ "function"
11
+ ].freeze
12
+
13
+ TOOL_ROLE = "function"
14
+
15
+ # Initialize a new Google Gemini message
16
+ #
17
+ # @param [String] The role of the message
18
+ # @param [String] The content of the message
19
+ # @param [Array<Hash>] The tool calls made in the message
20
+ # @param [String] The ID of the tool call
21
+ def initialize(role:, content: nil, tool_calls: [], tool_call_id: nil)
22
+ raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role)
23
+ raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) }
24
+
25
+ @role = role
26
+ # Some Tools return content as a JSON hence `.to_s`
27
+ @content = content.to_s
28
+ @tool_calls = tool_calls
29
+ @tool_call_id = tool_call_id
30
+ end
31
+
32
+ # Check if the message came from an LLM
33
+ #
34
+ # @return [Boolean] true/false whether this message was produced by an LLM
35
+ def llm?
36
+ model?
37
+ end
38
+
39
+ # Convert the message to a Google Gemini API-compatible hash
40
+ #
41
+ # @return [Hash] The message as a Google Gemini API-compatible hash
42
+ def to_hash
43
+ {}.tap do |h|
44
+ h[:role] = role
45
+ h[:parts] = if function?
46
+ [{
47
+ functionResponse: {
48
+ name: tool_call_id,
49
+ response: {
50
+ name: tool_call_id,
51
+ content: content
52
+ }
53
+ }
54
+ }]
55
+ elsif tool_calls.any?
56
+ tool_calls
57
+ else
58
+ [{text: content}]
59
+ end
60
+ end
61
+ end
62
+
63
+ # Google Gemini does not implement system prompts
64
+ def system?
65
+ false
66
+ end
67
+
68
+ # Check if the message is a tool call
69
+ #
70
+ # @return [Boolean] true/false whether this message is a tool call
71
+ def tool?
72
+ function?
73
+ end
74
+
75
+ # Check if the message is a tool call
76
+ #
77
+ # @return [Boolean] true/false whether this message is a tool call
78
+ def function?
79
+ role == "function"
80
+ end
81
+
82
+ # Check if the message came from an LLM
83
+ #
84
+ # @return [Boolean] true/false whether this message was produced by an LLM
85
+ def model?
86
+ role == "model"
87
+ end
88
+ end
89
+ end
90
+ end
@@ -0,0 +1,74 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain
4
+ module Messages
5
+ class OpenAIMessage < Base
6
+ # OpenAI uses the following roles:
7
+ ROLES = [
8
+ "system",
9
+ "assistant",
10
+ "user",
11
+ "tool"
12
+ ].freeze
13
+
14
+ TOOL_ROLE = "tool"
15
+
16
+ # Initialize a new OpenAI message
17
+ #
18
+ # @param [String] The role of the message
19
+ # @param [String] The content of the message
20
+ # @param [Array<Hash>] The tool calls made in the message
21
+ # @param [String] The ID of the tool call
22
+ def initialize(role:, content: nil, tool_calls: [], tool_call_id: nil) # TODO: Implement image_file: reference (https://platform.openai.com/docs/api-reference/messages/object#messages/object-content)
23
+ raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role)
24
+ raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) }
25
+
26
+ @role = role
27
+ # Some Tools return content as a JSON hence `.to_s`
28
+ @content = content.to_s
29
+ @tool_calls = tool_calls
30
+ @tool_call_id = tool_call_id
31
+ end
32
+
33
+ # Check if the message came from an LLM
34
+ #
35
+ # @return [Boolean] true/false whether this message was produced by an LLM
36
+ def llm?
37
+ assistant?
38
+ end
39
+
40
+ # Convert the message to an OpenAI API-compatible hash
41
+ #
42
+ # @return [Hash] The message as an OpenAI API-compatible hash
43
+ def to_hash
44
+ {}.tap do |h|
45
+ h[:role] = role
46
+ h[:content] = content if content # Content is nil for tool calls
47
+ h[:tool_calls] = tool_calls if tool_calls.any?
48
+ h[:tool_call_id] = tool_call_id if tool_call_id
49
+ end
50
+ end
51
+
52
+ # Check if the message came from an LLM
53
+ #
54
+ # @return [Boolean] true/false whether this message was produced by an LLM
55
+ def assistant?
56
+ role == "assistant"
57
+ end
58
+
59
+ # Check if the message are system instructions
60
+ #
61
+ # @return [Boolean] true/false whether this message are system instructions
62
+ def system?
63
+ role == "system"
64
+ end
65
+
66
+ # Check if the message is a tool call
67
+ #
68
+ # @return [Boolean] true/false whether this message is a tool call
69
+ def tool?
70
+ role == "tool"
71
+ end
72
+ end
73
+ end
74
+ end
@@ -8,16 +8,16 @@ module Langchain
8
8
 
9
9
  # @param messages [Array<Langchain::Message>]
10
10
  def initialize(messages: [])
11
- raise ArgumentError, "messages array must only contain Langchain::Message instance(s)" unless messages.is_a?(Array) && messages.all? { |m| m.is_a?(Langchain::Message) }
11
+ raise ArgumentError, "messages array must only contain Langchain::Message instance(s)" unless messages.is_a?(Array) && messages.all? { |m| m.is_a?(Langchain::Messages::Base) }
12
12
 
13
13
  @messages = messages
14
14
  end
15
15
 
16
- # Convert the thread to an OpenAI API-compatible array of hashes
16
+ # Convert the thread to an LLM APIs-compatible array of hashes
17
17
  #
18
18
  # @return [Array<Hash>] The thread as an OpenAI API-compatible array of hashes
19
- def openai_messages
20
- messages.map(&:to_openai_format)
19
+ def array_of_message_hashes
20
+ messages.map(&:to_hash)
21
21
  end
22
22
 
23
23
  # Add a message to the thread
@@ -25,7 +25,7 @@ module Langchain
25
25
  # @param message [Langchain::Message] The message to add
26
26
  # @return [Array<Langchain::Message>] The updated messages array
27
27
  def add_message(message)
28
- raise ArgumentError, "message must be a Langchain::Message instance" unless message.is_a?(Langchain::Message)
28
+ raise ArgumentError, "message must be a Langchain::Message instance" unless message.is_a?(Langchain::Messages::Base)
29
29
 
30
30
  # Prepend the message to the thread
31
31
  messages << message
@@ -13,7 +13,7 @@ module Langchain::LLM
13
13
  class Anthropic < Base
14
14
  DEFAULTS = {
15
15
  temperature: 0.0,
16
- completion_model_name: "claude-2",
16
+ completion_model_name: "claude-2.1",
17
17
  chat_completion_model_name: "claude-3-sonnet-20240229",
18
18
  max_tokens_to_sample: 256
19
19
  }.freeze
@@ -32,6 +32,15 @@ module Langchain::LLM
32
32
 
33
33
  @client = ::Anthropic::Client.new(access_token: api_key, **llm_options)
34
34
  @defaults = DEFAULTS.merge(default_options)
35
+ chat_parameters.update(
36
+ model: {default: @defaults[:chat_completion_model_name]},
37
+ temperature: {default: @defaults[:temperature]},
38
+ max_tokens: {default: @defaults[:max_tokens_to_sample]},
39
+ metadata: {},
40
+ system: {}
41
+ )
42
+ chat_parameters.ignore(:n, :user)
43
+ chat_parameters.remap(stop: :stop_sequences)
35
44
  end
36
45
 
37
46
  # Generate a completion for a given prompt
@@ -72,66 +81,35 @@ module Langchain::LLM
72
81
  parameters[:metadata] = metadata if metadata
73
82
  parameters[:stream] = stream if stream
74
83
 
75
- # TODO: Implement token length validator for Anthropic
76
- # parameters[:max_tokens_to_sample] = validate_max_tokens(prompt, parameters[:completion_model_name])
77
-
78
84
  response = client.complete(parameters: parameters)
79
85
  Langchain::LLM::AnthropicResponse.new(response)
80
86
  end
81
87
 
82
88
  # Generate a chat completion for given messages
83
89
  #
84
- # @param messages [Array<String>] Input messages
85
- # @param model [String] The model that will complete your prompt
86
- # @param max_tokens [Integer] Maximum number of tokens to generate before stopping
87
- # @param metadata [Hash] Object describing metadata about the request
88
- # @param stop_sequences [Array<String>] Custom text sequences that will cause the model to stop generating
89
- # @param stream [Boolean] Whether to incrementally stream the response using server-sent events
90
- # @param system [String] System prompt
91
- # @param temperature [Float] Amount of randomness injected into the response
92
- # @param tools [Array<String>] Definitions of tools that the model may use
93
- # @param top_k [Integer] Only sample from the top K options for each subsequent token
94
- # @param top_p [Float] Use nucleus sampling.
90
+ # @param [Hash] params unified chat parmeters from [Langchain::LLM::Parameters::Chat::SCHEMA]
91
+ # @option params [Array<String>] :messages Input messages
92
+ # @option params [String] :model The model that will complete your prompt
93
+ # @option params [Integer] :max_tokens Maximum number of tokens to generate before stopping
94
+ # @option params [Hash] :metadata Object describing metadata about the request
95
+ # @option params [Array<String>] :stop_sequences Custom text sequences that will cause the model to stop generating
96
+ # @option params [Boolean] :stream Whether to incrementally stream the response using server-sent events
97
+ # @option params [String] :system System prompt
98
+ # @option params [Float] :temperature Amount of randomness injected into the response
99
+ # @option params [Array<String>] :tools Definitions of tools that the model may use
100
+ # @option params [Integer] :top_k Only sample from the top K options for each subsequent token
101
+ # @option params [Float] :top_p Use nucleus sampling.
95
102
  # @return [Langchain::LLM::AnthropicResponse] The chat completion
96
- def chat(
97
- messages: [],
98
- model: @defaults[:chat_completion_model_name],
99
- max_tokens: @defaults[:max_tokens_to_sample],
100
- metadata: nil,
101
- stop_sequences: nil,
102
- stream: nil,
103
- system: nil,
104
- temperature: @defaults[:temperature],
105
- tools: [],
106
- top_k: nil,
107
- top_p: nil
108
- )
109
- raise ArgumentError.new("messages argument is required") if messages.empty?
110
- raise ArgumentError.new("model argument is required") if model.empty?
111
- raise ArgumentError.new("max_tokens argument is required") if max_tokens.nil?
103
+ def chat(params = {})
104
+ parameters = chat_parameters.to_params(params)
112
105
 
113
- parameters = {
114
- messages: messages,
115
- model: model,
116
- max_tokens: max_tokens,
117
- temperature: temperature
118
- }
119
- parameters[:metadata] = metadata if metadata
120
- parameters[:stop_sequences] = stop_sequences if stop_sequences
121
- parameters[:stream] = stream if stream
122
- parameters[:system] = system if system
123
- parameters[:tools] = tools if tools.any?
124
- parameters[:top_k] = top_k if top_k
125
- parameters[:top_p] = top_p if top_p
106
+ raise ArgumentError.new("messages argument is required") if Array(parameters[:messages]).empty?
107
+ raise ArgumentError.new("model argument is required") if parameters[:model].empty?
108
+ raise ArgumentError.new("max_tokens argument is required") if parameters[:max_tokens].nil?
126
109
 
127
110
  response = client.messages(parameters: parameters)
128
111
 
129
112
  Langchain::LLM::AnthropicResponse.new(response)
130
113
  end
131
-
132
- # TODO: Implement token length validator for Anthropic
133
- # def validate_max_tokens(messages, model)
134
- # LENGTH_VALIDATOR.validate_max_tokens!(messages, model)
135
- # end
136
114
  end
137
115
  end
@@ -59,6 +59,17 @@ module Langchain::LLM
59
59
  @defaults = DEFAULTS.merge(default_options)
60
60
  .merge(completion_model_name: completion_model)
61
61
  .merge(embedding_model_name: embedding_model)
62
+
63
+ chat_parameters.update(
64
+ model: {default: @defaults[:chat_completion_model_name]},
65
+ temperature: {},
66
+ max_tokens: {default: @defaults[:max_tokens_to_sample]},
67
+ metadata: {},
68
+ system: {},
69
+ anthropic_version: {default: "bedrock-2023-05-31"}
70
+ )
71
+ chat_parameters.ignore(:n, :user)
72
+ chat_parameters.remap(stop: :stop_sequences)
62
73
  end
63
74
 
64
75
  #
@@ -113,43 +124,28 @@ module Langchain::LLM
113
124
  # Generate a chat completion for a given prompt
114
125
  # Currently only configured to work with the Anthropic provider and
115
126
  # the claude-3 model family
116
- # @param messages [Array] The messages to generate a completion for
117
- # @param system [String] The system prompt to provide instructions
118
- # @param model [String] The model to use for completion defaults to @defaults[:chat_completion_model_name]
119
- # @param max_tokens [Integer] The maximum number of tokens to generate
120
- # @param stop_sequences [Array] The stop sequences to use for completion
121
- # @param temperature [Float] The temperature to use for completion
122
- # @param top_p [Float] The top p to use for completion
123
- # @param top_k [Integer] The top k to use for completion
127
+ #
128
+ # @param [Hash] params unified chat parmeters from [Langchain::LLM::Parameters::Chat::SCHEMA]
129
+ # @option params [Array<String>] :messages The messages to generate a completion for
130
+ # @option params [String] :system The system prompt to provide instructions
131
+ # @option params [String] :model The model to use for completion defaults to @defaults[:chat_completion_model_name]
132
+ # @option params [Integer] :max_tokens The maximum number of tokens to generate defaults to @defaults[:max_tokens_to_sample]
133
+ # @option params [Array<String>] :stop The stop sequences to use for completion
134
+ # @option params [Array<String>] :stop_sequences The stop sequences to use for completion
135
+ # @option params [Float] :temperature The temperature to use for completion
136
+ # @option params [Float] :top_p Use nucleus sampling.
137
+ # @option params [Integer] :top_k Only sample from the top K options for each subsequent token
124
138
  # @return [Langchain::LLM::AnthropicMessagesResponse] Response object
125
- def chat(
126
- messages: [],
127
- system: nil,
128
- model: defaults[:completion_model_name],
129
- max_tokens: defaults[:max_tokens_to_sample],
130
- stop_sequences: nil,
131
- temperature: nil,
132
- top_p: nil,
133
- top_k: nil
134
- )
135
- raise ArgumentError.new("messages argument is required") if messages.empty?
136
-
137
- raise "Model #{model} does not support chat completions." unless Langchain::LLM::AwsBedrock::SUPPORTED_CHAT_COMPLETION_PROVIDERS.include?(completion_provider)
138
-
139
- inference_parameters = {
140
- messages: messages,
141
- max_tokens: max_tokens,
142
- anthropic_version: @defaults[:anthropic_version]
143
- }
144
- inference_parameters[:system] = system if system
145
- inference_parameters[:stop_sequences] = stop_sequences if stop_sequences
146
- inference_parameters[:temperature] = temperature if temperature
147
- inference_parameters[:top_p] = top_p if top_p
148
- inference_parameters[:top_k] = top_k if top_k
139
+ def chat(params = {})
140
+ parameters = chat_parameters.to_params(params)
141
+
142
+ raise ArgumentError.new("messages argument is required") if Array(parameters[:messages]).empty?
143
+
144
+ raise "Model #{parameters[:model]} does not support chat completions." unless Langchain::LLM::AwsBedrock::SUPPORTED_CHAT_COMPLETION_PROVIDERS.include?(completion_provider)
149
145
 
150
146
  response = client.invoke_model({
151
- model_id: model,
152
- body: inference_parameters.to_json,
147
+ model_id: parameters[:model],
148
+ body: parameters.except(:model).to_json,
153
149
  content_type: "application/json",
154
150
  accept: "application/json"
155
151
  })
@@ -32,6 +32,12 @@ module Langchain::LLM
32
32
  **llm_options
33
33
  )
34
34
  @defaults = DEFAULTS.merge(default_options)
35
+ chat_parameters.update(
36
+ logprobs: {},
37
+ top_logprobs: {},
38
+ user: {}
39
+ )
40
+ chat_parameters.ignore(:top_k)
35
41
  end
36
42
 
37
43
  def embed(...)