langchainrb 0.5.5 → 0.5.7

Sign up to get free protection for your applications and to get access to all the features.
Files changed (49) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +12 -0
  3. data/Gemfile.lock +3 -1
  4. data/README.md +7 -5
  5. data/examples/store_and_query_with_pinecone.rb +5 -4
  6. data/lib/langchain/agent/base.rb +5 -0
  7. data/lib/langchain/agent/chain_of_thought_agent/chain_of_thought_agent.rb +22 -10
  8. data/lib/langchain/agent/chain_of_thought_agent/chain_of_thought_agent_prompt.yaml +26 -0
  9. data/lib/langchain/agent/sql_query_agent/sql_query_agent.rb +8 -8
  10. data/lib/langchain/agent/sql_query_agent/sql_query_agent_answer_prompt.yaml +11 -0
  11. data/lib/langchain/agent/sql_query_agent/sql_query_agent_sql_prompt.yaml +21 -0
  12. data/lib/langchain/chunker/base.rb +15 -0
  13. data/lib/langchain/chunker/text.rb +38 -0
  14. data/lib/langchain/contextual_logger.rb +60 -0
  15. data/lib/langchain/conversation.rb +35 -4
  16. data/lib/langchain/data.rb +4 -0
  17. data/lib/langchain/llm/ai21.rb +16 -2
  18. data/lib/langchain/llm/cohere.rb +5 -4
  19. data/lib/langchain/llm/google_palm.rb +15 -7
  20. data/lib/langchain/llm/openai.rb +67 -17
  21. data/lib/langchain/llm/prompts/summarize_template.yaml +9 -0
  22. data/lib/langchain/llm/replicate.rb +6 -5
  23. data/lib/langchain/prompt/base.rb +2 -2
  24. data/lib/langchain/tool/base.rb +9 -3
  25. data/lib/langchain/tool/calculator.rb +7 -9
  26. data/lib/langchain/tool/database.rb +29 -8
  27. data/lib/langchain/tool/{serp_api.rb → google_search.rb} +9 -9
  28. data/lib/langchain/tool/ruby_code_interpreter.rb +1 -1
  29. data/lib/langchain/tool/weather.rb +2 -2
  30. data/lib/langchain/tool/wikipedia.rb +1 -1
  31. data/lib/langchain/utils/token_length/base_validator.rb +38 -0
  32. data/lib/langchain/utils/token_length/google_palm_validator.rb +9 -29
  33. data/lib/langchain/utils/token_length/openai_validator.rb +10 -27
  34. data/lib/langchain/utils/token_length/token_limit_exceeded.rb +17 -0
  35. data/lib/langchain/vectorsearch/base.rb +6 -0
  36. data/lib/langchain/vectorsearch/chroma.rb +1 -1
  37. data/lib/langchain/vectorsearch/hnswlib.rb +2 -2
  38. data/lib/langchain/vectorsearch/milvus.rb +1 -14
  39. data/lib/langchain/vectorsearch/pgvector.rb +1 -5
  40. data/lib/langchain/vectorsearch/pinecone.rb +1 -4
  41. data/lib/langchain/vectorsearch/qdrant.rb +1 -4
  42. data/lib/langchain/vectorsearch/weaviate.rb +1 -4
  43. data/lib/langchain/version.rb +1 -1
  44. data/lib/langchain.rb +28 -12
  45. metadata +30 -11
  46. data/lib/langchain/agent/chain_of_thought_agent/chain_of_thought_agent_prompt.json +0 -10
  47. data/lib/langchain/agent/sql_query_agent/sql_query_agent_answer_prompt.json +0 -10
  48. data/lib/langchain/agent/sql_query_agent/sql_query_agent_sql_prompt.json +0 -10
  49. data/lib/langchain/llm/prompts/summarize_template.json +0 -5
@@ -22,14 +22,19 @@ module Langchain::LLM
22
22
 
23
23
  DEFAULTS = {
24
24
  temperature: 0.0,
25
- dimension: 768 # This is what the `embedding-gecko-001` model generates
25
+ dimension: 768, # This is what the `embedding-gecko-001` model generates
26
+ completion_model_name: "text-bison-001",
27
+ chat_completion_model_name: "chat-bison-001",
28
+ embeddings_model_name: "embedding-gecko-001"
26
29
  }.freeze
30
+ LENGTH_VALIDATOR = Langchain::Utils::TokenLength::GooglePalmValidator
27
31
 
28
- def initialize(api_key:)
32
+ def initialize(api_key:, default_options: {})
29
33
  depends_on "google_palm_api"
30
34
  require "google_palm_api"
31
35
 
32
36
  @client = ::GooglePalmApi::Client.new(api_key: api_key)
37
+ @defaults = DEFAULTS.merge(default_options)
33
38
  end
34
39
 
35
40
  #
@@ -55,7 +60,8 @@ module Langchain::LLM
55
60
  def complete(prompt:, **params)
56
61
  default_params = {
57
62
  prompt: prompt,
58
- temperature: DEFAULTS[:temperature]
63
+ temperature: @defaults[:temperature],
64
+ completion_model_name: @defaults[:completion_model_name]
59
65
  }
60
66
 
61
67
  if params[:stop_sequences]
@@ -84,13 +90,15 @@ module Langchain::LLM
84
90
  raise ArgumentError.new(":prompt or :messages argument is expected") if prompt.empty? && messages.empty?
85
91
 
86
92
  default_params = {
87
- temperature: DEFAULTS[:temperature],
93
+ temperature: @defaults[:temperature],
94
+ chat_completion_model_name: @defaults[:chat_completion_model_name],
88
95
  context: context,
89
96
  messages: compose_chat_messages(prompt: prompt, messages: messages),
90
97
  examples: compose_examples(examples)
91
98
  }
92
99
 
93
- Langchain::Utils::TokenLength::GooglePalmValidator.validate_max_tokens!(self, default_params[:messages], "chat-bison-001")
100
+ # chat-bison-001 is the only model that currently supports countMessageTokens functions
101
+ LENGTH_VALIDATOR.validate_max_tokens!(default_params[:messages], "chat-bison-001", llm: self)
94
102
 
95
103
  if options[:stop_sequences]
96
104
  default_params[:stop] = options.delete(:stop_sequences)
@@ -116,13 +124,13 @@ module Langchain::LLM
116
124
  #
117
125
  def summarize(text:)
118
126
  prompt_template = Langchain::Prompt.load_from_path(
119
- file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.json")
127
+ file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.yaml")
120
128
  )
121
129
  prompt = prompt_template.format(text: text)
122
130
 
123
131
  complete(
124
132
  prompt: prompt,
125
- temperature: DEFAULTS[:temperature],
133
+ temperature: @defaults[:temperature],
126
134
  # Most models have a context length of 2048 tokens (except for the newest models, which support 4096).
127
135
  max_tokens: 2048
128
136
  )
@@ -17,12 +17,14 @@ module Langchain::LLM
17
17
  embeddings_model_name: "text-embedding-ada-002",
18
18
  dimension: 1536
19
19
  }.freeze
20
+ LENGTH_VALIDATOR = Langchain::Utils::TokenLength::OpenAIValidator
20
21
 
21
- def initialize(api_key:, llm_options: {})
22
+ def initialize(api_key:, llm_options: {}, default_options: {})
22
23
  depends_on "ruby-openai"
23
24
  require "openai"
24
25
 
25
26
  @client = ::OpenAI::Client.new(access_token: api_key, **llm_options)
27
+ @defaults = DEFAULTS.merge(default_options)
26
28
  end
27
29
 
28
30
  #
@@ -33,9 +35,9 @@ module Langchain::LLM
33
35
  # @return [Array] The embedding
34
36
  #
35
37
  def embed(text:, **params)
36
- parameters = {model: DEFAULTS[:embeddings_model_name], input: text}
38
+ parameters = {model: @defaults[:embeddings_model_name], input: text}
37
39
 
38
- Langchain::Utils::TokenLength::OpenAIValidator.validate_max_tokens!(text, parameters[:model])
40
+ validate_max_tokens(text, parameters[:model])
39
41
 
40
42
  response = client.embeddings(parameters: parameters.merge(params))
41
43
  response.dig("data").first.dig("embedding")
@@ -49,37 +51,85 @@ module Langchain::LLM
49
51
  # @return [String] The completion
50
52
  #
51
53
  def complete(prompt:, **params)
52
- parameters = compose_parameters DEFAULTS[:completion_model_name], params
54
+ parameters = compose_parameters @defaults[:completion_model_name], params
53
55
 
54
56
  parameters[:prompt] = prompt
55
- parameters[:max_tokens] = Langchain::Utils::TokenLength::OpenAIValidator.validate_max_tokens!(prompt, parameters[:model])
57
+ parameters[:max_tokens] = validate_max_tokens(prompt, parameters[:model])
56
58
 
57
59
  response = client.completions(parameters: parameters)
58
60
  response.dig("choices", 0, "text")
59
61
  end
60
62
 
61
63
  #
62
- # Generate a chat completion for a given prompt
64
+ # Generate a chat completion for a given prompt or messages.
65
+ #
66
+ # == Examples
67
+ #
68
+ # # simplest case, just give a prompt
69
+ # openai.chat prompt: "When was Ruby first released?"
70
+ #
71
+ # # prompt plus some context about how to respond
72
+ # openai.chat context: "You are RubyGPT, a helpful chat bot for helping people learn Ruby", prompt: "Does Ruby have a REPL like IPython?"
73
+ #
74
+ # # full control over messages that get sent, equivilent to the above
75
+ # openai.chat messages: [
76
+ # {
77
+ # role: "system",
78
+ # content: "You are RubyGPT, a helpful chat bot for helping people learn Ruby", prompt: "Does Ruby have a REPL like IPython?"
79
+ # },
80
+ # {
81
+ # role: "user",
82
+ # content: "When was Ruby first released?"
83
+ # }
84
+ # ]
85
+ #
86
+ # # few-short prompting with examples
87
+ # openai.chat prompt: "When was factory_bot released?",
88
+ # examples: [
89
+ # {
90
+ # role: "user",
91
+ # content: "When was Ruby on Rails released?"
92
+ # }
93
+ # {
94
+ # role: "assistant",
95
+ # content: "2004"
96
+ # },
97
+ # ]
63
98
  #
64
99
  # @param prompt [String] The prompt to generate a chat completion for
65
- # @param messages [Array] The messages that have been sent in the conversation
66
- # @param context [String] The context of the conversation
67
- # @param examples [Array] Examples of messages provide model with
68
- # @param options extra parameters passed to OpenAI::Client#chat
100
+ # @param messages [Array<Hash>] The messages that have been sent in the conversation
101
+ # Each message should be a Hash with the following keys:
102
+ # - :content [String] The content of the message
103
+ # - :role [String] The role of the sender (system, user, assistant, or function)
104
+ # @param context [String] An initial context to provide as a system message, ie "You are RubyGPT, a helpful chat bot for helping people learn Ruby"
105
+ # @param examples [Array<Hash>] Examples of messages to provide to the model. Useful for Few-Shot Prompting
106
+ # Each message should be a Hash with the following keys:
107
+ # - :content [String] The content of the message
108
+ # - :role [String] The role of the sender (system, user, assistant, or function)
109
+ # @param options <Hash> extra parameters passed to OpenAI::Client#chat
110
+ # @yield [String] Stream responses back one String at a time
69
111
  # @return [String] The chat completion
70
112
  #
71
113
  def chat(prompt: "", messages: [], context: "", examples: [], **options)
72
114
  raise ArgumentError.new(":prompt or :messages argument is expected") if prompt.empty? && messages.empty?
73
115
 
74
- parameters = compose_parameters DEFAULTS[:chat_completion_model_name], options
116
+ parameters = compose_parameters @defaults[:chat_completion_model_name], options
75
117
  parameters[:messages] = compose_chat_messages(prompt: prompt, messages: messages, context: context, examples: examples)
76
118
  parameters[:max_tokens] = validate_max_tokens(parameters[:messages], parameters[:model])
77
119
 
120
+ if (streaming = block_given?)
121
+ parameters[:stream] = proc do |chunk, _bytesize|
122
+ yield chunk.dig("choices", 0, "delta", "content")
123
+ end
124
+ end
125
+
78
126
  response = client.chat(parameters: parameters)
79
127
 
80
- raise "Chat completion failed: #{response}" if response.dig("error")
128
+ raise "Chat completion failed: #{response}" if !response.empty? && response.dig("error")
81
129
 
82
- response.dig("choices", 0, "message", "content")
130
+ unless streaming
131
+ response.dig("choices", 0, "message", "content")
132
+ end
83
133
  end
84
134
 
85
135
  #
@@ -90,17 +140,17 @@ module Langchain::LLM
90
140
  #
91
141
  def summarize(text:)
92
142
  prompt_template = Langchain::Prompt.load_from_path(
93
- file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.json")
143
+ file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.yaml")
94
144
  )
95
145
  prompt = prompt_template.format(text: text)
96
146
 
97
- complete(prompt: prompt, temperature: DEFAULTS[:temperature])
147
+ complete(prompt: prompt, temperature: @defaults[:temperature])
98
148
  end
99
149
 
100
150
  private
101
151
 
102
152
  def compose_parameters(model, params)
103
- default_params = {model: model, temperature: DEFAULTS[:temperature]}
153
+ default_params = {model: model, temperature: @defaults[:temperature]}
104
154
 
105
155
  default_params[:stop] = params.delete(:stop_sequences) if params[:stop_sequences]
106
156
 
@@ -140,7 +190,7 @@ module Langchain::LLM
140
190
  end
141
191
 
142
192
  def validate_max_tokens(messages, model)
143
- Langchain::Utils::TokenLength::OpenAIValidator.validate_max_tokens!(messages, model)
193
+ LENGTH_VALIDATOR.validate_max_tokens!(messages, model)
144
194
  end
145
195
  end
146
196
  end
@@ -0,0 +1,9 @@
1
+ _type: prompt
2
+ input_variables:
3
+ - text
4
+ template: |
5
+ Write a concise summary of the following:
6
+
7
+ {text}
8
+
9
+ CONCISE SUMMARY:
@@ -32,7 +32,7 @@ module Langchain::LLM
32
32
  #
33
33
  # @param api_key [String] The API key to use
34
34
  #
35
- def initialize(api_key:)
35
+ def initialize(api_key:, default_options: {})
36
36
  depends_on "replicate-ruby"
37
37
  require "replicate"
38
38
 
@@ -41,6 +41,7 @@ module Langchain::LLM
41
41
  end
42
42
 
43
43
  @client = ::Replicate.client
44
+ @defaults = DEFAULTS.merge(default_options)
44
45
  end
45
46
 
46
47
  #
@@ -94,13 +95,13 @@ module Langchain::LLM
94
95
  #
95
96
  def summarize(text:)
96
97
  prompt_template = Langchain::Prompt.load_from_path(
97
- file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.json")
98
+ file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.yaml")
98
99
  )
99
100
  prompt = prompt_template.format(text: text)
100
101
 
101
102
  complete(
102
103
  prompt: prompt,
103
- temperature: DEFAULTS[:temperature],
104
+ temperature: @defaults[:temperature],
104
105
  # Most models have a context length of 2048 tokens (except for the newest models, which support 4096).
105
106
  max_tokens: 2048
106
107
  )
@@ -111,11 +112,11 @@ module Langchain::LLM
111
112
  private
112
113
 
113
114
  def completion_model
114
- @completion_model ||= client.retrieve_model(DEFAULTS[:completion_model_name]).latest_version
115
+ @completion_model ||= client.retrieve_model(@defaults[:completion_model_name]).latest_version
115
116
  end
116
117
 
117
118
  def embeddings_model
118
- @embeddings_model ||= client.retrieve_model(DEFAULTS[:embeddings_model_name]).latest_version
119
+ @embeddings_model ||= client.retrieve_model(@defaults[:embeddings_model_name]).latest_version
119
120
  end
120
121
  end
121
122
  end
@@ -45,11 +45,11 @@ module Langchain::Prompt
45
45
  end
46
46
 
47
47
  #
48
- # Save the object to a file in JSON format.
48
+ # Save the object to a file in JSON or YAML format.
49
49
  #
50
50
  # @param file_path [String, Pathname] The path to the file to save the object to
51
51
  #
52
- # @raise [ArgumentError] If file_path doesn't end with .json
52
+ # @raise [ArgumentError] If file_path doesn't end with .json or .yaml or .yml
53
53
  #
54
54
  # @return [void]
55
55
  #
@@ -9,7 +9,7 @@ module Langchain::Tool
9
9
  #
10
10
  # - {Langchain::Tool::Calculator}: Calculate the result of a math expression
11
11
  # - {Langchain::Tool::RubyCodeInterpretor}: Runs ruby code
12
- # - {Langchain::Tool::Search}: search on Google (via SerpAPI)
12
+ # - {Langchain::Tool::GoogleSearch}: search on Google (via SerpAPI)
13
13
  # - {Langchain::Tool::Wikipedia}: search on Wikipedia
14
14
  #
15
15
  # == Usage
@@ -30,13 +30,13 @@ module Langchain::Tool
30
30
  # agent = Langchain::Agent::ChainOfThoughtAgent.new(
31
31
  # llm: :openai, # or :cohere, :hugging_face, :google_palm or :replicate
32
32
  # llm_api_key: ENV["OPENAI_API_KEY"],
33
- # tools: ["search", "calculator", "wikipedia"]
33
+ # tools: ["google_search", "calculator", "wikipedia"]
34
34
  # )
35
35
  #
36
36
  # 4. Confirm that the Agent is using the Tools you passed in:
37
37
  #
38
38
  # agent.tools
39
- # # => ["search", "calculator", "wikipedia"]
39
+ # # => ["google_search", "calculator", "wikipedia"]
40
40
  #
41
41
  # == Adding Tools
42
42
  #
@@ -57,6 +57,12 @@ module Langchain::Tool
57
57
  self.class.const_get(:NAME)
58
58
  end
59
59
 
60
+ def self.logger_options
61
+ {
62
+ color: :light_blue
63
+ }
64
+ end
65
+
60
66
  #
61
67
  # Returns the DESCRIPTION constant of the tool
62
68
  #
@@ -16,6 +16,11 @@ module Langchain::Tool
16
16
  Useful for getting the result of a math expression.
17
17
 
18
18
  The input to this tool should be a valid mathematical expression that could be executed by a simple calculator.
19
+ Usage:
20
+ Action Input: 1 + 1
21
+ Action Input: 3 * 2 / 4
22
+ Action Input: 9 - 7
23
+ Action Input: (4.1 + 2.3) / (2.0 - 5.6) * 3
19
24
  DESC
20
25
 
21
26
  def initialize
@@ -28,18 +33,11 @@ module Langchain::Tool
28
33
  # @param input [String] math expression
29
34
  # @return [String] Answer
30
35
  def execute(input:)
31
- Langchain.logger.info("[#{self.class.name}]".light_blue + ": Executing \"#{input}\"")
36
+ Langchain.logger.info("Executing \"#{input}\"", for: self.class)
32
37
 
33
38
  Eqn::Calculator.calc(input)
34
39
  rescue Eqn::ParseError, Eqn::NoVariableValueError
35
- # Sometimes the input is not a pure math expression, e.g: "12F in Celsius"
36
- # We can use the google answer box to evaluate this expression
37
- # TODO: Figure out to find a better way to evaluate these language expressions.
38
- hash_results = Langchain::Tool::SerpApi
39
- .new(api_key: ENV["SERPAPI_API_KEY"])
40
- .execute_search(input: input)
41
- hash_results.dig(:answer_box, :to) ||
42
- hash_results.dig(:answer_box, :result)
40
+ "\"#{input}\" is an invalid mathematical expression"
43
41
  end
44
42
  end
45
43
  end
@@ -14,15 +14,18 @@ module Langchain::Tool
14
14
  The input to this tool should be valid SQL.
15
15
  DESC
16
16
 
17
- attr_reader :db
17
+ attr_reader :db, :requested_tables, :except_tables
18
18
 
19
19
  #
20
20
  # Establish a database connection
21
21
  #
22
22
  # @param connection_string [String] Database connection info, e.g. 'postgres://user:password@localhost:5432/db_name'
23
+ # @param tables [Array<Symbol>] The tables to use. Will use all if empty.
24
+ # @param except_tables [Array<Symbol>] The tables to exclude. Will exclude none if empty.
25
+
23
26
  # @return [Database] Database object
24
27
  #
25
- def initialize(connection_string:)
28
+ def initialize(connection_string:, tables: [], except_tables: [])
26
29
  depends_on "sequel"
27
30
  require "sequel"
28
31
  require "sequel/extensions/schema_dumper"
@@ -30,7 +33,8 @@ module Langchain::Tool
30
33
  raise StandardError, "connection_string parameter cannot be blank" if connection_string.empty?
31
34
 
32
35
  @db = Sequel.connect(connection_string)
33
- @db.extension :schema_dumper
36
+ @requested_tables = tables
37
+ @except_tables = except_tables
34
38
  end
35
39
 
36
40
  #
@@ -38,9 +42,26 @@ module Langchain::Tool
38
42
  #
39
43
  # @return [String] schema
40
44
  #
41
- def schema
42
- Langchain.logger.info("[#{self.class.name}]".light_blue + ": Dumping schema")
43
- db.dump_schema_migration(same_db: true, indexes: false) unless db.adapter_scheme == :mock
45
+ def dump_schema
46
+ Langchain.logger.info("Dumping schema tables and keys", for: self.class)
47
+ schema = ""
48
+ db.tables.each do |table|
49
+ next if except_tables.include?(table)
50
+ next unless requested_tables.empty? || requested_tables.include?(table)
51
+
52
+ schema << "CREATE TABLE #{table}(\n"
53
+ db.schema(table).each do |column|
54
+ schema << "#{column[0]} #{column[1][:type]}"
55
+ schema << " PRIMARY KEY" if column[1][:primary_key] == true
56
+ schema << "," unless column == db.schema(table).last
57
+ schema << "\n"
58
+ end
59
+ schema << ");\n"
60
+ db.foreign_key_list(table).each do |fk|
61
+ schema << "ALTER TABLE #{table} ADD FOREIGN KEY (#{fk[:columns][0]}) REFERENCES #{fk[:table]}(#{fk[:key][0]});\n"
62
+ end
63
+ end
64
+ schema
44
65
  end
45
66
 
46
67
  #
@@ -50,11 +71,11 @@ module Langchain::Tool
50
71
  # @return [Array] results
51
72
  #
52
73
  def execute(input:)
53
- Langchain.logger.info("[#{self.class.name}]".light_blue + ": Executing \"#{input}\"")
74
+ Langchain.logger.info("Executing \"#{input}\"", for: self.class)
54
75
 
55
76
  db[input].to_a
56
77
  rescue Sequel::DatabaseError => e
57
- Langchain.logger.error("[#{self.class.name}]".light_red + ": #{e.message}")
78
+ Langchain.logger.error(e.message, for: self.class)
58
79
  end
59
80
  end
60
81
  end
@@ -1,18 +1,18 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module Langchain::Tool
4
- class SerpApi < Base
4
+ class GoogleSearch < Base
5
5
  #
6
- # Wrapper around SerpAPI
6
+ # Wrapper around Google Serp SPI
7
7
  #
8
8
  # Gem requirements: gem "google_search_results", "~> 2.0.0"
9
9
  #
10
10
  # Usage:
11
- # search = Langchain::Tool::SerpApi.new(api_key: "YOUR_API_KEY")
11
+ # search = Langchain::Tool::GoogleSearch.new(api_key: "YOUR_API_KEY")
12
12
  # search.execute(input: "What is the capital of France?")
13
13
  #
14
14
 
15
- NAME = "search"
15
+ NAME = "google_search"
16
16
 
17
17
  description <<~DESC
18
18
  A wrapper around Google Search.
@@ -26,10 +26,10 @@ module Langchain::Tool
26
26
  attr_reader :api_key
27
27
 
28
28
  #
29
- # Initializes the SerpAPI tool
29
+ # Initializes the Google Search tool
30
30
  #
31
- # @param api_key [String] SerpAPI API key
32
- # @return [Langchain::Tool::SerpApi] SerpAPI tool
31
+ # @param api_key [String] Search API key
32
+ # @return [Langchain::Tool::GoogleSearch] Google search tool
33
33
  #
34
34
  def initialize(api_key:)
35
35
  depends_on "google_search_results"
@@ -54,7 +54,7 @@ module Langchain::Tool
54
54
  # @return [String] Answer
55
55
  #
56
56
  def execute(input:)
57
- Langchain.logger.info("[#{self.class.name}]".light_blue + ": Executing \"#{input}\"")
57
+ Langchain.logger.info("Executing \"#{input}\"", for: self.class)
58
58
 
59
59
  hash_results = execute_search(input: input)
60
60
 
@@ -72,7 +72,7 @@ module Langchain::Tool
72
72
  # @return [Hash] hash_results JSON
73
73
  #
74
74
  def execute_search(input:)
75
- GoogleSearch
75
+ ::GoogleSearch
76
76
  .new(q: input, serp_api_key: api_key)
77
77
  .get_hash
78
78
  end
@@ -21,7 +21,7 @@ module Langchain::Tool
21
21
  # @param input [String] ruby code expression
22
22
  # @return [String] Answer
23
23
  def execute(input:)
24
- Langchain.logger.info("[#{self.class.name}]".light_blue + ": Executing \"#{input}\"")
24
+ Langchain.logger.info("Executing \"#{input}\"", for: self.class)
25
25
 
26
26
  safe_eval(input)
27
27
  end
@@ -21,7 +21,7 @@ module Langchain::Tool
21
21
 
22
22
  description <<~DESC
23
23
  Useful for getting current weather data
24
-
24
+
25
25
  The input to this tool should be a city name followed by the units (imperial, metric, or standard)
26
26
  Usage:
27
27
  Action Input: St Louis, Missouri; metric
@@ -54,7 +54,7 @@ module Langchain::Tool
54
54
  # @param input [String] comma separated city and unit (optional: imperial, metric, or standard)
55
55
  # @return [String] Answer
56
56
  def execute(input:)
57
- Langchain.logger.info("[#{self.class.name}]".light_blue + ": Executing for \"#{input}\"")
57
+ Langchain.logger.info("Executing for \"#{input}\"", for: self.class)
58
58
 
59
59
  input_array = input.split(";")
60
60
  city, units = *input_array.map(&:strip)
@@ -26,7 +26,7 @@ module Langchain::Tool
26
26
  # @param input [String] search query
27
27
  # @return [String] Answer
28
28
  def execute(input:)
29
- Langchain.logger.info("[#{self.class.name}]".light_blue + ": Executing \"#{input}\"")
29
+ Langchain.logger.info("Executing \"#{input}\"", for: self.class)
30
30
 
31
31
  page = ::Wikipedia.find(input)
32
32
  # It would be nice to figure out a way to provide page.content but the LLM token limit is an issue
@@ -0,0 +1,38 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain
4
+ module Utils
5
+ module TokenLength
6
+ #
7
+ # Calculate the `max_tokens:` parameter to be set by calculating the context length of the text minus the prompt length
8
+ #
9
+ # @param content [String | Array<String>] The text or array of texts to validate
10
+ # @param model_name [String] The model name to validate against
11
+ # @return [Integer] Whether the text is valid or not
12
+ # @raise [TokenLimitExceeded] If the text is too long
13
+ #
14
+ class BaseValidator
15
+ def self.validate_max_tokens!(content, model_name, options = {})
16
+ text_token_length = if content.is_a?(Array)
17
+ content.sum { |item| token_length(item.to_json, model_name, options) }
18
+ else
19
+ token_length(content, model_name, options)
20
+ end
21
+
22
+ leftover_tokens = token_limit(model_name) - text_token_length
23
+
24
+ # Raise an error even if whole prompt is equal to the model's token limit (leftover_tokens == 0)
25
+ if leftover_tokens <= 0
26
+ raise limit_exceeded_exception(token_limit(model_name), text_token_length)
27
+ end
28
+
29
+ leftover_tokens
30
+ end
31
+
32
+ def self.limit_exceeded_exception(limit, length)
33
+ TokenLimitExceeded.new("This model's maximum context length is #{limit} tokens, but the given text is #{length} tokens long.", length - limit)
34
+ end
35
+ end
36
+ end
37
+ end
38
+ end
@@ -7,7 +7,7 @@ module Langchain
7
7
  # This class is meant to validate the length of the text passed in to Google Palm's API.
8
8
  # It is used to validate the token length before the API call is made
9
9
  #
10
- class GooglePalmValidator
10
+ class GooglePalmValidator < BaseValidator
11
11
  TOKEN_LIMITS = {
12
12
  # Source:
13
13
  # This data can be pulled when `list_models()` method is called: https://github.com/andreibondarev/google_palm_api#usage
@@ -26,43 +26,23 @@ module Langchain
26
26
  # }
27
27
  }.freeze
28
28
 
29
- #
30
- # Validate the context length of the text
31
- #
32
- # @param content [String | Array<String>] The text or array of texts to validate
33
- # @param model_name [String] The model name to validate against
34
- # @return [Integer] Whether the text is valid or not
35
- # @raise [TokenLimitExceeded] If the text is too long
36
- #
37
- def self.validate_max_tokens!(google_palm_llm, content, model_name)
38
- text_token_length = if content.is_a?(Array)
39
- content.sum { |item| token_length(google_palm_llm, item.to_json, model_name) }
40
- else
41
- token_length(google_palm_llm, content, model_name)
42
- end
43
-
44
- leftover_tokens = TOKEN_LIMITS.dig(model_name, "input_token_limit") - text_token_length
45
-
46
- # Raise an error even if whole prompt is equal to the model's token limit (leftover_tokens == 0)
47
- if leftover_tokens <= 0
48
- raise TokenLimitExceeded, "This model's maximum context length is #{TOKEN_LIMITS.dig(model_name, "input_token_limit")} tokens, but the given text is #{text_token_length} tokens long."
49
- end
50
-
51
- leftover_tokens
52
- end
53
-
54
29
  #
55
30
  # Calculate token length for a given text and model name
56
31
  #
57
- # @param llm [Langchain::LLM:GooglePalm] The Langchain::LLM:GooglePalm instance
58
32
  # @param text [String] The text to calculate the token length for
59
33
  # @param model_name [String] The model name to validate against
34
+ # @param options [Hash] the options to create a message with
35
+ # @option options [Langchain::LLM:GooglePalm] :llm The Langchain::LLM:GooglePalm instance
60
36
  # @return [Integer] The token length of the text
61
37
  #
62
- def self.token_length(llm, text, model_name = "chat-bison-001")
63
- response = llm.client.count_message_tokens(model: model_name, prompt: text)
38
+ def self.token_length(text, model_name = "chat-bison-001", options)
39
+ response = options[:llm].client.count_message_tokens(model: model_name, prompt: text)
64
40
  response.dig("tokenCount")
65
41
  end
42
+
43
+ def self.token_limit(model_name)
44
+ TOKEN_LIMITS.dig(model_name, "input_token_limit")
45
+ end
66
46
  end
67
47
  end
68
48
  end