langchainrb 0.17.1 → 0.18.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (38) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +10 -0
  3. data/README.md +5 -0
  4. data/lib/langchain/{assistants → assistant}/llm/adapter.rb +1 -1
  5. data/lib/langchain/assistant/llm/adapters/anthropic.rb +105 -0
  6. data/lib/langchain/assistant/llm/adapters/base.rb +63 -0
  7. data/lib/langchain/{assistants → assistant}/llm/adapters/google_gemini.rb +43 -3
  8. data/lib/langchain/{assistants → assistant}/llm/adapters/mistral_ai.rb +39 -2
  9. data/lib/langchain/assistant/llm/adapters/ollama.rb +94 -0
  10. data/lib/langchain/{assistants → assistant}/llm/adapters/openai.rb +38 -2
  11. data/lib/langchain/assistant/messages/anthropic_message.rb +77 -0
  12. data/lib/langchain/assistant/messages/base.rb +56 -0
  13. data/lib/langchain/assistant/messages/google_gemini_message.rb +92 -0
  14. data/lib/langchain/assistant/messages/mistral_ai_message.rb +98 -0
  15. data/lib/langchain/assistant/messages/ollama_message.rb +76 -0
  16. data/lib/langchain/assistant/messages/openai_message.rb +105 -0
  17. data/lib/langchain/{assistants/assistant.rb → assistant.rb} +26 -49
  18. data/lib/langchain/llm/ai21.rb +1 -1
  19. data/lib/langchain/llm/anthropic.rb +59 -4
  20. data/lib/langchain/llm/aws_bedrock.rb +6 -7
  21. data/lib/langchain/llm/azure.rb +1 -1
  22. data/lib/langchain/llm/hugging_face.rb +1 -1
  23. data/lib/langchain/llm/ollama.rb +0 -1
  24. data/lib/langchain/llm/openai.rb +2 -2
  25. data/lib/langchain/llm/parameters/chat.rb +1 -0
  26. data/lib/langchain/llm/replicate.rb +2 -10
  27. data/lib/langchain/version.rb +1 -1
  28. data/lib/langchain.rb +1 -14
  29. metadata +16 -16
  30. data/lib/langchain/assistants/llm/adapters/_base.rb +0 -21
  31. data/lib/langchain/assistants/llm/adapters/anthropic.rb +0 -62
  32. data/lib/langchain/assistants/llm/adapters/ollama.rb +0 -57
  33. data/lib/langchain/assistants/messages/anthropic_message.rb +0 -75
  34. data/lib/langchain/assistants/messages/base.rb +0 -54
  35. data/lib/langchain/assistants/messages/google_gemini_message.rb +0 -90
  36. data/lib/langchain/assistants/messages/mistral_ai_message.rb +0 -96
  37. data/lib/langchain/assistants/messages/ollama_message.rb +0 -74
  38. data/lib/langchain/assistants/messages/openai_message.rb +0 -103
@@ -0,0 +1,56 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain
4
+ class Assistant
5
+ module Messages
6
+ class Base
7
+ attr_reader :role,
8
+ :content,
9
+ :image_url,
10
+ :tool_calls,
11
+ :tool_call_id
12
+
13
+ # Check if the message came from a user
14
+ #
15
+ # @return [Boolean] true/false whether the message came from a user
16
+ def user?
17
+ role == "user"
18
+ end
19
+
20
+ # Check if the message came from an LLM
21
+ #
22
+ # @raise NotImplementedError if the subclass does not implement this method
23
+ def llm?
24
+ raise NotImplementedError, "Class #{self.class.name} must implement the method 'llm?'"
25
+ end
26
+
27
+ # Check if the message is a tool call
28
+ #
29
+ # @raise NotImplementedError if the subclass does not implement this method
30
+ def tool?
31
+ raise NotImplementedError, "Class #{self.class.name} must implement the method 'tool?'"
32
+ end
33
+
34
+ # Check if the message is a system prompt
35
+ #
36
+ # @raise NotImplementedError if the subclass does not implement this method
37
+ def system?
38
+ raise NotImplementedError, "Class #{self.class.name} must implement the method 'system?'"
39
+ end
40
+
41
+ # Returns the standardized role symbol based on the specific role methods
42
+ #
43
+ # @return [Symbol] the standardized role symbol (:system, :llm, :tool, :user, or :unknown)
44
+ def standard_role
45
+ return :user if user?
46
+ return :llm if llm?
47
+ return :tool if tool?
48
+ return :system if system?
49
+
50
+ # TODO: Should we return :unknown or raise an error?
51
+ :unknown
52
+ end
53
+ end
54
+ end
55
+ end
56
+ end
@@ -0,0 +1,92 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain
4
+ class Assistant
5
+ module Messages
6
+ class GoogleGeminiMessage < Base
7
+ # Google Gemini uses the following roles:
8
+ ROLES = [
9
+ "user",
10
+ "model",
11
+ "function"
12
+ ].freeze
13
+
14
+ TOOL_ROLE = "function"
15
+
16
+ # Initialize a new Google Gemini message
17
+ #
18
+ # @param [String] The role of the message
19
+ # @param [String] The content of the message
20
+ # @param [Array<Hash>] The tool calls made in the message
21
+ # @param [String] The ID of the tool call
22
+ def initialize(role:, content: nil, tool_calls: [], tool_call_id: nil)
23
+ raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role)
24
+ raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) }
25
+
26
+ @role = role
27
+ # Some Tools return content as a JSON hence `.to_s`
28
+ @content = content.to_s
29
+ @tool_calls = tool_calls
30
+ @tool_call_id = tool_call_id
31
+ end
32
+
33
+ # Check if the message came from an LLM
34
+ #
35
+ # @return [Boolean] true/false whether this message was produced by an LLM
36
+ def llm?
37
+ model?
38
+ end
39
+
40
+ # Convert the message to a Google Gemini API-compatible hash
41
+ #
42
+ # @return [Hash] The message as a Google Gemini API-compatible hash
43
+ def to_hash
44
+ {}.tap do |h|
45
+ h[:role] = role
46
+ h[:parts] = if function?
47
+ [{
48
+ functionResponse: {
49
+ name: tool_call_id,
50
+ response: {
51
+ name: tool_call_id,
52
+ content: content
53
+ }
54
+ }
55
+ }]
56
+ elsif tool_calls.any?
57
+ tool_calls
58
+ else
59
+ [{text: content}]
60
+ end
61
+ end
62
+ end
63
+
64
+ # Google Gemini does not implement system prompts
65
+ def system?
66
+ false
67
+ end
68
+
69
+ # Check if the message is a tool call
70
+ #
71
+ # @return [Boolean] true/false whether this message is a tool call
72
+ def tool?
73
+ function?
74
+ end
75
+
76
+ # Check if the message is a tool call
77
+ #
78
+ # @return [Boolean] true/false whether this message is a tool call
79
+ def function?
80
+ role == "function"
81
+ end
82
+
83
+ # Check if the message came from an LLM
84
+ #
85
+ # @return [Boolean] true/false whether this message was produced by an LLM
86
+ def model?
87
+ role == "model"
88
+ end
89
+ end
90
+ end
91
+ end
92
+ end
@@ -0,0 +1,98 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain
4
+ class Assistant
5
+ module Messages
6
+ class MistralAIMessage < Base
7
+ # MistralAI uses the following roles:
8
+ ROLES = [
9
+ "system",
10
+ "assistant",
11
+ "user",
12
+ "tool"
13
+ ].freeze
14
+
15
+ TOOL_ROLE = "tool"
16
+
17
+ # Initialize a new MistralAI message
18
+ #
19
+ # @param role [String] The role of the message
20
+ # @param content [String] The content of the message
21
+ # @param image_url [String] The URL of the image
22
+ # @param tool_calls [Array<Hash>] The tool calls made in the message
23
+ # @param tool_call_id [String] The ID of the tool call
24
+ def initialize(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil) # TODO: Implement image_file: reference (https://platform.openai.com/docs/api-reference/messages/object#messages/object-content)
25
+ raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role)
26
+ raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) }
27
+
28
+ @role = role
29
+ # Some Tools return content as a JSON hence `.to_s`
30
+ @content = content.to_s
31
+ # Make sure you're using the Pixtral model if you want to send image_url
32
+ @image_url = image_url
33
+ @tool_calls = tool_calls
34
+ @tool_call_id = tool_call_id
35
+ end
36
+
37
+ # Check if the message came from an LLM
38
+ #
39
+ # @return [Boolean] true/false whether this message was produced by an LLM
40
+ def llm?
41
+ assistant?
42
+ end
43
+
44
+ # Convert the message to an MistralAI API-compatible hash
45
+ #
46
+ # @return [Hash] The message as an MistralAI API-compatible hash
47
+ def to_hash
48
+ {}.tap do |h|
49
+ h[:role] = role
50
+
51
+ if tool_calls.any?
52
+ h[:tool_calls] = tool_calls
53
+ else
54
+ h[:tool_call_id] = tool_call_id if tool_call_id
55
+
56
+ h[:content] = []
57
+
58
+ if content && !content.empty?
59
+ h[:content] << {
60
+ type: "text",
61
+ text: content
62
+ }
63
+ end
64
+
65
+ if image_url
66
+ h[:content] << {
67
+ type: "image_url",
68
+ image_url: image_url
69
+ }
70
+ end
71
+ end
72
+ end
73
+ end
74
+
75
+ # Check if the message came from an LLM
76
+ #
77
+ # @return [Boolean] true/false whether this message was produced by an LLM
78
+ def assistant?
79
+ role == "assistant"
80
+ end
81
+
82
+ # Check if the message are system instructions
83
+ #
84
+ # @return [Boolean] true/false whether this message are system instructions
85
+ def system?
86
+ role == "system"
87
+ end
88
+
89
+ # Check if the message is a tool call
90
+ #
91
+ # @return [Boolean] true/false whether this message is a tool call
92
+ def tool?
93
+ role == "tool"
94
+ end
95
+ end
96
+ end
97
+ end
98
+ end
@@ -0,0 +1,76 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain
4
+ class Assistant
5
+ module Messages
6
+ class OllamaMessage < Base
7
+ # OpenAI uses the following roles:
8
+ ROLES = [
9
+ "system",
10
+ "assistant",
11
+ "user",
12
+ "tool"
13
+ ].freeze
14
+
15
+ TOOL_ROLE = "tool"
16
+
17
+ # Initialize a new OpenAI message
18
+ #
19
+ # @param [String] The role of the message
20
+ # @param [String] The content of the message
21
+ # @param [Array<Hash>] The tool calls made in the message
22
+ # @param [String] The ID of the tool call
23
+ def initialize(role:, content: nil, tool_calls: [], tool_call_id: nil)
24
+ raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role)
25
+ raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) }
26
+
27
+ @role = role
28
+ # Some Tools return content as a JSON hence `.to_s`
29
+ @content = content.to_s
30
+ @tool_calls = tool_calls
31
+ @tool_call_id = tool_call_id
32
+ end
33
+
34
+ # Convert the message to an OpenAI API-compatible hash
35
+ #
36
+ # @return [Hash] The message as an OpenAI API-compatible hash
37
+ def to_hash
38
+ {}.tap do |h|
39
+ h[:role] = role
40
+ h[:content] = content if content # Content is nil for tool calls
41
+ h[:tool_calls] = tool_calls if tool_calls.any?
42
+ h[:tool_call_id] = tool_call_id if tool_call_id
43
+ end
44
+ end
45
+
46
+ # Check if the message came from an LLM
47
+ #
48
+ # @return [Boolean] true/false whether this message was produced by an LLM
49
+ def llm?
50
+ assistant?
51
+ end
52
+
53
+ # Check if the message came from an LLM
54
+ #
55
+ # @return [Boolean] true/false whether this message was produced by an LLM
56
+ def assistant?
57
+ role == "assistant"
58
+ end
59
+
60
+ # Check if the message are system instructions
61
+ #
62
+ # @return [Boolean] true/false whether this message are system instructions
63
+ def system?
64
+ role == "system"
65
+ end
66
+
67
+ # Check if the message is a tool call
68
+ #
69
+ # @return [Boolean] true/false whether this message is a tool call
70
+ def tool?
71
+ role == "tool"
72
+ end
73
+ end
74
+ end
75
+ end
76
+ end
@@ -0,0 +1,105 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain
4
+ class Assistant
5
+ module Messages
6
+ class OpenAIMessage < Base
7
+ # OpenAI uses the following roles:
8
+ ROLES = [
9
+ "system",
10
+ "assistant",
11
+ "user",
12
+ "tool"
13
+ ].freeze
14
+
15
+ TOOL_ROLE = "tool"
16
+
17
+ # Initialize a new OpenAI message
18
+ #
19
+ # @param role [String] The role of the message
20
+ # @param content [String] The content of the message
21
+ # @param image_url [String] The URL of the image
22
+ # @param tool_calls [Array<Hash>] The tool calls made in the message
23
+ # @param tool_call_id [String] The ID of the tool call
24
+ def initialize(
25
+ role:,
26
+ content: nil,
27
+ image_url: nil,
28
+ tool_calls: [],
29
+ tool_call_id: nil
30
+ )
31
+ raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role)
32
+ raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) }
33
+
34
+ @role = role
35
+ # Some Tools return content as a JSON hence `.to_s`
36
+ @content = content.to_s
37
+ @image_url = image_url
38
+ @tool_calls = tool_calls
39
+ @tool_call_id = tool_call_id
40
+ end
41
+
42
+ # Check if the message came from an LLM
43
+ #
44
+ # @return [Boolean] true/false whether this message was produced by an LLM
45
+ def llm?
46
+ assistant?
47
+ end
48
+
49
+ # Convert the message to an OpenAI API-compatible hash
50
+ #
51
+ # @return [Hash] The message as an OpenAI API-compatible hash
52
+ def to_hash
53
+ {}.tap do |h|
54
+ h[:role] = role
55
+
56
+ if tool_calls.any?
57
+ h[:tool_calls] = tool_calls
58
+ else
59
+ h[:tool_call_id] = tool_call_id if tool_call_id
60
+
61
+ h[:content] = []
62
+
63
+ if content && !content.empty?
64
+ h[:content] << {
65
+ type: "text",
66
+ text: content
67
+ }
68
+ end
69
+
70
+ if image_url
71
+ h[:content] << {
72
+ type: "image_url",
73
+ image_url: {
74
+ url: image_url
75
+ }
76
+ }
77
+ end
78
+ end
79
+ end
80
+ end
81
+
82
+ # Check if the message came from an LLM
83
+ #
84
+ # @return [Boolean] true/false whether this message was produced by an LLM
85
+ def assistant?
86
+ role == "assistant"
87
+ end
88
+
89
+ # Check if the message are system instructions
90
+ #
91
+ # @return [Boolean] true/false whether this message are system instructions
92
+ def system?
93
+ role == "system"
94
+ end
95
+
96
+ # Check if the message is a tool call
97
+ #
98
+ # @return [Boolean] true/false whether this message is a tool call
99
+ def tool?
100
+ role == "tool"
101
+ end
102
+ end
103
+ end
104
+ end
105
+ end
@@ -1,7 +1,5 @@
1
1
  # frozen_string_literal: true
2
2
 
3
- require_relative "llm/adapter"
4
-
5
3
  module Langchain
6
4
  # Assistants are Agent-like objects that leverage helpful instructions, LLMs, tools and knowledge to respond to user queries.
7
5
  # Assistants can be configured with an LLM of your choice, any vector search database and easily extended with additional tools.
@@ -14,9 +12,19 @@ module Langchain
14
12
  # tools: [Langchain::Tool::NewsRetriever.new(api_key: ENV["NEWS_API_KEY"])]
15
13
  # )
16
14
  class Assistant
17
- attr_reader :llm, :instructions, :state, :llm_adapter, :tool_choice
18
- attr_reader :total_prompt_tokens, :total_completion_tokens, :total_tokens, :messages
19
- attr_accessor :tools, :add_message_callback
15
+ attr_reader :llm,
16
+ :instructions,
17
+ :state,
18
+ :llm_adapter,
19
+ :messages,
20
+ :tool_choice,
21
+ :total_prompt_tokens,
22
+ :total_completion_tokens,
23
+ :total_tokens
24
+
25
+ attr_accessor :tools,
26
+ :add_message_callback,
27
+ :parallel_tool_calls
20
28
 
21
29
  # Create a new assistant
22
30
  #
@@ -24,12 +32,15 @@ module Langchain
24
32
  # @param tools [Array<Langchain::Tool::Base>] Tools that the assistant has access to
25
33
  # @param instructions [String] The system instructions
26
34
  # @param tool_choice [String] Specify how tools should be selected. Options: "auto", "any", "none", or <specific function name>
27
- # @params add_message_callback [Proc] A callback function (Proc or lambda) that is called when any message is added to the conversation
35
+ # @param parallel_tool_calls [Boolean] Whether or not to run tools in parallel
36
+ # @param messages [Array<Langchain::Assistant::Messages::Base>] The messages
37
+ # @param add_message_callback [Proc] A callback function (Proc or lambda) that is called when any message is added to the conversation
28
38
  def initialize(
29
39
  llm:,
30
40
  tools: [],
31
41
  instructions: nil,
32
42
  tool_choice: "auto",
43
+ parallel_tool_calls: true,
33
44
  messages: [],
34
45
  add_message_callback: nil,
35
46
  &block
@@ -49,18 +60,15 @@ module Langchain
49
60
 
50
61
  self.messages = messages
51
62
  @tools = tools
63
+ @parallel_tool_calls = parallel_tool_calls
52
64
  self.tool_choice = tool_choice
53
- @instructions = instructions
65
+ self.instructions = instructions
54
66
  @block = block
55
67
  @state = :ready
56
68
 
57
69
  @total_prompt_tokens = 0
58
70
  @total_completion_tokens = 0
59
71
  @total_tokens = 0
60
-
61
- # The first message in the messages array should be the system instructions
62
- # For Google Gemini, and Anthropic system instructions are added to the `system:` param in the `chat` method
63
- initialize_instructions
64
72
  end
65
73
 
66
74
  # Add a user message to the messages array
@@ -104,7 +112,7 @@ module Langchain
104
112
  # @param messages [Array<Langchain::Message>] The messages to set
105
113
  # @return [Array<Langchain::Message>] The messages
106
114
  def messages=(messages)
107
- raise ArgumentError, "messages array must only contain Langchain::Message instance(s)" unless messages.is_a?(Array) && messages.all? { |m| m.is_a?(Langchain::Messages::Base) }
115
+ raise ArgumentError, "messages array must only contain Langchain::Message instance(s)" unless messages.is_a?(Array) && messages.all? { |m| m.is_a?(Messages::Base) }
108
116
 
109
117
  @messages = messages
110
118
  end
@@ -167,10 +175,8 @@ module Langchain
167
175
  # @param output [String] The output of the tool
168
176
  # @return [Array<Langchain::Message>] The messages
169
177
  def submit_tool_output(tool_call_id:, output:)
170
- tool_role = determine_tool_role
171
-
172
178
  # TODO: Validate that `tool_call_id` is valid by scanning messages and checking if this tool call ID was invoked
173
- add_message(role: tool_role, content: output, tool_call_id: tool_call_id)
179
+ add_message(role: @llm_adapter.tool_role, content: output, tool_call_id: tool_call_id)
174
180
  end
175
181
 
176
182
  # Delete all messages
@@ -181,9 +187,6 @@ module Langchain
181
187
  @messages = []
182
188
  end
183
189
 
184
- # TODO: Remove in the next major release
185
- alias_method :clear_thread!, :clear_messages!
186
-
187
190
  # Set new instructions
188
191
  #
189
192
  # @param new_instructions [String] New instructions that will be set as a system message
@@ -191,12 +194,9 @@ module Langchain
191
194
  def instructions=(new_instructions)
192
195
  @instructions = new_instructions
193
196
 
194
- # This only needs to be done that support Message#@role="system"
195
- if !llm.is_a?(Langchain::LLM::GoogleGemini) &&
196
- !llm.is_a?(Langchain::LLM::GoogleVertexAI) &&
197
- !llm.is_a?(Langchain::LLM::Anthropic)
198
- # Find message with role: "system" in messages and delete it from the messages array
199
- replace_system_message!(content: new_instructions)
197
+ if @llm_adapter.support_system_message?
198
+ # TODO: Should we still set a system message even if @instructions is "" or nil?
199
+ replace_system_message!(content: new_instructions) if @instructions
200
200
  end
201
201
  end
202
202
 
@@ -330,30 +330,6 @@ module Langchain
330
330
  :failed
331
331
  end
332
332
 
333
- # Determine the tool role based on the LLM type
334
- #
335
- # @return [String] The tool role
336
- def determine_tool_role
337
- case llm
338
- when Langchain::LLM::Anthropic
339
- Langchain::Messages::AnthropicMessage::TOOL_ROLE
340
- when Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI
341
- Langchain::Messages::GoogleGeminiMessage::TOOL_ROLE
342
- when Langchain::LLM::MistralAI
343
- Langchain::Messages::MistralAIMessage::TOOL_ROLE
344
- when Langchain::LLM::Ollama
345
- Langchain::Messages::OllamaMessage::TOOL_ROLE
346
- when Langchain::LLM::OpenAI
347
- Langchain::Messages::OpenAIMessage::TOOL_ROLE
348
- end
349
- end
350
-
351
- def initialize_instructions
352
- if llm.is_a?(Langchain::LLM::OpenAI) || llm.is_a?(Langchain::LLM::MistralAI)
353
- self.instructions = @instructions if @instructions
354
- end
355
- end
356
-
357
333
  # Call to the LLM#chat() method
358
334
  #
359
335
  # @return [Langchain::LLM::BaseResponse] The LLM response object
@@ -364,7 +340,8 @@ module Langchain
364
340
  instructions: @instructions,
365
341
  messages: array_of_message_hashes,
366
342
  tools: @tools,
367
- tool_choice: tool_choice
343
+ tool_choice: tool_choice,
344
+ parallel_tool_calls: parallel_tool_calls
368
345
  )
369
346
  @llm.chat(**params, &@block)
370
347
  end
@@ -8,7 +8,7 @@ module Langchain::LLM
8
8
  # gem "ai21", "~> 0.2.1"
9
9
  #
10
10
  # Usage:
11
- # ai21 = Langchain::LLM::AI21.new(api_key: ENV["AI21_API_KEY"])
11
+ # llm = Langchain::LLM::AI21.new(api_key: ENV["AI21_API_KEY"])
12
12
  #
13
13
  class AI21 < Base
14
14
  DEFAULTS = {
@@ -5,10 +5,10 @@ module Langchain::LLM
5
5
  # Wrapper around Anthropic APIs.
6
6
  #
7
7
  # Gem requirements:
8
- # gem "anthropic", "~> 0.3.0"
8
+ # gem "anthropic", "~> 0.3.2"
9
9
  #
10
10
  # Usage:
11
- # anthropic = Langchain::LLM::Anthropic.new(api_key: ENV["ANTHROPIC_API_KEY"])
11
+ # llm = Langchain::LLM::Anthropic.new(api_key: ENV["ANTHROPIC_API_KEY"])
12
12
  #
13
13
  class Anthropic < Base
14
14
  DEFAULTS = {
@@ -100,7 +100,7 @@ module Langchain::LLM
100
100
  # @option params [Integer] :top_k Only sample from the top K options for each subsequent token
101
101
  # @option params [Float] :top_p Use nucleus sampling.
102
102
  # @return [Langchain::LLM::AnthropicResponse] The chat completion
103
- def chat(params = {})
103
+ def chat(params = {}, &block)
104
104
  set_extra_headers! if params[:tools]
105
105
 
106
106
  parameters = chat_parameters.to_params(params)
@@ -109,9 +109,19 @@ module Langchain::LLM
109
109
  raise ArgumentError.new("model argument is required") if parameters[:model].empty?
110
110
  raise ArgumentError.new("max_tokens argument is required") if parameters[:max_tokens].nil?
111
111
 
112
- binding.pry
112
+ if block
113
+ @response_chunks = []
114
+ parameters[:stream] = proc do |chunk|
115
+ @response_chunks << chunk
116
+ yield chunk
117
+ end
118
+ end
119
+
113
120
  response = client.messages(parameters: parameters)
114
121
 
122
+ response = response_from_chunks if block
123
+ reset_response_chunks
124
+
115
125
  Langchain::LLM::AnthropicResponse.new(response)
116
126
  end
117
127
 
@@ -124,8 +134,53 @@ module Langchain::LLM
124
134
  response
125
135
  end
126
136
 
137
+ def response_from_chunks
138
+ grouped_chunks = @response_chunks.group_by { |chunk| chunk["index"] }.except(nil)
139
+
140
+ usage = @response_chunks.find { |chunk| chunk["type"] == "message_delta" }&.dig("usage")
141
+ stop_reason = @response_chunks.find { |chunk| chunk["type"] == "message_delta" }&.dig("delta", "stop_reason")
142
+
143
+ content = grouped_chunks.map do |_index, chunks|
144
+ text = chunks.map { |chunk| chunk.dig("delta", "text") }.join
145
+ if !text.nil? && !text.empty?
146
+ {"type" => "text", "text" => text}
147
+ else
148
+ tool_calls_from_choice_chunks(chunks)
149
+ end
150
+ end.flatten
151
+
152
+ @response_chunks.first&.slice("id", "object", "created", "model")
153
+ &.merge!(
154
+ {
155
+ "content" => content,
156
+ "usage" => usage,
157
+ "role" => "assistant",
158
+ "stop_reason" => stop_reason
159
+ }
160
+ )
161
+ end
162
+
163
+ def tool_calls_from_choice_chunks(chunks)
164
+ return unless (first_block = chunks.find { |chunk| chunk.dig("content_block", "type") == "tool_use" })
165
+
166
+ chunks.group_by { |chunk| chunk["index"] }.map do |index, chunks|
167
+ input = chunks.select { |chunk| chunk.dig("delta", "partial_json") }
168
+ .map! { |chunk| chunk.dig("delta", "partial_json") }.join
169
+ {
170
+ "id" => first_block.dig("content_block", "id"),
171
+ "type" => "tool_use",
172
+ "name" => first_block.dig("content_block", "name"),
173
+ "input" => JSON.parse(input).transform_keys(&:to_sym)
174
+ }
175
+ end.compact
176
+ end
177
+
127
178
  private
128
179
 
180
+ def reset_response_chunks
181
+ @response_chunks = []
182
+ end
183
+
129
184
  def set_extra_headers!
130
185
  ::Anthropic.configuration.extra_headers = {"anthropic-beta": "tools-2024-05-16"}
131
186
  end