langchainrb 0.13.4 → 0.14.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +11 -0
- data/README.md +3 -18
- data/lib/langchain/assistants/assistant.rb +204 -79
- data/lib/langchain/assistants/messages/base.rb +35 -1
- data/lib/langchain/assistants/messages/ollama_message.rb +86 -0
- data/lib/langchain/assistants/thread.rb +8 -1
- data/lib/langchain/llm/ai21.rb +0 -4
- data/lib/langchain/llm/anthropic.rb +15 -6
- data/lib/langchain/llm/azure.rb +3 -3
- data/lib/langchain/llm/base.rb +1 -0
- data/lib/langchain/llm/cohere.rb +0 -2
- data/lib/langchain/llm/google_gemini.rb +17 -3
- data/lib/langchain/llm/google_palm.rb +1 -4
- data/lib/langchain/llm/ollama.rb +1 -1
- data/lib/langchain/llm/replicate.rb +1 -1
- data/lib/langchain/llm/response/google_gemini_response.rb +1 -1
- data/lib/langchain/llm/response/ollama_response.rb +19 -1
- data/lib/langchain/loader.rb +3 -1
- data/lib/langchain/utils/hash_transformer.rb +25 -0
- data/lib/langchain/vectorsearch/chroma.rb +3 -1
- data/lib/langchain/vectorsearch/milvus.rb +18 -3
- data/lib/langchain/version.rb +1 -1
- metadata +9 -27
- data/lib/langchain/utils/token_length/ai21_validator.rb +0 -41
- data/lib/langchain/utils/token_length/base_validator.rb +0 -42
- data/lib/langchain/utils/token_length/cohere_validator.rb +0 -49
- data/lib/langchain/utils/token_length/google_palm_validator.rb +0 -57
- data/lib/langchain/utils/token_length/openai_validator.rb +0 -138
- data/lib/langchain/utils/token_length/token_limit_exceeded.rb +0 -17
data/lib/langchain/llm/azure.rb
CHANGED
@@ -42,17 +42,17 @@ module Langchain::LLM
|
|
42
42
|
|
43
43
|
def embed(...)
|
44
44
|
@client = @embed_client
|
45
|
-
super
|
45
|
+
super
|
46
46
|
end
|
47
47
|
|
48
48
|
def complete(...)
|
49
49
|
@client = @chat_client
|
50
|
-
super
|
50
|
+
super
|
51
51
|
end
|
52
52
|
|
53
53
|
def chat(...)
|
54
54
|
@client = @chat_client
|
55
|
-
super
|
55
|
+
super
|
56
56
|
end
|
57
57
|
end
|
58
58
|
end
|
data/lib/langchain/llm/base.rb
CHANGED
@@ -8,6 +8,7 @@ module Langchain::LLM
|
|
8
8
|
# Langchain.rb provides a common interface to interact with all supported LLMs:
|
9
9
|
#
|
10
10
|
# - {Langchain::LLM::AI21}
|
11
|
+
# - {Langchain::LLM::Anthropic}
|
11
12
|
# - {Langchain::LLM::Azure}
|
12
13
|
# - {Langchain::LLM::Cohere}
|
13
14
|
# - {Langchain::LLM::GooglePalm}
|
data/lib/langchain/llm/cohere.rb
CHANGED
@@ -74,8 +74,6 @@ module Langchain::LLM
|
|
74
74
|
|
75
75
|
default_params.merge!(params)
|
76
76
|
|
77
|
-
default_params[:max_tokens] = Langchain::Utils::TokenLength::CohereValidator.validate_max_tokens!(prompt, default_params[:model], llm: client)
|
78
|
-
|
79
77
|
response = client.generate(**default_params)
|
80
78
|
Langchain::LLM::CohereResponse.new response, model: @defaults[:completion_model_name]
|
81
79
|
end
|
@@ -18,7 +18,9 @@ module Langchain::LLM
|
|
18
18
|
|
19
19
|
chat_parameters.update(
|
20
20
|
model: {default: @defaults[:chat_completion_model_name]},
|
21
|
-
temperature: {default: @defaults[:temperature]}
|
21
|
+
temperature: {default: @defaults[:temperature]},
|
22
|
+
generation_config: {default: nil},
|
23
|
+
safety_settings: {default: nil}
|
22
24
|
)
|
23
25
|
chat_parameters.remap(
|
24
26
|
messages: :contents,
|
@@ -42,13 +44,25 @@ module Langchain::LLM
|
|
42
44
|
raise ArgumentError.new("messages argument is required") if Array(params[:messages]).empty?
|
43
45
|
|
44
46
|
parameters = chat_parameters.to_params(params)
|
45
|
-
parameters[:generation_config]
|
47
|
+
parameters[:generation_config] ||= {}
|
48
|
+
parameters[:generation_config][:temperature] ||= parameters[:temperature] if parameters[:temperature]
|
49
|
+
parameters.delete(:temperature)
|
50
|
+
parameters[:generation_config][:top_p] ||= parameters[:top_p] if parameters[:top_p]
|
51
|
+
parameters.delete(:top_p)
|
52
|
+
parameters[:generation_config][:top_k] ||= parameters[:top_k] if parameters[:top_k]
|
53
|
+
parameters.delete(:top_k)
|
54
|
+
parameters[:generation_config][:max_output_tokens] ||= parameters[:max_tokens] if parameters[:max_tokens]
|
55
|
+
parameters.delete(:max_tokens)
|
56
|
+
parameters[:generation_config][:response_mime_type] ||= parameters[:response_format] if parameters[:response_format]
|
57
|
+
parameters.delete(:response_format)
|
58
|
+
parameters[:generation_config][:stop_sequences] ||= parameters[:stop] if parameters[:stop]
|
59
|
+
parameters.delete(:stop)
|
46
60
|
|
47
61
|
uri = URI("https://generativelanguage.googleapis.com/v1beta/models/#{parameters[:model]}:generateContent?key=#{api_key}")
|
48
62
|
|
49
63
|
request = Net::HTTP::Post.new(uri)
|
50
64
|
request.content_type = "application/json"
|
51
|
-
request.body = parameters.to_json
|
65
|
+
request.body = Langchain::Utils::HashTransformer.deep_transform_keys(parameters) { |key| Langchain::Utils::HashTransformer.camelize_lower(key.to_s).to_sym }.to_json
|
52
66
|
|
53
67
|
response = Net::HTTP.start(uri.hostname, uri.port, use_ssl: uri.scheme == "https") do |http|
|
54
68
|
http.request(request)
|
@@ -18,7 +18,7 @@ module Langchain::LLM
|
|
18
18
|
chat_completion_model_name: "chat-bison-001",
|
19
19
|
embeddings_model_name: "embedding-gecko-001"
|
20
20
|
}.freeze
|
21
|
-
|
21
|
+
|
22
22
|
ROLE_MAPPING = {
|
23
23
|
"assistant" => "ai"
|
24
24
|
}
|
@@ -96,9 +96,6 @@ module Langchain::LLM
|
|
96
96
|
examples: compose_examples(examples)
|
97
97
|
}
|
98
98
|
|
99
|
-
# chat-bison-001 is the only model that currently supports countMessageTokens functions
|
100
|
-
LENGTH_VALIDATOR.validate_max_tokens!(default_params[:messages], "chat-bison-001", llm: self)
|
101
|
-
|
102
99
|
if options[:stop_sequences]
|
103
100
|
default_params[:stop] = options.delete(:stop_sequences)
|
104
101
|
end
|
data/lib/langchain/llm/ollama.rb
CHANGED
@@ -64,7 +64,7 @@ module Langchain::LLM
|
|
64
64
|
# Generate a completion for a given prompt
|
65
65
|
#
|
66
66
|
# @param prompt [String] The prompt to generate a completion for
|
67
|
-
# @return [Langchain::LLM::ReplicateResponse]
|
67
|
+
# @return [Langchain::LLM::ReplicateResponse] Response object
|
68
68
|
#
|
69
69
|
def complete(prompt:, **params)
|
70
70
|
response = completion_model.predict(prompt: prompt)
|
@@ -36,7 +36,7 @@ module Langchain::LLM
|
|
36
36
|
end
|
37
37
|
|
38
38
|
def prompt_tokens
|
39
|
-
raw_response.
|
39
|
+
raw_response.fetch("prompt_eval_count", 0) if done?
|
40
40
|
end
|
41
41
|
|
42
42
|
def completion_tokens
|
@@ -47,6 +47,24 @@ module Langchain::LLM
|
|
47
47
|
prompt_tokens + completion_tokens if done?
|
48
48
|
end
|
49
49
|
|
50
|
+
def tool_calls
|
51
|
+
if chat_completion && (parsed_tool_calls = JSON.parse(chat_completion))
|
52
|
+
[parsed_tool_calls]
|
53
|
+
elsif completion&.include?("[TOOL_CALLS]") && (
|
54
|
+
parsed_tool_calls = JSON.parse(
|
55
|
+
completion
|
56
|
+
# Slice out the serialize JSON
|
57
|
+
.slice(/\{.*\}/)
|
58
|
+
# Replace hash rocket with colon
|
59
|
+
.gsub("=>", ":")
|
60
|
+
)
|
61
|
+
)
|
62
|
+
[parsed_tool_calls]
|
63
|
+
else
|
64
|
+
[]
|
65
|
+
end
|
66
|
+
end
|
67
|
+
|
50
68
|
private
|
51
69
|
|
52
70
|
def done?
|
data/lib/langchain/loader.rb
CHANGED
@@ -90,7 +90,9 @@ module Langchain
|
|
90
90
|
private
|
91
91
|
|
92
92
|
def load_from_url
|
93
|
-
URI.
|
93
|
+
unescaped_url = URI.decode_www_form_component(@path)
|
94
|
+
escaped_url = URI::DEFAULT_PARSER.escape(unescaped_url)
|
95
|
+
URI.parse(escaped_url).open
|
94
96
|
end
|
95
97
|
|
96
98
|
def load_from_path
|
@@ -0,0 +1,25 @@
|
|
1
|
+
module Langchain
|
2
|
+
module Utils
|
3
|
+
class HashTransformer
|
4
|
+
# Converts a string to camelCase
|
5
|
+
def self.camelize_lower(str)
|
6
|
+
str.split("_").inject([]) { |buffer, e| buffer.push(buffer.empty? ? e : e.capitalize) }.join
|
7
|
+
end
|
8
|
+
|
9
|
+
# Recursively transforms the keys of a hash to camel case
|
10
|
+
def self.deep_transform_keys(hash, &block)
|
11
|
+
case hash
|
12
|
+
when Hash
|
13
|
+
hash.each_with_object({}) do |(key, value), result|
|
14
|
+
new_key = block.call(key)
|
15
|
+
result[new_key] = deep_transform_keys(value, &block)
|
16
|
+
end
|
17
|
+
when Array
|
18
|
+
hash.map { |item| deep_transform_keys(item, &block) }
|
19
|
+
else
|
20
|
+
hash
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
@@ -64,7 +64,9 @@ module Langchain::Vectorsearch
|
|
64
64
|
# @param ids [Array<String>] The list of ids to remove
|
65
65
|
# @return [Hash] The response from the server
|
66
66
|
def remove_texts(ids:)
|
67
|
-
collection.delete(
|
67
|
+
collection.delete(
|
68
|
+
ids: ids.map(&:to_s)
|
69
|
+
)
|
68
70
|
end
|
69
71
|
|
70
72
|
# Create the collection with the default schema
|
@@ -6,7 +6,7 @@ module Langchain::Vectorsearch
|
|
6
6
|
# Wrapper around Milvus REST APIs.
|
7
7
|
#
|
8
8
|
# Gem requirements:
|
9
|
-
# gem "milvus", "~> 0.9.
|
9
|
+
# gem "milvus", "~> 0.9.3"
|
10
10
|
#
|
11
11
|
# Usage:
|
12
12
|
# milvus = Langchain::Vectorsearch::Milvus.new(url:, index_name:, llm:, api_key:)
|
@@ -39,6 +39,21 @@ module Langchain::Vectorsearch
|
|
39
39
|
)
|
40
40
|
end
|
41
41
|
|
42
|
+
# Deletes a list of texts in the index
|
43
|
+
#
|
44
|
+
# @param ids [Array<Integer>] The ids of texts to delete
|
45
|
+
# @return [Boolean] The response from the server
|
46
|
+
def remove_texts(ids:)
|
47
|
+
raise ArgumentError, "ids must be an array" unless ids.is_a?(Array)
|
48
|
+
# Convert ids to integers if strings are passed
|
49
|
+
ids = ids.map(&:to_i)
|
50
|
+
|
51
|
+
client.entities.delete(
|
52
|
+
collection_name: index_name,
|
53
|
+
expression: "id in #{ids}"
|
54
|
+
)
|
55
|
+
end
|
56
|
+
|
42
57
|
# TODO: Add update_texts method
|
43
58
|
|
44
59
|
# Create default schema
|
@@ -83,7 +98,7 @@ module Langchain::Vectorsearch
|
|
83
98
|
# @return [Boolean] The response from the server
|
84
99
|
def create_default_index
|
85
100
|
client.indices.create(
|
86
|
-
collection_name:
|
101
|
+
collection_name: index_name,
|
87
102
|
field_name: "vectors",
|
88
103
|
extra_params: [
|
89
104
|
{key: "metric_type", value: "L2"},
|
@@ -125,7 +140,7 @@ module Langchain::Vectorsearch
|
|
125
140
|
|
126
141
|
client.search(
|
127
142
|
collection_name: index_name,
|
128
|
-
output_fields: ["id", "content", "vectors"
|
143
|
+
output_fields: ["id", "content"], # Add "vectors" if need to have full vectors returned.
|
129
144
|
top_k: k.to_s,
|
130
145
|
vectors: [embedding],
|
131
146
|
dsl_type: 1,
|
data/lib/langchain/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: langchainrb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.14.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrei Bondarev
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2024-
|
11
|
+
date: 2024-07-12 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: baran
|
@@ -212,14 +212,14 @@ dependencies:
|
|
212
212
|
requirements:
|
213
213
|
- - "~>"
|
214
214
|
- !ruby/object:Gem::Version
|
215
|
-
version: '0.
|
215
|
+
version: '0.3'
|
216
216
|
type: :development
|
217
217
|
prerelease: false
|
218
218
|
version_requirements: !ruby/object:Gem::Requirement
|
219
219
|
requirements:
|
220
220
|
- - "~>"
|
221
221
|
- !ruby/object:Gem::Version
|
222
|
-
version: '0.
|
222
|
+
version: '0.3'
|
223
223
|
- !ruby/object:Gem::Dependency
|
224
224
|
name: aws-sdk-bedrockruntime
|
225
225
|
requirement: !ruby/object:Gem::Requirement
|
@@ -408,14 +408,14 @@ dependencies:
|
|
408
408
|
requirements:
|
409
409
|
- - "~>"
|
410
410
|
- !ruby/object:Gem::Version
|
411
|
-
version: 0.9.
|
411
|
+
version: 0.9.3
|
412
412
|
type: :development
|
413
413
|
prerelease: false
|
414
414
|
version_requirements: !ruby/object:Gem::Requirement
|
415
415
|
requirements:
|
416
416
|
- - "~>"
|
417
417
|
- !ruby/object:Gem::Version
|
418
|
-
version: 0.9.
|
418
|
+
version: 0.9.3
|
419
419
|
- !ruby/object:Gem::Dependency
|
420
420
|
name: llama_cpp
|
421
421
|
requirement: !ruby/object:Gem::Requirement
|
@@ -682,20 +682,6 @@ dependencies:
|
|
682
682
|
- - "~>"
|
683
683
|
- !ruby/object:Gem::Version
|
684
684
|
version: 0.1.0
|
685
|
-
- !ruby/object:Gem::Dependency
|
686
|
-
name: tiktoken_ruby
|
687
|
-
requirement: !ruby/object:Gem::Requirement
|
688
|
-
requirements:
|
689
|
-
- - "~>"
|
690
|
-
- !ruby/object:Gem::Version
|
691
|
-
version: 0.0.9
|
692
|
-
type: :development
|
693
|
-
prerelease: false
|
694
|
-
version_requirements: !ruby/object:Gem::Requirement
|
695
|
-
requirements:
|
696
|
-
- - "~>"
|
697
|
-
- !ruby/object:Gem::Version
|
698
|
-
version: 0.0.9
|
699
685
|
description: Build LLM-backed Ruby applications with Ruby's Langchain.rb
|
700
686
|
email:
|
701
687
|
- andrei.bondarev13@gmail.com
|
@@ -711,6 +697,7 @@ files:
|
|
711
697
|
- lib/langchain/assistants/messages/anthropic_message.rb
|
712
698
|
- lib/langchain/assistants/messages/base.rb
|
713
699
|
- lib/langchain/assistants/messages/google_gemini_message.rb
|
700
|
+
- lib/langchain/assistants/messages/ollama_message.rb
|
714
701
|
- lib/langchain/assistants/messages/openai_message.rb
|
715
702
|
- lib/langchain/assistants/thread.rb
|
716
703
|
- lib/langchain/chunk.rb
|
@@ -809,12 +796,7 @@ files:
|
|
809
796
|
- lib/langchain/tool/wikipedia/wikipedia.json
|
810
797
|
- lib/langchain/tool/wikipedia/wikipedia.rb
|
811
798
|
- lib/langchain/utils/cosine_similarity.rb
|
812
|
-
- lib/langchain/utils/
|
813
|
-
- lib/langchain/utils/token_length/base_validator.rb
|
814
|
-
- lib/langchain/utils/token_length/cohere_validator.rb
|
815
|
-
- lib/langchain/utils/token_length/google_palm_validator.rb
|
816
|
-
- lib/langchain/utils/token_length/openai_validator.rb
|
817
|
-
- lib/langchain/utils/token_length/token_limit_exceeded.rb
|
799
|
+
- lib/langchain/utils/hash_transformer.rb
|
818
800
|
- lib/langchain/vectorsearch/base.rb
|
819
801
|
- lib/langchain/vectorsearch/chroma.rb
|
820
802
|
- lib/langchain/vectorsearch/elasticsearch.rb
|
@@ -852,7 +834,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
852
834
|
- !ruby/object:Gem::Version
|
853
835
|
version: '0'
|
854
836
|
requirements: []
|
855
|
-
rubygems_version: 3.5.
|
837
|
+
rubygems_version: 3.5.14
|
856
838
|
signing_key:
|
857
839
|
specification_version: 4
|
858
840
|
summary: Build LLM-backed Ruby applications with Ruby's Langchain.rb
|
@@ -1,41 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
module Langchain
|
4
|
-
module Utils
|
5
|
-
module TokenLength
|
6
|
-
#
|
7
|
-
# This class is meant to validate the length of the text passed in to AI21's API.
|
8
|
-
# It is used to validate the token length before the API call is made
|
9
|
-
#
|
10
|
-
|
11
|
-
class AI21Validator < BaseValidator
|
12
|
-
TOKEN_LIMITS = {
|
13
|
-
"j2-ultra" => 8192,
|
14
|
-
"j2-mid" => 8192,
|
15
|
-
"j2-light" => 8192
|
16
|
-
}.freeze
|
17
|
-
|
18
|
-
#
|
19
|
-
# Calculate token length for a given text and model name
|
20
|
-
#
|
21
|
-
# @param text [String] The text to calculate the token length for
|
22
|
-
# @param model_name [String] The model name to validate against
|
23
|
-
# @return [Integer] The token length of the text
|
24
|
-
#
|
25
|
-
def self.token_length(text, model_name, options = {})
|
26
|
-
res = options[:llm].tokenize(text)
|
27
|
-
res.dig(:tokens).length
|
28
|
-
end
|
29
|
-
|
30
|
-
def self.token_limit(model_name)
|
31
|
-
TOKEN_LIMITS[model_name]
|
32
|
-
end
|
33
|
-
singleton_class.alias_method :completion_token_limit, :token_limit
|
34
|
-
|
35
|
-
def self.token_length_from_messages(messages, model_name, options)
|
36
|
-
messages.sum { |message| token_length(message.to_json, model_name, options) }
|
37
|
-
end
|
38
|
-
end
|
39
|
-
end
|
40
|
-
end
|
41
|
-
end
|
@@ -1,42 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
module Langchain
|
4
|
-
module Utils
|
5
|
-
module TokenLength
|
6
|
-
#
|
7
|
-
# Calculate the `max_tokens:` parameter to be set by calculating the context length of the text minus the prompt length
|
8
|
-
#
|
9
|
-
# @param content [String | Array<String>] The text or array of texts to validate
|
10
|
-
# @param model_name [String] The model name to validate against
|
11
|
-
# @return [Integer] Whether the text is valid or not
|
12
|
-
# @raise [TokenLimitExceeded] If the text is too long
|
13
|
-
#
|
14
|
-
class BaseValidator
|
15
|
-
def self.validate_max_tokens!(content, model_name, options = {})
|
16
|
-
text_token_length = if content.is_a?(Array)
|
17
|
-
token_length_from_messages(content, model_name, options)
|
18
|
-
else
|
19
|
-
token_length(content, model_name, options)
|
20
|
-
end
|
21
|
-
|
22
|
-
leftover_tokens = token_limit(model_name) - text_token_length
|
23
|
-
|
24
|
-
# Some models have a separate token limit for completions (e.g. GPT-4 Turbo)
|
25
|
-
# We want the lower of the two limits
|
26
|
-
max_tokens = [leftover_tokens, completion_token_limit(model_name)].min
|
27
|
-
|
28
|
-
# Raise an error even if whole prompt is equal to the model's token limit (leftover_tokens == 0)
|
29
|
-
if max_tokens < 0
|
30
|
-
raise limit_exceeded_exception(token_limit(model_name), text_token_length)
|
31
|
-
end
|
32
|
-
|
33
|
-
max_tokens
|
34
|
-
end
|
35
|
-
|
36
|
-
def self.limit_exceeded_exception(limit, length)
|
37
|
-
TokenLimitExceeded.new("This model's maximum context length is #{limit} tokens, but the given text is #{length} tokens long.", length - limit)
|
38
|
-
end
|
39
|
-
end
|
40
|
-
end
|
41
|
-
end
|
42
|
-
end
|
@@ -1,49 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
module Langchain
|
4
|
-
module Utils
|
5
|
-
module TokenLength
|
6
|
-
#
|
7
|
-
# This class is meant to validate the length of the text passed in to Cohere's API.
|
8
|
-
# It is used to validate the token length before the API call is made
|
9
|
-
#
|
10
|
-
|
11
|
-
class CohereValidator < BaseValidator
|
12
|
-
TOKEN_LIMITS = {
|
13
|
-
# Source:
|
14
|
-
# https://docs.cohere.com/docs/models
|
15
|
-
"command-light" => 4096,
|
16
|
-
"command" => 4096,
|
17
|
-
"base-light" => 2048,
|
18
|
-
"base" => 2048,
|
19
|
-
"embed-english-light-v2.0" => 512,
|
20
|
-
"embed-english-v2.0" => 512,
|
21
|
-
"embed-multilingual-v2.0" => 256,
|
22
|
-
"summarize-medium" => 2048,
|
23
|
-
"summarize-xlarge" => 2048
|
24
|
-
}.freeze
|
25
|
-
|
26
|
-
#
|
27
|
-
# Calculate token length for a given text and model name
|
28
|
-
#
|
29
|
-
# @param text [String] The text to calculate the token length for
|
30
|
-
# @param model_name [String] The model name to validate against
|
31
|
-
# @return [Integer] The token length of the text
|
32
|
-
#
|
33
|
-
def self.token_length(text, model_name, options = {})
|
34
|
-
res = options[:llm].tokenize(text: text)
|
35
|
-
res["tokens"].length
|
36
|
-
end
|
37
|
-
|
38
|
-
def self.token_limit(model_name)
|
39
|
-
TOKEN_LIMITS[model_name]
|
40
|
-
end
|
41
|
-
singleton_class.alias_method :completion_token_limit, :token_limit
|
42
|
-
|
43
|
-
def self.token_length_from_messages(messages, model_name, options)
|
44
|
-
messages.sum { |message| token_length(message.to_json, model_name, options) }
|
45
|
-
end
|
46
|
-
end
|
47
|
-
end
|
48
|
-
end
|
49
|
-
end
|
@@ -1,57 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
module Langchain
|
4
|
-
module Utils
|
5
|
-
module TokenLength
|
6
|
-
#
|
7
|
-
# This class is meant to validate the length of the text passed in to Google Palm's API.
|
8
|
-
# It is used to validate the token length before the API call is made
|
9
|
-
#
|
10
|
-
class GooglePalmValidator < BaseValidator
|
11
|
-
TOKEN_LIMITS = {
|
12
|
-
# Source:
|
13
|
-
# This data can be pulled when `list_models()` method is called: https://github.com/andreibondarev/google_palm_api#usage
|
14
|
-
|
15
|
-
# chat-bison-001 is the only model that currently supports countMessageTokens functions
|
16
|
-
"chat-bison-001" => {
|
17
|
-
"input_token_limit" => 4000, # 4096 is the limit but the countMessageTokens does not return anything higher than 4000
|
18
|
-
"output_token_limit" => 1024
|
19
|
-
}
|
20
|
-
# "text-bison-001" => {
|
21
|
-
# "input_token_limit" => 8196,
|
22
|
-
# "output_token_limit" => 1024
|
23
|
-
# },
|
24
|
-
# "embedding-gecko-001" => {
|
25
|
-
# "input_token_limit" => 1024
|
26
|
-
# }
|
27
|
-
}.freeze
|
28
|
-
|
29
|
-
#
|
30
|
-
# Calculate token length for a given text and model name
|
31
|
-
#
|
32
|
-
# @param text [String] The text to calculate the token length for
|
33
|
-
# @param model_name [String] The model name to validate against
|
34
|
-
# @param options [Hash] the options to create a message with
|
35
|
-
# @option options [Langchain::LLM:GooglePalm] :llm The Langchain::LLM:GooglePalm instance
|
36
|
-
# @return [Integer] The token length of the text
|
37
|
-
#
|
38
|
-
def self.token_length(text, model_name = "chat-bison-001", options = {})
|
39
|
-
response = options[:llm].client.count_message_tokens(model: model_name, prompt: text)
|
40
|
-
|
41
|
-
raise Langchain::LLM::ApiError.new(response["error"]["message"]) unless response["error"].nil?
|
42
|
-
|
43
|
-
response.dig("tokenCount")
|
44
|
-
end
|
45
|
-
|
46
|
-
def self.token_length_from_messages(messages, model_name, options = {})
|
47
|
-
messages.sum { |message| token_length(message.to_json, model_name, options) }
|
48
|
-
end
|
49
|
-
|
50
|
-
def self.token_limit(model_name)
|
51
|
-
TOKEN_LIMITS.dig(model_name, "input_token_limit")
|
52
|
-
end
|
53
|
-
singleton_class.alias_method :completion_token_limit, :token_limit
|
54
|
-
end
|
55
|
-
end
|
56
|
-
end
|
57
|
-
end
|