langchainrb 0.12.0 → 0.13.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (42) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +12 -0
  3. data/README.md +3 -2
  4. data/lib/langchain/assistants/assistant.rb +75 -20
  5. data/lib/langchain/assistants/messages/base.rb +16 -0
  6. data/lib/langchain/assistants/messages/google_gemini_message.rb +90 -0
  7. data/lib/langchain/assistants/messages/openai_message.rb +74 -0
  8. data/lib/langchain/assistants/thread.rb +5 -5
  9. data/lib/langchain/llm/anthropic.rb +27 -49
  10. data/lib/langchain/llm/aws_bedrock.rb +30 -34
  11. data/lib/langchain/llm/azure.rb +6 -0
  12. data/lib/langchain/llm/base.rb +20 -1
  13. data/lib/langchain/llm/cohere.rb +38 -6
  14. data/lib/langchain/llm/google_gemini.rb +67 -0
  15. data/lib/langchain/llm/google_vertex_ai.rb +68 -112
  16. data/lib/langchain/llm/mistral_ai.rb +10 -19
  17. data/lib/langchain/llm/ollama.rb +23 -27
  18. data/lib/langchain/llm/openai.rb +20 -48
  19. data/lib/langchain/llm/parameters/chat.rb +51 -0
  20. data/lib/langchain/llm/response/base_response.rb +2 -2
  21. data/lib/langchain/llm/response/cohere_response.rb +16 -0
  22. data/lib/langchain/llm/response/google_gemini_response.rb +45 -0
  23. data/lib/langchain/llm/response/openai_response.rb +5 -1
  24. data/lib/langchain/llm/unified_parameters.rb +98 -0
  25. data/lib/langchain/loader.rb +6 -0
  26. data/lib/langchain/tool/base.rb +16 -6
  27. data/lib/langchain/tool/calculator/calculator.json +1 -1
  28. data/lib/langchain/tool/database/database.json +3 -3
  29. data/lib/langchain/tool/file_system/file_system.json +3 -3
  30. data/lib/langchain/tool/news_retriever/news_retriever.json +121 -0
  31. data/lib/langchain/tool/news_retriever/news_retriever.rb +132 -0
  32. data/lib/langchain/tool/ruby_code_interpreter/ruby_code_interpreter.json +1 -1
  33. data/lib/langchain/tool/vectorsearch/vectorsearch.json +1 -1
  34. data/lib/langchain/tool/weather/weather.json +1 -1
  35. data/lib/langchain/tool/wikipedia/wikipedia.json +1 -1
  36. data/lib/langchain/tool/wikipedia/wikipedia.rb +2 -2
  37. data/lib/langchain/utils/token_length/openai_validator.rb +6 -1
  38. data/lib/langchain/version.rb +1 -1
  39. data/lib/langchain.rb +3 -0
  40. metadata +22 -15
  41. data/lib/langchain/assistants/message.rb +0 -58
  42. data/lib/langchain/llm/response/google_vertex_ai_response.rb +0 -33
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 7f29aad35bc35dc95eb8673b11578b51c7449a19818989d9da5e640c6fb219c7
4
- data.tar.gz: 4d0c4d3d424a82c7f02fb9e49ca52a5bdca5dfbce19fbfa22f2d74ef46d81eb7
3
+ metadata.gz: b146eb8568d30ae12aca93a25818fcff7421b7ee2e968330f3a68c5e523da148
4
+ data.tar.gz: 33f88d7ba03501606706314dce58f626fa0df5aab50639b5f5db3df527ee6520
5
5
  SHA512:
6
- metadata.gz: 91b6f4fc5056308eab9119dcfda1be16857e6e9e6e531977148b1e8f31b72090794b67e6855afb95633b8f836b8d20921bc5a069afdc745d1114892143a177e1
7
- data.tar.gz: f7a7949ab2efd960eacf3a93f7beaa9104403a93619b8c95ea094901c2d3d19b89980c81d293ae16035c5ff51fe021a09f2e81e2c0ed6854bff87d30e6def925
6
+ metadata.gz: 6518e30de12653426280f6f8cf05f37a6d4b311ad4219af52276bace8a75ec6440b8f42c208d9d5c07bb4218f3259cc95ea9edd77f8ba037e7a1600a7dfa3170
7
+ data.tar.gz: 7f881a4347866c8b52161adaf6b98b669e38b4e2fd1ac513f02efd1dcfe73b2552ad8e65acdd02e9e002769ad3b394850f720e0971774632c2f489c25d9ce076
data/CHANGELOG.md CHANGED
@@ -1,5 +1,17 @@
1
1
  ## [Unreleased]
2
2
 
3
+ ## [0.13.0] - 2024-05-14
4
+ - New 🛠️ `Langchain::Tool::NewsRetriever` tool to fetch news via newsapi.org
5
+ - Langchain::Assistant works with `Langchain::LLM::GoogleVertexAI` and `Langchain::LLM::GoogleGemini` llms
6
+ - [BREAKING] Introduce new `Langchain::Messages::Base` abstraction
7
+
8
+ ## [0.12.1] - 2024-05-13
9
+ - Langchain::LLM::Ollama now uses `llama3` by default
10
+ - Langchain::LLM::Anthropic#complete() now uses `claude-2.1` by default
11
+ - Updated with new OpenAI models, including `gpt-4o`
12
+ - New `Langchain::LLM::Cohere#chat()` method.
13
+ - Introducing `UnifiedParameters` to unify parameters across LLM classes
14
+
3
15
  ## [0.12.0] - 2024-04-22
4
16
  - [BREAKING] Rename `dimension` parameter to `dimensions` everywhere
5
17
 
data/README.md CHANGED
@@ -412,6 +412,7 @@ Assistants are Agent-like objects that leverage helpful instructions, LLMs, tool
412
412
  | "file_system" | Interacts with the file system | | |
413
413
  | "ruby_code_interpreter" | Interprets Ruby expressions | | `gem "safe_ruby", "~> 1.0.4"` |
414
414
  | "google_search" | A wrapper around Google Search | `ENV["SERPAPI_API_KEY"]` (https://serpapi.com/manage-api-key) | `gem "google_search_results", "~> 2.0.0"` |
415
+ | "news_retriever" | A wrapper around NewsApi.org | `ENV["NEWS_API_KEY"]` (https://newsapi.org/) | |
415
416
  | "weather" | Calls Open Weather API to retrieve the current weather | `ENV["OPEN_WEATHER_API_KEY"]` (https://home.openweathermap.org/api_keys) | `gem "open-weather-ruby-client", "~> 0.3.0"` |
416
417
  | "wikipedia" | Calls Wikipedia API to retrieve the summary | | `gem "wikipedia-client", "~> 1.17.0"` |
417
418
 
@@ -445,14 +446,14 @@ assistant = Langchain::Assistant.new(
445
446
  thread: thread,
446
447
  instructions: "You are a Meteorologist Assistant that is able to pull the weather for any location",
447
448
  tools: [
448
- Langchain::Tool::GoogleSearch.new(api_key: ENV["SERPAPI_API_KEY"])
449
+ Langchain::Tool::Weather.new(api_key: ENV["OPEN_WEATHER_API_KEY"])
449
450
  ]
450
451
  )
451
452
  ```
452
453
  ### Using an Assistant
453
454
  You can now add your message to an Assistant.
454
455
  ```ruby
455
- assistant.add_message content: "What's the weather in New York City?"
456
+ assistant.add_message content: "What's the weather in New York, New York?"
456
457
  ```
457
458
 
458
459
  Run the Assistant to generate a response.
@@ -7,6 +7,12 @@ module Langchain
7
7
  attr_reader :llm, :thread, :instructions
8
8
  attr_accessor :tools
9
9
 
10
+ SUPPORTED_LLMS = [
11
+ Langchain::LLM::OpenAI,
12
+ Langchain::LLM::GoogleGemini,
13
+ Langchain::LLM::GoogleVertexAI
14
+ ]
15
+
10
16
  # Create a new assistant
11
17
  #
12
18
  # @param llm [Langchain::LLM::Base] LLM instance that the assistant will use
@@ -19,7 +25,9 @@ module Langchain
19
25
  tools: [],
20
26
  instructions: nil
21
27
  )
22
- raise ArgumentError, "Invalid LLM; currently only Langchain::LLM::OpenAI is supported" unless llm.instance_of?(Langchain::LLM::OpenAI)
28
+ unless SUPPORTED_LLMS.include?(llm.class)
29
+ raise ArgumentError, "Invalid LLM; currently only #{SUPPORTED_LLMS.join(", ")} are supported"
30
+ end
23
31
  raise ArgumentError, "Thread must be an instance of Langchain::Thread" unless thread.is_a?(Langchain::Thread)
24
32
  raise ArgumentError, "Tools must be an array of Langchain::Tool::Base instance(s)" unless tools.is_a?(Array) && tools.all? { |tool| tool.is_a?(Langchain::Tool::Base) }
25
33
 
@@ -30,7 +38,10 @@ module Langchain
30
38
 
31
39
  # The first message in the thread should be the system instructions
32
40
  # TODO: What if the user added old messages and the system instructions are already in there? Should this overwrite the existing instructions?
33
- add_message(role: "system", content: instructions) if instructions
41
+ if llm.is_a?(Langchain::LLM::OpenAI)
42
+ add_message(role: "system", content: instructions) if instructions
43
+ end
44
+ # For Google Gemini, system instructions are added to the `system:` param in the `chat` method
34
45
  end
35
46
 
36
47
  # Add a user message to the thread
@@ -59,11 +70,12 @@ module Langchain
59
70
 
60
71
  while running
61
72
  # TODO: I think we need to look at all messages and not just the last one.
62
- case (last_message = thread.messages.last).role
63
- when "system"
73
+ last_message = thread.messages.last
74
+
75
+ if last_message.system?
64
76
  # Do nothing
65
77
  running = false
66
- when "assistant"
78
+ elsif last_message.llm?
67
79
  if last_message.tool_calls.any?
68
80
  if auto_tool_execution
69
81
  run_tools(last_message.tool_calls)
@@ -76,11 +88,11 @@ module Langchain
76
88
  # Do nothing
77
89
  running = false
78
90
  end
79
- when "user"
91
+ elsif last_message.user?
80
92
  # Run it!
81
93
  response = chat_with_llm
82
94
 
83
- if response.tool_calls
95
+ if response.tool_calls.any?
84
96
  # Re-run the while(running) loop to process the tool calls
85
97
  running = true
86
98
  add_message(role: response.role, tool_calls: response.tool_calls)
@@ -89,12 +101,12 @@ module Langchain
89
101
  running = false
90
102
  add_message(role: response.role, content: response.chat_completion)
91
103
  end
92
- when "tool"
104
+ elsif last_message.tool?
93
105
  # Run it!
94
106
  response = chat_with_llm
95
107
  running = true
96
108
 
97
- if response.tool_calls
109
+ if response.tool_calls.any?
98
110
  add_message(role: response.role, tool_calls: response.tool_calls)
99
111
  elsif response.chat_completion
100
112
  add_message(role: response.role, content: response.chat_completion)
@@ -121,8 +133,14 @@ module Langchain
121
133
  # @param output [String] The output of the tool
122
134
  # @return [Array<Langchain::Message>] The messages in the thread
123
135
  def submit_tool_output(tool_call_id:, output:)
124
- # TODO: Validate that `tool_call_id` is valid
125
- add_message(role: "tool", content: output, tool_call_id: tool_call_id)
136
+ tool_role = if llm.is_a?(Langchain::LLM::OpenAI)
137
+ Langchain::Messages::OpenAIMessage::TOOL_ROLE
138
+ elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
139
+ Langchain::Messages::GoogleGeminiMessage::TOOL_ROLE
140
+ end
141
+
142
+ # TODO: Validate that `tool_call_id` is valid by scanning messages and checking if this tool call ID was invoked
143
+ add_message(role: tool_role, content: output, tool_call_id: tool_call_id)
126
144
  end
127
145
 
128
146
  # Delete all messages in the thread
@@ -156,10 +174,15 @@ module Langchain
156
174
  def chat_with_llm
157
175
  Langchain.logger.info("Sending a call to #{llm.class}", for: self.class)
158
176
 
159
- params = {messages: thread.openai_messages}
177
+ params = {messages: thread.array_of_message_hashes}
160
178
 
161
179
  if tools.any?
162
- params[:tools] = tools.map(&:to_openai_tools).flatten
180
+ if llm.is_a?(Langchain::LLM::OpenAI)
181
+ params[:tools] = tools.map(&:to_openai_tools).flatten
182
+ elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
183
+ params[:tools] = tools.map(&:to_google_gemini_tools).flatten
184
+ params[:system] = instructions if instructions
185
+ end
163
186
  # TODO: Not sure that tool_choice should always be "auto"; Maybe we can let the user toggle it.
164
187
  params[:tool_choice] = "auto"
165
188
  end
@@ -173,11 +196,11 @@ module Langchain
173
196
  def run_tools(tool_calls)
174
197
  # Iterate over each function invocation and submit tool output
175
198
  tool_calls.each do |tool_call|
176
- tool_call_id = tool_call.dig("id")
177
-
178
- function_name = tool_call.dig("function", "name")
179
- tool_name, method_name = function_name.split("-")
180
- tool_arguments = JSON.parse(tool_call.dig("function", "arguments"), symbolize_names: true)
199
+ tool_call_id, tool_name, method_name, tool_arguments = if llm.is_a?(Langchain::LLM::OpenAI)
200
+ extract_openai_tool_call(tool_call: tool_call)
201
+ elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
202
+ extract_google_gemini_tool_call(tool_call: tool_call)
203
+ end
181
204
 
182
205
  tool_instance = tools.find do |t|
183
206
  t.name == tool_name
@@ -190,13 +213,41 @@ module Langchain
190
213
 
191
214
  response = chat_with_llm
192
215
 
193
- if response.tool_calls
216
+ if response.tool_calls.any?
194
217
  add_message(role: response.role, tool_calls: response.tool_calls)
195
218
  elsif response.chat_completion
196
219
  add_message(role: response.role, content: response.chat_completion)
197
220
  end
198
221
  end
199
222
 
223
+ # Extract the tool call information from the OpenAI tool call hash
224
+ #
225
+ # @param tool_call [Hash] The tool call hash
226
+ # @return [Array] The tool call information
227
+ def extract_openai_tool_call(tool_call:)
228
+ tool_call_id = tool_call.dig("id")
229
+
230
+ function_name = tool_call.dig("function", "name")
231
+ tool_name, method_name = function_name.split("__")
232
+ tool_arguments = JSON.parse(tool_call.dig("function", "arguments"), symbolize_names: true)
233
+
234
+ [tool_call_id, tool_name, method_name, tool_arguments]
235
+ end
236
+
237
+ # Extract the tool call information from the Google Gemini tool call hash
238
+ #
239
+ # @param tool_call [Hash] The tool call hash, format: {"functionCall"=>{"name"=>"weather__execute", "args"=>{"input"=>"NYC"}}}
240
+ # @return [Array] The tool call information
241
+ def extract_google_gemini_tool_call(tool_call:)
242
+ tool_call_id = tool_call.dig("functionCall", "name")
243
+
244
+ function_name = tool_call.dig("functionCall", "name")
245
+ tool_name, method_name = function_name.split("__")
246
+ tool_arguments = tool_call.dig("functionCall", "args").transform_keys(&:to_sym)
247
+
248
+ [tool_call_id, tool_name, method_name, tool_arguments]
249
+ end
250
+
200
251
  # Build a message
201
252
  #
202
253
  # @param role [String] The role of the message
@@ -205,7 +256,11 @@ module Langchain
205
256
  # @param tool_call_id [String] The ID of the tool call to include in the message
206
257
  # @return [Langchain::Message] The Message object
207
258
  def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
208
- Message.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
259
+ if llm.is_a?(Langchain::LLM::OpenAI)
260
+ Langchain::Messages::OpenAIMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
261
+ elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
262
+ Langchain::Messages::GoogleGeminiMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
263
+ end
209
264
  end
210
265
 
211
266
  # TODO: Fix the message truncation when context window is exceeded
@@ -0,0 +1,16 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain
4
+ module Messages
5
+ class Base
6
+ attr_reader :role, :content, :tool_calls, :tool_call_id
7
+
8
+ # Check if the message came from a user
9
+ #
10
+ # @param [Boolean] true/false whether the message came from a user
11
+ def user?
12
+ role == "user"
13
+ end
14
+ end
15
+ end
16
+ end
@@ -0,0 +1,90 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain
4
+ module Messages
5
+ class GoogleGeminiMessage < Base
6
+ # Google Gemini uses the following roles:
7
+ ROLES = [
8
+ "user",
9
+ "model",
10
+ "function"
11
+ ].freeze
12
+
13
+ TOOL_ROLE = "function"
14
+
15
+ # Initialize a new Google Gemini message
16
+ #
17
+ # @param [String] The role of the message
18
+ # @param [String] The content of the message
19
+ # @param [Array<Hash>] The tool calls made in the message
20
+ # @param [String] The ID of the tool call
21
+ def initialize(role:, content: nil, tool_calls: [], tool_call_id: nil)
22
+ raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role)
23
+ raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) }
24
+
25
+ @role = role
26
+ # Some Tools return content as a JSON hence `.to_s`
27
+ @content = content.to_s
28
+ @tool_calls = tool_calls
29
+ @tool_call_id = tool_call_id
30
+ end
31
+
32
+ # Check if the message came from an LLM
33
+ #
34
+ # @return [Boolean] true/false whether this message was produced by an LLM
35
+ def llm?
36
+ model?
37
+ end
38
+
39
+ # Convert the message to a Google Gemini API-compatible hash
40
+ #
41
+ # @return [Hash] The message as a Google Gemini API-compatible hash
42
+ def to_hash
43
+ {}.tap do |h|
44
+ h[:role] = role
45
+ h[:parts] = if function?
46
+ [{
47
+ functionResponse: {
48
+ name: tool_call_id,
49
+ response: {
50
+ name: tool_call_id,
51
+ content: content
52
+ }
53
+ }
54
+ }]
55
+ elsif tool_calls.any?
56
+ tool_calls
57
+ else
58
+ [{text: content}]
59
+ end
60
+ end
61
+ end
62
+
63
+ # Google Gemini does not implement system prompts
64
+ def system?
65
+ false
66
+ end
67
+
68
+ # Check if the message is a tool call
69
+ #
70
+ # @return [Boolean] true/false whether this message is a tool call
71
+ def tool?
72
+ function?
73
+ end
74
+
75
+ # Check if the message is a tool call
76
+ #
77
+ # @return [Boolean] true/false whether this message is a tool call
78
+ def function?
79
+ role == "function"
80
+ end
81
+
82
+ # Check if the message came from an LLM
83
+ #
84
+ # @return [Boolean] true/false whether this message was produced by an LLM
85
+ def model?
86
+ role == "model"
87
+ end
88
+ end
89
+ end
90
+ end
@@ -0,0 +1,74 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain
4
+ module Messages
5
+ class OpenAIMessage < Base
6
+ # OpenAI uses the following roles:
7
+ ROLES = [
8
+ "system",
9
+ "assistant",
10
+ "user",
11
+ "tool"
12
+ ].freeze
13
+
14
+ TOOL_ROLE = "tool"
15
+
16
+ # Initialize a new OpenAI message
17
+ #
18
+ # @param [String] The role of the message
19
+ # @param [String] The content of the message
20
+ # @param [Array<Hash>] The tool calls made in the message
21
+ # @param [String] The ID of the tool call
22
+ def initialize(role:, content: nil, tool_calls: [], tool_call_id: nil) # TODO: Implement image_file: reference (https://platform.openai.com/docs/api-reference/messages/object#messages/object-content)
23
+ raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role)
24
+ raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) }
25
+
26
+ @role = role
27
+ # Some Tools return content as a JSON hence `.to_s`
28
+ @content = content.to_s
29
+ @tool_calls = tool_calls
30
+ @tool_call_id = tool_call_id
31
+ end
32
+
33
+ # Check if the message came from an LLM
34
+ #
35
+ # @return [Boolean] true/false whether this message was produced by an LLM
36
+ def llm?
37
+ assistant?
38
+ end
39
+
40
+ # Convert the message to an OpenAI API-compatible hash
41
+ #
42
+ # @return [Hash] The message as an OpenAI API-compatible hash
43
+ def to_hash
44
+ {}.tap do |h|
45
+ h[:role] = role
46
+ h[:content] = content if content # Content is nil for tool calls
47
+ h[:tool_calls] = tool_calls if tool_calls.any?
48
+ h[:tool_call_id] = tool_call_id if tool_call_id
49
+ end
50
+ end
51
+
52
+ # Check if the message came from an LLM
53
+ #
54
+ # @return [Boolean] true/false whether this message was produced by an LLM
55
+ def assistant?
56
+ role == "assistant"
57
+ end
58
+
59
+ # Check if the message are system instructions
60
+ #
61
+ # @return [Boolean] true/false whether this message are system instructions
62
+ def system?
63
+ role == "system"
64
+ end
65
+
66
+ # Check if the message is a tool call
67
+ #
68
+ # @return [Boolean] true/false whether this message is a tool call
69
+ def tool?
70
+ role == "tool"
71
+ end
72
+ end
73
+ end
74
+ end
@@ -8,16 +8,16 @@ module Langchain
8
8
 
9
9
  # @param messages [Array<Langchain::Message>]
10
10
  def initialize(messages: [])
11
- raise ArgumentError, "messages array must only contain Langchain::Message instance(s)" unless messages.is_a?(Array) && messages.all? { |m| m.is_a?(Langchain::Message) }
11
+ raise ArgumentError, "messages array must only contain Langchain::Message instance(s)" unless messages.is_a?(Array) && messages.all? { |m| m.is_a?(Langchain::Messages::Base) }
12
12
 
13
13
  @messages = messages
14
14
  end
15
15
 
16
- # Convert the thread to an OpenAI API-compatible array of hashes
16
+ # Convert the thread to an LLM APIs-compatible array of hashes
17
17
  #
18
18
  # @return [Array<Hash>] The thread as an OpenAI API-compatible array of hashes
19
- def openai_messages
20
- messages.map(&:to_openai_format)
19
+ def array_of_message_hashes
20
+ messages.map(&:to_hash)
21
21
  end
22
22
 
23
23
  # Add a message to the thread
@@ -25,7 +25,7 @@ module Langchain
25
25
  # @param message [Langchain::Message] The message to add
26
26
  # @return [Array<Langchain::Message>] The updated messages array
27
27
  def add_message(message)
28
- raise ArgumentError, "message must be a Langchain::Message instance" unless message.is_a?(Langchain::Message)
28
+ raise ArgumentError, "message must be a Langchain::Message instance" unless message.is_a?(Langchain::Messages::Base)
29
29
 
30
30
  # Prepend the message to the thread
31
31
  messages << message
@@ -13,7 +13,7 @@ module Langchain::LLM
13
13
  class Anthropic < Base
14
14
  DEFAULTS = {
15
15
  temperature: 0.0,
16
- completion_model_name: "claude-2",
16
+ completion_model_name: "claude-2.1",
17
17
  chat_completion_model_name: "claude-3-sonnet-20240229",
18
18
  max_tokens_to_sample: 256
19
19
  }.freeze
@@ -32,6 +32,15 @@ module Langchain::LLM
32
32
 
33
33
  @client = ::Anthropic::Client.new(access_token: api_key, **llm_options)
34
34
  @defaults = DEFAULTS.merge(default_options)
35
+ chat_parameters.update(
36
+ model: {default: @defaults[:chat_completion_model_name]},
37
+ temperature: {default: @defaults[:temperature]},
38
+ max_tokens: {default: @defaults[:max_tokens_to_sample]},
39
+ metadata: {},
40
+ system: {}
41
+ )
42
+ chat_parameters.ignore(:n, :user)
43
+ chat_parameters.remap(stop: :stop_sequences)
35
44
  end
36
45
 
37
46
  # Generate a completion for a given prompt
@@ -72,66 +81,35 @@ module Langchain::LLM
72
81
  parameters[:metadata] = metadata if metadata
73
82
  parameters[:stream] = stream if stream
74
83
 
75
- # TODO: Implement token length validator for Anthropic
76
- # parameters[:max_tokens_to_sample] = validate_max_tokens(prompt, parameters[:completion_model_name])
77
-
78
84
  response = client.complete(parameters: parameters)
79
85
  Langchain::LLM::AnthropicResponse.new(response)
80
86
  end
81
87
 
82
88
  # Generate a chat completion for given messages
83
89
  #
84
- # @param messages [Array<String>] Input messages
85
- # @param model [String] The model that will complete your prompt
86
- # @param max_tokens [Integer] Maximum number of tokens to generate before stopping
87
- # @param metadata [Hash] Object describing metadata about the request
88
- # @param stop_sequences [Array<String>] Custom text sequences that will cause the model to stop generating
89
- # @param stream [Boolean] Whether to incrementally stream the response using server-sent events
90
- # @param system [String] System prompt
91
- # @param temperature [Float] Amount of randomness injected into the response
92
- # @param tools [Array<String>] Definitions of tools that the model may use
93
- # @param top_k [Integer] Only sample from the top K options for each subsequent token
94
- # @param top_p [Float] Use nucleus sampling.
90
+ # @param [Hash] params unified chat parmeters from [Langchain::LLM::Parameters::Chat::SCHEMA]
91
+ # @option params [Array<String>] :messages Input messages
92
+ # @option params [String] :model The model that will complete your prompt
93
+ # @option params [Integer] :max_tokens Maximum number of tokens to generate before stopping
94
+ # @option params [Hash] :metadata Object describing metadata about the request
95
+ # @option params [Array<String>] :stop_sequences Custom text sequences that will cause the model to stop generating
96
+ # @option params [Boolean] :stream Whether to incrementally stream the response using server-sent events
97
+ # @option params [String] :system System prompt
98
+ # @option params [Float] :temperature Amount of randomness injected into the response
99
+ # @option params [Array<String>] :tools Definitions of tools that the model may use
100
+ # @option params [Integer] :top_k Only sample from the top K options for each subsequent token
101
+ # @option params [Float] :top_p Use nucleus sampling.
95
102
  # @return [Langchain::LLM::AnthropicResponse] The chat completion
96
- def chat(
97
- messages: [],
98
- model: @defaults[:chat_completion_model_name],
99
- max_tokens: @defaults[:max_tokens_to_sample],
100
- metadata: nil,
101
- stop_sequences: nil,
102
- stream: nil,
103
- system: nil,
104
- temperature: @defaults[:temperature],
105
- tools: [],
106
- top_k: nil,
107
- top_p: nil
108
- )
109
- raise ArgumentError.new("messages argument is required") if messages.empty?
110
- raise ArgumentError.new("model argument is required") if model.empty?
111
- raise ArgumentError.new("max_tokens argument is required") if max_tokens.nil?
103
+ def chat(params = {})
104
+ parameters = chat_parameters.to_params(params)
112
105
 
113
- parameters = {
114
- messages: messages,
115
- model: model,
116
- max_tokens: max_tokens,
117
- temperature: temperature
118
- }
119
- parameters[:metadata] = metadata if metadata
120
- parameters[:stop_sequences] = stop_sequences if stop_sequences
121
- parameters[:stream] = stream if stream
122
- parameters[:system] = system if system
123
- parameters[:tools] = tools if tools.any?
124
- parameters[:top_k] = top_k if top_k
125
- parameters[:top_p] = top_p if top_p
106
+ raise ArgumentError.new("messages argument is required") if Array(parameters[:messages]).empty?
107
+ raise ArgumentError.new("model argument is required") if parameters[:model].empty?
108
+ raise ArgumentError.new("max_tokens argument is required") if parameters[:max_tokens].nil?
126
109
 
127
110
  response = client.messages(parameters: parameters)
128
111
 
129
112
  Langchain::LLM::AnthropicResponse.new(response)
130
113
  end
131
-
132
- # TODO: Implement token length validator for Anthropic
133
- # def validate_max_tokens(messages, model)
134
- # LENGTH_VALIDATOR.validate_max_tokens!(messages, model)
135
- # end
136
114
  end
137
115
  end
@@ -59,6 +59,17 @@ module Langchain::LLM
59
59
  @defaults = DEFAULTS.merge(default_options)
60
60
  .merge(completion_model_name: completion_model)
61
61
  .merge(embedding_model_name: embedding_model)
62
+
63
+ chat_parameters.update(
64
+ model: {default: @defaults[:chat_completion_model_name]},
65
+ temperature: {},
66
+ max_tokens: {default: @defaults[:max_tokens_to_sample]},
67
+ metadata: {},
68
+ system: {},
69
+ anthropic_version: {default: "bedrock-2023-05-31"}
70
+ )
71
+ chat_parameters.ignore(:n, :user)
72
+ chat_parameters.remap(stop: :stop_sequences)
62
73
  end
63
74
 
64
75
  #
@@ -113,43 +124,28 @@ module Langchain::LLM
113
124
  # Generate a chat completion for a given prompt
114
125
  # Currently only configured to work with the Anthropic provider and
115
126
  # the claude-3 model family
116
- # @param messages [Array] The messages to generate a completion for
117
- # @param system [String] The system prompt to provide instructions
118
- # @param model [String] The model to use for completion defaults to @defaults[:chat_completion_model_name]
119
- # @param max_tokens [Integer] The maximum number of tokens to generate
120
- # @param stop_sequences [Array] The stop sequences to use for completion
121
- # @param temperature [Float] The temperature to use for completion
122
- # @param top_p [Float] The top p to use for completion
123
- # @param top_k [Integer] The top k to use for completion
127
+ #
128
+ # @param [Hash] params unified chat parmeters from [Langchain::LLM::Parameters::Chat::SCHEMA]
129
+ # @option params [Array<String>] :messages The messages to generate a completion for
130
+ # @option params [String] :system The system prompt to provide instructions
131
+ # @option params [String] :model The model to use for completion defaults to @defaults[:chat_completion_model_name]
132
+ # @option params [Integer] :max_tokens The maximum number of tokens to generate defaults to @defaults[:max_tokens_to_sample]
133
+ # @option params [Array<String>] :stop The stop sequences to use for completion
134
+ # @option params [Array<String>] :stop_sequences The stop sequences to use for completion
135
+ # @option params [Float] :temperature The temperature to use for completion
136
+ # @option params [Float] :top_p Use nucleus sampling.
137
+ # @option params [Integer] :top_k Only sample from the top K options for each subsequent token
124
138
  # @return [Langchain::LLM::AnthropicMessagesResponse] Response object
125
- def chat(
126
- messages: [],
127
- system: nil,
128
- model: defaults[:completion_model_name],
129
- max_tokens: defaults[:max_tokens_to_sample],
130
- stop_sequences: nil,
131
- temperature: nil,
132
- top_p: nil,
133
- top_k: nil
134
- )
135
- raise ArgumentError.new("messages argument is required") if messages.empty?
136
-
137
- raise "Model #{model} does not support chat completions." unless Langchain::LLM::AwsBedrock::SUPPORTED_CHAT_COMPLETION_PROVIDERS.include?(completion_provider)
138
-
139
- inference_parameters = {
140
- messages: messages,
141
- max_tokens: max_tokens,
142
- anthropic_version: @defaults[:anthropic_version]
143
- }
144
- inference_parameters[:system] = system if system
145
- inference_parameters[:stop_sequences] = stop_sequences if stop_sequences
146
- inference_parameters[:temperature] = temperature if temperature
147
- inference_parameters[:top_p] = top_p if top_p
148
- inference_parameters[:top_k] = top_k if top_k
139
+ def chat(params = {})
140
+ parameters = chat_parameters.to_params(params)
141
+
142
+ raise ArgumentError.new("messages argument is required") if Array(parameters[:messages]).empty?
143
+
144
+ raise "Model #{parameters[:model]} does not support chat completions." unless Langchain::LLM::AwsBedrock::SUPPORTED_CHAT_COMPLETION_PROVIDERS.include?(completion_provider)
149
145
 
150
146
  response = client.invoke_model({
151
- model_id: model,
152
- body: inference_parameters.to_json,
147
+ model_id: parameters[:model],
148
+ body: parameters.except(:model).to_json,
153
149
  content_type: "application/json",
154
150
  accept: "application/json"
155
151
  })
@@ -32,6 +32,12 @@ module Langchain::LLM
32
32
  **llm_options
33
33
  )
34
34
  @defaults = DEFAULTS.merge(default_options)
35
+ chat_parameters.update(
36
+ logprobs: {},
37
+ top_logprobs: {},
38
+ user: {}
39
+ )
40
+ chat_parameters.ignore(:top_k)
35
41
  end
36
42
 
37
43
  def embed(...)