langchainrb 0.13.4 → 0.14.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -42,17 +42,17 @@ module Langchain::LLM
42
42
 
43
43
  def embed(...)
44
44
  @client = @embed_client
45
- super(...)
45
+ super
46
46
  end
47
47
 
48
48
  def complete(...)
49
49
  @client = @chat_client
50
- super(...)
50
+ super
51
51
  end
52
52
 
53
53
  def chat(...)
54
54
  @client = @chat_client
55
- super(...)
55
+ super
56
56
  end
57
57
  end
58
58
  end
@@ -8,6 +8,7 @@ module Langchain::LLM
8
8
  # Langchain.rb provides a common interface to interact with all supported LLMs:
9
9
  #
10
10
  # - {Langchain::LLM::AI21}
11
+ # - {Langchain::LLM::Anthropic}
11
12
  # - {Langchain::LLM::Azure}
12
13
  # - {Langchain::LLM::Cohere}
13
14
  # - {Langchain::LLM::GooglePalm}
@@ -74,8 +74,6 @@ module Langchain::LLM
74
74
 
75
75
  default_params.merge!(params)
76
76
 
77
- default_params[:max_tokens] = Langchain::Utils::TokenLength::CohereValidator.validate_max_tokens!(prompt, default_params[:model], llm: client)
78
-
79
77
  response = client.generate(**default_params)
80
78
  Langchain::LLM::CohereResponse.new response, model: @defaults[:completion_model_name]
81
79
  end
@@ -18,7 +18,9 @@ module Langchain::LLM
18
18
 
19
19
  chat_parameters.update(
20
20
  model: {default: @defaults[:chat_completion_model_name]},
21
- temperature: {default: @defaults[:temperature]}
21
+ temperature: {default: @defaults[:temperature]},
22
+ generation_config: {default: nil},
23
+ safety_settings: {default: nil}
22
24
  )
23
25
  chat_parameters.remap(
24
26
  messages: :contents,
@@ -42,13 +44,25 @@ module Langchain::LLM
42
44
  raise ArgumentError.new("messages argument is required") if Array(params[:messages]).empty?
43
45
 
44
46
  parameters = chat_parameters.to_params(params)
45
- parameters[:generation_config] = {temperature: parameters.delete(:temperature)} if parameters[:temperature]
47
+ parameters[:generation_config] ||= {}
48
+ parameters[:generation_config][:temperature] ||= parameters[:temperature] if parameters[:temperature]
49
+ parameters.delete(:temperature)
50
+ parameters[:generation_config][:top_p] ||= parameters[:top_p] if parameters[:top_p]
51
+ parameters.delete(:top_p)
52
+ parameters[:generation_config][:top_k] ||= parameters[:top_k] if parameters[:top_k]
53
+ parameters.delete(:top_k)
54
+ parameters[:generation_config][:max_output_tokens] ||= parameters[:max_tokens] if parameters[:max_tokens]
55
+ parameters.delete(:max_tokens)
56
+ parameters[:generation_config][:response_mime_type] ||= parameters[:response_format] if parameters[:response_format]
57
+ parameters.delete(:response_format)
58
+ parameters[:generation_config][:stop_sequences] ||= parameters[:stop] if parameters[:stop]
59
+ parameters.delete(:stop)
46
60
 
47
61
  uri = URI("https://generativelanguage.googleapis.com/v1beta/models/#{parameters[:model]}:generateContent?key=#{api_key}")
48
62
 
49
63
  request = Net::HTTP::Post.new(uri)
50
64
  request.content_type = "application/json"
51
- request.body = parameters.to_json
65
+ request.body = Langchain::Utils::HashTransformer.deep_transform_keys(parameters) { |key| Langchain::Utils::HashTransformer.camelize_lower(key.to_s).to_sym }.to_json
52
66
 
53
67
  response = Net::HTTP.start(uri.hostname, uri.port, use_ssl: uri.scheme == "https") do |http|
54
68
  http.request(request)
@@ -18,7 +18,7 @@ module Langchain::LLM
18
18
  chat_completion_model_name: "chat-bison-001",
19
19
  embeddings_model_name: "embedding-gecko-001"
20
20
  }.freeze
21
- LENGTH_VALIDATOR = Langchain::Utils::TokenLength::GooglePalmValidator
21
+
22
22
  ROLE_MAPPING = {
23
23
  "assistant" => "ai"
24
24
  }
@@ -96,9 +96,6 @@ module Langchain::LLM
96
96
  examples: compose_examples(examples)
97
97
  }
98
98
 
99
- # chat-bison-001 is the only model that currently supports countMessageTokens functions
100
- LENGTH_VALIDATOR.validate_max_tokens!(default_params[:messages], "chat-bison-001", llm: self)
101
-
102
99
  if options[:stop_sequences]
103
100
  default_params[:stop] = options.delete(:stop_sequences)
104
101
  end
@@ -14,7 +14,7 @@ module Langchain::LLM
14
14
  attr_reader :url, :defaults
15
15
 
16
16
  DEFAULTS = {
17
- temperature: 0.8,
17
+ temperature: 0.0,
18
18
  completion_model_name: "llama3",
19
19
  embeddings_model_name: "llama3",
20
20
  chat_completion_model_name: "llama3"
@@ -64,7 +64,7 @@ module Langchain::LLM
64
64
  # Generate a completion for a given prompt
65
65
  #
66
66
  # @param prompt [String] The prompt to generate a completion for
67
- # @return [Langchain::LLM::ReplicateResponse] Reponse object
67
+ # @return [Langchain::LLM::ReplicateResponse] Response object
68
68
  #
69
69
  def complete(prompt:, **params)
70
70
  response = completion_model.predict(prompt: prompt)
@@ -3,7 +3,7 @@
3
3
  module Langchain::LLM
4
4
  class GoogleGeminiResponse < BaseResponse
5
5
  def initialize(raw_response, model: nil)
6
- super(raw_response, model: model)
6
+ super
7
7
  end
8
8
 
9
9
  def chat_completion
@@ -36,7 +36,7 @@ module Langchain::LLM
36
36
  end
37
37
 
38
38
  def prompt_tokens
39
- raw_response.dig("prompt_eval_count") if done?
39
+ raw_response.fetch("prompt_eval_count", 0) if done?
40
40
  end
41
41
 
42
42
  def completion_tokens
@@ -47,6 +47,24 @@ module Langchain::LLM
47
47
  prompt_tokens + completion_tokens if done?
48
48
  end
49
49
 
50
+ def tool_calls
51
+ if chat_completion && (parsed_tool_calls = JSON.parse(chat_completion))
52
+ [parsed_tool_calls]
53
+ elsif completion&.include?("[TOOL_CALLS]") && (
54
+ parsed_tool_calls = JSON.parse(
55
+ completion
56
+ # Slice out the serialize JSON
57
+ .slice(/\{.*\}/)
58
+ # Replace hash rocket with colon
59
+ .gsub("=>", ":")
60
+ )
61
+ )
62
+ [parsed_tool_calls]
63
+ else
64
+ []
65
+ end
66
+ end
67
+
50
68
  private
51
69
 
52
70
  def done?
@@ -90,7 +90,9 @@ module Langchain
90
90
  private
91
91
 
92
92
  def load_from_url
93
- URI.parse(URI::DEFAULT_PARSER.escape(@path)).open
93
+ unescaped_url = URI.decode_www_form_component(@path)
94
+ escaped_url = URI::DEFAULT_PARSER.escape(unescaped_url)
95
+ URI.parse(escaped_url).open
94
96
  end
95
97
 
96
98
  def load_from_path
@@ -0,0 +1,25 @@
1
+ module Langchain
2
+ module Utils
3
+ class HashTransformer
4
+ # Converts a string to camelCase
5
+ def self.camelize_lower(str)
6
+ str.split("_").inject([]) { |buffer, e| buffer.push(buffer.empty? ? e : e.capitalize) }.join
7
+ end
8
+
9
+ # Recursively transforms the keys of a hash to camel case
10
+ def self.deep_transform_keys(hash, &block)
11
+ case hash
12
+ when Hash
13
+ hash.each_with_object({}) do |(key, value), result|
14
+ new_key = block.call(key)
15
+ result[new_key] = deep_transform_keys(value, &block)
16
+ end
17
+ when Array
18
+ hash.map { |item| deep_transform_keys(item, &block) }
19
+ else
20
+ hash
21
+ end
22
+ end
23
+ end
24
+ end
25
+ end
@@ -64,7 +64,9 @@ module Langchain::Vectorsearch
64
64
  # @param ids [Array<String>] The list of ids to remove
65
65
  # @return [Hash] The response from the server
66
66
  def remove_texts(ids:)
67
- collection.delete(ids: ids)
67
+ collection.delete(
68
+ ids: ids.map(&:to_s)
69
+ )
68
70
  end
69
71
 
70
72
  # Create the collection with the default schema
@@ -6,7 +6,7 @@ module Langchain::Vectorsearch
6
6
  # Wrapper around Milvus REST APIs.
7
7
  #
8
8
  # Gem requirements:
9
- # gem "milvus", "~> 0.9.2"
9
+ # gem "milvus", "~> 0.9.3"
10
10
  #
11
11
  # Usage:
12
12
  # milvus = Langchain::Vectorsearch::Milvus.new(url:, index_name:, llm:, api_key:)
@@ -39,6 +39,21 @@ module Langchain::Vectorsearch
39
39
  )
40
40
  end
41
41
 
42
+ # Deletes a list of texts in the index
43
+ #
44
+ # @param ids [Array<Integer>] The ids of texts to delete
45
+ # @return [Boolean] The response from the server
46
+ def remove_texts(ids:)
47
+ raise ArgumentError, "ids must be an array" unless ids.is_a?(Array)
48
+ # Convert ids to integers if strings are passed
49
+ ids = ids.map(&:to_i)
50
+
51
+ client.entities.delete(
52
+ collection_name: index_name,
53
+ expression: "id in #{ids}"
54
+ )
55
+ end
56
+
42
57
  # TODO: Add update_texts method
43
58
 
44
59
  # Create default schema
@@ -83,7 +98,7 @@ module Langchain::Vectorsearch
83
98
  # @return [Boolean] The response from the server
84
99
  def create_default_index
85
100
  client.indices.create(
86
- collection_name: "Documents",
101
+ collection_name: index_name,
87
102
  field_name: "vectors",
88
103
  extra_params: [
89
104
  {key: "metric_type", value: "L2"},
@@ -125,7 +140,7 @@ module Langchain::Vectorsearch
125
140
 
126
141
  client.search(
127
142
  collection_name: index_name,
128
- output_fields: ["id", "content", "vectors"],
143
+ output_fields: ["id", "content"], # Add "vectors" if need to have full vectors returned.
129
144
  top_k: k.to_s,
130
145
  vectors: [embedding],
131
146
  dsl_type: 1,
@@ -1,5 +1,5 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module Langchain
4
- VERSION = "0.13.4"
4
+ VERSION = "0.14.0"
5
5
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: langchainrb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.13.4
4
+ version: 0.14.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrei Bondarev
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2024-06-16 00:00:00.000000000 Z
11
+ date: 2024-07-12 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: baran
@@ -212,14 +212,14 @@ dependencies:
212
212
  requirements:
213
213
  - - "~>"
214
214
  - !ruby/object:Gem::Version
215
- version: '0.2'
215
+ version: '0.3'
216
216
  type: :development
217
217
  prerelease: false
218
218
  version_requirements: !ruby/object:Gem::Requirement
219
219
  requirements:
220
220
  - - "~>"
221
221
  - !ruby/object:Gem::Version
222
- version: '0.2'
222
+ version: '0.3'
223
223
  - !ruby/object:Gem::Dependency
224
224
  name: aws-sdk-bedrockruntime
225
225
  requirement: !ruby/object:Gem::Requirement
@@ -408,14 +408,14 @@ dependencies:
408
408
  requirements:
409
409
  - - "~>"
410
410
  - !ruby/object:Gem::Version
411
- version: 0.9.2
411
+ version: 0.9.3
412
412
  type: :development
413
413
  prerelease: false
414
414
  version_requirements: !ruby/object:Gem::Requirement
415
415
  requirements:
416
416
  - - "~>"
417
417
  - !ruby/object:Gem::Version
418
- version: 0.9.2
418
+ version: 0.9.3
419
419
  - !ruby/object:Gem::Dependency
420
420
  name: llama_cpp
421
421
  requirement: !ruby/object:Gem::Requirement
@@ -682,20 +682,6 @@ dependencies:
682
682
  - - "~>"
683
683
  - !ruby/object:Gem::Version
684
684
  version: 0.1.0
685
- - !ruby/object:Gem::Dependency
686
- name: tiktoken_ruby
687
- requirement: !ruby/object:Gem::Requirement
688
- requirements:
689
- - - "~>"
690
- - !ruby/object:Gem::Version
691
- version: 0.0.9
692
- type: :development
693
- prerelease: false
694
- version_requirements: !ruby/object:Gem::Requirement
695
- requirements:
696
- - - "~>"
697
- - !ruby/object:Gem::Version
698
- version: 0.0.9
699
685
  description: Build LLM-backed Ruby applications with Ruby's Langchain.rb
700
686
  email:
701
687
  - andrei.bondarev13@gmail.com
@@ -711,6 +697,7 @@ files:
711
697
  - lib/langchain/assistants/messages/anthropic_message.rb
712
698
  - lib/langchain/assistants/messages/base.rb
713
699
  - lib/langchain/assistants/messages/google_gemini_message.rb
700
+ - lib/langchain/assistants/messages/ollama_message.rb
714
701
  - lib/langchain/assistants/messages/openai_message.rb
715
702
  - lib/langchain/assistants/thread.rb
716
703
  - lib/langchain/chunk.rb
@@ -809,12 +796,7 @@ files:
809
796
  - lib/langchain/tool/wikipedia/wikipedia.json
810
797
  - lib/langchain/tool/wikipedia/wikipedia.rb
811
798
  - lib/langchain/utils/cosine_similarity.rb
812
- - lib/langchain/utils/token_length/ai21_validator.rb
813
- - lib/langchain/utils/token_length/base_validator.rb
814
- - lib/langchain/utils/token_length/cohere_validator.rb
815
- - lib/langchain/utils/token_length/google_palm_validator.rb
816
- - lib/langchain/utils/token_length/openai_validator.rb
817
- - lib/langchain/utils/token_length/token_limit_exceeded.rb
799
+ - lib/langchain/utils/hash_transformer.rb
818
800
  - lib/langchain/vectorsearch/base.rb
819
801
  - lib/langchain/vectorsearch/chroma.rb
820
802
  - lib/langchain/vectorsearch/elasticsearch.rb
@@ -852,7 +834,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
852
834
  - !ruby/object:Gem::Version
853
835
  version: '0'
854
836
  requirements: []
855
- rubygems_version: 3.5.11
837
+ rubygems_version: 3.5.14
856
838
  signing_key:
857
839
  specification_version: 4
858
840
  summary: Build LLM-backed Ruby applications with Ruby's Langchain.rb
@@ -1,41 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module Langchain
4
- module Utils
5
- module TokenLength
6
- #
7
- # This class is meant to validate the length of the text passed in to AI21's API.
8
- # It is used to validate the token length before the API call is made
9
- #
10
-
11
- class AI21Validator < BaseValidator
12
- TOKEN_LIMITS = {
13
- "j2-ultra" => 8192,
14
- "j2-mid" => 8192,
15
- "j2-light" => 8192
16
- }.freeze
17
-
18
- #
19
- # Calculate token length for a given text and model name
20
- #
21
- # @param text [String] The text to calculate the token length for
22
- # @param model_name [String] The model name to validate against
23
- # @return [Integer] The token length of the text
24
- #
25
- def self.token_length(text, model_name, options = {})
26
- res = options[:llm].tokenize(text)
27
- res.dig(:tokens).length
28
- end
29
-
30
- def self.token_limit(model_name)
31
- TOKEN_LIMITS[model_name]
32
- end
33
- singleton_class.alias_method :completion_token_limit, :token_limit
34
-
35
- def self.token_length_from_messages(messages, model_name, options)
36
- messages.sum { |message| token_length(message.to_json, model_name, options) }
37
- end
38
- end
39
- end
40
- end
41
- end
@@ -1,42 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module Langchain
4
- module Utils
5
- module TokenLength
6
- #
7
- # Calculate the `max_tokens:` parameter to be set by calculating the context length of the text minus the prompt length
8
- #
9
- # @param content [String | Array<String>] The text or array of texts to validate
10
- # @param model_name [String] The model name to validate against
11
- # @return [Integer] Whether the text is valid or not
12
- # @raise [TokenLimitExceeded] If the text is too long
13
- #
14
- class BaseValidator
15
- def self.validate_max_tokens!(content, model_name, options = {})
16
- text_token_length = if content.is_a?(Array)
17
- token_length_from_messages(content, model_name, options)
18
- else
19
- token_length(content, model_name, options)
20
- end
21
-
22
- leftover_tokens = token_limit(model_name) - text_token_length
23
-
24
- # Some models have a separate token limit for completions (e.g. GPT-4 Turbo)
25
- # We want the lower of the two limits
26
- max_tokens = [leftover_tokens, completion_token_limit(model_name)].min
27
-
28
- # Raise an error even if whole prompt is equal to the model's token limit (leftover_tokens == 0)
29
- if max_tokens < 0
30
- raise limit_exceeded_exception(token_limit(model_name), text_token_length)
31
- end
32
-
33
- max_tokens
34
- end
35
-
36
- def self.limit_exceeded_exception(limit, length)
37
- TokenLimitExceeded.new("This model's maximum context length is #{limit} tokens, but the given text is #{length} tokens long.", length - limit)
38
- end
39
- end
40
- end
41
- end
42
- end
@@ -1,49 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module Langchain
4
- module Utils
5
- module TokenLength
6
- #
7
- # This class is meant to validate the length of the text passed in to Cohere's API.
8
- # It is used to validate the token length before the API call is made
9
- #
10
-
11
- class CohereValidator < BaseValidator
12
- TOKEN_LIMITS = {
13
- # Source:
14
- # https://docs.cohere.com/docs/models
15
- "command-light" => 4096,
16
- "command" => 4096,
17
- "base-light" => 2048,
18
- "base" => 2048,
19
- "embed-english-light-v2.0" => 512,
20
- "embed-english-v2.0" => 512,
21
- "embed-multilingual-v2.0" => 256,
22
- "summarize-medium" => 2048,
23
- "summarize-xlarge" => 2048
24
- }.freeze
25
-
26
- #
27
- # Calculate token length for a given text and model name
28
- #
29
- # @param text [String] The text to calculate the token length for
30
- # @param model_name [String] The model name to validate against
31
- # @return [Integer] The token length of the text
32
- #
33
- def self.token_length(text, model_name, options = {})
34
- res = options[:llm].tokenize(text: text)
35
- res["tokens"].length
36
- end
37
-
38
- def self.token_limit(model_name)
39
- TOKEN_LIMITS[model_name]
40
- end
41
- singleton_class.alias_method :completion_token_limit, :token_limit
42
-
43
- def self.token_length_from_messages(messages, model_name, options)
44
- messages.sum { |message| token_length(message.to_json, model_name, options) }
45
- end
46
- end
47
- end
48
- end
49
- end
@@ -1,57 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module Langchain
4
- module Utils
5
- module TokenLength
6
- #
7
- # This class is meant to validate the length of the text passed in to Google Palm's API.
8
- # It is used to validate the token length before the API call is made
9
- #
10
- class GooglePalmValidator < BaseValidator
11
- TOKEN_LIMITS = {
12
- # Source:
13
- # This data can be pulled when `list_models()` method is called: https://github.com/andreibondarev/google_palm_api#usage
14
-
15
- # chat-bison-001 is the only model that currently supports countMessageTokens functions
16
- "chat-bison-001" => {
17
- "input_token_limit" => 4000, # 4096 is the limit but the countMessageTokens does not return anything higher than 4000
18
- "output_token_limit" => 1024
19
- }
20
- # "text-bison-001" => {
21
- # "input_token_limit" => 8196,
22
- # "output_token_limit" => 1024
23
- # },
24
- # "embedding-gecko-001" => {
25
- # "input_token_limit" => 1024
26
- # }
27
- }.freeze
28
-
29
- #
30
- # Calculate token length for a given text and model name
31
- #
32
- # @param text [String] The text to calculate the token length for
33
- # @param model_name [String] The model name to validate against
34
- # @param options [Hash] the options to create a message with
35
- # @option options [Langchain::LLM:GooglePalm] :llm The Langchain::LLM:GooglePalm instance
36
- # @return [Integer] The token length of the text
37
- #
38
- def self.token_length(text, model_name = "chat-bison-001", options = {})
39
- response = options[:llm].client.count_message_tokens(model: model_name, prompt: text)
40
-
41
- raise Langchain::LLM::ApiError.new(response["error"]["message"]) unless response["error"].nil?
42
-
43
- response.dig("tokenCount")
44
- end
45
-
46
- def self.token_length_from_messages(messages, model_name, options = {})
47
- messages.sum { |message| token_length(message.to_json, model_name, options) }
48
- end
49
-
50
- def self.token_limit(model_name)
51
- TOKEN_LIMITS.dig(model_name, "input_token_limit")
52
- end
53
- singleton_class.alias_method :completion_token_limit, :token_limit
54
- end
55
- end
56
- end
57
- end