langchainrb 0.13.4 → 0.14.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -42,17 +42,17 @@ module Langchain::LLM
42
42
 
43
43
  def embed(...)
44
44
  @client = @embed_client
45
- super(...)
45
+ super
46
46
  end
47
47
 
48
48
  def complete(...)
49
49
  @client = @chat_client
50
- super(...)
50
+ super
51
51
  end
52
52
 
53
53
  def chat(...)
54
54
  @client = @chat_client
55
- super(...)
55
+ super
56
56
  end
57
57
  end
58
58
  end
@@ -8,6 +8,7 @@ module Langchain::LLM
8
8
  # Langchain.rb provides a common interface to interact with all supported LLMs:
9
9
  #
10
10
  # - {Langchain::LLM::AI21}
11
+ # - {Langchain::LLM::Anthropic}
11
12
  # - {Langchain::LLM::Azure}
12
13
  # - {Langchain::LLM::Cohere}
13
14
  # - {Langchain::LLM::GooglePalm}
@@ -74,8 +74,6 @@ module Langchain::LLM
74
74
 
75
75
  default_params.merge!(params)
76
76
 
77
- default_params[:max_tokens] = Langchain::Utils::TokenLength::CohereValidator.validate_max_tokens!(prompt, default_params[:model], llm: client)
78
-
79
77
  response = client.generate(**default_params)
80
78
  Langchain::LLM::CohereResponse.new response, model: @defaults[:completion_model_name]
81
79
  end
@@ -18,7 +18,9 @@ module Langchain::LLM
18
18
 
19
19
  chat_parameters.update(
20
20
  model: {default: @defaults[:chat_completion_model_name]},
21
- temperature: {default: @defaults[:temperature]}
21
+ temperature: {default: @defaults[:temperature]},
22
+ generation_config: {default: nil},
23
+ safety_settings: {default: nil}
22
24
  )
23
25
  chat_parameters.remap(
24
26
  messages: :contents,
@@ -42,13 +44,25 @@ module Langchain::LLM
42
44
  raise ArgumentError.new("messages argument is required") if Array(params[:messages]).empty?
43
45
 
44
46
  parameters = chat_parameters.to_params(params)
45
- parameters[:generation_config] = {temperature: parameters.delete(:temperature)} if parameters[:temperature]
47
+ parameters[:generation_config] ||= {}
48
+ parameters[:generation_config][:temperature] ||= parameters[:temperature] if parameters[:temperature]
49
+ parameters.delete(:temperature)
50
+ parameters[:generation_config][:top_p] ||= parameters[:top_p] if parameters[:top_p]
51
+ parameters.delete(:top_p)
52
+ parameters[:generation_config][:top_k] ||= parameters[:top_k] if parameters[:top_k]
53
+ parameters.delete(:top_k)
54
+ parameters[:generation_config][:max_output_tokens] ||= parameters[:max_tokens] if parameters[:max_tokens]
55
+ parameters.delete(:max_tokens)
56
+ parameters[:generation_config][:response_mime_type] ||= parameters[:response_format] if parameters[:response_format]
57
+ parameters.delete(:response_format)
58
+ parameters[:generation_config][:stop_sequences] ||= parameters[:stop] if parameters[:stop]
59
+ parameters.delete(:stop)
46
60
 
47
61
  uri = URI("https://generativelanguage.googleapis.com/v1beta/models/#{parameters[:model]}:generateContent?key=#{api_key}")
48
62
 
49
63
  request = Net::HTTP::Post.new(uri)
50
64
  request.content_type = "application/json"
51
- request.body = parameters.to_json
65
+ request.body = Langchain::Utils::HashTransformer.deep_transform_keys(parameters) { |key| Langchain::Utils::HashTransformer.camelize_lower(key.to_s).to_sym }.to_json
52
66
 
53
67
  response = Net::HTTP.start(uri.hostname, uri.port, use_ssl: uri.scheme == "https") do |http|
54
68
  http.request(request)
@@ -18,7 +18,7 @@ module Langchain::LLM
18
18
  chat_completion_model_name: "chat-bison-001",
19
19
  embeddings_model_name: "embedding-gecko-001"
20
20
  }.freeze
21
- LENGTH_VALIDATOR = Langchain::Utils::TokenLength::GooglePalmValidator
21
+
22
22
  ROLE_MAPPING = {
23
23
  "assistant" => "ai"
24
24
  }
@@ -96,9 +96,6 @@ module Langchain::LLM
96
96
  examples: compose_examples(examples)
97
97
  }
98
98
 
99
- # chat-bison-001 is the only model that currently supports countMessageTokens functions
100
- LENGTH_VALIDATOR.validate_max_tokens!(default_params[:messages], "chat-bison-001", llm: self)
101
-
102
99
  if options[:stop_sequences]
103
100
  default_params[:stop] = options.delete(:stop_sequences)
104
101
  end
@@ -14,7 +14,7 @@ module Langchain::LLM
14
14
  attr_reader :url, :defaults
15
15
 
16
16
  DEFAULTS = {
17
- temperature: 0.8,
17
+ temperature: 0.0,
18
18
  completion_model_name: "llama3",
19
19
  embeddings_model_name: "llama3",
20
20
  chat_completion_model_name: "llama3"
@@ -64,7 +64,7 @@ module Langchain::LLM
64
64
  # Generate a completion for a given prompt
65
65
  #
66
66
  # @param prompt [String] The prompt to generate a completion for
67
- # @return [Langchain::LLM::ReplicateResponse] Reponse object
67
+ # @return [Langchain::LLM::ReplicateResponse] Response object
68
68
  #
69
69
  def complete(prompt:, **params)
70
70
  response = completion_model.predict(prompt: prompt)
@@ -3,7 +3,7 @@
3
3
  module Langchain::LLM
4
4
  class GoogleGeminiResponse < BaseResponse
5
5
  def initialize(raw_response, model: nil)
6
- super(raw_response, model: model)
6
+ super
7
7
  end
8
8
 
9
9
  def chat_completion
@@ -36,7 +36,7 @@ module Langchain::LLM
36
36
  end
37
37
 
38
38
  def prompt_tokens
39
- raw_response.dig("prompt_eval_count") if done?
39
+ raw_response.fetch("prompt_eval_count", 0) if done?
40
40
  end
41
41
 
42
42
  def completion_tokens
@@ -47,6 +47,24 @@ module Langchain::LLM
47
47
  prompt_tokens + completion_tokens if done?
48
48
  end
49
49
 
50
+ def tool_calls
51
+ if chat_completion && (parsed_tool_calls = JSON.parse(chat_completion))
52
+ [parsed_tool_calls]
53
+ elsif completion&.include?("[TOOL_CALLS]") && (
54
+ parsed_tool_calls = JSON.parse(
55
+ completion
56
+ # Slice out the serialize JSON
57
+ .slice(/\{.*\}/)
58
+ # Replace hash rocket with colon
59
+ .gsub("=>", ":")
60
+ )
61
+ )
62
+ [parsed_tool_calls]
63
+ else
64
+ []
65
+ end
66
+ end
67
+
50
68
  private
51
69
 
52
70
  def done?
@@ -90,7 +90,9 @@ module Langchain
90
90
  private
91
91
 
92
92
  def load_from_url
93
- URI.parse(URI::DEFAULT_PARSER.escape(@path)).open
93
+ unescaped_url = URI.decode_www_form_component(@path)
94
+ escaped_url = URI::DEFAULT_PARSER.escape(unescaped_url)
95
+ URI.parse(escaped_url).open
94
96
  end
95
97
 
96
98
  def load_from_path
@@ -0,0 +1,25 @@
1
+ module Langchain
2
+ module Utils
3
+ class HashTransformer
4
+ # Converts a string to camelCase
5
+ def self.camelize_lower(str)
6
+ str.split("_").inject([]) { |buffer, e| buffer.push(buffer.empty? ? e : e.capitalize) }.join
7
+ end
8
+
9
+ # Recursively transforms the keys of a hash to camel case
10
+ def self.deep_transform_keys(hash, &block)
11
+ case hash
12
+ when Hash
13
+ hash.each_with_object({}) do |(key, value), result|
14
+ new_key = block.call(key)
15
+ result[new_key] = deep_transform_keys(value, &block)
16
+ end
17
+ when Array
18
+ hash.map { |item| deep_transform_keys(item, &block) }
19
+ else
20
+ hash
21
+ end
22
+ end
23
+ end
24
+ end
25
+ end
@@ -64,7 +64,9 @@ module Langchain::Vectorsearch
64
64
  # @param ids [Array<String>] The list of ids to remove
65
65
  # @return [Hash] The response from the server
66
66
  def remove_texts(ids:)
67
- collection.delete(ids: ids)
67
+ collection.delete(
68
+ ids: ids.map(&:to_s)
69
+ )
68
70
  end
69
71
 
70
72
  # Create the collection with the default schema
@@ -6,7 +6,7 @@ module Langchain::Vectorsearch
6
6
  # Wrapper around Milvus REST APIs.
7
7
  #
8
8
  # Gem requirements:
9
- # gem "milvus", "~> 0.9.2"
9
+ # gem "milvus", "~> 0.9.3"
10
10
  #
11
11
  # Usage:
12
12
  # milvus = Langchain::Vectorsearch::Milvus.new(url:, index_name:, llm:, api_key:)
@@ -39,6 +39,21 @@ module Langchain::Vectorsearch
39
39
  )
40
40
  end
41
41
 
42
+ # Deletes a list of texts in the index
43
+ #
44
+ # @param ids [Array<Integer>] The ids of texts to delete
45
+ # @return [Boolean] The response from the server
46
+ def remove_texts(ids:)
47
+ raise ArgumentError, "ids must be an array" unless ids.is_a?(Array)
48
+ # Convert ids to integers if strings are passed
49
+ ids = ids.map(&:to_i)
50
+
51
+ client.entities.delete(
52
+ collection_name: index_name,
53
+ expression: "id in #{ids}"
54
+ )
55
+ end
56
+
42
57
  # TODO: Add update_texts method
43
58
 
44
59
  # Create default schema
@@ -83,7 +98,7 @@ module Langchain::Vectorsearch
83
98
  # @return [Boolean] The response from the server
84
99
  def create_default_index
85
100
  client.indices.create(
86
- collection_name: "Documents",
101
+ collection_name: index_name,
87
102
  field_name: "vectors",
88
103
  extra_params: [
89
104
  {key: "metric_type", value: "L2"},
@@ -125,7 +140,7 @@ module Langchain::Vectorsearch
125
140
 
126
141
  client.search(
127
142
  collection_name: index_name,
128
- output_fields: ["id", "content", "vectors"],
143
+ output_fields: ["id", "content"], # Add "vectors" if need to have full vectors returned.
129
144
  top_k: k.to_s,
130
145
  vectors: [embedding],
131
146
  dsl_type: 1,
@@ -1,5 +1,5 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module Langchain
4
- VERSION = "0.13.4"
4
+ VERSION = "0.14.0"
5
5
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: langchainrb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.13.4
4
+ version: 0.14.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrei Bondarev
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2024-06-16 00:00:00.000000000 Z
11
+ date: 2024-07-12 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: baran
@@ -212,14 +212,14 @@ dependencies:
212
212
  requirements:
213
213
  - - "~>"
214
214
  - !ruby/object:Gem::Version
215
- version: '0.2'
215
+ version: '0.3'
216
216
  type: :development
217
217
  prerelease: false
218
218
  version_requirements: !ruby/object:Gem::Requirement
219
219
  requirements:
220
220
  - - "~>"
221
221
  - !ruby/object:Gem::Version
222
- version: '0.2'
222
+ version: '0.3'
223
223
  - !ruby/object:Gem::Dependency
224
224
  name: aws-sdk-bedrockruntime
225
225
  requirement: !ruby/object:Gem::Requirement
@@ -408,14 +408,14 @@ dependencies:
408
408
  requirements:
409
409
  - - "~>"
410
410
  - !ruby/object:Gem::Version
411
- version: 0.9.2
411
+ version: 0.9.3
412
412
  type: :development
413
413
  prerelease: false
414
414
  version_requirements: !ruby/object:Gem::Requirement
415
415
  requirements:
416
416
  - - "~>"
417
417
  - !ruby/object:Gem::Version
418
- version: 0.9.2
418
+ version: 0.9.3
419
419
  - !ruby/object:Gem::Dependency
420
420
  name: llama_cpp
421
421
  requirement: !ruby/object:Gem::Requirement
@@ -682,20 +682,6 @@ dependencies:
682
682
  - - "~>"
683
683
  - !ruby/object:Gem::Version
684
684
  version: 0.1.0
685
- - !ruby/object:Gem::Dependency
686
- name: tiktoken_ruby
687
- requirement: !ruby/object:Gem::Requirement
688
- requirements:
689
- - - "~>"
690
- - !ruby/object:Gem::Version
691
- version: 0.0.9
692
- type: :development
693
- prerelease: false
694
- version_requirements: !ruby/object:Gem::Requirement
695
- requirements:
696
- - - "~>"
697
- - !ruby/object:Gem::Version
698
- version: 0.0.9
699
685
  description: Build LLM-backed Ruby applications with Ruby's Langchain.rb
700
686
  email:
701
687
  - andrei.bondarev13@gmail.com
@@ -711,6 +697,7 @@ files:
711
697
  - lib/langchain/assistants/messages/anthropic_message.rb
712
698
  - lib/langchain/assistants/messages/base.rb
713
699
  - lib/langchain/assistants/messages/google_gemini_message.rb
700
+ - lib/langchain/assistants/messages/ollama_message.rb
714
701
  - lib/langchain/assistants/messages/openai_message.rb
715
702
  - lib/langchain/assistants/thread.rb
716
703
  - lib/langchain/chunk.rb
@@ -809,12 +796,7 @@ files:
809
796
  - lib/langchain/tool/wikipedia/wikipedia.json
810
797
  - lib/langchain/tool/wikipedia/wikipedia.rb
811
798
  - lib/langchain/utils/cosine_similarity.rb
812
- - lib/langchain/utils/token_length/ai21_validator.rb
813
- - lib/langchain/utils/token_length/base_validator.rb
814
- - lib/langchain/utils/token_length/cohere_validator.rb
815
- - lib/langchain/utils/token_length/google_palm_validator.rb
816
- - lib/langchain/utils/token_length/openai_validator.rb
817
- - lib/langchain/utils/token_length/token_limit_exceeded.rb
799
+ - lib/langchain/utils/hash_transformer.rb
818
800
  - lib/langchain/vectorsearch/base.rb
819
801
  - lib/langchain/vectorsearch/chroma.rb
820
802
  - lib/langchain/vectorsearch/elasticsearch.rb
@@ -852,7 +834,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
852
834
  - !ruby/object:Gem::Version
853
835
  version: '0'
854
836
  requirements: []
855
- rubygems_version: 3.5.11
837
+ rubygems_version: 3.5.14
856
838
  signing_key:
857
839
  specification_version: 4
858
840
  summary: Build LLM-backed Ruby applications with Ruby's Langchain.rb
@@ -1,41 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module Langchain
4
- module Utils
5
- module TokenLength
6
- #
7
- # This class is meant to validate the length of the text passed in to AI21's API.
8
- # It is used to validate the token length before the API call is made
9
- #
10
-
11
- class AI21Validator < BaseValidator
12
- TOKEN_LIMITS = {
13
- "j2-ultra" => 8192,
14
- "j2-mid" => 8192,
15
- "j2-light" => 8192
16
- }.freeze
17
-
18
- #
19
- # Calculate token length for a given text and model name
20
- #
21
- # @param text [String] The text to calculate the token length for
22
- # @param model_name [String] The model name to validate against
23
- # @return [Integer] The token length of the text
24
- #
25
- def self.token_length(text, model_name, options = {})
26
- res = options[:llm].tokenize(text)
27
- res.dig(:tokens).length
28
- end
29
-
30
- def self.token_limit(model_name)
31
- TOKEN_LIMITS[model_name]
32
- end
33
- singleton_class.alias_method :completion_token_limit, :token_limit
34
-
35
- def self.token_length_from_messages(messages, model_name, options)
36
- messages.sum { |message| token_length(message.to_json, model_name, options) }
37
- end
38
- end
39
- end
40
- end
41
- end
@@ -1,42 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module Langchain
4
- module Utils
5
- module TokenLength
6
- #
7
- # Calculate the `max_tokens:` parameter to be set by calculating the context length of the text minus the prompt length
8
- #
9
- # @param content [String | Array<String>] The text or array of texts to validate
10
- # @param model_name [String] The model name to validate against
11
- # @return [Integer] Whether the text is valid or not
12
- # @raise [TokenLimitExceeded] If the text is too long
13
- #
14
- class BaseValidator
15
- def self.validate_max_tokens!(content, model_name, options = {})
16
- text_token_length = if content.is_a?(Array)
17
- token_length_from_messages(content, model_name, options)
18
- else
19
- token_length(content, model_name, options)
20
- end
21
-
22
- leftover_tokens = token_limit(model_name) - text_token_length
23
-
24
- # Some models have a separate token limit for completions (e.g. GPT-4 Turbo)
25
- # We want the lower of the two limits
26
- max_tokens = [leftover_tokens, completion_token_limit(model_name)].min
27
-
28
- # Raise an error even if whole prompt is equal to the model's token limit (leftover_tokens == 0)
29
- if max_tokens < 0
30
- raise limit_exceeded_exception(token_limit(model_name), text_token_length)
31
- end
32
-
33
- max_tokens
34
- end
35
-
36
- def self.limit_exceeded_exception(limit, length)
37
- TokenLimitExceeded.new("This model's maximum context length is #{limit} tokens, but the given text is #{length} tokens long.", length - limit)
38
- end
39
- end
40
- end
41
- end
42
- end
@@ -1,49 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module Langchain
4
- module Utils
5
- module TokenLength
6
- #
7
- # This class is meant to validate the length of the text passed in to Cohere's API.
8
- # It is used to validate the token length before the API call is made
9
- #
10
-
11
- class CohereValidator < BaseValidator
12
- TOKEN_LIMITS = {
13
- # Source:
14
- # https://docs.cohere.com/docs/models
15
- "command-light" => 4096,
16
- "command" => 4096,
17
- "base-light" => 2048,
18
- "base" => 2048,
19
- "embed-english-light-v2.0" => 512,
20
- "embed-english-v2.0" => 512,
21
- "embed-multilingual-v2.0" => 256,
22
- "summarize-medium" => 2048,
23
- "summarize-xlarge" => 2048
24
- }.freeze
25
-
26
- #
27
- # Calculate token length for a given text and model name
28
- #
29
- # @param text [String] The text to calculate the token length for
30
- # @param model_name [String] The model name to validate against
31
- # @return [Integer] The token length of the text
32
- #
33
- def self.token_length(text, model_name, options = {})
34
- res = options[:llm].tokenize(text: text)
35
- res["tokens"].length
36
- end
37
-
38
- def self.token_limit(model_name)
39
- TOKEN_LIMITS[model_name]
40
- end
41
- singleton_class.alias_method :completion_token_limit, :token_limit
42
-
43
- def self.token_length_from_messages(messages, model_name, options)
44
- messages.sum { |message| token_length(message.to_json, model_name, options) }
45
- end
46
- end
47
- end
48
- end
49
- end
@@ -1,57 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module Langchain
4
- module Utils
5
- module TokenLength
6
- #
7
- # This class is meant to validate the length of the text passed in to Google Palm's API.
8
- # It is used to validate the token length before the API call is made
9
- #
10
- class GooglePalmValidator < BaseValidator
11
- TOKEN_LIMITS = {
12
- # Source:
13
- # This data can be pulled when `list_models()` method is called: https://github.com/andreibondarev/google_palm_api#usage
14
-
15
- # chat-bison-001 is the only model that currently supports countMessageTokens functions
16
- "chat-bison-001" => {
17
- "input_token_limit" => 4000, # 4096 is the limit but the countMessageTokens does not return anything higher than 4000
18
- "output_token_limit" => 1024
19
- }
20
- # "text-bison-001" => {
21
- # "input_token_limit" => 8196,
22
- # "output_token_limit" => 1024
23
- # },
24
- # "embedding-gecko-001" => {
25
- # "input_token_limit" => 1024
26
- # }
27
- }.freeze
28
-
29
- #
30
- # Calculate token length for a given text and model name
31
- #
32
- # @param text [String] The text to calculate the token length for
33
- # @param model_name [String] The model name to validate against
34
- # @param options [Hash] the options to create a message with
35
- # @option options [Langchain::LLM:GooglePalm] :llm The Langchain::LLM:GooglePalm instance
36
- # @return [Integer] The token length of the text
37
- #
38
- def self.token_length(text, model_name = "chat-bison-001", options = {})
39
- response = options[:llm].client.count_message_tokens(model: model_name, prompt: text)
40
-
41
- raise Langchain::LLM::ApiError.new(response["error"]["message"]) unless response["error"].nil?
42
-
43
- response.dig("tokenCount")
44
- end
45
-
46
- def self.token_length_from_messages(messages, model_name, options = {})
47
- messages.sum { |message| token_length(message.to_json, model_name, options) }
48
- end
49
-
50
- def self.token_limit(model_name)
51
- TOKEN_LIMITS.dig(model_name, "input_token_limit")
52
- end
53
- singleton_class.alias_method :completion_token_limit, :token_limit
54
- end
55
- end
56
- end
57
- end