langchainrb 0.5.5 → 0.5.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +12 -0
  3. data/Gemfile.lock +3 -1
  4. data/README.md +7 -5
  5. data/examples/store_and_query_with_pinecone.rb +5 -4
  6. data/lib/langchain/agent/base.rb +5 -0
  7. data/lib/langchain/agent/chain_of_thought_agent/chain_of_thought_agent.rb +22 -10
  8. data/lib/langchain/agent/chain_of_thought_agent/chain_of_thought_agent_prompt.yaml +26 -0
  9. data/lib/langchain/agent/sql_query_agent/sql_query_agent.rb +8 -8
  10. data/lib/langchain/agent/sql_query_agent/sql_query_agent_answer_prompt.yaml +11 -0
  11. data/lib/langchain/agent/sql_query_agent/sql_query_agent_sql_prompt.yaml +21 -0
  12. data/lib/langchain/chunker/base.rb +15 -0
  13. data/lib/langchain/chunker/text.rb +38 -0
  14. data/lib/langchain/contextual_logger.rb +60 -0
  15. data/lib/langchain/conversation.rb +35 -4
  16. data/lib/langchain/data.rb +4 -0
  17. data/lib/langchain/llm/ai21.rb +16 -2
  18. data/lib/langchain/llm/cohere.rb +5 -4
  19. data/lib/langchain/llm/google_palm.rb +15 -7
  20. data/lib/langchain/llm/openai.rb +67 -17
  21. data/lib/langchain/llm/prompts/summarize_template.yaml +9 -0
  22. data/lib/langchain/llm/replicate.rb +6 -5
  23. data/lib/langchain/prompt/base.rb +2 -2
  24. data/lib/langchain/tool/base.rb +9 -3
  25. data/lib/langchain/tool/calculator.rb +7 -9
  26. data/lib/langchain/tool/database.rb +29 -8
  27. data/lib/langchain/tool/{serp_api.rb → google_search.rb} +9 -9
  28. data/lib/langchain/tool/ruby_code_interpreter.rb +1 -1
  29. data/lib/langchain/tool/weather.rb +2 -2
  30. data/lib/langchain/tool/wikipedia.rb +1 -1
  31. data/lib/langchain/utils/token_length/base_validator.rb +38 -0
  32. data/lib/langchain/utils/token_length/google_palm_validator.rb +9 -29
  33. data/lib/langchain/utils/token_length/openai_validator.rb +10 -27
  34. data/lib/langchain/utils/token_length/token_limit_exceeded.rb +17 -0
  35. data/lib/langchain/vectorsearch/base.rb +6 -0
  36. data/lib/langchain/vectorsearch/chroma.rb +1 -1
  37. data/lib/langchain/vectorsearch/hnswlib.rb +2 -2
  38. data/lib/langchain/vectorsearch/milvus.rb +1 -14
  39. data/lib/langchain/vectorsearch/pgvector.rb +1 -5
  40. data/lib/langchain/vectorsearch/pinecone.rb +1 -4
  41. data/lib/langchain/vectorsearch/qdrant.rb +1 -4
  42. data/lib/langchain/vectorsearch/weaviate.rb +1 -4
  43. data/lib/langchain/version.rb +1 -1
  44. data/lib/langchain.rb +28 -12
  45. metadata +30 -11
  46. data/lib/langchain/agent/chain_of_thought_agent/chain_of_thought_agent_prompt.json +0 -10
  47. data/lib/langchain/agent/sql_query_agent/sql_query_agent_answer_prompt.json +0 -10
  48. data/lib/langchain/agent/sql_query_agent/sql_query_agent_sql_prompt.json +0 -10
  49. data/lib/langchain/llm/prompts/summarize_template.json +0 -5
@@ -22,14 +22,19 @@ module Langchain::LLM
22
22
 
23
23
  DEFAULTS = {
24
24
  temperature: 0.0,
25
- dimension: 768 # This is what the `embedding-gecko-001` model generates
25
+ dimension: 768, # This is what the `embedding-gecko-001` model generates
26
+ completion_model_name: "text-bison-001",
27
+ chat_completion_model_name: "chat-bison-001",
28
+ embeddings_model_name: "embedding-gecko-001"
26
29
  }.freeze
30
+ LENGTH_VALIDATOR = Langchain::Utils::TokenLength::GooglePalmValidator
27
31
 
28
- def initialize(api_key:)
32
+ def initialize(api_key:, default_options: {})
29
33
  depends_on "google_palm_api"
30
34
  require "google_palm_api"
31
35
 
32
36
  @client = ::GooglePalmApi::Client.new(api_key: api_key)
37
+ @defaults = DEFAULTS.merge(default_options)
33
38
  end
34
39
 
35
40
  #
@@ -55,7 +60,8 @@ module Langchain::LLM
55
60
  def complete(prompt:, **params)
56
61
  default_params = {
57
62
  prompt: prompt,
58
- temperature: DEFAULTS[:temperature]
63
+ temperature: @defaults[:temperature],
64
+ completion_model_name: @defaults[:completion_model_name]
59
65
  }
60
66
 
61
67
  if params[:stop_sequences]
@@ -84,13 +90,15 @@ module Langchain::LLM
84
90
  raise ArgumentError.new(":prompt or :messages argument is expected") if prompt.empty? && messages.empty?
85
91
 
86
92
  default_params = {
87
- temperature: DEFAULTS[:temperature],
93
+ temperature: @defaults[:temperature],
94
+ chat_completion_model_name: @defaults[:chat_completion_model_name],
88
95
  context: context,
89
96
  messages: compose_chat_messages(prompt: prompt, messages: messages),
90
97
  examples: compose_examples(examples)
91
98
  }
92
99
 
93
- Langchain::Utils::TokenLength::GooglePalmValidator.validate_max_tokens!(self, default_params[:messages], "chat-bison-001")
100
+ # chat-bison-001 is the only model that currently supports countMessageTokens functions
101
+ LENGTH_VALIDATOR.validate_max_tokens!(default_params[:messages], "chat-bison-001", llm: self)
94
102
 
95
103
  if options[:stop_sequences]
96
104
  default_params[:stop] = options.delete(:stop_sequences)
@@ -116,13 +124,13 @@ module Langchain::LLM
116
124
  #
117
125
  def summarize(text:)
118
126
  prompt_template = Langchain::Prompt.load_from_path(
119
- file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.json")
127
+ file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.yaml")
120
128
  )
121
129
  prompt = prompt_template.format(text: text)
122
130
 
123
131
  complete(
124
132
  prompt: prompt,
125
- temperature: DEFAULTS[:temperature],
133
+ temperature: @defaults[:temperature],
126
134
  # Most models have a context length of 2048 tokens (except for the newest models, which support 4096).
127
135
  max_tokens: 2048
128
136
  )
@@ -17,12 +17,14 @@ module Langchain::LLM
17
17
  embeddings_model_name: "text-embedding-ada-002",
18
18
  dimension: 1536
19
19
  }.freeze
20
+ LENGTH_VALIDATOR = Langchain::Utils::TokenLength::OpenAIValidator
20
21
 
21
- def initialize(api_key:, llm_options: {})
22
+ def initialize(api_key:, llm_options: {}, default_options: {})
22
23
  depends_on "ruby-openai"
23
24
  require "openai"
24
25
 
25
26
  @client = ::OpenAI::Client.new(access_token: api_key, **llm_options)
27
+ @defaults = DEFAULTS.merge(default_options)
26
28
  end
27
29
 
28
30
  #
@@ -33,9 +35,9 @@ module Langchain::LLM
33
35
  # @return [Array] The embedding
34
36
  #
35
37
  def embed(text:, **params)
36
- parameters = {model: DEFAULTS[:embeddings_model_name], input: text}
38
+ parameters = {model: @defaults[:embeddings_model_name], input: text}
37
39
 
38
- Langchain::Utils::TokenLength::OpenAIValidator.validate_max_tokens!(text, parameters[:model])
40
+ validate_max_tokens(text, parameters[:model])
39
41
 
40
42
  response = client.embeddings(parameters: parameters.merge(params))
41
43
  response.dig("data").first.dig("embedding")
@@ -49,37 +51,85 @@ module Langchain::LLM
49
51
  # @return [String] The completion
50
52
  #
51
53
  def complete(prompt:, **params)
52
- parameters = compose_parameters DEFAULTS[:completion_model_name], params
54
+ parameters = compose_parameters @defaults[:completion_model_name], params
53
55
 
54
56
  parameters[:prompt] = prompt
55
- parameters[:max_tokens] = Langchain::Utils::TokenLength::OpenAIValidator.validate_max_tokens!(prompt, parameters[:model])
57
+ parameters[:max_tokens] = validate_max_tokens(prompt, parameters[:model])
56
58
 
57
59
  response = client.completions(parameters: parameters)
58
60
  response.dig("choices", 0, "text")
59
61
  end
60
62
 
61
63
  #
62
- # Generate a chat completion for a given prompt
64
+ # Generate a chat completion for a given prompt or messages.
65
+ #
66
+ # == Examples
67
+ #
68
+ # # simplest case, just give a prompt
69
+ # openai.chat prompt: "When was Ruby first released?"
70
+ #
71
+ # # prompt plus some context about how to respond
72
+ # openai.chat context: "You are RubyGPT, a helpful chat bot for helping people learn Ruby", prompt: "Does Ruby have a REPL like IPython?"
73
+ #
74
+ # # full control over messages that get sent, equivilent to the above
75
+ # openai.chat messages: [
76
+ # {
77
+ # role: "system",
78
+ # content: "You are RubyGPT, a helpful chat bot for helping people learn Ruby", prompt: "Does Ruby have a REPL like IPython?"
79
+ # },
80
+ # {
81
+ # role: "user",
82
+ # content: "When was Ruby first released?"
83
+ # }
84
+ # ]
85
+ #
86
+ # # few-short prompting with examples
87
+ # openai.chat prompt: "When was factory_bot released?",
88
+ # examples: [
89
+ # {
90
+ # role: "user",
91
+ # content: "When was Ruby on Rails released?"
92
+ # }
93
+ # {
94
+ # role: "assistant",
95
+ # content: "2004"
96
+ # },
97
+ # ]
63
98
  #
64
99
  # @param prompt [String] The prompt to generate a chat completion for
65
- # @param messages [Array] The messages that have been sent in the conversation
66
- # @param context [String] The context of the conversation
67
- # @param examples [Array] Examples of messages provide model with
68
- # @param options extra parameters passed to OpenAI::Client#chat
100
+ # @param messages [Array<Hash>] The messages that have been sent in the conversation
101
+ # Each message should be a Hash with the following keys:
102
+ # - :content [String] The content of the message
103
+ # - :role [String] The role of the sender (system, user, assistant, or function)
104
+ # @param context [String] An initial context to provide as a system message, ie "You are RubyGPT, a helpful chat bot for helping people learn Ruby"
105
+ # @param examples [Array<Hash>] Examples of messages to provide to the model. Useful for Few-Shot Prompting
106
+ # Each message should be a Hash with the following keys:
107
+ # - :content [String] The content of the message
108
+ # - :role [String] The role of the sender (system, user, assistant, or function)
109
+ # @param options <Hash> extra parameters passed to OpenAI::Client#chat
110
+ # @yield [String] Stream responses back one String at a time
69
111
  # @return [String] The chat completion
70
112
  #
71
113
  def chat(prompt: "", messages: [], context: "", examples: [], **options)
72
114
  raise ArgumentError.new(":prompt or :messages argument is expected") if prompt.empty? && messages.empty?
73
115
 
74
- parameters = compose_parameters DEFAULTS[:chat_completion_model_name], options
116
+ parameters = compose_parameters @defaults[:chat_completion_model_name], options
75
117
  parameters[:messages] = compose_chat_messages(prompt: prompt, messages: messages, context: context, examples: examples)
76
118
  parameters[:max_tokens] = validate_max_tokens(parameters[:messages], parameters[:model])
77
119
 
120
+ if (streaming = block_given?)
121
+ parameters[:stream] = proc do |chunk, _bytesize|
122
+ yield chunk.dig("choices", 0, "delta", "content")
123
+ end
124
+ end
125
+
78
126
  response = client.chat(parameters: parameters)
79
127
 
80
- raise "Chat completion failed: #{response}" if response.dig("error")
128
+ raise "Chat completion failed: #{response}" if !response.empty? && response.dig("error")
81
129
 
82
- response.dig("choices", 0, "message", "content")
130
+ unless streaming
131
+ response.dig("choices", 0, "message", "content")
132
+ end
83
133
  end
84
134
 
85
135
  #
@@ -90,17 +140,17 @@ module Langchain::LLM
90
140
  #
91
141
  def summarize(text:)
92
142
  prompt_template = Langchain::Prompt.load_from_path(
93
- file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.json")
143
+ file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.yaml")
94
144
  )
95
145
  prompt = prompt_template.format(text: text)
96
146
 
97
- complete(prompt: prompt, temperature: DEFAULTS[:temperature])
147
+ complete(prompt: prompt, temperature: @defaults[:temperature])
98
148
  end
99
149
 
100
150
  private
101
151
 
102
152
  def compose_parameters(model, params)
103
- default_params = {model: model, temperature: DEFAULTS[:temperature]}
153
+ default_params = {model: model, temperature: @defaults[:temperature]}
104
154
 
105
155
  default_params[:stop] = params.delete(:stop_sequences) if params[:stop_sequences]
106
156
 
@@ -140,7 +190,7 @@ module Langchain::LLM
140
190
  end
141
191
 
142
192
  def validate_max_tokens(messages, model)
143
- Langchain::Utils::TokenLength::OpenAIValidator.validate_max_tokens!(messages, model)
193
+ LENGTH_VALIDATOR.validate_max_tokens!(messages, model)
144
194
  end
145
195
  end
146
196
  end
@@ -0,0 +1,9 @@
1
+ _type: prompt
2
+ input_variables:
3
+ - text
4
+ template: |
5
+ Write a concise summary of the following:
6
+
7
+ {text}
8
+
9
+ CONCISE SUMMARY:
@@ -32,7 +32,7 @@ module Langchain::LLM
32
32
  #
33
33
  # @param api_key [String] The API key to use
34
34
  #
35
- def initialize(api_key:)
35
+ def initialize(api_key:, default_options: {})
36
36
  depends_on "replicate-ruby"
37
37
  require "replicate"
38
38
 
@@ -41,6 +41,7 @@ module Langchain::LLM
41
41
  end
42
42
 
43
43
  @client = ::Replicate.client
44
+ @defaults = DEFAULTS.merge(default_options)
44
45
  end
45
46
 
46
47
  #
@@ -94,13 +95,13 @@ module Langchain::LLM
94
95
  #
95
96
  def summarize(text:)
96
97
  prompt_template = Langchain::Prompt.load_from_path(
97
- file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.json")
98
+ file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.yaml")
98
99
  )
99
100
  prompt = prompt_template.format(text: text)
100
101
 
101
102
  complete(
102
103
  prompt: prompt,
103
- temperature: DEFAULTS[:temperature],
104
+ temperature: @defaults[:temperature],
104
105
  # Most models have a context length of 2048 tokens (except for the newest models, which support 4096).
105
106
  max_tokens: 2048
106
107
  )
@@ -111,11 +112,11 @@ module Langchain::LLM
111
112
  private
112
113
 
113
114
  def completion_model
114
- @completion_model ||= client.retrieve_model(DEFAULTS[:completion_model_name]).latest_version
115
+ @completion_model ||= client.retrieve_model(@defaults[:completion_model_name]).latest_version
115
116
  end
116
117
 
117
118
  def embeddings_model
118
- @embeddings_model ||= client.retrieve_model(DEFAULTS[:embeddings_model_name]).latest_version
119
+ @embeddings_model ||= client.retrieve_model(@defaults[:embeddings_model_name]).latest_version
119
120
  end
120
121
  end
121
122
  end
@@ -45,11 +45,11 @@ module Langchain::Prompt
45
45
  end
46
46
 
47
47
  #
48
- # Save the object to a file in JSON format.
48
+ # Save the object to a file in JSON or YAML format.
49
49
  #
50
50
  # @param file_path [String, Pathname] The path to the file to save the object to
51
51
  #
52
- # @raise [ArgumentError] If file_path doesn't end with .json
52
+ # @raise [ArgumentError] If file_path doesn't end with .json or .yaml or .yml
53
53
  #
54
54
  # @return [void]
55
55
  #
@@ -9,7 +9,7 @@ module Langchain::Tool
9
9
  #
10
10
  # - {Langchain::Tool::Calculator}: Calculate the result of a math expression
11
11
  # - {Langchain::Tool::RubyCodeInterpretor}: Runs ruby code
12
- # - {Langchain::Tool::Search}: search on Google (via SerpAPI)
12
+ # - {Langchain::Tool::GoogleSearch}: search on Google (via SerpAPI)
13
13
  # - {Langchain::Tool::Wikipedia}: search on Wikipedia
14
14
  #
15
15
  # == Usage
@@ -30,13 +30,13 @@ module Langchain::Tool
30
30
  # agent = Langchain::Agent::ChainOfThoughtAgent.new(
31
31
  # llm: :openai, # or :cohere, :hugging_face, :google_palm or :replicate
32
32
  # llm_api_key: ENV["OPENAI_API_KEY"],
33
- # tools: ["search", "calculator", "wikipedia"]
33
+ # tools: ["google_search", "calculator", "wikipedia"]
34
34
  # )
35
35
  #
36
36
  # 4. Confirm that the Agent is using the Tools you passed in:
37
37
  #
38
38
  # agent.tools
39
- # # => ["search", "calculator", "wikipedia"]
39
+ # # => ["google_search", "calculator", "wikipedia"]
40
40
  #
41
41
  # == Adding Tools
42
42
  #
@@ -57,6 +57,12 @@ module Langchain::Tool
57
57
  self.class.const_get(:NAME)
58
58
  end
59
59
 
60
+ def self.logger_options
61
+ {
62
+ color: :light_blue
63
+ }
64
+ end
65
+
60
66
  #
61
67
  # Returns the DESCRIPTION constant of the tool
62
68
  #
@@ -16,6 +16,11 @@ module Langchain::Tool
16
16
  Useful for getting the result of a math expression.
17
17
 
18
18
  The input to this tool should be a valid mathematical expression that could be executed by a simple calculator.
19
+ Usage:
20
+ Action Input: 1 + 1
21
+ Action Input: 3 * 2 / 4
22
+ Action Input: 9 - 7
23
+ Action Input: (4.1 + 2.3) / (2.0 - 5.6) * 3
19
24
  DESC
20
25
 
21
26
  def initialize
@@ -28,18 +33,11 @@ module Langchain::Tool
28
33
  # @param input [String] math expression
29
34
  # @return [String] Answer
30
35
  def execute(input:)
31
- Langchain.logger.info("[#{self.class.name}]".light_blue + ": Executing \"#{input}\"")
36
+ Langchain.logger.info("Executing \"#{input}\"", for: self.class)
32
37
 
33
38
  Eqn::Calculator.calc(input)
34
39
  rescue Eqn::ParseError, Eqn::NoVariableValueError
35
- # Sometimes the input is not a pure math expression, e.g: "12F in Celsius"
36
- # We can use the google answer box to evaluate this expression
37
- # TODO: Figure out to find a better way to evaluate these language expressions.
38
- hash_results = Langchain::Tool::SerpApi
39
- .new(api_key: ENV["SERPAPI_API_KEY"])
40
- .execute_search(input: input)
41
- hash_results.dig(:answer_box, :to) ||
42
- hash_results.dig(:answer_box, :result)
40
+ "\"#{input}\" is an invalid mathematical expression"
43
41
  end
44
42
  end
45
43
  end
@@ -14,15 +14,18 @@ module Langchain::Tool
14
14
  The input to this tool should be valid SQL.
15
15
  DESC
16
16
 
17
- attr_reader :db
17
+ attr_reader :db, :requested_tables, :except_tables
18
18
 
19
19
  #
20
20
  # Establish a database connection
21
21
  #
22
22
  # @param connection_string [String] Database connection info, e.g. 'postgres://user:password@localhost:5432/db_name'
23
+ # @param tables [Array<Symbol>] The tables to use. Will use all if empty.
24
+ # @param except_tables [Array<Symbol>] The tables to exclude. Will exclude none if empty.
25
+
23
26
  # @return [Database] Database object
24
27
  #
25
- def initialize(connection_string:)
28
+ def initialize(connection_string:, tables: [], except_tables: [])
26
29
  depends_on "sequel"
27
30
  require "sequel"
28
31
  require "sequel/extensions/schema_dumper"
@@ -30,7 +33,8 @@ module Langchain::Tool
30
33
  raise StandardError, "connection_string parameter cannot be blank" if connection_string.empty?
31
34
 
32
35
  @db = Sequel.connect(connection_string)
33
- @db.extension :schema_dumper
36
+ @requested_tables = tables
37
+ @except_tables = except_tables
34
38
  end
35
39
 
36
40
  #
@@ -38,9 +42,26 @@ module Langchain::Tool
38
42
  #
39
43
  # @return [String] schema
40
44
  #
41
- def schema
42
- Langchain.logger.info("[#{self.class.name}]".light_blue + ": Dumping schema")
43
- db.dump_schema_migration(same_db: true, indexes: false) unless db.adapter_scheme == :mock
45
+ def dump_schema
46
+ Langchain.logger.info("Dumping schema tables and keys", for: self.class)
47
+ schema = ""
48
+ db.tables.each do |table|
49
+ next if except_tables.include?(table)
50
+ next unless requested_tables.empty? || requested_tables.include?(table)
51
+
52
+ schema << "CREATE TABLE #{table}(\n"
53
+ db.schema(table).each do |column|
54
+ schema << "#{column[0]} #{column[1][:type]}"
55
+ schema << " PRIMARY KEY" if column[1][:primary_key] == true
56
+ schema << "," unless column == db.schema(table).last
57
+ schema << "\n"
58
+ end
59
+ schema << ");\n"
60
+ db.foreign_key_list(table).each do |fk|
61
+ schema << "ALTER TABLE #{table} ADD FOREIGN KEY (#{fk[:columns][0]}) REFERENCES #{fk[:table]}(#{fk[:key][0]});\n"
62
+ end
63
+ end
64
+ schema
44
65
  end
45
66
 
46
67
  #
@@ -50,11 +71,11 @@ module Langchain::Tool
50
71
  # @return [Array] results
51
72
  #
52
73
  def execute(input:)
53
- Langchain.logger.info("[#{self.class.name}]".light_blue + ": Executing \"#{input}\"")
74
+ Langchain.logger.info("Executing \"#{input}\"", for: self.class)
54
75
 
55
76
  db[input].to_a
56
77
  rescue Sequel::DatabaseError => e
57
- Langchain.logger.error("[#{self.class.name}]".light_red + ": #{e.message}")
78
+ Langchain.logger.error(e.message, for: self.class)
58
79
  end
59
80
  end
60
81
  end
@@ -1,18 +1,18 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module Langchain::Tool
4
- class SerpApi < Base
4
+ class GoogleSearch < Base
5
5
  #
6
- # Wrapper around SerpAPI
6
+ # Wrapper around Google Serp SPI
7
7
  #
8
8
  # Gem requirements: gem "google_search_results", "~> 2.0.0"
9
9
  #
10
10
  # Usage:
11
- # search = Langchain::Tool::SerpApi.new(api_key: "YOUR_API_KEY")
11
+ # search = Langchain::Tool::GoogleSearch.new(api_key: "YOUR_API_KEY")
12
12
  # search.execute(input: "What is the capital of France?")
13
13
  #
14
14
 
15
- NAME = "search"
15
+ NAME = "google_search"
16
16
 
17
17
  description <<~DESC
18
18
  A wrapper around Google Search.
@@ -26,10 +26,10 @@ module Langchain::Tool
26
26
  attr_reader :api_key
27
27
 
28
28
  #
29
- # Initializes the SerpAPI tool
29
+ # Initializes the Google Search tool
30
30
  #
31
- # @param api_key [String] SerpAPI API key
32
- # @return [Langchain::Tool::SerpApi] SerpAPI tool
31
+ # @param api_key [String] Search API key
32
+ # @return [Langchain::Tool::GoogleSearch] Google search tool
33
33
  #
34
34
  def initialize(api_key:)
35
35
  depends_on "google_search_results"
@@ -54,7 +54,7 @@ module Langchain::Tool
54
54
  # @return [String] Answer
55
55
  #
56
56
  def execute(input:)
57
- Langchain.logger.info("[#{self.class.name}]".light_blue + ": Executing \"#{input}\"")
57
+ Langchain.logger.info("Executing \"#{input}\"", for: self.class)
58
58
 
59
59
  hash_results = execute_search(input: input)
60
60
 
@@ -72,7 +72,7 @@ module Langchain::Tool
72
72
  # @return [Hash] hash_results JSON
73
73
  #
74
74
  def execute_search(input:)
75
- GoogleSearch
75
+ ::GoogleSearch
76
76
  .new(q: input, serp_api_key: api_key)
77
77
  .get_hash
78
78
  end
@@ -21,7 +21,7 @@ module Langchain::Tool
21
21
  # @param input [String] ruby code expression
22
22
  # @return [String] Answer
23
23
  def execute(input:)
24
- Langchain.logger.info("[#{self.class.name}]".light_blue + ": Executing \"#{input}\"")
24
+ Langchain.logger.info("Executing \"#{input}\"", for: self.class)
25
25
 
26
26
  safe_eval(input)
27
27
  end
@@ -21,7 +21,7 @@ module Langchain::Tool
21
21
 
22
22
  description <<~DESC
23
23
  Useful for getting current weather data
24
-
24
+
25
25
  The input to this tool should be a city name followed by the units (imperial, metric, or standard)
26
26
  Usage:
27
27
  Action Input: St Louis, Missouri; metric
@@ -54,7 +54,7 @@ module Langchain::Tool
54
54
  # @param input [String] comma separated city and unit (optional: imperial, metric, or standard)
55
55
  # @return [String] Answer
56
56
  def execute(input:)
57
- Langchain.logger.info("[#{self.class.name}]".light_blue + ": Executing for \"#{input}\"")
57
+ Langchain.logger.info("Executing for \"#{input}\"", for: self.class)
58
58
 
59
59
  input_array = input.split(";")
60
60
  city, units = *input_array.map(&:strip)
@@ -26,7 +26,7 @@ module Langchain::Tool
26
26
  # @param input [String] search query
27
27
  # @return [String] Answer
28
28
  def execute(input:)
29
- Langchain.logger.info("[#{self.class.name}]".light_blue + ": Executing \"#{input}\"")
29
+ Langchain.logger.info("Executing \"#{input}\"", for: self.class)
30
30
 
31
31
  page = ::Wikipedia.find(input)
32
32
  # It would be nice to figure out a way to provide page.content but the LLM token limit is an issue
@@ -0,0 +1,38 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain
4
+ module Utils
5
+ module TokenLength
6
+ #
7
+ # Calculate the `max_tokens:` parameter to be set by calculating the context length of the text minus the prompt length
8
+ #
9
+ # @param content [String | Array<String>] The text or array of texts to validate
10
+ # @param model_name [String] The model name to validate against
11
+ # @return [Integer] Whether the text is valid or not
12
+ # @raise [TokenLimitExceeded] If the text is too long
13
+ #
14
+ class BaseValidator
15
+ def self.validate_max_tokens!(content, model_name, options = {})
16
+ text_token_length = if content.is_a?(Array)
17
+ content.sum { |item| token_length(item.to_json, model_name, options) }
18
+ else
19
+ token_length(content, model_name, options)
20
+ end
21
+
22
+ leftover_tokens = token_limit(model_name) - text_token_length
23
+
24
+ # Raise an error even if whole prompt is equal to the model's token limit (leftover_tokens == 0)
25
+ if leftover_tokens <= 0
26
+ raise limit_exceeded_exception(token_limit(model_name), text_token_length)
27
+ end
28
+
29
+ leftover_tokens
30
+ end
31
+
32
+ def self.limit_exceeded_exception(limit, length)
33
+ TokenLimitExceeded.new("This model's maximum context length is #{limit} tokens, but the given text is #{length} tokens long.", length - limit)
34
+ end
35
+ end
36
+ end
37
+ end
38
+ end
@@ -7,7 +7,7 @@ module Langchain
7
7
  # This class is meant to validate the length of the text passed in to Google Palm's API.
8
8
  # It is used to validate the token length before the API call is made
9
9
  #
10
- class GooglePalmValidator
10
+ class GooglePalmValidator < BaseValidator
11
11
  TOKEN_LIMITS = {
12
12
  # Source:
13
13
  # This data can be pulled when `list_models()` method is called: https://github.com/andreibondarev/google_palm_api#usage
@@ -26,43 +26,23 @@ module Langchain
26
26
  # }
27
27
  }.freeze
28
28
 
29
- #
30
- # Validate the context length of the text
31
- #
32
- # @param content [String | Array<String>] The text or array of texts to validate
33
- # @param model_name [String] The model name to validate against
34
- # @return [Integer] Whether the text is valid or not
35
- # @raise [TokenLimitExceeded] If the text is too long
36
- #
37
- def self.validate_max_tokens!(google_palm_llm, content, model_name)
38
- text_token_length = if content.is_a?(Array)
39
- content.sum { |item| token_length(google_palm_llm, item.to_json, model_name) }
40
- else
41
- token_length(google_palm_llm, content, model_name)
42
- end
43
-
44
- leftover_tokens = TOKEN_LIMITS.dig(model_name, "input_token_limit") - text_token_length
45
-
46
- # Raise an error even if whole prompt is equal to the model's token limit (leftover_tokens == 0)
47
- if leftover_tokens <= 0
48
- raise TokenLimitExceeded, "This model's maximum context length is #{TOKEN_LIMITS.dig(model_name, "input_token_limit")} tokens, but the given text is #{text_token_length} tokens long."
49
- end
50
-
51
- leftover_tokens
52
- end
53
-
54
29
  #
55
30
  # Calculate token length for a given text and model name
56
31
  #
57
- # @param llm [Langchain::LLM:GooglePalm] The Langchain::LLM:GooglePalm instance
58
32
  # @param text [String] The text to calculate the token length for
59
33
  # @param model_name [String] The model name to validate against
34
+ # @param options [Hash] the options to create a message with
35
+ # @option options [Langchain::LLM:GooglePalm] :llm The Langchain::LLM:GooglePalm instance
60
36
  # @return [Integer] The token length of the text
61
37
  #
62
- def self.token_length(llm, text, model_name = "chat-bison-001")
63
- response = llm.client.count_message_tokens(model: model_name, prompt: text)
38
+ def self.token_length(text, model_name = "chat-bison-001", options)
39
+ response = options[:llm].client.count_message_tokens(model: model_name, prompt: text)
64
40
  response.dig("tokenCount")
65
41
  end
42
+
43
+ def self.token_limit(model_name)
44
+ TOKEN_LIMITS.dig(model_name, "input_token_limit")
45
+ end
66
46
  end
67
47
  end
68
48
  end