langchainrb 0.17.1 → 0.18.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +10 -0
  3. data/README.md +5 -0
  4. data/lib/langchain/{assistants → assistant}/llm/adapter.rb +1 -1
  5. data/lib/langchain/assistant/llm/adapters/anthropic.rb +105 -0
  6. data/lib/langchain/assistant/llm/adapters/base.rb +63 -0
  7. data/lib/langchain/{assistants → assistant}/llm/adapters/google_gemini.rb +43 -3
  8. data/lib/langchain/{assistants → assistant}/llm/adapters/mistral_ai.rb +39 -2
  9. data/lib/langchain/assistant/llm/adapters/ollama.rb +94 -0
  10. data/lib/langchain/{assistants → assistant}/llm/adapters/openai.rb +38 -2
  11. data/lib/langchain/assistant/messages/anthropic_message.rb +77 -0
  12. data/lib/langchain/assistant/messages/base.rb +56 -0
  13. data/lib/langchain/assistant/messages/google_gemini_message.rb +92 -0
  14. data/lib/langchain/assistant/messages/mistral_ai_message.rb +98 -0
  15. data/lib/langchain/assistant/messages/ollama_message.rb +76 -0
  16. data/lib/langchain/assistant/messages/openai_message.rb +105 -0
  17. data/lib/langchain/{assistants/assistant.rb → assistant.rb} +26 -49
  18. data/lib/langchain/llm/ai21.rb +1 -1
  19. data/lib/langchain/llm/anthropic.rb +59 -4
  20. data/lib/langchain/llm/aws_bedrock.rb +6 -7
  21. data/lib/langchain/llm/azure.rb +1 -1
  22. data/lib/langchain/llm/hugging_face.rb +1 -1
  23. data/lib/langchain/llm/ollama.rb +0 -1
  24. data/lib/langchain/llm/openai.rb +2 -2
  25. data/lib/langchain/llm/parameters/chat.rb +1 -0
  26. data/lib/langchain/llm/replicate.rb +2 -10
  27. data/lib/langchain/version.rb +1 -1
  28. data/lib/langchain.rb +1 -14
  29. metadata +16 -16
  30. data/lib/langchain/assistants/llm/adapters/_base.rb +0 -21
  31. data/lib/langchain/assistants/llm/adapters/anthropic.rb +0 -62
  32. data/lib/langchain/assistants/llm/adapters/ollama.rb +0 -57
  33. data/lib/langchain/assistants/messages/anthropic_message.rb +0 -75
  34. data/lib/langchain/assistants/messages/base.rb +0 -54
  35. data/lib/langchain/assistants/messages/google_gemini_message.rb +0 -90
  36. data/lib/langchain/assistants/messages/mistral_ai_message.rb +0 -96
  37. data/lib/langchain/assistants/messages/ollama_message.rb +0 -74
  38. data/lib/langchain/assistants/messages/openai_message.rb +0 -103
@@ -7,12 +7,13 @@ module Langchain::LLM
7
7
  # gem 'aws-sdk-bedrockruntime', '~> 1.1'
8
8
  #
9
9
  # Usage:
10
- # bedrock = Langchain::LLM::AwsBedrock.new(llm_options: {})
10
+ # llm = Langchain::LLM::AwsBedrock.new(llm_options: {})
11
11
  #
12
12
  class AwsBedrock < Base
13
13
  DEFAULTS = {
14
+ chat_completion_model_name: "anthropic.claude-v2",
14
15
  completion_model_name: "anthropic.claude-v2",
15
- embedding_model_name: "amazon.titan-embed-text-v1",
16
+ embeddings_model_name: "amazon.titan-embed-text-v1",
16
17
  max_tokens_to_sample: 300,
17
18
  temperature: 1,
18
19
  top_k: 250,
@@ -52,13 +53,11 @@ module Langchain::LLM
52
53
  SUPPORTED_CHAT_COMPLETION_PROVIDERS = %i[anthropic].freeze
53
54
  SUPPORTED_EMBEDDING_PROVIDERS = %i[amazon cohere].freeze
54
55
 
55
- def initialize(completion_model: DEFAULTS[:completion_model_name], embedding_model: DEFAULTS[:embedding_model_name], aws_client_options: {}, default_options: {})
56
+ def initialize(aws_client_options: {}, default_options: {})
56
57
  depends_on "aws-sdk-bedrockruntime", req: "aws-sdk-bedrockruntime"
57
58
 
58
59
  @client = ::Aws::BedrockRuntime::Client.new(**aws_client_options)
59
60
  @defaults = DEFAULTS.merge(default_options)
60
- .merge(completion_model_name: completion_model)
61
- .merge(embedding_model_name: embedding_model)
62
61
 
63
62
  chat_parameters.update(
64
63
  model: {default: @defaults[:chat_completion_model_name]},
@@ -85,7 +84,7 @@ module Langchain::LLM
85
84
  parameters = compose_embedding_parameters params.merge(text:)
86
85
 
87
86
  response = client.invoke_model({
88
- model_id: @defaults[:embedding_model_name],
87
+ model_id: @defaults[:embeddings_model_name],
89
88
  body: parameters.to_json,
90
89
  content_type: "application/json",
91
90
  accept: "application/json"
@@ -180,7 +179,7 @@ module Langchain::LLM
180
179
  end
181
180
 
182
181
  def embedding_provider
183
- @defaults[:embedding_model_name].split(".").first.to_sym
182
+ @defaults[:embeddings_model_name].split(".").first.to_sym
184
183
  end
185
184
 
186
185
  def wrap_prompt(prompt)
@@ -7,7 +7,7 @@ module Langchain::LLM
7
7
  # gem "ruby-openai", "~> 6.3.0"
8
8
  #
9
9
  # Usage:
10
- # openai = Langchain::LLM::Azure.new(api_key:, llm_options: {}, embedding_deployment_url: chat_deployment_url:)
10
+ # llm = Langchain::LLM::Azure.new(api_key:, llm_options: {}, embedding_deployment_url: chat_deployment_url:)
11
11
  #
12
12
  class Azure < OpenAI
13
13
  attr_reader :embed_client
@@ -8,7 +8,7 @@ module Langchain::LLM
8
8
  # gem "hugging-face", "~> 0.3.4"
9
9
  #
10
10
  # Usage:
11
- # hf = Langchain::LLM::HuggingFace.new(api_key: ENV["HUGGING_FACE_API_KEY"])
11
+ # llm = Langchain::LLM::HuggingFace.new(api_key: ENV["HUGGING_FACE_API_KEY"])
12
12
  #
13
13
  class HuggingFace < Base
14
14
  DEFAULTS = {
@@ -5,7 +5,6 @@ module Langchain::LLM
5
5
  # Available models: https://ollama.ai/library
6
6
  #
7
7
  # Usage:
8
- # llm = Langchain::LLM::Ollama.new
9
8
  # llm = Langchain::LLM::Ollama.new(url: ENV["OLLAMA_URL"], default_options: {})
10
9
  #
11
10
  class Ollama < Base
@@ -7,7 +7,7 @@ module Langchain::LLM
7
7
  # gem "ruby-openai", "~> 6.3.0"
8
8
  #
9
9
  # Usage:
10
- # openai = Langchain::LLM::OpenAI.new(
10
+ # llm = Langchain::LLM::OpenAI.new(
11
11
  # api_key: ENV["OPENAI_API_KEY"],
12
12
  # llm_options: {}, # Available options: https://github.com/alexrudall/ruby-openai/blob/main/lib/openai/client.rb#L5-L13
13
13
  # default_options: {}
@@ -100,7 +100,7 @@ module Langchain::LLM
100
100
  # @param params [Hash] The parameters to pass to the `chat()` method
101
101
  # @return [Langchain::LLM::OpenAIResponse] Response object
102
102
  def complete(prompt:, **params)
103
- warn "DEPRECATED: `Langchain::LLM::OpenAI#complete` is deprecated, and will be removed in the next major version. Use `Langchain::LLM::OpenAI#chat` instead."
103
+ Langchain.logger.warn "DEPRECATED: `Langchain::LLM::OpenAI#complete` is deprecated, and will be removed in the next major version. Use `Langchain::LLM::OpenAI#chat` instead."
104
104
 
105
105
  if params[:stop_sequences]
106
106
  params[:stop] = params.delete(:stop_sequences)
@@ -34,6 +34,7 @@ module Langchain::LLM::Parameters
34
34
  # Function-calling
35
35
  tools: {default: []},
36
36
  tool_choice: {},
37
+ parallel_tool_calls: {},
37
38
 
38
39
  # Additional optional parameters
39
40
  logit_bias: {}
@@ -7,16 +7,8 @@ module Langchain::LLM
7
7
  # Gem requirements:
8
8
  # gem "replicate-ruby", "~> 0.2.2"
9
9
  #
10
- # Use it directly:
11
- # replicate = Langchain::LLM::Replicate.new(api_key: ENV["REPLICATE_API_KEY"])
12
- #
13
- # Or pass it to be used by a vector search DB:
14
- # chroma = Langchain::Vectorsearch::Chroma.new(
15
- # url: ENV["CHROMA_URL"],
16
- # index_name: "...",
17
- # llm: replicate
18
- # )
19
- #
10
+ # Usage:
11
+ # llm = Langchain::LLM::Replicate.new(api_key: ENV["REPLICATE_API_KEY"])
20
12
  class Replicate < Base
21
13
  DEFAULTS = {
22
14
  # TODO: Figure out how to send the temperature to the API
@@ -1,5 +1,5 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module Langchain
4
- VERSION = "0.17.1"
4
+ VERSION = "0.18.0"
5
5
  end
data/lib/langchain.rb CHANGED
@@ -22,25 +22,12 @@ loader.inflector.inflect(
22
22
  "mistral_ai_response" => "MistralAIResponse",
23
23
  "mistral_ai_message" => "MistralAIMessage",
24
24
  "openai" => "OpenAI",
25
- "openai_validator" => "OpenAIValidator",
26
25
  "openai_response" => "OpenAIResponse",
27
26
  "openai_message" => "OpenAIMessage",
28
27
  "pdf" => "PDF"
29
28
  )
29
+
30
30
  loader.collapse("#{__dir__}/langchain/llm/response")
31
- loader.collapse("#{__dir__}/langchain/assistants")
32
-
33
- loader.collapse("#{__dir__}/langchain/tool/calculator")
34
- loader.collapse("#{__dir__}/langchain/tool/database")
35
- loader.collapse("#{__dir__}/langchain/tool/docs_tool")
36
- loader.collapse("#{__dir__}/langchain/tool/file_system")
37
- loader.collapse("#{__dir__}/langchain/tool/google_search")
38
- loader.collapse("#{__dir__}/langchain/tool/ruby_code_interpreter")
39
- loader.collapse("#{__dir__}/langchain/tool/news_retriever")
40
- loader.collapse("#{__dir__}/langchain/tool/tavily")
41
- loader.collapse("#{__dir__}/langchain/tool/vectorsearch")
42
- loader.collapse("#{__dir__}/langchain/tool/weather")
43
- loader.collapse("#{__dir__}/langchain/tool/wikipedia")
44
31
 
45
32
  # RubyCodeInterpreter does not work with Ruby 3.3;
46
33
  # https://github.com/ukutaht/safe_ruby/issues/4
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: langchainrb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.17.1
4
+ version: 0.18.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrei Bondarev
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2024-10-07 00:00:00.000000000 Z
11
+ date: 2024-10-12 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: baran
@@ -637,20 +637,20 @@ files:
637
637
  - LICENSE.txt
638
638
  - README.md
639
639
  - lib/langchain.rb
640
- - lib/langchain/assistants/assistant.rb
641
- - lib/langchain/assistants/llm/adapter.rb
642
- - lib/langchain/assistants/llm/adapters/_base.rb
643
- - lib/langchain/assistants/llm/adapters/anthropic.rb
644
- - lib/langchain/assistants/llm/adapters/google_gemini.rb
645
- - lib/langchain/assistants/llm/adapters/mistral_ai.rb
646
- - lib/langchain/assistants/llm/adapters/ollama.rb
647
- - lib/langchain/assistants/llm/adapters/openai.rb
648
- - lib/langchain/assistants/messages/anthropic_message.rb
649
- - lib/langchain/assistants/messages/base.rb
650
- - lib/langchain/assistants/messages/google_gemini_message.rb
651
- - lib/langchain/assistants/messages/mistral_ai_message.rb
652
- - lib/langchain/assistants/messages/ollama_message.rb
653
- - lib/langchain/assistants/messages/openai_message.rb
640
+ - lib/langchain/assistant.rb
641
+ - lib/langchain/assistant/llm/adapter.rb
642
+ - lib/langchain/assistant/llm/adapters/anthropic.rb
643
+ - lib/langchain/assistant/llm/adapters/base.rb
644
+ - lib/langchain/assistant/llm/adapters/google_gemini.rb
645
+ - lib/langchain/assistant/llm/adapters/mistral_ai.rb
646
+ - lib/langchain/assistant/llm/adapters/ollama.rb
647
+ - lib/langchain/assistant/llm/adapters/openai.rb
648
+ - lib/langchain/assistant/messages/anthropic_message.rb
649
+ - lib/langchain/assistant/messages/base.rb
650
+ - lib/langchain/assistant/messages/google_gemini_message.rb
651
+ - lib/langchain/assistant/messages/mistral_ai_message.rb
652
+ - lib/langchain/assistant/messages/ollama_message.rb
653
+ - lib/langchain/assistant/messages/openai_message.rb
654
654
  - lib/langchain/chunk.rb
655
655
  - lib/langchain/chunker/base.rb
656
656
  - lib/langchain/chunker/markdown.rb
@@ -1,21 +0,0 @@
1
- module Langchain
2
- class Assistant
3
- module LLM
4
- module Adapters
5
- class Base
6
- def build_chat_params(tools:, instructions:, messages:, tool_choice:)
7
- raise NotImplementedError, "Subclasses must implement build_chat_params"
8
- end
9
-
10
- def extract_tool_call_args(tool_call:)
11
- raise NotImplementedError, "Subclasses must implement extract_tool_call_args"
12
- end
13
-
14
- def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
15
- raise NotImplementedError, "Subclasses must implement build_message"
16
- end
17
- end
18
- end
19
- end
20
- end
21
- end
@@ -1,62 +0,0 @@
1
- module Langchain
2
- class Assistant
3
- module LLM
4
- module Adapters
5
- class Anthropic < Base
6
- def build_chat_params(tools:, instructions:, messages:, tool_choice:)
7
- params = {messages: messages}
8
- if tools.any?
9
- params[:tools] = build_tools(tools)
10
- params[:tool_choice] = build_tool_choice(tool_choice)
11
- end
12
- params[:system] = instructions if instructions
13
- params
14
- end
15
-
16
- def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
17
- warn "Image URL is not supported by Anthropic currently" if image_url
18
-
19
- Langchain::Messages::AnthropicMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
20
- end
21
-
22
- # Extract the tool call information from the Anthropic tool call hash
23
- #
24
- # @param tool_call [Hash] The tool call hash, format: {"type"=>"tool_use", "id"=>"toolu_01TjusbFApEbwKPRWTRwzadR", "name"=>"news_retriever__get_top_headlines", "input"=>{"country"=>"us", "page_size"=>10}}], "stop_reason"=>"tool_use"}
25
- # @return [Array] The tool call information
26
- def extract_tool_call_args(tool_call:)
27
- tool_call_id = tool_call.dig("id")
28
- function_name = tool_call.dig("name")
29
- tool_name, method_name = function_name.split("__")
30
- tool_arguments = tool_call.dig("input").transform_keys(&:to_sym)
31
- [tool_call_id, tool_name, method_name, tool_arguments]
32
- end
33
-
34
- def build_tools(tools)
35
- tools.map { |tool| tool.class.function_schemas.to_anthropic_format }.flatten
36
- end
37
-
38
- def allowed_tool_choices
39
- ["auto", "any"]
40
- end
41
-
42
- def available_tool_names(tools)
43
- build_tools(tools).map { |tool| tool.dig(:name) }
44
- end
45
-
46
- private
47
-
48
- def build_tool_choice(choice)
49
- case choice
50
- when "auto"
51
- {type: "auto"}
52
- when "any"
53
- {type: "any"}
54
- else
55
- {type: "tool", name: choice}
56
- end
57
- end
58
- end
59
- end
60
- end
61
- end
62
- end
@@ -1,57 +0,0 @@
1
- module Langchain
2
- class Assistant
3
- module LLM
4
- module Adapters
5
- class Ollama < Base
6
- def build_chat_params(tools:, instructions:, messages:, tool_choice:)
7
- params = {messages: messages}
8
- if tools.any?
9
- params[:tools] = build_tools(tools)
10
- end
11
- params
12
- end
13
-
14
- def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
15
- warn "Image URL is not supported by Ollama currently" if image_url
16
-
17
- Langchain::Messages::OllamaMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
18
- end
19
-
20
- # Extract the tool call information from the OpenAI tool call hash
21
- #
22
- # @param tool_call [Hash] The tool call hash
23
- # @return [Array] The tool call information
24
- def extract_tool_call_args(tool_call:)
25
- tool_call_id = tool_call.dig("id")
26
-
27
- function_name = tool_call.dig("function", "name")
28
- tool_name, method_name = function_name.split("__")
29
-
30
- tool_arguments = tool_call.dig("function", "arguments")
31
- tool_arguments = if tool_arguments.is_a?(Hash)
32
- Langchain::Utils::HashTransformer.symbolize_keys(tool_arguments)
33
- else
34
- JSON.parse(tool_arguments, symbolize_names: true)
35
- end
36
-
37
- [tool_call_id, tool_name, method_name, tool_arguments]
38
- end
39
-
40
- def available_tool_names(tools)
41
- build_tools(tools).map { |tool| tool.dig(:function, :name) }
42
- end
43
-
44
- def allowed_tool_choices
45
- ["auto", "none"]
46
- end
47
-
48
- private
49
-
50
- def build_tools(tools)
51
- tools.map { |tool| tool.class.function_schemas.to_openai_format }.flatten
52
- end
53
- end
54
- end
55
- end
56
- end
57
- end
@@ -1,75 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module Langchain
4
- module Messages
5
- class AnthropicMessage < Base
6
- ROLES = [
7
- "assistant",
8
- "user",
9
- "tool_result"
10
- ].freeze
11
-
12
- TOOL_ROLE = "tool_result"
13
-
14
- def initialize(role:, content: nil, tool_calls: [], tool_call_id: nil)
15
- raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role)
16
- raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) }
17
-
18
- @role = role
19
- # Some Tools return content as a JSON hence `.to_s`
20
- @content = content.to_s
21
- @tool_calls = tool_calls
22
- @tool_call_id = tool_call_id
23
- end
24
-
25
- # Convert the message to an Anthropic API-compatible hash
26
- #
27
- # @return [Hash] The message as an Anthropic API-compatible hash
28
- def to_hash
29
- {}.tap do |h|
30
- h[:role] = tool? ? "user" : role
31
-
32
- h[:content] = if tool?
33
- [
34
- {
35
- type: "tool_result",
36
- tool_use_id: tool_call_id,
37
- content: content
38
- }
39
- ]
40
- elsif tool_calls.any?
41
- tool_calls
42
- else
43
- content
44
- end
45
- end
46
- end
47
-
48
- # Check if the message is a tool call
49
- #
50
- # @return [Boolean] true/false whether this message is a tool call
51
- def tool?
52
- role == "tool_result"
53
- end
54
-
55
- # Anthropic does not implement system prompts
56
- def system?
57
- false
58
- end
59
-
60
- # Check if the message came from an LLM
61
- #
62
- # @return [Boolean] true/false whether this message was produced by an LLM
63
- def assistant?
64
- role == "assistant"
65
- end
66
-
67
- # Check if the message came from an LLM
68
- #
69
- # @return [Boolean] true/false whether this message was produced by an LLM
70
- def llm?
71
- assistant?
72
- end
73
- end
74
- end
75
- end
@@ -1,54 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module Langchain
4
- module Messages
5
- class Base
6
- attr_reader :role,
7
- :content,
8
- :image_url,
9
- :tool_calls,
10
- :tool_call_id
11
-
12
- # Check if the message came from a user
13
- #
14
- # @return [Boolean] true/false whether the message came from a user
15
- def user?
16
- role == "user"
17
- end
18
-
19
- # Check if the message came from an LLM
20
- #
21
- # @raise NotImplementedError if the subclass does not implement this method
22
- def llm?
23
- raise NotImplementedError, "Class #{self.class.name} must implement the method 'llm?'"
24
- end
25
-
26
- # Check if the message is a tool call
27
- #
28
- # @raise NotImplementedError if the subclass does not implement this method
29
- def tool?
30
- raise NotImplementedError, "Class #{self.class.name} must implement the method 'tool?'"
31
- end
32
-
33
- # Check if the message is a system prompt
34
- #
35
- # @raise NotImplementedError if the subclass does not implement this method
36
- def system?
37
- raise NotImplementedError, "Class #{self.class.name} must implement the method 'system?'"
38
- end
39
-
40
- # Returns the standardized role symbol based on the specific role methods
41
- #
42
- # @return [Symbol] the standardized role symbol (:system, :llm, :tool, :user, or :unknown)
43
- def standard_role
44
- return :user if user?
45
- return :llm if llm?
46
- return :tool if tool?
47
- return :system if system?
48
-
49
- # TODO: Should we return :unknown or raise an error?
50
- :unknown
51
- end
52
- end
53
- end
54
- end
@@ -1,90 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module Langchain
4
- module Messages
5
- class GoogleGeminiMessage < Base
6
- # Google Gemini uses the following roles:
7
- ROLES = [
8
- "user",
9
- "model",
10
- "function"
11
- ].freeze
12
-
13
- TOOL_ROLE = "function"
14
-
15
- # Initialize a new Google Gemini message
16
- #
17
- # @param [String] The role of the message
18
- # @param [String] The content of the message
19
- # @param [Array<Hash>] The tool calls made in the message
20
- # @param [String] The ID of the tool call
21
- def initialize(role:, content: nil, tool_calls: [], tool_call_id: nil)
22
- raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role)
23
- raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) }
24
-
25
- @role = role
26
- # Some Tools return content as a JSON hence `.to_s`
27
- @content = content.to_s
28
- @tool_calls = tool_calls
29
- @tool_call_id = tool_call_id
30
- end
31
-
32
- # Check if the message came from an LLM
33
- #
34
- # @return [Boolean] true/false whether this message was produced by an LLM
35
- def llm?
36
- model?
37
- end
38
-
39
- # Convert the message to a Google Gemini API-compatible hash
40
- #
41
- # @return [Hash] The message as a Google Gemini API-compatible hash
42
- def to_hash
43
- {}.tap do |h|
44
- h[:role] = role
45
- h[:parts] = if function?
46
- [{
47
- functionResponse: {
48
- name: tool_call_id,
49
- response: {
50
- name: tool_call_id,
51
- content: content
52
- }
53
- }
54
- }]
55
- elsif tool_calls.any?
56
- tool_calls
57
- else
58
- [{text: content}]
59
- end
60
- end
61
- end
62
-
63
- # Google Gemini does not implement system prompts
64
- def system?
65
- false
66
- end
67
-
68
- # Check if the message is a tool call
69
- #
70
- # @return [Boolean] true/false whether this message is a tool call
71
- def tool?
72
- function?
73
- end
74
-
75
- # Check if the message is a tool call
76
- #
77
- # @return [Boolean] true/false whether this message is a tool call
78
- def function?
79
- role == "function"
80
- end
81
-
82
- # Check if the message came from an LLM
83
- #
84
- # @return [Boolean] true/false whether this message was produced by an LLM
85
- def model?
86
- role == "model"
87
- end
88
- end
89
- end
90
- end
@@ -1,96 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module Langchain
4
- module Messages
5
- class MistralAIMessage < Base
6
- # MistralAI uses the following roles:
7
- ROLES = [
8
- "system",
9
- "assistant",
10
- "user",
11
- "tool"
12
- ].freeze
13
-
14
- TOOL_ROLE = "tool"
15
-
16
- # Initialize a new MistralAI message
17
- #
18
- # @param role [String] The role of the message
19
- # @param content [String] The content of the message
20
- # @param image_url [String] The URL of the image
21
- # @param tool_calls [Array<Hash>] The tool calls made in the message
22
- # @param tool_call_id [String] The ID of the tool call
23
- def initialize(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil) # TODO: Implement image_file: reference (https://platform.openai.com/docs/api-reference/messages/object#messages/object-content)
24
- raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role)
25
- raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) }
26
-
27
- @role = role
28
- # Some Tools return content as a JSON hence `.to_s`
29
- @content = content.to_s
30
- # Make sure you're using the Pixtral model if you want to send image_url
31
- @image_url = image_url
32
- @tool_calls = tool_calls
33
- @tool_call_id = tool_call_id
34
- end
35
-
36
- # Check if the message came from an LLM
37
- #
38
- # @return [Boolean] true/false whether this message was produced by an LLM
39
- def llm?
40
- assistant?
41
- end
42
-
43
- # Convert the message to an MistralAI API-compatible hash
44
- #
45
- # @return [Hash] The message as an MistralAI API-compatible hash
46
- def to_hash
47
- {}.tap do |h|
48
- h[:role] = role
49
-
50
- if tool_calls.any?
51
- h[:tool_calls] = tool_calls
52
- else
53
- h[:tool_call_id] = tool_call_id if tool_call_id
54
-
55
- h[:content] = []
56
-
57
- if content && !content.empty?
58
- h[:content] << {
59
- type: "text",
60
- text: content
61
- }
62
- end
63
-
64
- if image_url
65
- h[:content] << {
66
- type: "image_url",
67
- image_url: image_url
68
- }
69
- end
70
- end
71
- end
72
- end
73
-
74
- # Check if the message came from an LLM
75
- #
76
- # @return [Boolean] true/false whether this message was produced by an LLM
77
- def assistant?
78
- role == "assistant"
79
- end
80
-
81
- # Check if the message are system instructions
82
- #
83
- # @return [Boolean] true/false whether this message are system instructions
84
- def system?
85
- role == "system"
86
- end
87
-
88
- # Check if the message is a tool call
89
- #
90
- # @return [Boolean] true/false whether this message is a tool call
91
- def tool?
92
- role == "tool"
93
- end
94
- end
95
- end
96
- end