langchainrb 0.13.4 → 0.14.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +11 -0
- data/README.md +3 -18
- data/lib/langchain/assistants/assistant.rb +204 -79
- data/lib/langchain/assistants/messages/base.rb +35 -1
- data/lib/langchain/assistants/messages/ollama_message.rb +86 -0
- data/lib/langchain/assistants/thread.rb +8 -1
- data/lib/langchain/llm/ai21.rb +0 -4
- data/lib/langchain/llm/anthropic.rb +15 -6
- data/lib/langchain/llm/azure.rb +3 -3
- data/lib/langchain/llm/base.rb +1 -0
- data/lib/langchain/llm/cohere.rb +0 -2
- data/lib/langchain/llm/google_gemini.rb +17 -3
- data/lib/langchain/llm/google_palm.rb +1 -4
- data/lib/langchain/llm/ollama.rb +1 -1
- data/lib/langchain/llm/replicate.rb +1 -1
- data/lib/langchain/llm/response/google_gemini_response.rb +1 -1
- data/lib/langchain/llm/response/ollama_response.rb +19 -1
- data/lib/langchain/loader.rb +3 -1
- data/lib/langchain/utils/hash_transformer.rb +25 -0
- data/lib/langchain/vectorsearch/chroma.rb +3 -1
- data/lib/langchain/vectorsearch/milvus.rb +18 -3
- data/lib/langchain/version.rb +1 -1
- metadata +9 -27
- data/lib/langchain/utils/token_length/ai21_validator.rb +0 -41
- data/lib/langchain/utils/token_length/base_validator.rb +0 -42
- data/lib/langchain/utils/token_length/cohere_validator.rb +0 -49
- data/lib/langchain/utils/token_length/google_palm_validator.rb +0 -57
- data/lib/langchain/utils/token_length/openai_validator.rb +0 -138
- data/lib/langchain/utils/token_length/token_limit_exceeded.rb +0 -17
data/lib/langchain/llm/azure.rb
CHANGED
@@ -42,17 +42,17 @@ module Langchain::LLM
|
|
42
42
|
|
43
43
|
def embed(...)
|
44
44
|
@client = @embed_client
|
45
|
-
super
|
45
|
+
super
|
46
46
|
end
|
47
47
|
|
48
48
|
def complete(...)
|
49
49
|
@client = @chat_client
|
50
|
-
super
|
50
|
+
super
|
51
51
|
end
|
52
52
|
|
53
53
|
def chat(...)
|
54
54
|
@client = @chat_client
|
55
|
-
super
|
55
|
+
super
|
56
56
|
end
|
57
57
|
end
|
58
58
|
end
|
data/lib/langchain/llm/base.rb
CHANGED
@@ -8,6 +8,7 @@ module Langchain::LLM
|
|
8
8
|
# Langchain.rb provides a common interface to interact with all supported LLMs:
|
9
9
|
#
|
10
10
|
# - {Langchain::LLM::AI21}
|
11
|
+
# - {Langchain::LLM::Anthropic}
|
11
12
|
# - {Langchain::LLM::Azure}
|
12
13
|
# - {Langchain::LLM::Cohere}
|
13
14
|
# - {Langchain::LLM::GooglePalm}
|
data/lib/langchain/llm/cohere.rb
CHANGED
@@ -74,8 +74,6 @@ module Langchain::LLM
|
|
74
74
|
|
75
75
|
default_params.merge!(params)
|
76
76
|
|
77
|
-
default_params[:max_tokens] = Langchain::Utils::TokenLength::CohereValidator.validate_max_tokens!(prompt, default_params[:model], llm: client)
|
78
|
-
|
79
77
|
response = client.generate(**default_params)
|
80
78
|
Langchain::LLM::CohereResponse.new response, model: @defaults[:completion_model_name]
|
81
79
|
end
|
@@ -18,7 +18,9 @@ module Langchain::LLM
|
|
18
18
|
|
19
19
|
chat_parameters.update(
|
20
20
|
model: {default: @defaults[:chat_completion_model_name]},
|
21
|
-
temperature: {default: @defaults[:temperature]}
|
21
|
+
temperature: {default: @defaults[:temperature]},
|
22
|
+
generation_config: {default: nil},
|
23
|
+
safety_settings: {default: nil}
|
22
24
|
)
|
23
25
|
chat_parameters.remap(
|
24
26
|
messages: :contents,
|
@@ -42,13 +44,25 @@ module Langchain::LLM
|
|
42
44
|
raise ArgumentError.new("messages argument is required") if Array(params[:messages]).empty?
|
43
45
|
|
44
46
|
parameters = chat_parameters.to_params(params)
|
45
|
-
parameters[:generation_config]
|
47
|
+
parameters[:generation_config] ||= {}
|
48
|
+
parameters[:generation_config][:temperature] ||= parameters[:temperature] if parameters[:temperature]
|
49
|
+
parameters.delete(:temperature)
|
50
|
+
parameters[:generation_config][:top_p] ||= parameters[:top_p] if parameters[:top_p]
|
51
|
+
parameters.delete(:top_p)
|
52
|
+
parameters[:generation_config][:top_k] ||= parameters[:top_k] if parameters[:top_k]
|
53
|
+
parameters.delete(:top_k)
|
54
|
+
parameters[:generation_config][:max_output_tokens] ||= parameters[:max_tokens] if parameters[:max_tokens]
|
55
|
+
parameters.delete(:max_tokens)
|
56
|
+
parameters[:generation_config][:response_mime_type] ||= parameters[:response_format] if parameters[:response_format]
|
57
|
+
parameters.delete(:response_format)
|
58
|
+
parameters[:generation_config][:stop_sequences] ||= parameters[:stop] if parameters[:stop]
|
59
|
+
parameters.delete(:stop)
|
46
60
|
|
47
61
|
uri = URI("https://generativelanguage.googleapis.com/v1beta/models/#{parameters[:model]}:generateContent?key=#{api_key}")
|
48
62
|
|
49
63
|
request = Net::HTTP::Post.new(uri)
|
50
64
|
request.content_type = "application/json"
|
51
|
-
request.body = parameters.to_json
|
65
|
+
request.body = Langchain::Utils::HashTransformer.deep_transform_keys(parameters) { |key| Langchain::Utils::HashTransformer.camelize_lower(key.to_s).to_sym }.to_json
|
52
66
|
|
53
67
|
response = Net::HTTP.start(uri.hostname, uri.port, use_ssl: uri.scheme == "https") do |http|
|
54
68
|
http.request(request)
|
@@ -18,7 +18,7 @@ module Langchain::LLM
|
|
18
18
|
chat_completion_model_name: "chat-bison-001",
|
19
19
|
embeddings_model_name: "embedding-gecko-001"
|
20
20
|
}.freeze
|
21
|
-
|
21
|
+
|
22
22
|
ROLE_MAPPING = {
|
23
23
|
"assistant" => "ai"
|
24
24
|
}
|
@@ -96,9 +96,6 @@ module Langchain::LLM
|
|
96
96
|
examples: compose_examples(examples)
|
97
97
|
}
|
98
98
|
|
99
|
-
# chat-bison-001 is the only model that currently supports countMessageTokens functions
|
100
|
-
LENGTH_VALIDATOR.validate_max_tokens!(default_params[:messages], "chat-bison-001", llm: self)
|
101
|
-
|
102
99
|
if options[:stop_sequences]
|
103
100
|
default_params[:stop] = options.delete(:stop_sequences)
|
104
101
|
end
|
data/lib/langchain/llm/ollama.rb
CHANGED
@@ -64,7 +64,7 @@ module Langchain::LLM
|
|
64
64
|
# Generate a completion for a given prompt
|
65
65
|
#
|
66
66
|
# @param prompt [String] The prompt to generate a completion for
|
67
|
-
# @return [Langchain::LLM::ReplicateResponse]
|
67
|
+
# @return [Langchain::LLM::ReplicateResponse] Response object
|
68
68
|
#
|
69
69
|
def complete(prompt:, **params)
|
70
70
|
response = completion_model.predict(prompt: prompt)
|
@@ -36,7 +36,7 @@ module Langchain::LLM
|
|
36
36
|
end
|
37
37
|
|
38
38
|
def prompt_tokens
|
39
|
-
raw_response.
|
39
|
+
raw_response.fetch("prompt_eval_count", 0) if done?
|
40
40
|
end
|
41
41
|
|
42
42
|
def completion_tokens
|
@@ -47,6 +47,24 @@ module Langchain::LLM
|
|
47
47
|
prompt_tokens + completion_tokens if done?
|
48
48
|
end
|
49
49
|
|
50
|
+
def tool_calls
|
51
|
+
if chat_completion && (parsed_tool_calls = JSON.parse(chat_completion))
|
52
|
+
[parsed_tool_calls]
|
53
|
+
elsif completion&.include?("[TOOL_CALLS]") && (
|
54
|
+
parsed_tool_calls = JSON.parse(
|
55
|
+
completion
|
56
|
+
# Slice out the serialize JSON
|
57
|
+
.slice(/\{.*\}/)
|
58
|
+
# Replace hash rocket with colon
|
59
|
+
.gsub("=>", ":")
|
60
|
+
)
|
61
|
+
)
|
62
|
+
[parsed_tool_calls]
|
63
|
+
else
|
64
|
+
[]
|
65
|
+
end
|
66
|
+
end
|
67
|
+
|
50
68
|
private
|
51
69
|
|
52
70
|
def done?
|
data/lib/langchain/loader.rb
CHANGED
@@ -90,7 +90,9 @@ module Langchain
|
|
90
90
|
private
|
91
91
|
|
92
92
|
def load_from_url
|
93
|
-
URI.
|
93
|
+
unescaped_url = URI.decode_www_form_component(@path)
|
94
|
+
escaped_url = URI::DEFAULT_PARSER.escape(unescaped_url)
|
95
|
+
URI.parse(escaped_url).open
|
94
96
|
end
|
95
97
|
|
96
98
|
def load_from_path
|
@@ -0,0 +1,25 @@
|
|
1
|
+
module Langchain
|
2
|
+
module Utils
|
3
|
+
class HashTransformer
|
4
|
+
# Converts a string to camelCase
|
5
|
+
def self.camelize_lower(str)
|
6
|
+
str.split("_").inject([]) { |buffer, e| buffer.push(buffer.empty? ? e : e.capitalize) }.join
|
7
|
+
end
|
8
|
+
|
9
|
+
# Recursively transforms the keys of a hash to camel case
|
10
|
+
def self.deep_transform_keys(hash, &block)
|
11
|
+
case hash
|
12
|
+
when Hash
|
13
|
+
hash.each_with_object({}) do |(key, value), result|
|
14
|
+
new_key = block.call(key)
|
15
|
+
result[new_key] = deep_transform_keys(value, &block)
|
16
|
+
end
|
17
|
+
when Array
|
18
|
+
hash.map { |item| deep_transform_keys(item, &block) }
|
19
|
+
else
|
20
|
+
hash
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
@@ -64,7 +64,9 @@ module Langchain::Vectorsearch
|
|
64
64
|
# @param ids [Array<String>] The list of ids to remove
|
65
65
|
# @return [Hash] The response from the server
|
66
66
|
def remove_texts(ids:)
|
67
|
-
collection.delete(
|
67
|
+
collection.delete(
|
68
|
+
ids: ids.map(&:to_s)
|
69
|
+
)
|
68
70
|
end
|
69
71
|
|
70
72
|
# Create the collection with the default schema
|
@@ -6,7 +6,7 @@ module Langchain::Vectorsearch
|
|
6
6
|
# Wrapper around Milvus REST APIs.
|
7
7
|
#
|
8
8
|
# Gem requirements:
|
9
|
-
# gem "milvus", "~> 0.9.
|
9
|
+
# gem "milvus", "~> 0.9.3"
|
10
10
|
#
|
11
11
|
# Usage:
|
12
12
|
# milvus = Langchain::Vectorsearch::Milvus.new(url:, index_name:, llm:, api_key:)
|
@@ -39,6 +39,21 @@ module Langchain::Vectorsearch
|
|
39
39
|
)
|
40
40
|
end
|
41
41
|
|
42
|
+
# Deletes a list of texts in the index
|
43
|
+
#
|
44
|
+
# @param ids [Array<Integer>] The ids of texts to delete
|
45
|
+
# @return [Boolean] The response from the server
|
46
|
+
def remove_texts(ids:)
|
47
|
+
raise ArgumentError, "ids must be an array" unless ids.is_a?(Array)
|
48
|
+
# Convert ids to integers if strings are passed
|
49
|
+
ids = ids.map(&:to_i)
|
50
|
+
|
51
|
+
client.entities.delete(
|
52
|
+
collection_name: index_name,
|
53
|
+
expression: "id in #{ids}"
|
54
|
+
)
|
55
|
+
end
|
56
|
+
|
42
57
|
# TODO: Add update_texts method
|
43
58
|
|
44
59
|
# Create default schema
|
@@ -83,7 +98,7 @@ module Langchain::Vectorsearch
|
|
83
98
|
# @return [Boolean] The response from the server
|
84
99
|
def create_default_index
|
85
100
|
client.indices.create(
|
86
|
-
collection_name:
|
101
|
+
collection_name: index_name,
|
87
102
|
field_name: "vectors",
|
88
103
|
extra_params: [
|
89
104
|
{key: "metric_type", value: "L2"},
|
@@ -125,7 +140,7 @@ module Langchain::Vectorsearch
|
|
125
140
|
|
126
141
|
client.search(
|
127
142
|
collection_name: index_name,
|
128
|
-
output_fields: ["id", "content", "vectors"
|
143
|
+
output_fields: ["id", "content"], # Add "vectors" if need to have full vectors returned.
|
129
144
|
top_k: k.to_s,
|
130
145
|
vectors: [embedding],
|
131
146
|
dsl_type: 1,
|
data/lib/langchain/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: langchainrb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.14.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrei Bondarev
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2024-
|
11
|
+
date: 2024-07-12 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: baran
|
@@ -212,14 +212,14 @@ dependencies:
|
|
212
212
|
requirements:
|
213
213
|
- - "~>"
|
214
214
|
- !ruby/object:Gem::Version
|
215
|
-
version: '0.
|
215
|
+
version: '0.3'
|
216
216
|
type: :development
|
217
217
|
prerelease: false
|
218
218
|
version_requirements: !ruby/object:Gem::Requirement
|
219
219
|
requirements:
|
220
220
|
- - "~>"
|
221
221
|
- !ruby/object:Gem::Version
|
222
|
-
version: '0.
|
222
|
+
version: '0.3'
|
223
223
|
- !ruby/object:Gem::Dependency
|
224
224
|
name: aws-sdk-bedrockruntime
|
225
225
|
requirement: !ruby/object:Gem::Requirement
|
@@ -408,14 +408,14 @@ dependencies:
|
|
408
408
|
requirements:
|
409
409
|
- - "~>"
|
410
410
|
- !ruby/object:Gem::Version
|
411
|
-
version: 0.9.
|
411
|
+
version: 0.9.3
|
412
412
|
type: :development
|
413
413
|
prerelease: false
|
414
414
|
version_requirements: !ruby/object:Gem::Requirement
|
415
415
|
requirements:
|
416
416
|
- - "~>"
|
417
417
|
- !ruby/object:Gem::Version
|
418
|
-
version: 0.9.
|
418
|
+
version: 0.9.3
|
419
419
|
- !ruby/object:Gem::Dependency
|
420
420
|
name: llama_cpp
|
421
421
|
requirement: !ruby/object:Gem::Requirement
|
@@ -682,20 +682,6 @@ dependencies:
|
|
682
682
|
- - "~>"
|
683
683
|
- !ruby/object:Gem::Version
|
684
684
|
version: 0.1.0
|
685
|
-
- !ruby/object:Gem::Dependency
|
686
|
-
name: tiktoken_ruby
|
687
|
-
requirement: !ruby/object:Gem::Requirement
|
688
|
-
requirements:
|
689
|
-
- - "~>"
|
690
|
-
- !ruby/object:Gem::Version
|
691
|
-
version: 0.0.9
|
692
|
-
type: :development
|
693
|
-
prerelease: false
|
694
|
-
version_requirements: !ruby/object:Gem::Requirement
|
695
|
-
requirements:
|
696
|
-
- - "~>"
|
697
|
-
- !ruby/object:Gem::Version
|
698
|
-
version: 0.0.9
|
699
685
|
description: Build LLM-backed Ruby applications with Ruby's Langchain.rb
|
700
686
|
email:
|
701
687
|
- andrei.bondarev13@gmail.com
|
@@ -711,6 +697,7 @@ files:
|
|
711
697
|
- lib/langchain/assistants/messages/anthropic_message.rb
|
712
698
|
- lib/langchain/assistants/messages/base.rb
|
713
699
|
- lib/langchain/assistants/messages/google_gemini_message.rb
|
700
|
+
- lib/langchain/assistants/messages/ollama_message.rb
|
714
701
|
- lib/langchain/assistants/messages/openai_message.rb
|
715
702
|
- lib/langchain/assistants/thread.rb
|
716
703
|
- lib/langchain/chunk.rb
|
@@ -809,12 +796,7 @@ files:
|
|
809
796
|
- lib/langchain/tool/wikipedia/wikipedia.json
|
810
797
|
- lib/langchain/tool/wikipedia/wikipedia.rb
|
811
798
|
- lib/langchain/utils/cosine_similarity.rb
|
812
|
-
- lib/langchain/utils/
|
813
|
-
- lib/langchain/utils/token_length/base_validator.rb
|
814
|
-
- lib/langchain/utils/token_length/cohere_validator.rb
|
815
|
-
- lib/langchain/utils/token_length/google_palm_validator.rb
|
816
|
-
- lib/langchain/utils/token_length/openai_validator.rb
|
817
|
-
- lib/langchain/utils/token_length/token_limit_exceeded.rb
|
799
|
+
- lib/langchain/utils/hash_transformer.rb
|
818
800
|
- lib/langchain/vectorsearch/base.rb
|
819
801
|
- lib/langchain/vectorsearch/chroma.rb
|
820
802
|
- lib/langchain/vectorsearch/elasticsearch.rb
|
@@ -852,7 +834,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
852
834
|
- !ruby/object:Gem::Version
|
853
835
|
version: '0'
|
854
836
|
requirements: []
|
855
|
-
rubygems_version: 3.5.
|
837
|
+
rubygems_version: 3.5.14
|
856
838
|
signing_key:
|
857
839
|
specification_version: 4
|
858
840
|
summary: Build LLM-backed Ruby applications with Ruby's Langchain.rb
|
@@ -1,41 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
module Langchain
|
4
|
-
module Utils
|
5
|
-
module TokenLength
|
6
|
-
#
|
7
|
-
# This class is meant to validate the length of the text passed in to AI21's API.
|
8
|
-
# It is used to validate the token length before the API call is made
|
9
|
-
#
|
10
|
-
|
11
|
-
class AI21Validator < BaseValidator
|
12
|
-
TOKEN_LIMITS = {
|
13
|
-
"j2-ultra" => 8192,
|
14
|
-
"j2-mid" => 8192,
|
15
|
-
"j2-light" => 8192
|
16
|
-
}.freeze
|
17
|
-
|
18
|
-
#
|
19
|
-
# Calculate token length for a given text and model name
|
20
|
-
#
|
21
|
-
# @param text [String] The text to calculate the token length for
|
22
|
-
# @param model_name [String] The model name to validate against
|
23
|
-
# @return [Integer] The token length of the text
|
24
|
-
#
|
25
|
-
def self.token_length(text, model_name, options = {})
|
26
|
-
res = options[:llm].tokenize(text)
|
27
|
-
res.dig(:tokens).length
|
28
|
-
end
|
29
|
-
|
30
|
-
def self.token_limit(model_name)
|
31
|
-
TOKEN_LIMITS[model_name]
|
32
|
-
end
|
33
|
-
singleton_class.alias_method :completion_token_limit, :token_limit
|
34
|
-
|
35
|
-
def self.token_length_from_messages(messages, model_name, options)
|
36
|
-
messages.sum { |message| token_length(message.to_json, model_name, options) }
|
37
|
-
end
|
38
|
-
end
|
39
|
-
end
|
40
|
-
end
|
41
|
-
end
|
@@ -1,42 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
module Langchain
|
4
|
-
module Utils
|
5
|
-
module TokenLength
|
6
|
-
#
|
7
|
-
# Calculate the `max_tokens:` parameter to be set by calculating the context length of the text minus the prompt length
|
8
|
-
#
|
9
|
-
# @param content [String | Array<String>] The text or array of texts to validate
|
10
|
-
# @param model_name [String] The model name to validate against
|
11
|
-
# @return [Integer] Whether the text is valid or not
|
12
|
-
# @raise [TokenLimitExceeded] If the text is too long
|
13
|
-
#
|
14
|
-
class BaseValidator
|
15
|
-
def self.validate_max_tokens!(content, model_name, options = {})
|
16
|
-
text_token_length = if content.is_a?(Array)
|
17
|
-
token_length_from_messages(content, model_name, options)
|
18
|
-
else
|
19
|
-
token_length(content, model_name, options)
|
20
|
-
end
|
21
|
-
|
22
|
-
leftover_tokens = token_limit(model_name) - text_token_length
|
23
|
-
|
24
|
-
# Some models have a separate token limit for completions (e.g. GPT-4 Turbo)
|
25
|
-
# We want the lower of the two limits
|
26
|
-
max_tokens = [leftover_tokens, completion_token_limit(model_name)].min
|
27
|
-
|
28
|
-
# Raise an error even if whole prompt is equal to the model's token limit (leftover_tokens == 0)
|
29
|
-
if max_tokens < 0
|
30
|
-
raise limit_exceeded_exception(token_limit(model_name), text_token_length)
|
31
|
-
end
|
32
|
-
|
33
|
-
max_tokens
|
34
|
-
end
|
35
|
-
|
36
|
-
def self.limit_exceeded_exception(limit, length)
|
37
|
-
TokenLimitExceeded.new("This model's maximum context length is #{limit} tokens, but the given text is #{length} tokens long.", length - limit)
|
38
|
-
end
|
39
|
-
end
|
40
|
-
end
|
41
|
-
end
|
42
|
-
end
|
@@ -1,49 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
module Langchain
|
4
|
-
module Utils
|
5
|
-
module TokenLength
|
6
|
-
#
|
7
|
-
# This class is meant to validate the length of the text passed in to Cohere's API.
|
8
|
-
# It is used to validate the token length before the API call is made
|
9
|
-
#
|
10
|
-
|
11
|
-
class CohereValidator < BaseValidator
|
12
|
-
TOKEN_LIMITS = {
|
13
|
-
# Source:
|
14
|
-
# https://docs.cohere.com/docs/models
|
15
|
-
"command-light" => 4096,
|
16
|
-
"command" => 4096,
|
17
|
-
"base-light" => 2048,
|
18
|
-
"base" => 2048,
|
19
|
-
"embed-english-light-v2.0" => 512,
|
20
|
-
"embed-english-v2.0" => 512,
|
21
|
-
"embed-multilingual-v2.0" => 256,
|
22
|
-
"summarize-medium" => 2048,
|
23
|
-
"summarize-xlarge" => 2048
|
24
|
-
}.freeze
|
25
|
-
|
26
|
-
#
|
27
|
-
# Calculate token length for a given text and model name
|
28
|
-
#
|
29
|
-
# @param text [String] The text to calculate the token length for
|
30
|
-
# @param model_name [String] The model name to validate against
|
31
|
-
# @return [Integer] The token length of the text
|
32
|
-
#
|
33
|
-
def self.token_length(text, model_name, options = {})
|
34
|
-
res = options[:llm].tokenize(text: text)
|
35
|
-
res["tokens"].length
|
36
|
-
end
|
37
|
-
|
38
|
-
def self.token_limit(model_name)
|
39
|
-
TOKEN_LIMITS[model_name]
|
40
|
-
end
|
41
|
-
singleton_class.alias_method :completion_token_limit, :token_limit
|
42
|
-
|
43
|
-
def self.token_length_from_messages(messages, model_name, options)
|
44
|
-
messages.sum { |message| token_length(message.to_json, model_name, options) }
|
45
|
-
end
|
46
|
-
end
|
47
|
-
end
|
48
|
-
end
|
49
|
-
end
|
@@ -1,57 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
module Langchain
|
4
|
-
module Utils
|
5
|
-
module TokenLength
|
6
|
-
#
|
7
|
-
# This class is meant to validate the length of the text passed in to Google Palm's API.
|
8
|
-
# It is used to validate the token length before the API call is made
|
9
|
-
#
|
10
|
-
class GooglePalmValidator < BaseValidator
|
11
|
-
TOKEN_LIMITS = {
|
12
|
-
# Source:
|
13
|
-
# This data can be pulled when `list_models()` method is called: https://github.com/andreibondarev/google_palm_api#usage
|
14
|
-
|
15
|
-
# chat-bison-001 is the only model that currently supports countMessageTokens functions
|
16
|
-
"chat-bison-001" => {
|
17
|
-
"input_token_limit" => 4000, # 4096 is the limit but the countMessageTokens does not return anything higher than 4000
|
18
|
-
"output_token_limit" => 1024
|
19
|
-
}
|
20
|
-
# "text-bison-001" => {
|
21
|
-
# "input_token_limit" => 8196,
|
22
|
-
# "output_token_limit" => 1024
|
23
|
-
# },
|
24
|
-
# "embedding-gecko-001" => {
|
25
|
-
# "input_token_limit" => 1024
|
26
|
-
# }
|
27
|
-
}.freeze
|
28
|
-
|
29
|
-
#
|
30
|
-
# Calculate token length for a given text and model name
|
31
|
-
#
|
32
|
-
# @param text [String] The text to calculate the token length for
|
33
|
-
# @param model_name [String] The model name to validate against
|
34
|
-
# @param options [Hash] the options to create a message with
|
35
|
-
# @option options [Langchain::LLM:GooglePalm] :llm The Langchain::LLM:GooglePalm instance
|
36
|
-
# @return [Integer] The token length of the text
|
37
|
-
#
|
38
|
-
def self.token_length(text, model_name = "chat-bison-001", options = {})
|
39
|
-
response = options[:llm].client.count_message_tokens(model: model_name, prompt: text)
|
40
|
-
|
41
|
-
raise Langchain::LLM::ApiError.new(response["error"]["message"]) unless response["error"].nil?
|
42
|
-
|
43
|
-
response.dig("tokenCount")
|
44
|
-
end
|
45
|
-
|
46
|
-
def self.token_length_from_messages(messages, model_name, options = {})
|
47
|
-
messages.sum { |message| token_length(message.to_json, model_name, options) }
|
48
|
-
end
|
49
|
-
|
50
|
-
def self.token_limit(model_name)
|
51
|
-
TOKEN_LIMITS.dig(model_name, "input_token_limit")
|
52
|
-
end
|
53
|
-
singleton_class.alias_method :completion_token_limit, :token_limit
|
54
|
-
end
|
55
|
-
end
|
56
|
-
end
|
57
|
-
end
|