langchainrb 0.17.1 → 0.18.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +10 -0
  3. data/README.md +5 -0
  4. data/lib/langchain/{assistants → assistant}/llm/adapter.rb +1 -1
  5. data/lib/langchain/assistant/llm/adapters/anthropic.rb +105 -0
  6. data/lib/langchain/assistant/llm/adapters/base.rb +63 -0
  7. data/lib/langchain/{assistants → assistant}/llm/adapters/google_gemini.rb +43 -3
  8. data/lib/langchain/{assistants → assistant}/llm/adapters/mistral_ai.rb +39 -2
  9. data/lib/langchain/assistant/llm/adapters/ollama.rb +94 -0
  10. data/lib/langchain/{assistants → assistant}/llm/adapters/openai.rb +38 -2
  11. data/lib/langchain/assistant/messages/anthropic_message.rb +77 -0
  12. data/lib/langchain/assistant/messages/base.rb +56 -0
  13. data/lib/langchain/assistant/messages/google_gemini_message.rb +92 -0
  14. data/lib/langchain/assistant/messages/mistral_ai_message.rb +98 -0
  15. data/lib/langchain/assistant/messages/ollama_message.rb +76 -0
  16. data/lib/langchain/assistant/messages/openai_message.rb +105 -0
  17. data/lib/langchain/{assistants/assistant.rb → assistant.rb} +26 -49
  18. data/lib/langchain/llm/ai21.rb +1 -1
  19. data/lib/langchain/llm/anthropic.rb +59 -4
  20. data/lib/langchain/llm/aws_bedrock.rb +6 -7
  21. data/lib/langchain/llm/azure.rb +1 -1
  22. data/lib/langchain/llm/hugging_face.rb +1 -1
  23. data/lib/langchain/llm/ollama.rb +0 -1
  24. data/lib/langchain/llm/openai.rb +2 -2
  25. data/lib/langchain/llm/parameters/chat.rb +1 -0
  26. data/lib/langchain/llm/replicate.rb +2 -10
  27. data/lib/langchain/version.rb +1 -1
  28. data/lib/langchain.rb +1 -14
  29. metadata +16 -16
  30. data/lib/langchain/assistants/llm/adapters/_base.rb +0 -21
  31. data/lib/langchain/assistants/llm/adapters/anthropic.rb +0 -62
  32. data/lib/langchain/assistants/llm/adapters/ollama.rb +0 -57
  33. data/lib/langchain/assistants/messages/anthropic_message.rb +0 -75
  34. data/lib/langchain/assistants/messages/base.rb +0 -54
  35. data/lib/langchain/assistants/messages/google_gemini_message.rb +0 -90
  36. data/lib/langchain/assistants/messages/mistral_ai_message.rb +0 -96
  37. data/lib/langchain/assistants/messages/ollama_message.rb +0 -74
  38. data/lib/langchain/assistants/messages/openai_message.rb +0 -103
@@ -0,0 +1,56 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain
4
+ class Assistant
5
+ module Messages
6
+ class Base
7
+ attr_reader :role,
8
+ :content,
9
+ :image_url,
10
+ :tool_calls,
11
+ :tool_call_id
12
+
13
+ # Check if the message came from a user
14
+ #
15
+ # @return [Boolean] true/false whether the message came from a user
16
+ def user?
17
+ role == "user"
18
+ end
19
+
20
+ # Check if the message came from an LLM
21
+ #
22
+ # @raise NotImplementedError if the subclass does not implement this method
23
+ def llm?
24
+ raise NotImplementedError, "Class #{self.class.name} must implement the method 'llm?'"
25
+ end
26
+
27
+ # Check if the message is a tool call
28
+ #
29
+ # @raise NotImplementedError if the subclass does not implement this method
30
+ def tool?
31
+ raise NotImplementedError, "Class #{self.class.name} must implement the method 'tool?'"
32
+ end
33
+
34
+ # Check if the message is a system prompt
35
+ #
36
+ # @raise NotImplementedError if the subclass does not implement this method
37
+ def system?
38
+ raise NotImplementedError, "Class #{self.class.name} must implement the method 'system?'"
39
+ end
40
+
41
+ # Returns the standardized role symbol based on the specific role methods
42
+ #
43
+ # @return [Symbol] the standardized role symbol (:system, :llm, :tool, :user, or :unknown)
44
+ def standard_role
45
+ return :user if user?
46
+ return :llm if llm?
47
+ return :tool if tool?
48
+ return :system if system?
49
+
50
+ # TODO: Should we return :unknown or raise an error?
51
+ :unknown
52
+ end
53
+ end
54
+ end
55
+ end
56
+ end
@@ -0,0 +1,92 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain
4
+ class Assistant
5
+ module Messages
6
+ class GoogleGeminiMessage < Base
7
+ # Google Gemini uses the following roles:
8
+ ROLES = [
9
+ "user",
10
+ "model",
11
+ "function"
12
+ ].freeze
13
+
14
+ TOOL_ROLE = "function"
15
+
16
+ # Initialize a new Google Gemini message
17
+ #
18
+ # @param [String] The role of the message
19
+ # @param [String] The content of the message
20
+ # @param [Array<Hash>] The tool calls made in the message
21
+ # @param [String] The ID of the tool call
22
+ def initialize(role:, content: nil, tool_calls: [], tool_call_id: nil)
23
+ raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role)
24
+ raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) }
25
+
26
+ @role = role
27
+ # Some Tools return content as a JSON hence `.to_s`
28
+ @content = content.to_s
29
+ @tool_calls = tool_calls
30
+ @tool_call_id = tool_call_id
31
+ end
32
+
33
+ # Check if the message came from an LLM
34
+ #
35
+ # @return [Boolean] true/false whether this message was produced by an LLM
36
+ def llm?
37
+ model?
38
+ end
39
+
40
+ # Convert the message to a Google Gemini API-compatible hash
41
+ #
42
+ # @return [Hash] The message as a Google Gemini API-compatible hash
43
+ def to_hash
44
+ {}.tap do |h|
45
+ h[:role] = role
46
+ h[:parts] = if function?
47
+ [{
48
+ functionResponse: {
49
+ name: tool_call_id,
50
+ response: {
51
+ name: tool_call_id,
52
+ content: content
53
+ }
54
+ }
55
+ }]
56
+ elsif tool_calls.any?
57
+ tool_calls
58
+ else
59
+ [{text: content}]
60
+ end
61
+ end
62
+ end
63
+
64
+ # Google Gemini does not implement system prompts
65
+ def system?
66
+ false
67
+ end
68
+
69
+ # Check if the message is a tool call
70
+ #
71
+ # @return [Boolean] true/false whether this message is a tool call
72
+ def tool?
73
+ function?
74
+ end
75
+
76
+ # Check if the message is a tool call
77
+ #
78
+ # @return [Boolean] true/false whether this message is a tool call
79
+ def function?
80
+ role == "function"
81
+ end
82
+
83
+ # Check if the message came from an LLM
84
+ #
85
+ # @return [Boolean] true/false whether this message was produced by an LLM
86
+ def model?
87
+ role == "model"
88
+ end
89
+ end
90
+ end
91
+ end
92
+ end
@@ -0,0 +1,98 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain
4
+ class Assistant
5
+ module Messages
6
+ class MistralAIMessage < Base
7
+ # MistralAI uses the following roles:
8
+ ROLES = [
9
+ "system",
10
+ "assistant",
11
+ "user",
12
+ "tool"
13
+ ].freeze
14
+
15
+ TOOL_ROLE = "tool"
16
+
17
+ # Initialize a new MistralAI message
18
+ #
19
+ # @param role [String] The role of the message
20
+ # @param content [String] The content of the message
21
+ # @param image_url [String] The URL of the image
22
+ # @param tool_calls [Array<Hash>] The tool calls made in the message
23
+ # @param tool_call_id [String] The ID of the tool call
24
+ def initialize(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil) # TODO: Implement image_file: reference (https://platform.openai.com/docs/api-reference/messages/object#messages/object-content)
25
+ raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role)
26
+ raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) }
27
+
28
+ @role = role
29
+ # Some Tools return content as a JSON hence `.to_s`
30
+ @content = content.to_s
31
+ # Make sure you're using the Pixtral model if you want to send image_url
32
+ @image_url = image_url
33
+ @tool_calls = tool_calls
34
+ @tool_call_id = tool_call_id
35
+ end
36
+
37
+ # Check if the message came from an LLM
38
+ #
39
+ # @return [Boolean] true/false whether this message was produced by an LLM
40
+ def llm?
41
+ assistant?
42
+ end
43
+
44
+ # Convert the message to an MistralAI API-compatible hash
45
+ #
46
+ # @return [Hash] The message as an MistralAI API-compatible hash
47
+ def to_hash
48
+ {}.tap do |h|
49
+ h[:role] = role
50
+
51
+ if tool_calls.any?
52
+ h[:tool_calls] = tool_calls
53
+ else
54
+ h[:tool_call_id] = tool_call_id if tool_call_id
55
+
56
+ h[:content] = []
57
+
58
+ if content && !content.empty?
59
+ h[:content] << {
60
+ type: "text",
61
+ text: content
62
+ }
63
+ end
64
+
65
+ if image_url
66
+ h[:content] << {
67
+ type: "image_url",
68
+ image_url: image_url
69
+ }
70
+ end
71
+ end
72
+ end
73
+ end
74
+
75
+ # Check if the message came from an LLM
76
+ #
77
+ # @return [Boolean] true/false whether this message was produced by an LLM
78
+ def assistant?
79
+ role == "assistant"
80
+ end
81
+
82
+ # Check if the message are system instructions
83
+ #
84
+ # @return [Boolean] true/false whether this message are system instructions
85
+ def system?
86
+ role == "system"
87
+ end
88
+
89
+ # Check if the message is a tool call
90
+ #
91
+ # @return [Boolean] true/false whether this message is a tool call
92
+ def tool?
93
+ role == "tool"
94
+ end
95
+ end
96
+ end
97
+ end
98
+ end
@@ -0,0 +1,76 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain
4
+ class Assistant
5
+ module Messages
6
+ class OllamaMessage < Base
7
+ # OpenAI uses the following roles:
8
+ ROLES = [
9
+ "system",
10
+ "assistant",
11
+ "user",
12
+ "tool"
13
+ ].freeze
14
+
15
+ TOOL_ROLE = "tool"
16
+
17
+ # Initialize a new OpenAI message
18
+ #
19
+ # @param [String] The role of the message
20
+ # @param [String] The content of the message
21
+ # @param [Array<Hash>] The tool calls made in the message
22
+ # @param [String] The ID of the tool call
23
+ def initialize(role:, content: nil, tool_calls: [], tool_call_id: nil)
24
+ raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role)
25
+ raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) }
26
+
27
+ @role = role
28
+ # Some Tools return content as a JSON hence `.to_s`
29
+ @content = content.to_s
30
+ @tool_calls = tool_calls
31
+ @tool_call_id = tool_call_id
32
+ end
33
+
34
+ # Convert the message to an OpenAI API-compatible hash
35
+ #
36
+ # @return [Hash] The message as an OpenAI API-compatible hash
37
+ def to_hash
38
+ {}.tap do |h|
39
+ h[:role] = role
40
+ h[:content] = content if content # Content is nil for tool calls
41
+ h[:tool_calls] = tool_calls if tool_calls.any?
42
+ h[:tool_call_id] = tool_call_id if tool_call_id
43
+ end
44
+ end
45
+
46
+ # Check if the message came from an LLM
47
+ #
48
+ # @return [Boolean] true/false whether this message was produced by an LLM
49
+ def llm?
50
+ assistant?
51
+ end
52
+
53
+ # Check if the message came from an LLM
54
+ #
55
+ # @return [Boolean] true/false whether this message was produced by an LLM
56
+ def assistant?
57
+ role == "assistant"
58
+ end
59
+
60
+ # Check if the message are system instructions
61
+ #
62
+ # @return [Boolean] true/false whether this message are system instructions
63
+ def system?
64
+ role == "system"
65
+ end
66
+
67
+ # Check if the message is a tool call
68
+ #
69
+ # @return [Boolean] true/false whether this message is a tool call
70
+ def tool?
71
+ role == "tool"
72
+ end
73
+ end
74
+ end
75
+ end
76
+ end
@@ -0,0 +1,105 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain
4
+ class Assistant
5
+ module Messages
6
+ class OpenAIMessage < Base
7
+ # OpenAI uses the following roles:
8
+ ROLES = [
9
+ "system",
10
+ "assistant",
11
+ "user",
12
+ "tool"
13
+ ].freeze
14
+
15
+ TOOL_ROLE = "tool"
16
+
17
+ # Initialize a new OpenAI message
18
+ #
19
+ # @param role [String] The role of the message
20
+ # @param content [String] The content of the message
21
+ # @param image_url [String] The URL of the image
22
+ # @param tool_calls [Array<Hash>] The tool calls made in the message
23
+ # @param tool_call_id [String] The ID of the tool call
24
+ def initialize(
25
+ role:,
26
+ content: nil,
27
+ image_url: nil,
28
+ tool_calls: [],
29
+ tool_call_id: nil
30
+ )
31
+ raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role)
32
+ raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) }
33
+
34
+ @role = role
35
+ # Some Tools return content as a JSON hence `.to_s`
36
+ @content = content.to_s
37
+ @image_url = image_url
38
+ @tool_calls = tool_calls
39
+ @tool_call_id = tool_call_id
40
+ end
41
+
42
+ # Check if the message came from an LLM
43
+ #
44
+ # @return [Boolean] true/false whether this message was produced by an LLM
45
+ def llm?
46
+ assistant?
47
+ end
48
+
49
+ # Convert the message to an OpenAI API-compatible hash
50
+ #
51
+ # @return [Hash] The message as an OpenAI API-compatible hash
52
+ def to_hash
53
+ {}.tap do |h|
54
+ h[:role] = role
55
+
56
+ if tool_calls.any?
57
+ h[:tool_calls] = tool_calls
58
+ else
59
+ h[:tool_call_id] = tool_call_id if tool_call_id
60
+
61
+ h[:content] = []
62
+
63
+ if content && !content.empty?
64
+ h[:content] << {
65
+ type: "text",
66
+ text: content
67
+ }
68
+ end
69
+
70
+ if image_url
71
+ h[:content] << {
72
+ type: "image_url",
73
+ image_url: {
74
+ url: image_url
75
+ }
76
+ }
77
+ end
78
+ end
79
+ end
80
+ end
81
+
82
+ # Check if the message came from an LLM
83
+ #
84
+ # @return [Boolean] true/false whether this message was produced by an LLM
85
+ def assistant?
86
+ role == "assistant"
87
+ end
88
+
89
+ # Check if the message are system instructions
90
+ #
91
+ # @return [Boolean] true/false whether this message are system instructions
92
+ def system?
93
+ role == "system"
94
+ end
95
+
96
+ # Check if the message is a tool call
97
+ #
98
+ # @return [Boolean] true/false whether this message is a tool call
99
+ def tool?
100
+ role == "tool"
101
+ end
102
+ end
103
+ end
104
+ end
105
+ end
@@ -1,7 +1,5 @@
1
1
  # frozen_string_literal: true
2
2
 
3
- require_relative "llm/adapter"
4
-
5
3
  module Langchain
6
4
  # Assistants are Agent-like objects that leverage helpful instructions, LLMs, tools and knowledge to respond to user queries.
7
5
  # Assistants can be configured with an LLM of your choice, any vector search database and easily extended with additional tools.
@@ -14,9 +12,19 @@ module Langchain
14
12
  # tools: [Langchain::Tool::NewsRetriever.new(api_key: ENV["NEWS_API_KEY"])]
15
13
  # )
16
14
  class Assistant
17
- attr_reader :llm, :instructions, :state, :llm_adapter, :tool_choice
18
- attr_reader :total_prompt_tokens, :total_completion_tokens, :total_tokens, :messages
19
- attr_accessor :tools, :add_message_callback
15
+ attr_reader :llm,
16
+ :instructions,
17
+ :state,
18
+ :llm_adapter,
19
+ :messages,
20
+ :tool_choice,
21
+ :total_prompt_tokens,
22
+ :total_completion_tokens,
23
+ :total_tokens
24
+
25
+ attr_accessor :tools,
26
+ :add_message_callback,
27
+ :parallel_tool_calls
20
28
 
21
29
  # Create a new assistant
22
30
  #
@@ -24,12 +32,15 @@ module Langchain
24
32
  # @param tools [Array<Langchain::Tool::Base>] Tools that the assistant has access to
25
33
  # @param instructions [String] The system instructions
26
34
  # @param tool_choice [String] Specify how tools should be selected. Options: "auto", "any", "none", or <specific function name>
27
- # @params add_message_callback [Proc] A callback function (Proc or lambda) that is called when any message is added to the conversation
35
+ # @param parallel_tool_calls [Boolean] Whether or not to run tools in parallel
36
+ # @param messages [Array<Langchain::Assistant::Messages::Base>] The messages
37
+ # @param add_message_callback [Proc] A callback function (Proc or lambda) that is called when any message is added to the conversation
28
38
  def initialize(
29
39
  llm:,
30
40
  tools: [],
31
41
  instructions: nil,
32
42
  tool_choice: "auto",
43
+ parallel_tool_calls: true,
33
44
  messages: [],
34
45
  add_message_callback: nil,
35
46
  &block
@@ -49,18 +60,15 @@ module Langchain
49
60
 
50
61
  self.messages = messages
51
62
  @tools = tools
63
+ @parallel_tool_calls = parallel_tool_calls
52
64
  self.tool_choice = tool_choice
53
- @instructions = instructions
65
+ self.instructions = instructions
54
66
  @block = block
55
67
  @state = :ready
56
68
 
57
69
  @total_prompt_tokens = 0
58
70
  @total_completion_tokens = 0
59
71
  @total_tokens = 0
60
-
61
- # The first message in the messages array should be the system instructions
62
- # For Google Gemini, and Anthropic system instructions are added to the `system:` param in the `chat` method
63
- initialize_instructions
64
72
  end
65
73
 
66
74
  # Add a user message to the messages array
@@ -104,7 +112,7 @@ module Langchain
104
112
  # @param messages [Array<Langchain::Message>] The messages to set
105
113
  # @return [Array<Langchain::Message>] The messages
106
114
  def messages=(messages)
107
- raise ArgumentError, "messages array must only contain Langchain::Message instance(s)" unless messages.is_a?(Array) && messages.all? { |m| m.is_a?(Langchain::Messages::Base) }
115
+ raise ArgumentError, "messages array must only contain Langchain::Message instance(s)" unless messages.is_a?(Array) && messages.all? { |m| m.is_a?(Messages::Base) }
108
116
 
109
117
  @messages = messages
110
118
  end
@@ -167,10 +175,8 @@ module Langchain
167
175
  # @param output [String] The output of the tool
168
176
  # @return [Array<Langchain::Message>] The messages
169
177
  def submit_tool_output(tool_call_id:, output:)
170
- tool_role = determine_tool_role
171
-
172
178
  # TODO: Validate that `tool_call_id` is valid by scanning messages and checking if this tool call ID was invoked
173
- add_message(role: tool_role, content: output, tool_call_id: tool_call_id)
179
+ add_message(role: @llm_adapter.tool_role, content: output, tool_call_id: tool_call_id)
174
180
  end
175
181
 
176
182
  # Delete all messages
@@ -181,9 +187,6 @@ module Langchain
181
187
  @messages = []
182
188
  end
183
189
 
184
- # TODO: Remove in the next major release
185
- alias_method :clear_thread!, :clear_messages!
186
-
187
190
  # Set new instructions
188
191
  #
189
192
  # @param new_instructions [String] New instructions that will be set as a system message
@@ -191,12 +194,9 @@ module Langchain
191
194
  def instructions=(new_instructions)
192
195
  @instructions = new_instructions
193
196
 
194
- # This only needs to be done that support Message#@role="system"
195
- if !llm.is_a?(Langchain::LLM::GoogleGemini) &&
196
- !llm.is_a?(Langchain::LLM::GoogleVertexAI) &&
197
- !llm.is_a?(Langchain::LLM::Anthropic)
198
- # Find message with role: "system" in messages and delete it from the messages array
199
- replace_system_message!(content: new_instructions)
197
+ if @llm_adapter.support_system_message?
198
+ # TODO: Should we still set a system message even if @instructions is "" or nil?
199
+ replace_system_message!(content: new_instructions) if @instructions
200
200
  end
201
201
  end
202
202
 
@@ -330,30 +330,6 @@ module Langchain
330
330
  :failed
331
331
  end
332
332
 
333
- # Determine the tool role based on the LLM type
334
- #
335
- # @return [String] The tool role
336
- def determine_tool_role
337
- case llm
338
- when Langchain::LLM::Anthropic
339
- Langchain::Messages::AnthropicMessage::TOOL_ROLE
340
- when Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI
341
- Langchain::Messages::GoogleGeminiMessage::TOOL_ROLE
342
- when Langchain::LLM::MistralAI
343
- Langchain::Messages::MistralAIMessage::TOOL_ROLE
344
- when Langchain::LLM::Ollama
345
- Langchain::Messages::OllamaMessage::TOOL_ROLE
346
- when Langchain::LLM::OpenAI
347
- Langchain::Messages::OpenAIMessage::TOOL_ROLE
348
- end
349
- end
350
-
351
- def initialize_instructions
352
- if llm.is_a?(Langchain::LLM::OpenAI) || llm.is_a?(Langchain::LLM::MistralAI)
353
- self.instructions = @instructions if @instructions
354
- end
355
- end
356
-
357
333
  # Call to the LLM#chat() method
358
334
  #
359
335
  # @return [Langchain::LLM::BaseResponse] The LLM response object
@@ -364,7 +340,8 @@ module Langchain
364
340
  instructions: @instructions,
365
341
  messages: array_of_message_hashes,
366
342
  tools: @tools,
367
- tool_choice: tool_choice
343
+ tool_choice: tool_choice,
344
+ parallel_tool_calls: parallel_tool_calls
368
345
  )
369
346
  @llm.chat(**params, &@block)
370
347
  end
@@ -8,7 +8,7 @@ module Langchain::LLM
8
8
  # gem "ai21", "~> 0.2.1"
9
9
  #
10
10
  # Usage:
11
- # ai21 = Langchain::LLM::AI21.new(api_key: ENV["AI21_API_KEY"])
11
+ # llm = Langchain::LLM::AI21.new(api_key: ENV["AI21_API_KEY"])
12
12
  #
13
13
  class AI21 < Base
14
14
  DEFAULTS = {
@@ -5,10 +5,10 @@ module Langchain::LLM
5
5
  # Wrapper around Anthropic APIs.
6
6
  #
7
7
  # Gem requirements:
8
- # gem "anthropic", "~> 0.3.0"
8
+ # gem "anthropic", "~> 0.3.2"
9
9
  #
10
10
  # Usage:
11
- # anthropic = Langchain::LLM::Anthropic.new(api_key: ENV["ANTHROPIC_API_KEY"])
11
+ # llm = Langchain::LLM::Anthropic.new(api_key: ENV["ANTHROPIC_API_KEY"])
12
12
  #
13
13
  class Anthropic < Base
14
14
  DEFAULTS = {
@@ -100,7 +100,7 @@ module Langchain::LLM
100
100
  # @option params [Integer] :top_k Only sample from the top K options for each subsequent token
101
101
  # @option params [Float] :top_p Use nucleus sampling.
102
102
  # @return [Langchain::LLM::AnthropicResponse] The chat completion
103
- def chat(params = {})
103
+ def chat(params = {}, &block)
104
104
  set_extra_headers! if params[:tools]
105
105
 
106
106
  parameters = chat_parameters.to_params(params)
@@ -109,9 +109,19 @@ module Langchain::LLM
109
109
  raise ArgumentError.new("model argument is required") if parameters[:model].empty?
110
110
  raise ArgumentError.new("max_tokens argument is required") if parameters[:max_tokens].nil?
111
111
 
112
- binding.pry
112
+ if block
113
+ @response_chunks = []
114
+ parameters[:stream] = proc do |chunk|
115
+ @response_chunks << chunk
116
+ yield chunk
117
+ end
118
+ end
119
+
113
120
  response = client.messages(parameters: parameters)
114
121
 
122
+ response = response_from_chunks if block
123
+ reset_response_chunks
124
+
115
125
  Langchain::LLM::AnthropicResponse.new(response)
116
126
  end
117
127
 
@@ -124,8 +134,53 @@ module Langchain::LLM
124
134
  response
125
135
  end
126
136
 
137
+ def response_from_chunks
138
+ grouped_chunks = @response_chunks.group_by { |chunk| chunk["index"] }.except(nil)
139
+
140
+ usage = @response_chunks.find { |chunk| chunk["type"] == "message_delta" }&.dig("usage")
141
+ stop_reason = @response_chunks.find { |chunk| chunk["type"] == "message_delta" }&.dig("delta", "stop_reason")
142
+
143
+ content = grouped_chunks.map do |_index, chunks|
144
+ text = chunks.map { |chunk| chunk.dig("delta", "text") }.join
145
+ if !text.nil? && !text.empty?
146
+ {"type" => "text", "text" => text}
147
+ else
148
+ tool_calls_from_choice_chunks(chunks)
149
+ end
150
+ end.flatten
151
+
152
+ @response_chunks.first&.slice("id", "object", "created", "model")
153
+ &.merge!(
154
+ {
155
+ "content" => content,
156
+ "usage" => usage,
157
+ "role" => "assistant",
158
+ "stop_reason" => stop_reason
159
+ }
160
+ )
161
+ end
162
+
163
+ def tool_calls_from_choice_chunks(chunks)
164
+ return unless (first_block = chunks.find { |chunk| chunk.dig("content_block", "type") == "tool_use" })
165
+
166
+ chunks.group_by { |chunk| chunk["index"] }.map do |index, chunks|
167
+ input = chunks.select { |chunk| chunk.dig("delta", "partial_json") }
168
+ .map! { |chunk| chunk.dig("delta", "partial_json") }.join
169
+ {
170
+ "id" => first_block.dig("content_block", "id"),
171
+ "type" => "tool_use",
172
+ "name" => first_block.dig("content_block", "name"),
173
+ "input" => JSON.parse(input).transform_keys(&:to_sym)
174
+ }
175
+ end.compact
176
+ end
177
+
127
178
  private
128
179
 
180
+ def reset_response_chunks
181
+ @response_chunks = []
182
+ end
183
+
129
184
  def set_extra_headers!
130
185
  ::Anthropic.configuration.extra_headers = {"anthropic-beta": "tools-2024-05-16"}
131
186
  end