langchainrb 0.5.5 → 0.5.7
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -0
- data/Gemfile.lock +3 -1
- data/README.md +7 -5
- data/examples/store_and_query_with_pinecone.rb +5 -4
- data/lib/langchain/agent/base.rb +5 -0
- data/lib/langchain/agent/chain_of_thought_agent/chain_of_thought_agent.rb +22 -10
- data/lib/langchain/agent/chain_of_thought_agent/chain_of_thought_agent_prompt.yaml +26 -0
- data/lib/langchain/agent/sql_query_agent/sql_query_agent.rb +8 -8
- data/lib/langchain/agent/sql_query_agent/sql_query_agent_answer_prompt.yaml +11 -0
- data/lib/langchain/agent/sql_query_agent/sql_query_agent_sql_prompt.yaml +21 -0
- data/lib/langchain/chunker/base.rb +15 -0
- data/lib/langchain/chunker/text.rb +38 -0
- data/lib/langchain/contextual_logger.rb +60 -0
- data/lib/langchain/conversation.rb +35 -4
- data/lib/langchain/data.rb +4 -0
- data/lib/langchain/llm/ai21.rb +16 -2
- data/lib/langchain/llm/cohere.rb +5 -4
- data/lib/langchain/llm/google_palm.rb +15 -7
- data/lib/langchain/llm/openai.rb +67 -17
- data/lib/langchain/llm/prompts/summarize_template.yaml +9 -0
- data/lib/langchain/llm/replicate.rb +6 -5
- data/lib/langchain/prompt/base.rb +2 -2
- data/lib/langchain/tool/base.rb +9 -3
- data/lib/langchain/tool/calculator.rb +7 -9
- data/lib/langchain/tool/database.rb +29 -8
- data/lib/langchain/tool/{serp_api.rb → google_search.rb} +9 -9
- data/lib/langchain/tool/ruby_code_interpreter.rb +1 -1
- data/lib/langchain/tool/weather.rb +2 -2
- data/lib/langchain/tool/wikipedia.rb +1 -1
- data/lib/langchain/utils/token_length/base_validator.rb +38 -0
- data/lib/langchain/utils/token_length/google_palm_validator.rb +9 -29
- data/lib/langchain/utils/token_length/openai_validator.rb +10 -27
- data/lib/langchain/utils/token_length/token_limit_exceeded.rb +17 -0
- data/lib/langchain/vectorsearch/base.rb +6 -0
- data/lib/langchain/vectorsearch/chroma.rb +1 -1
- data/lib/langchain/vectorsearch/hnswlib.rb +2 -2
- data/lib/langchain/vectorsearch/milvus.rb +1 -14
- data/lib/langchain/vectorsearch/pgvector.rb +1 -5
- data/lib/langchain/vectorsearch/pinecone.rb +1 -4
- data/lib/langchain/vectorsearch/qdrant.rb +1 -4
- data/lib/langchain/vectorsearch/weaviate.rb +1 -4
- data/lib/langchain/version.rb +1 -1
- data/lib/langchain.rb +28 -12
- metadata +30 -11
- data/lib/langchain/agent/chain_of_thought_agent/chain_of_thought_agent_prompt.json +0 -10
- data/lib/langchain/agent/sql_query_agent/sql_query_agent_answer_prompt.json +0 -10
- data/lib/langchain/agent/sql_query_agent/sql_query_agent_sql_prompt.json +0 -10
- data/lib/langchain/llm/prompts/summarize_template.json +0 -5
@@ -22,14 +22,19 @@ module Langchain::LLM
|
|
22
22
|
|
23
23
|
DEFAULTS = {
|
24
24
|
temperature: 0.0,
|
25
|
-
dimension: 768 # This is what the `embedding-gecko-001` model generates
|
25
|
+
dimension: 768, # This is what the `embedding-gecko-001` model generates
|
26
|
+
completion_model_name: "text-bison-001",
|
27
|
+
chat_completion_model_name: "chat-bison-001",
|
28
|
+
embeddings_model_name: "embedding-gecko-001"
|
26
29
|
}.freeze
|
30
|
+
LENGTH_VALIDATOR = Langchain::Utils::TokenLength::GooglePalmValidator
|
27
31
|
|
28
|
-
def initialize(api_key:)
|
32
|
+
def initialize(api_key:, default_options: {})
|
29
33
|
depends_on "google_palm_api"
|
30
34
|
require "google_palm_api"
|
31
35
|
|
32
36
|
@client = ::GooglePalmApi::Client.new(api_key: api_key)
|
37
|
+
@defaults = DEFAULTS.merge(default_options)
|
33
38
|
end
|
34
39
|
|
35
40
|
#
|
@@ -55,7 +60,8 @@ module Langchain::LLM
|
|
55
60
|
def complete(prompt:, **params)
|
56
61
|
default_params = {
|
57
62
|
prompt: prompt,
|
58
|
-
temperature:
|
63
|
+
temperature: @defaults[:temperature],
|
64
|
+
completion_model_name: @defaults[:completion_model_name]
|
59
65
|
}
|
60
66
|
|
61
67
|
if params[:stop_sequences]
|
@@ -84,13 +90,15 @@ module Langchain::LLM
|
|
84
90
|
raise ArgumentError.new(":prompt or :messages argument is expected") if prompt.empty? && messages.empty?
|
85
91
|
|
86
92
|
default_params = {
|
87
|
-
temperature:
|
93
|
+
temperature: @defaults[:temperature],
|
94
|
+
chat_completion_model_name: @defaults[:chat_completion_model_name],
|
88
95
|
context: context,
|
89
96
|
messages: compose_chat_messages(prompt: prompt, messages: messages),
|
90
97
|
examples: compose_examples(examples)
|
91
98
|
}
|
92
99
|
|
93
|
-
|
100
|
+
# chat-bison-001 is the only model that currently supports countMessageTokens functions
|
101
|
+
LENGTH_VALIDATOR.validate_max_tokens!(default_params[:messages], "chat-bison-001", llm: self)
|
94
102
|
|
95
103
|
if options[:stop_sequences]
|
96
104
|
default_params[:stop] = options.delete(:stop_sequences)
|
@@ -116,13 +124,13 @@ module Langchain::LLM
|
|
116
124
|
#
|
117
125
|
def summarize(text:)
|
118
126
|
prompt_template = Langchain::Prompt.load_from_path(
|
119
|
-
file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.
|
127
|
+
file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.yaml")
|
120
128
|
)
|
121
129
|
prompt = prompt_template.format(text: text)
|
122
130
|
|
123
131
|
complete(
|
124
132
|
prompt: prompt,
|
125
|
-
temperature:
|
133
|
+
temperature: @defaults[:temperature],
|
126
134
|
# Most models have a context length of 2048 tokens (except for the newest models, which support 4096).
|
127
135
|
max_tokens: 2048
|
128
136
|
)
|
data/lib/langchain/llm/openai.rb
CHANGED
@@ -17,12 +17,14 @@ module Langchain::LLM
|
|
17
17
|
embeddings_model_name: "text-embedding-ada-002",
|
18
18
|
dimension: 1536
|
19
19
|
}.freeze
|
20
|
+
LENGTH_VALIDATOR = Langchain::Utils::TokenLength::OpenAIValidator
|
20
21
|
|
21
|
-
def initialize(api_key:, llm_options: {})
|
22
|
+
def initialize(api_key:, llm_options: {}, default_options: {})
|
22
23
|
depends_on "ruby-openai"
|
23
24
|
require "openai"
|
24
25
|
|
25
26
|
@client = ::OpenAI::Client.new(access_token: api_key, **llm_options)
|
27
|
+
@defaults = DEFAULTS.merge(default_options)
|
26
28
|
end
|
27
29
|
|
28
30
|
#
|
@@ -33,9 +35,9 @@ module Langchain::LLM
|
|
33
35
|
# @return [Array] The embedding
|
34
36
|
#
|
35
37
|
def embed(text:, **params)
|
36
|
-
parameters = {model:
|
38
|
+
parameters = {model: @defaults[:embeddings_model_name], input: text}
|
37
39
|
|
38
|
-
|
40
|
+
validate_max_tokens(text, parameters[:model])
|
39
41
|
|
40
42
|
response = client.embeddings(parameters: parameters.merge(params))
|
41
43
|
response.dig("data").first.dig("embedding")
|
@@ -49,37 +51,85 @@ module Langchain::LLM
|
|
49
51
|
# @return [String] The completion
|
50
52
|
#
|
51
53
|
def complete(prompt:, **params)
|
52
|
-
parameters = compose_parameters
|
54
|
+
parameters = compose_parameters @defaults[:completion_model_name], params
|
53
55
|
|
54
56
|
parameters[:prompt] = prompt
|
55
|
-
parameters[:max_tokens] =
|
57
|
+
parameters[:max_tokens] = validate_max_tokens(prompt, parameters[:model])
|
56
58
|
|
57
59
|
response = client.completions(parameters: parameters)
|
58
60
|
response.dig("choices", 0, "text")
|
59
61
|
end
|
60
62
|
|
61
63
|
#
|
62
|
-
# Generate a chat completion for a given prompt
|
64
|
+
# Generate a chat completion for a given prompt or messages.
|
65
|
+
#
|
66
|
+
# == Examples
|
67
|
+
#
|
68
|
+
# # simplest case, just give a prompt
|
69
|
+
# openai.chat prompt: "When was Ruby first released?"
|
70
|
+
#
|
71
|
+
# # prompt plus some context about how to respond
|
72
|
+
# openai.chat context: "You are RubyGPT, a helpful chat bot for helping people learn Ruby", prompt: "Does Ruby have a REPL like IPython?"
|
73
|
+
#
|
74
|
+
# # full control over messages that get sent, equivilent to the above
|
75
|
+
# openai.chat messages: [
|
76
|
+
# {
|
77
|
+
# role: "system",
|
78
|
+
# content: "You are RubyGPT, a helpful chat bot for helping people learn Ruby", prompt: "Does Ruby have a REPL like IPython?"
|
79
|
+
# },
|
80
|
+
# {
|
81
|
+
# role: "user",
|
82
|
+
# content: "When was Ruby first released?"
|
83
|
+
# }
|
84
|
+
# ]
|
85
|
+
#
|
86
|
+
# # few-short prompting with examples
|
87
|
+
# openai.chat prompt: "When was factory_bot released?",
|
88
|
+
# examples: [
|
89
|
+
# {
|
90
|
+
# role: "user",
|
91
|
+
# content: "When was Ruby on Rails released?"
|
92
|
+
# }
|
93
|
+
# {
|
94
|
+
# role: "assistant",
|
95
|
+
# content: "2004"
|
96
|
+
# },
|
97
|
+
# ]
|
63
98
|
#
|
64
99
|
# @param prompt [String] The prompt to generate a chat completion for
|
65
|
-
# @param messages [Array] The messages that have been sent in the conversation
|
66
|
-
#
|
67
|
-
#
|
68
|
-
#
|
100
|
+
# @param messages [Array<Hash>] The messages that have been sent in the conversation
|
101
|
+
# Each message should be a Hash with the following keys:
|
102
|
+
# - :content [String] The content of the message
|
103
|
+
# - :role [String] The role of the sender (system, user, assistant, or function)
|
104
|
+
# @param context [String] An initial context to provide as a system message, ie "You are RubyGPT, a helpful chat bot for helping people learn Ruby"
|
105
|
+
# @param examples [Array<Hash>] Examples of messages to provide to the model. Useful for Few-Shot Prompting
|
106
|
+
# Each message should be a Hash with the following keys:
|
107
|
+
# - :content [String] The content of the message
|
108
|
+
# - :role [String] The role of the sender (system, user, assistant, or function)
|
109
|
+
# @param options <Hash> extra parameters passed to OpenAI::Client#chat
|
110
|
+
# @yield [String] Stream responses back one String at a time
|
69
111
|
# @return [String] The chat completion
|
70
112
|
#
|
71
113
|
def chat(prompt: "", messages: [], context: "", examples: [], **options)
|
72
114
|
raise ArgumentError.new(":prompt or :messages argument is expected") if prompt.empty? && messages.empty?
|
73
115
|
|
74
|
-
parameters = compose_parameters
|
116
|
+
parameters = compose_parameters @defaults[:chat_completion_model_name], options
|
75
117
|
parameters[:messages] = compose_chat_messages(prompt: prompt, messages: messages, context: context, examples: examples)
|
76
118
|
parameters[:max_tokens] = validate_max_tokens(parameters[:messages], parameters[:model])
|
77
119
|
|
120
|
+
if (streaming = block_given?)
|
121
|
+
parameters[:stream] = proc do |chunk, _bytesize|
|
122
|
+
yield chunk.dig("choices", 0, "delta", "content")
|
123
|
+
end
|
124
|
+
end
|
125
|
+
|
78
126
|
response = client.chat(parameters: parameters)
|
79
127
|
|
80
|
-
raise "Chat completion failed: #{response}" if response.dig("error")
|
128
|
+
raise "Chat completion failed: #{response}" if !response.empty? && response.dig("error")
|
81
129
|
|
82
|
-
|
130
|
+
unless streaming
|
131
|
+
response.dig("choices", 0, "message", "content")
|
132
|
+
end
|
83
133
|
end
|
84
134
|
|
85
135
|
#
|
@@ -90,17 +140,17 @@ module Langchain::LLM
|
|
90
140
|
#
|
91
141
|
def summarize(text:)
|
92
142
|
prompt_template = Langchain::Prompt.load_from_path(
|
93
|
-
file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.
|
143
|
+
file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.yaml")
|
94
144
|
)
|
95
145
|
prompt = prompt_template.format(text: text)
|
96
146
|
|
97
|
-
complete(prompt: prompt, temperature:
|
147
|
+
complete(prompt: prompt, temperature: @defaults[:temperature])
|
98
148
|
end
|
99
149
|
|
100
150
|
private
|
101
151
|
|
102
152
|
def compose_parameters(model, params)
|
103
|
-
default_params = {model: model, temperature:
|
153
|
+
default_params = {model: model, temperature: @defaults[:temperature]}
|
104
154
|
|
105
155
|
default_params[:stop] = params.delete(:stop_sequences) if params[:stop_sequences]
|
106
156
|
|
@@ -140,7 +190,7 @@ module Langchain::LLM
|
|
140
190
|
end
|
141
191
|
|
142
192
|
def validate_max_tokens(messages, model)
|
143
|
-
|
193
|
+
LENGTH_VALIDATOR.validate_max_tokens!(messages, model)
|
144
194
|
end
|
145
195
|
end
|
146
196
|
end
|
@@ -32,7 +32,7 @@ module Langchain::LLM
|
|
32
32
|
#
|
33
33
|
# @param api_key [String] The API key to use
|
34
34
|
#
|
35
|
-
def initialize(api_key:)
|
35
|
+
def initialize(api_key:, default_options: {})
|
36
36
|
depends_on "replicate-ruby"
|
37
37
|
require "replicate"
|
38
38
|
|
@@ -41,6 +41,7 @@ module Langchain::LLM
|
|
41
41
|
end
|
42
42
|
|
43
43
|
@client = ::Replicate.client
|
44
|
+
@defaults = DEFAULTS.merge(default_options)
|
44
45
|
end
|
45
46
|
|
46
47
|
#
|
@@ -94,13 +95,13 @@ module Langchain::LLM
|
|
94
95
|
#
|
95
96
|
def summarize(text:)
|
96
97
|
prompt_template = Langchain::Prompt.load_from_path(
|
97
|
-
file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.
|
98
|
+
file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.yaml")
|
98
99
|
)
|
99
100
|
prompt = prompt_template.format(text: text)
|
100
101
|
|
101
102
|
complete(
|
102
103
|
prompt: prompt,
|
103
|
-
temperature:
|
104
|
+
temperature: @defaults[:temperature],
|
104
105
|
# Most models have a context length of 2048 tokens (except for the newest models, which support 4096).
|
105
106
|
max_tokens: 2048
|
106
107
|
)
|
@@ -111,11 +112,11 @@ module Langchain::LLM
|
|
111
112
|
private
|
112
113
|
|
113
114
|
def completion_model
|
114
|
-
@completion_model ||= client.retrieve_model(
|
115
|
+
@completion_model ||= client.retrieve_model(@defaults[:completion_model_name]).latest_version
|
115
116
|
end
|
116
117
|
|
117
118
|
def embeddings_model
|
118
|
-
@embeddings_model ||= client.retrieve_model(
|
119
|
+
@embeddings_model ||= client.retrieve_model(@defaults[:embeddings_model_name]).latest_version
|
119
120
|
end
|
120
121
|
end
|
121
122
|
end
|
@@ -45,11 +45,11 @@ module Langchain::Prompt
|
|
45
45
|
end
|
46
46
|
|
47
47
|
#
|
48
|
-
# Save the object to a file in JSON format.
|
48
|
+
# Save the object to a file in JSON or YAML format.
|
49
49
|
#
|
50
50
|
# @param file_path [String, Pathname] The path to the file to save the object to
|
51
51
|
#
|
52
|
-
# @raise [ArgumentError] If file_path doesn't end with .json
|
52
|
+
# @raise [ArgumentError] If file_path doesn't end with .json or .yaml or .yml
|
53
53
|
#
|
54
54
|
# @return [void]
|
55
55
|
#
|
data/lib/langchain/tool/base.rb
CHANGED
@@ -9,7 +9,7 @@ module Langchain::Tool
|
|
9
9
|
#
|
10
10
|
# - {Langchain::Tool::Calculator}: Calculate the result of a math expression
|
11
11
|
# - {Langchain::Tool::RubyCodeInterpretor}: Runs ruby code
|
12
|
-
# - {Langchain::Tool::
|
12
|
+
# - {Langchain::Tool::GoogleSearch}: search on Google (via SerpAPI)
|
13
13
|
# - {Langchain::Tool::Wikipedia}: search on Wikipedia
|
14
14
|
#
|
15
15
|
# == Usage
|
@@ -30,13 +30,13 @@ module Langchain::Tool
|
|
30
30
|
# agent = Langchain::Agent::ChainOfThoughtAgent.new(
|
31
31
|
# llm: :openai, # or :cohere, :hugging_face, :google_palm or :replicate
|
32
32
|
# llm_api_key: ENV["OPENAI_API_KEY"],
|
33
|
-
# tools: ["
|
33
|
+
# tools: ["google_search", "calculator", "wikipedia"]
|
34
34
|
# )
|
35
35
|
#
|
36
36
|
# 4. Confirm that the Agent is using the Tools you passed in:
|
37
37
|
#
|
38
38
|
# agent.tools
|
39
|
-
# # => ["
|
39
|
+
# # => ["google_search", "calculator", "wikipedia"]
|
40
40
|
#
|
41
41
|
# == Adding Tools
|
42
42
|
#
|
@@ -57,6 +57,12 @@ module Langchain::Tool
|
|
57
57
|
self.class.const_get(:NAME)
|
58
58
|
end
|
59
59
|
|
60
|
+
def self.logger_options
|
61
|
+
{
|
62
|
+
color: :light_blue
|
63
|
+
}
|
64
|
+
end
|
65
|
+
|
60
66
|
#
|
61
67
|
# Returns the DESCRIPTION constant of the tool
|
62
68
|
#
|
@@ -16,6 +16,11 @@ module Langchain::Tool
|
|
16
16
|
Useful for getting the result of a math expression.
|
17
17
|
|
18
18
|
The input to this tool should be a valid mathematical expression that could be executed by a simple calculator.
|
19
|
+
Usage:
|
20
|
+
Action Input: 1 + 1
|
21
|
+
Action Input: 3 * 2 / 4
|
22
|
+
Action Input: 9 - 7
|
23
|
+
Action Input: (4.1 + 2.3) / (2.0 - 5.6) * 3
|
19
24
|
DESC
|
20
25
|
|
21
26
|
def initialize
|
@@ -28,18 +33,11 @@ module Langchain::Tool
|
|
28
33
|
# @param input [String] math expression
|
29
34
|
# @return [String] Answer
|
30
35
|
def execute(input:)
|
31
|
-
Langchain.logger.info("
|
36
|
+
Langchain.logger.info("Executing \"#{input}\"", for: self.class)
|
32
37
|
|
33
38
|
Eqn::Calculator.calc(input)
|
34
39
|
rescue Eqn::ParseError, Eqn::NoVariableValueError
|
35
|
-
#
|
36
|
-
# We can use the google answer box to evaluate this expression
|
37
|
-
# TODO: Figure out to find a better way to evaluate these language expressions.
|
38
|
-
hash_results = Langchain::Tool::SerpApi
|
39
|
-
.new(api_key: ENV["SERPAPI_API_KEY"])
|
40
|
-
.execute_search(input: input)
|
41
|
-
hash_results.dig(:answer_box, :to) ||
|
42
|
-
hash_results.dig(:answer_box, :result)
|
40
|
+
"\"#{input}\" is an invalid mathematical expression"
|
43
41
|
end
|
44
42
|
end
|
45
43
|
end
|
@@ -14,15 +14,18 @@ module Langchain::Tool
|
|
14
14
|
The input to this tool should be valid SQL.
|
15
15
|
DESC
|
16
16
|
|
17
|
-
attr_reader :db
|
17
|
+
attr_reader :db, :requested_tables, :except_tables
|
18
18
|
|
19
19
|
#
|
20
20
|
# Establish a database connection
|
21
21
|
#
|
22
22
|
# @param connection_string [String] Database connection info, e.g. 'postgres://user:password@localhost:5432/db_name'
|
23
|
+
# @param tables [Array<Symbol>] The tables to use. Will use all if empty.
|
24
|
+
# @param except_tables [Array<Symbol>] The tables to exclude. Will exclude none if empty.
|
25
|
+
|
23
26
|
# @return [Database] Database object
|
24
27
|
#
|
25
|
-
def initialize(connection_string:)
|
28
|
+
def initialize(connection_string:, tables: [], except_tables: [])
|
26
29
|
depends_on "sequel"
|
27
30
|
require "sequel"
|
28
31
|
require "sequel/extensions/schema_dumper"
|
@@ -30,7 +33,8 @@ module Langchain::Tool
|
|
30
33
|
raise StandardError, "connection_string parameter cannot be blank" if connection_string.empty?
|
31
34
|
|
32
35
|
@db = Sequel.connect(connection_string)
|
33
|
-
@
|
36
|
+
@requested_tables = tables
|
37
|
+
@except_tables = except_tables
|
34
38
|
end
|
35
39
|
|
36
40
|
#
|
@@ -38,9 +42,26 @@ module Langchain::Tool
|
|
38
42
|
#
|
39
43
|
# @return [String] schema
|
40
44
|
#
|
41
|
-
def
|
42
|
-
Langchain.logger.info("
|
43
|
-
|
45
|
+
def dump_schema
|
46
|
+
Langchain.logger.info("Dumping schema tables and keys", for: self.class)
|
47
|
+
schema = ""
|
48
|
+
db.tables.each do |table|
|
49
|
+
next if except_tables.include?(table)
|
50
|
+
next unless requested_tables.empty? || requested_tables.include?(table)
|
51
|
+
|
52
|
+
schema << "CREATE TABLE #{table}(\n"
|
53
|
+
db.schema(table).each do |column|
|
54
|
+
schema << "#{column[0]} #{column[1][:type]}"
|
55
|
+
schema << " PRIMARY KEY" if column[1][:primary_key] == true
|
56
|
+
schema << "," unless column == db.schema(table).last
|
57
|
+
schema << "\n"
|
58
|
+
end
|
59
|
+
schema << ");\n"
|
60
|
+
db.foreign_key_list(table).each do |fk|
|
61
|
+
schema << "ALTER TABLE #{table} ADD FOREIGN KEY (#{fk[:columns][0]}) REFERENCES #{fk[:table]}(#{fk[:key][0]});\n"
|
62
|
+
end
|
63
|
+
end
|
64
|
+
schema
|
44
65
|
end
|
45
66
|
|
46
67
|
#
|
@@ -50,11 +71,11 @@ module Langchain::Tool
|
|
50
71
|
# @return [Array] results
|
51
72
|
#
|
52
73
|
def execute(input:)
|
53
|
-
Langchain.logger.info("
|
74
|
+
Langchain.logger.info("Executing \"#{input}\"", for: self.class)
|
54
75
|
|
55
76
|
db[input].to_a
|
56
77
|
rescue Sequel::DatabaseError => e
|
57
|
-
Langchain.logger.error(
|
78
|
+
Langchain.logger.error(e.message, for: self.class)
|
58
79
|
end
|
59
80
|
end
|
60
81
|
end
|
@@ -1,18 +1,18 @@
|
|
1
1
|
# frozen_string_literal: true
|
2
2
|
|
3
3
|
module Langchain::Tool
|
4
|
-
class
|
4
|
+
class GoogleSearch < Base
|
5
5
|
#
|
6
|
-
# Wrapper around
|
6
|
+
# Wrapper around Google Serp SPI
|
7
7
|
#
|
8
8
|
# Gem requirements: gem "google_search_results", "~> 2.0.0"
|
9
9
|
#
|
10
10
|
# Usage:
|
11
|
-
# search = Langchain::Tool::
|
11
|
+
# search = Langchain::Tool::GoogleSearch.new(api_key: "YOUR_API_KEY")
|
12
12
|
# search.execute(input: "What is the capital of France?")
|
13
13
|
#
|
14
14
|
|
15
|
-
NAME = "
|
15
|
+
NAME = "google_search"
|
16
16
|
|
17
17
|
description <<~DESC
|
18
18
|
A wrapper around Google Search.
|
@@ -26,10 +26,10 @@ module Langchain::Tool
|
|
26
26
|
attr_reader :api_key
|
27
27
|
|
28
28
|
#
|
29
|
-
# Initializes the
|
29
|
+
# Initializes the Google Search tool
|
30
30
|
#
|
31
|
-
# @param api_key [String]
|
32
|
-
# @return [Langchain::Tool::
|
31
|
+
# @param api_key [String] Search API key
|
32
|
+
# @return [Langchain::Tool::GoogleSearch] Google search tool
|
33
33
|
#
|
34
34
|
def initialize(api_key:)
|
35
35
|
depends_on "google_search_results"
|
@@ -54,7 +54,7 @@ module Langchain::Tool
|
|
54
54
|
# @return [String] Answer
|
55
55
|
#
|
56
56
|
def execute(input:)
|
57
|
-
Langchain.logger.info("
|
57
|
+
Langchain.logger.info("Executing \"#{input}\"", for: self.class)
|
58
58
|
|
59
59
|
hash_results = execute_search(input: input)
|
60
60
|
|
@@ -72,7 +72,7 @@ module Langchain::Tool
|
|
72
72
|
# @return [Hash] hash_results JSON
|
73
73
|
#
|
74
74
|
def execute_search(input:)
|
75
|
-
GoogleSearch
|
75
|
+
::GoogleSearch
|
76
76
|
.new(q: input, serp_api_key: api_key)
|
77
77
|
.get_hash
|
78
78
|
end
|
@@ -21,7 +21,7 @@ module Langchain::Tool
|
|
21
21
|
# @param input [String] ruby code expression
|
22
22
|
# @return [String] Answer
|
23
23
|
def execute(input:)
|
24
|
-
Langchain.logger.info("
|
24
|
+
Langchain.logger.info("Executing \"#{input}\"", for: self.class)
|
25
25
|
|
26
26
|
safe_eval(input)
|
27
27
|
end
|
@@ -21,7 +21,7 @@ module Langchain::Tool
|
|
21
21
|
|
22
22
|
description <<~DESC
|
23
23
|
Useful for getting current weather data
|
24
|
-
|
24
|
+
|
25
25
|
The input to this tool should be a city name followed by the units (imperial, metric, or standard)
|
26
26
|
Usage:
|
27
27
|
Action Input: St Louis, Missouri; metric
|
@@ -54,7 +54,7 @@ module Langchain::Tool
|
|
54
54
|
# @param input [String] comma separated city and unit (optional: imperial, metric, or standard)
|
55
55
|
# @return [String] Answer
|
56
56
|
def execute(input:)
|
57
|
-
Langchain.logger.info("
|
57
|
+
Langchain.logger.info("Executing for \"#{input}\"", for: self.class)
|
58
58
|
|
59
59
|
input_array = input.split(";")
|
60
60
|
city, units = *input_array.map(&:strip)
|
@@ -26,7 +26,7 @@ module Langchain::Tool
|
|
26
26
|
# @param input [String] search query
|
27
27
|
# @return [String] Answer
|
28
28
|
def execute(input:)
|
29
|
-
Langchain.logger.info("
|
29
|
+
Langchain.logger.info("Executing \"#{input}\"", for: self.class)
|
30
30
|
|
31
31
|
page = ::Wikipedia.find(input)
|
32
32
|
# It would be nice to figure out a way to provide page.content but the LLM token limit is an issue
|
@@ -0,0 +1,38 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Langchain
|
4
|
+
module Utils
|
5
|
+
module TokenLength
|
6
|
+
#
|
7
|
+
# Calculate the `max_tokens:` parameter to be set by calculating the context length of the text minus the prompt length
|
8
|
+
#
|
9
|
+
# @param content [String | Array<String>] The text or array of texts to validate
|
10
|
+
# @param model_name [String] The model name to validate against
|
11
|
+
# @return [Integer] Whether the text is valid or not
|
12
|
+
# @raise [TokenLimitExceeded] If the text is too long
|
13
|
+
#
|
14
|
+
class BaseValidator
|
15
|
+
def self.validate_max_tokens!(content, model_name, options = {})
|
16
|
+
text_token_length = if content.is_a?(Array)
|
17
|
+
content.sum { |item| token_length(item.to_json, model_name, options) }
|
18
|
+
else
|
19
|
+
token_length(content, model_name, options)
|
20
|
+
end
|
21
|
+
|
22
|
+
leftover_tokens = token_limit(model_name) - text_token_length
|
23
|
+
|
24
|
+
# Raise an error even if whole prompt is equal to the model's token limit (leftover_tokens == 0)
|
25
|
+
if leftover_tokens <= 0
|
26
|
+
raise limit_exceeded_exception(token_limit(model_name), text_token_length)
|
27
|
+
end
|
28
|
+
|
29
|
+
leftover_tokens
|
30
|
+
end
|
31
|
+
|
32
|
+
def self.limit_exceeded_exception(limit, length)
|
33
|
+
TokenLimitExceeded.new("This model's maximum context length is #{limit} tokens, but the given text is #{length} tokens long.", length - limit)
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
38
|
+
end
|
@@ -7,7 +7,7 @@ module Langchain
|
|
7
7
|
# This class is meant to validate the length of the text passed in to Google Palm's API.
|
8
8
|
# It is used to validate the token length before the API call is made
|
9
9
|
#
|
10
|
-
class GooglePalmValidator
|
10
|
+
class GooglePalmValidator < BaseValidator
|
11
11
|
TOKEN_LIMITS = {
|
12
12
|
# Source:
|
13
13
|
# This data can be pulled when `list_models()` method is called: https://github.com/andreibondarev/google_palm_api#usage
|
@@ -26,43 +26,23 @@ module Langchain
|
|
26
26
|
# }
|
27
27
|
}.freeze
|
28
28
|
|
29
|
-
#
|
30
|
-
# Validate the context length of the text
|
31
|
-
#
|
32
|
-
# @param content [String | Array<String>] The text or array of texts to validate
|
33
|
-
# @param model_name [String] The model name to validate against
|
34
|
-
# @return [Integer] Whether the text is valid or not
|
35
|
-
# @raise [TokenLimitExceeded] If the text is too long
|
36
|
-
#
|
37
|
-
def self.validate_max_tokens!(google_palm_llm, content, model_name)
|
38
|
-
text_token_length = if content.is_a?(Array)
|
39
|
-
content.sum { |item| token_length(google_palm_llm, item.to_json, model_name) }
|
40
|
-
else
|
41
|
-
token_length(google_palm_llm, content, model_name)
|
42
|
-
end
|
43
|
-
|
44
|
-
leftover_tokens = TOKEN_LIMITS.dig(model_name, "input_token_limit") - text_token_length
|
45
|
-
|
46
|
-
# Raise an error even if whole prompt is equal to the model's token limit (leftover_tokens == 0)
|
47
|
-
if leftover_tokens <= 0
|
48
|
-
raise TokenLimitExceeded, "This model's maximum context length is #{TOKEN_LIMITS.dig(model_name, "input_token_limit")} tokens, but the given text is #{text_token_length} tokens long."
|
49
|
-
end
|
50
|
-
|
51
|
-
leftover_tokens
|
52
|
-
end
|
53
|
-
|
54
29
|
#
|
55
30
|
# Calculate token length for a given text and model name
|
56
31
|
#
|
57
|
-
# @param llm [Langchain::LLM:GooglePalm] The Langchain::LLM:GooglePalm instance
|
58
32
|
# @param text [String] The text to calculate the token length for
|
59
33
|
# @param model_name [String] The model name to validate against
|
34
|
+
# @param options [Hash] the options to create a message with
|
35
|
+
# @option options [Langchain::LLM:GooglePalm] :llm The Langchain::LLM:GooglePalm instance
|
60
36
|
# @return [Integer] The token length of the text
|
61
37
|
#
|
62
|
-
def self.token_length(
|
63
|
-
response = llm.client.count_message_tokens(model: model_name, prompt: text)
|
38
|
+
def self.token_length(text, model_name = "chat-bison-001", options)
|
39
|
+
response = options[:llm].client.count_message_tokens(model: model_name, prompt: text)
|
64
40
|
response.dig("tokenCount")
|
65
41
|
end
|
42
|
+
|
43
|
+
def self.token_limit(model_name)
|
44
|
+
TOKEN_LIMITS.dig(model_name, "input_token_limit")
|
45
|
+
end
|
66
46
|
end
|
67
47
|
end
|
68
48
|
end
|