langchainrb 0.5.5 → 0.5.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -0
- data/Gemfile.lock +3 -1
- data/README.md +7 -5
- data/examples/store_and_query_with_pinecone.rb +5 -4
- data/lib/langchain/agent/base.rb +5 -0
- data/lib/langchain/agent/chain_of_thought_agent/chain_of_thought_agent.rb +22 -10
- data/lib/langchain/agent/chain_of_thought_agent/chain_of_thought_agent_prompt.yaml +26 -0
- data/lib/langchain/agent/sql_query_agent/sql_query_agent.rb +8 -8
- data/lib/langchain/agent/sql_query_agent/sql_query_agent_answer_prompt.yaml +11 -0
- data/lib/langchain/agent/sql_query_agent/sql_query_agent_sql_prompt.yaml +21 -0
- data/lib/langchain/chunker/base.rb +15 -0
- data/lib/langchain/chunker/text.rb +38 -0
- data/lib/langchain/contextual_logger.rb +60 -0
- data/lib/langchain/conversation.rb +35 -4
- data/lib/langchain/data.rb +4 -0
- data/lib/langchain/llm/ai21.rb +16 -2
- data/lib/langchain/llm/cohere.rb +5 -4
- data/lib/langchain/llm/google_palm.rb +15 -7
- data/lib/langchain/llm/openai.rb +67 -17
- data/lib/langchain/llm/prompts/summarize_template.yaml +9 -0
- data/lib/langchain/llm/replicate.rb +6 -5
- data/lib/langchain/prompt/base.rb +2 -2
- data/lib/langchain/tool/base.rb +9 -3
- data/lib/langchain/tool/calculator.rb +7 -9
- data/lib/langchain/tool/database.rb +29 -8
- data/lib/langchain/tool/{serp_api.rb → google_search.rb} +9 -9
- data/lib/langchain/tool/ruby_code_interpreter.rb +1 -1
- data/lib/langchain/tool/weather.rb +2 -2
- data/lib/langchain/tool/wikipedia.rb +1 -1
- data/lib/langchain/utils/token_length/base_validator.rb +38 -0
- data/lib/langchain/utils/token_length/google_palm_validator.rb +9 -29
- data/lib/langchain/utils/token_length/openai_validator.rb +10 -27
- data/lib/langchain/utils/token_length/token_limit_exceeded.rb +17 -0
- data/lib/langchain/vectorsearch/base.rb +6 -0
- data/lib/langchain/vectorsearch/chroma.rb +1 -1
- data/lib/langchain/vectorsearch/hnswlib.rb +2 -2
- data/lib/langchain/vectorsearch/milvus.rb +1 -14
- data/lib/langchain/vectorsearch/pgvector.rb +1 -5
- data/lib/langchain/vectorsearch/pinecone.rb +1 -4
- data/lib/langchain/vectorsearch/qdrant.rb +1 -4
- data/lib/langchain/vectorsearch/weaviate.rb +1 -4
- data/lib/langchain/version.rb +1 -1
- data/lib/langchain.rb +28 -12
- metadata +30 -11
- data/lib/langchain/agent/chain_of_thought_agent/chain_of_thought_agent_prompt.json +0 -10
- data/lib/langchain/agent/sql_query_agent/sql_query_agent_answer_prompt.json +0 -10
- data/lib/langchain/agent/sql_query_agent/sql_query_agent_sql_prompt.json +0 -10
- data/lib/langchain/llm/prompts/summarize_template.json +0 -5
@@ -22,14 +22,19 @@ module Langchain::LLM
|
|
22
22
|
|
23
23
|
DEFAULTS = {
|
24
24
|
temperature: 0.0,
|
25
|
-
dimension: 768 # This is what the `embedding-gecko-001` model generates
|
25
|
+
dimension: 768, # This is what the `embedding-gecko-001` model generates
|
26
|
+
completion_model_name: "text-bison-001",
|
27
|
+
chat_completion_model_name: "chat-bison-001",
|
28
|
+
embeddings_model_name: "embedding-gecko-001"
|
26
29
|
}.freeze
|
30
|
+
LENGTH_VALIDATOR = Langchain::Utils::TokenLength::GooglePalmValidator
|
27
31
|
|
28
|
-
def initialize(api_key:)
|
32
|
+
def initialize(api_key:, default_options: {})
|
29
33
|
depends_on "google_palm_api"
|
30
34
|
require "google_palm_api"
|
31
35
|
|
32
36
|
@client = ::GooglePalmApi::Client.new(api_key: api_key)
|
37
|
+
@defaults = DEFAULTS.merge(default_options)
|
33
38
|
end
|
34
39
|
|
35
40
|
#
|
@@ -55,7 +60,8 @@ module Langchain::LLM
|
|
55
60
|
def complete(prompt:, **params)
|
56
61
|
default_params = {
|
57
62
|
prompt: prompt,
|
58
|
-
temperature:
|
63
|
+
temperature: @defaults[:temperature],
|
64
|
+
completion_model_name: @defaults[:completion_model_name]
|
59
65
|
}
|
60
66
|
|
61
67
|
if params[:stop_sequences]
|
@@ -84,13 +90,15 @@ module Langchain::LLM
|
|
84
90
|
raise ArgumentError.new(":prompt or :messages argument is expected") if prompt.empty? && messages.empty?
|
85
91
|
|
86
92
|
default_params = {
|
87
|
-
temperature:
|
93
|
+
temperature: @defaults[:temperature],
|
94
|
+
chat_completion_model_name: @defaults[:chat_completion_model_name],
|
88
95
|
context: context,
|
89
96
|
messages: compose_chat_messages(prompt: prompt, messages: messages),
|
90
97
|
examples: compose_examples(examples)
|
91
98
|
}
|
92
99
|
|
93
|
-
|
100
|
+
# chat-bison-001 is the only model that currently supports countMessageTokens functions
|
101
|
+
LENGTH_VALIDATOR.validate_max_tokens!(default_params[:messages], "chat-bison-001", llm: self)
|
94
102
|
|
95
103
|
if options[:stop_sequences]
|
96
104
|
default_params[:stop] = options.delete(:stop_sequences)
|
@@ -116,13 +124,13 @@ module Langchain::LLM
|
|
116
124
|
#
|
117
125
|
def summarize(text:)
|
118
126
|
prompt_template = Langchain::Prompt.load_from_path(
|
119
|
-
file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.
|
127
|
+
file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.yaml")
|
120
128
|
)
|
121
129
|
prompt = prompt_template.format(text: text)
|
122
130
|
|
123
131
|
complete(
|
124
132
|
prompt: prompt,
|
125
|
-
temperature:
|
133
|
+
temperature: @defaults[:temperature],
|
126
134
|
# Most models have a context length of 2048 tokens (except for the newest models, which support 4096).
|
127
135
|
max_tokens: 2048
|
128
136
|
)
|
data/lib/langchain/llm/openai.rb
CHANGED
@@ -17,12 +17,14 @@ module Langchain::LLM
|
|
17
17
|
embeddings_model_name: "text-embedding-ada-002",
|
18
18
|
dimension: 1536
|
19
19
|
}.freeze
|
20
|
+
LENGTH_VALIDATOR = Langchain::Utils::TokenLength::OpenAIValidator
|
20
21
|
|
21
|
-
def initialize(api_key:, llm_options: {})
|
22
|
+
def initialize(api_key:, llm_options: {}, default_options: {})
|
22
23
|
depends_on "ruby-openai"
|
23
24
|
require "openai"
|
24
25
|
|
25
26
|
@client = ::OpenAI::Client.new(access_token: api_key, **llm_options)
|
27
|
+
@defaults = DEFAULTS.merge(default_options)
|
26
28
|
end
|
27
29
|
|
28
30
|
#
|
@@ -33,9 +35,9 @@ module Langchain::LLM
|
|
33
35
|
# @return [Array] The embedding
|
34
36
|
#
|
35
37
|
def embed(text:, **params)
|
36
|
-
parameters = {model:
|
38
|
+
parameters = {model: @defaults[:embeddings_model_name], input: text}
|
37
39
|
|
38
|
-
|
40
|
+
validate_max_tokens(text, parameters[:model])
|
39
41
|
|
40
42
|
response = client.embeddings(parameters: parameters.merge(params))
|
41
43
|
response.dig("data").first.dig("embedding")
|
@@ -49,37 +51,85 @@ module Langchain::LLM
|
|
49
51
|
# @return [String] The completion
|
50
52
|
#
|
51
53
|
def complete(prompt:, **params)
|
52
|
-
parameters = compose_parameters
|
54
|
+
parameters = compose_parameters @defaults[:completion_model_name], params
|
53
55
|
|
54
56
|
parameters[:prompt] = prompt
|
55
|
-
parameters[:max_tokens] =
|
57
|
+
parameters[:max_tokens] = validate_max_tokens(prompt, parameters[:model])
|
56
58
|
|
57
59
|
response = client.completions(parameters: parameters)
|
58
60
|
response.dig("choices", 0, "text")
|
59
61
|
end
|
60
62
|
|
61
63
|
#
|
62
|
-
# Generate a chat completion for a given prompt
|
64
|
+
# Generate a chat completion for a given prompt or messages.
|
65
|
+
#
|
66
|
+
# == Examples
|
67
|
+
#
|
68
|
+
# # simplest case, just give a prompt
|
69
|
+
# openai.chat prompt: "When was Ruby first released?"
|
70
|
+
#
|
71
|
+
# # prompt plus some context about how to respond
|
72
|
+
# openai.chat context: "You are RubyGPT, a helpful chat bot for helping people learn Ruby", prompt: "Does Ruby have a REPL like IPython?"
|
73
|
+
#
|
74
|
+
# # full control over messages that get sent, equivilent to the above
|
75
|
+
# openai.chat messages: [
|
76
|
+
# {
|
77
|
+
# role: "system",
|
78
|
+
# content: "You are RubyGPT, a helpful chat bot for helping people learn Ruby", prompt: "Does Ruby have a REPL like IPython?"
|
79
|
+
# },
|
80
|
+
# {
|
81
|
+
# role: "user",
|
82
|
+
# content: "When was Ruby first released?"
|
83
|
+
# }
|
84
|
+
# ]
|
85
|
+
#
|
86
|
+
# # few-short prompting with examples
|
87
|
+
# openai.chat prompt: "When was factory_bot released?",
|
88
|
+
# examples: [
|
89
|
+
# {
|
90
|
+
# role: "user",
|
91
|
+
# content: "When was Ruby on Rails released?"
|
92
|
+
# }
|
93
|
+
# {
|
94
|
+
# role: "assistant",
|
95
|
+
# content: "2004"
|
96
|
+
# },
|
97
|
+
# ]
|
63
98
|
#
|
64
99
|
# @param prompt [String] The prompt to generate a chat completion for
|
65
|
-
# @param messages [Array] The messages that have been sent in the conversation
|
66
|
-
#
|
67
|
-
#
|
68
|
-
#
|
100
|
+
# @param messages [Array<Hash>] The messages that have been sent in the conversation
|
101
|
+
# Each message should be a Hash with the following keys:
|
102
|
+
# - :content [String] The content of the message
|
103
|
+
# - :role [String] The role of the sender (system, user, assistant, or function)
|
104
|
+
# @param context [String] An initial context to provide as a system message, ie "You are RubyGPT, a helpful chat bot for helping people learn Ruby"
|
105
|
+
# @param examples [Array<Hash>] Examples of messages to provide to the model. Useful for Few-Shot Prompting
|
106
|
+
# Each message should be a Hash with the following keys:
|
107
|
+
# - :content [String] The content of the message
|
108
|
+
# - :role [String] The role of the sender (system, user, assistant, or function)
|
109
|
+
# @param options <Hash> extra parameters passed to OpenAI::Client#chat
|
110
|
+
# @yield [String] Stream responses back one String at a time
|
69
111
|
# @return [String] The chat completion
|
70
112
|
#
|
71
113
|
def chat(prompt: "", messages: [], context: "", examples: [], **options)
|
72
114
|
raise ArgumentError.new(":prompt or :messages argument is expected") if prompt.empty? && messages.empty?
|
73
115
|
|
74
|
-
parameters = compose_parameters
|
116
|
+
parameters = compose_parameters @defaults[:chat_completion_model_name], options
|
75
117
|
parameters[:messages] = compose_chat_messages(prompt: prompt, messages: messages, context: context, examples: examples)
|
76
118
|
parameters[:max_tokens] = validate_max_tokens(parameters[:messages], parameters[:model])
|
77
119
|
|
120
|
+
if (streaming = block_given?)
|
121
|
+
parameters[:stream] = proc do |chunk, _bytesize|
|
122
|
+
yield chunk.dig("choices", 0, "delta", "content")
|
123
|
+
end
|
124
|
+
end
|
125
|
+
|
78
126
|
response = client.chat(parameters: parameters)
|
79
127
|
|
80
|
-
raise "Chat completion failed: #{response}" if response.dig("error")
|
128
|
+
raise "Chat completion failed: #{response}" if !response.empty? && response.dig("error")
|
81
129
|
|
82
|
-
|
130
|
+
unless streaming
|
131
|
+
response.dig("choices", 0, "message", "content")
|
132
|
+
end
|
83
133
|
end
|
84
134
|
|
85
135
|
#
|
@@ -90,17 +140,17 @@ module Langchain::LLM
|
|
90
140
|
#
|
91
141
|
def summarize(text:)
|
92
142
|
prompt_template = Langchain::Prompt.load_from_path(
|
93
|
-
file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.
|
143
|
+
file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.yaml")
|
94
144
|
)
|
95
145
|
prompt = prompt_template.format(text: text)
|
96
146
|
|
97
|
-
complete(prompt: prompt, temperature:
|
147
|
+
complete(prompt: prompt, temperature: @defaults[:temperature])
|
98
148
|
end
|
99
149
|
|
100
150
|
private
|
101
151
|
|
102
152
|
def compose_parameters(model, params)
|
103
|
-
default_params = {model: model, temperature:
|
153
|
+
default_params = {model: model, temperature: @defaults[:temperature]}
|
104
154
|
|
105
155
|
default_params[:stop] = params.delete(:stop_sequences) if params[:stop_sequences]
|
106
156
|
|
@@ -140,7 +190,7 @@ module Langchain::LLM
|
|
140
190
|
end
|
141
191
|
|
142
192
|
def validate_max_tokens(messages, model)
|
143
|
-
|
193
|
+
LENGTH_VALIDATOR.validate_max_tokens!(messages, model)
|
144
194
|
end
|
145
195
|
end
|
146
196
|
end
|
@@ -32,7 +32,7 @@ module Langchain::LLM
|
|
32
32
|
#
|
33
33
|
# @param api_key [String] The API key to use
|
34
34
|
#
|
35
|
-
def initialize(api_key:)
|
35
|
+
def initialize(api_key:, default_options: {})
|
36
36
|
depends_on "replicate-ruby"
|
37
37
|
require "replicate"
|
38
38
|
|
@@ -41,6 +41,7 @@ module Langchain::LLM
|
|
41
41
|
end
|
42
42
|
|
43
43
|
@client = ::Replicate.client
|
44
|
+
@defaults = DEFAULTS.merge(default_options)
|
44
45
|
end
|
45
46
|
|
46
47
|
#
|
@@ -94,13 +95,13 @@ module Langchain::LLM
|
|
94
95
|
#
|
95
96
|
def summarize(text:)
|
96
97
|
prompt_template = Langchain::Prompt.load_from_path(
|
97
|
-
file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.
|
98
|
+
file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.yaml")
|
98
99
|
)
|
99
100
|
prompt = prompt_template.format(text: text)
|
100
101
|
|
101
102
|
complete(
|
102
103
|
prompt: prompt,
|
103
|
-
temperature:
|
104
|
+
temperature: @defaults[:temperature],
|
104
105
|
# Most models have a context length of 2048 tokens (except for the newest models, which support 4096).
|
105
106
|
max_tokens: 2048
|
106
107
|
)
|
@@ -111,11 +112,11 @@ module Langchain::LLM
|
|
111
112
|
private
|
112
113
|
|
113
114
|
def completion_model
|
114
|
-
@completion_model ||= client.retrieve_model(
|
115
|
+
@completion_model ||= client.retrieve_model(@defaults[:completion_model_name]).latest_version
|
115
116
|
end
|
116
117
|
|
117
118
|
def embeddings_model
|
118
|
-
@embeddings_model ||= client.retrieve_model(
|
119
|
+
@embeddings_model ||= client.retrieve_model(@defaults[:embeddings_model_name]).latest_version
|
119
120
|
end
|
120
121
|
end
|
121
122
|
end
|
@@ -45,11 +45,11 @@ module Langchain::Prompt
|
|
45
45
|
end
|
46
46
|
|
47
47
|
#
|
48
|
-
# Save the object to a file in JSON format.
|
48
|
+
# Save the object to a file in JSON or YAML format.
|
49
49
|
#
|
50
50
|
# @param file_path [String, Pathname] The path to the file to save the object to
|
51
51
|
#
|
52
|
-
# @raise [ArgumentError] If file_path doesn't end with .json
|
52
|
+
# @raise [ArgumentError] If file_path doesn't end with .json or .yaml or .yml
|
53
53
|
#
|
54
54
|
# @return [void]
|
55
55
|
#
|
data/lib/langchain/tool/base.rb
CHANGED
@@ -9,7 +9,7 @@ module Langchain::Tool
|
|
9
9
|
#
|
10
10
|
# - {Langchain::Tool::Calculator}: Calculate the result of a math expression
|
11
11
|
# - {Langchain::Tool::RubyCodeInterpretor}: Runs ruby code
|
12
|
-
# - {Langchain::Tool::
|
12
|
+
# - {Langchain::Tool::GoogleSearch}: search on Google (via SerpAPI)
|
13
13
|
# - {Langchain::Tool::Wikipedia}: search on Wikipedia
|
14
14
|
#
|
15
15
|
# == Usage
|
@@ -30,13 +30,13 @@ module Langchain::Tool
|
|
30
30
|
# agent = Langchain::Agent::ChainOfThoughtAgent.new(
|
31
31
|
# llm: :openai, # or :cohere, :hugging_face, :google_palm or :replicate
|
32
32
|
# llm_api_key: ENV["OPENAI_API_KEY"],
|
33
|
-
# tools: ["
|
33
|
+
# tools: ["google_search", "calculator", "wikipedia"]
|
34
34
|
# )
|
35
35
|
#
|
36
36
|
# 4. Confirm that the Agent is using the Tools you passed in:
|
37
37
|
#
|
38
38
|
# agent.tools
|
39
|
-
# # => ["
|
39
|
+
# # => ["google_search", "calculator", "wikipedia"]
|
40
40
|
#
|
41
41
|
# == Adding Tools
|
42
42
|
#
|
@@ -57,6 +57,12 @@ module Langchain::Tool
|
|
57
57
|
self.class.const_get(:NAME)
|
58
58
|
end
|
59
59
|
|
60
|
+
def self.logger_options
|
61
|
+
{
|
62
|
+
color: :light_blue
|
63
|
+
}
|
64
|
+
end
|
65
|
+
|
60
66
|
#
|
61
67
|
# Returns the DESCRIPTION constant of the tool
|
62
68
|
#
|
@@ -16,6 +16,11 @@ module Langchain::Tool
|
|
16
16
|
Useful for getting the result of a math expression.
|
17
17
|
|
18
18
|
The input to this tool should be a valid mathematical expression that could be executed by a simple calculator.
|
19
|
+
Usage:
|
20
|
+
Action Input: 1 + 1
|
21
|
+
Action Input: 3 * 2 / 4
|
22
|
+
Action Input: 9 - 7
|
23
|
+
Action Input: (4.1 + 2.3) / (2.0 - 5.6) * 3
|
19
24
|
DESC
|
20
25
|
|
21
26
|
def initialize
|
@@ -28,18 +33,11 @@ module Langchain::Tool
|
|
28
33
|
# @param input [String] math expression
|
29
34
|
# @return [String] Answer
|
30
35
|
def execute(input:)
|
31
|
-
Langchain.logger.info("
|
36
|
+
Langchain.logger.info("Executing \"#{input}\"", for: self.class)
|
32
37
|
|
33
38
|
Eqn::Calculator.calc(input)
|
34
39
|
rescue Eqn::ParseError, Eqn::NoVariableValueError
|
35
|
-
#
|
36
|
-
# We can use the google answer box to evaluate this expression
|
37
|
-
# TODO: Figure out to find a better way to evaluate these language expressions.
|
38
|
-
hash_results = Langchain::Tool::SerpApi
|
39
|
-
.new(api_key: ENV["SERPAPI_API_KEY"])
|
40
|
-
.execute_search(input: input)
|
41
|
-
hash_results.dig(:answer_box, :to) ||
|
42
|
-
hash_results.dig(:answer_box, :result)
|
40
|
+
"\"#{input}\" is an invalid mathematical expression"
|
43
41
|
end
|
44
42
|
end
|
45
43
|
end
|
@@ -14,15 +14,18 @@ module Langchain::Tool
|
|
14
14
|
The input to this tool should be valid SQL.
|
15
15
|
DESC
|
16
16
|
|
17
|
-
attr_reader :db
|
17
|
+
attr_reader :db, :requested_tables, :except_tables
|
18
18
|
|
19
19
|
#
|
20
20
|
# Establish a database connection
|
21
21
|
#
|
22
22
|
# @param connection_string [String] Database connection info, e.g. 'postgres://user:password@localhost:5432/db_name'
|
23
|
+
# @param tables [Array<Symbol>] The tables to use. Will use all if empty.
|
24
|
+
# @param except_tables [Array<Symbol>] The tables to exclude. Will exclude none if empty.
|
25
|
+
|
23
26
|
# @return [Database] Database object
|
24
27
|
#
|
25
|
-
def initialize(connection_string:)
|
28
|
+
def initialize(connection_string:, tables: [], except_tables: [])
|
26
29
|
depends_on "sequel"
|
27
30
|
require "sequel"
|
28
31
|
require "sequel/extensions/schema_dumper"
|
@@ -30,7 +33,8 @@ module Langchain::Tool
|
|
30
33
|
raise StandardError, "connection_string parameter cannot be blank" if connection_string.empty?
|
31
34
|
|
32
35
|
@db = Sequel.connect(connection_string)
|
33
|
-
@
|
36
|
+
@requested_tables = tables
|
37
|
+
@except_tables = except_tables
|
34
38
|
end
|
35
39
|
|
36
40
|
#
|
@@ -38,9 +42,26 @@ module Langchain::Tool
|
|
38
42
|
#
|
39
43
|
# @return [String] schema
|
40
44
|
#
|
41
|
-
def
|
42
|
-
Langchain.logger.info("
|
43
|
-
|
45
|
+
def dump_schema
|
46
|
+
Langchain.logger.info("Dumping schema tables and keys", for: self.class)
|
47
|
+
schema = ""
|
48
|
+
db.tables.each do |table|
|
49
|
+
next if except_tables.include?(table)
|
50
|
+
next unless requested_tables.empty? || requested_tables.include?(table)
|
51
|
+
|
52
|
+
schema << "CREATE TABLE #{table}(\n"
|
53
|
+
db.schema(table).each do |column|
|
54
|
+
schema << "#{column[0]} #{column[1][:type]}"
|
55
|
+
schema << " PRIMARY KEY" if column[1][:primary_key] == true
|
56
|
+
schema << "," unless column == db.schema(table).last
|
57
|
+
schema << "\n"
|
58
|
+
end
|
59
|
+
schema << ");\n"
|
60
|
+
db.foreign_key_list(table).each do |fk|
|
61
|
+
schema << "ALTER TABLE #{table} ADD FOREIGN KEY (#{fk[:columns][0]}) REFERENCES #{fk[:table]}(#{fk[:key][0]});\n"
|
62
|
+
end
|
63
|
+
end
|
64
|
+
schema
|
44
65
|
end
|
45
66
|
|
46
67
|
#
|
@@ -50,11 +71,11 @@ module Langchain::Tool
|
|
50
71
|
# @return [Array] results
|
51
72
|
#
|
52
73
|
def execute(input:)
|
53
|
-
Langchain.logger.info("
|
74
|
+
Langchain.logger.info("Executing \"#{input}\"", for: self.class)
|
54
75
|
|
55
76
|
db[input].to_a
|
56
77
|
rescue Sequel::DatabaseError => e
|
57
|
-
Langchain.logger.error(
|
78
|
+
Langchain.logger.error(e.message, for: self.class)
|
58
79
|
end
|
59
80
|
end
|
60
81
|
end
|
@@ -1,18 +1,18 @@
|
|
1
1
|
# frozen_string_literal: true
|
2
2
|
|
3
3
|
module Langchain::Tool
|
4
|
-
class
|
4
|
+
class GoogleSearch < Base
|
5
5
|
#
|
6
|
-
# Wrapper around
|
6
|
+
# Wrapper around Google Serp SPI
|
7
7
|
#
|
8
8
|
# Gem requirements: gem "google_search_results", "~> 2.0.0"
|
9
9
|
#
|
10
10
|
# Usage:
|
11
|
-
# search = Langchain::Tool::
|
11
|
+
# search = Langchain::Tool::GoogleSearch.new(api_key: "YOUR_API_KEY")
|
12
12
|
# search.execute(input: "What is the capital of France?")
|
13
13
|
#
|
14
14
|
|
15
|
-
NAME = "
|
15
|
+
NAME = "google_search"
|
16
16
|
|
17
17
|
description <<~DESC
|
18
18
|
A wrapper around Google Search.
|
@@ -26,10 +26,10 @@ module Langchain::Tool
|
|
26
26
|
attr_reader :api_key
|
27
27
|
|
28
28
|
#
|
29
|
-
# Initializes the
|
29
|
+
# Initializes the Google Search tool
|
30
30
|
#
|
31
|
-
# @param api_key [String]
|
32
|
-
# @return [Langchain::Tool::
|
31
|
+
# @param api_key [String] Search API key
|
32
|
+
# @return [Langchain::Tool::GoogleSearch] Google search tool
|
33
33
|
#
|
34
34
|
def initialize(api_key:)
|
35
35
|
depends_on "google_search_results"
|
@@ -54,7 +54,7 @@ module Langchain::Tool
|
|
54
54
|
# @return [String] Answer
|
55
55
|
#
|
56
56
|
def execute(input:)
|
57
|
-
Langchain.logger.info("
|
57
|
+
Langchain.logger.info("Executing \"#{input}\"", for: self.class)
|
58
58
|
|
59
59
|
hash_results = execute_search(input: input)
|
60
60
|
|
@@ -72,7 +72,7 @@ module Langchain::Tool
|
|
72
72
|
# @return [Hash] hash_results JSON
|
73
73
|
#
|
74
74
|
def execute_search(input:)
|
75
|
-
GoogleSearch
|
75
|
+
::GoogleSearch
|
76
76
|
.new(q: input, serp_api_key: api_key)
|
77
77
|
.get_hash
|
78
78
|
end
|
@@ -21,7 +21,7 @@ module Langchain::Tool
|
|
21
21
|
# @param input [String] ruby code expression
|
22
22
|
# @return [String] Answer
|
23
23
|
def execute(input:)
|
24
|
-
Langchain.logger.info("
|
24
|
+
Langchain.logger.info("Executing \"#{input}\"", for: self.class)
|
25
25
|
|
26
26
|
safe_eval(input)
|
27
27
|
end
|
@@ -21,7 +21,7 @@ module Langchain::Tool
|
|
21
21
|
|
22
22
|
description <<~DESC
|
23
23
|
Useful for getting current weather data
|
24
|
-
|
24
|
+
|
25
25
|
The input to this tool should be a city name followed by the units (imperial, metric, or standard)
|
26
26
|
Usage:
|
27
27
|
Action Input: St Louis, Missouri; metric
|
@@ -54,7 +54,7 @@ module Langchain::Tool
|
|
54
54
|
# @param input [String] comma separated city and unit (optional: imperial, metric, or standard)
|
55
55
|
# @return [String] Answer
|
56
56
|
def execute(input:)
|
57
|
-
Langchain.logger.info("
|
57
|
+
Langchain.logger.info("Executing for \"#{input}\"", for: self.class)
|
58
58
|
|
59
59
|
input_array = input.split(";")
|
60
60
|
city, units = *input_array.map(&:strip)
|
@@ -26,7 +26,7 @@ module Langchain::Tool
|
|
26
26
|
# @param input [String] search query
|
27
27
|
# @return [String] Answer
|
28
28
|
def execute(input:)
|
29
|
-
Langchain.logger.info("
|
29
|
+
Langchain.logger.info("Executing \"#{input}\"", for: self.class)
|
30
30
|
|
31
31
|
page = ::Wikipedia.find(input)
|
32
32
|
# It would be nice to figure out a way to provide page.content but the LLM token limit is an issue
|
@@ -0,0 +1,38 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Langchain
|
4
|
+
module Utils
|
5
|
+
module TokenLength
|
6
|
+
#
|
7
|
+
# Calculate the `max_tokens:` parameter to be set by calculating the context length of the text minus the prompt length
|
8
|
+
#
|
9
|
+
# @param content [String | Array<String>] The text or array of texts to validate
|
10
|
+
# @param model_name [String] The model name to validate against
|
11
|
+
# @return [Integer] Whether the text is valid or not
|
12
|
+
# @raise [TokenLimitExceeded] If the text is too long
|
13
|
+
#
|
14
|
+
class BaseValidator
|
15
|
+
def self.validate_max_tokens!(content, model_name, options = {})
|
16
|
+
text_token_length = if content.is_a?(Array)
|
17
|
+
content.sum { |item| token_length(item.to_json, model_name, options) }
|
18
|
+
else
|
19
|
+
token_length(content, model_name, options)
|
20
|
+
end
|
21
|
+
|
22
|
+
leftover_tokens = token_limit(model_name) - text_token_length
|
23
|
+
|
24
|
+
# Raise an error even if whole prompt is equal to the model's token limit (leftover_tokens == 0)
|
25
|
+
if leftover_tokens <= 0
|
26
|
+
raise limit_exceeded_exception(token_limit(model_name), text_token_length)
|
27
|
+
end
|
28
|
+
|
29
|
+
leftover_tokens
|
30
|
+
end
|
31
|
+
|
32
|
+
def self.limit_exceeded_exception(limit, length)
|
33
|
+
TokenLimitExceeded.new("This model's maximum context length is #{limit} tokens, but the given text is #{length} tokens long.", length - limit)
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
38
|
+
end
|
@@ -7,7 +7,7 @@ module Langchain
|
|
7
7
|
# This class is meant to validate the length of the text passed in to Google Palm's API.
|
8
8
|
# It is used to validate the token length before the API call is made
|
9
9
|
#
|
10
|
-
class GooglePalmValidator
|
10
|
+
class GooglePalmValidator < BaseValidator
|
11
11
|
TOKEN_LIMITS = {
|
12
12
|
# Source:
|
13
13
|
# This data can be pulled when `list_models()` method is called: https://github.com/andreibondarev/google_palm_api#usage
|
@@ -26,43 +26,23 @@ module Langchain
|
|
26
26
|
# }
|
27
27
|
}.freeze
|
28
28
|
|
29
|
-
#
|
30
|
-
# Validate the context length of the text
|
31
|
-
#
|
32
|
-
# @param content [String | Array<String>] The text or array of texts to validate
|
33
|
-
# @param model_name [String] The model name to validate against
|
34
|
-
# @return [Integer] Whether the text is valid or not
|
35
|
-
# @raise [TokenLimitExceeded] If the text is too long
|
36
|
-
#
|
37
|
-
def self.validate_max_tokens!(google_palm_llm, content, model_name)
|
38
|
-
text_token_length = if content.is_a?(Array)
|
39
|
-
content.sum { |item| token_length(google_palm_llm, item.to_json, model_name) }
|
40
|
-
else
|
41
|
-
token_length(google_palm_llm, content, model_name)
|
42
|
-
end
|
43
|
-
|
44
|
-
leftover_tokens = TOKEN_LIMITS.dig(model_name, "input_token_limit") - text_token_length
|
45
|
-
|
46
|
-
# Raise an error even if whole prompt is equal to the model's token limit (leftover_tokens == 0)
|
47
|
-
if leftover_tokens <= 0
|
48
|
-
raise TokenLimitExceeded, "This model's maximum context length is #{TOKEN_LIMITS.dig(model_name, "input_token_limit")} tokens, but the given text is #{text_token_length} tokens long."
|
49
|
-
end
|
50
|
-
|
51
|
-
leftover_tokens
|
52
|
-
end
|
53
|
-
|
54
29
|
#
|
55
30
|
# Calculate token length for a given text and model name
|
56
31
|
#
|
57
|
-
# @param llm [Langchain::LLM:GooglePalm] The Langchain::LLM:GooglePalm instance
|
58
32
|
# @param text [String] The text to calculate the token length for
|
59
33
|
# @param model_name [String] The model name to validate against
|
34
|
+
# @param options [Hash] the options to create a message with
|
35
|
+
# @option options [Langchain::LLM:GooglePalm] :llm The Langchain::LLM:GooglePalm instance
|
60
36
|
# @return [Integer] The token length of the text
|
61
37
|
#
|
62
|
-
def self.token_length(
|
63
|
-
response = llm.client.count_message_tokens(model: model_name, prompt: text)
|
38
|
+
def self.token_length(text, model_name = "chat-bison-001", options)
|
39
|
+
response = options[:llm].client.count_message_tokens(model: model_name, prompt: text)
|
64
40
|
response.dig("tokenCount")
|
65
41
|
end
|
42
|
+
|
43
|
+
def self.token_limit(model_name)
|
44
|
+
TOKEN_LIMITS.dig(model_name, "input_token_limit")
|
45
|
+
end
|
66
46
|
end
|
67
47
|
end
|
68
48
|
end
|